diff --git a/func_rewrite.py b/func_rewrite.py deleted file mode 100644 index 8ccc359..0000000 --- a/func_rewrite.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Callable -import inspect -import re - - -def get_arg_names(func: Callable) -> list[str]: - spec = inspect.getfullargspec(func) - args = spec.args - if spec.defaults is not None and len(spec.defaults) > 0: - args = args[: -len(spec.defaults)] - return args - - -def validate_arg_names(names: list[str]): - for n in names: - if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: - raise ValueError(f"{n} is an invalid parameter name") - - -def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None): - if arg_names is None: - arg_names = get_arg_names(func) - else: - validate_arg_names(arg_names) - validate_arg_names(kwarg_names) - sign_arg_str = ", ".join(arg_names + kwarg_names) - call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) - tmp_name = f"{func.__name__}_0" - func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" - scope = dict(__func__=func) - exec(func_str, scope) - return scope[tmp_name] - - -def lol(a, b=None, c=None): - print(f"{a=}, {b=}, {c=}") - - -def main(): - lol1 = func_rewrite(lol, ["c"]) - print(inspect.getfullargspec(lol1)) - lol2 = func_rewrite(lol, ["b"]) - print(inspect.getfullargspec(lol2)) - - -if __name__ == "__main__": - main() diff --git a/play.py b/play.py index 0fb33a2..446c52a 100644 --- a/play.py +++ b/play.py @@ -1,6 +1,18 @@ -from tqdm import tqdm -import time -import random +from scgenerator import Parameters +import os -for i in tqdm(range(100), smoothing=0): - time.sleep(random.random()) + +def main(): + cwd = os.getcwd() + try: + os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations") + + pa = Parameters.load("PM1550+PM2000D/PM1550_PM2000D raman_test/initial_config_0.toml") + + print(pa) + finally: + os.chdir(cwd) + + +if __name__ == "__main__": + main() diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index f53314d..3d3f2e6 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -2,15 +2,15 @@ from . import initialize, io, math, utils from .initialize import ( Config, ContinuationParamSequence, - Params, + Parameters, ParamSequence, RecoveryParamSequence, ) -from .io import Paths, load_params, load_toml +from .io import Paths, load_toml from .math import abs2, argclosest, span from .physics import fiber, materials, pulse, simulate, units from .physics.simulate import RK4IP, new_simulation, resume_simulations from .physics.units import PlotRange from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot from .spectra import Pulse, Spectrum -from .utils.parameter import BareConfig, BareParams +from .utils.parameter import BareConfig, Parameters diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 3db90c7..4f6f7b2 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -15,25 +15,12 @@ from .utils import override_config, required_simulations from .utils.evaluator import Evaluator from .utils.parameter import ( BareConfig, - BareParams, + Parameters, hc_model_specific_parameters, mandatory_parameters, ) -@dataclass -class Params(BareParams): - @classmethod - def from_bare(cls, bare: BareParams): - param_dict = {k: v for k, v in asdict(bare).items() if v is not None} - evaluator = Evaluator.default() - evaluator.set(**param_dict) - for p_name in mandatory_parameters: - evaluator.compute(p_name) - new_param_dict = {k: v for k, v in evaluator.params.items() if k in param_dict} - return cls(**new_param_dict) - - @dataclass class Config(BareConfig): @classmethod @@ -222,11 +209,11 @@ class ParamSequence: self.update_num_sim() - def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: + def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Parameters]]: """iterates through all possible parameters, yielding a config as well as a flattened computed parameters set each time""" - for variable_list, bare_params in required_simulations(self.config): - yield variable_list, Params.from_bare(bare_params) + for variable_list, params in required_simulations(self.config): + yield variable_list, params def __len__(self): return self.num_sim @@ -259,19 +246,19 @@ class ContinuationParamSequence(ParamSequence): new config """ self.prev_sim_dir = Path(prev_sim_dir) - self.bare_configs = io.load_config_sequence(new_config.previous_config_file) + self.bare_configs = BareConfig.load_sequence(new_config.previous_config_file) self.bare_configs.append(new_config) self.bare_configs[0] = Config.from_bare(self.bare_configs[0]) final_config = utils.final_config_from_sequence(*self.bare_configs) super().__init__(final_config) - def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: + def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Parameters]]: """iterates through all possible parameters, yielding a config as well as a flattened computed parameters set each time""" - for variable_list, bare_params in required_simulations(*self.bare_configs): + for variable_list, params in required_simulations(*self.bare_configs): prev_data_dir = self.find_prev_data_dirs(variable_list)[0] - bare_params.prev_data_dir = str(prev_data_dir.resolve()) - yield variable_list, Params.from_bare(bare_params) + params.prev_data_dir = str(prev_data_dir.resolve()) + yield variable_list, params def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]: """finds the previous simulation data that this new config should start from @@ -324,7 +311,7 @@ class RecoveryParamSequence(ParamSequence): self.prev_sim_dir = None if self.config.prev_sim_dir is not None: self.prev_sim_dir = Path(self.config.prev_sim_dir) - init_config = io.load_config(self.prev_sim_dir / "initial_config.toml") + init_config = BareConfig.load(self.prev_sim_dir / "initial_config.toml") self.prev_variable_lists = [ ( set(variable_list[1:]), @@ -357,17 +344,17 @@ class RecoveryParamSequence(ParamSequence): self.num_steps += not_started * self.config.z_num self.single_sim = self.num_sim == 1 - def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: - for variable_list, bare_params in required_simulations(self.config): + def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Parameters]]: + for variable_list, params in required_simulations(self.config): data_dir = io.get_sim_dir(self.id) / utils.format_variable_list(variable_list) if not data_dir.is_dir() or io.find_last_spectrum_num(data_dir) == 0: if (prev_data_dir := self.find_prev_data_dir(variable_list)) is not None: - bare_params.prev_data_dir = str(prev_data_dir) - yield variable_list, Params.from_bare(bare_params) + params.prev_data_dir = str(prev_data_dir) + yield variable_list, params elif io.num_left_to_propagate(data_dir, self.config.z_num) != 0: - yield variable_list, recover_params(bare_params, data_dir) + yield variable_list, params + "Needs to rethink recovery procedure" else: continue @@ -417,7 +404,7 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]: """ previous = None - configs = io.load_config_sequence(*configs) + configs = BareConfig.load_sequence(*configs) for config in configs: # if (p := Path(config)).is_dir(): # config = p / "initial_config.toml" @@ -487,22 +474,20 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]: # raise TypeError("not enough parameter to determine time vector") -def recover_params(params: BareParams, data_folder: Path) -> Params: - params = Params.from_bare(params) - try: - prev = io.load_params(data_folder / "params.toml") - build_sim_grid_in_place(prev) - except FileNotFoundError: - prev = BareParams() - for k, v in filter(lambda el: el[1] is not None, vars(prev).items()): - if getattr(params, k) is None: - setattr(params, k, v) - num, last_spectrum = io.load_last_spectrum(data_folder) - params.spec_0 = last_spectrum - params.field_0 = np.fft.ifft(last_spectrum) - params.recovery_last_stored = num - params.cons_qty = np.load(data_folder / "cons_qty.npy") - return params +# def recover_params(params: Parameters, data_folder: Path) -> Parameters: +# try: +# prev = Parameters.load(data_folder / "params.toml") +# except FileNotFoundError: +# prev = Parameters() +# for k, v in filter(lambda el: el[1] is not None, vars(prev).items()): +# if getattr(params, k) is None: +# setattr(params, k, v) +# num, last_spectrum = io.load_last_spectrum(data_folder) +# params.spec_0 = last_spectrum +# params.field_0 = np.fft.ifft(last_spectrum) +# params.recovery_last_stored = num +# params.cons_qty = np.load(data_folder / "cons_qty.npy") +# return params # def build_sim_grid( diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index ea57208..6e7c173 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -1,20 +1,22 @@ +from __future__ import annotations + import itertools import os import shutil from pathlib import Path -from typing import Any, Callable, Dict, Generator, List, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Tuple import numpy as np import pkg_resources as pkg import toml +from tqdm.std import Bar -from . import env, utils +from scgenerator.utils.parameter import BareConfig + +from . import env from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__ from .env import TMP_FOLDER_KEY_BASE - -from .errors import IncompleteDataFolderError from .logger import get_logger -from .utils.parameter import BareConfig, BareParams, translate PathTree = List[Tuple[Path, ...]] @@ -98,7 +100,9 @@ def save_toml(path: os.PathLike, dico): return dico -def save_parameters(params: BareParams, destination_dir: Path, file_name="params.toml") -> Path: +def save_parameters( + params: dict[str, Any], destination_dir: Path, file_name: str = "params.toml" +) -> Path: """saves a parameter dictionary. Note that is does remove some entries, particularly those that take a lot of space ("t", "w", ...) @@ -114,87 +118,17 @@ def save_parameters(params: BareParams, destination_dir: Path, file_name="params Path path to newly created the paramter file """ - param = params.prepare_for_dump() file_path = destination_dir / file_name file_path.parent.mkdir(exist_ok=True) # save toml of the simulation with open(file_path, "w") as file: - toml.dump(param, file, encoder=toml.TomlNumpyEncoder()) + toml.dump(params, file, encoder=toml.TomlNumpyEncoder()) return file_path -def load_params(path: os.PathLike) -> BareParams: - """loads a parameters toml files and converts data to appropriate type - It is advised to run initialize.build_sim_grid to recover some parameters that are not saved. - - Parameters - ---------- - path : PathLike - path to the toml - - Returns - ---------- - BareParams - params obj - """ - params = load_toml(path) - try: - return BareParams(**params) - except TypeError: - return BareParams(**dict(translate(p, v) for p, v in params.items())) - - -def load_config(path: os.PathLike) -> BareConfig: - """loads a parameters toml files and converts data to appropriate type - It is advised to run initialize.build_sim_grid to recover some parameters that are not saved. - - Parameters - ---------- - path : PathLike - path to the toml - - Returns - ---------- - BareParams - config obj - """ - config = load_toml(path) - return BareConfig(**config) - - -def load_config_sequence(*config_paths: os.PathLike) -> list[BareConfig]: - """Loads a sequence of - - Parameters - ---------- - config_paths : os.PathLike - either one path (the last config containing previous_config_file parameter) - or a list of config path in the order they have to be simulated - - Returns - ------- - list[BareConfig] - all loaded configs - """ - if config_paths[0] is None: - return [] - all_configs = [load_config(config_paths[0])] - if len(config_paths) == 1: - while True: - if all_configs[0].previous_config_file is not None: - all_configs.insert(0, load_config(all_configs[0].previous_config_file)) - else: - break - else: - for i, path in enumerate(config_paths[1:]): - all_configs.append(load_config(path)) - all_configs[i + 1].previous_config_file = config_paths[i] - return all_configs - - def load_material_dico(name: str) -> dict[str, Any]: """loads a material dictionary Parameters @@ -228,74 +162,6 @@ def get_data_dirs(sim_dir: Path) -> List[Path]: return [p.resolve() for p in sim_dir.glob("*") if p.is_dir()] -def check_data_integrity(sub_folders: List[Path], init_z_num: int): - """checks the integrity and completeness of a simulation data folder - - Parameters - ---------- - path : str - path to the data folder - init_z_num : int - z_num as specified by the initial configuration file - - Raises - ------ - IncompleteDataFolderError - raised if not all spectra are present in any folder - """ - - for sub_folder in utils.PBars(sub_folders, "Checking integrity"): - if num_left_to_propagate(sub_folder, init_z_num) != 0: - raise IncompleteDataFolderError( - f"not enough spectra of the specified {init_z_num} found in {sub_folder}" - ) - - -def num_left_to_propagate(sub_folder: Path, init_z_num: int) -> int: - """checks if a propagation has completed - - Parameters - ---------- - sub_folder : Path - path to the sub folder containing the spectra - init_z_num : int - number of z position to store as specified in the master config file - - Returns - ------- - bool - True if the propagation has completed - - Raises - ------ - IncompleteDataFolderError - raised if init_z_num doesn't match that specified in the individual parameter file - """ - z_num = load_toml(sub_folder / "params.toml")["z_num"] - num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing - - if z_num != init_z_num: - raise IncompleteDataFolderError( - f"initial config specifies {init_z_num} spectra per" - + f" but the parameter file in {sub_folder} specifies {z_num}" - ) - - return z_num - num_spectra - - -def find_last_spectrum_num(data_dir: Path): - for num in itertools.count(1): - p_to_test = data_dir / SPEC1_FN.format(num) - if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0: - return num - 1 - - -def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]: - """return the last spectrum stored in path as well as its id""" - num = find_last_spectrum_num(data_dir) - return num, np.load(data_dir / SPEC1_FN.format(num)) - - def update_appended_params(source: Path, destination: Path, z: Sequence): z_num = len(z) params = load_toml(source) diff --git a/src/scgenerator/physics/__init__.py b/src/scgenerator/physics/__init__.py index 6f5ed25..4e5154e 100644 --- a/src/scgenerator/physics/__init__.py +++ b/src/scgenerator/physics/__init__.py @@ -8,8 +8,7 @@ from typing import TypeVar import numpy as np from scipy.optimize import minimize_scalar -from ..io import load_material_dico -from .. import math +from .. import math, io from . import fiber, materials, units, pulse T = TypeVar("T") @@ -54,7 +53,7 @@ def material_dispersion( order = np.argsort(w) - material_dico = load_material_dico(material) + material_dico = io.load_material_dico(material) if ideal: n_gas_2 = materials.sellmeier(wavelengths, material_dico, pressure, temperature) + 1 else: diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 0cf4486..844e56d 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -2,10 +2,8 @@ from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union, T import numpy as np from numpy.ma import core -import toml from numpy.fft import fft, ifft from numpy.polynomial.chebyshev import Chebyshev, cheb2poly -from numpy.polynomial.polynomial import Polynomial from scipy.interpolate import interp1d from ..logger import get_logger diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 33fdcf6..c6b8409 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -318,7 +318,8 @@ def L_sol(L_D): def load_previous_spectrum(prev_data_dir: str) -> np.ndarray: - return io.load_last_spectrum(Path(prev_data_dir))[1] + num = utils.find_last_spectrum_num(data_dir) + return np.load(data_dir / SPEC1_FN.format(num)) def load_field_file( diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 211943f..139dc37 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -8,6 +8,7 @@ from typing import Dict, List, Tuple, Type, Union import numpy as np from .. import env, initialize, io, utils +from ..utils import Parameters, BareConfig from ..const import PARAM_SEPARATOR from ..errors import IncompleteDataFolderError from ..logger import get_logger @@ -23,29 +24,29 @@ except ModuleNotFoundError: class RK4IP: def __init__( self, - params: initialize.Params, + params: Parameters, save_data=False, job_identifier="", task_id=0, ): """A 1D solver using 4th order Runge-Kutta in the interaction picture + Parameters + ---------- Parameters - ---------- - params : Params - parameters of the simulation - save_data : bool, optional - save calculated spectra to disk, by default False - job_identifier : str, optional - string identifying the parameter set, by default "" - task_id : int, optional - unique identifier of the session, by default 0 + parameters of the simulation + save_data : bool, optional + save calculated spectra to disk, by default False + job_identifier : str, optional + string identifying the parameter set, by default "" + task_id : int, optional + unique identifier of the session, by default 0 """ self.set(params, save_data, job_identifier, task_id) def set( self, - params: initialize.Params, + params: Parameters, save_data=False, job_identifier="", task_id=0, @@ -306,7 +307,7 @@ class RK4IP: class SequentialRK4IP(RK4IP): def __init__( self, - params: initialize.Params, + params: Parameters, pbars: utils.PBars, save_data=False, job_identifier="", @@ -327,7 +328,7 @@ class SequentialRK4IP(RK4IP): class MutliProcRK4IP(RK4IP): def __init__( self, - params: initialize.Params, + params: Parameters, p_queue: multiprocessing.Queue, worker_id: int, save_data=False, @@ -353,7 +354,7 @@ class RayRK4IP(RK4IP): def set( self, - params: initialize.Params, + params: Parameters, p_actor, worker_id: int, save_data=False, @@ -445,7 +446,9 @@ class Simulations: self.sim_dir = io.get_sim_dir( self.id, path_if_new=Path(self.name + PARAM_SEPARATOR + "tmp") ) - io.save_parameters(self.param_seq.config, self.sim_dir, file_name="initial_config.toml") + io.save_parameters( + self.param_seq.config.prepare_for_dump(), self.sim_dir, file_name="initial_config.toml" + ) self.sim_jobs_per_node = 1 @@ -467,19 +470,19 @@ class Simulations: def _run_available(self): for variable, params in self.param_seq: v_list_str = utils.format_variable_list(variable) - io.save_parameters(params, self.sim_dir / v_list_str) + io.save_parameters(params.prepare_for_dump(), self.sim_dir / v_list_str) self.new_sim(v_list_str, params) self.finish() - def new_sim(self, v_list_str: str, params: initialize.Params): + def new_sim(self, v_list_str: str, params: Parameters): """responsible to launch a new simulation Parameters ---------- v_list_str : str string that uniquely identifies the simulation as returned by utils.format_variable_list - params : initialize.Params + params : Parameters computed parameters """ raise NotImplementedError() @@ -507,7 +510,7 @@ class SequencialSimulations(Simulations, priority=0): super().__init__(param_seq, task_id=task_id) self.pbars = utils.PBars(self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1) - def new_sim(self, v_list_str: str, params: initialize.Params): + def new_sim(self, v_list_str: str, params: Parameters): self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}") SequentialRK4IP( params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id @@ -556,7 +559,7 @@ class MultiProcSimulations(Simulations, priority=1): worker.start() super().run() - def new_sim(self, v_list_str: str, params: initialize.Params): + def new_sim(self, v_list_str: str, params: Parameters): self.queue.put((v_list_str, params), block=True, timeout=None) def finish(self): @@ -579,7 +582,7 @@ class MultiProcSimulations(Simulations, priority=1): p_queue: multiprocessing.Queue, ): while True: - raw_data: Tuple[List[tuple], initialize.Params] = queue.get() + raw_data: Tuple[List[tuple], Parameters] = queue.get() if raw_data == 0: queue.task_done() return @@ -635,7 +638,7 @@ class RaySimulations(Simulations, priority=2): .remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps) ) - def new_sim(self, v_list_str: str, params: initialize.Params): + def new_sim(self, v_list_str: str, params: Parameters): while self.num_submitted >= self.sim_jobs_total: self.collect_1_job() @@ -685,7 +688,7 @@ def run_simulation_sequence( method=None, prev_sim_dir: os.PathLike = None, ): - configs = io.load_config_sequence(*config_files) + configs = BareConfig.load_sequence(*config_files) prev = prev_sim_dir for config in configs: diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index f778297..df1f73e 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -5,7 +5,7 @@ from typing import Callable, TypeVar, Union from dataclasses import dataclass -from ..utils.parameter import Parameter, boolean, type_checker +from ..utils import parameter import numpy as np from numpy import pi @@ -183,10 +183,10 @@ def is_unit(name, value): @dataclass class PlotRange: - left: float = Parameter(type_checker(int, float)) - right: float = Parameter(type_checker(int, float)) - unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit) - conserved_quantity: bool = Parameter(boolean, default=True) + left: float = parameter.Parameter(parameter.type_checker(int, float)) + right: float = parameter.Parameter(parameter.type_checker(int, float)) + unit: Callable[[float], float] = parameter.Parameter(is_unit, converter=get_unit) + conserved_quantity: bool = parameter.Parameter(parameter.boolean, default=True) def __str__(self): return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}" diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 2d96efc..52cd429 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -21,7 +21,7 @@ from . import io, math from .defaults import default_plotting as defaults from .math import abs2, make_uniform_1D, span from .physics import pulse, units -from .utils.parameter import BareConfig, BareParams +from .utils.parameter import BareConfig, Parameters RangeType = Tuple[float, float, Union[str, Callable]] NO_LIM = object() @@ -263,7 +263,7 @@ def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0 def propagation_plot( values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], - params: BareParams, + params: Parameters, ax: plt.Axes, log: Union[int, float, bool, str] = "1D", vmin: float = None, @@ -281,7 +281,7 @@ def propagation_plot( raw values, either complex fields or complex spectra plt_range : Union[units.PlotRange, RangeType] time, wavelength or frequency range - params : BareParams + params : Parameters parameters of the simulation log : Union[int, float, bool, str], optional what kind of log to apply, see apply_log for details. by default "1D" @@ -418,7 +418,7 @@ def plot_2D( def transform_2D_propagation( values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], - params: BareParams, + params: Parameters, log: Union[int, float, bool, str] = "1D", skip: int = 1, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: @@ -430,7 +430,7 @@ def transform_2D_propagation( values to transform plt_range : Union[units.PlotRange, RangeType] range - params : BareParams + params : Parameters parameters of the simulation log : Union[int, float, bool, str], optional see apply_log, by default "1D" @@ -469,7 +469,7 @@ def transform_2D_propagation( def mean_values_plot( values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], - params: BareParams, + params: Parameters, ax: plt.Axes, log: Union[float, int, str, bool] = False, vmin: float = None, @@ -511,7 +511,7 @@ def mean_values_plot( def transform_mean_values( values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], - params: BareParams, + params: Parameters, log: Union[bool, int, float] = False, spacing: Union[int, float] = 1, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: @@ -523,7 +523,7 @@ def transform_mean_values( values to transform plt_range : Union[units.PlotRange, RangeType] x axis specifications - params : BareParams + params : Parameters parameters of the simulation log : Union[bool, int, float], optional see transform_1D_values for details, by default False @@ -637,7 +637,7 @@ def plot_mean( def single_position_plot( values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], - params: BareParams, + params: Parameters, ax: plt.Axes, log: Union[str, int, float, bool] = False, vmin: float = None, @@ -712,7 +712,7 @@ def plot_1D( def transform_1D_values( values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], - params: BareParams, + params: Parameters, log: Union[int, float, bool] = False, spacing: Union[int, float] = 1, ) -> tuple[np.ndarray, np.ndarray]: @@ -724,7 +724,7 @@ def transform_1D_values( values to plot, may be complex plt_range : Union[units.PlotRange, RangeType] plot range specification, either (min, max, unit) or a PlotRange obj - params : BareParams + params : Parameters parameters of the simulations log : Union[int, float, bool], optional if True, will convert to dB relative to max. If a float or int, whill @@ -767,7 +767,7 @@ def plot_spectrogram( values: np.ndarray, x_range: RangeType, y_range: RangeType, - params: BareParams, + params: Parameters, t_res: int = None, gate_width: float = None, log: bool = "2D", @@ -790,7 +790,7 @@ def plot_spectrogram( units : function to convert from the desired units to rad/s or to time. common functions are already defined in scgenerator.physics.units look there for more details - params : BareParams + params : Parameters parameters of the simulations log : bool, optional whether to compute the logarithm of the spectrogram @@ -954,7 +954,7 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr def prep_plot_axis( - values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], params: BareParams + values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], params: Parameters ) -> tuple[bool, np.ndarray, units.PlotRange]: is_spectrum = values.dtype == "complex" if not isinstance(plt_range, units.PlotRange): diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index 751ca20..a935a73 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt import numpy as np from tqdm import tqdm -from ..utils.parameter import BareParams +from ..utils.parameter import Parameters from ..const import PARAM_SEPARATOR from ..initialize import ParamSequence @@ -19,7 +19,7 @@ from ..plotting import plot_setup from .. import env, math -def fingerprint(params: BareParams): +def fingerprint(params: Parameters): h1 = hash(params.field_0.tobytes()) h2 = tuple(params.beta2_coefficients) return h1, h2 @@ -160,7 +160,7 @@ def plot_1_dispersion( right: plt.Axes, style: dict[str, Any], lbl: list[str], - params: BareParams, + params: Parameters, loss: plt.Axes = None, ): beta_arr = fiber.dispersion_from_coefficients(params.w_c, params.beta2_coefficients) @@ -253,7 +253,7 @@ def finish_plot(fig, legend_axes, all_labels, params): plt.show() -def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], BareParams]]: +def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]: cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"]) pseq = ParamSequence(config_path) for style, (variables, params) in zip(cc, pseq): diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index 3345624..8a47a8a 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -10,7 +10,8 @@ from typing import Tuple import numpy as np from ..initialize import validate_config_sequence -from ..io import Paths, load_config +from ..io import Paths +from ..utils.parameter import BareConfig def primes(n): @@ -127,7 +128,7 @@ def main(): ) if args.command == "merge": - final_name = load_config(Path(args.configs[0]) / "initial_config.toml").name + final_name = BareConfig.load(Path(args.configs[0]) / "initial_config.toml").name sim_num = "many" args.nodes = 1 args.cpus_per_node = 1 diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index ed1897a..8d95d5c 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -11,13 +11,13 @@ from .const import SPECN_FN from .logger import get_logger from .physics import pulse, units from .plotting import mean_values_plot, propagation_plot, single_position_plot -from .utils.parameter import BareParams +from .utils.parameter import Parameters class Spectrum(np.ndarray): - params: BareParams + params: Parameters - def __new__(cls, input_array, params: BareParams): + def __new__(cls, input_array, params: Parameters): # Input array is an already formed ndarray instance # We first cast to be our class type obj = np.asarray(input_array).view(cls) @@ -144,7 +144,7 @@ class Pulse(Sequence): if not self.path.is_dir(): raise FileNotFoundError(f"Folder {self.path} does not exist") - self.params = io.load_params(self.path / "params.toml") + self.params = Parameters.load(self.path / "params.toml") initialize.build_sim_grid_in_place(self.params) diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index 1cb6343..4700857 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -16,15 +16,17 @@ from dataclasses import asdict, replace from io import StringIO from pathlib import Path from typing import Any, Iterable, Iterator, TypeVar, Union +from ..errors import IncompleteDataFolderError import numpy as np from numpy.lib.arraysetops import isin from tqdm import tqdm from .. import env -from ..const import PARAM_SEPARATOR +from ..const import PARAM_SEPARATOR, SPEC1_FN from ..math import * -from .parameter import BareConfig, BareParams +from .. import io +from .parameter import BareConfig, Parameters T_ = TypeVar("T_") @@ -212,7 +214,7 @@ def format_value(name: str, value) -> str: return str(value) elif isinstance(value, (float, int)): try: - return getattr(BareParams, name).display(value) + return getattr(Parameters, name).display(value) except AttributeError: return format(value, ".9g") elif isinstance(value, (list, tuple, np.ndarray)): @@ -226,7 +228,7 @@ def format_value(name: str, value) -> str: def pretty_format_value(name: str, value) -> str: try: - return getattr(BareParams, name).display(value) + return getattr(Parameters, name).display(value) except AttributeError: return name + PARAM_SEPARATOR + str(value) @@ -248,12 +250,74 @@ def pretty_format_from_sim_name(name: str) -> str: out = [] for key, value in zip(s[::2], s[1::2]): try: - out += [key.replace("_", " "), getattr(BareParams, key).display(float(value))] + out += [key.replace("_", " "), getattr(Parameters, key).display(float(value))] except (AttributeError, ValueError): out.append(key + PARAM_SEPARATOR + value) return PARAM_SEPARATOR.join(out) +def check_data_integrity(sub_folders: list[Path], init_z_num: int): + """checks the integrity and completeness of a simulation data folder + + Parameters + ---------- + path : str + path to the data folder + init_z_num : int + z_num as specified by the initial configuration file + + Raises + ------ + IncompleteDataFolderError + raised if not all spectra are present in any folder + """ + + for sub_folder in PBars(sub_folders, "Checking integrity"): + if num_left_to_propagate(sub_folder, init_z_num) != 0: + raise IncompleteDataFolderError( + f"not enough spectra of the specified {init_z_num} found in {sub_folder}" + ) + + +def num_left_to_propagate(sub_folder: Path, init_z_num: int) -> int: + """checks if a propagation has completed + + Parameters + ---------- + sub_folder : Path + path to the sub folder containing the spectra + init_z_num : int + number of z position to store as specified in the master config file + + Returns + ------- + bool + True if the propagation has completed + + Raises + ------ + IncompleteDataFolderError + raised if init_z_num doesn't match that specified in the individual parameter file + """ + z_num = io.load_toml(sub_folder / "params.toml")["z_num"] + num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing + + if z_num != init_z_num: + raise IncompleteDataFolderError( + f"initial config specifies {init_z_num} spectra per" + + f" but the parameter file in {sub_folder} specifies {z_num}" + ) + + return z_num - num_spectra + + +def find_last_spectrum_num(data_dir: Path): + for num in itertools.count(1): + p_to_test = data_dir / SPEC1_FN.format(num) + if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0: + return num - 1 + + def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]: """given a config with "variable" parameters, iterates through every possible combination, yielding a a list of (parameter_name, value) tuples and a full config dictionary. @@ -268,7 +332,7 @@ def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any] Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]] variable_list : a list of (name, value) tuple of parameter name and value that are variable. - params : a dict[str, Any] to be fed to Params + params : a dict[str, Any] to be fed to Parameters """ possible_keys = [] possible_ranges = [] @@ -294,7 +358,7 @@ def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any] def required_simulations( *configs: BareConfig, -) -> Iterator[tuple[list[tuple[str, Any]], BareParams]]: +) -> Iterator[tuple[list[tuple[str, Any]], Parameters]]: """takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different parameter set and iterates through every single necessary simulation @@ -317,7 +381,7 @@ def required_simulations( for j in range(configs[0].repeat or 1): variable_ind = [("id", i)] + variable_only + [("num", j)] i += 1 - yield variable_ind, BareParams(**params_dict) + yield variable_ind, Parameters(**params_dict) def reduce_all_variable(all_variable: list[list[tuple[str, Any]]]) -> list[tuple[str, Any]]: diff --git a/src/scgenerator/utils/evaluator.py b/src/scgenerator/utils/evaluator.py deleted file mode 100644 index be5ac3d..0000000 --- a/src/scgenerator/utils/evaluator.py +++ /dev/null @@ -1,392 +0,0 @@ -import itertools -import re -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, Callable, Optional, TypeVar, Union - -import numpy as np - -from .. import math -from ..logger import get_logger -from ..physics import fiber, materials, pulse, units - -T = TypeVar("T") -import inspect - - -class EvaluatorError(Exception): - pass - - -class Rule: - def __init__( - self, - target: Union[str, list[Optional[str]]], - func: Callable, - args: list[str] = None, - priorities: Union[int, list[int]] = None, - conditions: dict[str, str] = None, - ): - targets = list(target) if isinstance(target, (list, tuple)) else [target] - self.func = func - if priorities is None: - priorities = [1] * len(targets) - elif isinstance(priorities, (int, float, np.integer, np.floating)): - priorities = [priorities] - self.targets = dict(zip(targets, priorities)) - if args is None: - args = get_arg_names(func) - self.args = args - self.conditions = conditions or {} - - def __repr__(self) -> str: - return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})" - - @classmethod - def deduce( - cls, - target: Union[str, list[Optional[str]]], - func: Callable, - kwarg_names: list[str], - n_var: int, - args_const: list[str] = None, - ) -> list["Rule"]: - """given a function that doesn't need all its keyword arguemtn specified, will - return a list of Rule obj, one for each combination of n_var specified kwargs - - Parameters - ---------- - target : str | list[str | None] - name of the variable(s) that func returns - func : Callable - function to work with - kwarg_names : list[str] - list of all kwargs of the function to be used - n_var : int - how many shoulf be used per rule - arg_const : list[str], optional - override the name of the positional arguments - - Returns - ------- - list[Rule] - list of all possible rules - - Example - ------- - >> def lol(a, b=None, c=None): - pass - >> print(Rule.deduce(["d"], lol, ["b", "c"], 1)) - [ - Rule(targets={'d': 1}, func=, args=['a', 'b']), - Rule(targets={'d': 1}, func=, args=['a', 'c']) - ] - """ - rules: list[cls] = [] - for var_possibility in itertools.combinations(kwarg_names, n_var): - - new_func = func_rewrite(func, list(var_possibility), args_const) - - rules.append(cls(target, new_func)) - return rules - - -@dataclass -class EvalStat: - priority: float = np.inf - - -class Evaluator: - @classmethod - def default(cls) -> "Evaluator": - evaluator = cls() - evaluator.append(*default_rules) - return evaluator - - def __init__(self): - self.rules: dict[str, list[Rule]] = defaultdict(list) - self.params = {} - self.__curent_lookup = set() - self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat) - self.logger = get_logger(__name__) - - def append(self, *rule: Rule): - for r in rule: - for t in r.targets: - if t is not None: - self.rules[t].append(r) - self.rules[t].sort(key=lambda el: el.targets[t], reverse=True) - - def set(self, **params: Any): - self.params.update(params) - for k in params: - self.eval_stats[k].priority = np.inf - - def reset(self): - self.params = {} - self.eval_stats = defaultdict(EvalStat) - - def compute(self, target: str) -> Any: - """computes a target - - Parameters - ---------- - target : str - name of the target - - Returns - ------- - Any - return type of the target function - - Raises - ------ - EvaluatorError - a cyclic dependence exists - KeyError - there is no saved rule for the target - """ - value = self.params.get(target) - if value is None: - if target in self.__curent_lookup: - raise EvaluatorError( - "cyclic dependency detected : " - f"{target!r} seems to depend on itself, " - f"please provide a value for at least one variable in {self.__curent_lookup}" - ) - else: - self.__curent_lookup.add(target) - - if len(self.rules[target]) == 0: - raise EvaluatorError(f"no rule for {target}") - - error = None - for ii, rule in enumerate( - filter(lambda r: self.validate_condition(r), reversed(self.rules[target])) - ): - self.logger.debug(f"attempt {ii+1} to compute {target}, this time using {rule!r}") - try: - args = [self.compute(k) for k in rule.args] - returned_values = rule.func(*args) - if len(rule.targets) == 1: - returned_values = [returned_values] - for ((param_name, param_priority), returned_value) in zip( - rule.targets.items(), returned_values - ): - if ( - param_name == target - or param_name not in self.params - or self.eval_stats[param_name].priority < param_priority - ): - self.logger.info( - f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}" - ) - self.params[param_name] = returned_value - self.eval_stats[param_name] = param_priority - if param_name == target: - value = returned_value - break - except (EvaluatorError, KeyError) as e: - error = e - continue - - if value is None and error is not None: - raise error - - self.__curent_lookup.remove(target) - return value - - def validate_condition(self, rule: Rule) -> bool: - return all(self.compute(k) == v for k, v in rule.conditions.items()) - - def __call__(self, target: str, args: list[str] = None): - """creates a wrapper that adds decorated functions to the set of rules - - Parameters - ---------- - target : str - name of the target - args : list[str], optional - list of name of arguments. Automatically deduced from function signature if - not provided, by default None - """ - - def wrapper(func): - self.append(Rule(target, func, args)) - return func - - return wrapper - - -def get_arg_names(func: Callable) -> list[str]: - spec = inspect.getfullargspec(func) - args = spec.args - if spec.defaults is not None and len(spec.defaults) > 0: - args = args[: -len(spec.defaults)] - return args - - -def validate_arg_names(names: list[str]): - for n in names: - if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: - raise ValueError(f"{n} is an invalid parameter name") - - -def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None): - if arg_names is None: - arg_names = get_arg_names(func) - else: - validate_arg_names(arg_names) - validate_arg_names(kwarg_names) - sign_arg_str = ", ".join(arg_names + kwarg_names) - call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) - tmp_name = f"{func.__name__}_0" - func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" - scope = dict(__func__=func) - exec(func_str, scope) - out_func = scope[tmp_name] - out_func.__module__ = "evaluator" - return out_func - - -default_rules: list[Rule] = [ - # Grid - *Rule.deduce( - ["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"], - math.build_sim_grid, - ["time_window", "t_num", "dt"], - 2, - ), - # Pulse - Rule("spec_0", np.fft.fft, ["field_0"]), - Rule("field_0", np.fft.ifft, ["spec_0"]), - Rule("spec_0", pulse.load_previous_spectrum, priorities=3), - Rule( - ["pre_field_0", "peak_power", "energy", "width"], - pulse.load_field_file, - [ - "field_file", - "t", - "peak_power", - "energy", - "intensity_noise", - "noise_correlation", - "quantum_noise", - "w_c", - "w0", - "time_window", - "dt", - ], - priorities=[2, 1, 1, 1], - ), - Rule("pre_field_0", pulse.initial_field, priorities=1), - Rule( - "field_0", - pulse.add_shot_noise, - ["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"], - ), - Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), - Rule("peak_power", pulse.soliton_num_to_peak_power), - Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]), - Rule("energy", pulse.mean_power_to_energy), - Rule("t0", pulse.width_to_t0), - Rule("t0", pulse.soliton_num_to_t0), - Rule("width", pulse.t0_to_width), - Rule("soliton_num", pulse.soliton_num), - Rule("L_D", pulse.L_D), - Rule("L_NL", pulse.L_NL), - Rule("L_sol", pulse.L_sol), - # Fiber Dispersion - Rule("wl_for_disp", fiber.lambda_for_dispersion), - Rule("w_for_disp", units.m, ["wl_for_disp"]), - Rule( - "beta2_coefficients", - fiber.dispersion_coefficients, - ["wl_for_disp", "beta2_arr", "w0", "interpolation_range", "interpolation_degree"], - ), - Rule("beta2_arr", fiber.beta2), - Rule("beta2_arr", fiber.dispersion_from_coefficients), - Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]), - Rule( - ["wl_for_disp", "beta2_arr", "interpolation_range"], - fiber.load_custom_dispersion, - priorities=[2, 2, 2], - ), - Rule("hr_w", fiber.delayed_raman_w), - Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")), - Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")), - Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")), - Rule( - "n_eff", - fiber.n_eff_pcf, - ["wl_for_disp", "pitch", "pitch_ratio"], - conditions=dict(model="pcf"), - ), - Rule("capillary_spacing", fiber.HCARF_gap), - # Fiber nonlinearity - Rule("A_eff", fiber.A_eff_from_V), - Rule("A_eff", fiber.A_eff_from_diam), - Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")), - Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1), - Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]), - Rule("A_eff_arr", fiber.load_custom_A_eff), - Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1), - Rule( - "V_eff", - fiber.V_parameter_koshiba, - ["wavelength", "pitch", "pitch_ratio"], - conditions=dict(model="pcf"), - ), - Rule("V_eff", fiber.V_eff_marcuse, ["wavelength", "core_radius", "numerical_aperture"]), - Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")), - Rule("V_eff_arr", fiber.V_eff_marcuse), - Rule("gamma", lambda gamma_arr: gamma_arr[0]), - Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]), - # Fiber loss - Rule("alpha", fiber.compute_capillary_loss), - Rule("alpha", fiber.load_custom_loss), - # gas - Rule("n_gas_2", materials.n_gas_2), -] - - -def main(): - import matplotlib.pyplot as plt - - evalor = Evaluator() - evalor.append(*default_rules) - evalor.set( - **{ - "length": 1, - "z_num": 128, - "wavelength": 1500e-9, - "interpolation_degree": 8, - "interpolation_range": (500e-9, 2200e-9), - "t_num": 16384, - "dt": 1e-15, - "shape": "gaussian", - "repetition_rate": 40e6, - "width": 30e-15, - "mean_power": 100e-3, - "n2": 2.4e-20, - "A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_marcuse.npz", - "model": "pcf", - "quantum_noise": True, - "pitch": 1.2e-6, - "pitch_ratio": 0.5, - } - ) - evalor.compute("z_targets") - print(evalor.params.keys()) - print(evalor.params["l"][evalor.params["l"] > 0].min()) - evalor.compute("spec_0") - plt.plot(evalor.params["l"], abs(evalor.params["spec_0"]) ** 2) - plt.yscale("log") - plt.show() - print(evalor.compute("gamma")) - print(evalor.compute("beta2")) - from pprint import pprint - - -if __name__ == "__main__": - main() diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index fffc323..6981f23 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -1,15 +1,21 @@ import datetime as datetime_module +import inspect +import itertools +import re +from collections import defaultdict from copy import copy from dataclasses import asdict, dataclass from functools import lru_cache -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, TypeVar - +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +import os import numpy as np +from tqdm.std import Bar +from .. import math from ..const import __version__ - -# from .evaluator import Rule, Evaluator -# from ..physics import pulse, fiber, materials +from ..logger import get_logger +from .. import io +from ..physics import fiber, materials, pulse, units T = TypeVar("T") @@ -365,10 +371,10 @@ mandatory_parameters = [ @dataclass -class BareParams: +class Parameters: """ - This class defines each valid parameter's name, type and valid value but doesn't provide - any method to act on those. For that, use initialize.Params + This class defines each valid parameter's name, type and valid value. Initializing + such an obj will automatically compute all possible parameters """ # root @@ -476,11 +482,25 @@ class BareParams: def prepare_for_dump(self) -> Dict[str, Any]: param = asdict(self) - param = BareParams.strip_params_dict(param) + param = Parameters.strip_params_dict(param) param["datetime"] = datetime_module.datetime.now() param["version"] = __version__ return param + def __post_init__(self): + param_dict = {k: v for k, v in asdict(self).items() if v is not None} + evaluator = Evaluator.default() + evaluator.set(**param_dict) + for p_name in mandatory_parameters: + evaluator.compute(p_name) + for k, v in evaluator.params.items(): + if k in param_dict: + setattr(self, k, v) + + @classmethod + def load(cls, path: os.PathLike) -> "Parameters": + return cls(**io.load_toml(path)) + @staticmethod def strip_params_dict(dico: Dict[str, Any]) -> Dict[str, Any]: """prepares a dictionary for serialization. Some keys may not be preserved @@ -513,7 +533,7 @@ class BareParams: if not isinstance(value, types): continue if isinstance(value, dict): - out[key] = BareParams.strip_params_dict(value) + out[key] = Parameters.strip_params_dict(value) elif isinstance(value, np.ndarray) and value.dtype == complex: continue else: @@ -525,9 +545,382 @@ class BareParams: return out +class EvaluatorError(Exception): + pass + + +class Rule: + def __init__( + self, + target: Union[str, list[Optional[str]]], + func: Callable, + args: list[str] = None, + priorities: Union[int, list[int]] = None, + conditions: dict[str, str] = None, + ): + targets = list(target) if isinstance(target, (list, tuple)) else [target] + self.func = func + if priorities is None: + priorities = [1] * len(targets) + elif isinstance(priorities, (int, float, np.integer, np.floating)): + priorities = [priorities] + self.targets = dict(zip(targets, priorities)) + if args is None: + args = get_arg_names(func) + self.args = args + self.conditions = conditions or {} + + def __repr__(self) -> str: + return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})" + + @classmethod + def deduce( + cls, + target: Union[str, list[Optional[str]]], + func: Callable, + kwarg_names: list[str], + n_var: int, + args_const: list[str] = None, + ) -> list["Rule"]: + """given a function that doesn't need all its keyword arguemtn specified, will + return a list of Rule obj, one for each combination of n_var specified kwargs + + Parameters + ---------- + target : str | list[str | None] + name of the variable(s) that func returns + func : Callable + function to work with + kwarg_names : list[str] + list of all kwargs of the function to be used + n_var : int + how many shoulf be used per rule + arg_const : list[str], optional + override the name of the positional arguments + + Returns + ------- + list[Rule] + list of all possible rules + + Example + ------- + >> def lol(a, b=None, c=None): + pass + >> print(Rule.deduce(["d"], lol, ["b", "c"], 1)) + [ + Rule(targets={'d': 1}, func=, args=['a', 'b']), + Rule(targets={'d': 1}, func=, args=['a', 'c']) + ] + """ + rules: list[cls] = [] + for var_possibility in itertools.combinations(kwarg_names, n_var): + + new_func = func_rewrite(func, list(var_possibility), args_const) + + rules.append(cls(target, new_func)) + return rules + + @dataclass -class BareConfig(BareParams): - variable: dict = VariableParameter(BareParams) +class EvalStat: + priority: float = np.inf + + +class Evaluator: + @classmethod + def default(cls) -> "Evaluator": + evaluator = cls() + evaluator.append(*default_rules) + return evaluator + + def __init__(self): + self.rules: dict[str, list[Rule]] = defaultdict(list) + self.params = {} + self.__curent_lookup = set() + self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat) + self.logger = get_logger(__name__) + + def append(self, *rule: Rule): + for r in rule: + for t in r.targets: + if t is not None: + self.rules[t].append(r) + self.rules[t].sort(key=lambda el: el.targets[t], reverse=True) + + def set(self, **params: Any): + self.params.update(params) + for k in params: + self.eval_stats[k].priority = np.inf + + def reset(self): + self.params = {} + self.eval_stats = defaultdict(EvalStat) + + def compute(self, target: str) -> Any: + """computes a target + + Parameters + ---------- + target : str + name of the target + + Returns + ------- + Any + return type of the target function + + Raises + ------ + EvaluatorError + a cyclic dependence exists + KeyError + there is no saved rule for the target + """ + value = self.params.get(target) + if value is None: + if target in self.__curent_lookup: + raise EvaluatorError( + "cyclic dependency detected : " + f"{target!r} seems to depend on itself, " + f"please provide a value for at least one variable in {self.__curent_lookup}" + ) + else: + self.__curent_lookup.add(target) + + if len(self.rules[target]) == 0: + raise EvaluatorError(f"no rule for {target}") + + error = None + for ii, rule in enumerate( + filter(lambda r: self.validate_condition(r), reversed(self.rules[target])) + ): + self.logger.debug(f"attempt {ii+1} to compute {target}, this time using {rule!r}") + try: + args = [self.compute(k) for k in rule.args] + returned_values = rule.func(*args) + if len(rule.targets) == 1: + returned_values = [returned_values] + for ((param_name, param_priority), returned_value) in zip( + rule.targets.items(), returned_values + ): + if ( + param_name == target + or param_name not in self.params + or self.eval_stats[param_name].priority < param_priority + ): + self.logger.info( + f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}" + ) + self.params[param_name] = returned_value + self.eval_stats[param_name] = param_priority + if param_name == target: + value = returned_value + break + except (EvaluatorError, KeyError) as e: + error = e + continue + + if value is None and error is not None: + raise error + + self.__curent_lookup.remove(target) + return value + + def validate_condition(self, rule: Rule) -> bool: + return all(self.compute(k) == v for k, v in rule.conditions.items()) + + def __call__(self, target: str, args: list[str] = None): + """creates a wrapper that adds decorated functions to the set of rules + + Parameters + ---------- + target : str + name of the target + args : list[str], optional + list of name of arguments. Automatically deduced from function signature if + not provided, by default None + """ + + def wrapper(func): + self.append(Rule(target, func, args)) + return func + + return wrapper + + +def get_arg_names(func: Callable) -> list[str]: + spec = inspect.getfullargspec(func) + args = spec.args + if spec.defaults is not None and len(spec.defaults) > 0: + args = args[: -len(spec.defaults)] + return args + + +def validate_arg_names(names: list[str]): + for n in names: + if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: + raise ValueError(f"{n} is an invalid parameter name") + + +def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None): + if arg_names is None: + arg_names = get_arg_names(func) + else: + validate_arg_names(arg_names) + validate_arg_names(kwarg_names) + sign_arg_str = ", ".join(arg_names + kwarg_names) + call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) + tmp_name = f"{func.__name__}_0" + func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" + scope = dict(__func__=func) + exec(func_str, scope) + out_func = scope[tmp_name] + out_func.__module__ = "evaluator" + return out_func + + +default_rules: list[Rule] = [ + # Grid + *Rule.deduce( + ["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"], + math.build_sim_grid, + ["time_window", "t_num", "dt"], + 2, + ), + # Pulse + Rule("spec_0", np.fft.fft, ["field_0"]), + Rule("field_0", np.fft.ifft, ["spec_0"]), + Rule("spec_0", pulse.load_previous_spectrum, priorities=3), + Rule( + ["pre_field_0", "peak_power", "energy", "width"], + pulse.load_field_file, + [ + "field_file", + "t", + "peak_power", + "energy", + "intensity_noise", + "noise_correlation", + "quantum_noise", + "w_c", + "w0", + "time_window", + "dt", + ], + priorities=[2, 1, 1, 1], + ), + Rule("pre_field_0", pulse.initial_field, priorities=1), + Rule( + "field_0", + pulse.add_shot_noise, + ["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"], + ), + Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), + Rule("peak_power", pulse.soliton_num_to_peak_power), + Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]), + Rule("energy", pulse.mean_power_to_energy), + Rule("t0", pulse.width_to_t0), + Rule("t0", pulse.soliton_num_to_t0), + Rule("width", pulse.t0_to_width), + Rule("soliton_num", pulse.soliton_num), + Rule("L_D", pulse.L_D), + Rule("L_NL", pulse.L_NL), + Rule("L_sol", pulse.L_sol), + # Fiber Dispersion + Rule("wl_for_disp", fiber.lambda_for_dispersion), + Rule("w_for_disp", units.m, ["wl_for_disp"]), + Rule( + "beta2_coefficients", + fiber.dispersion_coefficients, + ["wl_for_disp", "beta2_arr", "w0", "interpolation_range", "interpolation_degree"], + ), + Rule("beta2_arr", fiber.beta2), + Rule("beta2_arr", fiber.dispersion_from_coefficients), + Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]), + Rule( + ["wl_for_disp", "beta2_arr", "interpolation_range"], + fiber.load_custom_dispersion, + priorities=[2, 2, 2], + ), + Rule("hr_w", fiber.delayed_raman_w), + Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")), + Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")), + Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")), + Rule( + "n_eff", + fiber.n_eff_pcf, + ["wl_for_disp", "pitch", "pitch_ratio"], + conditions=dict(model="pcf"), + ), + Rule("capillary_spacing", fiber.HCARF_gap), + # Fiber nonlinearity + Rule("A_eff", fiber.A_eff_from_V), + Rule("A_eff", fiber.A_eff_from_diam), + Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")), + Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1), + Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]), + Rule("A_eff_arr", fiber.load_custom_A_eff), + Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1), + Rule( + "V_eff", + fiber.V_parameter_koshiba, + ["wavelength", "pitch", "pitch_ratio"], + conditions=dict(model="pcf"), + ), + Rule("V_eff", fiber.V_eff_marcuse, ["wavelength", "core_radius", "numerical_aperture"]), + Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")), + Rule("V_eff_arr", fiber.V_eff_marcuse), + Rule("gamma", lambda gamma_arr: gamma_arr[0]), + Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]), + # Fiber loss + Rule("alpha", fiber.compute_capillary_loss), + Rule("alpha", fiber.load_custom_loss), + # gas + Rule("n_gas_2", materials.n_gas_2), +] + + +@dataclass +class BareConfig(Parameters): + variable: dict = VariableParameter(Parameters) + + def __post_init__(self): + pass + + @classmethod + def load(cls, path: os.PathLike) -> "BareConfig": + return cls(**io.load_toml(path)) + + @classmethod + def load_sequence(cls, *config_paths: os.PathLike) -> list["BareConfig"]: + """Loads a sequence of + + Parameters + ---------- + config_paths : os.PathLike + either one path (the last config containing previous_config_file parameter) + or a list of config path in the order they have to be simulated + + Returns + ------- + list[BareConfig] + all loaded configs + """ + if config_paths[0] is None: + return [] + all_configs = [cls.load(config_paths[0])] + if len(config_paths) == 1: + while True: + if all_configs[0].previous_config_file is not None: + all_configs.insert(0, cls.load(all_configs[0].previous_config_file)) + else: + break + else: + for i, path in enumerate(config_paths[1:]): + all_configs.append(cls.load(path)) + all_configs[i + 1].previous_config_file = config_paths[i] + return all_configs if __name__ == "__main__": diff --git a/testing/test_initialize.py b/testing/test_initialize.py index 882f5f0..d406453 100644 --- a/testing/test_initialize.py +++ b/testing/test_initialize.py @@ -7,7 +7,7 @@ import toml from scgenerator import defaults, utils, math from scgenerator.errors import * from scgenerator.physics import pulse, units -from scgenerator.utils.parameter import BareConfig, BareParams +from scgenerator.utils.parameter import BareConfig, Parameters def load_conf(name): @@ -143,10 +143,10 @@ class TestInitializeMethods(unittest.TestCase): init.Config(**conf("good5")).__dict__.items(), ) - def setup_conf_custom_field(self, path) -> BareParams: + def setup_conf_custom_field(self, path) -> Parameters: conf = load_conf(path) - conf = BareParams(**conf) + conf = Parameters(**conf) init.build_sim_grid_in_place(conf) return conf @@ -192,12 +192,12 @@ class TestInitializeMethods(unittest.TestCase): self.assertTrue(result) conf = self.setup_conf_custom_field("custom_field/wavelength_shift1") - result = init.Params.from_bare(conf) + result = Parameters(**conf) self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9) conf = self.setup_conf_custom_field("custom_field/wavelength_shift1") conf.wavelength = 1593e-9 - result = init.Params.from_bare(conf) + result = Parameters(**conf) conf = load_conf("custom_field/wavelength_shift2") conf = init.Config(**conf) diff --git a/testing/test_new_iterator.py b/testing/test_new_iterator.py deleted file mode 100644 index e15fd90..0000000 --- a/testing/test_new_iterator.py +++ /dev/null @@ -1,12 +0,0 @@ -import scgenerator as sc -from pathlib import Path -import os - -os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/") - -root = Path("PM1550+PMHNLF+PM1550+PM2000") - -confs = sc.io.load_config_sequence(root / "4_PM2000.toml") -final = sc.utils.final_config_from_sequence(*confs) - -print(final)