diff --git a/play.py b/play.py index 9f687f3..d0cc8a8 100644 --- a/play.py +++ b/play.py @@ -4,22 +4,3 @@ import scgenerator as sc import matplotlib.pyplot as plt from pathlib import Path from pprint import pprint - - -def _main(): - print(os.getcwd()) - for v_list, params in sc.Configuration("PM1550+PM2000D+PM1550/Pos30000.toml"): - print(params.fiber_map) - - -def main(): - drr = os.getcwd() - os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations") - try: - _main() - finally: - os.chdir(drr) - - -if __name__ == "__main__": - main() diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 997e495..6cb8fcb 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -1,9 +1,10 @@ -from . import math, utils +from . import math from .math import abs2, argclosest, span from .physics import fiber, materials, pulse, simulate, units from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot from .spectra import Pulse, Spectrum -from .utils import Paths, open_config, parameter -from .utils.parameter import Configuration, Parameters -from .utils.utils import PlotRange +from ._utils import Paths, open_config, parameter +from ._utils.parameter import Configuration, Parameters +from ._utils.utils import PlotRange +from ._utils.variationer import Variationer, VariationDescriptor, VariationSpecsError diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/_utils/__init__.py similarity index 98% rename from src/scgenerator/utils/__init__.py rename to src/scgenerator/_utils/__init__.py index fdcbcea..03c8859 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/_utils/__init__.py @@ -25,7 +25,7 @@ import pkg_resources as pkg import toml from tqdm import tqdm -from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__ +from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN1, Z_FN, __version__ from ..env import pbar_policy from ..logger import get_logger @@ -143,7 +143,8 @@ def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]: return dico -def load_toml(descr: str) -> dict[str, Any]: +def load_toml(descr: os.PathLike) -> dict[str, Any]: + descr = str(descr) if ":" in descr: path, entry = descr.split(":", 1) with open(path) as file: @@ -188,6 +189,7 @@ def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]] params.setdefault("variable", {}) configs.append(loaded_config | params) configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"] + configs[0]["variable"]["num"] = list(range(configs[0].get("repeat", 1))) return Path(final_path), configs @@ -341,7 +343,7 @@ def merge_path_tree( for i, (z, merged_spectra) in enumerate(merge_spectra(path_tree)): z_arr.append(z) - spec_out_name = SPECN_FN.format(i) + spec_out_name = SPECN_FN1.format(i) np.save(destination / spec_out_name, merged_spectra) if z_callback is not None: z_callback(i) diff --git a/src/scgenerator/utils/cache.py b/src/scgenerator/_utils/cache.py similarity index 100% rename from src/scgenerator/utils/cache.py rename to src/scgenerator/_utils/cache.py diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/_utils/parameter.py similarity index 94% rename from src/scgenerator/utils/parameter.py rename to src/scgenerator/_utils/parameter.py index fa43c58..f7d7c1a 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/_utils/parameter.py @@ -14,14 +14,15 @@ from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequen import numpy as np from numpy.lib import isin -from scgenerator.utils import ensure_folder, variationer -from .. import math, utils +from .. import math from ..const import PARAM_FN, PARAM_SEPARATOR, __version__ from ..errors import EvaluatorError, NoDefaultError from ..logger import get_logger from ..physics import fiber, materials, pulse, units -from ..utils.variationer import VariationDescriptor, Variationer +from .._utils.variationer import VariationDescriptor, Variationer +from .. import _utils as utils +from .. import env from .utils import func_rewrite, _mock_function, get_arg_names T = TypeVar("T") @@ -312,13 +313,6 @@ class Parameter: return f"{num_str} {unit}" -def fiber_map_converter(d: dict[str, str]) -> list[tuple[float, str]]: - if isinstance(d, dict): - return [(float(k), v) for k, v in d.items()] - else: - return [(float(k), v) for k, v in d] - - @dataclass class Parameters: """ @@ -432,15 +426,13 @@ class Parameters: const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) beta_func: Callable[[float], list[float]] = Parameter(func_validator) gamma_func: Callable[[float], float] = Parameter(func_validator) - fiber_map: list[tuple[float, str]] = Parameter( - validator_list(type_checker(tuple)), converter=fiber_map_converter - ) + + num: int = Parameter(non_negative(int)) datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) version: str = Parameter(string) def prepare_for_dump(self) -> dict[str, Any]: param = asdict(self) - param["fiber_map"] = [(str(z), n) for z, n in param.get("fiber_map", [])] param = Parameters.strip_params_dict(param) param["datetime"] = datetime_module.datetime.now() param["version"] = __version__ @@ -816,7 +808,9 @@ class Configuration: self.overwrite = overwrite self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path) self.final_path = utils.ensure_folder( - self.final_path, mkdir=False, prevent_overwrite=not self.overwrite + Path(env.get(env.OUTPUT_PATH, self.final_path)), + mkdir=False, + prevent_overwrite=not self.overwrite, ) self.master_config = self.fiber_configs[0] self.name = self.final_path.name @@ -868,23 +862,8 @@ class Configuration: def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]: for i in range(self.num_fibers): for sim_config in self.iterate_single_fiber(i): - - if i > 0: - sim_config.config["prev_data_dir"] = str( - self.fiber_paths[i - 1] / sim_config.descriptor[:i].formatted_descriptor() - ) params = Parameters(**sim_config.config) params.compute() - fiber_map = [] - for j in range(i + 1): - this_conf = self.all_configs[sim_config.descriptor.index[: j + 1]].config - if j > 0: - prev_conf = self.all_configs[sim_config.descriptor.index[:j]].config - length = prev_conf["length"] + fiber_map[j - 1][0] - else: - length = 0.0 - fiber_map.append((length, this_conf["name"])) - params.fiber_map = fiber_map yield sim_config.descriptor, params def iterate_single_fiber( @@ -903,18 +882,21 @@ class Configuration: __SimConfig configuration obj """ - sim_dict: dict[Path, self.__SimConfig] = {} - for descr in self.variationer.iterate(index): - cfg = descr.update_config(self.fiber_configs[index]) - p = ensure_folder( - self.fiber_paths[index] / descr.formatted_descriptor(), + sim_dict: dict[Path, Configuration.__SimConfig] = {} + for descriptor in self.variationer.iterate(index): + cfg = descriptor.update_config(self.fiber_configs[index]) + if index > 0: + cfg["prev_data_dir"] = str( + self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True) + ) + p = utils.ensure_folder( + self.fiber_paths[index] / descriptor.formatted_descriptor(True), not self.overwrite, False, ) cfg["output_path"] = str(p) - sim_config = self.__SimConfig(descr, cfg, p) - sim_dict[p] = sim_config - self.all_configs[sim_config.descriptor.index] = sim_config + sim_config = self.__SimConfig(descriptor, cfg, p) + sim_dict[p] = self.all_configs[sim_config.descriptor.index] = sim_config while len(sim_dict) > 0: for data_dir, sim_config in sim_dict.items(): task, config_dict = self.__decide(sim_config) @@ -1001,9 +983,12 @@ class Configuration: raise ValueError(f"Too many spectra in {data_dir}") def save_parameters(self): - for config, sim_dir in zip(self.fiber_configs, self.fiber_paths): - os.makedirs(sim_dir, exist_ok=True) - utils.save_toml(sim_dir / f"initial_config.toml", config) + os.makedirs(self.final_path, exist_ok=True) + cfgs = [ + cfg | dict(variable=self.variationer.all_dicts[i]) + for i, cfg in enumerate(self.fiber_configs) + ] + utils.save_toml(self.final_path / f"initial_config.toml", dict(name=self.name, Fiber=cfgs)) @property def first(self) -> Parameters: diff --git a/src/scgenerator/utils/utils.py b/src/scgenerator/_utils/utils.py similarity index 69% rename from src/scgenerator/utils/utils.py rename to src/scgenerator/_utils/utils.py index a18aacb..44eadf3 100644 --- a/src/scgenerator/utils/utils.py +++ b/src/scgenerator/_utils/utils.py @@ -1,12 +1,17 @@ import inspect +import os import re +from collections import defaultdict from functools import cache +from pathlib import Path from string import printable as str_printable from typing import Callable import numpy as np from pydantic import BaseModel +from .._utils import load_toml, save_toml +from ..const import PARAM_FN, Z_FN from ..physics.units import get_unit @@ -144,3 +149,55 @@ def _mock_function(num_args: int, num_returns: int) -> Callable: out_func = scope[func_name] out_func.__module__ = "evaluator" return out_func + + +def combine_simulations(path: Path, dest: Path = None): + """combines raw simulations into one folder per branch + + Parameters + ---------- + path : Path + source of the simulations (must contain u_xx directories) + dest : Path, optional + if given, moves the simulations to dest, by default None + """ + paths: dict[str, list[Path]] = defaultdict(list) + if dest is None: + dest = path + + for p in path.glob("u_*b_*"): + if p.is_dir(): + paths[p.name.split()[1]].append(p) + for l in paths.values(): + l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0]) + for pulses in paths.values(): + new_path = dest / update_path(pulses[0].name) + os.makedirs(new_path, exist_ok=True) + for num, pulse in enumerate(pulses): + params_ok = False + for file in pulse.glob("*"): + if file.name == PARAM_FN: + if not params_ok: + update_params(new_path, file) + params_ok = True + else: + file.unlink() + elif file.name == Z_FN: + file.rename(new_path / file.name) + else: + file.rename(new_path / (file.stem + f"_{num}" + file.suffix)) + pulse.rmdir() + + +def update_params(new_path: Path, file: Path): + params = load_toml(file) + if (p := params.get("prev_data_dir")) is not None: + p = Path(p) + params["prev_data_dir"] = str(p.parent / update_path(p.name)) + params["output_path"] = str(new_path) + save_toml(new_path / PARAM_FN, params) + file.unlink() + + +def update_path(p: str) -> str: + return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p) diff --git a/src/scgenerator/utils/variationer.py b/src/scgenerator/_utils/variationer.py similarity index 87% rename from src/scgenerator/utils/variationer.py rename to src/scgenerator/_utils/variationer.py index d53777d..fdd6368 100644 --- a/src/scgenerator/utils/variationer.py +++ b/src/scgenerator/_utils/variationer.py @@ -116,6 +116,7 @@ class VariationDescriptor(utils.HashableBaseModel): index: tuple[tuple[int, ...], ...] separator: str = "fiber" _format_registry: dict[str, Callable[..., str]] = {} + __ids: dict[int, int] = {} def __str__(self) -> str: return self.formatted_descriptor(add_identifier=False) @@ -173,7 +174,19 @@ class VariationDescriptor(utils.HashableBaseModel): raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator ) - def update_config(self, cfg: dict[str, Any]): + def update_config(self, cfg: dict[str, Any]) -> dict[str, Any]: + """updates a dictionary with the value of the descriptor + + Parameters + ---------- + cfg : dict[str, Any] + dict to be updated + + Returns + ------- + dict[str, Any] + same as cfg but with key from the descriptor added/updated. + """ return cfg | {k: v for k, v in self.raw_descr[-1]} @property @@ -188,17 +201,22 @@ class VariationDescriptor(utils.HashableBaseModel): @property def branch(self) -> "BranchDescriptor": - for i in reversed(range(len(self.raw_descr))): - for j in reversed(range(len(self.raw_descr[i]))): - if self.raw_descr[i][j][0] == "num": - del self.raw_descr[i][j] - return VariationDescriptor( - raw_descr=self.raw_descr, index=self.index, separator=self.separator - ) + descr = [] + ind = [] + for i, l in enumerate(self.raw_descr): + descr.append([]) + ind.append([]) + for j, (k, v) in enumerate(l): + if k != "num": + descr[-1].append((k, v)) + ind[-1].append(self.index[i][j]) + return BranchDescriptor(raw_descr=descr, index=ind, separator=self.separator) @property def identifier(self) -> str: - return "u_" + utils.to_62(hash(str(self.flat))) + unique_id = hash(str(self.flat)) + self.__ids.setdefault(unique_id, len(self.__ids)) + return "u_" + str(self.__ids[unique_id]) class BranchDescriptor(VariationDescriptor): @@ -208,7 +226,7 @@ class BranchDescriptor(VariationDescriptor): def identifier(self) -> str: branch_id = hash(str(self.flat)) self.__ids.setdefault(branch_id, len(self.__ids)) - return str(self.__ids[branch_id]) + return "b_" + str(self.__ids[branch_id]) @validator("raw_descr") def validate_raw_descr(cls, v): diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index 2532fdf..92b81ab 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -1,14 +1,13 @@ import argparse import os import re -import subprocess -import sys from collections import ChainMap from pathlib import Path import numpy as np -from .. import const, env, scripts, utils +from .. import const, env, scripts +from .. import _utils as utils from ..logger import get_logger from ..physics.fiber import dispersion_coefficients from ..physics.simulate import SequencialSimulations, run_simulation diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index c679236..b8843a2 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -20,7 +20,8 @@ def pbar_format(worker_id: int): SPEC1_FN = "spectrum_{}.npy" -SPECN_FN = "spectra_{}.npy" +SPECN_FN1 = "spectra_{}.npy" +SPEC1_FN_N = "spectrum_{}_{}.npy" Z_FN = "z.npy" PARAM_FN = "params.toml" PARAM_SEPARATOR = " " diff --git a/src/scgenerator/env.py b/src/scgenerator/env.py index 46c0d6e..13f3545 100644 --- a/src/scgenerator/env.py +++ b/src/scgenerator/env.py @@ -48,7 +48,7 @@ def data_folder(task_id: int) -> Optional[str]: return tmp -def get(key: str) -> Any: +def get(key: str, default=None) -> Any: str_value = os.environ.get(key) if isinstance(str_value, str): try: @@ -58,7 +58,7 @@ def get(key: str) -> Any: return t(str_value) except (ValueError, KeyError): pass - return None + return default def all_environ() -> Dict[str, str]: diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 20db3dc..505c6b7 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -3,7 +3,7 @@ from typing import Union import numpy as np from scipy.interpolate import griddata, interp1d from scipy.special import jn_zeros -from .utils.cache import np_cache +from ._utils.cache import np_cache pi = np.pi c = 299792458.0 diff --git a/src/scgenerator/physics/__init__.py b/src/scgenerator/physics/__init__.py index 82bc01e..24c3937 100644 --- a/src/scgenerator/physics/__init__.py +++ b/src/scgenerator/physics/__init__.py @@ -10,8 +10,8 @@ from scipy.optimize import minimize_scalar from .. import math from . import fiber, materials, units, pulse -from .. import utils -from ..utils import cache +from .. import _utils +from .._utils import cache T = TypeVar("T") diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index ccddbb1..4a27249 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -7,9 +7,9 @@ from scipy.interpolate import interp1d from ..logger import get_logger -from .. import utils +from .. import _utils from ..math import abs2, argclosest, power_fact, u_nm -from ..utils.cache import np_cache +from .._utils.cache import np_cache from . import materials as mat from . import units from .units import c, pi diff --git a/src/scgenerator/physics/materials.py b/src/scgenerator/physics/materials.py index 812e60e..b9572e6 100644 --- a/src/scgenerator/physics/materials.py +++ b/src/scgenerator/physics/materials.py @@ -5,7 +5,7 @@ from scipy.integrate import cumulative_trapezoid from ..logger import get_logger from . import units -from .. import utils +from .. import _utils from .units import NA, c, kB, me, e, hbar diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 6a07c70..ce8ab29 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -23,8 +23,6 @@ from scipy.interpolate import UnivariateSpline from scipy.optimize import minimize_scalar from scipy.optimize.optimize import OptimizeResult -from scgenerator import utils - from ..defaults import default_plotting from ..logger import get_logger from ..math import * diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 477d9c5..46b9897 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -9,9 +9,11 @@ from typing import Any, Generator, Type, Union import numpy as np from send2trash import send2trash -from .. import env, utils +from .. import env +from .. import _utils as utils +from .._utils.utils import combine_simulations from ..logger import get_logger -from ..utils.parameter import Configuration, Parameters +from .._utils.parameter import Configuration, Parameters from . import pulse from .fiber import create_non_linear_op, fast_dispersion_op @@ -718,17 +720,9 @@ def run_simulation( sim = new_simulation(config, method) sim.run() - path_trees = utils.build_path_trees(config.fiber_paths[-1]) - final_name = env.get(env.OUTPUT_PATH) - if final_name is None: - final_name = config.final_path - - utils.merge(final_name, path_trees) - try: - send2trash(config.fiber_paths) - except (PermissionError, OSError): - get_logger(__name__).error("Could not send temporary directories to trash") + for path in config.fiber_paths: + combine_simulations(path) def new_simulation( diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index 80089fe..f648bb9 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -2,7 +2,6 @@ # For example, nm(X) means "I give the number X in nm, figure out the ang. freq." # to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ... -from dataclasses import dataclass from typing import Callable, TypeVar, Union import numpy as np diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 712bf12..6163520 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -14,8 +14,8 @@ from .const import PARAM_SEPARATOR from .defaults import default_plotting as defaults from .math import abs2, span from .physics import pulse, units -from .utils.parameter import Parameters -from .utils.utils import PlotRange, sort_axis +from ._utils.parameter import Parameters +from ._utils.utils import PlotRange, sort_axis RangeType = tuple[float, float, Union[str, Callable]] NO_LIM = object() diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index 029268f..176f236 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -12,12 +12,11 @@ from ..const import PARAM_FN, PARAM_SEPARATOR from ..physics import fiber, units from ..plotting import plot_setup from ..spectra import Pulse -from ..utils import auto_crop, open_config, save_toml, translate_parameters -from ..utils.parameter import ( +from .._utils import auto_crop, open_config, save_toml, translate_parameters +from .._utils.parameter import ( Configuration, Parameters, ) -from ..utils.variationer import VariationDescriptor def fingerprint(params: Parameters): diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index a977f35..3fc4b6b 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -9,8 +9,8 @@ from typing import Tuple import numpy as np -from ..utils import Paths -from ..utils.parameter import Configuration +from .._utils import Paths +from .._utils.parameter import Configuration def primes(n): diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index fdd6b78..82a83b8 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -1,13 +1,19 @@ +from __future__ import annotations + import os from collections.abc import Sequence from pathlib import Path -from typing import Callable, Dict, Iterable, Optional, Union +from typing import Any, Callable, Dict, Iterable, Optional, Union import matplotlib.pyplot as plt import numpy as np +from pydantic import BaseModel, DirectoryPath, root_validator from . import math -from .const import SPECN_FN +from ._utils import load_spectrum +from ._utils.parameter import Parameters +from ._utils.utils import PlotRange +from .const import SPECN_FN1, PARAM_FN, SPEC1_FN_N from .logger import get_logger from .physics import pulse, units from .plotting import ( @@ -16,9 +22,87 @@ from .plotting import ( single_position_plot, transform_2D_propagation, ) -from .utils.parameter import Parameters -from .utils.utils import PlotRange -from .utils import load_spectrum + + +class SimulationSeries: + path: Path + params: Parameters + total_length: float + total_num_steps: int + previous: SimulationSeries = None + + class Config: + arbitrary_types_allowed = True + + def __init__(self, path: os.PathLike): + self.path = Path(path) + self.params = Parameters.load(self.path / PARAM_FN) + if self.params.prev_data_dir is not None: + self.previous = SimulationSeries(self.params.prev_data_dir) + self.total_length = self.accumulate_params("length") + self.total_num_steps = self.accumulate_params("z_num") + + def fiber_map(self): + lengths = self.all_params("length") + return [ + (this[0], following[1]) for this, following in zip(lengths, [(None, 0.0)] + lengths) + ] + + def all_params(self, key: str) -> list[tuple[str, Any]]: + """returns the value of a parameter for each fiber + + Parameters + ---------- + key : str + name of the parameter + + Returns + ------- + list[tuple[str, Any]] + list of (fiber_name, param_value) tuples + """ + return list(reversed(self._all_params(key, []))) + + def accumulate_params(self, key: str) -> Any: + """returns the sum of all the values a parameter takes. Useful to + get the total length of the fiber, the total number of steps, etc. + + Parameters + ---------- + key : str + name of the parameter + + Returns + ------- + Any + final sum + """ + return sum(el[1] for el in self.all_params(key)) + + def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray: + return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind)) + + def _all_params(self, key: str, l: list) -> list: + l.append((self.params.name, getattr(self.params, key))) + if self.previous is not None: + return self.previous._all_params(key, l) + return l + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})" + + def __eq__(self, other: SimulationSeries) -> bool: + return ( + self.path == other.path + and self.params == other.params + and self.previous == other.previous + ) + + def __contains__(self, other: SimulationSeries) -> bool: + if other is self or other == self: + return True + if self.previous is not None: + return other in self.previous class Spectrum(np.ndarray): @@ -129,6 +213,23 @@ class Spectrum(np.ndarray): class Pulse(Sequence): + path: Path + default_ind: Optional[int] + params: Parameters + z: np.ndarray + namx: int + t: np.ndarray + w: np.ndarray + w_order: np.ndarray + + def __new__(cls, path: os.PathLike, *args, **kwargs) -> "Pulse": + try: + if load_spectrum(Path(path) / SPECN_FN1.format(0)).ndim == 2: + return super().__new__(LegacyPulse) + except FileNotFoundError: + pass + return super().__new__(cls) + def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None): """load a data folder as a pulse @@ -144,36 +245,6 @@ class Pulse(Sequence): FileNotFoundError path does not contain proper data """ - self.logger = get_logger(__name__) - self.path = Path(path) - self.default_ind = default_ind - - if not self.path.is_dir(): - raise FileNotFoundError(f"Folder {self.path} does not exist") - - self.params = Parameters.load(self.path / "params.toml") - self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"]) - if self.params.fiber_map is None: - self.params.fiber_map = [(0.0, self.params.name)] - - try: - self.z = np.load(os.path.join(path, "z.npy")) - except FileNotFoundError: - if self.params is not None: - self.z = self.params.z_targets - else: - raise - self.nmax = len(list(self.path.glob("spectra_*.npy"))) - if self.nmax <= 0: - raise FileNotFoundError(f"No appropriate file in specified folder {self.path}") - - self.t = self.params.t - w = math.wspace(self.t) + units.m(self.params.wavelength) - self.w_order = np.argsort(w) - self.w = w - self.wl = units.m.inv(self.w) - self.params.w = self.w - self.params.z_targets = self.z def __iter__(self): """ @@ -190,73 +261,6 @@ class Pulse(Sequence): def __getitem__(self, key) -> Spectrum: return self.all_spectra(key) - def intensity(self, unit): - if unit.type in ["WL", "FREQ", "AFREQ"]: - x_axis = unit.inv(self.w) - else: - x_axis = unit.inv(self.t) - - order = np.argsort(x_axis) - func = dict( - WL=self._to_wl_int, - FREQ=self._to_freq_int, - AFREQ=self._to_afreq_int, - TIME=self._to_time_int, - )[unit.type] - - for spec in self: - yield x_axis[order], func(spec)[:, order] - - def _to_wl_int(self, spectrum): - return units.to_WL(math.abs2(spectrum), spectrum.wl) - - def _to_freq_int(self, spectrum): - return math.abs2(spectrum) - - def _to_afreq_int(self, spectrum): - return math.abs2(spectrum) - - def _to_time_int(self, spectrum): - return math.abs2(np.fft.ifft(spectrum)) - - def amplitude(self, unit): - if unit.type in ["WL", "FREQ", "AFREQ"]: - x_axis = unit.inv(self.w) - else: - x_axis = unit.inv(self.t) - - order = np.argsort(x_axis) - func = dict( - WL=self._to_wl_amp, - FREQ=self._to_freq_amp, - AFREQ=self._to_afreq_amp, - TIME=self._to_time_amp, - )[unit.type] - - for spec in self: - yield x_axis[order], func(spec)[:, order] - - def _to_wl_amp(self, spectrum): - return ( - np.sqrt( - units.to_WL( - math.abs2(spectrum), - spectrum.wl, - ) - ) - * spectrum - / np.abs(spectrum) - ) - - def _to_freq_amp(self, spectrum): - return spectrum - - def _to_afreq_amp(self, spectrum): - return spectrum - - def _to_time_amp(self, spectrum): - return np.fft.ifft(spectrum) - def all_spectra(self, ind=None) -> Spectrum: """ loads the data already simulated. @@ -305,12 +309,7 @@ class Pulse(Sequence): return np.fft.ifft(self.all_spectra(ind=ind), axis=-1) def _load1(self, i: int): - if i < 0: - i = self.nmax + i - spec = load_spectrum(self.path / SPECN_FN.format(i)) - spec = np.atleast_2d(spec) - spec = Spectrum(spec, self.params) - return spec + pass def plot_2D( self, @@ -412,3 +411,46 @@ class Pulse(Sequence): index """ return math.argclosest(self.z, z) + + +class LegacyPulse(Pulse): + def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None): + print("old init called", path, default_ind) + self.logger = get_logger(__name__) + self.path = Path(path) + self.default_ind = default_ind + + if not self.path.is_dir(): + raise FileNotFoundError(f"Folder {self.path} does not exist") + + self.params = Parameters.load(self.path / "params.toml") + self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"]) + if self.params.fiber_map is None: + self.params.fiber_map = [(0.0, self.params.name)] + + try: + self.z = np.load(os.path.join(path, "z.npy")) + except FileNotFoundError: + if self.params is not None: + self.z = self.params.z_targets + else: + raise + self.nmax = len(list(self.path.glob("spectra_*.npy"))) + if self.nmax <= 0: + raise FileNotFoundError(f"No appropriate file in specified folder {self.path}") + + self.t = self.params.t + w = math.wspace(self.t) + units.m(self.params.wavelength) + self.w_order = np.argsort(w) + self.w = w + self.wl = units.m.inv(self.w) + self.params.w = self.w + self.params.z_targets = self.z + + def _load1(self, i: int): + if i < 0: + i = self.nmax + i + spec = load_spectrum(self.path / SPECN_FN1.format(i)) + spec = np.atleast_2d(spec) + spec = Spectrum(spec, self.params) + return spec diff --git a/testing/long_tests/test_recovery_param_seq.py b/testing/long_tests/test_recovery_param_seq.py deleted file mode 100644 index 5d73367..0000000 --- a/testing/long_tests/test_recovery_param_seq.py +++ /dev/null @@ -1,35 +0,0 @@ -import shutil -import unittest - -import toml -from scgenerator import logger -from send2trash import send2trash - -TMP = "testing/.tmp" - - -class TestRecoveryParamSequence(unittest.TestCase): - def setUp(self): - shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP) - self.conf = toml.load(TMP + "/initial_config.toml") - io.set_data_folder(55, TMP) - - def test_remaining_simulations_count(self): - param_seq = initialize.RecoveryParamSequence(self.conf, 55) - self.assertEqual(5, len(param_seq)) - - def test_only_one_to_complete(self): - param_seq = initialize.RecoveryParamSequence(self.conf, 55) - i = 0 - for expected, (vary_list, params) in zip([True, False, False, False, False], param_seq): - i += 1 - self.assertEqual(expected, "recovery_last_stored" in params) - - self.assertEqual(5, i) - - def tearDown(self): - send2trash(TMP) - - -if __name__ == "__main__": - unittest.main() diff --git a/testing/test_initialize.py b/testing/test_initialize.py deleted file mode 100644 index edd3d90..0000000 --- a/testing/test_initialize.py +++ /dev/null @@ -1,216 +0,0 @@ -import unittest -from copy import deepcopy - -import numpy as np -import toml -from scgenerator import defaults, utils, math -from scgenerator.errors import * -from scgenerator.physics import pulse, units -from scgenerator.utils.parameter import Config, Parameters - - -def load_conf(name): - with open("testing/configs/" + name + ".toml") as file: - conf = toml.load(file) - return conf - - -def conf_maker(folder): - def conf(name): - return load_conf(folder + "/" + name) - - return conf - - -class TestParamSequence(unittest.TestCase): - def iterconf(self, files): - conf = conf_maker("param_sequence") - for path in files: - yield init.ParamSequence(conf(path)) - - def test_no_repeat_in_sub_folder_names(self): - for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]): - l = [] - s = [] - for vary_list, _ in utils.required_simulations(param_seq.config): - self.assertNotIn(vary_list, l) - self.assertNotIn(utils.format_variable_list(vary_list), s) - l.append(vary_list) - s.append(utils.format_variable_list(vary_list)) - - def test_no_variations_yields_only_num_and_id(self): - for param_seq in self.iterconf(["no_variations"]): - for vary_list, _ in utils.required_simulations(param_seq.config): - self.assertEqual(vary_list[1][0], "num") - self.assertEqual(vary_list[0][0], "id") - self.assertEqual(2, len(vary_list)) - - -class TestInitializeMethods(unittest.TestCase): - def test_validate_types(self): - conf = lambda s: load_conf("validate_types/" + s) - - with self.assertRaisesRegex(ValueError, r"'behaviors\[3\]' must be a str in"): - init.Config(**conf("bad2")) - - with self.assertRaisesRegex(TypeError, "value must be of type "): - init.Config(**conf("bad3")) - - with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"): - init.Config(**conf("bad4")) - - with self.assertRaisesRegex( - TypeError, "'variable intensity_noise' value must be of type " - ): - init.Config(**conf("bad5")) - - with self.assertRaisesRegex(ValueError, "'repeat' must be positive"): - init.Config(**conf("bad6")) - - with self.assertRaisesRegex( - ValueError, "variable parameter 'intensity_noise' must not be empty" - ): - init.Config(**conf("bad7")) - - self.assertIsNone(init.Config(**conf("good")).hr_w) - - def test_ensure_consistency(self): - conf = lambda s: load_conf("ensure_consistency/" + s) - with self.assertRaisesRegex( - MissingParameterError, - r"1 of '\['t0', 'width'\]' is required and no defaults have been set", - ): - init.Config(**conf("bad1")) - - with self.assertRaisesRegex( - MissingParameterError, - r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set", - ): - init.Config(**conf("bad2")) - - with self.assertRaisesRegex( - MissingParameterError, - r"2 of '\['dt', 't_num', 'time_window'\]' are required and no defaults have been set", - ): - init.Config(**conf("bad3")) - - with self.assertRaisesRegex( - DuplicateParameterError, - r"got multiple values for parameter 'width'", - ): - init.Config(**conf("bad4")) - - with self.assertRaisesRegex( - MissingParameterError, - r"'capillary_thickness' is a required parameter for fiber model 'hasan' and no defaults have been set", - ): - init.Config(**conf("bad5")) - - with self.assertRaisesRegex( - MissingParameterError, - r"1 of '\['capillary_spacing', 'capillary_outer_d'\]' is required for fiber model 'hasan' and no defaults have been set", - ): - init.Config(**conf("bad6")) - - self.assertLessEqual( - {"model": "pcf"}.items(), init.Config(**conf("good1")).__dict__.items() - ) - - self.assertIsNone(init.Config(**conf("good4")).gamma) - - self.assertLessEqual( - {"raman_type": "agrawal"}.items(), - init.Config(**conf("good2")).__dict__.items(), - ) - - self.assertLessEqual( - {"name": "no name"}.items(), init.Config(**conf("good3")).__dict__.items() - ) - - self.assertLessEqual( - {"capillary_nested": 0, "capillary_resonance_strengths": []}.items(), - init.Config(**conf("good4")).__dict__.items(), - ) - - self.assertLessEqual( - dict(he_mode=(1, 1)).items(), - init.Config(**conf("good5")).__dict__.items(), - ) - - self.assertLessEqual( - dict(temperature=300, pressure=1e5, gas_name="vacuum", plasma_density=0).items(), - init.Config(**conf("good5")).__dict__.items(), - ) - - def setup_conf_custom_field(self, path) -> Parameters: - - conf = load_conf(path) - conf = Parameters(**conf) - init.build_sim_grid_in_place(conf) - return conf - - def test_setup_custom_field(self): - d = np.load("testing/configs/custom_field/init_field.npz") - t = d["time"] - field = d["field"] - conf = self.setup_conf_custom_field("custom_field/no_change") - result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field( - conf - ) - self.assertAlmostEqual(conf.field_0.real.max(), field.real.max(), 4) - self.assertTrue(result) - - conf = self.setup_conf_custom_field("custom_field/peak_power") - result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field( - conf - ) - conf.wavelength = pulse.correct_wavelength(conf.wavelength, conf.w_c, conf.field_0) - self.assertAlmostEqual(math.abs2(conf.field_0).max(), 20000, 4) - self.assertTrue(result) - self.assertNotAlmostEqual(conf.wavelength, 1593e-9) - - conf = self.setup_conf_custom_field("custom_field/mean_power") - result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field( - conf - ) - self.assertAlmostEqual(np.trapz(math.abs2(conf.field_0), conf.t), 0.22 / 40e6, 4) - self.assertTrue(result) - - conf = self.setup_conf_custom_field("custom_field/recover1") - result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field( - conf - ) - self.assertAlmostEqual(math.abs2(conf.field_0 - field).sum(), 0) - self.assertTrue(result) - - conf = self.setup_conf_custom_field("custom_field/recover2") - result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field( - conf - ) - self.assertAlmostEqual((math.abs2(conf.field_0) / 0.9 - math.abs2(field)).sum(), 0) - self.assertTrue(result) - - conf = self.setup_conf_custom_field("custom_field/wavelength_shift1") - result = Parameters(**conf) - self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9) - - conf = self.setup_conf_custom_field("custom_field/wavelength_shift1") - conf.wavelength = 1593e-9 - result = Parameters(**conf) - - conf = load_conf("custom_field/wavelength_shift2") - conf = init.Config(**conf) - for target, (variable, config) in zip( - [1050e-9, 1321e-9, 1593e-9], init.ParamSequence(conf) - ): - init.build_sim_grid_in_place(conf) - self.assertAlmostEqual( - units.m.inv(config.w)[np.argmax(math.abs2(config.spec_0))], target - ) - print(config.wavelength, target) - - -if __name__ == "__main__": - conf = conf_maker("validate_types") - - unittest.main() diff --git a/testing/test_pulse.py b/testing/test_pulse.py deleted file mode 100644 index 6a632d4..0000000 --- a/testing/test_pulse.py +++ /dev/null @@ -1,41 +0,0 @@ -import unittest -from scgenerator.physics.pulse import conform_pulse_params - - -class TestPulseMethods(unittest.TestCase): - def test_conform_pulse_params(self): - self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, energy=6)) - self.assertNotIn(None, conform_pulse_params("gaussian", width=5, energy=6)) - self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, peak_power=6)) - self.assertNotIn(None, conform_pulse_params("gaussian", width=5, peak_power=6)) - - self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, energy=6))) - self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, energy=6))) - self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, peak_power=6))) - self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, peak_power=6))) - - with self.assertRaisesRegex( - TypeError, "when soliton number is desired, both gamma and beta2 must be specified" - ): - conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01) - with self.assertRaisesRegex( - TypeError, "when soliton number is desired, both gamma and beta2 must be specified" - ): - conform_pulse_params("gaussian", t0=5, energy=6, beta2=0.01) - - self.assertEqual( - 5, len(conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01, beta2=2e-6)) - ) - self.assertEqual( - 5, len(conform_pulse_params("gaussian", width=5, energy=6, gamma=0.01, beta2=2e-6)) - ) - self.assertEqual( - 5, len(conform_pulse_params("gaussian", t0=5, peak_power=6, gamma=0.01, beta2=2e-6)) - ) - self.assertEqual( - 5, len(conform_pulse_params("gaussian", width=5, peak_power=6, gamma=0.01, beta2=2e-6)) - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/testing/test_utils.py b/testing/test_utils.py deleted file mode 100644 index ef277eb..0000000 --- a/testing/test_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -import unittest - -import numpy as np -import toml -from scgenerator import utils - - -def load_conf(name): - with open("testing/configs/" + name + ".toml") as file: - conf = toml.load(file) - return conf - - -def conf_maker(folder, val=True): - def conf(name): - if val: - return initialize.Config(**load_conf(folder + "/" + name)) - else: - return initialize.Config(**load_conf(folder + "/" + name)) - - return conf - - -class TestUtilsMethods(unittest.TestCase): - def test_count_variations(self): - conf = conf_maker("count_variations") - - for sim, vary in [(1, 0), (1, 1), (2, 1), (2, 0), (120, 3)]: - self.assertEqual((sim, vary), utils.count_variations(conf(f"{sim}sim_{vary}vary"))) - - def test_format_value(self): - values = [ - 122e-6, - True, - ["raman", "ss"], - np.arange(5), - 1.123, - 1.1230001, - 0.002e122, - 12.3456e-9, - ] - s = [ - "0.000122", - "True", - "raman-ss", - "0-1-2-3-4", - "1.123", - "1.1230001", - "2e+119", - "1.23456e-08", - ] - - for value, target in zip(values, s): - self.assertEqual(target, utils.format_value(value)) - - def test_override_config(self): - conf = conf_maker("override", False) - old = conf("initial_config") - new = conf("fiber2") - - over = utils.override_config(vars(new), old) - self.assertNotIn("input_transmission", over.variable) - self.assertIsNone(over.input_transmission) - - -if __name__ == "__main__": - unittest.main() diff --git a/testing/test_variationer.py b/testing/test_variationer.py new file mode 100644 index 0000000..67033e2 --- /dev/null +++ b/testing/test_variationer.py @@ -0,0 +1,54 @@ +from pydantic import main +import scgenerator as sc + + +def test_descriptor(): + # Same branch + var1 = sc.VariationDescriptor( + raw_descr=[[("num", 1), ("a", False)], [("b", 0)]], index=[[1, 0], [0]] + ) + var2 = sc.VariationDescriptor( + raw_descr=[[("num", 2), ("a", False)], [("b", 0)]], index=[[1, 0], [0]] + ) + assert var1.branch.identifier == "b_0" + assert var1.identifier != var1.branch.identifier + assert var1.identifier != var2.identifier + assert var1.branch.identifier == var2.branch.identifier + + # different branch + var3 = sc.VariationDescriptor( + raw_descr=[[("num", 2), ("a", True)], [("b", 0)]], index=[[1, 0], [0]] + ) + assert var1.branch.identifier != var3.branch.identifier + assert var1.formatted_descriptor() != var2.formatted_descriptor() + assert var1.formatted_descriptor() != var3.formatted_descriptor() + + +def test_variationer(): + var = sc.Variationer( + [ + dict(a=[1, 2], num=[0, 1, 2]), + [dict(b=["000", "111"], c=["a", "-1"])], + dict(), + dict(), + [dict(aaa=[True, False], bb=[1, 3])], + ] + ) + assert var.var_num(0) == 6 + assert var.var_num(1) == 12 + assert var.var_num() == 24 + + cfg = dict(bb=None) + branches = set() + for descr in var.iterate(): + assert descr.update_config(cfg).items() >= set(descr.raw_descr[-1]) + branches.add(descr.branch.identifier) + assert len(branches) == 8 + + +def main(): + test_descriptor() + + +if __name__ == "__main__": + main()