From 57c593cf4ff60a6682d5d133c4960020d00dc1f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Mon, 24 Jul 2023 08:23:04 +0200 Subject: [PATCH] removed a bunch of stuff Removed: - Variationer - FileConfiguration - Scripts (slurm, ...) - CLI --- .gitignore | 10 +- pyproject.toml | 4 + src/scgenerator/__init__.py | 7 +- src/scgenerator/evaluator.py | 10 +- src/scgenerator/logger.py | 1 - src/scgenerator/operators.py | 5 +- src/scgenerator/parameter.py | 304 +++------------------ src/scgenerator/pbar.py | 189 ------------- src/scgenerator/physics/fiber.py | 25 +- src/scgenerator/physics/materials.py | 37 ++- src/scgenerator/scripts/__init__.py | 346 ------------------------ src/scgenerator/scripts/slurm_submit.py | 157 ----------- src/scgenerator/solver.py | 2 + src/scgenerator/spectra.py | 1 + src/scgenerator/utils.py | 8 +- src/scgenerator/variationer.py | 336 ----------------------- 16 files changed, 87 insertions(+), 1355 deletions(-) delete mode 100644 src/scgenerator/pbar.py delete mode 100644 src/scgenerator/scripts/__init__.py delete mode 100644 src/scgenerator/scripts/slurm_submit.py delete mode 100644 src/scgenerator/variationer.py diff --git a/.gitignore b/.gitignore index 4d56e8f..d32a76c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .DS_store .idea **/*.npy +.conda-env pyrightconfig.json @@ -17,15 +18,8 @@ __pycache__ tmp* paths.json scgenerator_log* +scgenerator.log .scgenerator_tmp sc-*.log .vscode - - -# latex -*.aux -*.fdb_latexmk -*.fls -*.log -*.synctex.gz diff --git a/pyproject.toml b/pyproject.toml index 06890c3..bc1d55f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ [tool.ruff] line-length = 100 +ignore = ["E741"] [tool.ruff.pydocstyle] convention = "numpy" @@ -34,3 +35,6 @@ convention = "numpy" [tool.black] line-length = 100 +[tool.isort] +profile = "black" + diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index b225bda..31d8466 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -1,10 +1,9 @@ -# # flake8: noqa +# isort: skip_file +# ruff: noqa from scgenerator import math, operators, plotting from scgenerator.helpers import * from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace -from scgenerator.parameter import FileConfiguration, Parameters +from scgenerator.parameter import Parameters from scgenerator.physics import fiber, materials, plasma, pulse, units from scgenerator.physics.units import PlotRange from scgenerator.solver import integrate, solve43 -from scgenerator.utils import (Paths, _open_config, open_single_config, - simulations_list) diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index 7f9867e..d099e1b 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -312,7 +312,7 @@ default_rules: list[Rule] = [ Rule("w_num", len, ["w"]), Rule("dw", lambda w: w[1] - w[0]), Rule(["fft", "ifft"], utils.fft_functions, priorities=1), - Rule("interpolation_range", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)), + Rule("wavelength_window", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)), # Pulse Rule("field_0", pulse.finalize_pulse), Rule(["input_time", "input_field"], pulse.load_custom_field), @@ -393,7 +393,7 @@ default_rules: list[Rule] = [ Rule( "V_eff_arr", fiber.V_eff_step_index, - ["l", "core_radius", "numerical_aperture", "interpolation_range"], + ["l", "core_radius", "numerical_aperture", "wavelength_window"], ), Rule("n2", materials.gas_n2), Rule("n2", lambda: 2.2e-20, priorities=-1), @@ -403,7 +403,7 @@ default_rules: list[Rule] = [ # Raman Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w), Rule("raman_fraction", fiber.raman_fraction), - Rule("raman_fraction", lambda:0, priorities=-1), + Rule("raman_fraction", lambda: 0, priorities=-1), # loss Rule("alpha_arr", fiber.scalar_loss), Rule("alpha_arr", fiber.safe_capillary_loss, conditions=dict(loss="capillary")), @@ -434,7 +434,7 @@ envelope_rules = default_rules + [ Rule("beta2_arr", fiber.dispersion_from_coefficients), Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]), Rule( - ["wl_for_disp", "beta2_arr", "interpolation_range"], + ["wl_for_disp", "beta2_arr", "wavelength_window"], fiber.load_custom_dispersion, priorities=[2, 2, 2], ), @@ -442,7 +442,7 @@ envelope_rules = default_rules + [ Rule("gamma_op", operators.variable_gamma, priorities=2), Rule("gamma_op", operators.constant_quantity, ["gamma_arr"], priorities=1), Rule("gamma_op", lambda w_num, gamma: operators.constant_quantity(np.ones(w_num) * gamma)), - Rule("gamma_op", operators.no_op_freq, priorities=-1), + Rule("gamma_op", lambda: operators.constant_quantity(0.0), priorities=-1), Rule("ss_op", lambda w_c, w0: operators.constant_quantity(w_c / w0)), Rule("ss_op", lambda: operators.constant_quantity(0), priorities=-1), Rule("spm_op", operators.envelope_spm), diff --git a/src/scgenerator/logger.py b/src/scgenerator/logger.py index c3f925a..42eee72 100644 --- a/src/scgenerator/logger.py +++ b/src/scgenerator/logger.py @@ -48,7 +48,6 @@ def configure_logger(logger: logging.Logger): updated logger """ if not hasattr(logger, "already_configured"): - print_lvl = lvl_map.get(log_print_level(), logging.NOTSET) file_lvl = lvl_map.get(log_file_level(), logging.NOTSET) diff --git a/src/scgenerator/operators.py b/src/scgenerator/operators.py index a986daa..08af32f 100755 --- a/src/scgenerator/operators.py +++ b/src/scgenerator/operators.py @@ -7,6 +7,7 @@ from __future__ import annotations from typing import Callable import numpy as np + from scgenerator import math from scgenerator.logger import get_logger from scgenerator.physics import fiber, materials, plasma, pulse, units @@ -266,8 +267,7 @@ def constant_wave_vector( ################################################## -def envelope_raman(hr_w:np.ndarra, raman_fraction: float) -> FieldOperator: - +def envelope_raman(hr_w: np.ndarra, raman_fraction: float) -> FieldOperator: def operate(field: np.ndarray, z: float) -> np.ndarray: return raman_fraction * np.fft.ifft(hr_w * np.fft.fft(math.abs2(field))) @@ -336,7 +336,6 @@ def ionization( N0 = number_density(z) plasma_info = plasma_obj(field, N0) - # state.stats["ionization_fraction"] = plasma_info.electron_density[-1] / N0 # state.stats["electron_density"] = plasma_info.electron_density[-1] return plasma_info.polarization diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 34bcf0e..0d69f55 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -1,32 +1,50 @@ from __future__ import annotations import datetime as datetime_module -import enum import os -import time from copy import copy from dataclasses import dataclass, field, fields from functools import lru_cache, wraps from math import isnan from pathlib import Path -from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeVar +from typing import (Any, Callable, ClassVar, Iterable, Iterator, Set, Type, + TypeVar) import numpy as np -from scgenerator import env, utils -from scgenerator.const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__ +from scgenerator import utils +from scgenerator.const import MANDATORY_PARAMETERS, __version__ from scgenerator.errors import EvaluatorError from scgenerator.evaluator import Evaluator -from scgenerator.logger import get_logger from scgenerator.operators import Qualifier, SpecOperator -from scgenerator.utils import fiber_folder, update_path_name -from scgenerator.variationer import VariationDescriptor, Variationer +from scgenerator.utils import update_path_name T = TypeVar("T") +DISPLAY_INFO = {} + + +def format_value(name: str, value) -> str: + if value is True or value is False: + return str(value) + elif isinstance(value, (float, int)): + try: + return DISPLAY_INFO[name](value) + except KeyError: + return format(value, ".9g") + elif isinstance(value, np.ndarray): + return np.array2string(value) + elif isinstance(value, (list, tuple)): + return "-".join([str(v) for v in value]) + elif isinstance(value, str): + p = Path(value) + if p.exists(): + return p.stem + elif callable(value): + return getattr(value, "__name__", repr(value)) + return str(value) + # Validator - - @lru_cache def type_checker(*types): def _type_checker_wrapper(validator, n=None): @@ -224,7 +242,7 @@ class Parameter: pass if self.default is not None: Evaluator.register_default_param(self.name, self.default) - VariationDescriptor.register_formatter(self.name, self.display) + DISPLAY_INFO[self.name] = self.display def __get__(self, instance: Parameters, owner): if instance is None: @@ -382,7 +400,7 @@ class Parameters: dt: float = Parameter(in_range_excl(0, 10e-15)) tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11) step_size: float = Parameter(non_negative(float, int), default=0) - interpolation_range: tuple[float, float] = Parameter( + wavelength_window: tuple[float, float] = Parameter( validator_and(float_pair, validator_list(in_range_incl(100e-9, 10000e-9))) ) interpolation_degree: int = Parameter(validator_and(type_checker(int), in_range_incl(2, 18))) @@ -469,11 +487,7 @@ class Parameters: exclude = exclude or [] if isinstance(exclude, str): exclude = [exclude] - p_pairs = [ - (k, VariationDescriptor.format_value(k, getattr(self, k))) - for k in params - if k not in exclude - ] + p_pairs = [(k, format_value(k, getattr(self, k))) for k in params if k not in exclude] max_left = max(len(el[0]) for el in p_pairs) max_right = max(len(el[1]) for el in p_pairs) return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs) @@ -544,262 +558,6 @@ class Parameters: return None -class AbstractConfiguration: - fiber_paths: list[Path] - num_sim: int - total_num_steps: int - worker_num: int - final_path: Path - - def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]: - raise NotImplementedError() - - def save_parameters(self): - raise NotImplementedError() - - -class FileConfiguration(AbstractConfiguration): - """ - Primary role is to load the final config file of the simulation and deduce every - simulatin that has to happen. Iterating through the Configuration obj yields a list of - parameter names and values that change throughout the simulation as well as parameter - obj with the output path of the simulation saved in its output_path attribute. - """ - - fiber_configs: list[utils.SubConfig] - master_config_dict: dict[str, Any] - num_fibers: int - repeat: int - z_num: int - overwrite: bool - all_configs: dict[tuple[tuple[int, ...], ...], "FileConfiguration.__SimConfig"] - - @dataclass(frozen=True) - class __SimConfig: - descriptor: VariationDescriptor - config: dict[str, Any] - output_path: Path - - @property - def sim_num(self) -> int: - return len(self.descriptor.index) - - class State(enum.Enum): - COMPLETE = enum.auto() - PARTIAL = enum.auto() - ABSENT = enum.auto() - - class Action(enum.Enum): - RUN = enum.auto() - WAIT = enum.auto() - SKIP = enum.auto() - - def __init__( - self, - config_path: os.PathLike, - overwrite: bool = True, - wait: bool = False, - skip_callback: Callable[[int], None] = None, - final_output_path: os.PathLike = None, - ): - self.logger = get_logger(__name__) - self.wait = wait - - self.overwrite = overwrite - self.final_path, self.fiber_configs = utils.load_config_sequence(config_path) - self.final_path = env.get(env.OUTPUT_PATH, self.final_path) - if final_output_path is not None: - self.final_path = final_output_path - self.final_path = utils.ensure_folder( - Path(self.final_path), - mkdir=False, - prevent_overwrite=not self.overwrite, - ) - self.master_config_dict = self.fiber_configs[0].fixed | { - k: v[0] for vary_dict in self.fiber_configs[0].variable for k, v in vary_dict.items() - } - self.name = self.final_path.name - self.z_num = 0 - self.total_num_steps = 0 - self.fiber_paths = [] - self.all_configs = {} - self.skip_callback = skip_callback - self.worker_num = self.master_config_dict.get("worker_num", max(1, os.cpu_count() // 2)) - self.repeat = self.master_config_dict.get("repeat", 1) - self.variationer = Variationer() - - fiber_names = set() - self.num_fibers = 0 - for i, config in enumerate(self.fiber_configs): - config.fixed.setdefault("name", Parameters.name.default) - self.z_num += config.fixed["z_num"] - fiber_names.add(config.fixed["name"]) - self.variationer.append(config.variable) - self.fiber_paths.append( - utils.ensure_folder( - self.final_path / fiber_folder(i, self.name, Path(config.fixed["name"]).name), - mkdir=False, - prevent_overwrite=not self.overwrite, - ) - ) - self.__validate_variable(config.variable) - self.num_fibers += 1 - Evaluator.evaluate_default( - self.master_config_dict - | config.fixed - | {k: v[0] for vary_dict in config.variable for k, v in vary_dict.items()}, - True, - ) - self.num_sim = self.variationer.var_num() - self.total_num_steps = sum( - config.fixed["z_num"] * self.variationer.var_num(i) - for i, config in enumerate(self.fiber_configs) - ) - - def __validate_variable(self, vary_dict_list: list[dict[str, list]]): - for vary_dict in vary_dict_list: - for k, v in vary_dict.items(): - p: Parameter = getattr(Parameters, k) - validator_list(p.validator)("variable " + k, v) - if k not in VALID_VARIABLE: - raise TypeError(f"{k!r} is not a valid variable parameter") - if len(v) == 0: - raise ValueError(f"variable parameter {k!r} must not be empty") - - def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]: - for i in range(self.num_fibers): - yield from self.iterate_single_fiber(i) - - def iterate_single_fiber(self, index: int) -> Iterator[tuple[VariationDescriptor, Parameters]]: - """iterates through the parameters of only one fiber. It takes care of recovering partially - completed simulations, skipping complete ones and waiting for the previous fiber to finish - - Parameters - ---------- - index : int - which fiber to iterate over - - Yields - ------- - __SimConfig - configuration obj - """ - if index < 0: - index = self.num_fibers + index - sim_dict: dict[Path, FileConfiguration.__SimConfig] = {} - for descriptor in self.variationer.iterate(index): - cfg = descriptor.update_config(self.fiber_configs[index].fixed) - if index > 0: - cfg["prev_data_dir"] = str( - self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True) - ) - p = utils.ensure_folder( - self.fiber_paths[index] / descriptor.formatted_descriptor(True), - not self.overwrite, - False, - ) - cfg["output_path"] = p - sim_config = self.__SimConfig(descriptor, cfg, p) - sim_dict[p] = self.all_configs[sim_config.descriptor.index] = sim_config - while len(sim_dict) > 0: - for data_dir, sim_config in sim_dict.items(): - task, config_dict = self.__decide(sim_config) - if task == self.Action.RUN: - sim_dict.pop(data_dir) - param_dict = sim_config.config - yield sim_config.descriptor, Parameters(**param_dict) - if "recovery_last_stored" in config_dict and self.skip_callback is not None: - self.skip_callback(config_dict["recovery_last_stored"]) - break - elif task == self.Action.SKIP: - sim_dict.pop(data_dir) - self.logger.debug(f"skipping {data_dir} as it is already complete") - if self.skip_callback is not None: - self.skip_callback(config_dict["z_num"]) - break - else: - self.logger.debug("sleeping while waiting for other simulations to complete") - time.sleep(1) - - def __decide( - self, sim_config: "FileConfiguration.__SimConfig" - ) -> tuple["FileConfiguration.Action", dict[str, Any]]: - """decide what to to with a particular simulation - - Parameters - ---------- - sim_config : __SimConfig - - Returns - ------- - str : Configuration.Action - what to do - config_dict : dict[str, Any] - config dictionary. The only key possibly modified is 'prev_data_dir', which - gets set if the simulation is partially completed - """ - if not self.wait: - return self.Action.RUN, sim_config.config - out_status, num = self.sim_status(sim_config.output_path, sim_config.config) - if out_status == self.State.COMPLETE: - return self.Action.SKIP, sim_config.config - elif out_status == self.State.PARTIAL: - sim_config.config["recovery_data_dir"] = str(sim_config.output_path) - sim_config.config["recovery_last_stored"] = num - return self.Action.RUN, sim_config.config - - if "prev_data_dir" in sim_config.config: - prev_data_path = Path(sim_config.config["prev_data_dir"]) - prev_status, _ = self.sim_status(prev_data_path) - if prev_status in {self.State.PARTIAL, self.State.ABSENT}: - return self.Action.WAIT, sim_config.config - return self.Action.RUN, sim_config.config - - def sim_status( - self, data_dir: Path, config_dict: dict[str, Any] = None - ) -> tuple["FileConfiguration.State", int]: - """returns the status of a simulation - - Parameters - ---------- - data_dir : Path - directory where simulation data is to be saved - config_dict : dict[str, Any], optional - configuration of the simulation. If None, will attempt to load - the params.toml file if present, by default None - - Returns - ------- - Configuration.State - status - """ - num = utils.find_last_spectrum_num(data_dir) - if config_dict is None: - try: - config_dict = utils.load_toml(data_dir / PARAM_FN) - except FileNotFoundError: - self.logger.warning(f"did not find {PARAM_FN!r} in {data_dir}") - return self.State.ABSENT, 0 - if num == config_dict["z_num"] - 1: - return self.State.COMPLETE, num - elif config_dict["z_num"] - 1 > num > 0: - return self.State.PARTIAL, num - elif num == 0: - return self.State.ABSENT, 0 - else: - raise ValueError(f"Too many spectra in {data_dir}") - - def save_parameters(self): - os.makedirs(self.final_path, exist_ok=True) - cfgs = [cfg.fixed | dict(variable=cfg.variable) for cfg in self.fiber_configs] - utils.save_toml(self.final_path / "initial_config.toml", dict(name=self.name, Fiber=cfgs)) - - @property - def first(self) -> Parameters: - for _, param in self: - return param - - if __name__ == "__main__": numero = type_checker(int) diff --git a/src/scgenerator/pbar.py b/src/scgenerator/pbar.py deleted file mode 100644 index acc6c0e..0000000 --- a/src/scgenerator/pbar.py +++ /dev/null @@ -1,189 +0,0 @@ -import multiprocessing -import os -import random -import threading -import typing -from collections import abc -from io import StringIO -from pathlib import Path -from typing import Iterable, Union - -from tqdm import tqdm - -from scgenerator.env import pbar_policy - -T_ = typing.TypeVar("T_") - - -class PBars: - def __init__( - self, - task: Union[int, Iterable[T_]], - desc: str, - num_sub_bars: int = 0, - head_kwargs=None, - worker_kwargs=None, - ) -> "PBars": - """creates a PBars obj - - Parameters - ---------- - task : int | Iterable - if int : total length of the main task - if Iterable : behaves like tqdm - desc : str - description of the main task - num_sub_bars : int - number of sub-tasks - - """ - self.id = random.randint(100000, 999999) - try: - self.width = os.get_terminal_size().columns - except OSError: - self.width = 120 - if isinstance(task, abc.Iterable): - self.iterator: Iterable[T_] = iter(task) - self.num_tot: int = len(task) - else: - self.num_tot: int = task - self.iterator = None - - self.policy = pbar_policy() - if head_kwargs is None: - head_kwargs = dict() - if worker_kwargs is None: - worker_kwargs = dict( - total=1, - desc="Worker {worker_id}", - bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]", - ) - if "print" not in pbar_policy(): - head_kwargs["file"] = worker_kwargs["file"] = StringIO() - self.width = 80 - head_kwargs["desc"] = desc - self.pbars = [tqdm(total=self.num_tot, ncols=self.width, ascii=False, **head_kwargs)] - for i in range(1, num_sub_bars + 1): - kwargs = {k: v for k, v in worker_kwargs.items()} - if "desc" in kwargs: - kwargs["desc"] = kwargs["desc"].format(worker_id=i) - self.append(tqdm(position=i, ncols=self.width, ascii=False, **kwargs)) - self.print_path = Path( - f"progress {self.pbars[0].desc.replace('/', '')} {self.id}" - ).resolve() - self.close_ev = threading.Event() - if "file" in self.policy: - self.thread = threading.Thread(target=self.print_worker, daemon=True) - self.thread.start() - - def print(self): - if "file" not in self.policy: - return - s = [] - for pbar in self.pbars: - s.append(str(pbar)) - self.print_path.write_text("\n".join(s)) - - def print_worker(self): - while True: - if self.close_ev.wait(2.0): - return - self.print() - - def __iter__(self): - with self as pb: - for thing in self.iterator: - yield thing - pb.update() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def __getitem__(self, key): - return self.pbars[key] - - def update(self, i=None, value=1): - if i is None: - for pbar in self.pbars[1:]: - pbar.update(value) - elif i > 0: - self.pbars[i].update(value) - self.pbars[0].update() - - def append(self, pbar: tqdm): - self.pbars.append(pbar) - - def reset(self, i): - self.pbars[i].update(-self.pbars[i].n) - self.print() - - def close(self): - self.print() - self.close_ev.set() - if "file" in self.policy: - self.thread.join() - for pbar in self.pbars: - pbar.close() - - -class ProgressBarActor: - def __init__(self, name: str, num_workers: int, num_steps: int) -> None: - self.counters = [0 for _ in range(num_workers + 1)] - self.p_bars = PBars( - num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step") - ) - - def update(self, worker_id: int, rel_pos: float = None) -> None: - """update a counter - - Parameters - ---------- - worker_id : int - id of the worker. 0 is the overall progress - rel_pos : float, optional - if None, increase the counter by one, if set, will set - the counter to the specified value (instead of incrementing it), by default None - """ - if rel_pos is None: - self.counters[worker_id] += 1 - else: - self.counters[worker_id] = rel_pos - - def update_pbars(self): - for counter, pbar in zip(self.counters, self.p_bars.pbars): - pbar.update(counter - pbar.n) - - def close(self): - self.p_bars.close() - - -def progress_worker( - name: str, num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue -): - """keeps track of progress on a separate thread - - Parameters - ---------- - num_steps : int - total number of steps, used for the main progress bar (position 0) - progress_queue : multiprocessing.Queue - values are either - Literal[0] : stop the worker and close the progress bars - tuple[int, float] : worker id and relative progress between 0 and 1 - """ - with PBars( - num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step") - ) as pbars: - while True: - raw = progress_queue.get() - if raw == 0: - return - i, rel_pos = raw - if i > 0: - pbars[i].update(rel_pos - pbars[i].n) - pbars[0].update() - elif i == 0: - pbars[0].update(rel_pos) diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 96af0b0..7610c65 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -18,7 +18,7 @@ T = TypeVar("T") def lambda_for_envelope_dispersion( - l: np.ndarray, interpolation_range: tuple[float, float] + l: np.ndarray, wavelength_window: tuple[float, float] ) -> tuple[np.ndarray, np.ndarray]: """Returns a wl vector for dispersion calculation in envelope mode @@ -30,10 +30,10 @@ def lambda_for_envelope_dispersion( np.ndarray indices of the original l where the values are valid (i.e. without the two extra on each side) """ - su = np.where((l >= interpolation_range[0]) & (l <= interpolation_range[1]))[0] - if l[su].min() > 1.01 * interpolation_range[0]: + su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0] + if l[su].min() > 1.01 * wavelength_window[0]: raise ValueError( - f"lower range of {1e9*interpolation_range[0]:.1f}nm is not reached by the grid. " + f"lower range of {1e9*wavelength_window[0]:.1f}nm is not reached by the grid. " f"Minimum of grid is {1e9*l[su].min():.1f}nm. Try a finer grid" ) @@ -48,7 +48,7 @@ def lambda_for_envelope_dispersion( def lambda_for_full_field_dispersion( - l: np.ndarray, interpolation_range: tuple[float, float] + l: np.ndarray, wavelength_window: tuple[float, float] ) -> tuple[np.ndarray, np.ndarray]: """Returns a wl vector for dispersion calculation in full field mode @@ -60,10 +60,10 @@ def lambda_for_full_field_dispersion( np.ndarray indices of the original l where the values are valid (i.e. without the two extra on each side) """ - su = np.where((l >= interpolation_range[0]) & (l <= interpolation_range[1]))[0] - if l[su].min() > 1.01 * interpolation_range[0]: + su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0] + if l[su].min() > 1.01 * wavelength_window[0]: raise ValueError( - f"lower range of {1e9*interpolation_range[0]:.1f}nm is not reached by the grid. " + f"lower range of {1e9*wavelength_window[0]:.1f}nm is not reached by the grid. " "try a finer grid" ) fu = np.concatenate((su[:2] - 2, su, su[-2:] + 2)) @@ -385,7 +385,7 @@ def V_eff_step_index( l: T, core_radius: float, numerical_aperture: float, - interpolation_range: tuple[float, float] = None, + wavelength_window: tuple[float, float] = None, ) -> T: """computes the V parameter of a step-index fiber @@ -397,7 +397,7 @@ def V_eff_step_index( radius of the core numerical_aperture : float as a decimal number - interpolation_range : tuple[float, float], optional + wavelength_window : tuple[float, float], optional when provided, only computes V over this range, wavelengths outside this range will yield V=inf, by default None @@ -407,8 +407,8 @@ def V_eff_step_index( V parameter """ pi2cn = 2 * pi * core_radius * numerical_aperture - if interpolation_range is not None and isinstance(l, np.ndarray): - low, high = interpolation_range + if wavelength_window is not None and isinstance(l, np.ndarray): + low, high = wavelength_window l = np.where((l >= low) & (l <= high), l, np.inf) return pi2cn / l @@ -805,7 +805,6 @@ def delayed_raman_w(t: np.ndarray, raman_type: str) -> tuple[np.ndarray, float]: return hr_w, raman_fraction(raman_type) - def fast_poly_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None)): """ dispersive operator diff --git a/src/scgenerator/physics/materials.py b/src/scgenerator/physics/materials.py index 94736cd..61afbef 100644 --- a/src/scgenerator/physics/materials.py +++ b/src/scgenerator/physics/materials.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass, field from functools import cache -from typing import TypeVar +from typing import Any, TypeVar import numpy as np @@ -106,14 +106,16 @@ class Gas: chi3_0: float ionization_energy: float | None + _raw_sellmeier: dict[str, Any] + def __init__(self, gas_name: str): self.name = gas_name - self.mat_dico = utils.load_material_dico(gas_name) - self.atomic_mass = self.mat_dico["atomic_mass"] - self.atomic_number = self.mat_dico["atomic_number"] - self.ionization_energy = self.mat_dico.get("ionization_energy") + self._raw_sellmeier = utils.load_material_dico(gas_name) + self.atomic_mass = self._raw_sellmeier["atomic_mass"] + self.atomic_number = self._raw_sellmeier["atomic_number"] + self.ionization_energy = self._raw_sellmeier.get("ionization_energy") - s = self.mat_dico.get("sellmeier", {}) + s = self._raw_sellmeier.get("sellmeier", {}) self.sellmeier = Sellmeier( **{ newk: s.get(k, None) @@ -124,7 +126,7 @@ class Gas: if k in s } ) - kerr = self.mat_dico["kerr"] + kerr = self._raw_sellmeier["kerr"] n2_0 = kerr["n2"] self._kerr_wl = kerr.get("wavelength", 800e-9) self.chi3_0 = ( @@ -212,18 +214,23 @@ class Gas: Raises ---------- - ValueError : Since the Van der Waals equation is a cubic one, there could be more than one real, positive solution + ValueError : Since the Van der Waals equation is a cubic one, there could be more than one + real, positive solution """ logger = get_logger(__name__) if pressure == 0: return 0 - a = self.mat_dico.get("a", 0) - b = self.mat_dico.get("b", 0) - pressure = self.mat_dico["sellmeier"].get("P0", 101325) if pressure is None else pressure + a = self._raw_sellmeier.get("a", 0) + b = self._raw_sellmeier.get("b", 0) + pressure = ( + self._raw_sellmeier["sellmeier"].get("P0", 101325) if pressure is None else pressure + ) temperature = ( - self.mat_dico["sellmeier"].get("T0", 273.15) if temperature is None else temperature + self._raw_sellmeier["sellmeier"].get("T0", 273.15) + if temperature is None + else temperature ) ap = a / NA**2 bp = b / NA @@ -302,10 +309,10 @@ class Gas: return Z**3 / (16 * ns**4) * 5.14220670712125e11 def get(self, key, default=None): - return self.mat_dico.get(key, default) + return self._raw_sellmeier.get(key, default) def __getitem__(self, key): - return self.mat_dico[key] + return self._raw_sellmeier[key] def __repr__(self) -> str: return f"{self.__class__.__name__}({self.name!r})" @@ -317,7 +324,7 @@ def n_gas_2(wl_for_disp: np.ndarray, gas_name: str, pressure: float, temperature return Sellmeier.load(gas_name).n_gas_2(wl_for_disp, temperature, pressure) -def pressure_from_gradient(ratio, p0, p1): +def pressure_from_gradient(ratio: float, p0: float, p1: float) -> float: """returns the pressure as function of distance with eq. 20 in Markos et al. (2017) Parameters ---------- diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py deleted file mode 100644 index 2c344a7..0000000 --- a/src/scgenerator/scripts/__init__.py +++ /dev/null @@ -1,346 +0,0 @@ -import os -import re -from pathlib import Path -from typing import Any, Iterable, Optional - -import matplotlib.pyplot as plt -import numpy as np -from cycler import cycler -from tqdm import tqdm - -from scgenerator import env, math -from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN -from scgenerator.parameter import FileConfiguration, Parameters -from scgenerator.physics import fiber, units -from scgenerator.plotting import plot_setup, transform_2D_propagation, get_extent -from scgenerator.spectra import SimulationSeries -from scgenerator.utils import _open_config, auto_crop, save_toml, simulations_list, load_toml, load_spectrum - - -def fingerprint(params: Parameters): - h1 = hash(params.field_0.tobytes()) - h2 = tuple(params.beta2_coefficients) - return h1, h2 - - -def plot_all(sim_dir: Path, limits: list[str], show=False, **opts): - for k, v in opts.items(): - if k in ["skip"]: - opts[k] = int(v) - if v == "True": - opts[k] = True - elif v == "False": - opts[k] = False - dir_list = simulations_list(sim_dir) - if len(dir_list) == 0: - dir_list = [sim_dir] - limits = [ - tuple(func(el) for func, el in zip([float, float, str], lim.split(","))) for lim in limits - ] - with tqdm(total=len(dir_list) * max(1, len(limits))) as bar: - for p in dir_list: - pulse = SimulationSeries(p) - if not limits: - limits = [ - ( - pulse.params.interpolation_range[0] * 1e9, - pulse.params.interpolation_range[1] * 1e9, - "nm", - ) - ] - for left, right, unit in limits: - path, fig, ax = plot_setup( - pulse.path.parent - / ( - pulse.path.name - + PARAM_SEPARATOR - + f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}" - ) - ) - fig.suptitle(p.name) - pulse.plot_2D( - left, - right, - unit, - ax, - **opts, - ) - bar.update() - if show: - plt.show() - else: - fig.savefig(path, bbox_inches="tight") - plt.close(fig) - - -def plot_init_field_spec( - config_path: Path, - lim_t: tuple[float, float] = None, - lim_l: tuple[float, float] = None, -): - fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7)) - all_labels = [] - already_plotted = set() - for style, lbl, params in plot_helper(config_path): - if (bbb := hash(params.field_0.tobytes())) not in already_plotted: - already_plotted.add(bbb) - else: - continue - - lbl = plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params) - all_labels.append(lbl) - finish_plot(fig, left, right, all_labels, params) - - -def plot_dispersion(config_path: Path, lim: tuple[float, float] = None): - fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7)) - left.grid() - right.grid() - all_labels = [] - already_plotted = set() - loss_ax = None - plt.sca(left) - for style, lbl, params in plot_helper(config_path): - if params.alpha_arr is not None and loss_ax is None: - loss_ax = right.twinx() - if (bbb := tuple(params.beta2_coefficients)) not in already_plotted: - already_plotted.add(bbb) - else: - continue - - lbl = plot_1_dispersion(lim, left, right, style, lbl, params, loss_ax) - all_labels.append(lbl) - finish_plot(fig, right, all_labels, params) - - -def plot_init( - config_path: Path, - lim_field: tuple[float, float] = None, - lim_spec: tuple[float, float] = None, - lim_disp: tuple[float, float] = None, -): - fig, ((tl, tr), (bl, br)) = plt.subplots(2, 2, figsize=(14, 10)) - loss_ax = None - tl.grid() - tr.grid() - all_labels = [] - already_plotted = set() - for style, lbl, params in plot_helper(config_path): - if params.alpha_arr is not None and loss_ax is None: - loss_ax = tr.twinx() - if (fp := fingerprint(params)) not in already_plotted: - already_plotted.add(fp) - else: - continue - lbl = plot_1_dispersion(lim_disp, tl, tr, style, lbl, params, loss_ax) - lbl = plot_1_init_spec_field(lim_field, lim_spec, bl, br, style, lbl, params) - all_labels.append(lbl) - print(params.pretty_str(exclude="beta2_coefficients")) - finish_plot(fig, tr, all_labels, params) - - -def plot_1_init_spec_field( - lim_t: Optional[tuple[float, float]], - lim_l: Optional[tuple[float, float]], - left: plt.Axes, - right: plt.Axes, - style: dict[str, Any], - lbl: str, - params: Parameters, -): - field = math.abs2(params.field_0) - spec = math.abs2(params.spec_0) - t = units.fs.inv(params.t) - wl = units.nm.inv(params.w) - - lbl += f" max at {wl[spec.argmax()]:.1f} nm" - - mt = np.ones_like(t, dtype=bool) - if lim_t is not None: - mt &= t >= lim_t[0] - mt &= t <= lim_t[1] - else: - mt = auto_crop(t, field) - ml = np.ones_like(wl, dtype=bool) - if lim_l is not None: - ml &= wl >= lim_l[0] - ml &= wl <= lim_l[1] - else: - ml = auto_crop(wl, spec) - - left.plot(t[mt], field[mt]) - right.plot(wl[ml], spec[ml], label=" ", **style) - return lbl - - -def plot_1_dispersion( - lim: Optional[tuple[float, float]], - left: plt.Axes, - right: plt.Axes, - style: dict[str, Any], - lbl: list[str], - params: Parameters, - loss: plt.Axes = None, -): - beta_arr = fiber.dispersion_from_coefficients(params.w_c, params.beta2_coefficients) - wl = units.m.inv(params.w) - D = fiber.beta2_to_D(beta_arr, wl) * 1e6 - - zdw = math.all_zeros(wl, beta_arr) - zdw = zdw[(zdw >= params.interpolation_range[0]) & (zdw <= params.interpolation_range[1])] - if len(zdw) > 0: - zdw = zdw[np.argmin(abs(zdw - params.wavelength))] - lbl += f" ZDW at {zdw*1e9:.1f}nm" - else: - lbl += "" - - m = np.ones_like(wl, dtype=bool) - if lim is None: - lim = params.interpolation_range - m &= wl >= (lim[0] if lim[0] < 1 else lim[0] * 1e-9) - m &= wl <= (lim[1] if lim[1] < 1 else lim[1] * 1e-9) - - info_str = ( - rf"$\lambda_{{\mathrm{{min}}}}={np.min(params.l[params.l>0])*1e9:.1f}$ nm" - + f"\nlower interpolation limit : {params.interpolation_range[0]*1e9:.1f} nm\n" - + f"max time delay : {params.t.max()*1e12:.1f} ps" - ) - - left.annotate( - info_str, - xy=(1, 1), - xytext=(-12, -12), - xycoords="axes fraction", - textcoords="offset points", - va="top", - ha="right", - backgroundcolor=(1, 1, 1, 0.4), - ) - - m = np.argwhere(m)[:, 0] - m = np.array(sorted(m, key=lambda el: wl[el])) - - if len(m) == 0: - raise ValueError(f"nothing to plot in the range {lim!r}") - - # plot D - right.plot(1e9 * wl[m], D[m], label=" ", **style) - right.set_ylabel(units.D_ps_nm_km.label) - - # plot beta2 - left.plot(units.nm.inv(params.w[m]), units.beta2_fs_cm.inv(beta_arr[m]), label=" ", **style) - left.set_ylabel(units.beta2_fs_cm.label) - - left.set_xlabel(units.nm.label) - right.set_xlabel("wavelength (nm)") - - if params.alpha_arr is not None and loss is not None: - loss.plot(1e9 * wl[m], params.alpha_arr[m], c="r", ls="--") - loss.set_ylabel("loss (1/m)", color="r") - loss.set_yscale("log") - loss.tick_params(axis="y", labelcolor="r") - - return lbl - - -def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], params: Parameters): - fig.suptitle(params.name) - plt.tight_layout() - - handles, _ = legend_axes.get_legend_handles_labels() - - legend_axes.legend(handles, all_labels, prop=dict(size=8, family="monospace")) - - out_path = env.output_path() - - show = out_path is None - if not show: - file_name = out_path.stem + ".pdf" - out_path = out_path.parent / file_name - if ( - out_path.exists() - and input(f"{out_path.name} already exsits, overwrite ? (y/[n])\n > ") != "y" - ): - show = True - else: - fig.savefig(out_path, bbox_inches="tight") - if show: - plt.show() - - -def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]: - cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"]) - for style, (descriptor, params), _ in zip(cc, FileConfiguration(config_path), range(20)): - yield style, descriptor.branch.formatted_descriptor(), params - - -def convert_params(params_file: os.PathLike): - p = Path(params_file) - if p.name == PARAM_FN: - d = _open_config(params_file) - save_toml(params_file, d) - print(f"converted {p}") - else: - - - - - - - - for pp in p.glob(PARAM_FN): - convert_params(pp) - for pp in p.glob("fiber*"): - if pp.is_dir(): - convert_params(pp) - - -def partial_plot(root: os.PathLike, lim: str = None): - path = Path(root) - fig, ax = plt.subplots(figsize=(12, 8)) - fig.suptitle(path.name) - spec_list = sorted( - path.glob(SPEC1_FN.format("*")), key=lambda el: int(re.search("[0-9]+", el.name)[0]) - ) - - - - - - params = Parameters(**load_toml(path / "params.toml")) - params.z_targets = params.z_targets[: len(spec_list)] - raw_values = np.array([load_spectrum(s) for s in spec_list]) - if lim is None: - plot_range = units.PlotRange( - 0.5 * params.interpolation_range[0] * 1e9, - 1.1 * params.interpolation_range[1] * 1e9, - "nm", - ) - else: - left_u, right_u, unit = lim.split(",") - plot_range = units.PlotRange(float(left_u), float(right_u), unit) - if plot_range.unit.type == "TIME": - values = params.ifft(raw_values) - log = False - vmin = None - else: - values = raw_values - log = "2D" - vmin = -60 - - x, y, values = transform_2D_propagation( - values, - plot_range, - params, - log=log, - ) - ax.imshow( - values, - origin="lower", - aspect="auto", - vmin=vmin, - interpolation="nearest", - extent=get_extent(x, y), - ) - - return ax diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py deleted file mode 100644 index c1e804e..0000000 --- a/src/scgenerator/scripts/slurm_submit.py +++ /dev/null @@ -1,157 +0,0 @@ -import argparse -import os -import re -import shutil -import subprocess -from datetime import datetime, timedelta -from pathlib import Path -from typing import Tuple - -import numpy as np - -from ..utils import Paths -from ..parameter import FileConfiguration - - -def primes(n): - prime_factors = [] - d = 2 - while d * d <= n: - while (n % d) == 0: - prime_factors.append(d) - n //= d - d += 1 - if n > 1: - prime_factors.append(n) - return prime_factors - - -def balance(n, lim=(32, 32)): - factors = primes(n) - if len(factors) == 1: - factors = primes(n + 1) - a, b, x, y = 1, 1, 1, 1 - while len(factors) > 0 and x <= lim[0] and y <= lim[1]: - a = x - b = y - if y >= x: - x *= factors.pop(0) - else: - y *= factors.pop() - return a, b - - -def distribute( - num: int, nodes: int = None, cpus_per_node: int = None, lim=(16, 32) -) -> Tuple[int, int]: - if nodes is None and cpus_per_node is None: - balanced = balance(num, lim) - if num > max(lim): - while np.product(balanced) < min(lim): - num += 1 - balanced = balance(num, lim) - nodes = min(balanced) - cpus_per_node = max(balanced) - - elif nodes is None: - nodes = num // cpus_per_node - while nodes > lim[0]: - nodes //= 2 - elif cpus_per_node is None: - cpus_per_node = num // nodes - while cpus_per_node > lim[1]: - cpus_per_node //= 2 - return nodes, cpus_per_node - - -def format_time(t): - try: - t = float(t) - return timedelta(minutes=t) - except ValueError: - return t - - -def create_parser(): - parser = argparse.ArgumentParser(description="submit a job to a slurm cluster") - parser.add_argument("config", help="path to the toml configuration file") - parser.add_argument( - "-t", "--time", required=True, type=str, help="time required for the job in hh:mm:ss" - ) - parser.add_argument( - "-c", "--cpus-per-node", default=None, type=int, help="number of cpus required per node" - ) - parser.add_argument("-n", "--nodes", default=None, type=int, help="number of nodes required") - parser.add_argument( - "--environment-setup", - required=False, - default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && " - "export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_PRINT_LEVEL=none && export SCGENERATOR_LOG_FILE_LEVEL=info", - help="commands to run to setup the environement (default : activate the sc environment with conda)", - ) - parser.add_argument( - "--command", default="run", choices=["run", "resume", "merge"], help="command to run" - ) - parser.add_argument("--dependency", default=None, help="sbatch dependency argument") - return parser - - -def copy_starting_files(): - for name in ["start_worker", "start_head"]: - path = Paths.get(name) - file_name = os.path.split(path)[1] - shutil.copy(path, file_name) - mode = os.stat(file_name) - os.chmod(file_name, 0o100 | mode.st_mode) - - -def main(): - - command_map = dict(run="Propagate", resume="Resuming", merge="Merging") - - parser = create_parser() - template = Paths.gets("submit_job_template") - args = parser.parse_args() - - if args.dependency is None: - args.dependency = "" - else: - args.dependency = f"#SBATCH --dependency={args.dependency}" - - if not re.match(r"^[0-9]{2}:[0-9]{2}:[0-9]{2}$", args.time) and not re.match( - r"^[0-9]+$", args.time - ): - - raise ValueError( - "time format must be an integer number of minute or must match the pattern hh:mm:ss" - ) - - config = FileConfiguration(args.config) - final_name = config.final_path - sim_num = config.num_sim - - if args.command == "merge": - args.nodes = 1 - args.cpus_per_node = 1 - else: - args.nodes, args.cpus_per_node = distribute(config.num_sim, args.nodes, args.cpus_per_node) - - submit_path = Path( - "submit " + final_name.replace("/", "") + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh" - ) - tmp_path = Path("submit tmp.sh") - - job_name = f"supercontinuum {final_name}" - submit_sh = template.format(job_name=job_name, **vars(args)) - - tmp_path.write_text(submit_sh) - subprocess.run(["sbatch", "--test-only", str(tmp_path)]) - submit = input( - f"{command_map[args.command]} {sim_num} pulses from config {args.config} with {args.cpus_per_node} cpus" - + f" per node on {args.nodes} nodes for {format_time(args.time)} ? (y/[n])\n" - ) - if submit.lower() in ["y", "yes"]: - submit_path.write_text(submit_sh) - copy_starting_files() - subprocess.run(["sbatch", submit_path]) - tmp_path.unlink() diff --git a/src/scgenerator/solver.py b/src/scgenerator/solver.py index 3a82dce..987a920 100644 --- a/src/scgenerator/solver.py +++ b/src/scgenerator/solver.py @@ -6,6 +6,7 @@ from typing import Any, Iterator, Sequence import numba import numpy as np + from scgenerator.math import abs2 from scgenerator.operators import SpecOperator from scgenerator.utils import TimedMessage @@ -133,6 +134,7 @@ def solve43( targets = list(sorted(set(targets))) if targets[0] == 0: targets.pop(0) + h = min(h, targets[0] / 2) step_ind = 0 msg = TimedMessage(2) diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index f4e0128..92b5c5a 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -405,6 +405,7 @@ class SimulatedFiber: return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind)) else: return load_spectrum(self.path / SPEC1_FN.format(z_ind)) + psd = np.fft.rfft(signal) / np.sqrt(0.5 * len(time) / dt) def __repr__(self) -> str: return f"{self.__class__.__name__}(path={self.path})" diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index 9b08772..5b0aac5 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -23,15 +23,13 @@ import pkg_resources as pkg import tomli import tomli_w -from scgenerator.const import (PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS, - SPEC1_FN, Z_FN) +from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS, SPEC1_FN, Z_FN from scgenerator.errors import DuplicateParameterError from scgenerator.logger import get_logger T_ = TypeVar("T_") - class TimedMessage: def __init__(self, interval: float = 10.0): self.interval = datetime.timedelta(seconds=interval) @@ -179,7 +177,8 @@ def _open_config(path: os.PathLike): return dico -def resolve_relative_paths(d:dict[str, Any], root:os.PathLike | None=None): + +def resolve_relative_paths(d: dict[str, Any], root: os.PathLike | None = None): root = Path(root) if root is not None else Path.cwd() for k, v in d.items(): if isinstance(v, MutableMapping): @@ -192,7 +191,6 @@ def resolve_relative_paths(d:dict[str, Any], root:os.PathLike | None=None): d[k] = str(root / v) - def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]: if (f_list := dico.pop("INCLUDE", None)) is not None: if isinstance(f_list, str): diff --git a/src/scgenerator/variationer.py b/src/scgenerator/variationer.py deleted file mode 100644 index ce2558b..0000000 --- a/src/scgenerator/variationer.py +++ /dev/null @@ -1,336 +0,0 @@ -import itertools -from collections.abc import MutableMapping, Sequence -from math import prod -from pathlib import Path -from typing import Any, Callable, Generator, Generic, Iterable, Iterator, Optional, TypeVar, Union - -import numpy as np -from pydantic import validator -from pydantic.main import BaseModel - -from scgenerator.const import PARAM_SEPARATOR - -T = TypeVar("T") - - -class VariationSpecsError(ValueError): - pass - - -class Variationer: - """ - manages possible combinations of values given dicts of lists - - Example - ------- - ``` - >> var = Variationer([dict(a=[1, 2]), [dict(b=["000", "111"], c=["a", "-1"])]]) - list(v.raw_descr for v in var.iterate()) - - [ - ((("a", 1),), (("b", "000"), ("c", "a"))), - ((("a", 1),), (("b", "111"), ("c", "-1"))), - ((("a", 2),), (("b", "000"), ("c", "a"))), - ((("a", 2),), (("b", "111"), ("c", "-1"))), - ] - ``` - """ - - all_indices: list[list[int]] - all_dicts: list[list[dict[str, list]]] - - def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]] = None): - self.all_indices = [] - self.all_dicts = [] - if variables is not None: - for i, el in enumerate(variables): - self.append(el) - - def append(self, var_list: Union[list[MutableMapping], MutableMapping]): - """append a list of variable parameter sets - each call to append creates a new group of parameters - - Parameters - ---------- - var_list : Union[list[MutableMapping], MutableMapping] - each dict in the list is treated as an independent parameter - this means that if for one dict, len > 1, the lists of possible values - must be the same length - - Example - ------- - `>> append([dict(wavelength=[800e-9, 900e-9], power=[1e3, 2e3]), dict(length=[3e-2, 3.5e-2, 4e-2])])` - - means that for every parameter variations, wavelength=800e-9 will always occur when power=1e3 and - vice versa, while length is free to vary independently - - Raises - ------ - VariationSpecsError - raised when possible values lists in a same dict are not the same length - """ - if not isinstance(var_list, Sequence): - var_list = [{k: v} for k, v in var_list.items()] - else: - var_list = list(var_list) - num_vars = [] - for d in var_list: - values = list(d.values()) - if len(values) > 0: - len_to_test = len(values[0]) - if not all(len(v) == len_to_test for v in values[1:]): - raise VariationSpecsError( - "variable items should all have the same number of parameters" - ) - num_vars.append(len_to_test) - if len(num_vars) == 0: - num_vars = [1] - self.all_indices.append(num_vars) - self.all_dicts.append(var_list) - - def iterate(self, index: int = -1) -> Generator["VariationDescriptor", None, None]: - index = self.__index(index) - flattened_indices = sum(self.all_indices[: index + 1], []) - index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[: index + 1]]) - ranges = [range(i) for i in flattened_indices] - for r in itertools.product(*ranges): - out: list[list[tuple[str, Any]]] = [] - indicies: list[list[int]] = [] - for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])): - out.append([]) - indicies.append([]) - for value_index, var_d in zip(r[start:end], self.all_dicts[i]): - for k, v in var_d.items(): - out[-1].append((k, v[value_index])) - indicies[-1].append(value_index) - yield VariationDescriptor(raw_descr=out, index=indicies) - - def __index(self, index: int) -> int: - if index < 0: - index = len(self.all_indices) + index - return index - - def var_num(self, index: int = -1) -> int: - index = self.__index(index) - return max(1, prod(prod(el) for el in self.all_indices[: index + 1])) - - -class VariationDescriptor(BaseModel): - raw_descr: tuple[tuple[tuple[str, Any], ...], ...] - index: tuple[tuple[int, ...], ...] - separator: str = "fiber" - _format_registry: dict[str, Callable[..., str]] = {} - __ids: dict[int, int] = {} - - @classmethod - def register_formatter(cls, p_name: str, func: Callable[..., str]): - """register a function that formats a particular parameter - - Parameters - ---------- - p_name : str - name of the parameter - func : Callable[..., str] - function that takes as single argument the value of the parameter and returns a string - """ - cls._format_registry[p_name] = func - - @classmethod - def format_value(cls, name: str, value) -> str: - if value is True or value is False: - return str(value) - elif isinstance(value, (float, int)): - try: - return cls._format_registry[name](value) - except KeyError: - return format(value, ".9g") - elif isinstance(value, (list, tuple, np.ndarray)): - return "-".join([str(v) for v in value]) - elif isinstance(value, str): - p = Path(value) - if p.exists(): - return p.stem - elif callable(value): - return getattr(value, "__name__", repr(value)) - return str(value) - - class Config: - allow_mutation = False - - def formatted_descriptor(self, add_identifier=False) -> str: - """formats a variable list into a str such that each simulation has a unique - directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations) - branch identifier can added at the beginning. - - Parameters - ---------- - add_identifier : bool - add unique simulation and parameter-set identifiers - - Returns - ------- - str - simulation descriptor - """ - str_list = [] - - for p_name, p_value in self.flat: - ps, vs = self._format_single_pair(p_name, p_value) - str_list.append(ps + PARAM_SEPARATOR + vs) - tmp_name = PARAM_SEPARATOR.join(str_list) - if not add_identifier: - return tmp_name - return ( - self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name - ) - - def _format_single_pair(self, p_name: str, p_value: Any) -> tuple[str, str]: - ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "") - vs = self.format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "") - return ps, vs - - def __getitem__(self, key) -> "VariationDescriptor": - return VariationDescriptor( - raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator - ) - - def __str__(self) -> str: - return self.formatted_descriptor(add_identifier=False) - - def __lt__(self, other: "VariationDescriptor") -> bool: - return self.raw_descr < other.raw_descr - - def __le__(self, other: "VariationDescriptor") -> bool: - return self.raw_descr <= other.raw_descr - - def __gt__(self, other: "VariationDescriptor") -> bool: - return self.raw_descr > other.raw_descr - - def __ge__(self, other: "VariationDescriptor") -> bool: - return self.raw_descr >= other.raw_descr - - def __eq__(self, other: "VariationDescriptor") -> bool: - return self.raw_descr == other.raw_descr - - def __hash__(self) -> int: - return hash(self.raw_descr) - - def __contains__(self, other: "VariationDescriptor") -> bool: - return all(el in self.raw_descr for el in other.raw_descr) - - def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]: - """updates a dictionary with the value of the descriptor - - Parameters - ---------- - cfg : dict[str, Any] - dict to be updated - index : int, optional - index of the fiber from which to apply the parameters, by default -1 - - Returns - ------- - dict[str, Any] - same as cfg but with key from the descriptor added/updated. - """ - out_cfg = cfg.copy() - out_cfg.pop("variable", None) - return out_cfg | {k: v for k, v in self.raw_descr[index]} - - def iter_parents(self) -> Iterator["VariationDescriptor"]: - if (p := self.parent) is not None: - yield from p.iter_parents() - yield self - - @property - def short(self) -> str: - """shortened description of the simulation""" - return " ".join( - self._format_single_pair(p, v)[1] for p, v in self.flat if p not in {"fiber", "num"} - ) - - @property - def flat(self) -> list[tuple[str, Any]]: - out = [] - for n, variable_list in enumerate(self.raw_descr): - out += [ - (self.separator, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[n % 26] * (n // 26 + 1)), - *variable_list, - ] - return out - - @property - def branch(self) -> "BranchDescriptor": - descr: list[list[tuple[str, Any]]] = [] - ind: list[list[int]] = [] - for i, l in enumerate(self.raw_descr): - descr.append([]) - ind.append([]) - for j, (k, v) in enumerate(l): - if k != "num": - descr[-1].append((k, v)) - ind[-1].append(self.index[i][j]) - return BranchDescriptor(raw_descr=descr, index=ind, separator=self.separator) - - @property - def identifier(self) -> str: - unique_id = hash(str(self.flat)) - self.__ids.setdefault(unique_id, len(self.__ids)) - return "u_" + str(self.__ids[unique_id]) - - @property - def parent(self) -> Optional["VariationDescriptor"]: - if len(self.raw_descr) < 2: - return None - return VariationDescriptor( - raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator - ) - - @property - def num_fibers(self) -> int: - return len(self.raw_descr) - - -class BranchDescriptor(VariationDescriptor): - __ids: dict[int, int] = {} - - @property - def identifier(self) -> str: - branch_id = hash(str(self.flat)) - self.__ids.setdefault(branch_id, len(self.__ids)) - return "b_" + str(self.__ids[branch_id]) - - @validator("raw_descr") - def validate_raw_descr(cls, v): - return tuple(tuple(el for el in variable if el[0] != "num") for variable in v) - - -class DescriptorDict(Generic[T]): - def __init__(self, dico: dict[VariationDescriptor, T] = None): - self.dico: dict[tuple[tuple[tuple[str, Any], ...], ...], tuple[VariationDescriptor, T]] = {} - if dico is not None: - for k, v in dico.items(): - self[k] = v - - def __setitem__(self, key: VariationDescriptor, value: T): - if not isinstance(key, VariationDescriptor): - raise TypeError("key must be a VariationDescriptor instance") - self.dico[key.raw_descr] = (key, value) - - def __getitem__( - self, key: Union[VariationDescriptor, tuple[tuple[tuple[str, Any], ...], ...]] - ) -> T: - if isinstance(key, VariationDescriptor): - return self.dico[key.raw_descr][1] - else: - return self.dico[key][1] - - def items(self) -> Iterator[tuple[VariationDescriptor, T]]: - for k, v in self.dico.items(): - yield k, v[1] - - def keys(self) -> list[VariationDescriptor]: - return [v[0] for v in self.dico.values()] - - def values(self) -> list[T]: - return [v[1] for v in self.dico.values()]