compute function written

This commit is contained in:
Benoît Sierro
2021-08-27 09:28:45 +02:00
parent 4bcf5add60
commit 6c869d5c6c
6 changed files with 78 additions and 71 deletions

View File

@@ -6,31 +6,32 @@ from typing import Any, Dict, Iterator, List, Tuple, Union
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
from numpy import pi
from . import io, utils from . import io, utils
from .defaults import default_parameters from .defaults import default_parameters
from .errors import * from .errors import *
from .logger import get_logger from .logger import get_logger
from .math import power_fact from .utils import override_config, required_simulations
from .physics import fiber, pulse, units from .utils.evaluator import Evaluator
from .utils import override_config, required_simulations, evaluator from .utils.parameter import (
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters BareConfig,
BareParams,
global_evaluator = evaluator.Evaluator() hc_model_specific_parameters,
mandatory_parameters,
)
@dataclass @dataclass
class Params(BareParams): class Params(BareParams):
@classmethod @classmethod
def from_bare(cls, bare: BareParams): def from_bare(cls, bare: BareParams):
return cls(**asdict(bare)) param_dict = {k: v for k, v in asdict(bare).items() if v is not None}
evaluator = Evaluator.default()
def __post_init__(self): evaluator.set(**param_dict)
self.compute() for p_name in mandatory_parameters:
evaluator.compute(p_name)
def compute(self): new_param_dict = {k: v for k, v in evaluator.params.items() if k in param_dict}
logger = get_logger(__name__) return cls(**new_param_dict)
@dataclass @dataclass

View File

@@ -12,7 +12,6 @@ from ..logger import get_logger
from .. import io from .. import io
from ..math import abs2, argclosest, power_fact, u_nm from ..math import abs2, argclosest, power_fact, u_nm
from ..utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
from ..utils.cache import np_cache from ..utils.cache import np_cache
from . import materials as mat from . import materials as mat
from . import units from . import units

View File

