diff --git a/README.md b/README.md index 01896a2..cb4bcf2 100644 --- a/README.md +++ b/README.md @@ -260,14 +260,8 @@ time_window: float total length of the temporal grid in s ### optional -behaviors: list of str {"spm", "raman", "ss"} - spm is self-phase modulation - raman is raman effect - ss is self-steepening - default : ["spm", "ss"] - -raman_type: str {"measured", "stolen", "agrawal"} - type of Raman effect. Default is "agrawal". +raman_type: str {"measured", "stolen", "agrawal"}, optional + type of Raman effect. Specifying this parameter has the effect of turning on Raman effect ideal_gas: bool if True, use the ideal gas law. Otherwise, use van der Waals equation. default : False @@ -285,7 +279,7 @@ step_size: float if given, sets a constant step size rather than adapting it. parallel: bool - whether to run simulations in parallel with the available ressources. default : false + whether to run simulations in parallel with the available resources. default : false repeat: int how many simulations to run per parameter set. default : 1 diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index e4adb8b..7fd4376 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -1,3 +1,4 @@ +# flake8: noqa from . import math from .legacy import convert_sim_folder from .math import abs2, argclosest, span diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index b20e441..5a68d20 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -64,7 +64,6 @@ VALID_VARIABLE = { "width", "t0", "soliton_num", - "behaviors", "raman_type", "tolerated_error", "step_size", @@ -85,15 +84,23 @@ MANDATORY_PARAMETERS = [ "input_transmission", "z_targets", "length", - "beta2_coefficients", - "gamma_arr", - "behaviors", "adapt_step_size", "tolerated_error", - "dynamic_dispersion", "recovery_last_stored", "output_path", "repeat", "linear_operator", "nonlinear_operator", ] + +ROOT_PARAMETERS = [ + "repeat", + "num", + "dt", + "t_num", + "time_window", + "step_size", + "tolerated_error", + "width", + "shape", +] diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index d07571a..3834c8d 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -349,7 +349,7 @@ default_rules: list[Rule] = [ Rule("A_eff", fiber.A_eff_marcatili, priorities=-2), Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]), Rule("A_eff_arr", fiber.load_custom_A_eff), - Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1), + # Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1), Rule( "V_eff", fiber.V_parameter_koshiba, @@ -364,6 +364,7 @@ default_rules: list[Rule] = [ ["l", "core_radius", "numerical_aperture", "interpolation_range"], ), Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1), + Rule("gamma", fiber.gamma_parameter), Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]), Rule("n2", materials.gas_n2), Rule("n2", lambda: 2.2e-20, priorities=-1), diff --git a/src/scgenerator/legacy.py b/src/scgenerator/legacy.py index 535f7ab..ac02a7a 100644 --- a/src/scgenerator/legacy.py +++ b/src/scgenerator/legacy.py @@ -1,6 +1,6 @@ -from genericpath import exists import os import sys +from collections import MutableMapping from pathlib import Path from typing import Any, Set @@ -9,8 +9,8 @@ import toml from .const import SPEC1_FN, SPEC1_FN_N, SPECN_FN1 from .parameter import Configuration, Parameters -from .utils import save_parameters from .pbar import PBars +from .utils import save_parameters from .variationer import VariationDescriptor @@ -87,6 +87,45 @@ def _mv_specs(pbar: PBars, new_params: Parameters, start_z: int, spec_num: int, pbar.update() +def translate_parameters(d: dict[str, Any]) -> dict[str, Any]: + """translate parameters name and value from older versions of the program + + Parameters + ---------- + d : dict[str, Any] + [description] + + Returns + ------- + dict[str, Any] + [description] + """ + old_names = dict( + interp_degree="interpolation_degree", + beta="beta2_coefficients", + interp_range="interpolation_range", + ) + wl_limits_old = ["lower_wavelength_interp_limit", "upper_wavelength_interp_limit"] + defaults_to_add = dict(repeat=1) + new = {} + if len(set(wl_limits_old) & d.keys()) == 2: + new["interpolation_range"] = (d[wl_limits_old[0]], d[wl_limits_old[1]]) + for k, v in d.items(): + if k == "error_ok": + new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v + elif k == "behaviors": + beh = d["behaviors"] + if "raman" in beh: + new["raman_type"] = d["raman_type"] + new["spm"] = "spm" in beh + new["self_steepening"] = "ss" in beh + elif isinstance(v, MutableMapping): + new[k] = translate_parameters(v) + else: + new[old_names.get(k, k)] = v + return defaults_to_add | new + + def main(): convert_sim_folder(sys.argv[1]) diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index d65271c..0037f2a 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -1,7 +1,6 @@ from typing import Union import numpy as np -from scipy.interpolate import griddata, interp1d from scipy.special import jn_zeros from .cache import np_cache @@ -172,50 +171,6 @@ def indft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray: return indft_matrix(t, f) @ a -def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"): - """Interpolates a 2D array with the help of griddata - Parameters - ---------- - values : 2D array of real values - x_axis : x-coordinates of values - y_axis : y-coordinates of values - method : method of interpolation to be passed to griddata - Returns - ---------- - array of shape n - """ - xx, yy = np.meshgrid(x_axis, y_axis) - xx = xx.flatten() - yy = yy.flatten() - - if not isinstance(n, tuple): - n = (n, n) - - # old_points = np.array([gridx.ravel(), gridy.ravel()]) - - newx, newy = np.meshgrid(np.linspace(*span(x_axis), n[0]), np.linspace(*span(y_axis), n[1])) - - print("interpolating") - out = griddata((xx, yy), values.flatten(), (newx, newy), method=method, fill_value=0) - print("interpolating done!") - return out.reshape(n[1], n[0]) - - -def make_uniform_1D(values, x_axis, n=1024, method="linear"): - """Interpolates a 2D array with the help of interp1d - Parameters - ---------- - values : 1D array of real values - x_axis : x-coordinates of values - method : method of interpolation to be passed to interp1d - Returns - ---------- - array of length n - """ - xx = np.linspace(*span(x_axis), len(x_axis)) - return interp1d(x_axis, values, kind=method)(xx) - - def all_zeros(x: np.ndarray, y: np.ndarray) -> np.ndarray: """find all the x values such that y(x)=0 with linear interpolation""" pos = np.argwhere(y[1:] * y[:-1] < 0)[:, 0] diff --git a/src/scgenerator/operators.py b/src/scgenerator/operators.py index 4e0d8da..44902b8 100644 --- a/src/scgenerator/operators.py +++ b/src/scgenerator/operators.py @@ -17,7 +17,7 @@ from .physics import fiber, pulse class SpectrumDescriptor: name: str - value: np.ndarray + value: np.ndarray = None def __set__(self, instance, value): instance.spec2 = math.abs2(value) diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 1db9484..2a3e4fa 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -12,13 +12,13 @@ from typing import Any, Callable, Iterable, Iterator, TypeVar, Union import numpy as np -from . import env, utils -from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS +from . import env, legacy, utils +from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__ +from .evaluator import Evaluator from .logger import get_logger +from .operators import LinearOperator, NonLinearOperator from .utils import fiber_folder, update_path_name from .variationer import VariationDescriptor, Variationer -from .evaluator import Evaluator -from .operators import NonLinearOperator, LinearOperator T = TypeVar("T") @@ -312,11 +312,9 @@ class Parameters: t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) # simulation - behaviors: tuple[str] = Parameter( - validator_list(literal("spm", "raman", "ss")), converter=tuple, default=("spm", "ss") - ) - parallel: bool = Parameter(boolean, default=True) raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower) + self_steepening: bool = Parameter(boolean, default=True) + spm: bool = Parameter(boolean, default=True) ideal_gas: bool = Parameter(boolean, default=False) repeat: int = Parameter(positive(int), default=1) t_num: int = Parameter(positive(int)) @@ -329,6 +327,7 @@ class Parameters: interpolation_degree: int = Parameter(positive(int), default=8) prev_sim_dir: str = Parameter(string) recovery_last_stored: int = Parameter(non_negative(int), default=0) + parallel: bool = Parameter(boolean, default=True) worker_num: int = Parameter(positive(int)) # computed @@ -459,9 +458,9 @@ class Configuration: obj with the output path of the simulation saved in its output_path attribute. """ - fiber_configs: list[dict[str, Any]] + fiber_configs: list[utils.SubConfig] vary_dicts: list[dict[str, list]] - master_config: dict[str, Any] + master_config_dict: dict[str, Any] fiber_paths: list[Path] num_sim: int num_fibers: int @@ -515,51 +514,47 @@ class Configuration: mkdir=False, prevent_overwrite=not self.overwrite, ) - self.master_config = self.fiber_configs[0].copy() + self.master_config_dict = self.fiber_configs[0].fixed | { + k: v[0] for vary_dict in self.fiber_configs[0].variable for k, v in vary_dict.items() + } self.name = self.final_path.name self.z_num = 0 self.total_num_steps = 0 self.fiber_paths = [] self.all_configs = {} self.skip_callback = skip_callback - self.worker_num = self.master_config.get("worker_num", max(1, os.cpu_count() // 2)) - self.repeat = self.master_config.get("repeat", 1) + self.worker_num = self.master_config_dict.get("worker_num", max(1, os.cpu_count() // 2)) + self.repeat = self.master_config_dict.get("repeat", 1) self.variationer = Variationer() fiber_names = set() self.num_fibers = 0 for i, config in enumerate(self.fiber_configs): - config.setdefault("name", Parameters.name.default) - self.z_num += config["z_num"] - fiber_names.add(config["name"]) - vary_dict_list: list[dict[str, list]] = config.pop("variable") - self.variationer.append(vary_dict_list) + config.fixed.setdefault("name", Parameters.name.default) + self.z_num += config.fixed["z_num"] + fiber_names.add(config.fixed["name"]) + self.variationer.append(config.variable) self.fiber_paths.append( utils.ensure_folder( - self.final_path / fiber_folder(i, self.name, config["name"]), + self.final_path / fiber_folder(i, self.name, config.fixed["name"]), mkdir=False, prevent_overwrite=not self.overwrite, ) ) - self.__validate_variable(vary_dict_list) + self.__validate_variable(config.variable) self.num_fibers += 1 Evaluator.evaluate_default( - self.__build_base_config() - | config - | {k: v[0] for vary_dict in vary_dict_list for k, v in vary_dict.items()}, + self.master_config_dict + | config.fixed + | {k: v[0] for vary_dict in config.variable for k, v in vary_dict.items()}, True, ) self.num_sim = self.variationer.var_num() self.total_num_steps = sum( - config["z_num"] * self.variationer.var_num(i) + config.fixed["z_num"] * self.variationer.var_num(i) for i, config in enumerate(self.fiber_configs) ) - self.parallel = self.master_config.get("parallel", Parameters.parallel.default) - - def __build_base_config(self): - cfg = self.master_config.copy() - vary: list[dict[str, list]] = cfg.pop("variable") - return cfg | {k: v[0] for vary_dict in vary for k, v in vary_dict.items()} + self.parallel = self.master_config_dict.get("parallel", Parameters.parallel.default) def __validate_variable(self, vary_dict_list: list[dict[str, list]]): for vary_dict in vary_dict_list: @@ -593,7 +588,7 @@ class Configuration: index = self.num_fibers + index sim_dict: dict[Path, Configuration.__SimConfig] = {} for descriptor in self.variationer.iterate(index): - cfg = descriptor.update_config(self.fiber_configs[index]) + cfg = descriptor.update_config(self.fiber_configs[index].fixed) if index > 0: cfg["prev_data_dir"] = str( self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True) @@ -611,7 +606,8 @@ class Configuration: task, config_dict = self.__decide(sim_config) if task == self.Action.RUN: sim_dict.pop(data_dir) - yield sim_config.descriptor, Parameters(**sim_config.config) + param_dict = legacy.translate_parameters(sim_config.config) + yield sim_config.descriptor, Parameters(**param_dict) if "recovery_last_stored" in config_dict and self.skip_callback is not None: self.skip_callback(config_dict["recovery_last_stored"]) break @@ -695,10 +691,7 @@ class Configuration: def save_parameters(self): os.makedirs(self.final_path, exist_ok=True) - cfgs = [ - cfg | dict(variable=self.variationer.all_dicts[i]) - for i, cfg in enumerate(self.fiber_configs) - ] + cfgs = [cfg.fixed | dict(variable=cfg.variable) for cfg in self.fiber_configs] utils.save_toml(self.final_path / "initial_config.toml", dict(name=self.name, Fiber=cfgs)) @property diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index e7133ab..151edf1 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -14,14 +14,15 @@ from dataclasses import dataclass from functools import cache from pathlib import Path from string import printable as str_printable -from typing import Any, Callable, MutableMapping, Sequence, TypeVar +from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Set import numpy as np import pkg_resources as pkg import toml -from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN +from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN, ROOT_PARAMETERS from .logger import get_logger +from .errors import DuplicateParameterError T_ = TypeVar("T_") @@ -74,39 +75,51 @@ class Paths: return os.path.join(cls.get("plots"), name) -class ConfigFileParser: - path: Path - repeat: int - master: ConfigFileParser.SubConfig - configs: list[ConfigFileParser.SubConfig] +@dataclass(init=False) +class SubConfig: + fixed: dict[str, Any] + variable: list[dict[str, list]] + fixed_keys: Set[str] + variable_keys: Set[str] - @dataclass - class SubConfig: - fixed: dict[str, Any] - variable: dict[str, list] + def __init__(self, dico: dict[str, Any]): + dico = dico.copy() + self.variable = conform_variable_entry(dico.pop("variable", [])) + self.fixed = dico + self.__update - def __init__(self, path: os.PathLike): - self.path = Path(path) - fiber_list: list[dict[str, Any]] - if self.path.name.lower().endswith(".toml"): - loaded_config = _open_config(self.path) - fiber_list = loaded_config.pop("Fiber") - else: - loaded_config = dict(name=self.path.name) - fiber_list = [_open_config(p) for p in sorted(self.path.glob("initial_config*.toml"))] + def __update(self): + self.variable_keys = set() + self.fixed_keys = set() + for dico in self.variable: + for key in dico: + if key in self.variable_keys: + raise DuplicateParameterError(f"{key} is specified twice") + self.variable_keys.add(key) + for key in self.fixed: + if key in self.variable_keys: + raise DuplicateParameterError(f"{key} is specified twice") + self.fixed_keys.add(key) - if len(fiber_list) == 0: - raise ValueError(f"No fiber in config {self.path}") - configs = [] - for i, params in enumerate(fiber_list): - configs.append(loaded_config | params) - for root_vary, first_vary in itertools.product( - loaded_config["variable"], configs[0]["variable"] - ): - if len(common := root_vary.keys() & first_vary.keys()) != 0: - raise ValueError(f"These variable keys are specified twice : {common!r}") - configs[0] |= {k: v for k, v in loaded_config.items() if k != "variable"} - configs[0]["variable"].append(dict(num=list(range(configs[0].get("repeat", 1))))) + def weak_update(self, other: SubConfig = None, **kwargs): + """similar to a dict update method put prioritizes existing values + + Parameters + ---------- + other : SubConfig + other obj + """ + if other is None: + other = SubConfig(kwargs) + self.fixed = other.fixed | self.fixed + self.variable = other.variable + self.variable + self.__update() + + +def conform_variable_entry(d) -> list[dict[str, list]]: + if isinstance(d, MutableMapping): + d = [{k: v} for k, v in d.items()] + return d def load_previous_spectrum(prev_data_dir: str) -> np.ndarray: @@ -141,23 +154,11 @@ def _open_config(path: os.PathLike): path = conform_toml_path(path) dico = resolve_loadfile_arg(load_toml(path)) - dico = standardize_variable_dicts(dico) if "Fiber" not in dico: dico = dict(name=path.name, Fiber=[dico]) return dico -def standardize_variable_dicts(dico: dict[str, Any]): - if "Fiber" in dico: - dico["Fiber"] = [standardize_variable_dicts(fiber) for fiber in dico["Fiber"]] - if (var := dico.get("variable")) is not None: - if isinstance(var, MutableMapping): - dico["variable"] = [var] - else: - dico["variable"] = [{}] - return dico - - def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]: if (f_list := dico.pop("INCLUDE", None)) is not None: if isinstance(f_list, str): @@ -196,7 +197,7 @@ def save_toml(path: os.PathLike, dico): return dico -def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]: +def load_config_sequence(path: os.PathLike) -> tuple[Path, list[SubConfig]]: """loads a configuration file Parameters @@ -213,28 +214,26 @@ def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]] """ path = Path(path) - fiber_list: list[dict[str, Any]] if path.name.lower().endswith(".toml"): - loaded_config = _open_config(path) - fiber_list = loaded_config.pop("Fiber") + master_config_dict = _open_config(path) + fiber_list = [SubConfig(d) for d in master_config_dict.pop("Fiber")] + master_config = SubConfig(master_config_dict) else: - loaded_config = dict(name=path.name) - fiber_list = [_open_config(p) for p in sorted(path.glob("initial_config*.toml"))] + master_config = SubConfig(dict(name=path.name)) + fiber_list = [SubConfig(_open_config(p)) for p in sorted(path.glob("initial_config*.toml"))] if len(fiber_list) == 0: raise ValueError(f"No fiber in config {path}") - final_path = loaded_config.get("name") - configs = [] - for i, params in enumerate(fiber_list): - configs.append(loaded_config | params) - for root_vary, first_vary in itertools.product( - loaded_config["variable"], configs[0]["variable"] - ): - if len(common := root_vary.keys() & first_vary.keys()) != 0: - raise ValueError(f"These variable keys are specified twice : {common!r}") - configs[0] |= {k: v for k, v in loaded_config.items() if k != "variable"} - configs[0]["variable"].append(dict(num=list(range(configs[0].get("repeat", 1))))) - return Path(final_path), configs + for fiber in fiber_list: + fiber.weak_update(master_config) + if "num" not in fiber_list[0].variable_keys: + repeat_arg = list(range(fiber_list[0].fixed.get("repeat", 1))) + fiber_list[0].weak_update(variable=dict(num=repeat_arg)) + for p_name in ROOT_PARAMETERS: + if any(p_name in conf.variable_keys for conf in fiber_list[1:]): + raise ValueError(f"{p_name} should only be specified in the root or first fiber") + configs = fiber_list + return Path(master_config.fixed["name"]), configs @cache @@ -340,27 +339,6 @@ def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray ) -def translate_parameters(d: dict[str, Any]) -> dict[str, Any]: - old_names = dict( - interp_degree="interpolation_degree", - beta="beta2_coefficients", - interp_range="interpolation_range", - ) - deleted_names = {"lower_wavelength_interp_limit", "upper_wavelength_interp_limit"} - defaults_to_add = dict(repeat=1) - new = {} - for k, v in d.items(): - if k == "error_ok": - new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v - elif k in deleted_names: - continue - elif isinstance(v, MutableMapping): - new[k] = translate_parameters(v) - else: - new[old_names.get(k, k)] = v - return defaults_to_add | new - - def to_62(i: int) -> str: arr = [] if i == 0: @@ -445,7 +423,7 @@ def combine_simulations(path: Path, dest: Path = None): for l in paths.values(): try: l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0]) - except ValueError: + except TypeError: pass for pulses in paths.values(): new_path = dest / update_path_name(pulses[0].name) diff --git a/src/scgenerator/variationer.py b/src/scgenerator/variationer.py index b7cedd4..5576069 100644 --- a/src/scgenerator/variationer.py +++ b/src/scgenerator/variationer.py @@ -79,7 +79,7 @@ class Variationer: len_to_test = len(values[0]) if not all(len(v) == len_to_test for v in values[1:]): raise VariationSpecsError( - f"variable items should all have the same number of parameters" + "variable items should all have the same number of parameters" ) num_vars.append(len_to_test) if len(num_vars) == 0: