From 2a3d222d85ad0d040c8af94a10350814736dcbe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Mon, 18 Oct 2021 11:29:17 +0200 Subject: [PATCH] new ConfigFile class --- src/scgenerator/__init__.py | 21 +- src/scgenerator/_utils/utils.py | 260 ------------------ src/scgenerator/{_utils => }/cache.py | 0 src/scgenerator/cli/cli.py | 2 +- src/scgenerator/{_utils => }/legacy.py | 4 +- src/scgenerator/math.py | 2 +- src/scgenerator/{_utils => }/parameter.py | 30 +- src/scgenerator/{_utils => }/pbar.py | 2 +- src/scgenerator/physics/__init__.py | 5 +- src/scgenerator/physics/fiber.py | 8 +- src/scgenerator/physics/materials.py | 11 +- src/scgenerator/physics/simulate.py | 14 +- src/scgenerator/physics/units.py | 74 ++++- src/scgenerator/plotting.py | 6 +- src/scgenerator/scripts/__init__.py | 8 +- src/scgenerator/scripts/slurm_submit.py | 4 +- src/scgenerator/spectra.py | 6 +- .../{_utils/__init__.py => utils.py} | 221 ++++++++++++++- src/scgenerator/{_utils => }/variationer.py | 3 +- 19 files changed, 330 insertions(+), 351 deletions(-) delete mode 100644 src/scgenerator/_utils/utils.py rename src/scgenerator/{_utils => }/cache.py (100%) rename src/scgenerator/{_utils => }/legacy.py (96%) rename src/scgenerator/{_utils => }/parameter.py (98%) rename src/scgenerator/{_utils => }/pbar.py (99%) rename src/scgenerator/{_utils/__init__.py => utils.py} (57%) rename src/scgenerator/{_utils => }/variationer.py (99%) diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index c5f3af3..e4adb8b 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -1,25 +1,20 @@ from . import math +from .legacy import convert_sim_folder from .math import abs2, argclosest, span +from .parameter import Configuration, Parameters from .physics import fiber, materials, pulse, simulate, units from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation +from .physics.units import PlotRange from .plotting import ( + get_extent, mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot, - transform_2D_propagation, transform_1D_values, + transform_2D_propagation, transform_mean_values, - get_extent, -) -from .spectra import Spectrum, SimulationSeries -from ._utils import Paths, _open_config, parameter, open_single_config -from ._utils.parameter import Configuration, Parameters -from ._utils.utils import PlotRange -from ._utils.legacy import convert_sim_folder -from ._utils.variationer import ( - Variationer, - VariationDescriptor, - VariationSpecsError, - DescriptorDict, ) +from .spectra import SimulationSeries, Spectrum +from .utils import Paths, _open_config, open_single_config +from .variationer import DescriptorDict, VariationDescriptor, Variationer, VariationSpecsError diff --git a/src/scgenerator/_utils/utils.py b/src/scgenerator/_utils/utils.py deleted file mode 100644 index 2931f09..0000000 --- a/src/scgenerator/_utils/utils.py +++ /dev/null @@ -1,260 +0,0 @@ -import inspect -import os -import re -from collections import defaultdict -from functools import cache -from pathlib import Path -from string import printable as str_printable -from typing import Any, Callable, Iterator, Set - -import numpy as np -import toml -from pydantic import BaseModel - -from .._utils import load_toml, save_toml -from ..const import PARAM_FN, PARAM_SEPARATOR, Z_FN -from ..physics.units import get_unit - - -class HashableBaseModel(BaseModel): - """Pydantic BaseModel that's immutable and can be hashed""" - - def __hash__(self) -> int: - return hash(type(self)) + sum(hash(v) for v in self.__dict__.values()) - - class Config: - allow_mutation = False - - -def to_62(i: int) -> str: - arr = [] - if i == 0: - return "0" - i = abs(i) - while i: - i, value = divmod(i, 62) - arr.append(str_printable[value]) - return "".join(reversed(arr)) - - -class PlotRange(HashableBaseModel): - left: float - right: float - unit: Callable[[float], float] - conserved_quantity: bool = True - - def __init__(self, left, right, unit, **kwargs): - super().__init__(left=left, right=right, unit=get_unit(unit), **kwargs) - - def __str__(self): - return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}" - - def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: - return sort_axis(axis, self) - - def __iter__(self): - yield self.left - yield self.right - yield self.unit.__name__ - - -def sort_axis( - axis: np.ndarray, plt_range: PlotRange -) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: - """ - given an axis, returns this axis cropped according to the given range, converted and sorted - - Parameters - ---------- - axis : 1D array containing the original axis (usual the w or t array) - plt_range : tupple (min, max, conversion_function) used to crop the axis - - Returns - ------- - cropped : the axis cropped, converted and sorted - indices : indices to use to slice and sort other array in the same fashion - extent : tupple with min and max of cropped - - Example - ------- - w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20)) - t = np.linspace(-10, 10, 400) - W, T = np.meshgrid(w, t) - y = np.exp(-W**2 - T**2) - - # Define ranges - rw = (-4, 4, s) - rt = (-2, 6, s) - - w, cw = sort_axis(w, rw) - t, ct = sort_axis(t, rt) - - # slice y according to the given ranges - y = y[ct][:, cw] - """ - if isinstance(plt_range, tuple): - plt_range = PlotRange(*plt_range) - r = np.array((plt_range.left, plt_range.right), dtype="float") - - indices = np.arange(len(axis))[ - (axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r))) - ] - cropped = axis[indices] - order = np.argsort(plt_range.unit.inv(cropped)) - indices = indices[order] - cropped = cropped[order] - out_ax = plt_range.unit.inv(cropped) - - return out_ax, indices, (out_ax[0], out_ax[-1]) - - -def get_arg_names(func: Callable) -> list[str]: - # spec = inspect.getfullargspec(func) - # args = spec.args - # if spec.defaults is not None and len(spec.defaults) > 0: - # args = args[: -len(spec.defaults)] - # return args - return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty] - - -def validate_arg_names(names: list[str]): - for n in names: - if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: - raise ValueError(f"{n} is an invalid parameter name") - - -def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable: - if arg_names is None: - arg_names = get_arg_names(func) - else: - validate_arg_names(arg_names) - validate_arg_names(kwarg_names) - sign_arg_str = ", ".join(arg_names + kwarg_names) - call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) - tmp_name = f"{func.__name__}_0" - func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" - scope = dict(__func__=func) - exec(func_str, scope) - out_func = scope[tmp_name] - out_func.__module__ = "evaluator" - return out_func - - -@cache -def _mock_function(num_args: int, num_returns: int) -> Callable: - arg_str = ", ".join("a" * (n + 1) for n in range(num_args)) - return_str = ", ".join("True" for _ in range(num_returns)) - func_name = f"__mock_{num_args}_{num_returns}" - func_str = f"def {func_name}({arg_str}):\n return {return_str}" - scope = {} - exec(func_str, scope) - out_func = scope[func_name] - out_func.__module__ = "evaluator" - return out_func - - -def combine_simulations(path: Path, dest: Path = None): - """combines raw simulations into one folder per branch - - Parameters - ---------- - path : Path - source of the simulations (must contain u_xx directories) - dest : Path, optional - if given, moves the simulations to dest, by default None - """ - paths: dict[str, list[Path]] = defaultdict(list) - if dest is None: - dest = path - - for p in path.glob("u_*b_*"): - if p.is_dir(): - paths[p.name.split()[1]].append(p) - for l in paths.values(): - l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0]) - for pulses in paths.values(): - new_path = dest / update_path(pulses[0].name) - os.makedirs(new_path, exist_ok=True) - for num, pulse in enumerate(pulses): - params_ok = False - for file in pulse.glob("*"): - if file.name == PARAM_FN: - if not params_ok: - update_params(new_path, file) - params_ok = True - else: - file.unlink() - elif file.name == Z_FN: - file.rename(new_path / file.name) - elif file.name.startswith("spectr") and num == 0: - file.rename(new_path / file.name) - else: - file.rename(new_path / (file.stem + f"_{num}" + file.suffix)) - pulse.rmdir() - - -def update_params(new_path: Path, file: Path): - params = load_toml(file) - if (p := params.get("prev_data_dir")) is not None: - p = Path(p) - params["prev_data_dir"] = str(p.parent / update_path(p.name)) - params["output_path"] = str(new_path) - save_toml(new_path / PARAM_FN, params) - file.unlink() - - -def save_parameters( - params: dict[str, Any], destination_dir: Path, file_name: str = PARAM_FN -) -> Path: - """saves a parameter dictionary. Note that is does remove some entries, particularly - those that take a lot of space ("t", "w", ...) - - Parameters - ---------- - params : dict[str, Any] - dictionary to save - destination_dir : Path - destination directory - - Returns - ------- - Path - path to newly created the paramter file - """ - file_path = destination_dir / file_name - os.makedirs(file_path.parent, exist_ok=True) - - # save toml of the simulation - with open(file_path, "w") as file: - toml.dump(params, file, encoder=toml.TomlNumpyEncoder()) - - return file_path - - -def update_path(p: str) -> str: - return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p) - - -def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str: - return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name]) - - -def simulations_list(path: os.PathLike) -> list[Path]: - """finds simulations folders contained in a parent directory - - Parameters - ---------- - path : os.PathLike - parent path - - Returns - ------- - list[Path] - Absolute Path to the simulation folder - """ - paths: list[Path] = [] - for pwd, _, files in os.walk(path): - if PARAM_FN in files: - paths.append(Path(pwd)) - paths.sort(key=lambda el: el.parent.name) - return [p for p in paths if p.parent.name == paths[-1].parent.name] diff --git a/src/scgenerator/_utils/cache.py b/src/scgenerator/cache.py similarity index 100% rename from src/scgenerator/_utils/cache.py rename to src/scgenerator/cache.py diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index 4d6ed39..ab6f8fc 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -7,7 +7,7 @@ from pathlib import Path import matplotlib.pyplot as plt import numpy as np -from .. import _utils as utils +from .. import utils from .. import const, env, scripts from ..logger import get_logger from ..physics.fiber import dispersion_coefficients diff --git a/src/scgenerator/_utils/legacy.py b/src/scgenerator/legacy.py similarity index 96% rename from src/scgenerator/_utils/legacy.py rename to src/scgenerator/legacy.py index 35cf4fd..535f7ab 100644 --- a/src/scgenerator/_utils/legacy.py +++ b/src/scgenerator/legacy.py @@ -7,11 +7,11 @@ from typing import Any, Set import numpy as np import toml -from ..const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1, Z_FN +from .const import SPEC1_FN, SPEC1_FN_N, SPECN_FN1 from .parameter import Configuration, Parameters from .utils import save_parameters from .pbar import PBars -from .variationer import VariationDescriptor, Variationer +from .variationer import VariationDescriptor def load_config(path: os.PathLike) -> dict[str, Any]: diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 4b6f069..adf4bbb 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -3,7 +3,7 @@ from typing import Union import numpy as np from scipy.interpolate import griddata, interp1d from scipy.special import jn_zeros -from ._utils.cache import np_cache +from .cache import np_cache pi = np.pi c = 299792458.0 diff --git a/src/scgenerator/_utils/parameter.py b/src/scgenerator/parameter.py similarity index 98% rename from src/scgenerator/_utils/parameter.py rename to src/scgenerator/parameter.py index 90dda50..8f0f6e9 100644 --- a/src/scgenerator/_utils/parameter.py +++ b/src/scgenerator/parameter.py @@ -8,34 +8,22 @@ import os import re import time from collections import defaultdict -from copy import copy, deepcopy +from copy import copy from dataclasses import asdict, dataclass, fields -from functools import cache, lru_cache +from functools import lru_cache from pathlib import Path -from typing import ( - Any, - Callable, - Generator, - Iterable, - Iterator, - Literal, - Optional, - Sequence, - TypeVar, - Union, -) +from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar, Union import numpy as np from numpy.lib import isin -from .. import _utils as utils -from .. import env, math -from .._utils.variationer import VariationDescriptor, Variationer -from ..const import PARAM_FN, PARAM_SEPARATOR, __version__ -from ..errors import EvaluatorError, NoDefaultError -from ..logger import get_logger -from ..physics import fiber, materials, pulse, units +from . import env, math, utils +from .const import PARAM_FN, __version__ +from .errors import EvaluatorError, NoDefaultError +from .logger import get_logger +from .physics import fiber, materials, pulse, units from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path +from .variationer import VariationDescriptor, Variationer T = TypeVar("T") diff --git a/src/scgenerator/_utils/pbar.py b/src/scgenerator/pbar.py similarity index 99% rename from src/scgenerator/_utils/pbar.py rename to src/scgenerator/pbar.py index 37db473..f7fab52 100644 --- a/src/scgenerator/_utils/pbar.py +++ b/src/scgenerator/pbar.py @@ -10,7 +10,7 @@ from typing import Iterable, Union from tqdm import tqdm -from ..env import pbar_policy +from .env import pbar_policy T_ = typing.TypeVar("T_") diff --git a/src/scgenerator/physics/__init__.py b/src/scgenerator/physics/__init__.py index 4048118..f63631d 100644 --- a/src/scgenerator/physics/__init__.py +++ b/src/scgenerator/physics/__init__.py @@ -10,7 +10,8 @@ from scipy.optimize import minimize_scalar from .. import math from . import fiber, materials, units, pulse -from .._utils import cache, load_material_dico +from ..cache import np_cache +from ..utils import load_material_dico T = TypeVar("T") @@ -21,7 +22,7 @@ def group_delay_to_gdd(wavelength: np.ndarray, group_delay: np.ndarray) -> np.nd return gdd -@cache.np_cache +@np_cache def material_dispersion( wavelengths: np.ndarray, material: str, diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 8bac854..0ad35e2 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -1,21 +1,19 @@ from typing import Any, Iterable, Literal, TypeVar import numpy as np -from numpy.fft import fft, ifft from numpy import e +from numpy.fft import fft, ifft from numpy.polynomial.chebyshev import Chebyshev, cheb2poly from scipy.interpolate import interp1d +from .. import utils +from ..cache import np_cache from ..logger import get_logger - -from .. import _utils as utils from ..math import abs2, argclosest, power_fact, u_nm -from .._utils.cache import np_cache from . import materials as mat from . import units from .units import c, pi - pipi = 2 * pi T = TypeVar("T") diff --git a/src/scgenerator/physics/materials.py b/src/scgenerator/physics/materials.py index 4d35c65..43dee9a 100644 --- a/src/scgenerator/physics/materials.py +++ b/src/scgenerator/physics/materials.py @@ -1,13 +1,14 @@ from typing import Any, Callable + import numpy as np import scipy.special from scipy.integrate import cumulative_trapezoid +from .. import utils +from ..cache import np_cache from ..logger import get_logger from . import units -from .. import _utils -from .units import NA, c, kB, me, e, hbar -from .._utils.cache import np_cache +from .units import NA, c, e, hbar, kB, me @np_cache @@ -15,7 +16,7 @@ def n_gas_2( wl_for_disp: np.ndarray, gas_name: str, pressure: float, temperature: float, ideal_gas: bool ): """Returns the sqare of the index of refraction of the specified gas""" - material_dico = _utils.load_material_dico(gas_name) + material_dico = utils.load_material_dico(gas_name) if ideal_gas: n_gas_2 = sellmeier(wl_for_disp, material_dico, pressure, temperature) + 1 @@ -218,7 +219,7 @@ def gas_n2(gas_name: str, pressure: float, temperature: float) -> float: float n2 in m2/W """ - return non_linear_refractive_index(_utils.load_material_dico(gas_name), pressure, temperature) + return non_linear_refractive_index(utils.load_material_dico(gas_name), pressure, temperature) def adiabadicity(w: np.ndarray, I: float, field: np.ndarray) -> np.ndarray: diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 1ee4c1f..83def58 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -7,17 +7,13 @@ from pathlib import Path from typing import Any, Generator, Type, Union import numpy as np -from send2trash import send2trash -from .. import env -from .. import _utils as utils -from .._utils.utils import combine_simulations, save_parameters +from .. import utils from ..logger import get_logger -from .._utils.parameter import Configuration, Parameters -from .._utils.pbar import PBars, ProgressBarActor, progress_worker +from ..parameter import Configuration, Parameters +from ..pbar import PBars, ProgressBarActor, progress_worker from . import pulse from .fiber import create_non_linear_op, fast_dispersion_op -from scgenerator._utils import pbar try: import ray @@ -505,7 +501,7 @@ class Simulations: for variable, params in self.configuration: params.compute() v_list_str = variable.formatted_descriptor(True) - save_parameters(params.prepare_for_dump(), Path(params.output_path)) + utils.save_parameters(params.prepare_for_dump(), Path(params.output_path)) self.new_sim(v_list_str, params) self.finish() @@ -737,7 +733,7 @@ def run_simulation( sim.run() for path in config.fiber_paths: - combine_simulations(path) + utils.combine_simulations(path) def new_simulation( diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index f648bb9..b03823f 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -3,7 +3,7 @@ # to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ... from typing import Callable, TypeVar, Union - +from operator import itemgetter import numpy as np from numpy import pi @@ -274,3 +274,75 @@ def to_log2D(arr, ref=None): m = arr / ref m = 10 * np.log10(m, out=np.zeros_like(m) - 100, where=m > 0) return m + + +class PlotRange(tuple): + left: float = property(itemgetter(0)) + right: float = property(itemgetter(1)) + unit: Callable[[float], float] = property(itemgetter(2)) + conserved_quantity: bool = property(itemgetter(3)) + __slots__ = [] + + def __new__(cls, left, right, unit, conserved_quantity=True): + return tuple.__new__(cls, (left, right, get_unit(unit), conserved_quantity)) + + def __iter__(self): + yield self.left + yield self.right + yield self.unit.__name__ + + def __repr__(self): + return "PlotRange" + super().__repr__() + + def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: + return sort_axis(axis, self) + + +def sort_axis( + axis: np.ndarray, plt_range: PlotRange +) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: + """ + given an axis, returns this axis cropped according to the given range, converted and sorted + + Parameters + ---------- + axis : 1D array containing the original axis (usual the w or t array) + plt_range : tupple (min, max, conversion_function) used to crop the axis + + Returns + ------- + cropped : the axis cropped, converted and sorted + indices : indices to use to slice and sort other array in the same fashion + extent : tupple with min and max of cropped + + Example + ------- + w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20)) + t = np.linspace(-10, 10, 400) + W, T = np.meshgrid(w, t) + y = np.exp(-W**2 - T**2) + + # Define ranges + rw = (-4, 4, s) + rt = (-2, 6, s) + + w, cw = sort_axis(w, rw) + t, ct = sort_axis(t, rt) + + # slice y according to the given ranges + y = y[ct][:, cw] + """ + if not isinstance(plt_range, PlotRange): + plt_range = PlotRange(*plt_range) + r = np.array((plt_range.left, plt_range.right), dtype="float") + + indices = np.arange(len(axis))[ + (axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r))) + ] + cropped = axis[indices] + order = np.argsort(plt_range.unit.inv(cropped)) + indices = indices[order] + cropped = cropped[order] + out_ax = plt_range.unit.inv(cropped) + + return out_ax, indices, (out_ax[0], out_ax[-1]) diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 30afa42..1df1851 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -11,13 +11,13 @@ from scipy.interpolate import UnivariateSpline from scipy.interpolate.interpolate import interp1d from . import math -from ._utils import load_spectrum -from ._utils.parameter import Parameters -from ._utils.utils import PlotRange, sort_axis from .const import PARAM_SEPARATOR, SPEC1_FN from .defaults import default_plotting as defaults from .math import abs2, span +from .parameter import Parameters from .physics import pulse, units +from .physics.units import PlotRange, sort_axis +from .utils import load_spectrum RangeType = tuple[float, float, Union[str, Callable]] NO_LIM = object() diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index 1d5fc8f..a50f34e 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -9,15 +9,11 @@ from tqdm import tqdm from .. import env, math from ..const import PARAM_FN, PARAM_SEPARATOR +from ..parameter import Configuration, Parameters from ..physics import fiber, units from ..plotting import plot_setup from ..spectra import SimulationSeries -from .._utils import auto_crop, _open_config, save_toml, translate_parameters -from .._utils.parameter import ( - Configuration, - Parameters, -) -from .._utils.utils import simulations_list +from ..utils import _open_config, auto_crop, save_toml, simulations_list, translate_parameters def fingerprint(params: Parameters): diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index 3fc4b6b..1c134ca 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -9,8 +9,8 @@ from typing import Tuple import numpy as np -from .._utils import Paths -from .._utils.parameter import Configuration +from ..utils import Paths +from ..parameter import Configuration def primes(n): diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index eede238..915118a 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -8,18 +8,18 @@ import matplotlib.pyplot as plt import numpy as np from . import math -from ._utils import load_spectrum -from ._utils.parameter import Parameters -from ._utils.utils import PlotRange, simulations_list from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N from .logger import get_logger +from .parameter import Parameters from .physics import pulse, units +from .physics.units import PlotRange from .plotting import ( mean_values_plot, propagation_plot, single_position_plot, transform_2D_propagation, ) +from .utils import load_spectrum, simulations_list class Spectrum(np.ndarray): diff --git a/src/scgenerator/_utils/__init__.py b/src/scgenerator/utils.py similarity index 57% rename from src/scgenerator/_utils/__init__.py rename to src/scgenerator/utils.py index ea0184f..c81a68b 100644 --- a/src/scgenerator/_utils/__init__.py +++ b/src/scgenerator/utils.py @@ -3,33 +3,27 @@ This files includes utility functions designed more or less to be used specifica scgenerator module but some function may be used in any python program """ - from __future__ import annotations - +from dataclasses import dataclass +import inspect import itertools import os -from collections import abc +import re +from collections import defaultdict +from functools import cache from pathlib import Path from string import printable as str_printable -from functools import cache -from typing import Any, MutableMapping, Sequence, TypeVar - +from typing import Any, Callable, MutableMapping, Sequence, TypeVar import numpy as np -from numpy.lib.arraysetops import isin import pkg_resources as pkg import toml -from tqdm import tqdm -import itertools - -from ..const import SPEC1_FN, __version__ -from ..logger import get_logger +from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN +from .logger import get_logger T_ = TypeVar("T_") -PathTree = list[tuple[Path, ...]] - class Paths: _data_files = [ @@ -79,6 +73,42 @@ class Paths: return os.path.join(cls.get("plots"), name) +class ConfigFileParser: + path: Path + repeat: int + master: ConfigFileParser.SubConfig + configs: list[ConfigFileParser.SubConfig] + + @dataclass + class SubConfig: + fixed: dict[str, Any] + variable: dict[str, list] + + def __init__(self, path: os.PathLike): + self.path = Path(path) + fiber_list: list[dict[str, Any]] + if self.path.name.lower().endswith(".toml"): + loaded_config = _open_config(self.path) + fiber_list = loaded_config.pop("Fiber") + else: + loaded_config = dict(name=self.path.name) + fiber_list = [_open_config(p) for p in sorted(self.path.glob("initial_config*.toml"))] + + if len(fiber_list) == 0: + raise ValueError(f"No fiber in config {self.path}") + final_path = loaded_config.get("name") + configs = [] + for i, params in enumerate(fiber_list): + configs.append(loaded_config | params) + for root_vary, first_vary in itertools.product( + loaded_config["variable"], configs[0]["variable"] + ): + if len(common := root_vary.keys() & first_vary.keys()) != 0: + raise ValueError(f"These variable keys are specified twice : {common!r}") + configs[0] |= {k: v for k, v in loaded_config.items() if k != "variable"} + configs[0]["variable"].append(dict(num=list(range(configs[0].get("repeat", 1))))) + + def load_previous_spectrum(prev_data_dir: str) -> np.ndarray: prev_data_dir = Path(prev_data_dir) num = find_last_spectrum_num(prev_data_dir) @@ -329,3 +359,166 @@ def translate_parameters(d: dict[str, Any]) -> dict[str, Any]: else: new[old_names.get(k, k)] = v return defaults_to_add | new + + +def to_62(i: int) -> str: + arr = [] + if i == 0: + return "0" + i = abs(i) + while i: + i, value = divmod(i, 62) + arr.append(str_printable[value]) + return "".join(reversed(arr)) + + +def get_arg_names(func: Callable) -> list[str]: + # spec = inspect.getfullargspec(func) + # args = spec.args + # if spec.defaults is not None and len(spec.defaults) > 0: + # args = args[: -len(spec.defaults)] + # return args + return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty] + + +def validate_arg_names(names: list[str]): + for n in names: + if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: + raise ValueError(f"{n} is an invalid parameter name") + + +def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable: + if arg_names is None: + arg_names = get_arg_names(func) + else: + validate_arg_names(arg_names) + validate_arg_names(kwarg_names) + sign_arg_str = ", ".join(arg_names + kwarg_names) + call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) + tmp_name = f"{func.__name__}_0" + func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" + scope = dict(__func__=func) + exec(func_str, scope) + out_func = scope[tmp_name] + out_func.__module__ = "evaluator" + return out_func + + +@cache +def _mock_function(num_args: int, num_returns: int) -> Callable: + arg_str = ", ".join("a" * (n + 1) for n in range(num_args)) + return_str = ", ".join("True" for _ in range(num_returns)) + func_name = f"__mock_{num_args}_{num_returns}" + func_str = f"def {func_name}({arg_str}):\n return {return_str}" + scope = {} + exec(func_str, scope) + out_func = scope[func_name] + out_func.__module__ = "evaluator" + return out_func + + +def combine_simulations(path: Path, dest: Path = None): + """combines raw simulations into one folder per branch + + Parameters + ---------- + path : Path + source of the simulations (must contain u_xx directories) + dest : Path, optional + if given, moves the simulations to dest, by default None + """ + paths: dict[str, list[Path]] = defaultdict(list) + if dest is None: + dest = path + + for p in path.glob("u_*b_*"): + if p.is_dir(): + paths[p.name.split()[1]].append(p) + for l in paths.values(): + l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0]) + for pulses in paths.values(): + new_path = dest / update_path(pulses[0].name) + os.makedirs(new_path, exist_ok=True) + for num, pulse in enumerate(pulses): + params_ok = False + for file in pulse.glob("*"): + if file.name == PARAM_FN: + if not params_ok: + update_params(new_path, file) + params_ok = True + else: + file.unlink() + elif file.name == Z_FN: + file.rename(new_path / file.name) + elif file.name.startswith("spectr") and num == 0: + file.rename(new_path / file.name) + else: + file.rename(new_path / (file.stem + f"_{num}" + file.suffix)) + pulse.rmdir() + + +def update_params(new_path: Path, file: Path): + params = load_toml(file) + if (p := params.get("prev_data_dir")) is not None: + p = Path(p) + params["prev_data_dir"] = str(p.parent / update_path(p.name)) + params["output_path"] = str(new_path) + save_toml(new_path / PARAM_FN, params) + file.unlink() + + +def save_parameters( + params: dict[str, Any], destination_dir: Path, file_name: str = PARAM_FN +) -> Path: + """saves a parameter dictionary. Note that is does remove some entries, particularly + those that take a lot of space ("t", "w", ...) + + Parameters + ---------- + params : dict[str, Any] + dictionary to save + destination_dir : Path + destination directory + + Returns + ------- + Path + path to newly created the paramter file + """ + file_path = destination_dir / file_name + os.makedirs(file_path.parent, exist_ok=True) + + # save toml of the simulation + with open(file_path, "w") as file: + toml.dump(params, file, encoder=toml.TomlNumpyEncoder()) + + return file_path + + +def update_path(p: str) -> str: + return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p) + + +def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str: + return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name]) + + +def simulations_list(path: os.PathLike) -> list[Path]: + """finds simulations folders contained in a parent directory + + Parameters + ---------- + path : os.PathLike + parent path + + Returns + ------- + list[Path] + Absolute Path to the simulation folder + """ + paths: list[Path] = [] + for pwd, _, files in os.walk(path): + if PARAM_FN in files: + paths.append(Path(pwd)) + paths.sort(key=lambda el: el.parent.name) + return [p for p in paths if p.parent.name == paths[-1].parent.name] diff --git a/src/scgenerator/_utils/variationer.py b/src/scgenerator/variationer.py similarity index 99% rename from src/scgenerator/_utils/variationer.py rename to src/scgenerator/variationer.py index 538ca4f..eb96c49 100644 --- a/src/scgenerator/_utils/variationer.py +++ b/src/scgenerator/variationer.py @@ -8,8 +8,7 @@ import numpy as np from pydantic import validator from pydantic.main import BaseModel -from ..const import PARAM_SEPARATOR -from . import utils +from .const import PARAM_SEPARATOR T = TypeVar("T")