diff --git a/src/scgenerator/cli/__main__.py b/src/scgenerator/cli/__main__.py new file mode 100644 index 0000000..9ae637f --- /dev/null +++ b/src/scgenerator/cli/__main__.py @@ -0,0 +1,4 @@ +from .cli import main + +if __name__ == "__main__": + main() diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index f8350dc..9a20003 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -105,7 +105,7 @@ def prep_ray(args): def resume_sim(args): method = prep_ray(args) - sim = resume_simulations(args.sim_dir, method=method) + sim = resume_simulations(Path(args.sim_dir), method=method) sim.run() run_simulation_sequence( *args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name diff --git a/src/scgenerator/cli/new_config.py b/src/scgenerator/cli/new_config.py deleted file mode 100644 index fd7a6e6..0000000 --- a/src/scgenerator/cli/new_config.py +++ /dev/null @@ -1,78 +0,0 @@ -from .. import const -import toml - -valid_commands = ["finish", "next"] - - -class Configurator: - def __init__(self, name): - self.config = dict(name=name, fiber=dict(), gas=dict(), pulse=dict(), simulation=dict()) - - def list_input(self): - answer = "" - while answer == "": - answer = input("Please enter a list of values (one per line)\n") - - out = [self.process_input(answer)] - - while answer != "": - answer = input() - out.append(self.process_input(answer)) - - return out[:-1] - - def process_input(self, s): - try: - return int(s) - except ValueError: - pass - - try: - return float(s) - except ValueError: - pass - - return s - - def accept(self, question, default=True): - question += " ([y]/n)" if default else " (y/[n])" - question += "\n" - inp = input(question) - - yes_str = ["y", "yes"] - if default: - yes_str.append("") - - return inp.lower() in yes_str - - def print_current(self, config: dict): - print(toml.dumps(config)) - - def get(self, section, param_name): - question = f"Please enter a value for the parameter '{param_name}'\n" - valid = const.valid_param_types[section][param_name] - - is_valid = False - value = None - - while not is_valid: - answer = input(question) - if answer == "variable" and param_name in const.valid_variable[section]: - value = self.list_input() - print(value) - is_valid = all(valid(v) for v in value) - else: - value = self.process_input(answer) - is_valid = valid(value) - - return value - - def ask_next_command(self): - s = "" - raw_input = input(s).split(" ") - return raw_input[0], raw_input[1:] - - def main(self): - editing = True - while editing: - command, args = self.ask_next_command() diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index 3d6c881..7e24687 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -1,6 +1,3 @@ -import numpy as np -from collections import namedtuple - __version__ = "0.1.0" @@ -19,243 +16,6 @@ def pbar_format(worker_id: int): ) -##### - - -def in_range_excl(func, r): - def _in_range(n): - if not func(n): - return False - return n > r[0] and n < r[1] - - _in_range.__doc__ = func.__doc__ + f" between {r[0]} and {r[1]} (exclusive) " - return _in_range - - -def in_range_incl(func, r): - def _in_range(n): - if not func(n): - return False - return n >= r[0] and n <= r[1] - - _in_range.__doc__ = func.__doc__ + f" between {r[0]} and {r[1]} (inclusive)" - return _in_range - - -def num(n): - """must be a single, real, non-negative number""" - return isinstance(n, (float, int)) and n >= 0 - - -def integer(n): - """must be a strictly positive integer""" - return isinstance(n, int) and n > 0 - - -def boolean(b): - """must be a boolean""" - return type(b) == bool - - -def behaviors(l): - """must be a valid list of behaviors""" - for s in l: - if s.lower() not in ["spm", "raman", "ss"]: - return False - return True - - -def beta(l): - """must be a valid beta array""" - for n in l: - if not isinstance(n, (float, int)): - return False - return True - - -def field_0(f): - return isinstance(f, (str, tuple, list, np.ndarray)) - - -def he_mode(mode): - """must be a valide HE mode""" - if not isinstance(mode, (list, tuple)): - return False - if not len(mode) == 2: - return False - for m in mode: - if not integer(m): - return False - return True - - -def fit_parameters(param): - """must be a valide fitting parameter tuple of the mercatili_adjusted model""" - if not isinstance(param, (list, tuple)): - return False - if not len(param) == 2: - return False - for n in param: - if not integer(n): - return False - return True - - -def string(l=None): - if l is None: - - def _string(s): - return isinstance(s, str) - - _string.__doc__ = f"must be a str" - else: - - def _string(s): - return isinstance(s, str) and s.lower() in l - - _string.__doc__ = f"must be a str matching one of {l}" - - return _string - - -def capillary_resonance_strengths(l): - """must be a list of non-zero, real number""" - if not isinstance(l, (list, tuple)): - return False - for m in l: - if not num(m): - return False - return True - - -def capillary_nested(n): - """must be a non negative integer""" - return isinstance(n, int) and n >= 0 - - -valid_param_types = dict( - root=dict( - name=string(), - prev_data_dir=string(), - ), - fiber=dict( - input_transmission=in_range_incl(num, (0, 1)), - gamma=num, - n2=num, - effective_mode_diameter=num, - A_eff=num, - pitch=in_range_excl(num, (0, 1e-3)), - pitch_ratio=in_range_excl(num, (0, 1)), - core_radius=in_range_excl(num, (0, 1e-3)), - he_mode=he_mode, - fit_parameters=fit_parameters, - beta=beta, - dispersion_file=string(), - model=string(["pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"]), - length=in_range_excl(num, (0, 1e9)), - capillary_num=integer, - capillary_outer_d=in_range_excl(num, (0, 1e-3)), - capillary_thickness=in_range_excl(num, (0, 1e-3)), - capillary_spacing=in_range_excl(num, (0, 1e-3)), - capillary_resonance_strengths=capillary_resonance_strengths, - capillary_nested=capillary_nested, - ), - gas=dict( - gas_name=string(["vacuum", "helium", "air"]), - pressure=num, - temperature=num, - plasma_density=num, - ), - pulse=dict( - field_0=field_0, - field_file=string(), - repetition_rate=num, - peak_power=num, - mean_power=num, - energy=num, - soliton_num=num, - quantum_noise=boolean, - shape=string(["gaussian", "sech"]), - wavelength=in_range_excl(num, (100e-9, 3000e-9)), - intensity_noise=in_range_incl(num, (0, 1)), - width=in_range_excl(num, (0, 1e-9)), - t0=in_range_excl(num, (0, 1e-9)), - ), - simulation=dict( - behaviors=behaviors, - parallel=boolean, - raman_type=string(["measured", "agrawal", "stolen"]), - ideal_gas=boolean, - repeat=integer, - t_num=integer, - z_num=integer, - time_window=num, - dt=in_range_excl(num, (0, 5e-15)), - tolerated_error=in_range_excl(num, (1e-15, 1e-5)), - step_size=num, - lower_wavelength_interp_limit=in_range_excl(num, (100e-9, 3000e-9)), - upper_wavelength_interp_limit=in_range_excl(num, (100e-9, 5000e-9)), - frep=num, - prev_sim_dir=string(), - readjust_wavelength=boolean, - ), -) - -hc_model_specific_parameters = dict( - marcatili=["core_radius", "he_mode"], - marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"], - hasan=[ - "core_radius", - "capillary_num", - "capillary_thickness", - "capillary_resonance_strengths", - "capillary_nested", - "capillary_spacing", - "capillary_outer_d", - ], -) -"""dependecy map only includes actual fiber parameters and exclude gas parameters""" - -valid_variable = dict( - fiber=[ - "beta", - "gamma", - "pitch", - "pitch_ratio", - "core_radius", - "capillary_num", - "capillary_outer_d", - "capillary_thickness", - "capillary_spacing", - "capillary_resonance_strengths", - "capillary_nested", - "he_mode", - "fit_parameters", - "input_transmission", - "n2", - ], - gas=["pressure", "temperature", "gas_name", "plasma_density"], - pulse=[ - "peak_power", - "mean_power", - "energy", - "quantum_noise", - "shape", - "wavelength", - "intensity_noise", - "width", - "soliton_num", - ], - simulation=[ - "behaviors", - "raman_type", - "tolerated_error", - "step_size", - "ideal_gas", - "readjust_wavelength", - ], -) - ENVIRON_KEY_BASE = "SCGENERATOR_" PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY" LOG_POLICY = ENVIRON_KEY_BASE + "LOG_POLICY" diff --git a/src/scgenerator/defaults.py b/src/scgenerator/defaults.py index f414b25..6e6dd03 100644 --- a/src/scgenerator/defaults.py +++ b/src/scgenerator/defaults.py @@ -1,6 +1,5 @@ import matplotlib.pyplot as plt - -from .errors import MissingParameterError +from pathlib import Path default_parameters = dict( input_transmission=1.0, @@ -28,6 +27,7 @@ default_parameters = dict( upper_wavelength_interp_limit=1900e-9, ideal_gas=False, readjust_wavelength=False, + recovery_last_stored=0, ) default_plotting = dict( @@ -36,7 +36,7 @@ default_plotting = dict( vmin=-40, vmax=0, vmax_with_headroom=2, - name="plot", + out_path=Path("plot"), avg_main_to_coherence_ratio=4, avg_line_labels=["individual values", "mean"], muted_style=dict(linewidth=0.5, c=(0.8, 0.8, 0.8, 0.4)), @@ -57,76 +57,3 @@ default_plotting = dict( text_topright_style=dict(verticalalignment="top", horizontalalignment="right"), text_topleft_style=dict(verticalalignment="top", horizontalalignment="left"), ) - - -def get(section_dict, param, **kwargs): - """checks if param is in the parameter section dict and attempts to fill in a default value - - Parameters - ---------- - section_dict : dict - the parameters section {fiber, pulse, simulation, root} sub-dictionary - param : str - the name of the parameter (dict key) - kwargs : any - key word arguments passed to the MissingParameterError constructor - - Returns - ------- - dict - the updated section_dict dictionary - - Raises - ------ - MissingFiberParameterError - raised when a parameter is missing and no default exists - """ - - # whether the parameter is in the right place and valid is checked elsewhere, - # here, we just make sure it is present. - if param not in section_dict and param not in section_dict.get("variable", {}): - try: - section_dict[param] = default_parameters[param] - # LOG - except KeyError: - raise MissingParameterError(param, **kwargs) - return section_dict - - -def get_fiber(section_dict, param, **kwargs): - """wrapper for fiber parameters that depend on fiber model""" - return get(section_dict, param, fiber_model=section_dict["model"], **kwargs) - - -def get_multiple(section_dict, params, num, **kwargs): - """similar to th get method but works with several parameters - - Parameters - ---------- - section_dict : dict - the parameters section {fiber, pulse, simulation, root}, sub-dictionary - params : list of str - names of the required parameters - num : int - how many of the parameters in params are required - - Returns - ------- - dict - the updated section_dict - - Raises - ------ - MissingParameterError - raised when not enough parameters are provided and no defaults exist - """ - gotten = 0 - for param in params: - try: - section_dict = get(section_dict, param, **kwargs) - gotten += 1 - except MissingParameterError: - pass - if gotten >= num: - return section_dict - raise MissingParameterError(params, num_required=num, **kwargs) diff --git a/src/scgenerator/env.py b/src/scgenerator/env.py index 591ea53..3cc445b 100644 --- a/src/scgenerator/env.py +++ b/src/scgenerator/env.py @@ -1,11 +1,10 @@ import os -from pathlib import Path from typing import Dict, Literal, Optional, Set -from .const import ENVIRON_KEY_BASE, PBAR_POLICY, LOG_POLICY, TMP_FOLDER_KEY_BASE +from .const import ENVIRON_KEY_BASE, LOG_POLICY, PBAR_POLICY, TMP_FOLDER_KEY_BASE -def data_folder(task_id: int) -> Optional[Path]: +def data_folder(task_id: int) -> Optional[str]: idstr = str(int(task_id)) tmp = os.getenv(TMP_FOLDER_KEY_BASE + idstr) return tmp diff --git a/src/scgenerator/errors.py b/src/scgenerator/errors.py index 40a229c..9fa79b4 100644 --- a/src/scgenerator/errors.py +++ b/src/scgenerator/errors.py @@ -34,18 +34,3 @@ class DuplicateParameterError(Exception): class IncompleteDataFolderError(FileNotFoundError): pass - - -# class MissingFiberParameterError(MissingParameterError): -# def __init__(self, param, model): -# self.param = param -# self.model = model -# super().__init__( -# f"'{self.param}' is a required parameter for fiber model '{self.model}' and no default value is set" -# ) - - -# class MissingPulseParameterError(MissingParameterError): -# def __init__(self, param): -# self.param = param -# super().__init__(f"'{self.param}' is a required pulse parameter and no default value is set") diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index b29cd85..ea4ca66 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -1,52 +1,359 @@ import os from collections.abc import Mapping -from typing import Any, Dict, Iterator, List, Set, Tuple, Union +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterator, List, Tuple, Union import numpy as np from numpy import pi -from scipy.interpolate.interpolate import interp1d -from tqdm import tqdm -from pathlib import Path -from . import defaults, io, utils -from .const import hc_model_specific_parameters, valid_param_types, valid_variable +from . import io, utils +from .defaults import default_parameters from .errors import * from .logger import get_logger -from .math import abs2, length, power_fact +from .math import abs2, power_fact from .physics import fiber, pulse, units from .utils import count_variations, override_config, required_simulations +from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters -class ParamSequence(Mapping): - def __init__(self, config: Union[Dict[str, Any], os.PathLike]): - if not isinstance(config, Mapping): - config = io.load_toml(config) - self.config = validate(config) - self.name = self.config["name"] +@dataclass +class Params(BareParams): + @classmethod + def from_bare(cls, bare: BareParams): + return cls(**asdict(bare)) + + def __post_init__(self): + self.compute() + + def compute(self): + logger = get_logger(__name__) + + ( + self.z_targets, + self.t, + self.time_window, + self.t_num, + self.dt, + self.w_c, + self.w0, + self.w, + self.w_power_fact, + ) = build_sim_grid( + self.length, self.z_num, self.wavelength, self.time_window, self.t_num, self.dt + ) + + # Initial field may influence the grid + if self.mean_power is not None: + self.energy = self.mean_power / self.repetition_rate + ( + custom_field, + self.width, + self.peak_power, + self.energy, + self.field_0, + ) = pulse.setup_custom_field(self) + if self.readjust_wavelength: + delta_w = self.w_c[np.argmax(abs2(np.fft.fft(self.field_0)))] + logger.debug(f"adjusted w by {delta_w}") + self.wavelength = units.m.inv(units.m(self.wavelength) - delta_w) + self.w_c, self.w0, self.w, self.w_power_fact = update_frequency_domain( + self.t, self.wavelength + ) + + if self.step_size is not None: + self.error_ok = self.step_size + self.adapt_step_size = False + else: + self.error_ok = self.tolerated_error + self.adapt_step_size = True + + # FIBER + self.interp_range = [ + max(self.lower_wavelength_interp_limit, units.m.inv(np.max(self.w[self.w > 0]))), + min(self.upper_wavelength_interp_limit, units.m.inv(np.min(self.w[self.w > 0]))), + ] + + temp_gamma = None + if self.effective_mode_diameter is not None: + self.A_eff = (self.effective_mode_diameter / 2) ** 2 * pi + if self.beta is not None: + self.beta = np.array(self.beta) + self.dynamic_dispersion = False + else: + self.dynamic_dispersion = fiber.is_dynamic_dispersion(self.pressure) + self.beta, temp_gamma = fiber.compute_dispersion(self) + if self.dynamic_dispersion: + self.gamma_func = temp_gamma + self.beta_func = self.beta + self.beta = self.beta_func(0) + temp_gamma = temp_gamma(0) + + if self.gamma is None: + self.gamma = temp_gamma + logger.info(f"using computed \u0263 = {self.gamma:.2e} W/m^2") + + # Raman response + if "raman" in self.behaviors: + self.hr_w = fiber.delayed_raman_w(self.t, self.dt, self.raman_type) + + # GENERIC PULSE + if not custom_field: + custom_field = False + ( + self.width, + self.t0, + self.peak_power, + self.energy, + self.soliton_num, + ) = pulse.conform_pulse_params( + self.shape, + self.width, + self.t0, + self.peak_power, + self.energy, + self.soliton_num, + self.gamma, + self.beta, + ) + logger.info(f"computed initial N = {self['soliton_num']:.3g}") + + self.L_D = self.t0 ** 2 / abs(self.beta[0]) + self.L_NL = 1 / (self.gamma * self.peak_power) if self.gamma else np.inf + self.L_sol = pi / 2 * self.L_D + + # Technical noise + if self.intensity_noise is not None and self.intensity_noise > 0: + delta_int, delta_T0 = pulse.technical_noise(self.intensity_noise) + self["peak_power"] *= delta_int + self["t0"] *= delta_T0 + self["width"] *= delta_T0 + + self.field_0 = pulse.initial_field(self.t, self.shape, self.t0, self.peak_power) + + if self.quantum_noise: + self.field_0 = self.field_0 + pulse.shot_noise( + self.w_c, self.w0, self.time_window, self.dt + ) + + self.spec_0 = np.fft.fft(self.field_0) + + def build_sim_grid(self): + ( + self.z_targets, + self.t, + self.time_window, + self.t_num, + self.dt, + self.w_c, + self.w0, + self.w, + self.w_power_fact, + ) = build_sim_grid( + self.length, self.z_num, self.wavelength, self.time_window, self.t_num, self.dt + ) + + +@dataclass +class Config(BareConfig): + @classmethod + def from_bare(cls, bare: BareConfig): + return cls(**asdict(bare)) + + def __post_init__(self): + for p_name, value in self.__dict__.items(): + if value is not None and p_name in self.variable: + raise DuplicateParameterError(f"got multiple values for parameter {p_name!r}") + self.setdefault("name", "no name") + self.fiber_consistency() + if self.model in hc_model_specific_parameters: + self.gas_consistency() + self.pulse_consistency() + self.simulation_consistency() + + def fiber_consistency(self): + if self.contains("beta"): + if not (self.contains("A_eff") or self.contains("effective_mode_diameter")): + self.gamma = self.get("gamma", specified_parameters=["beta"]) + self.setdefault("model", "custom") + + elif self.contains("dispersion_file"): + if not (self.contains("A_eff") or self.contains("effective_mode_diameter")): + fiber = self.get("gamma", specified_parameters=["dispersion_file"]) + self.setdefault("model", "custom") + + else: + fiber = self.get("model") + + if self.model == "pcf": + fiber = self.get_fiber("pitch") + fiber = self.get_fiber("pitch_ratio") + + elif self.model == "hasan": + fiber = self.get_multiple( + fiber, ["capillary_spacing", "capillary_outer_d"], 1, fiber_model="hasan" + ) + for param in [ + "core_radius", + "capillary_num", + "capillary_thickness", + "capillary_resonance_strengths", + "capillary_nested", + ]: + fiber = self.get_fiber(param) + else: + for param in hc_model_specific_parameters[self.model]: + fiber = self.get_fiber(param) + for param in ["length", "input_transmission"]: + fiber = self.get(param) + + def gas_consistency(self): + for param in ["gas_name", "temperature", "pressure", "plasma_density"]: + self.get(param, specified_params=["gas"]) + + def pulse_consistency(self): + for param in ["wavelength", "quantum_noise", "intensity_noise"]: + self.get(param) + + if not self.contains("field_file"): + self.get("shape") + + if self.contains("soliton_num"): + self.get_multiple( + ["peak_power", "mean_power", "energy", "width", "t0"], + 1, + specified_parameters=["soliton_num"], + ) + + else: + self.get_multiple(["t0", "width"], 1) + self.get_multiple(["peak_power", "energy", "mean_power"], 1) + if self.contains("mean_power"): + self.get("repetition_rate", specified_parameters=["mean_power"]) + + def simulation_consistency(self): + self.get_multiple(["dt", "t_num", "time_window"], 2) + + for param in [ + "behaviors", + "z_num", + "frep", + "tolerated_error", + "parallel", + "repeat", + "lower_wavelength_interp_limit", + "upper_wavelength_interp_limit", + "ideal_gas", + "readjust_wavelength", + "recovery_last_stored", + ]: + self.get(param) + + if ( + any(["raman" in l for l in self.variable.get("behaviors", [])]) + or "raman" in self.behaviors + ): + self.get("raman_type", specified_parameters=["raman"]) + + def contains(self, key): + return self.variable.get(key) is not None or getattr(self, key) is not None + + def get(self, param, **kwargs) -> Any: + """checks if param is in the parameter section dict and attempts to fill in a default value + + Parameters + ---------- + param : str + the name of the parameter (dict key) + kwargs : any + key word arguments passed to the MissingParameterError constructor + + Raises + ------ + MissingFiberParameterError + raised when a parameter is missing and no default exists + """ + + # whether the parameter is in the right place and valid is checked elsewhere, + # here, we just make sure it is present. + if not self.contains(param): + try: + setattr(self, param, default_parameters[param]) + except KeyError: + raise MissingParameterError(param, **kwargs) + + def get_fiber(self, param, **kwargs): + """wrapper for fiber parameters that depend on fiber model""" + self.get(param, fiber_model=self.model, **kwargs) + + def get_multiple(self, params, num, **kwargs): + """similar to the get method but works with several parameters + + Parameters + ---------- + params : list of str + names of the required parameters + num : int + how many of the parameters in params are required + + Raises + ------ + MissingParameterError + raised when not enough parameters are provided and no defaults exist + """ + gotten = 0 + for param in params: + try: + self.get(param, **kwargs) + gotten += 1 + except MissingParameterError: + pass + if gotten >= num: + return + raise MissingParameterError(params, num_required=num, **kwargs) + + def setdefault(self, param, value): + if getattr(self, param) is None: + setattr(self, param, value) + + +class ParamSequence: + def __init__(self, config_dict: Union[Dict[str, Any], os.PathLike, BareConfig]): + """creates a param sequence from a base config + + Parameters + ---------- + config_dict : Union[Dict[str, Any], os.PathLike, BareConfig] + Can be either a dictionary, a path to a config toml file or BareConfig obj + """ + if isinstance(config_dict, BareConfig): + self.config = config_dict + else: + if not isinstance(config_dict, Mapping): + config_dict = io.load_toml(config_dict) + self.config = Config(**config_dict) + self.name = self.config.name self.logger = get_logger(__name__) self.num_sim, self.num_variable = count_variations(self.config) - self.num_steps = self.num_sim * self.config["simulation"]["z_num"] + self.num_steps = self.num_sim * self.config.z_num self.single_sim = self.num_sim == 1 - def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Dict[str, Any]]]: + def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: """iterates through all possible parameters, yielding a config as well as a flattened computed parameters set each time""" - for variable_list, full_config in required_simulations(self.config): - yield variable_list, compute_init_parameters(full_config) + for variable_list, bare_params in required_simulations(self.config): + yield variable_list, Params.from_bare(bare_params) def __len__(self): return self.num_sim - def __getitem__(self, key): - return self.config[key[0]][key[1]] - - def __str__(self) -> str: + def __repr__(self) -> str: return f"dispatcher generated from config {self.name}" class ContinuationParamSequence(ParamSequence): - def __init__(self, prev_sim_dir: str, new_config: Dict[str, Any]): + def __init__(self, prev_sim_dir: os.PathLike, new_config_dict: Dict[str, Any]): """Parameter sequence that builds on a previous simulation but with a new configuration It is recommended that only the fiber and the number of points stored may be changed and changing other parameters could results in unexpected behaviors. The new config doesn't have to @@ -54,31 +361,29 @@ class ContinuationParamSequence(ParamSequence): Parameters ---------- - prev_sim_dir : str + prev_sim_dir : PathLike path to the folder of the previous simulation containing 'initial_config.toml' new_config : Dict[str, Any] new config """ self.prev_sim_dir = Path(prev_sim_dir) - init_config = io.load_previous_parameters( - os.path.join(self.prev_sim_dir, "initial_config.toml") - ) + init_config = io.load_config(self.prev_sim_dir / "initial_config.toml") self.prev_variable_lists = [ (set(variable_list[1:]), self.prev_sim_dir / utils.format_variable_list(variable_list)) for variable_list, _ in required_simulations(init_config) ] - new_config = utils.override_config(new_config, init_config) + new_config = utils.override_config(new_config_dict, init_config) super().__init__(new_config) - def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Dict[str, Any]]]: + def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: """iterates through all possible parameters, yielding a config as well as a flattened computed parameters set each time""" - for variable_list, full_config in required_simulations(self.config): + for variable_list, bare_params in required_simulations(self.config): prev_data_dir = self.find_prev_data_dir(variable_list).resolve() - full_config["prev_data_dir"] = str(prev_data_dir) - yield variable_list, compute_init_parameters(full_config) + bare_params.prev_data_dir = str(prev_data_dir) + yield variable_list, Params.from_bare(bare_params) def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path: """finds the previous simulation data that this new config should start from @@ -109,33 +414,30 @@ class ContinuationParamSequence(ParamSequence): class RecoveryParamSequence(ParamSequence): - def __init__(self, config, task_id): - super().__init__(config) + def __init__(self, config_dict, task_id): + super().__init__(config_dict) self.id = task_id self.num_steps = 0 - z_num = config["simulation"]["z_num"] - started = self.num_sim + not_started = self.num_sim sub_folders = io.get_data_dirs(io.get_sim_dir(self.id)) for sub_folder in utils.PBars( sub_folders, "Initial recovery", head_kwargs=dict(unit="sim") ): - num_left = io.num_left_to_propagate(sub_folder, z_num) + num_left = io.num_left_to_propagate(sub_folder, self.config.z_num) if num_left == 0: self.num_sim -= 1 self.num_steps += num_left - started -= 1 + not_started -= 1 - self.num_steps += started * z_num + self.num_steps += not_started * self.config.z_num self.single_sim = self.num_sim == 1 self.prev_sim_dir = None - if "prev_sim_dir" in self.config.get("simulation", {}): - self.prev_sim_dir = Path(self.config["simulation"]["prev_sim_dir"]) - init_config = io.load_previous_parameters( - os.path.join(self.prev_sim_dir, "initial_config.toml") - ) + if self.config.prev_sim_dir is not None: + self.prev_sim_dir = Path(self.config.prev_sim_dir) + init_config = io.load_config(self.prev_sim_dir / "initial_config.toml") self.prev_variable_lists = [ ( set(variable_list[1:]), @@ -144,17 +446,17 @@ class RecoveryParamSequence(ParamSequence): for variable_list, _ in required_simulations(init_config) ] - def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: - for variable_list, params in required_simulations(self.config): + def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: + for variable_list, bare_params in required_simulations(self.config): data_dir = io.get_sim_dir(self.id) / utils.format_variable_list(variable_list) if not data_dir.is_dir() or io.find_last_spectrum_num(data_dir) == 0: if (prev_data_dir := self.find_prev_data_dir(variable_list)) is not None: - params["prev_data_dir"] = str(prev_data_dir) - yield variable_list, compute_init_parameters(params) - elif io.num_left_to_propagate(data_dir, self.config["simulation"]["z_num"]) != 0: - yield variable_list, recover_params(params, data_dir) + bare_params.prev_data_dir = str(prev_data_dir) + yield variable_list, Params.from_bare(bare_params) + elif io.num_left_to_propagate(data_dir, self.config.z_num) != 0: + yield variable_list, recover_params(bare_params, data_dir) else: continue @@ -188,24 +490,7 @@ class RecoveryParamSequence(ParamSequence): ) -def validate(config: dict) -> dict: - """validates a configuration dictionary and attempts to fill in defaults - - Parameters - ---------- - config : dict - loaded configuration - - Returns - ------- - dict - updated configuration - """ - _validate_types(config) - return _ensure_consistency(config) - - -def validate_config_sequence(*configs: os.PathLike) -> Dict[str, Any]: +def validate_config_sequence(*configs: os.PathLike) -> Config: """validates a sequence of configs where all but the first one may have parameters missing @@ -224,8 +509,7 @@ def validate_config_sequence(*configs: os.PathLike) -> Dict[str, Any]: if (p := Path(config)).is_dir(): config = p / "initial_config.toml" dico = io.load_toml(config) - previous = override_config(dico, previous) - validate(previous) + previous = Config.from_bare(override_config(dico, previous)) return previous @@ -290,499 +574,44 @@ def tspace(time_window=None, t_num=None, dt=None): raise TypeError("not enough parameter to determine time vector") -def validate_single_parameter(section: str, key: str, value: Any): +def recover_params(params: BareParams, data_folder: Path) -> Params: + params = Params.from_bare(params) try: - func = valid_param_types[section][key] - except KeyError: - s = f"The parameter '{key}' does not belong " - if section == "root": - s += "at the root of the config file" - else: - s += f"in the category '{section}'" - s += ". Make sure it is a valid parameter in the first place" - raise TypeError(s) - if not func(value): - raise TypeError( - f"value '{value}' of type {type(value).__name__} for key '{key}' is not valid, {func.__doc__}" + prev = io.load_params(data_folder / "params.toml") + ( + prev.z_targets, + prev.t, + prev.time_window, + prev.t_num, + prev.dt, + prev.w_c, + prev.w0, + prev.w, + prev.w_power_fact, + ) = build_sim_grid( + prev.length, prev.z_num, prev.wavelength, prev.time_window, prev.t_num, prev.dt ) - return - - -def _validate_types(config): - """validates the data types in the initial config dictionary - - Parameters - ---------- - config : dict - the initial config dictionary - - Raises - ------ - TypeError - raised when a parameter has the wrong type - """ - - for domain, parameters in config.items(): - if isinstance(parameters, dict): - for param_name, param_value in parameters.items(): - if param_name == "variable": - for k_vary, v_vary in param_value.items(): - if not isinstance(v_vary, list): - raise TypeError(f"Variable parameters should be specified in a list") - - if len(v_vary) < 1: - raise ValueError( - f"Variable parameters lists should contain at least 1 element" - ) - - if k_vary not in valid_variable[domain]: - raise TypeError(f"'{k_vary}' is not a valid variable parameter") - - [ - validate_single_parameter(domain, k_vary, v_vary_indiv) - for v_vary_indiv in v_vary - ] - else: - validate_single_parameter(domain, param_name, param_value) - else: - validate_single_parameter("root", domain, parameters) - - -def _contains(sub_conf, param): - return param in sub_conf or param in sub_conf.get("variable", {}) - - -def _ensure_consistency_fiber(fiber: Dict[str, Any]): - """ensure the fiber sub-dictionary of the parameter set is consistent - - Parameters - ---------- - fiber : dict - dictionary containing the fiber parameters - - Returns - ------- - dict - the updated dictionary - - Raises - ------ - MissingParameterError - When at least one required parameter with no default is missing - """ - - if _contains(fiber, "beta"): - if not (_contains(fiber, "A_eff") or _contains(fiber, "effective_mode_diameter")): - fiber = defaults.get(fiber, "gamma", specified_parameters=["beta"]) - fiber.setdefault("model", "custom") - - elif _contains(fiber, "dispersion_file"): - if not (_contains(fiber, "A_eff") or _contains(fiber, "effective_mode_diameter")): - fiber = defaults.get(fiber, "gamma", specified_parameters=["dispersion_file"]) - fiber.setdefault("model", "custom") - - else: - fiber = defaults.get(fiber, "model") - - if fiber["model"] == "pcf": - fiber = defaults.get_fiber(fiber, "pitch") - fiber = defaults.get_fiber(fiber, "pitch_ratio") - - elif fiber["model"] == "hasan": - fiber = defaults.get_multiple( - fiber, ["capillary_spacing", "capillary_outer_d"], 1, fiber_model="hasan" - ) - for param in [ - "core_radius", - "capillary_num", - "capillary_thickness", - "capillary_resonance_strengths", - "capillary_nested", - ]: - fiber = defaults.get_fiber(fiber, param) - else: - for param in hc_model_specific_parameters[fiber["model"]]: - fiber = defaults.get_fiber(fiber, param) - for param in ["length", "input_transmission"]: - fiber = defaults.get(fiber, param) - return fiber - - -def _ensure_consistency_gas(gas): - """ensure the gas sub-dictionary of the parameter set is consistent - - Parameters - ---------- - gas : dict - dictionary containing the gas parameters - - Returns - ------- - dict - the updated dictionary - - Raises - ------ - MissingParameterError - When at least one required parameter with no default is missing - """ - for param in ["gas_name", "temperature", "pressure", "plasma_density"]: - gas = defaults.get(gas, param, specified_params=["gas"]) - return gas - - -def _ensure_consistency_pulse(pulse): - """ensure the pulse sub-dictionary of the parameter set is consistent - - Parameters - ---------- - pulse : dict - dictionary of the pulse section of parameters - - Returns - ------- - dict - the updated pulse dictionary - - Raises - ------ - MissingParameterError - When at least one required parameter with no default is missing - """ - for param in ["wavelength", "quantum_noise", "intensity_noise"]: - pulse = defaults.get(pulse, param) - - if not _contains(pulse, "field_file"): - pulse = defaults.get(pulse, "shape") - - if _contains(pulse, "soliton_num"): - pulse = defaults.get_multiple( - pulse, - ["peak_power", "mean_power", "energy", "width", "t0"], - 1, - specified_parameters=["soliton_num"], - ) - - else: - pulse = defaults.get_multiple(pulse, ["t0", "width"], 1) - pulse = defaults.get_multiple(pulse, ["peak_power", "energy", "mean_power"], 1) - if _contains(pulse, "mean_power"): - pulse = defaults.get(pulse, "repetition_rate", specified_parameters=["mean_power"]) - return pulse - - -def _ensure_consistency_simulation(simulation): - """ensure the simulation sub-dictionary of the parameter set is consistent - - Parameters - ---------- - pulse : dict - dictionary of the pulse section of parameters - - Returns - ------- - dict - the updated pulse dictionary - - Raises - ------ - MissingParameterError - When at least one required parameter with no default is missing - """ - simulation = defaults.get_multiple(simulation, ["dt", "t_num", "time_window"], 2) - - for param in [ - "behaviors", - "z_num", - "frep", - "tolerated_error", - "parallel", - "repeat", - "lower_wavelength_interp_limit", - "upper_wavelength_interp_limit", - "ideal_gas", - "readjust_wavelength", - ]: - simulation = defaults.get(simulation, param) - - if "raman" in simulation.get("behaviors", {}) or any( - ["raman" in l for l in simulation.get("variable", {}).get("behaviors", [])] - ): - simulation = defaults.get(simulation, "raman_type", specified_parameters=["raman"]) - return simulation - - -def _ensure_consistency(config): - """ensure the config dictionary is consistent and that certain parameters are set, - either by filling in defaults or by raising an error. This is not where new values are calculated. - - Parameters - ---------- - config : dict - original config dict loaded from the toml file - - Returns - ------- - dict - the consistent config dict - """ - - _validate_types(config) - - # ensure parameters are not specified multiple times - for sub_dict in valid_param_types.values(): - for param_name in sub_dict: - for set_param in config.values(): - if isinstance(set_param, dict): - if param_name in set_param and param_name in set_param.get("variable", {}): - raise DuplicateParameterError( - f"got multiple values for parameter '{param_name}'" - ) - - # ensure every required parameter has a value - config["name"] = config.get("name", "no name") - - config["fiber"] = _ensure_consistency_fiber(config.get("fiber", {})) - - if config["fiber"]["model"] in hc_model_specific_parameters: - config["gas"] = _ensure_consistency_gas(config.get("gas", {})) - - config["pulse"] = _ensure_consistency_pulse(config.get("pulse", {})) - config["simulation"] = _ensure_consistency_simulation(config.get("simulation", {})) - - return config - - -def recover_params(config: Dict[str, Any], data_folder: Path) -> Dict[str, Any]: - params = compute_init_parameters(config) - try: - prev_params = io.load_previous_parameters(data_folder / "params.toml") - prev_params = build_sim_grid(prev_params) except FileNotFoundError: - prev_params = {} - for k, v in prev_params.items(): - params.setdefault(k, v) + prev = BareParams() + for k, v in filter(lambda el: el[1] is not None, vars(prev).items()): + if getattr(params, k) is None: + setattr(params, k, v) num, last_spectrum = io.load_last_spectrum(data_folder) - params["spec_0"] = last_spectrum - params["field_0"] = np.fft.ifft(last_spectrum) - params["recovery_last_stored"] = num - params["cons_qty"] = np.load(data_folder / "cons_qty.npy") + params.spec_0 = last_spectrum + params.field_0 = np.fft.ifft(last_spectrum) + params.recovery_last_stored = num + params.cons_qty = np.load(data_folder / "cons_qty.npy") return params -def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]: - """computes all derived values from a config dictionary - - Parameters - ---------- - config : dict - a configuration dictionary containing the pulse, fiber and simulation sections with no variable parameter. - a flattened parameters dictionary may be provided instead - Note : checking the validity of the configuration shall be done before calling this function. - - Returns - ------- - dict - a flattened dictionary (no fiber, pulse, simulation subsections) with all the necessary values to run RK4IP - """ - - logger = get_logger(__name__) - - # copy and flatten the config - params = {k: v for k, v in config.items() if isinstance(v, (str, int, float))} - for section in ["pulse", "fiber", "simulation", "gas"]: - for key, value in config.get(section, {}).items(): - params[key] = value - - params = build_sim_grid(params) - - # Initial field may influence the grid - if "mean_power" in params: - params["energy"] = params["mean_power"] / params["repetition_rate"] - custom_field = setup_custom_field(params) - - if "step_size" in params: - params["error_ok"] = params["step_size"] - params["adapt_step_size"] = False - else: - params["error_ok"] = params["tolerated_error"] - params["adapt_step_size"] = True - - # FIBER - params["interp_range"] = _interp_range( - params["w"], - params["upper_wavelength_interp_limit"], - params["lower_wavelength_interp_limit"], - ) - - temp_gamma = None - if "effective_mode_diameter" in params: - params["A_eff"] = (params["effective_mode_diameter"] / 2) ** 2 * pi - if "beta" in params: - params["beta"] = np.array(params["beta"]) - params["dynamic_dispersion"] = False - else: - params["dynamic_dispersion"] = fiber.is_dynamic_dispersion(params) - params["beta"], temp_gamma = fiber.dispersion_central(params["model"], params) - if params["dynamic_dispersion"]: - params["gamma_func"] = temp_gamma - params["beta_func"] = params["beta"] - params["beta"] = params["beta_func"](0) - temp_gamma = temp_gamma(0) - - if "gamma" not in params: - params["gamma"] = temp_gamma - logger.info(f"using computed \u0263 = {params['gamma']:.2e} W/m^2") - - # Raman response - if "raman" in params["behaviors"]: - params["hr_w"] = fiber.delayed_raman_w(params["t"], params["dt"], params["raman_type"]) - - # GENERIC PULSE - if not custom_field: - custom_field = False - params = _update_pulse_parameters(params) - logger.info(f"computed initial N = {params['soliton_num']:.3g}") - - params["L_D"] = params["t0"] ** 2 / abs(params["beta"][0]) - params["L_NL"] = 1 / (params["gamma"] * params["peak_power"]) if params["gamma"] else np.inf - params["L_sol"] = pi / 2 * params["L_D"] - - # Technical noise - if "intensity_noise" in params: - params = _technical_noise(params) - - params["field_0"] = pulse.initial_field( - params["t"], params["shape"], params["t0"], params["peak_power"] - ) - - if params["quantum_noise"]: - params["field_0"] = params["field_0"] + pulse.shot_noise( - params["w_c"], params["w0"], params["time_window"], params["dt"] - ) - - params["spec_0"] = np.fft.fft(params["field_0"]) - - return params - - -def setup_custom_field(params: Dict[str, Any]) -> bool: - """sets up a custom field function if necessary and returns - True if it did so, False otherwise - - Parameters - ---------- - params : Dict[str, Any] - params dictionary - - Returns - ------- - bool - True if the field has been modified - """ - logger = get_logger(__name__) - if "prev_data_dir" in params: - spec = io.load_last_spectrum(Path(params["prev_data_dir"]))[1] - params["field_0"] = np.fft.ifft(spec) * np.sqrt(params["input_transmission"]) - else: - if "field_file" in params: - field_data = np.load(params["field_file"]) - field_interp = interp1d( - field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0) - ) - params["field_0"] = field_interp(params["t"]) - elif "field_0" in params: - params = _evalutate_custom_field_equation(params) - else: - return False - - params["field_0"] = params["field_0"] * pulse.modify_field_ratio( - params["t"], - params["field_0"], - params.get("peak_power"), - params.get("energy"), - params.get("intensity_noise"), - ) - params["width"], params["peak_power"], params["energy"] = pulse.measure_field( - params["t"], params["field_0"] - ) - if params.get("readjust_wavelength", False): - delta_w = params["w_c"][np.argmax(abs2(np.fft.fft(params["field_0"])))] - logger.debug(f"adjusted w by {delta_w}") - params["wavelength"] = units.m.inv(units.m(params["wavelength"]) - delta_w) - _update_frequency_domain(params) - return True - - -def _update_pulse_parameters(params): - ( - params["width"], - params["t0"], - params["peak_power"], - params["energy"], - params["soliton_num"], - ) = pulse.conform_pulse_params( - shape=params["shape"], - width=params.get("width", None), - t0=params.get("t0", None), - peak_power=params.get("peak_power", None), - energy=params.get("energy", None), - gamma=params["gamma"], - beta2=params["beta"][0], - ) - return params - - -def _evalutate_custom_field_equation(params): - field_info = params["field_0"] - if isinstance(field_info, str): - field_0 = eval( - field_info, - dict( - sin=np.sin, - cos=np.cos, - tan=np.tan, - exp=np.exp, - pi=np.pi, - sqrt=np.sqrt, - **params, - ), - ) - - params["field_0"] = field_0 - elif len(field_info) != params["t_num"]: - raise ValueError( - "initial field is given but doesn't match size and type with the time array" - ) - return params - - -def _technical_noise(params): - logger = get_logger(__name__) - - if params["intensity_noise"] > 0: - logger.info(f"intensity noise of {params['intensity_noise']}") - delta_int, delta_T0 = pulse.technical_noise(params["intensity_noise"]) - params["peak_power"] *= delta_int - params["t0"] *= delta_T0 - params["width"] *= delta_T0 - params = _update_pulse_parameters(params) - return params - - -def _interp_range(w, upper, lower): - # by default, the interpolation range of the dispersion polynomial stops exactly - # at the boundary of the frequency window we consider - - interp_range = [ - max(lower, units.m.inv(np.max(w[w > 0]))), - min(upper, units.m.inv(np.min(w[w > 0]))), - ] - - return interp_range - - -def build_sim_grid(params): +def build_sim_grid( + length: float, + z_num: int, + wavelength: float, + time_window: float = None, + t_num: int = None, + dt: float = None, +): """computes a bunch of values that relate to the simulation grid Parameters @@ -795,61 +624,19 @@ def build_sim_grid(params): dict updated parameter dictionary """ - t = params.get( - "t", - tspace( - time_window=params.get("time_window", None), - t_num=params.get("t_num", None), - dt=params.get("dt", None), - ), - ) - params["t"] = t - params["time_window"] = length(t) - params["dt"] = t[1] - t[0] - params["t_num"] = len(t) - params["z_targets"] = np.linspace(0, params["length"], params["z_num"]) - params = _update_frequency_domain(params) - return params + t = tspace(time_window, t_num, dt) + + time_window = t.max() - t.min() + dt = t[1] - t[0] + t_num = len(t) + z_targets = np.linspace(0, length, z_num) + w_c, w0, w, w_power_fact = update_frequency_domain(t, wavelength) + return z_targets, t, time_window, t_num, dt, w_c, w0, w, w_power_fact -def _update_frequency_domain(params): - w_c = wspace(params["t"]) - w0 = units.m(params["wavelength"]) - params["w0"] = w0 - params["w_c"] = w_c - params["w"] = w_c + w0 - params["w_power_fact"] = np.array([power_fact(w_c, k) for k in range(2, 11)]) - return params - - -def sanitize_z_targets(z_targets): - """ - processes the 'z_targets' arguments and guarantees that: - - it is sorted - - it doesn't contain the same value twice - - it starts with 0 - Parameters - ---------- - z_targets : float, int or array-like - float or int : end point of the fiber starting from 0 - array-like of len(.) == 3 : `numpy.linspace` arguments - array-like of other length : target distances at which to store the spectra - Returns - ---------- - z_targets : list (mutability is important) - """ - if isinstance(z_targets, (float, int)): - z_targets = np.linspace(0, z_targets, defaults.default_parameters["length"]) - else: - z_targets = np.array(z_targets).flatten() - - if len(z_targets) == 3: - z_targets = np.linspace(*z_targets[:2], int(z_targets[2])) - - z_targets = list(set(value for value in z_targets if value >= 0)) - z_targets.sort() - - if 0 not in z_targets: - z_targets = [0] + z_targets - - return z_targets +def update_frequency_domain(t, wavelength): + w_c = wspace(t) + w0 = units.m(wavelength) + w = w_c + w0 + w_power_fact = np.array([power_fact(w_c, k) for k in range(2, 11)]) + return w_c, w0, w, w_power_fact diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 4f0d36a..f1d0f9f 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -1,3 +1,4 @@ +from dataclasses import asdict import itertools import os import shutil @@ -11,18 +12,17 @@ import toml from . import env, utils from .const import ( - __version__, - ENVIRON_KEY_BASE, PARAM_FN, PARAM_SEPARATOR, - PBAR_POLICY, SPEC1_FN, SPECN_FN, TMP_FOLDER_KEY_BASE, Z_FN, + __version__, ) from .errors import IncompleteDataFolderError from .logger import get_logger +from .utils.parameter import BareConfig, BareParams PathTree = List[Tuple[Path, ...]] @@ -88,6 +88,10 @@ def load_toml(path: os.PathLike): path = conform_toml_path(path) with open(path, mode="r") as file: dico = toml.load(file) + + for section in ["simulation", "fiber", "pulse", "gas"]: + dico.update(dico.pop(section, {})) + return dico @@ -99,52 +103,15 @@ def save_toml(path: os.PathLike, dico): return dico -def serializable(val): - """returns True if val is serializable into a Json file""" - types = (np.ndarray, float, int, str, list, tuple) - - out = isinstance(val, types) - if isinstance(val, np.ndarray): - out &= val.dtype != "complex" - return out - - -def prepare_for_serialization(dico: Dict[str, Any]) -> Dict[str, Any]: - """prepares a dictionary for serialization. Some keys may not be preserved - (dropped due to no conversion available) - - Parameters - ---------- - dico : dict - dictionary - """ - forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"] - types = (np.ndarray, float, int, str, list, tuple, dict) - out = {} - for key, value in dico.items(): - if key in forbiden_keys: - continue - if not isinstance(value, types): - continue - if isinstance(value, dict): - out[key] = prepare_for_serialization(value) - elif isinstance(value, np.ndarray) and value.dtype == complex: - continue - else: - out[key] = value - - return out - - -def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path: +def save_parameters(params: BareParams, destination_dir: Path, file_name="params.toml") -> Path: """saves a parameter dictionary. Note that is does remove some entries, particularly those that take a lot of space ("t", "w", ...) Parameters ---------- - param_dict : Dict[str, Any] + params : Dict[str, Any] dictionary to save - data_dir : Path + destination_dir : Path destination directory Returns @@ -152,12 +119,8 @@ def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path: Path path to newly created the paramter file """ - param = param_dict.copy() - file_path = destination_dir / "params.toml" - - param = prepare_for_serialization(param) - param["datetime"] = datetime.now() - param["version"] = __version__ + param = params.prepare_for_dump() + file_path = destination_dir / file_name file_path.parent.mkdir(exist_ok=True) @@ -168,7 +131,7 @@ def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path: return file_path -def load_previous_parameters(path: os.PathLike): +def load_params(path: os.PathLike) -> BareParams: """loads a parameters toml files and converts data to appropriate type It is advised to run initialize.build_sim_grid to recover some parameters that are not saved. @@ -179,15 +142,29 @@ def load_previous_parameters(path: os.PathLike): Returns ---------- - dict - flattened parameters dictionary + BareParams + params obj """ params = load_toml(path) + return BareParams(**params) - for k, v in params.items(): - if isinstance(v, list) and isinstance(v[0], (float, int)): - params[k] = np.array(v) - return params + +def load_config(path: os.PathLike) -> BareConfig: + """loads a parameters toml files and converts data to appropriate type + It is advised to run initialize.build_sim_grid to recover some parameters that are not saved. + + Parameters + ---------- + path : PathLike + path to the toml + + Returns + ---------- + BareParams + config obj + """ + config = load_toml(path) + return BareConfig(**config) def load_material_dico(name): diff --git a/src/scgenerator/logger.py b/src/scgenerator/logger.py index a971056..50f5299 100644 --- a/src/scgenerator/logger.py +++ b/src/scgenerator/logger.py @@ -1,7 +1,6 @@ import logging -from typing import Optional -from .env import log_policy +from .env import log_policy # class DebugOnlyFileHandler(logging.FileHandler): # def __init__( diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index e7276fa..6e1b0c5 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -1,8 +1,8 @@ -from typing import Type, Union +from typing import Union + import numpy as np +from scipy.interpolate import griddata, interp1d from scipy.special import jn_zeros -from scipy.interpolate import interp1d, griddata -from numba import jit def span(*vec): @@ -54,7 +54,6 @@ def power_fact(x, n): raise TypeError(f"type {type(x)} of x not supported.") -@jit(nopython=True) def _power_fact_single(x, n): result = 1.0 for k in range(n): @@ -62,7 +61,6 @@ def _power_fact_single(x, n): return result -@jit(nopython=True) def _power_fact_array(x, n): result = np.ones(len(x), dtype=np.float64) for k in range(n): @@ -70,7 +68,6 @@ def _power_fact_array(x, n): return result -@jit(nopython=True) def abs2(z: np.ndarray) -> np.ndarray: return z.real ** 2 + z.imag ** 2 @@ -133,4 +130,4 @@ def make_uniform_1D(values, x_axis, n=1024, method="linear"): array of length n """ xx = np.linspace(*span(x_axis), len(x_axis)) - return interp1d(x_axis, values, kind=method)(xx) \ No newline at end of file + return interp1d(x_axis, values, kind=method)(xx) diff --git a/src/scgenerator/parameters.py b/src/scgenerator/parameters.py deleted file mode 100644 index 7e112a3..0000000 --- a/src/scgenerator/parameters.py +++ /dev/null @@ -1,36 +0,0 @@ -class Parameter: - """base class for parameters""" - - all = dict(fiber=dict(), pulse=dict(), gas=dict(), simulation=dict()) - help_message = "no help message lol" - - def __init_subclass__(cls, section): - Parameter.all[section][cls.__name__.lower()] = cls - - def __init__(self, s): - self.s = s - valid = True - try: - self.value = self._convert() - valid = self.valid() - except ValueError: - valid = False - - if not valid: - raise ValueError( - f"{self.__class__.__name__} {self.__class__.help_message}. input : {self.s}" - ) - - def _convert(self): - value = self.conversion_func(self.s) - return value - - -class Wavelength(Parameter, section="pulse"): - help_message = "must be a strictly positive real number" - - def valid(self): - return self.value > 0 - - def conversion_func(self, s: str) -> float: - return float(s) diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 3fd3333..b97365a 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -1,15 +1,14 @@ +from typing import Any, Dict, List, Tuple + import numpy as np -from numpy.lib import disp -from numpy.lib.arraysetops import isin import toml -from numba import jit from numpy.fft import fft, ifft from numpy.polynomial.chebyshev import Chebyshev, cheb2poly from scipy.interpolate import interp1d from .. import io -from ..const import hc_model_specific_parameters from ..math import abs2, argclosest, power_fact, u_nm +from ..utils.parameter import BareParams, hc_model_specific_parameters from . import materials as mat from . import units from .units import c, pi @@ -25,7 +24,7 @@ def lambda_for_dispersion(): return np.linspace(190e-9, 3000e-9, 4000) -def is_dynamic_dispersion(params): +def is_dynamic_dispersion(pressure=None): """tests if the parameter dictionary implies that the dispersion profile of the fiber changes with z Parameters @@ -38,8 +37,8 @@ def is_dynamic_dispersion(params): bool : True if dispersion is supposed to change with z """ out = False - if "pressure" in params: - out |= isinstance(params["pressure"], (tuple, list)) and len(params["pressure"]) == 2 + if pressure is not None: + out |= isinstance(pressure, (tuple, list)) and len(pressure) == 2 return out @@ -483,7 +482,19 @@ def HCPCF_dispersion( return beta2(w, n_eff) -def dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg): +def dynamic_HCPCF_dispersion( + lambda_: np.ndarray, + pressure_values: List[float], + core_radius: float, + fiber_model: str, + model_params: Dict[str, Any], + temperature: float, + ideal_gas: bool, + w0: float, + interp_range: Tuple[float, float], + material_dico: Dict[str, Any], + deg, +): """returns functions for beta2 coefficients and gamma instead of static values Parameters @@ -504,25 +515,22 @@ def dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg): in the fiber """ - # store values because storing functions acts weird with dict - pressure_values = params["pressure"] - a = params["core_radius"] - fiber_model = params["fiber_model"] - model_params = {k: params[k] for k in hc_model_specific_parameters[fiber_model]} - temp = params["temperature"] - ideal_gas = params["ideal_gas"] - w0 = params["w0"] - interp_range = params["interp_range"] - - A_eff = 1.5 * a ** 2 + A_eff = 1.5 * core_radius ** 2 # defining function instead of storing every possilble value pressure = lambda r: mat.pressure_from_gradient(r, *pressure_values) beta2 = lambda r: HCPCF_dispersion( - lambda_, a, material_dico, fiber_model, model_params, pressure(r), temp, ideal_gas + lambda_, + core_radius, + material_dico, + fiber_model, + model_params, + pressure(r), + temperature, + ideal_gas, ) - n2 = lambda r: mat.non_linear_refractive_index(material_dico, pressure(r), temp) + n2 = lambda r: mat.non_linear_refractive_index(material_dico, pressure(r), temperature) ratio_range = np.linspace(0, 1, 256) gamma_grid = np.array([gamma_parameter(n2(r), w0, A_eff) for r in ratio_range]) @@ -640,7 +648,7 @@ def PCF_dispersion(lambda_, pitch, ratio_d, w0=None, n2=None, A_eff=None): return beta2, gamma -def dispersion_central(fiber_model, params, deg=8): +def compute_dispersion(params: BareParams, deg=8): """dispatch function depending on what type of fiber is used Parameters @@ -660,8 +668,8 @@ def dispersion_central(fiber_model, params, deg=8): nonlinear parameter """ - if "dispersion_file" in params: - disp_file = np.load(params["dispersion_file"]) + if params.dispersion_file is not None: + disp_file = np.load(params.dispersion_file) lambda_ = disp_file["wavelength"] D = disp_file["dispersion"] beta2 = D_to_beta2(D, lambda_) @@ -669,21 +677,20 @@ def dispersion_central(fiber_model, params, deg=8): else: lambda_ = lambda_for_dispersion() beta2 = np.zeros_like(lambda_) - fiber_model = fiber_model.lower() - if fiber_model == "pcf": + if params.model == "pcf": beta2, gamma = PCF_dispersion( lambda_, - params["pitch"], - params["pitch_ratio"], - w0=params["w0"], - n2=params.get("n2"), - A_eff=params.get("A_eff"), + params.pitch, + params.pitch_ratio, + w0=params.w0, + n2=params.n2, + A_eff=params.A_eff, ) else: # Load material info - gas_name = params["gas_name"] + gas_name = params.gas_name if gas_name == "vacuum": material_dico = None @@ -691,8 +698,20 @@ def dispersion_central(fiber_model, params, deg=8): material_dico = toml.loads(io.Paths.gets("gas"))[gas_name] # compute dispersion - if params.get("dynamic_dispersion", False): - return dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg) + if params.dynamic_dispersion: + return dynamic_HCPCF_dispersion( + lambda_, + params.pressure, + params.core_radius, + params.model, + {k: getattr(params, k) for k in hc_model_specific_parameters[params.model]}, + params.temperature, + params.ideal_gas, + params.w0, + params.interp_range, + material_dico, + deg, + ) else: # actually compute the dispersion @@ -700,31 +719,31 @@ def dispersion_central(fiber_model, params, deg=8): beta2 = HCPCF_dispersion( lambda_, material_dico, - fiber_model, - {k: params[k] for k in hc_model_specific_parameters[fiber_model]}, - params["pressure"], - params["temperature"], - params["ideal_gas"], + params.model, + {k: getattr(params, k) for k in hc_model_specific_parameters[params.model]}, + params.pressure, + params.temperature, + params.ideal_gas, ) if material_dico is not None: - A_eff = 1.5 * params["core_radius"] ** 2 + A_eff = 1.5 * params.core_radius ** 2 n2 = mat.non_linear_refractive_index( - material_dico, params["pressure"], params["temperature"] + material_dico, params.pressure, params.temperature ) - gamma = gamma_parameter(n2, params["w0"], A_eff) + gamma = gamma_parameter(n2, params.w0, A_eff) else: gamma = None # add plasma if wanted - if params["plasma_density"] > 0: - beta2 += plasma_dispersion(lambda_, params["plasma_density"]) + if params.plasma_density > 0: + beta2 += plasma_dispersion(lambda_, params.plasma_density) - beta2_coef = dispersion_coefficients(lambda_, beta2, params["w0"], params["interp_range"], deg) + beta2_coef = dispersion_coefficients(lambda_, beta2, params.w0, params.interp_range, deg) if gamma is None: - if "A_eff" in params: - gamma = gamma_parameter(params.get("n2", 2.6e-20), params["w0"], params["A_eff"]) + if params.A_eff is not None: + gamma = gamma_parameter(params.n2, params.w0, params.A_eff) else: gamma = 0 @@ -855,15 +874,13 @@ def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=Non """ # Compute raman response function if necessary + f_r = 0.18 if "raman" in behaviors: - if "hr_w" == None: - raise TypeError("freq-dependent Raman response must be give") - else: - if f_r is None: - if raman_type in ["stolen", "measured"]: - f_r = 0.18 - elif raman_type == "agrawal": - f_r = 0.245 + if hr_w is None: + raise ValueError("freq-dependent Raman response must be give") + if f_r is None: + if raman_type == "agrawal": + f_r = 0.245 if "spm" in behaviors: spm_part = lambda fi: (1 - f_r) * abs2(fi) @@ -875,7 +892,6 @@ def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=Non else: raman_part = lambda fi: 0 - spm_part = jit(spm_part, nopython=True) ss_part = w_c / w0 if "ss" in behaviors else 0 if isinstance(gamma, (float, int)): @@ -924,7 +940,6 @@ def fast_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None)): return -1j * out -@jit(nopython=True) def _fast_disp_loop(dispersion: np.ndarray, beta_arr, power_fact_arr): for k in range(len(beta_arr) - 1, -1, -1): dispersion = dispersion + beta_arr[k] * power_fact_arr[k] diff --git a/src/scgenerator/physics/materials.py b/src/scgenerator/physics/materials.py index 778a5d0..2e38a14 100644 --- a/src/scgenerator/physics/materials.py +++ b/src/scgenerator/physics/materials.py @@ -1,7 +1,6 @@ import numpy as np from ..logger import get_logger - from . import units from .units import NA, c, kB diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 00cb831..060296b 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -11,6 +11,7 @@ n is the number of spectra at the same z position and nt is the size of the time import itertools import os +from pathlib import Path from typing import Literal, Tuple import matplotlib.pyplot as plt @@ -18,13 +19,13 @@ import numpy as np from numpy import pi from numpy.fft import fft, fftshift, ifft from scipy.interpolate import UnivariateSpline -from numba import jit +from .. import io from ..defaults import default_plotting - from ..logger import get_logger -from ..plotting import plot_setup from ..math import * +from ..plotting import plot_setup +from ..utils.parameter import BareParams c = 299792458.0 hbar = 1.05457148e-34 @@ -205,6 +206,48 @@ def conform_pulse_params( return width, t0, peak_power, energy, soliton_num +def setup_custom_field(params: BareParams) -> bool: + """sets up a custom field function if necessary and returns + True if it did so, False otherwise + + Parameters + ---------- + params : Dict[str, Any] + params dictionary + + Returns + ------- + bool + True if the field has been modified + """ + field_0 = width = peak_power = energy = None + + did_set = True + + if params.prev_data_dir is not None: + spec = io.load_last_spectrum(Path(params.prev_data_dir))[1] + field_0 = np.fft.ifft(spec) * np.sqrt(params.input_transmission) + elif params.field_file is not None: + field_data = np.load(params.field_file) + field_interp = interp1d( + field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0) + ) + field_0 = field_interp(params.t) + + field_0 = field_0 * modify_field_ratio( + params.t, + field_0, + params.peak_power, + params.energy, + params.intensity_noise, + ) + width, peak_power, energy = measure_field(params.t, field_0) + else: + did_set = False + + return did_set, width, peak_power, energy, field_0 + + def E0_to_P0(E0, t0, shape="gaussian"): """convert an initial total pulse energy to a pulse peak peak_power""" return E0 / (t0 * P0T0_to_E0_fac[shape]) @@ -223,12 +266,10 @@ def gauss_pulse(t, t0, P0, offset=0): return np.sqrt(P0) * np.exp(-(((t - offset) / t0) ** 2)) -@jit(nopython=True) def photon_number(spectrum, w, dw, gamma): return np.sum(1 / gamma * abs2(spectrum) / w * dw) -@jit(nopython=True) def pulse_energy(spectrum, w, dw, _): return np.sum(abs2(spectrum) * dw) diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 23277d0..a9c7df0 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -1,7 +1,8 @@ import multiprocessing import os from datetime import datetime -from typing import Any, Dict, List, Tuple, Type +from pathlib import Path +from typing import Dict, List, Tuple, Type import numpy as np @@ -18,7 +19,14 @@ except ModuleNotFoundError: class RK4IP: - def __init__(self, sim_params, save_data=False, job_identifier="", task_id=0, n_percent=10): + def __init__( + self, + params: initialize.Params, + save_data=False, + job_identifier="", + task_id=0, + n_percent=10, + ): """A 1D solver using 4th order Runge-Kutta in the interaction picture Parameters @@ -76,31 +84,29 @@ class RK4IP: self.logger = get_logger(self.job_identifier) self.resuming = False self.save_data = save_data - self._extract_params(sim_params) - self._setup_functions() - self.starting_num = sim_params.get("recovery_last_stored", 0) - self._setup_sim_parameters() - def _extract_params(self, params): - self.w_c = params.pop("w_c") - self.w0 = params.pop("w0") - self.w_power_fact = params.pop("w_power_fact") - self.spec_0 = params.pop("spec_0") - self.z_targets = params.pop("z_targets") - self.z_final = params.pop("length") - self.beta = params.pop("beta_func", params.pop("beta")) - self.gamma = params.pop("gamma_func", params.pop("gamma")) - self.behaviors = params.pop("behaviors") - self.raman_type = params.pop("raman_type", "stolen") - self.f_r = params.pop("f_r", 0) - self.hr_w = params.pop("hr_w", None) - self.adapt_step_size = params.pop("adapt_step_size", True) - self.error_ok = params.pop("error_ok") - self.dynamic_dispersion = params.pop("dynamic_dispersion", False) + self.w_c = params.w_c + self.w0 = params.w0 + self.w_power_fact = params.w_power_fact + self.spec_0 = params.spec_0 + self.z_targets = params.z_targets + self.z_final = params.length + self.beta = params.beta_func if params.beta_func is not None else params.beta + self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma + self.behaviors = params.behaviors + self.raman_type = params.raman_type + self.hr_w = params.hr_w + self.adapt_step_size = params.adapt_step_size + self.error_ok = params.error_ok + self.dynamic_dispersion = params.dynamic_dispersion + self.starting_num = params.recovery_last_stored + + self._setup_functions() + self._setup_sim_parameters() def _setup_functions(self): self.N_func = create_non_linear_op( - self.behaviors, self.w_c, self.w0, self.gamma, self.raman_type, self.f_r, self.hr_w + self.behaviors, self.w_c, self.w0, self.gamma, self.raman_type, hr_w=self.hr_w ) if self.dynamic_dispersion: @@ -303,7 +309,7 @@ class RK4IP: class SequentialRK4IP(RK4IP): def __init__( self, - sim_params, + params: initialize.Params, pbars: utils.PBars, save_data=False, job_identifier="", @@ -312,7 +318,7 @@ class SequentialRK4IP(RK4IP): ): self.pbars = pbars super().__init__( - sim_params, + params, save_data=save_data, job_identifier=job_identifier, task_id=task_id, @@ -326,7 +332,7 @@ class SequentialRK4IP(RK4IP): class MutliProcRK4IP(RK4IP): def __init__( self, - sim_params, + params: initialize.Params, p_queue: multiprocessing.Queue, worker_id: int, save_data=False, @@ -337,7 +343,7 @@ class MutliProcRK4IP(RK4IP): self.worker_id = worker_id self.p_queue = p_queue super().__init__( - sim_params, + params, save_data=save_data, job_identifier=job_identifier, task_id=task_id, @@ -351,7 +357,7 @@ class MutliProcRK4IP(RK4IP): class RayRK4IP(RK4IP): def __init__( self, - sim_params, + params: initialize.Params, p_actor, worker_id: int, save_data=False, @@ -362,7 +368,7 @@ class RayRK4IP(RK4IP): self.worker_id = worker_id self.p_actor = p_actor super().__init__( - sim_params, + params, save_data=save_data, job_identifier=job_identifier, task_id=task_id, @@ -414,7 +420,7 @@ class Simulations: if isinstance(method, str): method = Simulations.simulation_methods_dict[method] return method(param_seq, task_id) - elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"]: + elif param_seq.num_sim > 1 and param_seq.config.parallel: return Simulations.get_best_method()(param_seq, task_id) else: return SequencialSimulations(param_seq, task_id) @@ -439,7 +445,7 @@ class Simulations: self.name = self.param_seq.name self.sim_dir = io.get_sim_dir(self.id, name_if_new=self.name) - io.save_toml(os.path.join(self.sim_dir, "initial_config.toml"), self.param_seq.config) + io.save_parameters(self.param_seq.config, self.sim_dir, file_name="initial_config.toml") self.sim_jobs_per_node = 1 self.max_concurrent_jobs = np.inf @@ -447,9 +453,7 @@ class Simulations: @property def finished_and_complete(self): try: - io.check_data_integrity( - io.get_data_dirs(self.sim_dir), self.param_seq["simulation", "z_num"] - ) + io.check_data_integrity(io.get_data_dirs(self.sim_dir), self.param_seq.config.z_num) return True except IncompleteDataFolderError: return False @@ -472,15 +476,15 @@ class Simulations: self.new_sim(v_list_str, params) self.finish() - def new_sim(self, v_list_str: str, params: dict): + def new_sim(self, v_list_str: str, params: initialize.Params): """responsible to launch a new simulation Parameters ---------- v_list_str : str string that uniquely identifies the simulation as returned by utils.format_variable_list - params : dict - a flattened parameter dictionary, as returned by initialize.compute_init_parameters + params : initialize.Params + computed parameters """ raise NotImplementedError() @@ -507,7 +511,7 @@ class SequencialSimulations(Simulations, priority=0): super().__init__(param_seq, task_id=task_id) self.pbars = utils.PBars(self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1) - def new_sim(self, v_list_str: str, params: Dict[str, Any]): + def new_sim(self, v_list_str: str, params: initialize.Params): self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}") SequentialRK4IP( params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id @@ -517,7 +521,7 @@ class SequencialSimulations(Simulations, priority=0): pass def finish(self): - pass + self.pbars.close() class MultiProcSimulations(Simulations, priority=1): @@ -553,7 +557,7 @@ class MultiProcSimulations(Simulations, priority=1): worker.start() super().run() - def new_sim(self, v_list_str: str, params: dict): + def new_sim(self, v_list_str: str, params: initialize.Params): self.queue.put((v_list_str, params), block=True, timeout=None) def finish(self): @@ -576,7 +580,7 @@ class MultiProcSimulations(Simulations, priority=1): p_queue: multiprocessing.Queue, ): while True: - raw_data: Tuple[List[tuple], Dict[str, Any]] = queue.get() + raw_data: Tuple[List[tuple], initialize.Params] = queue.get() if raw_data == 0: queue.task_done() return @@ -635,7 +639,7 @@ class RaySimulations(Simulations, priority=2): .remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps) ) - def new_sim(self, v_list_str: str, params: dict): + def new_sim(self, v_list_str: str, params: initialize.Params): while len(self.jobs) >= self.sim_jobs_total: self._collect_1_job() @@ -707,28 +711,27 @@ def new_simulation( method: Type[Simulations] = None, ) -> Simulations: - config = io.load_toml(config_file) + config_dict = io.load_toml(config_file) if prev_sim_dir is not None: - config.setdefault("simulation", {}) - config["simulation"]["prev_sim_dir"] = str(prev_sim_dir) + config_dict["prev_sim_dir"] = str(prev_sim_dir) task_id = np.random.randint(1e9, 1e12) if prev_sim_dir is None: - param_seq = initialize.ParamSequence(config) + param_seq = initialize.ParamSequence(config_dict) else: - param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config) + param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config_dict) print(f"{param_seq.name=}") return Simulations.new(param_seq, task_id, method) -def resume_simulations(sim_dir: str, method: Type[Simulations] = None) -> Simulations: +def resume_simulations(sim_dir: Path, method: Type[Simulations] = None) -> Simulations: task_id = np.random.randint(1e9, 1e12) - config = io.load_toml(os.path.join(sim_dir, "initial_config.toml")) + config = io.load_toml(sim_dir / "initial_config.toml") io.set_data_folder(task_id, sim_dir) param_seq = initialize.RecoveryParamSequence(config, task_id) diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index 1ad5603..a9cb97b 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -2,9 +2,10 @@ # For example, nm(X) means "I give the number X in nm, figure out the ang. freq." # to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ... -from numba.core.types.misc import Phantom +from typing import Callable, Union + import numpy as np -from numpy import isin, pi +from numpy import pi c = 299792458.0 hbar = 1.05457148e-34 @@ -217,7 +218,7 @@ units_map = dict( ) -def get_unit(unit): +def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]: if isinstance(unit, str): return units_map[unit] return unit diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index b45cb78..ea45460 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -1,55 +1,47 @@ import os +from pathlib import Path +from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union import matplotlib.gridspec as gs import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import ListedColormap -from scgenerator.utils import variable_iterator from scipy.interpolate import UnivariateSpline from . import io, math -from .math import abs2, length, make_uniform_1D, span -from .physics import pulse, units from .defaults import default_plotting as defaults +from .math import abs2, make_uniform_1D, span +from .physics import pulse, units +from .utils.parameter import BareParams + +RangeType = Tuple[float, float, Union[str, Callable]] def plot_setup( - folder_name=None, - file_name=None, - file_type="png", - figsize=defaults["figsize"], - params=None, - mode="default", -): + out_path: Path, + file_type: str = "png", + figsize: Tuple[float, float] = defaults["figsize"], + mode: Literal["default", "coherence", "coherence_T"] = "default", +) -> Tuple[Path, plt.Figure, Union[plt.Axes, Tuple[plt.Axes]]]: """It should return : - a folder_name - a file name - a fig - an axis """ - file_name = defaults["name"] if file_name is None else file_name + out_path = defaults["name"] if out_path is None else out_path + plot_name = out_path.stem + out_dir = out_path.resolve().parent - if params is not None: - folder_name = params.get("plot.folder_name", folder_name) - file_name = params.get("plot.file_name", file_name) - file_type = params.get("plot.file_type", file_type) - figsize = params.get("plot.figsize", figsize) + file_name = plot_name + "." + file_type + out_path = out_dir / file_name - # ensure output folder_name exists - folder_name, file_name = ( - os.path.split(file_name) - if folder_name is None - else (folder_name, os.path.split(file_name)[1]) - ) - folder_name = os.path.join(io.Paths.get("plots"), folder_name) - if not os.path.exists(os.path.abspath(folder_name)): - os.makedirs(os.path.abspath(folder_name)) + os.makedirs(out_dir, exist_ok=True) # ensure no overwrite ind = 0 - while os.path.exists(os.path.join(folder_name, file_name + "_" + str(ind) + "." + file_type)): + while (full_path := (out_dir / (plot_name + f"_{ind}." + file_type))).exists(): ind += 1 - file_name = file_name + "_" + str(ind) + "." + file_type if mode == "default": fig, ax = plt.subplots(figsize=figsize) @@ -78,7 +70,7 @@ def plot_setup( else: raise ValueError(f"mode {mode} not understood") - return folder_name, file_name, fig, ax + return full_path, fig, ax def draw_across(ax1, xy1, ax2, xy2, clip_on=False, **kwargs): @@ -297,9 +289,7 @@ def _finish_plot_2D( folder_name = "" if is_new_plot: - folder_name, file_name, fig, ax = plot_setup( - file_name=file_name, file_type=file_type, params=params - ) + out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type) else: fig = ax.get_figure() @@ -345,8 +335,8 @@ def _finish_plot_2D( cbar.ax.set_ylabel(cbar_label) if is_new_plot: - fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200) - print(f"plot saved in {os.path.join(folder_name, file_name)}") + fig.savefig(out_path, bbox_inches="tight", dpi=200) + print(f"plot saved in {out_path}") if cbar_label is not None: return fig, ax, cbar.ax else: @@ -354,20 +344,20 @@ def _finish_plot_2D( def plot_spectrogram( - values, - x_range, - y_range, - params, - t_res=None, - gate_width=None, - log=True, - vmin=None, - vmax=None, - cbar_label="normalized intensity (dB)", - file_type="png", - file_name=None, - cmap=None, - ax=None, + values: np.ndarray, + x_range: RangeType, + y_range: RangeType, + params: BareParams, + t_res: int = None, + gate_width: float = None, + log: bool = True, + vmin: float = None, + vmax: float = None, + cbar_label: str = "normalized intensity (dB)", + file_type: str = "png", + file_name: str = None, + cmap: str = None, + ax: plt.Axes = None, ): """Plots a spectrogram given a complex field in the time domain Parameters @@ -382,7 +372,7 @@ def plot_spectrogram( units : function to convert from the desired units to rad/s or to time. common functions are already defined in scgenerator.physics.units look there for more details - params : dict + params : BareParams parameters of the simulations log : bool, optional whether to compute the logarithm of the spectrogram @@ -424,16 +414,16 @@ def plot_spectrogram( t_win = 2 * np.max(t_range[2](np.abs(t_range[:2]))) spec_kwargs = dict(t_res=t_res, t_win=t_win, gate_width=gate_width, shift=False) spec, new_t = pulse.spectrogram( - params["t"].copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None} + params.t.copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None} ) # Crop and reoder axis new_t, ind_t, _ = units.sort_axis(new_t, t_range) - new_f, ind_f, _ = units.sort_axis(params["w"], f_range) + new_f, ind_f, _ = units.sort_axis(params.w, f_range) values = spec[ind_t][:, ind_f] if f_range[2].type == "WL": values = np.apply_along_axis( - units.to_WL, 1, values, params["frep"], units.m(f_range[2].inv(new_f)) + units.to_WL, 1, values, params.frep, units.m(f_range[2].inv(new_f)) ) values = np.apply_along_axis(make_uniform_1D, 1, values, new_f) @@ -463,19 +453,19 @@ def plot_spectrogram( def plot_results_2D( - values, - plt_range, - params, - log="1D", - skip=16, - vmin=None, - vmax=None, - transpose=False, - cbar_label="normalized intensity (dB)", - file_type="png", - file_name=None, - cmap=None, - ax=None, + values: np.ndarray, + plt_range: RangeType, + params: BareParams, + log: Union[int, float, bool, str] = "1D", + skip: int = 16, + vmin: float = None, + vmax: float = None, + transpose: bool = False, + cbar_label: Optional[str] = "normalized intensity (dB)", + file_type: str = "png", + file_name: str = None, + cmap: str = None, + ax: plt.Axes = None, ): """ plots 2D arrays and automatically saves the plots, as well as returns it @@ -540,27 +530,32 @@ def plot_results_2D( # make uniform if converting to wavelength if plt_range[2].type == "WL": if is_spectrum: - values = np.apply_along_axis(units.to_WL, 1, values, params.get("frep", 1), x_axis) + values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis) values = np.array( [make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values] ) - z = params["z_targets"] - lim_diff = 1e-5 * np.max(z) - dz_s = np.diff(z) + lim_diff = 1e-5 * np.max(params.z_targets) + dz_s = np.diff(params.z_targets) if not np.all(np.diff(dz_s) < lim_diff): new_z = np.linspace( - *span(z), int(np.floor((np.max(z) - np.min(z)) / np.min(dz_s[dz_s > lim_diff]))) + *span(params.z_targets), + int( + np.floor( + (np.max(params.z_targets) - np.min(params.z_targets)) + / np.min(dz_s[dz_s > lim_diff]) + ) + ), ) values = np.array( - [make_uniform_1D(v, z, n=len(new_z), method="linear") for v in values.T] + [make_uniform_1D(v, params.z_targets, n=len(new_z), method="linear") for v in values.T] ).T - z = new_z + params.z_targets = new_z return _finish_plot_2D( values, x_axis, plt_range[2].label, - z, + params.z_targets, "propagation distance (m)", log, vmin, @@ -576,20 +571,20 @@ def plot_results_2D( def plot_results_1D( - values, - plt_range, - params, - log=False, - spacing=1, - vmin=None, - vmax=None, - ylabel=None, - yscaling=1, - file_type="pdf", - file_name=None, - ax=None, - line_label=None, - transpose=False, + values: np.ndarray, + plt_range: RangeType, + params: BareParams, + log: Union[str, int, float, bool] = False, + spacing: Union[int, float] = 1, + vmin: float = None, + vmax: float = None, + ylabel: str = None, + yscaling: float = 1, + file_type: str = "pdf", + file_name: str = None, + ax: plt.Axes = None, + line_label: str = None, + transpose: bool = False, **line_kwargs, ): """ @@ -656,7 +651,7 @@ def plot_results_1D( # make uniform if converting to wavelength if plt_range[2].type == "WL": if is_spectrum: - values = units.to_WL(values, params["frep"], units.m.inv(params["w"][ind])) + values = units.to_WL(values, params.frep, units.m.inv(params.w[ind])) # change the resolution if isinstance(spacing, float): @@ -683,9 +678,7 @@ def plot_results_1D( folder_name = "" if is_new_plot: - folder_name, file_name, fig, ax = plot_setup( - file_name=file_name, file_type=file_type, params=params - ) + out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type) else: fig = ax.get_figure() if transpose: @@ -702,40 +695,40 @@ def plot_results_1D( ax.set_xlabel(plt_range[2].label) if is_new_plot: - fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200) - print(f"plot saved in {os.path.join(folder_name, file_name)}") + fig.savefig(out_path, bbox_inches="tight", dpi=200) + print(f"plot saved in {out_path}") return fig, ax, x_axis, values -def _prep_plot(values, plt_range, params): +def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams): is_spectrum = values.dtype == "complex" plt_range = (*plt_range[:2], units.get_unit(plt_range[2])) if plt_range[2].type in ["WL", "FREQ", "AFREQ"]: - x_axis = params["w"].copy() + x_axis = params.w.copy() else: - x_axis = params["t"].copy() + x_axis = params.t.copy() return is_spectrum, x_axis, plt_range def plot_avg( - values, - plt_range, - params, - log=False, - spacing=1, - vmin=None, - vmax=None, - ylabel=None, - yscaling=1, - renormalize=True, - add_coherence=False, - file_type="png", - file_name=None, - ax=None, - line_labels=None, - legend=True, - legend_kwargs={}, - transpose=False, + values: np.ndarray, + plt_range: RangeType, + params: BareParams, + log: Union[float, int, str, bool] = False, + spacing: Union[float, int] = 1, + vmin: float = None, + vmax: float = None, + ylabel: str = None, + yscaling: float = 1, + renormalize: bool = True, + add_coherence: bool = False, + file_type: str = "png", + file_name: str = None, + ax: plt.Axes = None, + line_labels: Tuple[str, str] = None, + legend: bool = True, + legend_kwargs: Dict[str, Any] = {}, + transpose: bool = False, ): """ plots 1D arrays and there mean and automatically saves the plots, as well as returns it @@ -817,8 +810,8 @@ def plot_avg( values *= yscaling mean_values = np.mean(values, axis=0) if plt_range[2].type == "WL" and renormalize: - values = np.apply_along_axis(units.to_WL, 1, values, params["frep"], x_axis) - mean_values = units.to_WL(mean_values, params["frep"], x_axis) + values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis) + mean_values = units.to_WL(mean_values, params.frep, x_axis) # change the resolution if isinstance(spacing, float): @@ -852,12 +845,12 @@ def plot_avg( if is_new_plot: if add_coherence: mode = "coherence_T" if transpose else "coherence" - folder_name, file_name, fig, (top, bot) = plot_setup( - file_name=file_name, file_type=file_type, params=params, mode=mode + out_path, fig, (top, bot) = plot_setup( + out_path=Path(folder_name) / file_name, file_type=file_type, mode=mode ) else: - folder_name, file_name, fig, top = plot_setup( - file_name=file_name, file_type=file_type, params=params + out_path, fig, top = plot_setup( + out_path=Path(folder_name) / file_name, file_type=file_type ) bot = top else: @@ -923,8 +916,8 @@ def plot_avg( top.legend(custom_lines, line_labels, **legend_kwargs) if is_new_plot: - fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200) - print(f"plot saved in {os.path.join(folder_name, file_name)}") + fig.savefig(out_path, bbox_inches="tight", dpi=200) + print(f"plot saved in {out_path}") if top is bot: return fig, top @@ -984,46 +977,6 @@ def prepare_plot_1D(values, plt_range, x_axis, yscaling=1, spacing=1, frep=80e6) return x_axis, np.squeeze(values) -def plot_dispersion_parameter(params, plt_range): - """ - Plots the dispersion parameter D as well as the beta2 parameter over the given range - """ - # TODO allow several curves, with legends, to be plotted - - x_axis = np.linspace(*plt_range[:2], 1000) - w_axis = plt_range[2](x_axis) - - if "disp_obj" in params: - D = params["disp_obj"].D_w(w_axis) - beta2 = params["disp_obj"].beta2_w(w_axis) - else: - print("no dispersion information given") - return - - fig, (ax_D, ax_beta2) = plt.subplots(1, 2) - - ax_D.plot(x_axis, 1e6 * D) - ax_D.plot( - x_axis, - 0 * x_axis, - ":", - c="k", - ) - ax_D.set_xlabel(plt_range[2].label) - ax_D.set_ylabel(r"Dispersion parameter $D$ ($\frac{\mathrm{ps}}{\mathrm{nm\ km}}$)") - - ax_beta2.plot(x_axis, 1e27 * beta2) - ax_beta2.plot( - x_axis, - 0 * x_axis, - ":", - c="k", - ) - ax_beta2.set_xlabel(plt_range[2].label) - ax_beta2.set_ylabel(r"$\beta_2$ parameter ($\frac{\mathrm{ps}^2}{\mathrm{km}}$)") - plt.show() - - def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)): """returns a new colormap based on "name" but that has a solid bacground (default=white)""" top = plt.get_cmap(name, 1024) diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index aa5a24a..94a7417 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -128,11 +128,11 @@ def main(): args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node) submit_path = Path( - "submit " + final_config["name"] + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh" + "submit " + final_config.name + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh" ) tmp_path = Path("submit tmp.sh") - job_name = f"supercontinuum {final_config['name']}" + job_name = f"supercontinuum {final_config.name}" submit_sh = template.format( job_name=job_name, configs_list=" ".join(f'"{c}"' for c in args.configs), **vars(args) ) diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index a22f928..843373d 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -1,16 +1,14 @@ import os -from collections.abc import Mapping, Sequence -from glob import glob -from typing import Any, Dict, List, Tuple +from collections.abc import Sequence from pathlib import Path +from typing import Dict import numpy as np -from scgenerator.const import SPECN_FN - -from . import io, initialize, math -from .plotting import units +from . import initialize, io, math +from .const import SPECN_FN from .logger import get_logger +from .plotting import units class Spectrum(np.ndarray): @@ -43,7 +41,7 @@ class Pulse(Sequence): self.params = None try: - self.params = io.load_previous_parameters(self.path / "params.toml") + self.params = io.load_params(self.path / "params.toml") except FileNotFoundError: self.logger.info(f"parameters corresponding to {self.path} not found") diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils/__init__.py similarity index 69% rename from src/scgenerator/utils.py rename to src/scgenerator/utils/__init__.py index 7e6f636..50ba13d 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils/__init__.py @@ -4,24 +4,23 @@ scgenerator module but some function may be used in any python program """ - -import collections import itertools import multiprocessing import threading -import time from collections import abc from copy import deepcopy +from dataclasses import asdict, replace from io import StringIO from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, List, Mapping, Tuple, TypeVar, Union +from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union import numpy as np from tqdm import tqdm -from . import env -from .const import PARAM_SEPARATOR, valid_variable -from .math import * +from .. import env +from ..const import PARAM_SEPARATOR +from ..math import * +from .parameter import BareConfig, BareParams T_ = TypeVar("T_") @@ -177,18 +176,11 @@ def progress_worker( pbars[0].update() -def count_variations(config: dict) -> Tuple[int, int]: +def count_variations(config: BareConfig) -> Tuple[int, int]: """returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and variable_params_num is the number of distinct parameters that will vary.""" - sim_num = 1 - variable_params_num = 0 - - for section_name in valid_variable: - for array in config.get(section_name, {}).get("variable", {}).values(): - sim_num *= len(array) - variable_params_num += 1 - - sim_num *= config["simulation"].get("repeat", 1) + variable_params_num = len(config.variable) + sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat return sim_num, variable_params_num @@ -217,49 +209,45 @@ def format_value(value): return str(value) -def variable_iterator(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: +def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]: """given a config with "variable" parameters, iterates through every possible combination, yielding a a list of (parameter_name, value) tuples and a full config dictionary. Parameters ---------- - config : dict - initial config dictionary + config : BareConfig + initial config obj Yields ------- - Iterator[Tuple[List[Tuple[str, Any]], dict]] + Iterator[Tuple[List[Tuple[str, Any]], BareParams]] variable_list : a list of (name, value) tuple of parameter name and value that are variable. - dict : a config dictionary for one simulation + params : a BareParams obj for one simulation """ - indiv_config = deepcopy(config) - variable_dict = { - section_name: indiv_config.get(section_name, {}).pop("variable", {}) - for section_name in valid_variable - } - possible_keys = [] possible_ranges = [] - for section_name, section in variable_dict.items(): - for key in section: - arr = variable_dict[section_name][key] - possible_keys.append((section_name, key)) - possible_ranges.append(range(len(arr))) + for key, values in config.variable.items(): + possible_keys.append(key) + possible_ranges.append(range(len(values))) combinations = itertools.product(*possible_ranges) for combination in combinations: + indiv_config = {} variable_list = [] for i, key in enumerate(possible_keys): - parameter_value = variable_dict[key[0]][key[1]][combination[i]] - indiv_config[key[0]][key[1]] = parameter_value - variable_list.append((key[1], parameter_value)) - yield variable_list, indiv_config + parameter_value = config.variable[key][combination[i]] + indiv_config[key] = parameter_value + variable_list.append((key, parameter_value)) + param_dict = asdict(config) + param_dict.pop("variable") + param_dict.update(indiv_config) + yield variable_list, BareParams(**param_dict) -def required_simulations(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: +def required_simulations(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]: """takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different parameter set and iterates through every single necessary simulation @@ -273,48 +261,19 @@ def required_simulations(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]] dict : a config dictionary for one simulation """ i = 0 # unique sim id - for variable_only, full_config in variable_iterator(config): - for j in range(config["simulation"]["repeat"]): + for variable_only, bare_params in variable_iterator(config): + for j in range(config.repeat): variable_ind = [("id", i)] + variable_only + [("num", j)] i += 1 - yield variable_ind, full_config + yield variable_ind, bare_params -def deep_update(d: Mapping, u: Mapping) -> dict: - for k, v in u.items(): - if isinstance(v, collections.abc.Mapping): - d[k] = deep_update(d.get(k, {}), v) - else: - d[k] = v - return d - - -def override_config(new: Dict[str, Any], old: Dict[str, Any] = None) -> Dict[str, Any]: +def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig: """makes sure all the parameters set in new are there, leaves untouched parameters in old""" if old is None: - return new - out = deepcopy(old) - for section_name, section in new.items(): - if isinstance(section, Mapping): - for param_name, value in section.items(): - if param_name == "variable" and isinstance(value, Mapping): - out[section_name].setdefault("variable", {}) - for p, v in value.items(): - # override previously unvariable param - if p in old[section_name]: - del out[section_name][p] - out[section_name]["variable"][p] = v - else: - # override previously variable param - if ( - "variable" in old[section_name] - and isinstance(old[section_name]["variable"], Mapping) - and param_name in old[section_name]["variable"] - ): - del out[section_name]["variable"][param_name] - if len(out[section_name]["variable"]) == 0: - del out[section_name["variable"]] - out[section_name][param_name] = value - else: - out[section_name] = section - return out + return BareConfig(**new) + variable = deepcopy(old.variable) + variable.update(new.pop("variable", {})) # add new variable + for k in new: + variable.pop(k) # remove old ones + return replace(old, variable=variable, **{k: None for k in variable}, **new) diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py new file mode 100644 index 0000000..b1cb124 --- /dev/null +++ b/src/scgenerator/utils/parameter.py @@ -0,0 +1,442 @@ +import datetime +from copy import copy +from dataclasses import asdict, dataclass +from functools import lru_cache +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union + +import numpy as np + +from ..const import __version__ + + +@lru_cache +def type_checker(*types): + def _type_checker_wrapper(validator, n=None): + if isinstance(validator, str) and n is not None: + _name = validator + validator = lambda *args: None + + def _type_checker_wrapped(name, n): + if not isinstance(n, types): + raise TypeError( + f"{name!r} value must be of type {' or '.join(format(t) for t in types)} " + f"instead of {type(n)}" + ) + validator(name, n) + + if n is None: + return _type_checker_wrapped + else: + _type_checker_wrapped(_name, n) + + return _type_checker_wrapper + + +@type_checker(str) +def string(name, n): + if len(n) == 0: + raise ValueError(f"{name!r} must not be empty") + + +def in_range_excl(_min, _max): + @type_checker(float, int) + def _in_range(name, n): + if n <= _min or n >= _max: + raise ValueError(f"{name!r} must be between {_min} and {_max} (exclusive)") + + return _in_range + + +def in_range_incl(_min, _max): + @type_checker(float, int) + def _in_range(name, n): + if n < _min or n > _max: + raise ValueError(f"{name!r} must be between {_min} and {_max} (inclusive)") + + return _in_range + + +def boolean(name, n): + if not n is True and not n is False: + raise ValueError(f"{name!r} must be True or False") + + +@lru_cache +def non_negative(*types): + @type_checker(*types) + def _non_negative(name, n): + if n < 0: + raise ValueError(f"{name!r} must be non negative") + + return _non_negative + + +@lru_cache +def positive(*types): + @type_checker(*types) + def _positive(name, n): + if n <= 0: + raise ValueError(f"{name!r} must be positive") + + return _positive + + +@type_checker(tuple, list) +def int_pair(name, t): + invalid = len(t) != 2 + for m in t: + if invalid or not isinstance(m, int): + raise ValueError(f"{name!r} must be a list or a tuple of 2 int") + + +def literal(*l): + l = set(l) + + @type_checker(str) + def _string(name, s): + if not s in l: + raise ValueError(f"{name!r} must be a str in {l}") + + return _string + + +def validator_list(validator): + """returns a new validator that applies validator to each el of an iterable""" + + @type_checker(list, tuple) + def _list_validator(name, l): + for i, el in enumerate(l): + validator(name + f"[{i}]", el) + + return _list_validator + + +def validator_or(*validators): + """combines many validators and raises an exception only if all of them raise an exception""" + + n = len(validators) + + def _or_validator(name, value): + errors = [] + for validator in validators: + try: + validator(name, value) + except (ValueError, TypeError) as e: + errors.append(e) + errors.sort(key=lambda el: isinstance(el, ValueError)) + if len(errors) == n: + raise errors[-1] + + return _or_validator + + +def validator_and(*validators): + def _and_validator(name, n): + for v in validators: + v(name, n) + + return _and_validator + + +@type_checker(list, tuple, np.ndarray) +def num_list(name, l): + for i, el in enumerate(l): + type_checker(int, float)(name + f"[{i}]", el) + + +def func_validator(name, n): + if not callable(n): + raise TypeError(f"{name!r} must be callable") + + +class Parameter: + def __init__(self, validator, converter=None, default=None): + """Single parameter + + Parameters + ---------- + tpe : type + type of the paramter + validators : Callable[[str, Any], None] + signature : validator(name, value) + must raise a ValueError when value doesn't fit the criteria checked by + validator. name is passed to validator to be included in the error message + converter : Callable, optional + converts a valid value (for example, str.lower), by default None + default : callable, optional + factory function for a default value (for example, list), by default None + """ + + self.validator = validator + self.converter = converter + self.default = default + + def __set_name__(self, owner, name): + self.name = name + + def __get__(self, instance, owner): + if not instance: + return self + return instance.__dict__[self.name] + + def __delete__(self, instance): + del instance.__dict__[self.name] + + def __set__(self, instance, value): + if isinstance(value, Parameter): + defaut = None if self.default is None else copy(self.default) + instance.__dict__[self.name] = defaut + else: + if value is not None: + self.validator(self.name, value) + if self.converter is not None: + value = self.converter(value) + instance.__dict__[self.name] = value + + +class VariableParameter: + def __init__(self, parameterBase): + self.pbase = parameterBase + + def __set_name__(self, owner, name): + self.name = name + + def __get__(self, instance, owner): + if not instance: + return self + return instance.__dict__[self.name] + + def __delete__(self, instance): + del instance.__dict__[self.name] + + def __set__(self, instance, value: dict): + if isinstance(value, VariableParameter): + value = {} + else: + for k, v in value.items(): + if k not in valid_variable: + raise TypeError(f"{k!r} is not a valide variable parameter") + if len(v) == 0: + raise ValueError(f"variable parameter {k!r} must not be empty") + + p = getattr(self.pbase, k) + + for el in v: + p.validator(k, el) + instance.__dict__[self.name] = value + + +valid_variable = { + "beta", + "gamma", + "pitch", + "pitch_ratio", + "core_radius", + "capillary_num", + "capillary_outer_d", + "capillary_thickness", + "capillary_spacing", + "capillary_resonance_strengths", + "capillary_nested", + "he_mode", + "fit_parameters", + "input_transmission", + "n2", + "pressure", + "temperature", + "gas_name", + "plasma_density" "peak_power", + "mean_power", + "peak_power", + "energy", + "quantum_noise", + "shape", + "wavelength", + "intensity_noise", + "width", + "soliton_num", + "behaviors", + "raman_type", + "tolerated_error", + "step_size", + "ideal_gas", + "readjust_wavelength", +} + +hc_model_specific_parameters = dict( + marcatili=["core_radius", "he_mode"], + marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"], + hasan=[ + "core_radius", + "capillary_num", + "capillary_thickness", + "capillary_resonance_strengths", + "capillary_nested", + "capillary_spacing", + "capillary_outer_d", + ], +) +"""dependecy map only includes actual fiber parameters and exclude gas parameters""" + + +@dataclass +class BareParams: + """ + This class defines each valid parameter's name, type and valid value but doesn't provide + any method to act on those. For that, use initialize.Params + """ + + # root + name: str = Parameter(string) + prev_data_dir: str = Parameter(string) + + # # fiber + input_transmission: float = Parameter(in_range_incl(0, 1)) + gamma: float = Parameter(non_negative(float, int)) + n2: float = Parameter(non_negative(float, int)) + effective_mode_diameter: float = Parameter(positive(float, int)) + A_eff: float = Parameter(non_negative(float, int)) + pitch: float = Parameter(in_range_excl(0, 1e-3)) + pitch_ratio: float = Parameter(in_range_excl(0, 1)) + core_radius: float = Parameter(in_range_excl(0, 1e-3)) + he_mode: Tuple[int, int] = Parameter(int_pair) + fit_parameters: Tuple[int, int] = Parameter(int_pair) + beta: Iterable[float] = Parameter(num_list) + dispersion_file: str = Parameter(string) + model: str = Parameter(literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom")) + length: float = Parameter(non_negative(float, int)) + capillary_num: int = Parameter(positive(int)) + capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3)) + capillary_thickness: float = Parameter(in_range_excl(0, 1e-3)) + capillary_spacing: float = Parameter(in_range_excl(0, 1e-3)) + capillary_resonance_strengths: Iterable[float] = Parameter(num_list) + capillary_nested: int = Parameter(non_negative(int)) + + # gas + gas_name: str = Parameter(literal("vacuum", "helium", "air"), converter=str.lower) + pressure: Union[float, Iterable[float]] = Parameter( + validator_or(non_negative(float, int), num_list) + ) + temperature: float = Parameter(positive(float, int)) + plasma_density: float = Parameter(non_negative(float, int)) + + # pulse + field_file: str = Parameter(string) + repetition_rate: float = Parameter(non_negative(float, int)) + peak_power: float = Parameter(positive(float, int)) + mean_power: float = Parameter(positive(float, int)) + energy: float = Parameter(positive(float, int)) + soliton_num: float = Parameter(positive(float, int)) + quantum_noise: bool = Parameter(boolean) + shape: str = Parameter(literal("gaussian", "sech")) + wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9)) + intensity_noise: float = Parameter(in_range_incl(0, 1)) + width: float = Parameter(in_range_excl(0, 1e-9)) + t0: float = Parameter(in_range_excl(0, 1e-9)) + + # simulation + behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss"))) + parallel: bool = Parameter(boolean) + raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower) + ideal_gas: bool = Parameter(boolean) + repeat: int = Parameter(positive(int)) + t_num: int = Parameter(positive(int)) + z_num: int = Parameter(positive(int)) + time_window: float = Parameter(positive(float, int)) + dt: float = Parameter(in_range_excl(0, 5e-15)) + tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-5)) + step_size: float = Parameter(positive(float, int)) + lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9)) + upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9)) + frep: float = Parameter(positive(float, int)) + prev_sim_dir: str = Parameter(string) + readjust_wavelength: bool = Parameter(boolean) + recovery_last_stored: int = Parameter(non_negative(int)) + + # computed + field_0: np.ndarray = Parameter(type_checker(np.ndarray)) + spec_0: np.ndarray = Parameter(type_checker(np.ndarray)) + w: np.ndarray = Parameter(type_checker(np.ndarray)) + w_c: np.ndarray = Parameter(type_checker(np.ndarray)) + t: np.ndarray = Parameter(type_checker(np.ndarray)) + L_D: float = Parameter(non_negative(float, int)) + L_NL: float = Parameter(non_negative(float, int)) + L_sol: float = Parameter(non_negative(float, int)) + dynamic_dispersion: bool = Parameter(boolean) + adapt_step_size: bool = Parameter(boolean) + error_ok: float = Parameter(positive(float)) + hr_w: np.ndarray = Parameter(type_checker(np.ndarray)) + z_targets: np.ndarray = Parameter(type_checker(np.ndarray)) + const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) + beta_func: Callable[[float], List[float]] = Parameter(func_validator) + gamma_func: Callable[[float], float] = Parameter(func_validator) + + def prepare_for_dump(self) -> Dict[str, Any]: + param = asdict(self) + param = BareParams.strip_params_dict(param) + param["datetime"] = datetime.datetime.now() + param["version"] = __version__ + return param + + @staticmethod + def strip_params_dict(dico: Dict[str, Any]) -> Dict[str, Any]: + """prepares a dictionary for serialization. Some keys may not be preserved + (dropped because they take a lot of space and can be exactly reconstructed) + + Parameters + ---------- + dico : dict + dictionary + """ + forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"] + types = (np.ndarray, float, int, str, list, tuple, dict) + out = {} + for key, value in dico.items(): + if key in forbiden_keys: + continue + if not isinstance(value, types): + continue + if isinstance(value, dict): + out[key] = BareParams.strip_params_dict(value) + elif isinstance(value, np.ndarray) and value.dtype == complex: + continue + else: + out[key] = value + + if "variable" in out and len(out["variable"]) == 0: + del out["variable"] + + return out + + +@dataclass +class BareConfig(BareParams): + variable: dict = VariableParameter(BareParams) + + +if __name__ == "__main__": + + numero = type_checker(int) + + @numero + def natural_number(name, n): + if n < 0: + raise ValueError(f"{name!r} must be positive") + + try: + numero("a", np.arange(45)) + except Exception as e: + print(e) + try: + natural_number("b", -1) + except Exception as e: + print(e) + try: + natural_number("c", 1.0) + except Exception as e: + print(e) + try: + natural_number("d", 1) + print("success !") + except Exception as e: + print(e) diff --git a/testing/test_new_params.py b/testing/test_new_params.py new file mode 100644 index 0000000..82e900d --- /dev/null +++ b/testing/test_new_params.py @@ -0,0 +1,37 @@ +from numba.core import config +from scgenerator.initialize import Config, Params, BareParams +from scgenerator.utils import variable_iterator, override_config +from scgenerator.io import load_toml +from pprint import pprint +from dataclasses import asdict + +dico = load_toml("testing/configs/ensure_consistency/good2.toml") +out = dict(variable=dict()) +for k, v in dico.items(): + if isinstance(v, dict): + for kk, vv in v.items(): + if kk == "variable": + for kkk, vvv in vv.items(): + out["variable"][kkk] = vvv + else: + out[kk] = vv + +pprint(out) +p = Config(**out) +print(p) + +for l, c in variable_iterator(p): + print(l, c.width, c.intensity_noise) + print() + +config2 = override_config(dict(width=1.2e-13, variable=dict(peak_power=[1e5, 2e5])), p) +print( + f"{config2.variable=}", + f"{config2.intensity_noise=}", + f"{config2.width=}", + f"{config2.peak_power=}", +) + +par = BareParams() + +print(all(v is None for v in vars(par).values()))