diff --git a/README.md b/README.md index d5890e5..5405150 100644 --- a/README.md +++ b/README.md @@ -141,8 +141,8 @@ hasan : capillary_num : int number of capillaries -capillary_outer_d : float, optional if g is specified - outer diameter of the capillaries +capillary_radius : float, optional if g is specified + outer radius of the capillaries capillary_thickness : float thickness of the capillary walls diff --git a/play.py b/play.py index 9f687f3..d0cc8a8 100644 --- a/play.py +++ b/play.py @@ -4,22 +4,3 @@ import scgenerator as sc import matplotlib.pyplot as plt from pathlib import Path from pprint import pprint - - -def _main(): - print(os.getcwd()) - for v_list, params in sc.Configuration("PM1550+PM2000D+PM1550/Pos30000.toml"): - print(params.fiber_map) - - -def main(): - drr = os.getcwd() - os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations") - try: - _main() - finally: - os.chdir(drr) - - -if __name__ == "__main__": - main() diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 4c9e73d..c5f3af3 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -1,8 +1,25 @@ -from . import math, utils +from . import math from .math import abs2, argclosest, span from .physics import fiber, materials, pulse, simulate, units from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation -from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot -from .spectra import Pulse, Spectrum -from .utils import Paths, open_config, parameter -from .utils.parameter import Configuration, Parameters, PlotRange +from .plotting import ( + mean_values_plot, + plot_spectrogram, + propagation_plot, + single_position_plot, + transform_2D_propagation, + transform_1D_values, + transform_mean_values, + get_extent, +) +from .spectra import Spectrum, SimulationSeries +from ._utils import Paths, _open_config, parameter, open_single_config +from ._utils.parameter import Configuration, Parameters +from ._utils.utils import PlotRange +from ._utils.legacy import convert_sim_folder +from ._utils.variationer import ( + Variationer, + VariationDescriptor, + VariationSpecsError, + DescriptorDict, +) diff --git a/src/scgenerator/_utils/__init__.py b/src/scgenerator/_utils/__init__.py new file mode 100644 index 0000000..b540d9e --- /dev/null +++ b/src/scgenerator/_utils/__init__.py @@ -0,0 +1,325 @@ +""" +This files includes utility functions designed more or less to be used specifically with the +scgenerator module but some function may be used in any python program + +""" + +from __future__ import annotations + +import itertools +import multiprocessing +import os +import random +import re +import shutil +import threading +from collections import abc +from io import StringIO +from pathlib import Path +from string import printable as str_printable +from functools import cache +from typing import Any, Callable, Generator, Iterable, MutableMapping, Sequence, TypeVar, Union + + +import numpy as np +import pkg_resources as pkg +import toml +from tqdm import tqdm + +from .pbar import PBars +from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN1, Z_FN, __version__ +from ..env import pbar_policy +from ..logger import get_logger + +T_ = TypeVar("T_") + +PathTree = list[tuple[Path, ...]] + + +class Paths: + _data_files = [ + "materials.toml", + "hr_t.npz", + "submit_job_template.txt", + "start_worker.sh", + "start_head.sh", + ] + + paths = { + f.split(".")[0]: os.path.abspath( + pkg.resource_filename("scgenerator", os.path.join("data", f)) + ) + for f in _data_files + } + + @classmethod + def get(cls, key): + if key not in cls.paths: + if os.path.exists("paths.toml"): + with open("paths.toml") as file: + paths_dico = toml.load(file) + for k, v in paths_dico.items(): + cls.paths[k] = v + if key not in cls.paths: + get_logger(__name__).info( + f"{key} was not found in path index, returning current working directory." + ) + cls.paths[key] = os.getcwd() + + return cls.paths[key] + + @classmethod + def gets(cls, key): + """returned the specified file as a string""" + with open(cls.get(key)) as file: + return file.read() + + @classmethod + def plot(cls, name): + """returns the paths to the specified plot. Used to save new plot + example + --------- + fig.savefig(Paths.plot("figure5.pdf")) + """ + return os.path.join(cls.get("plots"), name) + + +def load_previous_spectrum(prev_data_dir: str) -> np.ndarray: + prev_data_dir = Path(prev_data_dir) + num = find_last_spectrum_num(prev_data_dir) + return load_spectrum(prev_data_dir / SPEC1_FN.format(num)) + + +@cache +def load_spectrum(folder: os.PathLike) -> np.ndarray: + return np.load(folder) + + +def conform_toml_path(path: os.PathLike) -> str: + path: str = str(path) + if not path.lower().endswith(".toml"): + path = path + ".toml" + return path + + +def open_single_config(path: os.PathLike) -> dict[str, Any]: + d = _open_config(path) + f = d.pop("Fiber")[0] + return d | f + + +def _open_config(path: os.PathLike): + """returns a dictionary parsed from the specified toml file + This also handle having a 'INCLUDE' argument that will fill + otherwise unspecified keys with what's in the INCLUDE file(s)""" + + path = conform_toml_path(path) + dico = resolve_loadfile_arg(load_toml(path)) + + dico.setdefault("variable", {}) + for key in {"simulation", "fiber", "gas", "pulse"} & dico.keys(): + section = dico.pop(key) + dico["variable"].update(section.pop("variable", {})) + dico.update(section) + if len(dico["variable"]) == 0: + dico.pop("variable") + return dico + + +def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]: + if (f_list := dico.pop("INCLUDE", None)) is not None: + if isinstance(f_list, str): + f_list = [f_list] + for to_load in f_list: + loaded = load_toml(to_load) + for k, v in loaded.items(): + if k not in dico and k not in dico.get("variable", {}): + dico[k] = v + for k, v in dico.items(): + if isinstance(v, MutableMapping): + dico[k] = resolve_loadfile_arg(v) + elif isinstance(v, Sequence): + for i, vv in enumerate(v): + if isinstance(vv, MutableMapping): + dico[k][i] = resolve_loadfile_arg(vv) + return dico + + +def load_toml(descr: os.PathLike) -> dict[str, Any]: + descr = str(descr) + if ":" in descr: + path, entry = descr.split(":", 1) + with open(path) as file: + return toml.load(file)[entry] + else: + with open(descr) as file: + return toml.load(file) + + +def save_toml(path: os.PathLike, dico): + """saves a dictionary into a toml file""" + path = conform_toml_path(path) + with open(path, mode="w") as file: + toml.dump(dico, file) + return dico + + +def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]: + """loads a configuration file + + Parameters + ---------- + path : os.PathLike + path to the config toml file or a directory containing config files + + Returns + ------- + final_path : Path + output name of the simulation + list[dict[str, Any]] + one config per fiber + + """ + path = Path(path) + fiber_list: list[dict[str, Any]] + if path.name.lower().endswith(".toml"): + loaded_config = _open_config(path) + fiber_list = loaded_config.pop("Fiber") + else: + loaded_config = dict(name=path.name) + fiber_list = [_open_config(p) for p in sorted(path.glob("initial_config*.toml"))] + + if len(fiber_list) == 0: + raise ValueError(f"No fiber in config {path}") + final_path = loaded_config.get("name") + configs = [] + for i, params in enumerate(fiber_list): + params.setdefault("variable", {}) + configs.append(loaded_config | params) + configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"] + configs[0]["variable"]["num"] = list(range(configs[0].get("repeat", 1))) + + return Path(final_path), configs + + +def load_material_dico(name: str) -> dict[str, Any]: + """loads a material dictionary + Parameters + ---------- + name : str + name of the material + Returns + ---------- + material_dico : dict + """ + return toml.loads(Paths.gets("materials"))[name] + + +def save_data(data: np.ndarray, data_dir: Path, file_name: str): + """saves numpy array to disk + + Parameters + ---------- + data : np.ndarray + data to save + file_name : str + file name + task_id : int + id that uniquely identifies the process + identifier : str, optional + identifier in the main data folder of the task, by default "" + """ + path = data_dir / file_name + np.save(path, data) + get_logger(__name__).debug(f"saved data in {path}") + return + + +def ensure_folder(path: Path, prevent_overwrite: bool = True, mkdir=True) -> Path: + """ensure a folder exists and doesn't overwrite anything if required + + Parameters + ---------- + path : Path + desired path + prevent_overwrite : bool, optional + whether to create a new directory when one already exists, by default True + + Returns + ------- + Path + final path + """ + + path = path.resolve() + + # is path root ? + if len(path.parts) < 2: + return path + + # is a part of path an existing *file* ? + parts = path.parts + path = Path(path.root) + for part in parts: + if path.is_file(): + path = ensure_folder(path, mkdir=mkdir, prevent_overwrite=False) + path /= part + + folder_name = path.name + + for i in itertools.count(): + if not path.is_file() and (not prevent_overwrite or not path.is_dir()): + if mkdir: + path.mkdir(exist_ok=True) + return path + path = path.parent / (folder_name + f"_{i}") + + +def branch_id(branch: tuple[Path, ...]) -> str: + return branch[-1].name.split()[1] + + +def find_last_spectrum_num(data_dir: Path): + for num in itertools.count(1): + p_to_test = data_dir / SPEC1_FN.format(num) + if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0: + return num - 1 + + +def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray: + threshold = y.min() + rel_thr * (y.max() - y.min()) + above_threshold = y > threshold + ind = np.argsort(x) + valid_ind = [ + np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k + ] + ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0] + width = len(ind_above) + return np.concatenate( + ( + np.arange(max(ind_above[0] - width, 0), ind_above[0]), + ind_above, + np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)), + ) + ) + + +def translate_parameters(d: dict[str, Any]) -> dict[str, Any]: + old_names = dict( + interp_degree="interpolation_degree", + beta="beta2_coefficients", + interp_range="interpolation_range", + ) + deleted_names = {"lower_wavelength_interp_limit", "upper_wavelength_interp_limit"} + defaults_to_add = dict(repeat=1) + new = {} + for k, v in d.items(): + if k == "error_ok": + new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v + elif k in deleted_names: + continue + elif isinstance(v, MutableMapping): + new[k] = translate_parameters(v) + else: + new[old_names.get(k, k)] = v + return defaults_to_add | new diff --git a/src/scgenerator/utils/cache.py b/src/scgenerator/_utils/cache.py similarity index 100% rename from src/scgenerator/utils/cache.py rename to src/scgenerator/_utils/cache.py diff --git a/src/scgenerator/_utils/legacy.py b/src/scgenerator/_utils/legacy.py new file mode 100644 index 0000000..fbd2c62 --- /dev/null +++ b/src/scgenerator/_utils/legacy.py @@ -0,0 +1,99 @@ +from genericpath import exists +import os +import sys +from pathlib import Path +from pprint import pprint +from typing import Any, Set + +import numpy as np +import toml + +from ..const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1, Z_FN +from .parameter import Configuration, Parameters +from .utils import fiber_folder, save_parameters +from .pbar import PBars +from .variationer import VariationDescriptor, Variationer + + +def load_config(path: os.PathLike) -> dict[str, Any]: + with open(path) as file: + d = toml.load(file) + d.setdefault("variable", {}) + return d + + +def load_config_sequence(path: os.PathLike) -> tuple[list[Path], list[dict[str, Any]]]: + paths = sorted(list(Path(path).glob("initial_config*.toml"))) + return paths, [load_config(cfg) for cfg in paths] + + +def convert_sim_folder(path: os.PathLike): + path = Path(path).resolve() + new_root = path.parent / "sc_legagy_converter" / path.name + os.makedirs(new_root, exist_ok=True) + config_paths, configs = load_config_sequence(path) + master_config = dict(name=path.name, Fiber=configs) + with open(new_root / "initial_config.toml", "w") as f: + toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder()) + configuration = Configuration(path, final_output_path=new_root) + pbar = PBars(configuration.total_num_steps, "Converting") + + new_paths: dict[VariationDescriptor, Parameters] = dict(configuration) + old_paths: Set[Path] = set() + old2new: list[tuple[Path, VariationDescriptor, Parameters, tuple[int, int]]] = [] + for descriptor, params in configuration.iterate_single_fiber(-1): + old_path = path / descriptor.branch.formatted_descriptor() + if not Path(old_path).is_dir(): + raise FileNotFoundError(f"missing {old_path} from {path}. Aborting.") + old_paths.add(old_path) + for d in descriptor.iter_parents(): + z_num_start = sum(c["z_num"] for c in configs[: d.num_fibers - 1]) + z_limits = (z_num_start, z_num_start + params.z_num) + old2new.append((old_path, d, new_paths[d], z_limits)) + + processed_paths: Set[Path] = set() + processed_specs: Set[VariationDescriptor] = set() + + for old_path, descr, new_params, (start_z, end_z) in old2new: + move_specs = descr not in processed_specs + processed_specs.add(descr) + if (parent := descr.parent) is not None: + new_params.prev_data_dir = str(new_paths[parent].final_path) + save_parameters(new_params.prepare_for_dump(), new_params.final_path) + for spec_num in range(start_z, end_z): + old_spec = old_path / SPECN_FN1.format(spec_num) + if move_specs: + _mv_specs(pbar, new_params, start_z, spec_num, old_spec) + old_spec.unlink() + if old_path not in processed_paths: + (old_path / PARAM_FN).unlink() + (old_path / Z_FN).unlink() + processed_paths.add(old_path) + + for old_path in processed_paths: + old_path.rmdir() + + for cp in config_paths: + cp.unlink() + + +def _mv_specs(pbar: PBars, new_params: Parameters, start_z: int, spec_num: int, old_spec: Path): + os.makedirs(new_params.final_path, exist_ok=True) + spec_data = np.load(old_spec) + for j, spec1 in enumerate(spec_data): + if j == 0: + np.save(new_params.final_path / SPEC1_FN.format(spec_num - start_z), spec1) + else: + np.save( + new_params.final_path / SPEC1_FN_N.format(spec_num - start_z, j), + spec1, + ) + pbar.update() + + +def main(): + convert_sim_folder(sys.argv[1]) + + +if __name__ == "__main__": + main() diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/_utils/parameter.py similarity index 67% rename from src/scgenerator/utils/parameter.py rename to src/scgenerator/_utils/parameter.py index ff161ea..039cedd 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/_utils/parameter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime as datetime_module import enum import inspect @@ -10,15 +12,30 @@ from copy import copy, deepcopy from dataclasses import asdict, dataclass, fields from functools import cache, lru_cache from pathlib import Path -from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union +from typing import ( + Any, + Callable, + Generator, + Iterable, + Iterator, + Literal, + Optional, + Sequence, + TypeVar, + Union, +) + import numpy as np from numpy.lib import isin -from .. import math, utils +from .. import _utils as utils +from .. import env, math +from .._utils.variationer import VariationDescriptor, Variationer from ..const import PARAM_FN, PARAM_SEPARATOR, __version__ from ..errors import EvaluatorError, NoDefaultError from ..logger import get_logger from ..physics import fiber, materials, pulse, units +from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path T = TypeVar("T") @@ -38,7 +55,7 @@ VALID_VARIABLE = { "effective_mode_diameter", "core_radius", "capillary_num", - "capillary_outer_d", + "capillary_radius", "capillary_thickness", "capillary_spacing", "capillary_resonance_strengths", @@ -69,6 +86,7 @@ VALID_VARIABLE = { "interpolation_degree", "ideal_gas", "length", + "num", } MANDATORY_PARAMETERS = [ @@ -256,7 +274,7 @@ class Parameter: ---------- tpe : type type of the paramter - validators : Callable[[str, Any], None] + validator : Callable[[str, Any], None] signature : validator(name, value) must raise a ValueError when value doesn't fit the criteria checked by validator. name is passed to validator to be included in the error message @@ -290,7 +308,6 @@ class Parameter: if isinstance(value, Parameter): defaut = None if self.default is None else copy(self.default) instance.__dict__[self.name] = defaut - # instance.__dict__[self.name] = None else: if value is not None: if self.converter is not None: @@ -298,7 +315,7 @@ class Parameter: self.validator(self.name, value) instance.__dict__[self.name] = value - def display(self, num: float): + def display(self, num: float) -> str: if self.display_info is None: return str(num) else: @@ -309,18 +326,23 @@ class Parameter: return f"{num_str} {unit}" -def fiber_map_converter(d: dict[str, str]) -> list[tuple[float, str]]: - if isinstance(d, dict): - return [(float(k), v) for k, v in d.items()] - else: - return [(float(k), v) for k, v in d] +@dataclass +class _AbstractParameters: + @classmethod + def __init_subclass__(cls): + cls.register_param_formatters() + + @classmethod + def register_param_formatters(cls): + for k, v in cls.__dict__.items(): + if isinstance(v, Parameter): + VariationDescriptor.register_formatter(k, v.display) @dataclass -class Parameters: +class Parameters(_AbstractParameters): """ - This class defines each valid parameter's name, type and valid value. Initializing - such an obj will automatically compute all possible parameters + This class defines each valid parameter's name, type and valid value. """ # root @@ -352,7 +374,7 @@ class Parameters: ) length: float = Parameter(non_negative(float, int)) capillary_num: int = Parameter(positive(int)) - capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3)) + capillary_radius: float = Parameter(in_range_excl(0, 1e-3)) capillary_thickness: float = Parameter(in_range_excl(0, 1e-3)) capillary_spacing: float = Parameter(in_range_excl(0, 1e-3)) capillary_resonance_strengths: Iterable[float] = Parameter(num_list, default=[]) @@ -430,15 +452,13 @@ class Parameters: const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) beta_func: Callable[[float], list[float]] = Parameter(func_validator) gamma_func: Callable[[float], float] = Parameter(func_validator) - fiber_map: list[tuple[float, str]] = Parameter( - validator_list(type_checker(tuple)), converter=fiber_map_converter - ) + + num: int = Parameter(non_negative(int)) datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) version: str = Parameter(string) def prepare_for_dump(self) -> dict[str, Any]: param = asdict(self) - param["fiber_map"] = [(str(z), n) for z, n in param.get("fiber_map", [])] param = Parameters.strip_params_dict(param) param["datetime"] = datetime_module.datetime.now() param["version"] = __version__ @@ -461,7 +481,7 @@ class Parameters: @classmethod def load(cls, path: os.PathLike) -> "Parameters": - return cls(**utils.open_config(path)) + return cls(**utils._open_config(path)) @classmethod def load_and_compute(cls, path: os.PathLike) -> "Parameters": @@ -512,6 +532,12 @@ class Parameters: return out + @property + def final_path(self) -> Path: + if self.output_path is not None: + return Path(update_path(self.output_path)) + return None + class Rule: def __init__( @@ -769,9 +795,12 @@ class Configuration: obj with the output path of the simulation saved in its output_path attribute. """ - master_configs: list[dict[str, Any]] - sim_dirs: list[Path] + fiber_configs: list[dict[str, Any]] + vary_dicts: list[dict[str, list]] + master_config: dict[str, Any] + fiber_paths: list[Path] num_sim: int + num_fibers: int repeat: int z_num: int total_num_steps: int @@ -779,19 +808,17 @@ class Configuration: parallel: bool overwrite: bool final_path: str - all_configs_dict: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"] - all_configs_list: list[list["Configuration.__SimConfig"]] + all_configs: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"] @dataclass(frozen=True) class __SimConfig: - vary_list: list[tuple[str, Any]] + descriptor: VariationDescriptor config: dict[str, Any] output_path: Path - index: tuple[tuple[int, ...], ...] @property def sim_num(self) -> int: - return len(self.index) + return len(self.descriptor.index) class State(enum.Enum): COMPLETE = enum.auto() @@ -805,57 +832,70 @@ class Configuration: def __init__( self, - final_config_path: os.PathLike, + config_path: os.PathLike, overwrite: bool = True, + wait: bool = False, skip_callback: Callable[[int], None] = None, + final_output_path: os.PathLike = None, ): self.logger = get_logger(__name__) + self.wait = wait - self.master_configs, self.final_path = utils.load_config_sequence(final_config_path) - if self.final_path is None: - self.final_path = Parameters.name.default - self.name = Path(self.final_path).name + self.overwrite = overwrite + self.final_path, self.fiber_configs = utils.load_config_sequence(config_path) + self.final_path = env.get(env.OUTPUT_PATH, self.final_path) + if final_output_path is not None: + self.final_path = final_output_path + self.final_path = utils.ensure_folder( + Path(self.final_path), + mkdir=False, + prevent_overwrite=not self.overwrite, + ) + self.master_config = self.fiber_configs[0].copy() + self.name = self.final_path.name self.z_num = 0 self.total_num_steps = 0 - self.sim_dirs = [] - self.overwrite = overwrite + self.fiber_paths = [] + self.all_configs = {} self.skip_callback = skip_callback - self.worker_num = self.master_configs[0].get("worker_num", max(1, os.cpu_count() // 2)) - self.repeat = self.master_configs[0].get("repeat", 1) + self.worker_num = self.master_config.get("worker_num", max(1, os.cpu_count() // 2)) + self.repeat = self.master_config.get("repeat", 1) + self.variationer = Variationer() - names = set() - for i, config in enumerate(self.master_configs): + fiber_names = set() + self.num_fibers = 0 + for i, config in enumerate(self.fiber_configs): + config.setdefault("name", Parameters.name.default) self.z_num += config["z_num"] - config.setdefault("name", f"{Parameters.name.default} {i}") - given_name = config["name"] - fn_i = 0 - while config["name"] in names: - config["name"] = given_name + f"_{fn_i}" - fn_i += 1 - names.add(config["name"]) - - self.sim_dirs.append( + fiber_names.add(config["name"]) + vary_dict = config.pop("variable") + self.variationer.append(vary_dict) + self.fiber_paths.append( utils.ensure_folder( - Path("_".join(["_", self.name, Path(config["name"]).name, "_"])), + self.final_path / fiber_folder(i, self.name, config["name"]), mkdir=False, prevent_overwrite=not self.overwrite, ) ) - self.__validate_variable(config) - self.__compute_sim_dirs() - [Evaluator.evaluate_default(c[0].config, True) for c in self.all_configs_list] - self.num_sim = len(self.all_configs_list[-1]) + self.__validate_variable(vary_dict) + self.num_fibers += 1 + Evaluator.evaluate_default( + self.__build_base_config() | config | {k: v[0] for k, v in vary_dict.items()}, True + ) + self.num_sim = self.variationer.var_num() self.total_num_steps = sum( - config["z_num"] * len(self.all_configs_list[i]) - for i, config in enumerate(self.master_configs) + config["z_num"] * self.variationer.var_num(i) + for i, config in enumerate(self.fiber_configs) ) - self.final_sim_dir = utils.ensure_folder( - Path(self.master_configs[-1]["name"]), mkdir=False, prevent_overwrite=not self.overwrite - ) - self.parallel = self.master_configs[0].get("parallel", Parameters.parallel.default) + self.parallel = self.master_config.get("parallel", Parameters.parallel.default) - def __validate_variable(self, config: dict[str, Any]): - for k, v in config.get("variable", {}).items(): + def __build_base_config(self): + cfg = self.master_config.copy() + vary = cfg.pop("variable", {}) + return cfg | {k: v[0] for k, v in vary.items()} + + def __validate_variable(self, vary_dict: dict[str, list]): + for k, v in vary_dict.items(): p = getattr(Parameters, k) validator_list(p.validator)("variable " + k, v) if k not in VALID_VARIABLE: @@ -863,76 +903,47 @@ class Configuration: if len(v) == 0: raise ValueError(f"variable parameter {k!r} must not be empty") - def __compute_sim_dirs(self): - self.all_configs_dict = {} - self.all_configs_list = [] - self.master_configs[0]["variable"]["num"] = list( - range(self.master_configs[0].get("repeat", 1)) - ) - dp = DataPather([c["variable"] for c in self.master_configs]) - for i, conf in enumerate(self.master_configs): - self.all_configs_list.append([]) - for sim_index, prev_path, this_path, this_vary in dp.all_vary_list(i): - this_conf = conf.copy() - if i > 0: - prev_path = utils.ensure_folder( - self.sim_dirs[i - 1] / prev_path, not self.overwrite, False - ) - this_conf["prev_data_dir"] = str(prev_path) + def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]: + for i in range(self.num_fibers): + yield from self.iterate_single_fiber(i) - this_path = utils.ensure_folder( - self.sim_dirs[i] / this_path, not self.overwrite, False - ) - this_conf.pop("variable") - conf_to_use = {k: v for k, v in this_vary if k != "num"} | this_conf - self.all_configs_dict[sim_index] = self.__SimConfig( - this_vary, conf_to_use, this_path, sim_index - ) - self.all_configs_list[i].append(self.all_configs_dict[sim_index]) - - def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]: - for i, sim_config_list in enumerate(self.all_configs_list): - for sim_config, params in self.__iter_1_sim(sim_config_list): - fiber_map = [] - for j in range(i + 1): - this_conf = self.all_configs_dict[sim_config.index[: j + 1]].config - if j > 0: - prev_conf = self.all_configs_dict[sim_config.index[:j]].config - length = prev_conf["length"] + fiber_map[j - 1][0] - else: - length = 0.0 - fiber_map.append((length, this_conf["name"])) - params.output_path = str(sim_config.output_path) - params.fiber_map = fiber_map - yield sim_config.vary_list, params - - def __iter_1_sim( - self, configs: list["Configuration.__SimConfig"] - ) -> Generator[tuple["Configuration.__SimConfig", Parameters], None, None]: + def iterate_single_fiber(self, index: int) -> Iterator[tuple[VariationDescriptor, Parameters]]: """iterates through the parameters of only one fiber. It takes care of recovering partially completed simulations, skipping complete ones and waiting for the previous fiber to finish Parameters ---------- - configs : list[__SimConfig] - list of configuration obj + index : int + which fiber to iterate over Yields ------- __SimConfig configuration obj - Parameters - computed Parameters obj """ - sim_dict: dict[Path, Configuration.__SimConfig] = {s.output_path: s for s in configs} + if index < 0: + index = self.num_fibers + index + sim_dict: dict[Path, Configuration.__SimConfig] = {} + for descriptor in self.variationer.iterate(index): + cfg = descriptor.update_config(self.fiber_configs[index]) + if index > 0: + cfg["prev_data_dir"] = str( + self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True) + ) + p = utils.ensure_folder( + self.fiber_paths[index] / descriptor.formatted_descriptor(True), + not self.overwrite, + False, + ) + cfg["output_path"] = str(p) + sim_config = self.__SimConfig(descriptor, cfg, p) + sim_dict[p] = self.all_configs[sim_config.descriptor.index] = sim_config while len(sim_dict) > 0: for data_dir, sim_config in sim_dict.items(): task, config_dict = self.__decide(sim_config) if task == self.Action.RUN: sim_dict.pop(data_dir) - p = Parameters(**config_dict) - p.compute() - yield sim_config, p + yield sim_config.descriptor, Parameters(**sim_config.config) if "recovery_last_stored" in config_dict and self.skip_callback is not None: self.skip_callback(config_dict["recovery_last_stored"]) break @@ -957,12 +968,14 @@ class Configuration: Returns ------- - str : {'run', 'wait', 'skip'} + str : Configuration.Action what to do config_dict : dict[str, Any] config dictionary. The only key possibly modified is 'prev_data_dir', which gets set if the simulation is partially completed """ + if not self.wait: + return self.Action.RUN, sim_config.config out_status, num = self.sim_status(sim_config.output_path, sim_config.config) if out_status == self.State.COMPLETE: return self.Action.SKIP, sim_config.config @@ -999,7 +1012,7 @@ class Configuration: num = utils.find_last_spectrum_num(data_dir) if config_dict is None: try: - config_dict = utils.open_config(data_dir / PARAM_FN) + config_dict = utils._open_config(data_dir / PARAM_FN) except FileNotFoundError: self.logger.warning(f"did not find {PARAM_FN!r} in {data_dir}") return self.State.ABSENT, 0 @@ -1013,9 +1026,12 @@ class Configuration: raise ValueError(f"Too many spectra in {data_dir}") def save_parameters(self): - for config, sim_dir in zip(self.master_configs, self.sim_dirs): - os.makedirs(sim_dir, exist_ok=True) - utils.save_toml(sim_dir / f"initial_config.toml", config) + os.makedirs(self.final_path, exist_ok=True) + cfgs = [ + cfg | dict(variable=self.variationer.all_dicts[i]) + for i, cfg in enumerate(self.fiber_configs) + ] + utils.save_toml(self.final_path / f"initial_config.toml", dict(name=self.name, Fiber=cfgs)) @property def first(self) -> Parameters: @@ -1023,336 +1039,6 @@ class Configuration: return param -class DataPather: - def __init__(self, dl: list[dict[str, Any]]): - self.dict_list = dl - - def vary_list_iterator( - self, index: int - ) -> Generator[tuple[tuple[tuple[int, ...]], list[list[tuple[str, Any]]]], None, None]: - """iterates through every possible combination of a list of dict of lists - - Parameters - ---------- - index : int - up to where in the stored dict_list to go - - Yields - ------- - list[list[tuple[str, Any]]] - list of list of (key, value) pairs - - Example - ------- - - self.dict_list = [{a:[56, 57], b:["?", "!"]}, {c:[0, -1]}] -> - [ - [[(a, 56), (b, "?")], [(c, 0)]], - [[(a, 56), (b, "?")], [(c, 1)]], - [[(a, 56), (b, "!")], [(c, 0)]], - [[(a, 56), (b, "!")], [(c, 1)]], - [[(a, 57), (b, "?")], [(c, 0)]], - [[(a, 57), (b, "?")], [(c, 1)]], - [[(a, 57), (b, "!")], [(c, 0)]], - [[(a, 57), (b, "!")], [(c, 1)]], - ] - """ - if index < 0: - index = len(self.dict_list) - index - d_tem_list = [el for d in self.dict_list[: index + 1] for el in d.items()] - dict_pos = np.cumsum([0] + [len(d) for d in self.dict_list[: index + 1]]) - ranges = [range(len(l)) for _, l in d_tem_list] - - for r in itertools.product(*ranges): - flat = [(d_tem_list[i][0], d_tem_list[i][1][j]) for i, j in enumerate(r)] - pos = tuple(r) - out = [flat[left:right] for left, right in zip(dict_pos[:-1], dict_pos[1:])] - pos = tuple(pos[left:right] for left, right in zip(dict_pos[:-1], dict_pos[1:])) - yield pos, out - - def all_vary_list(self, index): - for sim_index, l in self.vary_list_iterator(index): - unique_vary: list[tuple[str, Any]] = [] - for ll in l[: index + 1]: - for pname, pval in ll: - for i, (pn, _) in enumerate(unique_vary): - if pn == pname: - del unique_vary[i] - break - unique_vary.append((pname, pval)) - yield sim_index, format_variable_list( - reduce_all_variable(l[:index]), add_iden=True - ), format_variable_list(reduce_all_variable(l), add_iden=True), unique_vary - - def __repr__(self): - return f"DataPather([{', '.join(repr(d) for d in self.dict_list)}])" - - -@dataclass(frozen=True) -class PlotRange: - left: float = Parameter(type_checker(int, float)) - right: float = Parameter(type_checker(int, float)) - unit: Callable[[float], float] = Parameter(units.is_unit, converter=units.get_unit) - conserved_quantity: bool = Parameter(boolean, default=True) - - def __post_init__(self): - if self.left >= self.right: - raise ValueError( - f"left value {self.left!r} must be strictly smaller than right value {self.right!r}" - ) - - def __str__(self): - return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}" - - def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: - return sort_axis(axis, self) - - def __iter__(self): - yield self.left - yield self.right - yield self.unit.__name__ - - -def sort_axis( - axis: np.ndarray, plt_range: PlotRange -) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: - """convert an array according to the given range - - Parameters - ---------- - axis : np.ndarray, shape (n,) - array - plt_range : PlotRange - range to crop in - - Returns - ------- - np.ndarray - new array converted to the desired unit and cropped in the given range - np.ndarray - indices of the concerved values - tuple[float, float] - actual minimum and maximum of the new axis - - Example - ------- - >> sort_axis([18.0, 19.0, 20.0, 13.0, 15.2], PlotRange(1400, 1900, "cm")) - ([1520.0, 1800.0, 1900.0], [4, 0, 1], (1520.0, 1900.0)) - """ - if isinstance(plt_range, tuple): - plt_range = PlotRange(*plt_range) - - masked = np.ma.array(axis, mask=~np.isfinite(axis)) - converted = plt_range.unit.inv(masked) - converted[(converted < plt_range.left) | (converted > plt_range.right)] = np.ma.masked - indices = np.arange(len(axis))[~converted.mask] - cropped = converted.compressed() - order = cropped.argsort() - - return cropped[order], indices[order], (cropped.min(), cropped.max()) - - -def get_arg_names(func: Callable) -> list[str]: - spec = inspect.getfullargspec(func) - args = spec.args - if spec.defaults is not None and len(spec.defaults) > 0: - args = args[: -len(spec.defaults)] - return args - - -def validate_arg_names(names: list[str]): - for n in names: - if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: - raise ValueError(f"{n} is an invalid parameter name") - - -def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable: - if arg_names is None: - arg_names = get_arg_names(func) - else: - validate_arg_names(arg_names) - validate_arg_names(kwarg_names) - sign_arg_str = ", ".join(arg_names + kwarg_names) - call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) - tmp_name = f"{func.__name__}_0" - func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" - scope = dict(__func__=func) - exec(func_str, scope) - out_func = scope[tmp_name] - out_func.__module__ = "evaluator" - return out_func - - -@cache -def _mock_function(num_args: int, num_returns: int) -> Callable: - if not isinstance(num_args, int) and isinstance(num_returns, int): - raise TypeError(f"num_args and num_returns must be int") - arg_str = ", ".join("a" * (n + 1) for n in range(num_args)) - return_str = ", ".join("True" for _ in range(num_returns)) - func_name = f"__mock_{num_args}_{num_returns}" - func_str = f"def {func_name}({arg_str}):\n return {return_str}" - scope = {} - exec(func_str, scope) - out_func = scope[func_name] - out_func.__module__ = "evaluator" - return out_func - - -def format_variable_list(l: list[tuple[str, Any]], add_iden=False) -> str: - """formats a variable list into a str such that each simulation has a unique - directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations) - branch identifier are added at the beginning. - - Parameters - ---------- - l : list[tuple[str, Any]] - list of variable parameters - add_iden : bool - add unique simulation and parameter-set identifiers - - Returns - ------- - str - directory name - """ - str_list = [] - for p_name, p_value in l: - ps = p_name.replace("/", "").replace(PARAM_SEPARATOR, "") - vs = format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "") - str_list.append(ps + PARAM_SEPARATOR + vs) - tmp_name = PARAM_SEPARATOR.join(str_list) - if not add_iden: - return tmp_name - unique_id = unique_identifier(l) - branch_id = branch_identifier(l) - return unique_id + PARAM_SEPARATOR + branch_id + PARAM_SEPARATOR + tmp_name - - -def branch_identifier(l): - branch_id = "b_" + utils.to_62(hash(str([el for el in l if el[0] != "num"]))) - return branch_id - - -def unique_identifier(l): - unique_id = "u_" + utils.to_62(hash(str(l))) - return unique_id - - -def format_value(name: str, value) -> str: - if value is True or value is False: - return str(value) - elif isinstance(value, (float, int)): - try: - return getattr(Parameters, name).display(value) - except AttributeError: - return format(value, ".9g") - elif isinstance(value, (list, tuple, np.ndarray)): - return "-".join([str(v) for v in value]) - elif isinstance(value, str): - p = Path(value) - if p.exists(): - return p.stem - return str(value) - - -def pretty_format_value(name: str, value) -> str: - try: - return getattr(Parameters, name).display(value) - except AttributeError: - return name + PARAM_SEPARATOR + str(value) - - -def pretty_format_from_sim_name(name: str) -> str: - """formats a pretty version of a simulation directory - - Parameters - ---------- - name : str - name of the simulation (directory name) - - Returns - ------- - str - prettier name - """ - s = name.split(PARAM_SEPARATOR) - out = [] - for key, value in zip(s[::2], s[1::2]): - try: - out += [key.replace("_", " "), getattr(Parameters, key).display(float(value))] - except (AttributeError, ValueError): - out.append(key + PARAM_SEPARATOR + value) - return PARAM_SEPARATOR.join(out) - - -def variable_iterator( - config: dict[str, Any], first: bool -) -> Generator[tuple[list[tuple[str, Any]], dict[str, Any]], None, None]: - """given a config with "variable" parameters, iterates through every possible combination, - yielding a a list of (parameter_name, value) tuples and a full config dictionary. - - Parameters - ---------- - config : BareConfig - initial config obj - first : int - whether it is the first fiber or not (only the first fiber get a sim number) - - Yields - ------- - Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]] - variable_list : a list of (name, value) tuple of parameter name and value that are variable. - - params : a dict[str, Any] to be fed to Parameters - """ - possible_keys = [] - possible_ranges = [] - - for key, values in config.get("variable", {}).items(): - possible_keys.append(key) - possible_ranges.append(range(len(values))) - - combinations = itertools.product(*possible_ranges) - - master_index = 0 - repeat = config.get("repeat", 1) if first else 1 - for combination in combinations: - indiv_config = {} - variable_list = [] - for i, key in enumerate(possible_keys): - parameter_value = config["variable"][key][combination[i]] - indiv_config[key] = parameter_value - variable_list.append((key, parameter_value)) - param_dict = deepcopy(config) - param_dict.pop("variable") - param_dict.update(indiv_config) - for repeat_index in range(repeat): - # variable_ind = [("id", master_index)] + variable_list - variable_ind = variable_list - if first: - variable_ind += [("num", repeat_index)] - yield variable_ind, param_dict - master_index += 1 - - -def reduce_all_variable(all_variable: list[list[tuple[str, Any]]]) -> list[tuple[str, Any]]: - out = [] - for n, variable_list in enumerate(all_variable): - out += [("fiber", "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[n % 26] * (n // 26 + 1)), *variable_list] - return out - - -def strip_vary_list(all_variable: T) -> T: - if len(all_variable) == 0: - return all_variable - elif isinstance(all_variable[0], Sequence) and ( - len(all_variable[0]) == 0 or not isinstance(all_variable[0][0], str) - ): - return [strip_vary_list(el) for el in all_variable] - else: - return [el for el in all_variable if el[0] != "num"] - - default_rules: list[Rule] = [ # Grid *Rule.deduce( @@ -1417,6 +1103,7 @@ default_rules: list[Rule] = [ priorities=[2, 2, 2], ), Rule("hr_w", fiber.delayed_raman_w), + Rule("n_gas_2", materials.n_gas_2), Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")), Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")), Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")), @@ -1426,12 +1113,13 @@ default_rules: list[Rule] = [ ["wl_for_disp", "pitch", "pitch_ratio"], conditions=dict(model="pcf"), ), - Rule("capillary_spacing", fiber.HCARF_gap), + Rule("capillary_spacing", fiber.capillary_spacing_hasan), # Fiber nonlinearity Rule("A_eff", fiber.A_eff_from_V), Rule("A_eff", fiber.A_eff_from_diam), Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")), Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1), + Rule("A_eff", fiber.A_eff_marcatili, priorities=-2), Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]), Rule("A_eff_arr", fiber.load_custom_A_eff), Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1), diff --git a/src/scgenerator/_utils/pbar.py b/src/scgenerator/_utils/pbar.py new file mode 100644 index 0000000..37db473 --- /dev/null +++ b/src/scgenerator/_utils/pbar.py @@ -0,0 +1,189 @@ +import multiprocessing +import os +import random +import threading +import typing +from collections import abc +from io import StringIO +from pathlib import Path +from typing import Iterable, Union + +from tqdm import tqdm + +from ..env import pbar_policy + +T_ = typing.TypeVar("T_") + + +class PBars: + def __init__( + self, + task: Union[int, Iterable[T_]], + desc: str, + num_sub_bars: int = 0, + head_kwargs=None, + worker_kwargs=None, + ) -> "PBars": + """creates a PBars obj + + Parameters + ---------- + task : int | Iterable + if int : total length of the main task + if Iterable : behaves like tqdm + desc : str + description of the main task + num_sub_bars : int + number of sub-tasks + + """ + self.id = random.randint(100000, 999999) + try: + self.width = os.get_terminal_size().columns + except OSError: + self.width = 80 + if isinstance(task, abc.Iterable): + self.iterator: Iterable[T_] = iter(task) + self.num_tot: int = len(task) + else: + self.num_tot: int = task + self.iterator = None + + self.policy = pbar_policy() + if head_kwargs is None: + head_kwargs = dict() + if worker_kwargs is None: + worker_kwargs = dict( + total=1, + desc="Worker {worker_id}", + bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]", + ) + if "print" not in pbar_policy(): + head_kwargs["file"] = worker_kwargs["file"] = StringIO() + self.width = 80 + head_kwargs["desc"] = desc + self.pbars = [tqdm(total=self.num_tot, ncols=self.width, ascii=False, **head_kwargs)] + for i in range(1, num_sub_bars + 1): + kwargs = {k: v for k, v in worker_kwargs.items()} + if "desc" in kwargs: + kwargs["desc"] = kwargs["desc"].format(worker_id=i) + self.append(tqdm(position=i, ncols=self.width, ascii=False, **kwargs)) + self.print_path = Path( + f"progress {self.pbars[0].desc.replace('/', '')} {self.id}" + ).resolve() + self.close_ev = threading.Event() + if "file" in self.policy: + self.thread = threading.Thread(target=self.print_worker, daemon=True) + self.thread.start() + + def print(self): + if "file" not in self.policy: + return + s = [] + for pbar in self.pbars: + s.append(str(pbar)) + self.print_path.write_text("\n".join(s)) + + def print_worker(self): + while True: + if self.close_ev.wait(2.0): + return + self.print() + + def __iter__(self): + with self as pb: + for thing in self.iterator: + yield thing + pb.update() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __getitem__(self, key): + return self.pbars[key] + + def update(self, i=None, value=1): + if i is None: + for pbar in self.pbars[1:]: + pbar.update(value) + elif i > 0: + self.pbars[i].update(value) + self.pbars[0].update() + + def append(self, pbar: tqdm): + self.pbars.append(pbar) + + def reset(self, i): + self.pbars[i].update(-self.pbars[i].n) + self.print() + + def close(self): + self.print() + self.close_ev.set() + if "file" in self.policy: + self.thread.join() + for pbar in self.pbars: + pbar.close() + + +class ProgressBarActor: + def __init__(self, name: str, num_workers: int, num_steps: int) -> None: + self.counters = [0 for _ in range(num_workers + 1)] + self.p_bars = PBars( + num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step") + ) + + def update(self, worker_id: int, rel_pos: float = None) -> None: + """update a counter + + Parameters + ---------- + worker_id : int + id of the worker. 0 is the overall progress + rel_pos : float, optional + if None, increase the counter by one, if set, will set + the counter to the specified value (instead of incrementing it), by default None + """ + if rel_pos is None: + self.counters[worker_id] += 1 + else: + self.counters[worker_id] = rel_pos + + def update_pbars(self): + for counter, pbar in zip(self.counters, self.p_bars.pbars): + pbar.update(counter - pbar.n) + + def close(self): + self.p_bars.close() + + +def progress_worker( + name: str, num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue +): + """keeps track of progress on a separate thread + + Parameters + ---------- + num_steps : int + total number of steps, used for the main progress bar (position 0) + progress_queue : multiprocessing.Queue + values are either + Literal[0] : stop the worker and close the progress bars + tuple[int, float] : worker id and relative progress between 0 and 1 + """ + with PBars( + num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step") + ) as pbars: + while True: + raw = progress_queue.get() + if raw == 0: + return + i, rel_pos = raw + if i > 0: + pbars[i].update(rel_pos - pbars[i].n) + pbars[0].update() + elif i == 0: + pbars[0].update(rel_pos) diff --git a/src/scgenerator/_utils/utils.py b/src/scgenerator/_utils/utils.py new file mode 100644 index 0000000..2931f09 --- /dev/null +++ b/src/scgenerator/_utils/utils.py @@ -0,0 +1,260 @@ +import inspect +import os +import re +from collections import defaultdict +from functools import cache +from pathlib import Path +from string import printable as str_printable +from typing import Any, Callable, Iterator, Set + +import numpy as np +import toml +from pydantic import BaseModel + +from .._utils import load_toml, save_toml +from ..const import PARAM_FN, PARAM_SEPARATOR, Z_FN +from ..physics.units import get_unit + + +class HashableBaseModel(BaseModel): + """Pydantic BaseModel that's immutable and can be hashed""" + + def __hash__(self) -> int: + return hash(type(self)) + sum(hash(v) for v in self.__dict__.values()) + + class Config: + allow_mutation = False + + +def to_62(i: int) -> str: + arr = [] + if i == 0: + return "0" + i = abs(i) + while i: + i, value = divmod(i, 62) + arr.append(str_printable[value]) + return "".join(reversed(arr)) + + +class PlotRange(HashableBaseModel): + left: float + right: float + unit: Callable[[float], float] + conserved_quantity: bool = True + + def __init__(self, left, right, unit, **kwargs): + super().__init__(left=left, right=right, unit=get_unit(unit), **kwargs) + + def __str__(self): + return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}" + + def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: + return sort_axis(axis, self) + + def __iter__(self): + yield self.left + yield self.right + yield self.unit.__name__ + + +def sort_axis( + axis: np.ndarray, plt_range: PlotRange +) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: + """ + given an axis, returns this axis cropped according to the given range, converted and sorted + + Parameters + ---------- + axis : 1D array containing the original axis (usual the w or t array) + plt_range : tupple (min, max, conversion_function) used to crop the axis + + Returns + ------- + cropped : the axis cropped, converted and sorted + indices : indices to use to slice and sort other array in the same fashion + extent : tupple with min and max of cropped + + Example + ------- + w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20)) + t = np.linspace(-10, 10, 400) + W, T = np.meshgrid(w, t) + y = np.exp(-W**2 - T**2) + + # Define ranges + rw = (-4, 4, s) + rt = (-2, 6, s) + + w, cw = sort_axis(w, rw) + t, ct = sort_axis(t, rt) + + # slice y according to the given ranges + y = y[ct][:, cw] + """ + if isinstance(plt_range, tuple): + plt_range = PlotRange(*plt_range) + r = np.array((plt_range.left, plt_range.right), dtype="float") + + indices = np.arange(len(axis))[ + (axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r))) + ] + cropped = axis[indices] + order = np.argsort(plt_range.unit.inv(cropped)) + indices = indices[order] + cropped = cropped[order] + out_ax = plt_range.unit.inv(cropped) + + return out_ax, indices, (out_ax[0], out_ax[-1]) + + +def get_arg_names(func: Callable) -> list[str]: + # spec = inspect.getfullargspec(func) + # args = spec.args + # if spec.defaults is not None and len(spec.defaults) > 0: + # args = args[: -len(spec.defaults)] + # return args + return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty] + + +def validate_arg_names(names: list[str]): + for n in names: + if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: + raise ValueError(f"{n} is an invalid parameter name") + + +def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable: + if arg_names is None: + arg_names = get_arg_names(func) + else: + validate_arg_names(arg_names) + validate_arg_names(kwarg_names) + sign_arg_str = ", ".join(arg_names + kwarg_names) + call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) + tmp_name = f"{func.__name__}_0" + func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" + scope = dict(__func__=func) + exec(func_str, scope) + out_func = scope[tmp_name] + out_func.__module__ = "evaluator" + return out_func + + +@cache +def _mock_function(num_args: int, num_returns: int) -> Callable: + arg_str = ", ".join("a" * (n + 1) for n in range(num_args)) + return_str = ", ".join("True" for _ in range(num_returns)) + func_name = f"__mock_{num_args}_{num_returns}" + func_str = f"def {func_name}({arg_str}):\n return {return_str}" + scope = {} + exec(func_str, scope) + out_func = scope[func_name] + out_func.__module__ = "evaluator" + return out_func + + +def combine_simulations(path: Path, dest: Path = None): + """combines raw simulations into one folder per branch + + Parameters + ---------- + path : Path + source of the simulations (must contain u_xx directories) + dest : Path, optional + if given, moves the simulations to dest, by default None + """ + paths: dict[str, list[Path]] = defaultdict(list) + if dest is None: + dest = path + + for p in path.glob("u_*b_*"): + if p.is_dir(): + paths[p.name.split()[1]].append(p) + for l in paths.values(): + l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0]) + for pulses in paths.values(): + new_path = dest / update_path(pulses[0].name) + os.makedirs(new_path, exist_ok=True) + for num, pulse in enumerate(pulses): + params_ok = False + for file in pulse.glob("*"): + if file.name == PARAM_FN: + if not params_ok: + update_params(new_path, file) + params_ok = True + else: + file.unlink() + elif file.name == Z_FN: + file.rename(new_path / file.name) + elif file.name.startswith("spectr") and num == 0: + file.rename(new_path / file.name) + else: + file.rename(new_path / (file.stem + f"_{num}" + file.suffix)) + pulse.rmdir() + + +def update_params(new_path: Path, file: Path): + params = load_toml(file) + if (p := params.get("prev_data_dir")) is not None: + p = Path(p) + params["prev_data_dir"] = str(p.parent / update_path(p.name)) + params["output_path"] = str(new_path) + save_toml(new_path / PARAM_FN, params) + file.unlink() + + +def save_parameters( + params: dict[str, Any], destination_dir: Path, file_name: str = PARAM_FN +) -> Path: + """saves a parameter dictionary. Note that is does remove some entries, particularly + those that take a lot of space ("t", "w", ...) + + Parameters + ---------- + params : dict[str, Any] + dictionary to save + destination_dir : Path + destination directory + + Returns + ------- + Path + path to newly created the paramter file + """ + file_path = destination_dir / file_name + os.makedirs(file_path.parent, exist_ok=True) + + # save toml of the simulation + with open(file_path, "w") as file: + toml.dump(params, file, encoder=toml.TomlNumpyEncoder()) + + return file_path + + +def update_path(p: str) -> str: + return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p) + + +def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str: + return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name]) + + +def simulations_list(path: os.PathLike) -> list[Path]: + """finds simulations folders contained in a parent directory + + Parameters + ---------- + path : os.PathLike + parent path + + Returns + ------- + list[Path] + Absolute Path to the simulation folder + """ + paths: list[Path] = [] + for pwd, _, files in os.walk(path): + if PARAM_FN in files: + paths.append(Path(pwd)) + paths.sort(key=lambda el: el.parent.name) + return [p for p in paths if p.parent.name == paths[-1].parent.name] diff --git a/src/scgenerator/_utils/variationer.py b/src/scgenerator/_utils/variationer.py new file mode 100644 index 0000000..cbd5e8c --- /dev/null +++ b/src/scgenerator/_utils/variationer.py @@ -0,0 +1,321 @@ +from math import prod +import itertools +from collections.abc import MutableMapping, Sequence +from pathlib import Path +from typing import Any, Callable, Generator, Generic, Iterable, Iterator, Optional, TypeVar, Union + +import numpy as np +from pydantic import validator +from pydantic.main import BaseModel + +from ..const import PARAM_SEPARATOR +from . import utils + +T = TypeVar("T") + + +class VariationSpecsError(ValueError): + pass + + +class Variationer: + """ + manages possible combinations of values given dicts of lists + + Example + ------- + `>> var = Variationer([dict(a=[1, 2]), [dict(b=["000", "111"], c=["a", "-1"])]]) + list(v.raw_descr for v in var.iterate()) + + [ + ((("a", 1),), (("b", "000"), ("c", "a"))), + ((("a", 1),), (("b", "111"), ("c", "-1"))), + ((("a", 2),), (("b", "000"), ("c", "a"))), + ((("a", 2),), (("b", "111"), ("c", "-1"))), + ]` + + """ + + all_indices: list[list[int]] + all_dicts: list[list[dict[str, list]]] + + def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]] = None): + self.all_indices = [] + self.all_dicts = [] + if variables is not None: + for i, el in enumerate(variables): + self.append(el) + + def append(self, var_list: Union[list[MutableMapping], MutableMapping]): + """append a list of variable parameter sets + each call to append creates a new group of parameters + + Parameters + ---------- + var_list : Union[list[MutableMapping], MutableMapping] + each dict in the list is treated as an independent parameter + this means that if for one dict, len > 1, the lists of possible values + must be the same length + + Example + ------- + `>> append([dict(wavelength=[800e-9, 900e-9], power=[1e3, 2e3]), dict(length=[3e-2, 3.5e-2, 4e-2])])` + + means that for every parameter variations, wavelength=800e-9 will always occur when power=1e3 and + vice versa, while length is free to vary independently + + Raises + ------ + VariationSpecsError + raised when possible values lists in a same dict are not the same length + """ + if not isinstance(var_list, Sequence): + var_list = [{k: v} for k, v in var_list.items()] + else: + var_list = list(var_list) + num_vars = [] + for d in var_list: + values = list(d.values()) + len_to_test = len(values[0]) + if not all(len(v) == len_to_test for v in values[1:]): + raise VariationSpecsError( + f"variable items should all have the same number of parameters" + ) + num_vars.append(len_to_test) + if len(num_vars) == 0: + num_vars = [1] + self.all_indices.append(num_vars) + self.all_dicts.append(var_list) + + def iterate(self, index: int = -1) -> Generator["VariationDescriptor", None, None]: + index = self.__index(index) + flattened_indices = sum(self.all_indices[: index + 1], []) + index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[: index + 1]]) + ranges = [range(i) for i in flattened_indices] + for r in itertools.product(*ranges): + out: list[list[tuple[str, Any]]] = [] + indicies: list[list[int]] = [] + for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])): + out.append([]) + indicies.append([]) + for value_index, var_d in zip(r[start:end], self.all_dicts[i]): + for k, v in var_d.items(): + out[-1].append((k, v[value_index])) + indicies[-1].append(value_index) + yield VariationDescriptor(raw_descr=out, index=indicies) + + def __index(self, index: int) -> int: + if index < 0: + index = len(self.all_indices) + index + return index + + def var_num(self, index: int = -1) -> int: + index = self.__index(index) + return max(1, prod(prod(el) for el in self.all_indices[: index + 1])) + + +class VariationDescriptor(BaseModel): + raw_descr: tuple[tuple[tuple[str, Any], ...], ...] + index: tuple[tuple[int, ...], ...] + separator: str = "fiber" + _format_registry: dict[str, Callable[..., str]] = {} + __ids: dict[int, int] = {} + + @classmethod + def register_formatter(cls, p_name: str, func: Callable[..., str]): + """register a function that formats a particular parameter + + Parameters + ---------- + p_name : str + name of the parameter + func : Callable[..., str] + function that takes as single argument the value of the parameter and returns a string + """ + cls._format_registry[p_name] = func + + class Config: + allow_mutation = False + + def formatted_descriptor(self, add_identifier=False) -> str: + """formats a variable list into a str such that each simulation has a unique + directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations) + branch identifier can added at the beginning. + + Parameters + ---------- + add_identifier : bool + add unique simulation and parameter-set identifiers + + Returns + ------- + str + simulation descriptor + """ + str_list = [] + + for p_name, p_value in self.flat: + ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "") + vs = self.format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "") + str_list.append(ps + PARAM_SEPARATOR + vs) + tmp_name = PARAM_SEPARATOR.join(str_list) + if not add_identifier: + return tmp_name + return ( + self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name + ) + + def format_value(self, name: str, value) -> str: + if value is True or value is False: + return str(value) + elif isinstance(value, (float, int)): + try: + return self._format_registry[name](value) + except KeyError: + return format(value, ".9g") + elif isinstance(value, (list, tuple, np.ndarray)): + return "-".join([str(v) for v in value]) + elif isinstance(value, str): + p = Path(value) + if p.exists(): + return p.stem + return str(value) + + def __getitem__(self, key) -> "VariationDescriptor": + return VariationDescriptor( + raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator + ) + + def __str__(self) -> str: + return self.formatted_descriptor(add_identifier=False) + + def __lt__(self, other: "VariationDescriptor") -> bool: + return self.raw_descr < other.raw_descr + + def __le__(self, other: "VariationDescriptor") -> bool: + return self.raw_descr <= other.raw_descr + + def __gt__(self, other: "VariationDescriptor") -> bool: + return self.raw_descr > other.raw_descr + + def __ge__(self, other: "VariationDescriptor") -> bool: + return self.raw_descr >= other.raw_descr + + def __eq__(self, other: "VariationDescriptor") -> bool: + return self.raw_descr == other.raw_descr + + def __hash__(self) -> int: + return hash(self.raw_descr) + + def __contains__(self, other: "VariationDescriptor") -> bool: + return all(el in self.raw_descr for el in other.raw_descr) + + def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]: + """updates a dictionary with the value of the descriptor + + Parameters + ---------- + cfg : dict[str, Any] + dict to be updated + index : int, optional + index of the fiber from which to apply the parameters, by default -1 + + Returns + ------- + dict[str, Any] + same as cfg but with key from the descriptor added/updated. + """ + out_cfg = cfg.copy() + out_cfg.pop("variable", None) + return out_cfg | {k: v for k, v in self.raw_descr[index]} + + def iter_parents(self) -> Iterator["VariationDescriptor"]: + if (p := self.parent) is not None: + yield from p.iter_parents() + yield self + + @property + def flat(self) -> list[tuple[str, Any]]: + out = [] + for n, variable_list in enumerate(self.raw_descr): + out += [ + (self.separator, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[n % 26] * (n // 26 + 1)), + *variable_list, + ] + return out + + @property + def branch(self) -> "BranchDescriptor": + descr: list[list[tuple[str, Any]]] = [] + ind: list[list[int]] = [] + for i, l in enumerate(self.raw_descr): + descr.append([]) + ind.append([]) + for j, (k, v) in enumerate(l): + if k != "num": + descr[-1].append((k, v)) + ind[-1].append(self.index[i][j]) + return BranchDescriptor(raw_descr=descr, index=ind, separator=self.separator) + + @property + def identifier(self) -> str: + unique_id = hash(str(self.flat)) + self.__ids.setdefault(unique_id, len(self.__ids)) + return "u_" + str(self.__ids[unique_id]) + + @property + def parent(self) -> Optional["VariationDescriptor"]: + if len(self.raw_descr) < 2: + return None + return VariationDescriptor( + raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator + ) + + @property + def num_fibers(self) -> int: + return len(self.raw_descr) + + +class BranchDescriptor(VariationDescriptor): + __ids: dict[int, int] = {} + + @property + def identifier(self) -> str: + branch_id = hash(str(self.flat)) + self.__ids.setdefault(branch_id, len(self.__ids)) + return "b_" + str(self.__ids[branch_id]) + + @validator("raw_descr") + def validate_raw_descr(cls, v): + return tuple(tuple(el for el in variable if el[0] != "num") for variable in v) + + +class DescriptorDict(Generic[T]): + def __init__(self, dico: dict[VariationDescriptor, T] = None): + self.dico: dict[tuple[tuple[tuple[str, Any], ...], ...], tuple[VariationDescriptor, T]] = {} + if dico is not None: + for k, v in dico.items(): + self[k] = v + + def __setitem__(self, key: VariationDescriptor, value: T): + if not isinstance(key, VariationDescriptor): + raise TypeError("key must be a VariationDescriptor instance") + self.dico[key.raw_descr] = (key, value) + + def __getitem__( + self, key: Union[VariationDescriptor, tuple[tuple[tuple[str, Any], ...], ...]] + ) -> T: + if isinstance(key, VariationDescriptor): + return self.dico[key.raw_descr][1] + else: + return self.dico[key][1] + + def items(self) -> Iterator[tuple[VariationDescriptor, T]]: + for k, v in self.dico.items(): + yield k, v[1] + + def keys(self) -> list[VariationDescriptor]: + return [v[0] for v in self.dico.values()] + + def values(self) -> list[T]: + return [v[1] for v in self.dico.values()] diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index 2532fdf..92b81ab 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -1,14 +1,13 @@ import argparse import os import re -import subprocess -import sys from collections import ChainMap from pathlib import Path import numpy as np -from .. import const, env, scripts, utils +from .. import const, env, scripts +from .. import _utils as utils from ..logger import get_logger from ..physics.fiber import dispersion_coefficients from ..physics.simulate import SequencialSimulations, run_simulation diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index c679236..b8843a2 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -20,7 +20,8 @@ def pbar_format(worker_id: int): SPEC1_FN = "spectrum_{}.npy" -SPECN_FN = "spectra_{}.npy" +SPECN_FN1 = "spectra_{}.npy" +SPEC1_FN_N = "spectrum_{}_{}.npy" Z_FN = "z.npy" PARAM_FN = "params.toml" PARAM_SEPARATOR = " " diff --git a/src/scgenerator/env.py b/src/scgenerator/env.py index 46c0d6e..13f3545 100644 --- a/src/scgenerator/env.py +++ b/src/scgenerator/env.py @@ -48,7 +48,7 @@ def data_folder(task_id: int) -> Optional[str]: return tmp -def get(key: str) -> Any: +def get(key: str, default=None) -> Any: str_value = os.environ.get(key) if isinstance(str_value, str): try: @@ -58,7 +58,7 @@ def get(key: str) -> Any: return t(str_value) except (ValueError, KeyError): pass - return None + return default def all_environ() -> Dict[str, str]: diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 20db3dc..505c6b7 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -3,7 +3,7 @@ from typing import Union import numpy as np from scipy.interpolate import griddata, interp1d from scipy.special import jn_zeros -from .utils.cache import np_cache +from ._utils.cache import np_cache pi = np.pi c = 299792458.0 diff --git a/src/scgenerator/physics/__init__.py b/src/scgenerator/physics/__init__.py index 36709a8..4048118 100644 --- a/src/scgenerator/physics/__init__.py +++ b/src/scgenerator/physics/__init__.py @@ -10,8 +10,7 @@ from scipy.optimize import minimize_scalar from .. import math from . import fiber, materials, units, pulse -from .. import utils -from ..utils import cache +from .._utils import cache, load_material_dico T = TypeVar("T") @@ -62,7 +61,7 @@ def material_dispersion( ) return disp else: - material_dico = utils.load_material_dico(material) + material_dico = load_material_dico(material) if ideal: n_gas_2 = materials.sellmeier(wavelengths, material_dico, pressure, temperature) + 1 else: diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 07653c8..eaa6370 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -8,9 +8,9 @@ from scipy.interpolate import interp1d from ..logger import get_logger -from .. import utils +from .. import _utils as utils from ..math import abs2, argclosest, power_fact, u_nm -from ..utils.cache import np_cache +from .._utils.cache import np_cache from . import materials as mat from . import units from .units import c, pi @@ -49,27 +49,6 @@ def is_dynamic_dispersion(pressure=None): return out -def HCARF_gap(core_radius: float, capillary_num: int, capillary_outer_d: float): - """computes the gap length between capillaries of a hollow core anti-resonance fiber - - Parameters - ---------- - core_radius : float - radius of the core (m) (from cented to edge of a capillary) - capillary_num : int - number of capillaries - capillary_outer_d : float - diameter of the capillaries including the wall thickness(m). The core together with the microstructure has a diameter of 2R + 2d - - Returns - ------- - gap : float - """ - return (core_radius + capillary_outer_d / 2) * 2 * np.sin( - pi / capillary_num - ) - capillary_outer_d - - def gvd_from_n_eff(n_eff: np.ndarray, wl_for_disp: np.ndarray): """computes the dispersion parameter D from an effective index of refraction n_eff Since computing gradients/derivatives of discrete arrays is not well defined on the boundary, it is @@ -193,6 +172,30 @@ def n_eff_marcatili_adjusted(wl_for_disp, n_gas_2, core_radius, he_mode=(1, 1), return np.sqrt(n_gas_2 - (wl_for_disp * u / (pipi * corrected_radius)) ** 2) +def A_eff_marcatili(core_radius: float) -> float: + """Effective mode-field area for fundamental mode hollow capillaries + + Parameters + ---------- + core_radius : float + radius of the core + + Returns + ------- + float + effective mode field area + """ + return 1.5 * core_radius ** 2 + + +def capillary_spacing_hasan( + capillary_num: int, capillary_radius: float, core_radius: float +) -> float: + return ( + 2 * (capillary_radius + core_radius) * np.sin(np.pi / capillary_num) - 2 * capillary_radius + ) + + @np_cache def n_eff_hasan( wl_for_disp: np.ndarray, diff --git a/src/scgenerator/physics/materials.py b/src/scgenerator/physics/materials.py index 812e60e..1adb5c4 100644 --- a/src/scgenerator/physics/materials.py +++ b/src/scgenerator/physics/materials.py @@ -5,14 +5,14 @@ from scipy.integrate import cumulative_trapezoid from ..logger import get_logger from . import units -from .. import utils +from .. import _utils from .units import NA, c, kB, me, e, hbar def n_gas_2( - wl_for_disp: np.ndarray, gas: str, pressure: float, temperature: float, ideal_gas: bool + wl_for_disp: np.ndarray, gas_name: str, pressure: float, temperature: float, ideal_gas: bool ): - material_dico = utils.load_material_dico(gas) + material_dico = _utils.load_material_dico(gas_name) if ideal_gas: n_gas_2 = sellmeier(wl_for_disp, material_dico, pressure, temperature) + 1 diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index ae570f1..f9e5102 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -23,8 +23,6 @@ from scipy.interpolate import UnivariateSpline from scipy.optimize import minimize_scalar from scipy.optimize.optimize import OptimizeResult -from scgenerator import utils - from ..defaults import default_plotting from ..logger import get_logger from ..math import * @@ -820,7 +818,8 @@ def find_lobe_limits(x_axis, values, debug="", already_sorted=True): ) ax.legend() fig.savefig(out_path, bbox_inches="tight") - plt.close() + if fig is not None: + plt.close(fig) else: good_roots, left_lim, right_lim = _select_roots(d_spline, d_roots, dd_roots, fwhm_pos) diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index dae90b7..39bd3c4 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -9,11 +9,15 @@ from typing import Any, Generator, Type, Union import numpy as np from send2trash import send2trash -from .. import env, utils +from .. import env +from .. import _utils as utils +from .._utils.utils import combine_simulations, save_parameters from ..logger import get_logger -from ..utils.parameter import Configuration, Parameters, format_variable_list +from .._utils.parameter import Configuration, Parameters +from .._utils.pbar import PBars, ProgressBarActor, progress_worker from . import pulse from .fiber import create_non_linear_op, fast_dispersion_op +from scgenerator._utils import pbar try: import ray @@ -215,6 +219,17 @@ class RK4IP: return self.stored_spectra def irun(self) -> Generator[tuple[int, int, np.ndarray], None, None]: + """run the simulation as a generator obj + + Yields + ------- + int + current simulation step + int + current number of spectra returned + np.ndarray + spectrum + """ # Print introduction self.logger.debug( @@ -332,7 +347,7 @@ class SequentialRK4IP(RK4IP): def __init__( self, params: Parameters, - pbars: utils.PBars, + pbars: PBars, save_data=False, job_identifier="", task_id=0, @@ -466,14 +481,14 @@ class Simulations: self.configuration = configuration - self.name = self.configuration.final_path - self.sim_dir = self.configuration.final_sim_dir + self.name = self.configuration.name + self.sim_dir = self.configuration.final_path self.configuration.save_parameters() self.sim_jobs_per_node = 1 def finished_and_complete(self): - for sim in self.configuration.all_configs_dict.values(): + for sim in self.configuration.all_configs.values(): if ( self.configuration.sim_status(sim.output_path)[0] != self.configuration.State.COMPLETE @@ -487,8 +502,9 @@ class Simulations: def _run_available(self): for variable, params in self.configuration: - v_list_str = format_variable_list(variable, add_iden=True) - utils.save_parameters(params.prepare_for_dump(), Path(params.output_path)) + params.compute() + v_list_str = variable.formatted_descriptor(True) + save_parameters(params.prepare_for_dump(), Path(params.output_path)) self.new_sim(v_list_str, params) self.finish() @@ -525,8 +541,10 @@ class SequencialSimulations(Simulations, priority=0): def __init__(self, configuration: Configuration, task_id): super().__init__(configuration, task_id=task_id) - self.pbars = utils.PBars( - self.configuration.total_num_steps, "Simulating " + self.configuration.final_path, 1 + self.pbars = PBars( + self.configuration.total_num_steps, + "Simulating " + self.configuration.final_path.name, + 1, ) self.configuration.skip_callback = lambda num: self.pbars.update(0, num) @@ -567,7 +585,7 @@ class MultiProcSimulations(Simulations, priority=1): for i in range(self.sim_jobs_per_node) ] self.p_worker = multiprocessing.Process( - target=utils.progress_worker, + target=progress_worker, args=( Path(self.configuration.final_path).name, self.sim_jobs_per_node, @@ -656,7 +674,7 @@ class RaySimulations(Simulations, priority=2): self.pool = ray.util.ActorPool(self.propagator.remote() for _ in range(self.sim_jobs_total)) self.num_submitted = 0 self.rolling_id = 0 - self.p_actor = ray.remote(utils.ProgressBarActor).remote( + self.p_actor = ray.remote(ProgressBarActor).remote( self.configuration.final_path, self.sim_jobs_total, self.configuration.total_num_steps ) self.configuration.skip_callback = lambda num: ray.get(self.p_actor.update.remote(0, num)) @@ -712,21 +730,13 @@ def run_simulation( config_file: os.PathLike, method: Union[str, Type[Simulations]] = None, ): - config = Configuration(config_file) + config = Configuration(config_file, wait=True) sim = new_simulation(config, method) sim.run() - path_trees = utils.build_path_trees(config.sim_dirs[-1]) - final_name = env.get(env.OUTPUT_PATH) - if final_name is None: - final_name = config.final_path - - utils.merge(final_name, path_trees) - try: - send2trash(config.sim_dirs) - except (PermissionError, OSError): - get_logger(__name__).error("Could not send temporary directories to trash") + for path in config.fiber_paths: + combine_simulations(path) def new_simulation( @@ -762,6 +772,8 @@ def parallel_RK4IP( ]: logger = get_logger(__name__) params = list(Configuration(config)) + for _, param in params: + param.compute() n = len(params) z_num = params[0][1].z_num diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index 80089fe..f648bb9 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -2,7 +2,6 @@ # For example, nm(X) means "I give the number X in nm, figure out the ang. freq." # to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ... -from dataclasses import dataclass from typing import Callable, TypeVar, Union import numpy as np diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index fbaa903..6163520 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -14,7 +14,8 @@ from .const import PARAM_SEPARATOR from .defaults import default_plotting as defaults from .math import abs2, span from .physics import pulse, units -from .utils.parameter import Parameters, PlotRange, sort_axis +from ._utils.parameter import Parameters +from ._utils.utils import PlotRange, sort_axis RangeType = tuple[float, float, Union[str, Callable]] NO_LIM = object() diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index ede8211..c740781 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -11,14 +11,13 @@ from .. import env, math from ..const import PARAM_FN, PARAM_SEPARATOR from ..physics import fiber, units from ..plotting import plot_setup -from ..spectra import Pulse -from ..utils import auto_crop, open_config, save_toml, translate_parameters -from ..utils.parameter import ( +from ..spectra import SimulationSeries +from .._utils import auto_crop, _open_config, save_toml, translate_parameters +from .._utils.parameter import ( Configuration, Parameters, - pretty_format_from_sim_name, - pretty_format_value, ) +from .._utils.utils import simulations_list def fingerprint(params: Parameters): @@ -33,7 +32,7 @@ def plot_all(sim_dir: Path, limits: list[str], show=False, **opts): opts[k] = int(v) if k in {"log", "renormalize"}: opts[k] = True if v == "True" else False - dir_list = list(p for p in sim_dir.glob("*") if p.is_dir()) + dir_list = simulations_list(sim_dir) if len(dir_list) == 0: dir_list = [sim_dir] limits = [ @@ -41,12 +40,12 @@ def plot_all(sim_dir: Path, limits: list[str], show=False, **opts): ] with tqdm(total=len(dir_list) * len(limits)) as bar: for p in dir_list: - pulse = Pulse(p) + pulse = SimulationSeries(p) for left, right, unit in limits: path, fig, ax = plot_setup( pulse.path.parent / ( - pretty_format_from_sim_name(pulse.path.name) + pulse.path.name + PARAM_SEPARATOR + f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}" ) @@ -259,7 +258,7 @@ def finish_plot(fig, legend_axes, all_labels, params): def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]: cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"]) - pseq = Configuration(open_config(config_path)) + pseq = Configuration(_open_config(config_path)) for style, (variables, params) in zip(cc, pseq): lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]] yield style, lbl, params @@ -268,7 +267,7 @@ def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters def convert_params(params_file: os.PathLike): p = Path(params_file) if p.name == PARAM_FN: - d = open_config(params_file) + d = _open_config(params_file) d = translate_parameters(d) save_toml(params_file, d) print(f"converted {p}") diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index a977f35..3fc4b6b 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -9,8 +9,8 @@ from typing import Tuple import numpy as np -from ..utils import Paths -from ..utils.parameter import Configuration +from .._utils import Paths +from .._utils.parameter import Configuration def primes(n): diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 1ec018c..fe9279d 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import os -from collections.abc import Sequence from pathlib import Path -from typing import Callable, Dict, Iterable, Optional, Union +from typing import Any, Callable, Iterator, Optional, Union import matplotlib.pyplot as plt import numpy as np from . import math -from .const import SPECN_FN +from ._utils import load_spectrum +from ._utils.parameter import Parameters +from ._utils.utils import PlotRange, simulations_list +from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N from .logger import get_logger from .physics import pulse, units from .plotting import ( @@ -16,8 +20,6 @@ from .plotting import ( single_position_plot, transform_2D_propagation, ) -from .utils.parameter import Parameters, PlotRange -from .utils import load_spectrum class Spectrum(np.ndarray): @@ -42,18 +44,6 @@ class Spectrum(np.ndarray): def __getitem__(self, key) -> "Spectrum": return super().__getitem__(key) - def energy(self) -> Union[np.ndarray, float]: - if self.ndim == 1: - m = np.argwhere(self.params.l > 0)[:, 0] - m = np.array(sorted(m, key=lambda el: self.params.l[el])) - return np.trapz(self.wl_int[m], self.params.l[m]) - else: - return np.array([s.energy() for s in self]) - - def crop_wl(self, left: float, right: float) -> np.ndarray: - cond = (self.params.l >= left) & (self.params.l <= right) - return cond - @property def wl_int(self): return units.to_WL(math.abs2(self), self.params.l) @@ -118,7 +108,7 @@ class Spectrum(np.ndarray): return self.params.l[np.argmax(self.wl_int, axis=-1)] return np.array([s.wl_max for s in self]) - def mask_wl(self, pos: float, width: float) -> "Spectrum": + def mask_wl(self, pos: float, width: float) -> Spectrum: return self * np.exp( -(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2) ) @@ -127,189 +117,105 @@ class Spectrum(np.ndarray): return pulse.measure_field(self.params.t, self.time_amp) -class Pulse(Sequence): - def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None): - """load a data folder as a pulse +class SimulationSeries: + path: Path + params: Parameters + total_length: float + total_num_steps: int + previous: SimulationSeries = None + fiber_lengths: list[tuple[str, float]] + fiber_positions: list[tuple[str, float]] + z_inds: np.ndarray - Parameters - ---------- - path : os.PathLike - path to the data (folder containing .npy files) - default_ind : int | Iterable[int], optional - default indices to be loaded, by default None - - Raises - ------ - FileNotFoundError - path does not contain proper data - """ - self.logger = get_logger(__name__) - self.path = Path(path) - self.default_ind = default_ind - - if not self.path.is_dir(): - raise FileNotFoundError(f"Folder {self.path} does not exist") - - self.params = Parameters.load(self.path / "params.toml") + def __init__(self, path: os.PathLike): + self.logger = get_logger() + for self.path in simulations_list(path): + break + else: + raise FileNotFoundError(f"No simulation in {path}") + self.params = Parameters.load(self.path / PARAM_FN) self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"]) - if self.params.fiber_map is None: - self.params.fiber_map = [(0.0, self.params.name)] - - try: - self.z = np.load(os.path.join(path, "z.npy")) - except FileNotFoundError: - if self.params is not None: - self.z = self.params.z_targets - else: - raise - self.nmax = len(list(self.path.glob("spectra_*.npy"))) - if self.nmax <= 0: - raise FileNotFoundError(f"No appropriate file in specified folder {self.path}") - self.t = self.params.t - w = math.wspace(self.t) + units.m(self.params.wavelength) - self.w_order = np.argsort(w) - self.w = w - self.wl = units.m.inv(self.w) - self.params.w = self.w - self.params.z_targets = self.z + self.w = self.params.w + if self.params.prev_data_dir is not None: + self.previous = SimulationSeries(self.params.prev_data_dir) + self.total_length = self.accumulate_params("length") + self.total_num_steps = self.accumulate_params("z_num") + self.z_inds = np.arange(len(self.params.z_targets)) + self.z = self.params.z_targets + if self.previous is not None: + self.z += self.previous.params.z_targets[-1] + self.params.z_targets = np.concatenate((self.previous.z, self.params.z_targets)) + self.z_inds += self.previous.z_inds[-1] + 1 + self.fiber_lengths = self.all_params("length") + self.fiber_positions = [ + (this[0], following[1]) + for this, following in zip(self.fiber_lengths, [(None, 0.0)] + self.fiber_lengths) + ] - def __iter__(self): - """ - similar to all_spectra but works as an iterator - """ - - self.logger.debug(f"iterating through {self.path}") - for i in range(self.nmax): - yield self._load1(i) - - def __len__(self): - return self.nmax - - def __getitem__(self, key) -> Spectrum: - return self.all_spectra(key) - - def intensity(self, unit): - if unit.type in ["WL", "FREQ", "AFREQ"]: - x_axis = unit.inv(self.w) - else: - x_axis = unit.inv(self.t) - - order = np.argsort(x_axis) - func = dict( - WL=self._to_wl_int, - FREQ=self._to_freq_int, - AFREQ=self._to_afreq_int, - TIME=self._to_time_int, - )[unit.type] - - for spec in self: - yield x_axis[order], func(spec)[:, order] - - def _to_wl_int(self, spectrum): - return units.to_WL(math.abs2(spectrum), spectrum.wl) - - def _to_freq_int(self, spectrum): - return math.abs2(spectrum) - - def _to_afreq_int(self, spectrum): - return math.abs2(spectrum) - - def _to_time_int(self, spectrum): - return math.abs2(np.fft.ifft(spectrum)) - - def amplitude(self, unit): - if unit.type in ["WL", "FREQ", "AFREQ"]: - x_axis = unit.inv(self.w) - else: - x_axis = unit.inv(self.t) - - order = np.argsort(x_axis) - func = dict( - WL=self._to_wl_amp, - FREQ=self._to_freq_amp, - AFREQ=self._to_afreq_amp, - TIME=self._to_time_amp, - )[unit.type] - - for spec in self: - yield x_axis[order], func(spec)[:, order] - - def _to_wl_amp(self, spectrum): - return ( - np.sqrt( - units.to_WL( - math.abs2(spectrum), - spectrum.wl, - ) - ) - * spectrum - / np.abs(spectrum) - ) - - def _to_freq_amp(self, spectrum): - return spectrum - - def _to_afreq_amp(self, spectrum): - return spectrum - - def _to_time_amp(self, spectrum): - return np.fft.ifft(spectrum) - - def all_spectra(self, ind=None) -> Spectrum: - """ - loads the data already simulated. - defauft shape is (z_targets, n, nt) + def all_params(self, key: str) -> list[tuple[str, Any]]: + """returns the value of a parameter for each fiber Parameters ---------- - ind : int or list of int - if only certain spectra are desired + key : str + name of the parameter + Returns - ---------- - spectra : array of shape (nz, m, nt) - array of complex spectra (pulse at nz positions consisting - of nm simulation on a nt size grid) + ------- + list[tuple[str, Any]] + list of (fiber_name, param_value) tuples """ + return list(reversed(self._all_params(key, []))) - self.logger.debug(f"opening {self.path}") + def accumulate_params(self, key: str) -> Any: + """returns the sum of all the values a parameter takes. Useful to + get the total length of the fiber, the total number of steps, etc. - # Check if file exists and assert how many z positions there are + Parameters + ---------- + key : str + name of the parameter - if ind is None: - if self.default_ind is None: - ind = range(self.nmax) - else: - ind = self.default_ind - if isinstance(ind, (int, np.integer)): - ind = [ind] - elif isinstance(ind, (float, np.floating)): - ind = [self.z_ind(ind)] - elif isinstance(ind[0], (float, np.floating)): - ind = [self.z_ind(ii) for ii in ind] + Returns + ------- + Any + final sum + """ + return sum(el[1] for el in self.all_params(key)) - # Load the spectra - spectra = [] - for i in ind: - spectra.append(self._load1(i)) - spectra = Spectrum(spectra, self.params) - - self.logger.debug(f"all spectra from {self.path} successfully loaded") - if len(ind) == 1: - return spectra[0] + def spectra( + self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0 + ) -> Spectrum: + if z_descr is None: + out = [self.spectra(i, sim_ind) for i in range(self.total_num_steps)] else: - return spectra + if isinstance(z_descr, (float, np.floating)): + if self.z[0] <= z_descr <= self.z[-1]: + z_ind = self.z_inds[np.argmin(np.abs(self.z - z_descr))] + elif 0 <= z_descr < self.z[0]: + return self.previous.spectra(z_descr, sim_ind) + else: + raise ValueError( + f"cannot match z={z_descr} with max length of {self.total_length}" + ) + else: + z_ind = z_descr - def all_fields(self, ind=None): - return np.fft.ifft(self.all_spectra(ind=ind), axis=-1) + if z_ind < self.z_inds[0]: + return self.previous.spectra(z_ind, sim_ind) + if sim_ind is None: + out = [self._load_1(z_ind, i) for i in range(self.params.repeat)] + else: + out = self._load_1(z_ind) + return Spectrum(out, self.params) - def _load1(self, i: int): - if i < 0: - i = self.nmax + i - spec = load_spectrum(self.path / SPECN_FN.format(i)) - spec = np.atleast_2d(spec) - spec = Spectrum(spec, self.params) - return spec + def fields( + self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0 + ) -> Spectrum: + return np.fft.ifft(self.spectra(z_descr, sim_ind)) + + # Plotting def plot_2D( self, @@ -317,12 +223,11 @@ class Pulse(Sequence): right: float, unit: Union[Callable[[float], float], str], ax: plt.Axes, - z_pos: Union[int, Iterable[int]] = None, sim_ind: int = 0, **kwargs, ): plot_range = PlotRange(left, right, unit) - vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind) + vals = self.retrieve_plot_values(plot_range, None, sim_ind) return propagation_plot(vals, plot_range, self.params, ax, **kwargs) def plot_1D( @@ -349,7 +254,7 @@ class Pulse(Sequence): **kwargs, ): plot_range = PlotRange(left, right, unit) - vals = self.retrieve_plot_values(plot_range, z_pos, slice(None)) + vals = self.retrieve_plot_values(plot_range, z_pos, None) return mean_values_plot(vals, plot_range, self.params, ax, **kwargs) def retrieve_plot_values( @@ -357,16 +262,9 @@ class Pulse(Sequence): ): if plot_range.unit.type == "TIME": - vals = self.all_fields(ind=z_pos) + return self.fields(z_pos, sim_ind) else: - vals = self.all_spectra(ind=z_pos) - - if sim_ind is None: - return vals - elif z_pos is None: - return vals[:, sim_ind] - else: - return vals[sim_ind] + return self.spectra(z_pos, sim_ind) def rin_propagation( self, left: float, right: float, unit: str @@ -392,22 +290,63 @@ class Pulse(Sequence): RIN """ spectra = [] - for spec in np.moveaxis(self.all_spectra(), 1, 0): + for spec in np.moveaxis(self.spectra(None, None), 1, 0): x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False) spectra.append(tmp) return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1)) - def z_ind(self, z: float) -> int: - """return the closest z index to the given target + # Private + + def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray: + """loads a spectrum file Parameters ---------- - z : float - target + z_ind : int + z_index relative to the entire simulation + sim_ind : int, optional + simulation index, used when repeated simulations with same parameters are ran, by default 0 Returns ------- - int - index + np.ndarray + loaded spectrum file """ - return math.argclosest(self.z, z) + if sim_ind > 0: + return load_spectrum(self.path / SPEC1_FN_N.format(z_ind - self.z_inds[0], sim_ind)) + else: + return load_spectrum(self.path / SPEC1_FN.format(z_ind - self.z_inds[0])) + + def _all_params(self, key: str, l: list) -> list: + l.append((self.params.name, getattr(self.params, key))) + if self.previous is not None: + return self.previous._all_params(key, l) + return l + + # Magic methods + + def __iter__(self) -> Iterator[Spectrum]: + for i in range(self.total_num_steps): + yield self.spectra(i, None) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})" + + def __eq__(self, other: SimulationSeries) -> bool: + return ( + self.path == other.path + and self.params == other.params + and self.previous == other.previous + ) + + def __contains__(self, other: SimulationSeries) -> bool: + if other is self or other == self: + return True + if self.previous is not None: + return other in self.previous + + def __getitem__(self, key) -> Spectrum: + if isinstance(key, tuple): + return self.spectra(*key) + else: + return self.spectra(key, None) diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py deleted file mode 100644 index 9ca5ac3..0000000 --- a/src/scgenerator/utils/__init__.py +++ /dev/null @@ -1,677 +0,0 @@ -""" -This files includes utility functions designed more or less to be used specifically with the -scgenerator module but some function may be used in any python program - -""" - -from __future__ import annotations - -import itertools -import multiprocessing -import os -import random -import re -import shutil -import threading -from collections import abc -from io import StringIO -from pathlib import Path -from string import printable as str_printable -from functools import cache -from typing import Any, Callable, Generator, Iterable, MutableMapping, Sequence, TypeVar, Union - -import numpy as np -import pkg_resources as pkg -import toml -from tqdm import tqdm - -from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__ -from ..env import pbar_policy -from ..logger import get_logger - -T_ = TypeVar("T_") - -PathTree = list[tuple[Path, ...]] - - -class Paths: - _data_files = [ - "materials.toml", - "hr_t.npz", - "submit_job_template.txt", - "start_worker.sh", - "start_head.sh", - ] - - paths = { - f.split(".")[0]: os.path.abspath( - pkg.resource_filename("scgenerator", os.path.join("data", f)) - ) - for f in _data_files - } - - @classmethod - def get(cls, key): - if key not in cls.paths: - if os.path.exists("paths.toml"): - with open("paths.toml") as file: - paths_dico = toml.load(file) - for k, v in paths_dico.items(): - cls.paths[k] = v - if key not in cls.paths: - get_logger(__name__).info( - f"{key} was not found in path index, returning current working directory." - ) - cls.paths[key] = os.getcwd() - - return cls.paths[key] - - @classmethod - def gets(cls, key): - """returned the specified file as a string""" - with open(cls.get(key)) as file: - return file.read() - - @classmethod - def plot(cls, name): - """returns the paths to the specified plot. Used to save new plot - example - --------- - fig.savefig(Paths.plot("figure5.pdf")) - """ - return os.path.join(cls.get("plots"), name) - - -def load_previous_spectrum(prev_data_dir: str) -> np.ndarray: - prev_data_dir = Path(prev_data_dir) - num = find_last_spectrum_num(prev_data_dir) - return load_spectrum(prev_data_dir / SPEC1_FN.format(num)) - - -@cache -def load_spectrum(folder: os.PathLike) -> np.ndarray: - return np.load(folder) - - -def conform_toml_path(path: os.PathLike) -> str: - path: str = str(path) - if not path.lower().endswith(".toml"): - path = path + ".toml" - return path - - -def open_single_config(path: os.PathLike) -> dict[str, Any]: - d = open_config(path) - f = d.pop("Fiber")[0] - return d | f - - -def open_config(path: os.PathLike): - """returns a dictionary parsed from the specified toml file - This also handle having a 'INCLUDE' argument that will fill - otherwise unspecified keys with what's in the INCLUDE file(s)""" - - path = conform_toml_path(path) - dico = resolve_loadfile_arg(load_toml(path)) - - dico.setdefault("variable", {}) - for key in {"simulation", "fiber", "gas", "pulse"} & dico.keys(): - section = dico.pop(key) - dico["variable"].update(section.pop("variable", {})) - dico.update(section) - if len(dico["variable"]) == 0: - dico.pop("variable") - return dico - - -def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]: - if (f_list := dico.pop("INCLUDE", None)) is not None: - if isinstance(f_list, str): - f_list = [f_list] - for to_load in f_list: - loaded = load_toml(to_load) - for k, v in loaded.items(): - if k not in dico and k not in dico.get("variable", {}): - dico[k] = v - for k, v in dico.items(): - if isinstance(v, MutableMapping): - dico[k] = resolve_loadfile_arg(v) - elif isinstance(v, Sequence): - for i, vv in enumerate(v): - if isinstance(vv, MutableMapping): - dico[k][i] = resolve_loadfile_arg(vv) - return dico - - -def load_toml(descr: str) -> dict[str, Any]: - if ":" in descr: - path, entry = descr.split(":", 1) - with open(path) as file: - return toml.load(file)[entry] - else: - with open(descr) as file: - return toml.load(file) - - -def save_toml(path: os.PathLike, dico): - """saves a dictionary into a toml file""" - path = conform_toml_path(path) - with open(path, mode="w") as file: - toml.dump(dico, file) - return dico - - -def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str, Any]], str]: - loaded_config = open_config(final_config_path) - final_name = loaded_config.get("name") - fiber_list = loaded_config.pop("Fiber") - configs = [] - if fiber_list is not None: - master_variable = loaded_config.get("variable", {}) - for i, params in enumerate(fiber_list): - params.setdefault("variable", master_variable if i == 0 else {}) - if i == 0: - params["variable"] |= master_variable - configs.append(loaded_config | params) - else: - configs.append(loaded_config) - while "previous_config_file" in configs[0]: - configs.insert(0, open_config(configs[0]["previous_config_file"])) - configs[0].setdefault("variable", {}) - for pre, nex in zip(configs[:-1], configs[1:]): - variable = nex.pop("variable", {}) - nex.update({k: v for k, v in pre.items() if k not in nex}) - nex["variable"] = variable - - return configs, final_name - - -def save_parameters( - params: dict[str, Any], destination_dir: Path, file_name: str = PARAM_FN -) -> Path: - """saves a parameter dictionary. Note that is does remove some entries, particularly - those that take a lot of space ("t", "w", ...) - - Parameters - ---------- - params : dict[str, Any] - dictionary to save - destination_dir : Path - destination directory - - Returns - ------- - Path - path to newly created the paramter file - """ - file_path = destination_dir / file_name - os.makedirs(file_path.parent, exist_ok=True) - - # save toml of the simulation - with open(file_path, "w") as file: - toml.dump(params, file, encoder=toml.TomlNumpyEncoder()) - - return file_path - - -def load_material_dico(name: str) -> dict[str, Any]: - """loads a material dictionary - Parameters - ---------- - name : str - name of the material - Returns - ---------- - material_dico : dict - """ - return toml.loads(Paths.gets("materials"))[name] - - -def update_appended_params(source: Path, destination: Path, z: Sequence): - z_num = len(z) - params = open_config(source) - params["z_num"] = z_num - params["length"] = float(z[-1] - z[0]) - for p_name in ["recovery_data_dir", "prev_data_dir", "output_path"]: - if p_name in params: - del params[p_name] - save_toml(destination, params) - - -def to_62(i: int) -> str: - arr = [] - if i == 0: - return "0" - i = abs(i) - while i: - i, value = divmod(i, 62) - arr.append(str_printable[value]) - return "".join(reversed(arr)) - - -def build_path_trees(sim_dir: Path) -> list[PathTree]: - sim_dir = sim_dir.resolve() - path_branches: list[tuple[Path, ...]] = [] - to_check = list(sim_dir.glob("*fiber*num*")) - with PBars(len(to_check), desc="Building path trees") as pbar: - for branch in map(build_path_branch, to_check): - if branch is not None: - path_branches.append(branch) - pbar.update() - path_trees = group_path_branches(path_branches) - return path_trees - - -def build_path_branch(data_dir: Path) -> tuple[Path, ...]: - if not data_dir.is_dir(): - return None - path_branch = [data_dir] - while ( - prev_sim_path := open_config(path_branch[-1] / PARAM_FN).get("prev_data_dir") - ) is not None: - p = Path(prev_sim_path).resolve() - if not p.exists(): - p = Path(*p.parts[-2:]).resolve() - path_branch.append(p) - return tuple(reversed(path_branch)) - - -def group_path_branches(path_branches: list[tuple[Path, ...]]) -> list[PathTree]: - """groups path lists - - [ - ("a/id 0 wavelength 100 num 0"," b/id 0 wavelength 100 num 0"), - ("a/id 2 wavelength 100 num 1"," b/id 2 wavelength 100 num 1"), - ("a/id 1 wavelength 200 num 0"," b/id 1 wavelength 200 num 0"), - ("a/id 3 wavelength 200 num 1"," b/id 3 wavelength 200 num 1") - ] - -> - [ - ( - ("a/id 0 wavelength 100 num 0", "a/id 2 wavelength 100 num 1"), - ("b/id 0 wavelength 100 num 0", "b/id 2 wavelength 100 num 1"), - ) - ( - ("a/id 1 wavelength 200 num 0", "a/id 3 wavelength 200 num 1"), - ("b/id 1 wavelength 200 num 0", "b/id 3 wavelength 200 num 1"), - ) - ] - - - Parameters - ---------- - path_branches : list[tuple[Path, ...]] - each element of the list is a path to a folder containing data of one simulation - - Returns - ------- - list[PathTree] - list of PathTrees to be used in merge - """ - sort_key = lambda el: el[0] - - size = len(path_branches[0]) - out_trees_map: dict[str, dict[int, dict[int, Path]]] = {} - for branch in path_branches: - b_id = branch_id(branch) - out_trees_map.setdefault(b_id, {i: {} for i in range(size)}) - for sim_part, data_dir in enumerate(branch): - num = re.search(r"(?<=num )[0-9]+", data_dir.name)[0] - out_trees_map[b_id][sim_part][int(num)] = data_dir - - return [ - tuple( - tuple(w for _, w in sorted(v.items(), key=sort_key)) - for __, v in sorted(d.items(), key=sort_key) - ) - for d in out_trees_map.values() - ] - - -def merge_path_tree( - path_tree: PathTree, destination: Path, z_callback: Callable[[int], None] = None -): - """given a path tree, copies the file into the right location - - Parameters - ---------- - path_tree : PathTree - elements of the list returned by group_path_branches - destination : Path - dir where to save the data - """ - z_arr: list[float] = [] - - destination.mkdir(exist_ok=True) - - for i, (z, merged_spectra) in enumerate(merge_spectra(path_tree)): - z_arr.append(z) - spec_out_name = SPECN_FN.format(i) - np.save(destination / spec_out_name, merged_spectra) - if z_callback is not None: - z_callback(i) - d = np.diff(z_arr) - d[d < 0] = 0 - z_arr = np.concatenate(([z_arr[0]], np.cumsum(d))) - np.save(destination / Z_FN, z_arr) - update_appended_params(path_tree[-1][0] / PARAM_FN, destination / PARAM_FN, z_arr) - - -def merge_spectra( - path_tree: PathTree, -) -> Generator[tuple[float, np.ndarray], None, None]: - for same_sim_paths in path_tree: - z_arr = np.load(same_sim_paths[0] / Z_FN) - for i, z in enumerate(z_arr): - spectra: list[np.ndarray] = [] - for data_dir in same_sim_paths: - spec = np.load(data_dir / SPEC1_FN.format(i)) - spectra.append(spec) - yield z, np.atleast_2d(spectra) - - -def merge(destination: os.PathLike, path_trees: list[PathTree] = None): - - destination = ensure_folder(Path(destination)) - - z_num = 0 - prev_z_num = 0 - - for i, sim_dir in enumerate(sim_dirs(path_trees)): - conf = sim_dir / "initial_config.toml" - shutil.copy( - conf, - destination / f"initial_config_{i}.toml", - ) - prev_z_num = open_config(conf).get("z_num", prev_z_num) - z_num += prev_z_num - - pbars = PBars( - len(path_trees) * z_num, "Merging", 1, worker_kwargs=dict(total=z_num, desc="current pos") - ) - for path_tree in path_trees: - pbars.reset(1) - iden_items = path_tree[-1][0].name.split()[2:] - for i, p_name in list(enumerate(iden_items))[-2::-2]: - if p_name == "num": - del iden_items[i + 1] - del iden_items[i] - iden = PARAM_SEPARATOR.join(iden_items) - merge_path_tree(path_tree, destination / iden, z_callback=lambda i: pbars.update(1)) - - -def sim_dirs(path_trees: list[PathTree]) -> Generator[Path, None, None]: - for p in path_trees[0]: - yield p[0].parent - - -def save_data(data: np.ndarray, data_dir: Path, file_name: str): - """saves numpy array to disk - - Parameters - ---------- - data : np.ndarray - data to save - file_name : str - file name - task_id : int - id that uniquely identifies the process - identifier : str, optional - identifier in the main data folder of the task, by default "" - """ - path = data_dir / file_name - np.save(path, data) - get_logger(__name__).debug(f"saved data in {path}") - return - - -def ensure_folder(path: Path, prevent_overwrite: bool = True, mkdir=True) -> Path: - """ensure a folder exists and doesn't overwrite anything if required - - Parameters - ---------- - path : Path - desired path - prevent_overwrite : bool, optional - whether to create a new directory when one already exists, by default True - - Returns - ------- - Path - final path - """ - - path = path.resolve() - - # is path root ? - if len(path.parts) < 2: - return path - - # is a part of path an existing *file* ? - parts = path.parts - path = Path(path.root) - for part in parts: - if path.is_file(): - path = ensure_folder(path, mkdir=mkdir, prevent_overwrite=False) - path /= part - - folder_name = path.name - - for i in itertools.count(): - if not path.is_file() and (not prevent_overwrite or not path.is_dir()): - if mkdir: - path.mkdir(exist_ok=True) - return path - path = path.parent / (folder_name + f"_{i}") - - -class PBars: - def __init__( - self, - task: Union[int, Iterable[T_]], - desc: str, - num_sub_bars: int = 0, - head_kwargs=None, - worker_kwargs=None, - ) -> "PBars": - - self.id = random.randint(100000, 999999) - try: - self.width = os.get_terminal_size().columns - except OSError: - self.width = 80 - if isinstance(task, abc.Iterable): - self.iterator: Iterable[T_] = iter(task) - self.num_tot: int = len(task) - else: - self.num_tot: int = task - self.iterator = None - - self.policy = pbar_policy() - if head_kwargs is None: - head_kwargs = dict() - if worker_kwargs is None: - worker_kwargs = dict( - total=1, - desc="Worker {worker_id}", - bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]", - ) - if "print" not in pbar_policy(): - head_kwargs["file"] = worker_kwargs["file"] = StringIO() - self.width = 80 - head_kwargs["desc"] = desc - self.pbars = [tqdm(total=self.num_tot, ncols=self.width, ascii=False, **head_kwargs)] - for i in range(1, num_sub_bars + 1): - kwargs = {k: v for k, v in worker_kwargs.items()} - if "desc" in kwargs: - kwargs["desc"] = kwargs["desc"].format(worker_id=i) - self.append(tqdm(position=i, ncols=self.width, ascii=False, **kwargs)) - self.print_path = Path( - f"progress {self.pbars[0].desc.replace('/', '')} {self.id}" - ).resolve() - self.close_ev = threading.Event() - if "file" in self.policy: - self.thread = threading.Thread(target=self.print_worker, daemon=True) - self.thread.start() - - def print(self): - if "file" not in self.policy: - return - s = [] - for pbar in self.pbars: - s.append(str(pbar)) - self.print_path.write_text("\n".join(s)) - - def print_worker(self): - while True: - if self.close_ev.wait(2.0): - return - self.print() - - def __iter__(self): - with self as pb: - for thing in self.iterator: - yield thing - pb.update() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def __getitem__(self, key): - return self.pbars[key] - - def update(self, i=None, value=1): - if i is None: - for pbar in self.pbars[1:]: - pbar.update(value) - elif i > 0: - self.pbars[i].update(value) - self.pbars[0].update() - - def append(self, pbar: tqdm): - self.pbars.append(pbar) - - def reset(self, i): - self.pbars[i].update(-self.pbars[i].n) - self.print() - - def close(self): - self.print() - self.close_ev.set() - if "file" in self.policy: - self.thread.join() - for pbar in self.pbars: - pbar.close() - - -class ProgressBarActor: - def __init__(self, name: str, num_workers: int, num_steps: int) -> None: - self.counters = [0 for _ in range(num_workers + 1)] - self.p_bars = PBars( - num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step") - ) - - def update(self, worker_id: int, rel_pos: float = None) -> None: - """update a counter - - Parameters - ---------- - worker_id : int - id of the worker. 0 is the overall progress - rel_pos : float, optional - if None, increase the counter by one, if set, will set - the counter to the specified value (instead of incrementing it), by default None - """ - if rel_pos is None: - self.counters[worker_id] += 1 - else: - self.counters[worker_id] = rel_pos - - def update_pbars(self): - for counter, pbar in zip(self.counters, self.p_bars.pbars): - pbar.update(counter - pbar.n) - - def close(self): - self.p_bars.close() - - -def progress_worker( - name: str, num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue -): - """keeps track of progress on a separate thread - - Parameters - ---------- - num_steps : int - total number of steps, used for the main progress bar (position 0) - progress_queue : multiprocessing.Queue - values are either - Literal[0] : stop the worker and close the progress bars - tuple[int, float] : worker id and relative progress between 0 and 1 - """ - with PBars( - num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step") - ) as pbars: - while True: - raw = progress_queue.get() - if raw == 0: - return - i, rel_pos = raw - if i > 0: - pbars[i].update(rel_pos - pbars[i].n) - pbars[0].update() - elif i == 0: - pbars[0].update(rel_pos) - - -def branch_id(branch: tuple[Path, ...]) -> str: - return branch[-1].name.split()[1] - - -def find_last_spectrum_num(data_dir: Path): - for num in itertools.count(1): - p_to_test = data_dir / SPEC1_FN.format(num) - if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0: - return num - 1 - - -def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray: - threshold = y.min() + rel_thr * (y.max() - y.min()) - above_threshold = y > threshold - ind = np.argsort(x) - valid_ind = [ - np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k - ] - ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0] - width = len(ind_above) - return np.concatenate( - ( - np.arange(max(ind_above[0] - width, 0), ind_above[0]), - ind_above, - np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)), - ) - ) - - -def translate_parameters(d: dict[str, Any]) -> dict[str, Any]: - old_names = dict( - interp_degree="interpolation_degree", - beta="beta2_coefficients", - interp_range="interpolation_range", - ) - deleted_names = {"lower_wavelength_interp_limit", "upper_wavelength_interp_limit"} - defaults_to_add = dict(repeat=1) - new = {} - for k, v in d.items(): - if k == "error_ok": - new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v - elif k in deleted_names: - continue - elif isinstance(v, MutableMapping): - new[k] = translate_parameters(v) - else: - new[old_names.get(k, k)] = v - return defaults_to_add | new diff --git a/testing/long_tests/test_recovery_param_seq.py b/testing/long_tests/test_recovery_param_seq.py deleted file mode 100644 index 5d73367..0000000 --- a/testing/long_tests/test_recovery_param_seq.py +++ /dev/null @@ -1,35 +0,0 @@ -import shutil -import unittest - -import toml -from scgenerator import logger -from send2trash import send2trash - -TMP = "testing/.tmp" - - -class TestRecoveryParamSequence(unittest.TestCase): - def setUp(self): - shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP) - self.conf = toml.load(TMP + "/initial_config.toml") - io.set_data_folder(55, TMP) - - def test_remaining_simulations_count(self): - param_seq = initialize.RecoveryParamSequence(self.conf, 55) - self.assertEqual(5, len(param_seq)) - - def test_only_one_to_complete(self): - param_seq = initialize.RecoveryParamSequence(self.conf, 55) - i = 0 - for expected, (vary_list, params) in zip([True, False, False, False, False], param_seq): - i += 1 - self.assertEqual(expected, "recovery_last_stored" in params) - - self.assertEqual(5, i) - - def tearDown(self): - send2trash(TMP) - - -if __name__ == "__main__": - unittest.main() diff --git a/testing/test_initialize.py b/testing/test_initialize.py deleted file mode 100644 index edd3d90..0000000 --- a/testing/test_initialize.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -from copy import deepcopy - -import numpy as np -import toml -from scgenerator import defaults, utils, math -from scgenerator.errors import * -from scgenerator.physics import pulse, units -from scgenerator.utils.parameter import Config, Parameters - - -def load_conf(name): - with open("testing/configs/" + name + ".toml") as file: - conf = toml.load(file) - return conf - - -def conf_maker(folder): - def conf(name): - return load_conf(folder + "/" + name) - - return conf - - -class TestParamSequence(unittest.TestCase): - def iterconf(self, files): - conf = conf_maker("param_sequence") - for path in files: - yield init.ParamSequence(conf(path)) - - def test_no_repeat_in_sub_folder_names(self): - for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]): - l = [] - s = [] - for vary_list, _ in utils.required_simulations(param_seq.config): - self.assertNotIn(vary_list, l) - self.assertNotIn(utils.format_variable_list(vary_list), s) - l.append(vary_list) - s.append(utils.format_variable_list(vary_list)) - - def test_no_variations_yields_only_num_and_id(self): - for param_seq in self.iterconf(["no_variations"]): - for vary_list, _ in utils.required_simulations(param_seq.config): - self.assertEqual(vary_list[1][0], "num") - self.assertEqual(vary_list[0][0], "id") - self.assertEqual(2, len(vary_list)) - - -class TestInitializeMethods(unittest.TestCase): - def test_validate_types(self): - conf = lambda s: load_conf("validate_types/" + s) - - with self.assertRaisesRegex(ValueError, r"'behaviors\[3\]' must be a str in"): - init.Config(**conf("bad2")) - - with self.assertRaisesRegex(TypeError, "value must be of type "): - init.Config(**conf("bad3")) - - with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"): - init.Config(**conf("bad4")) - - with self.assertRaisesRegex( - TypeError, "'variable intensity_noise' value must be of type " - ): - init.Config(**conf("bad5")) - - with self.assertRaisesRegex(ValueError, "'repeat' must be positive"): - init.Config(**conf("bad6")) - - with self.assertRaisesRegex( - ValueError, "variable parameter 'intensity_noise' must not be empty" - ): - init.Config(**conf("bad7")) - - self.assertIsNone(init.Config(**conf("good")).hr_w) - - def test_ensure_consistency(self): - conf = lambda s: load_conf("ensure_consistency/" + s) - with self.assertRaisesRegex( - MissingParameterError, - r"1 of '\['t0', 'width'\]' is required and no defaults have been set", - ): - init.Config(**conf("bad1")) - - with self.assertRaisesRegex( - MissingParameterError, - r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set", - ): - init.Config(**conf("bad2")) - - with self.assertRaisesRegex( - MissingParameterError, - r"2 of '\['dt', 't_num', 'time_window'\]' are required and no defaults have been set", - ): - init.Config(**conf("bad3")) - - with self.assertRaisesRegex( - DuplicateParameterError, - r"got multiple values for parameter 'width'", - ): - init.Config(**conf("bad4")) - - with self.assertRaisesRegex( - MissingParameterError, - r"'capillary_thickness' is a required parameter for fiber model 'hasan' and no defaults have been set", - ): - init.Config(**conf("bad5")) - - with self.assertRaisesRegex( - MissingParameterError, - r"1 of '\['capillary_spacing', 'capillary_outer_d'\]' is required for fiber model 'hasan' and no defaults have been set", - ): - init.Config(**conf("bad6")) - - self.assertLessEqual( - {"model": "pcf"}.items(), init.Config(**conf("good1")).__dict__.items() - ) - - self.assertIsNone(init.Config(**conf("good4")).gamma) - - self.assertLessEqual( - {"raman_type": "agrawal"}.items(), - init.Config(**conf("good2")).__dict__.items(), - ) - - self.assertLessEqual( - {"name": "no name"}.items(), init.Config(**conf("good3")).__dict__.items() - ) - - self.assertLessEqual( - {"capillary_nested": 0, "capillary_resonance_strengths": []}.items(), - init.Config(**conf("good4")).__dict__.items(), - ) - - self.assertLessEqual( - dict(he_mode=(1, 1)).items(), - init.Config(**conf("good5")).__dict__.items(), - ) - - self.assertLessEqual( - dict(temperature=300, pressure=1e5, gas_name="vacuum", plasma_density=0).items(), - init.Config(**conf("good5")).__dict__.items(), - ) - - def setup_conf_custom_field(self, path) -> Parameters: - - conf = load_conf(path) - conf = Parameters(**conf) - init.build_sim_grid_in_place(conf) - return conf - - def test_setup_custom_field(self): - d = np.load("testing/configs/custom_field/init_field.npz") - t = d["time"] - field = d["field"] - conf = self.setup_conf_custom_field("custom_field/no_change") - result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field( - conf - ) - self.assertAlmostEqual(conf.field_0.real.max(), field.real.max(), 4) - self.assertTrue(result) - - conf = self.setup_conf_custom_field("custom_field/peak_power") - result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field( - conf - ) - conf.wavelength = pulse.correct_wavelength(conf.wavelength, conf.w_c, conf.field_0) - self.assertAlmostEqual(math.abs2(conf.field_0).max(), 20000, 4) - self.assertTrue(result) - self.assertNotAlmostEqual(conf.wavelength, 1593e-9) - - conf = self.setup_conf_custom_field("custom_field/mean_power") - result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field( - conf - ) - self.assertAlmostEqual(np.trapz(math.abs2(conf.field_0), conf.t), 0.22 / 40e6, 4) - self.assertTrue(result) - - conf = self.setup_conf_custom_field("custom_field/recover1") - result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field( - conf - ) - self.assertAlmostEqual(math.abs2(conf.field_0 - field).sum(), 0) - self.assertTrue(result) - - conf = self.setup_conf_custom_field("custom_field/recover2") - result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field( - conf - ) - self.assertAlmostEqual((math.abs2(conf.field_0) / 0.9 - math.abs2(field)).sum(), 0) - self.assertTrue(result) - - conf = self.setup_conf_custom_field("custom_field/wavelength_shift1") - result = Parameters(**conf) - self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9) - - conf = self.setup_conf_custom_field("custom_field/wavelength_shift1") - conf.wavelength = 1593e-9 - result = Parameters(**conf) - - conf = load_conf("custom_field/wavelength_shift2") - conf = init.Config(**conf) - for target, (variable, config) in zip( - [1050e-9, 1321e-9, 1593e-9], init.ParamSequence(conf) - ): - init.build_sim_grid_in_place(conf) - self.assertAlmostEqual( - units.m.inv(config.w)[np.argmax(math.abs2(config.spec_0))], target - ) - print(config.wavelength, target) - - -if __name__ == "__main__": - conf = conf_maker("validate_types") - - unittest.main() diff --git a/testing/test_pulse.py b/testing/test_pulse.py deleted file mode 100644 index 6a632d4..0000000 --- a/testing/test_pulse.py +++ /dev/null @@ -1,41 +0,0 @@ -import unittest -from scgenerator.physics.pulse import conform_pulse_params - - -class TestPulseMethods(unittest.TestCase): - def test_conform_pulse_params(self): - self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, energy=6)) - self.assertNotIn(None, conform_pulse_params("gaussian", width=5, energy=6)) - self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, peak_power=6)) - self.assertNotIn(None, conform_pulse_params("gaussian", width=5, peak_power=6)) - - self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, energy=6))) - self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, energy=6))) - self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, peak_power=6))) - self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, peak_power=6))) - - with self.assertRaisesRegex( - TypeError, "when soliton number is desired, both gamma and beta2 must be specified" - ): - conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01) - with self.assertRaisesRegex( - TypeError, "when soliton number is desired, both gamma and beta2 must be specified" - ): - conform_pulse_params("gaussian", t0=5, energy=6, beta2=0.01) - - self.assertEqual( - 5, len(conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01, beta2=2e-6)) - ) - self.assertEqual( - 5, len(conform_pulse_params("gaussian", width=5, energy=6, gamma=0.01, beta2=2e-6)) - ) - self.assertEqual( - 5, len(conform_pulse_params("gaussian", t0=5, peak_power=6, gamma=0.01, beta2=2e-6)) - ) - self.assertEqual( - 5, len(conform_pulse_params("gaussian", width=5, peak_power=6, gamma=0.01, beta2=2e-6)) - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/testing/test_utils.py b/testing/test_utils.py deleted file mode 100644 index ef277eb..0000000 --- a/testing/test_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -import unittest - -import numpy as np -import toml -from scgenerator import utils - - -def load_conf(name): - with open("testing/configs/" + name + ".toml") as file: - conf = toml.load(file) - return conf - - -def conf_maker(folder, val=True): - def conf(name): - if val: - return initialize.Config(**load_conf(folder + "/" + name)) - else: - return initialize.Config(**load_conf(folder + "/" + name)) - - return conf - - -class TestUtilsMethods(unittest.TestCase): - def test_count_variations(self): - conf = conf_maker("count_variations") - - for sim, vary in [(1, 0), (1, 1), (2, 1), (2, 0), (120, 3)]: - self.assertEqual((sim, vary), utils.count_variations(conf(f"{sim}sim_{vary}vary"))) - - def test_format_value(self): - values = [ - 122e-6, - True, - ["raman", "ss"], - np.arange(5), - 1.123, - 1.1230001, - 0.002e122, - 12.3456e-9, - ] - s = [ - "0.000122", - "True", - "raman-ss", - "0-1-2-3-4", - "1.123", - "1.1230001", - "2e+119", - "1.23456e-08", - ] - - for value, target in zip(values, s): - self.assertEqual(target, utils.format_value(value)) - - def test_override_config(self): - conf = conf_maker("override", False) - old = conf("initial_config") - new = conf("fiber2") - - over = utils.override_config(vars(new), old) - self.assertNotIn("input_transmission", over.variable) - self.assertIsNone(over.input_transmission) - - -if __name__ == "__main__": - unittest.main() diff --git a/testing/test_variationer.py b/testing/test_variationer.py new file mode 100644 index 0000000..67033e2 --- /dev/null +++ b/testing/test_variationer.py @@ -0,0 +1,54 @@ +from pydantic import main +import scgenerator as sc + + +def test_descriptor(): + # Same branch + var1 = sc.VariationDescriptor( + raw_descr=[[("num", 1), ("a", False)], [("b", 0)]], index=[[1, 0], [0]] + ) + var2 = sc.VariationDescriptor( + raw_descr=[[("num", 2), ("a", False)], [("b", 0)]], index=[[1, 0], [0]] + ) + assert var1.branch.identifier == "b_0" + assert var1.identifier != var1.branch.identifier + assert var1.identifier != var2.identifier + assert var1.branch.identifier == var2.branch.identifier + + # different branch + var3 = sc.VariationDescriptor( + raw_descr=[[("num", 2), ("a", True)], [("b", 0)]], index=[[1, 0], [0]] + ) + assert var1.branch.identifier != var3.branch.identifier + assert var1.formatted_descriptor() != var2.formatted_descriptor() + assert var1.formatted_descriptor() != var3.formatted_descriptor() + + +def test_variationer(): + var = sc.Variationer( + [ + dict(a=[1, 2], num=[0, 1, 2]), + [dict(b=["000", "111"], c=["a", "-1"])], + dict(), + dict(), + [dict(aaa=[True, False], bb=[1, 3])], + ] + ) + assert var.var_num(0) == 6 + assert var.var_num(1) == 12 + assert var.var_num() == 24 + + cfg = dict(bb=None) + branches = set() + for descr in var.iterate(): + assert descr.update_config(cfg).items() >= set(descr.raw_descr[-1]) + branches.add(descr.branch.identifier) + assert len(branches) == 8 + + +def main(): + test_descriptor() + + +if __name__ == "__main__": + main()