diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index d202ba9..3db90c7 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -6,31 +6,32 @@ from typing import Any, Dict, Iterator, List, Tuple, Union from collections import defaultdict import numpy as np -from numpy import pi from . import io, utils from .defaults import default_parameters from .errors import * from .logger import get_logger -from .math import power_fact -from .physics import fiber, pulse, units -from .utils import override_config, required_simulations, evaluator -from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters - -global_evaluator = evaluator.Evaluator() +from .utils import override_config, required_simulations +from .utils.evaluator import Evaluator +from .utils.parameter import ( + BareConfig, + BareParams, + hc_model_specific_parameters, + mandatory_parameters, +) @dataclass class Params(BareParams): @classmethod def from_bare(cls, bare: BareParams): - return cls(**asdict(bare)) - - def __post_init__(self): - self.compute() - - def compute(self): - logger = get_logger(__name__) + param_dict = {k: v for k, v in asdict(bare).items() if v is not None} + evaluator = Evaluator.default() + evaluator.set(**param_dict) + for p_name in mandatory_parameters: + evaluator.compute(p_name) + new_param_dict = {k: v for k, v in evaluator.params.items() if k in param_dict} + return cls(**new_param_dict) @dataclass diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 6389dea..0cf4486 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -12,7 +12,6 @@ from ..logger import get_logger from .. import io from ..math import abs2, argclosest, power_fact, u_nm -from ..utils.parameter import BareConfig, BareParams, hc_model_specific_parameters from ..utils.cache import np_cache from . import materials as mat from . import units diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index e36784d..33fdcf6 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -30,7 +30,6 @@ from ..defaults import default_plotting from ..logger import get_logger from ..math import * from ..plotting import plot_setup -from ..utils.parameter import BareParams from . import units c = 299792458.0 @@ -343,54 +342,6 @@ def load_field_file( return field_0, peak_power, energy, width -def setup_custom_field(params: BareParams) -> bool: - """sets up a custom field function if necessary and returns - True if it did so, False otherwise - - Parameters - ---------- - params : Dict[str, Any] - params dictionary - - Returns - ------- - bool - True if the field has been modified - """ - field_0 = params.field_0 - width = params.width - peak_power = params.peak_power - energy = params.energy - - did_set = True - - if params.prev_data_dir is not None: - spec = io.load_last_spectrum(Path(params.prev_data_dir))[1] - field_0 = np.fft.ifft(spec) - elif params.field_file is not None: - field_data = np.load(params.field_file) - field_interp = interp1d( - field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0) - ) - field_0 = field_interp(params.t) - - field_0 = field_0 * modify_field_ratio( - params.t, - field_0, - params.peak_power, - params.energy, - params.intensity_noise, - ) - width, peak_power, energy = measure_field(params.t, field_0) - else: - did_set = False - - if did_set: - field_0 = field_0 * np.sqrt(params.input_transmission) - - return did_set, width, peak_power, energy, field_0 - - def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float: """ finds a new wavelength parameter such that the maximum of the spectrum corresponding @@ -481,6 +432,14 @@ def shot_noise(w_c, w0, T, dt): return out +def add_shot_noise( + field_0: np.ndarray, quantum_noise: bool, w_c: bool, w0: float, time_window: float, dt: float +) -> np.ndarray: + if quantum_noise: + field_0 = field_0 + shot_noise(w_c, w0, time_window, dt) + return field_0 + + def mean_phase(spectra): """computes the mean phase of spectra Parameter diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 5178277..211943f 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -80,7 +80,7 @@ class RK4IP: self.raman_type = params.raman_type self.hr_w = params.hr_w self.adapt_step_size = params.adapt_step_size - self.error_ok = params.error_ok + self.error_ok = params.tolerated_error self.dynamic_dispersion = params.dynamic_dispersion self.starting_num = params.recovery_last_stored diff --git a/src/scgenerator/utils/evaluator.py b/src/scgenerator/utils/evaluator.py index d80dfc6..be5ac3d 100644 --- a/src/scgenerator/utils/evaluator.py +++ b/src/scgenerator/utils/evaluator.py @@ -97,6 +97,12 @@ class EvalStat: class Evaluator: + @classmethod + def default(cls) -> "Evaluator": + evaluator = cls() + evaluator.append(*default_rules) + return evaluator + def __init__(self): self.rules: dict[str, list[Rule]] = defaultdict(list) self.params = {} @@ -111,7 +117,7 @@ class Evaluator: self.rules[t].append(r) self.rules[t].sort(key=lambda el: el.targets[t], reverse=True) - def update(self, **params: Any): + def set(self, **params: Any): self.params.update(params) for k in params: self.eval_stats[k].priority = np.inf @@ -256,9 +262,29 @@ default_rules: list[Rule] = [ Rule("field_0", np.fft.ifft, ["spec_0"]), Rule("spec_0", pulse.load_previous_spectrum, priorities=3), Rule( - ["field_0", "peak_power", "energy", "width"], pulse.load_field_file, priorities=[2, 1, 1, 1] + ["pre_field_0", "peak_power", "energy", "width"], + pulse.load_field_file, + [ + "field_file", + "t", + "peak_power", + "energy", + "intensity_noise", + "noise_correlation", + "quantum_noise", + "w_c", + "w0", + "time_window", + "dt", + ], + priorities=[2, 1, 1, 1], + ), + Rule("pre_field_0", pulse.initial_field, priorities=1), + Rule( + "field_0", + pulse.add_shot_noise, + ["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"], ), - Rule("field_0", pulse.initial_field, priorities=1), Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), Rule("peak_power", pulse.soliton_num_to_peak_power), Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]), @@ -329,7 +355,7 @@ def main(): evalor = Evaluator() evalor.append(*default_rules) - evalor.update( + evalor.set( **{ "length": 1, "z_num": 128, @@ -343,8 +369,9 @@ def main(): "width": 30e-15, "mean_power": 100e-3, "n2": 2.4e-20, - "A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_max.npz", + "A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_marcuse.npz", "model": "pcf", + "quantum_noise": True, "pitch": 1.2e-6, "pitch_ratio": 0.5, } @@ -354,6 +381,7 @@ def main(): print(evalor.params["l"][evalor.params["l"] > 0].min()) evalor.compute("spec_0") plt.plot(evalor.params["l"], abs(evalor.params["spec_0"]) ** 2) + plt.yscale("log") plt.show() print(evalor.compute("gamma")) print(evalor.compute("beta2")) diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index e8ccfbc..fffc323 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -342,6 +342,27 @@ hc_model_specific_parameters = dict( ) """dependecy map only includes actual fiber parameters and exclude gas parameters""" +mandatory_parameters = [ + "name", + "w_c", + "w", + "w0", + "w_power_fact", + "alpha", + "spec_0", + "z_targets", + "length", + "beta2_coefficients", + "gamma_arr", + "behaviors", + "raman_type", + "hr_w", + "adapt_step_size", + "tollerated_error", + "dynamic_dispersion", + "recovery_last_stored", +] + @dataclass class BareParams: @@ -445,7 +466,6 @@ class BareParams: L_sol: float = Parameter(non_negative(float, int)) dynamic_dispersion: bool = Parameter(boolean) adapt_step_size: bool = Parameter(boolean) - error_ok: float = Parameter(positive(float)) hr_w: np.ndarray = Parameter(type_checker(np.ndarray)) z_targets: np.ndarray = Parameter(type_checker(np.ndarray)) const_qty: np.ndarray = Parameter(type_checker(np.ndarray))