@@ -30,7 +30,6 @@ from ..defaults import default_plotting
from ..logger import get_logger from ..logger import get_logger
from ..math import * from ..math import *
from ..plotting import plot_setup from ..plotting import plot_setup
from ..utils.parameter import BareParams
from . import units from . import units
c = 299792458.0 c = 299792458.0
@@ -343,54 +342,6 @@ def load_field_file(
return field_0, peak_power, energy, width return field_0, peak_power, energy, width
def setup_custom_field(params: BareParams) -> bool:
"""sets up a custom field function if necessary and returns
True if it did so, False otherwise
Parameters
----------
params : Dict[str, Any]
params dictionary
Returns
-------
bool
True if the field has been modified
"""
field_0 = params.field_0
width = params.width
peak_power = params.peak_power
energy = params.energy
did_set = True
if params.prev_data_dir is not None:
spec = io.load_last_spectrum(Path(params.prev_data_dir))[1]
field_0 = np.fft.ifft(spec)
elif params.field_file is not None:
field_data = np.load(params.field_file)
field_interp = interp1d(
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
)
field_0 = field_interp(params.t)
field_0 = field_0 * modify_field_ratio(
params.t,
field_0,
params.peak_power,
params.energy,
params.intensity_noise,
)
width, peak_power, energy = measure_field(params.t, field_0)
else:
did_set = False
if did_set:
field_0 = field_0 * np.sqrt(params.input_transmission)
return did_set, width, peak_power, energy, field_0
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float: def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:
""" """
finds a new wavelength parameter such that the maximum of the spectrum corresponding finds a new wavelength parameter such that the maximum of the spectrum corresponding
@@ -481,6 +432,14 @@ def shot_noise(w_c, w0, T, dt):
return out return out
def add_shot_noise(
field_0: np.ndarray, quantum_noise: bool, w_c: bool, w0: float, time_window: float, dt: float
) -> np.ndarray:
if quantum_noise:
field_0 = field_0 + shot_noise(w_c, w0, time_window, dt)
return field_0
def mean_phase(spectra): def mean_phase(spectra):
"""computes the mean phase of spectra """computes the mean phase of spectra
Parameter Parameter

View File

@@ -80,7 +80,7 @@ class RK4IP:
self.raman_type = params.raman_type self.raman_type = params.raman_type
self.hr_w = params.hr_w self.hr_w = params.hr_w
self.adapt_step_size = params.adapt_step_size self.adapt_step_size = params.adapt_step_size
self.error_ok = params.error_ok self.error_ok = params.tolerated_error
self.dynamic_dispersion = params.dynamic_dispersion self.dynamic_dispersion = params.dynamic_dispersion
self.starting_num = params.recovery_last_stored self.starting_num = params.recovery_last_stored

View File

@@ -97,6 +97,12 @@ class EvalStat:
class Evaluator: class Evaluator:
@classmethod
def default(cls) -> "Evaluator":
evaluator = cls()
evaluator.append(*default_rules)
return evaluator
def __init__(self): def __init__(self):
self.rules: dict[str, list[Rule]] = defaultdict(list) self.rules: dict[str, list[Rule]] = defaultdict(list)
self.params = {} self.params = {}
@@ -111,7 +117,7 @@ class Evaluator:
self.rules[t].append(r) self.rules[t].append(r)
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True) self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
def update(self, **params: Any): def set(self, **params: Any):
self.params.update(params) self.params.update(params)
for k in params: for k in params:
self.eval_stats[k].priority = np.inf self.eval_stats[k].priority = np.inf
@@ -256,9 +262,29 @@ default_rules: list[Rule] = [
Rule("field_0", np.fft.ifft, ["spec_0"]), Rule("field_0", np.fft.ifft, ["spec_0"]),
Rule("spec_0", pulse.load_previous_spectrum, priorities=3), Rule("spec_0", pulse.load_previous_spectrum, priorities=3),
Rule( Rule(
["field_0", "peak_power", "energy", "width"], pulse.load_field_file, priorities=[2, 1, 1, 1] ["pre_field_0", "peak_power", "energy", "width"],
pulse.load_field_file,
[
"field_file",
"t",
"peak_power",
"energy",
"intensity_noise",
"noise_correlation",
"quantum_noise",
"w_c",
"w0",
"time_window",
"dt",
],
priorities=[2, 1, 1, 1],
),
Rule("pre_field_0", pulse.initial_field, priorities=1),
Rule(
"field_0",
pulse.add_shot_noise,
["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"],
), ),
Rule("field_0", pulse.initial_field, priorities=1),
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
Rule("peak_power", pulse.soliton_num_to_peak_power), Rule("peak_power", pulse.soliton_num_to_peak_power),
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]), Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
@@ -329,7 +355,7 @@ def main():
evalor = Evaluator() evalor = Evaluator()
evalor.append(*default_rules) evalor.append(*default_rules)
evalor.update( evalor.set(
**{ **{
"length": 1, "length": 1,
"z_num": 128, "z_num": 128,
@@ -343,8 +369,9 @@ def main():
"width": 30e-15, "width": 30e-15,
"mean_power": 100e-3, "mean_power": 100e-3,
"n2": 2.4e-20, "n2": 2.4e-20,
"A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_max.npz", "A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_marcuse.npz",
"model": "pcf", "model": "pcf",
"quantum_noise": True,
"pitch": 1.2e-6, "pitch": 1.2e-6,
"pitch_ratio": 0.5, "pitch_ratio": 0.5,
} }
@@ -354,6 +381,7 @@ def main():
print(evalor.params["l"][evalor.params["l"] > 0].min()) print(evalor.params["l"][evalor.params["l"] > 0].min())
evalor.compute("spec_0") evalor.compute("spec_0")
plt.plot(evalor.params["l"], abs(evalor.params["spec_0"]) ** 2) plt.plot(evalor.params["l"], abs(evalor.params["spec_0"]) ** 2)
plt.yscale("log")
plt.show() plt.show()
print(evalor.compute("gamma")) print(evalor.compute("gamma"))
print(evalor.compute("beta2")) print(evalor.compute("beta2"))

View File

@@ -342,6 +342,27 @@ hc_model_specific_parameters = dict(
) )
"""dependecy map only includes actual fiber parameters and exclude gas parameters""" """dependecy map only includes actual fiber parameters and exclude gas parameters"""
mandatory_parameters = [
"name",
"w_c",
"w",
"w0",
"w_power_fact",
"alpha",
"spec_0",
"z_targets",
"length",
"beta2_coefficients",
"gamma_arr",
"behaviors",
"raman_type",
"hr_w",
"adapt_step_size",
"tollerated_error",
"dynamic_dispersion",
"recovery_last_stored",
]
@dataclass @dataclass
class BareParams: class BareParams:
@@ -445,7 +466,6 @@ class BareParams:
L_sol: float = Parameter(non_negative(float, int)) L_sol: float = Parameter(non_negative(float, int))
dynamic_dispersion: bool = Parameter(boolean) dynamic_dispersion: bool = Parameter(boolean)
adapt_step_size: bool = Parameter(boolean) adapt_step_size: bool = Parameter(boolean)
error_ok: float = Parameter(positive(float))
hr_w: np.ndarray = Parameter(type_checker(np.ndarray)) hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
z_targets: np.ndarray = Parameter(type_checker(np.ndarray)) z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) const_qty: np.ndarray = Parameter(type_checker(np.ndarray))