diff --git a/src/scgenerator/_utils/legacy.py b/src/scgenerator/_utils/legacy.py index 407e040..e198f55 100644 --- a/src/scgenerator/_utils/legacy.py +++ b/src/scgenerator/_utils/legacy.py @@ -27,84 +27,67 @@ def load_config_sequence(path: os.PathLike) -> tuple[list[Path], list[dict[str, def convert_sim_folder(path: os.PathLike): - path = Path(path) + path = Path(path).resolve() config_paths, configs = load_config_sequence(path) master_config = dict(name=path.name, Fiber=configs) with open(path / "initial_config.toml", "w") as f: toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder()) - configuration = Configuration(path / "initial_config.toml") - new_fiber_paths: list[Path] = [ - path / fiber_folder(i, path.name, cfg["name"]) for i, cfg in enumerate(configs) - ] - for p in new_fiber_paths: - p.mkdir(exist_ok=True) - repeat = configs[0].get("repeat", 1) - + configuration = Configuration(path / "initial_config.toml", final_output_path=path) pbar = PBars(configuration.total_num_steps, "Converting") - old_paths: dict[Path, VariationDescriptor] = { - path / descr.branch.formatted_descriptor(): (descr, param.final_path) - for descr, param in configuration - } + new_paths: dict[VariationDescriptor, Parameters] = dict(configuration) + old_paths: Set[Path] = set() + old2new: list[tuple[Path, VariationDescriptor, Parameters, tuple[int, int]]] = [] + for descriptor, params in configuration.iterate_single_fiber(-1): + old_path = path / descriptor.branch.formatted_descriptor() + if not Path(old_path).is_dir(): + raise FileNotFoundError(f"missing {old_path} from {path}. Aborting.") + old_paths.add(old_path) + for d in descriptor.iter_parents(): + z_num_start = sum(c["z_num"] for c in configs[: d.num_fibers - 1]) + z_limits = (z_num_start, z_num_start + params.z_num) + old2new.append((old_path, d, new_paths[d], z_limits)) - # create map from old to new path - - pprint(old_paths) - quit() - for p in old_paths: - if not p.is_dir(): - raise FileNotFoundError(f"missing {p} from {path}") processed_paths: Set[Path] = set() - for old_variation_path, descriptor in old_paths.items(): # fiberA=0, fiber B=0 - vary_parts = old_variation_path.name.split("fiber")[1:] - identifiers = [ - "".join("fiber" + el for el in vary_parts[: i + 1]).strip() - for i in range(len(vary_parts)) - ] - cum_z_num = 0 - for i, (fiber_path, new_identifier) in enumerate(zip(new_fiber_paths, identifiers)): - config = descriptor.update_config(configs[i], i) - new_variation_path = fiber_path / new_identifier - z_num = config["z_num"] - move = new_variation_path not in processed_paths - os.makedirs(new_variation_path, exist_ok=True) - processed_paths.add(new_variation_path) + processed_specs: Set[VariationDescriptor] = set() - for spec_num in range(cum_z_num, cum_z_num + z_num): - old_spec = old_variation_path / SPECN_FN1.format(spec_num) - if move: - spec_data = np.load(old_spec) - for j, spec1 in enumerate(spec_data): - if j == 0: - np.save( - new_variation_path / SPEC1_FN.format(spec_num - cum_z_num), spec1 - ) - else: - np.save( - new_variation_path / SPEC1_FN_N.format(spec_num - cum_z_num, j), - spec1, - ) - pbar.update() - else: - pbar.update(value=repeat) - old_spec.unlink() - if move: - if i > 0: - config["prev_data_dir"] = str( - (new_fiber_paths[i - 1] / identifiers[i - 1]).resolve() - ) - params = Parameters(**config) - params.compute() - save_parameters(params.prepare_for_dump(), new_variation_path) - cum_z_num += z_num - (old_variation_path / PARAM_FN).unlink() - (old_variation_path / Z_FN).unlink() - old_variation_path.rmdir() + for old_path, descr, new_params, (start_z, end_z) in old2new: + move_specs = descr not in processed_specs + processed_specs.add(descr) + if (parent := descr.parent) is not None: + new_params.prev_data_dir = str(new_paths[parent].final_path) + save_parameters(new_params.prepare_for_dump(), new_params.final_path) + for spec_num in range(start_z, end_z): + old_spec = old_path / SPECN_FN1.format(spec_num) + if move_specs: + _mv_specs(pbar, new_params, start_z, spec_num, old_spec) + old_spec.unlink() + if old_path not in processed_paths: + (old_path / PARAM_FN).unlink() + (old_path / Z_FN).unlink() + processed_paths.add(old_path) + + for old_path in processed_paths: + old_path.rmdir() for cp in config_paths: cp.unlink() +def _mv_specs(pbar: PBars, new_params: Parameters, start_z: int, spec_num: int, old_spec: Path): + os.makedirs(new_params.final_path, exist_ok=True) + spec_data = np.load(old_spec) + for j, spec1 in enumerate(spec_data): + if j == 0: + np.save(new_params.final_path / SPEC1_FN.format(spec_num - start_z), spec1) + else: + np.save( + new_params.final_path / SPEC1_FN_N.format(spec_num - start_z, j), + spec1, + ) + pbar.update() + + def main(): convert_sim_folder(sys.argv[1]) diff --git a/src/scgenerator/_utils/parameter.py b/src/scgenerator/_utils/parameter.py index 2d871ac..0f463d5 100644 --- a/src/scgenerator/_utils/parameter.py +++ b/src/scgenerator/_utils/parameter.py @@ -12,7 +12,18 @@ from copy import copy, deepcopy from dataclasses import asdict, dataclass, fields from functools import cache, lru_cache from pathlib import Path -from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union +from typing import ( + Any, + Callable, + Generator, + Iterable, + Iterator, + Literal, + Optional, + Sequence, + TypeVar, + Union, +) import numpy as np from numpy.lib import isin @@ -523,7 +534,7 @@ class Parameters(_AbstractParameters): @property def final_path(self) -> Path: if self.output_path is not None: - return update_path(self.output_path) + return Path(update_path(self.output_path)) return None @@ -820,22 +831,26 @@ class Configuration: def __init__( self, - final_config_path: os.PathLike, + config_path: os.PathLike, overwrite: bool = True, wait: bool = False, skip_callback: Callable[[int], None] = None, + final_output_path: os.PathLike = None, ): self.logger = get_logger(__name__) self.wait = wait self.overwrite = overwrite - self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path) + self.final_path, self.fiber_configs = utils.load_config_sequence(config_path) + self.final_path = env.get(env.OUTPUT_PATH, self.final_path) + if final_output_path is not None: + self.final_path = final_output_path self.final_path = utils.ensure_folder( - Path(env.get(env.OUTPUT_PATH, self.final_path)), + Path(self.final_path), mkdir=False, prevent_overwrite=not self.overwrite, ) - self.master_config = self.fiber_configs[0] + self.master_config = self.fiber_configs[0].copy() self.name = self.final_path.name self.z_num = 0 self.total_num_steps = 0 @@ -874,7 +889,7 @@ class Configuration: self.parallel = self.master_config.get("parallel", Parameters.parallel.default) def __build_base_config(self): - cfg = self.fiber_configs[0].copy() + cfg = self.master_config.copy() vary = cfg.pop("variable", {}) return cfg | {k: v[0] for k, v in vary.items()} @@ -887,15 +902,11 @@ class Configuration: if len(v) == 0: raise ValueError(f"variable parameter {k!r} must not be empty") - def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]: + def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]: for i in range(self.num_fibers): - for sim_config in self.iterate_single_fiber(i): - params = Parameters(**sim_config.config) - yield sim_config.descriptor, params + yield from self.iterate_single_fiber(i) - def iterate_single_fiber( - self, index: int - ) -> Generator["Configuration.__SimConfig", None, None]: + def iterate_single_fiber(self, index: int) -> Iterator[tuple[VariationDescriptor, Parameters]]: """iterates through the parameters of only one fiber. It takes care of recovering partially completed simulations, skipping complete ones and waiting for the previous fiber to finish @@ -909,6 +920,8 @@ class Configuration: __SimConfig configuration obj """ + if index < 0: + index = self.num_fibers + index sim_dict: dict[Path, Configuration.__SimConfig] = {} for descriptor in self.variationer.iterate(index): cfg = descriptor.update_config(self.fiber_configs[index]) @@ -929,7 +942,7 @@ class Configuration: task, config_dict = self.__decide(sim_config) if task == self.Action.RUN: sim_dict.pop(data_dir) - yield sim_config + yield sim_config.descriptor, Parameters(**sim_config.config) if "recovery_last_stored" in config_dict and self.skip_callback is not None: self.skip_callback(config_dict["recovery_last_stored"]) break diff --git a/src/scgenerator/_utils/variationer.py b/src/scgenerator/_utils/variationer.py index 8ca2e40..cbd5e8c 100644 --- a/src/scgenerator/_utils/variationer.py +++ b/src/scgenerator/_utils/variationer.py @@ -121,6 +121,19 @@ class VariationDescriptor(BaseModel): _format_registry: dict[str, Callable[..., str]] = {} __ids: dict[int, int] = {} + @classmethod + def register_formatter(cls, p_name: str, func: Callable[..., str]): + """register a function that formats a particular parameter + + Parameters + ---------- + p_name : str + name of the parameter + func : Callable[..., str] + function that takes as single argument the value of the parameter and returns a string + """ + cls._format_registry[p_name] = func + class Config: allow_mutation = False @@ -152,19 +165,6 @@ class VariationDescriptor(BaseModel): self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name ) - @classmethod - def register_formatter(cls, p_name: str, func: Callable[..., str]): - """register a function that formats a particular parameter - - Parameters - ---------- - p_name : str - name of the parameter - func : Callable[..., str] - function that takes as single argument the value of the parameter and returns a string - """ - cls._format_registry[p_name] = func - def format_value(self, name: str, value) -> str: if value is True or value is False: return str(value) @@ -201,9 +201,15 @@ class VariationDescriptor(BaseModel): def __ge__(self, other: "VariationDescriptor") -> bool: return self.raw_descr >= other.raw_descr + def __eq__(self, other: "VariationDescriptor") -> bool: + return self.raw_descr == other.raw_descr + def __hash__(self) -> int: return hash(self.raw_descr) + def __contains__(self, other: "VariationDescriptor") -> bool: + return all(el in self.raw_descr for el in other.raw_descr) + def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]: """updates a dictionary with the value of the descriptor @@ -223,6 +229,11 @@ class VariationDescriptor(BaseModel): out_cfg.pop("variable", None) return out_cfg | {k: v for k, v in self.raw_descr[index]} + def iter_parents(self) -> Iterator["VariationDescriptor"]: + if (p := self.parent) is not None: + yield from p.iter_parents() + yield self + @property def flat(self) -> list[tuple[str, Any]]: out = [] @@ -260,6 +271,10 @@ class VariationDescriptor(BaseModel): raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator ) + @property + def num_fibers(self) -> int: + return len(self.raw_descr) + class BranchDescriptor(VariationDescriptor): __ids: dict[int, int] = {} diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 6a9e376..2072806 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -1,20 +1,17 @@ from __future__ import annotations import os -import warnings -from collections.abc import Sequence from pathlib import Path -from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union +from typing import Any, Callable, Iterator, Optional, Union import matplotlib.pyplot as plt import numpy as np -from pydantic import BaseModel, DirectoryPath, root_validator from . import math from ._utils import load_spectrum from ._utils.parameter import Parameters from ._utils.utils import PlotRange, iter_simulations -from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1 +from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N from .logger import get_logger from .physics import pulse, units from .plotting import ( @@ -111,7 +108,7 @@ class Spectrum(np.ndarray): return self.params.l[np.argmax(self.wl_int, axis=-1)] return np.array([s.wl_max for s in self]) - def mask_wl(self, pos: float, width: float) -> "Spectrum": + def mask_wl(self, pos: float, width: float) -> Spectrum: return self * np.exp( -(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2) ) @@ -353,300 +350,3 @@ class SimulationSeries: return self.spectra(*key) else: return self.spectra(key, None) - - -class Pulse(Sequence): - def __new__(cls, path: os.PathLike): - warnings.warn( - "You are using the legacy version of the pulse loader. " - "Please consider updating your data with scgenerator.convert_sim_folder " - "and loading data with the SimulationSeries class" - ) - if (Path(path) / SPECN_FN1.format(0)).exists(): - return LegacyPulse(path) - return SimulationSeries(path) - - def __getitem__(self, key) -> Spectrum: - raise NotImplementedError() - - -class LegacyPulse(Sequence): - def __init__(self, path: os.PathLike): - """load a data folder as a pulse - - Parameters - ---------- - path : os.PathLike - path to the data (folder containing .npy files) - default_ind : int | Iterable[int], optional - default indices to be loaded, by default None - - Raises - ------ - FileNotFoundError - path does not contain proper data - """ - self.logger = get_logger(__name__) - self.path = Path(path) - - if not self.path.is_dir(): - raise FileNotFoundError(f"Folder {self.path} does not exist") - - self.params = Parameters.load(self.path / "params.toml") - self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"]) - if self.params.fiber_map is None: - self.params.fiber_map = [(0.0, self.params.name)] - - try: - self.z = np.load(os.path.join(path, "z.npy")) - except FileNotFoundError: - if self.params is not None: - self.z = self.params.z_targets - else: - raise - self.nmax = len(list(self.path.glob("spectra_*.npy"))) - if self.nmax <= 0: - raise FileNotFoundError(f"No appropriate file in specified folder {self.path}") - - self.t = self.params.t - w = math.wspace(self.t) + units.m(self.params.wavelength) - self.w_order = np.argsort(w) - self.w = w - self.wl = units.m.inv(self.w) - self.params.w = self.w - self.params.z_targets = self.z - - def __iter__(self): - """ - similar to all_spectra but works as an iterator - """ - - self.logger.debug(f"iterating through {self.path}") - for i in range(self.nmax): - yield self._load1(i) - - def __len__(self): - return self.nmax - - def __getitem__(self, key) -> Spectrum: - return self.all_spectra(key) - - def intensity(self, unit): - if unit.type in ["WL", "FREQ", "AFREQ"]: - x_axis = unit.inv(self.w) - else: - x_axis = unit.inv(self.t) - - order = np.argsort(x_axis) - func = dict( - WL=self._to_wl_int, - FREQ=self._to_freq_int, - AFREQ=self._to_afreq_int, - TIME=self._to_time_int, - )[unit.type] - - for spec in self: - yield x_axis[order], func(spec)[:, order] - - def _to_wl_int(self, spectrum): - return units.to_WL(math.abs2(spectrum), spectrum.wl) - - def _to_freq_int(self, spectrum): - return math.abs2(spectrum) - - def _to_afreq_int(self, spectrum): - return math.abs2(spectrum) - - def _to_time_int(self, spectrum): - return math.abs2(np.fft.ifft(spectrum)) - - def amplitude(self, unit): - if unit.type in ["WL", "FREQ", "AFREQ"]: - x_axis = unit.inv(self.w) - else: - x_axis = unit.inv(self.t) - - order = np.argsort(x_axis) - func = dict( - WL=self._to_wl_amp, - FREQ=self._to_freq_amp, - AFREQ=self._to_afreq_amp, - TIME=self._to_time_amp, - )[unit.type] - - for spec in self: - yield x_axis[order], func(spec)[:, order] - - def _to_wl_amp(self, spectrum): - return ( - np.sqrt( - units.to_WL( - math.abs2(spectrum), - spectrum.wl, - ) - ) - * spectrum - / np.abs(spectrum) - ) - - def _to_freq_amp(self, spectrum): - return spectrum - - def _to_afreq_amp(self, spectrum): - return spectrum - - def _to_time_amp(self, spectrum): - return np.fft.ifft(spectrum) - - def all_spectra(self, ind=None) -> Spectrum: - """ - loads the data already simulated. - defauft shape is (z_targets, n, nt) - - Parameters - ---------- - ind : int or list of int - if only certain spectra are desired - Returns - ---------- - spectra : array of shape (nz, m, nt) - array of complex spectra (pulse at nz positions consisting - of nm simulation on a nt size grid) - """ - - self.logger.debug(f"opening {self.path}") - - # Check if file exists and assert how many z positions there are - - if ind is None: - ind = range(self.nmax) - if isinstance(ind, (int, np.integer)): - ind = [ind] - elif isinstance(ind, (float, np.floating)): - ind = [self.z_ind(ind)] - elif isinstance(ind[0], (float, np.floating)): - ind = [self.z_ind(ii) for ii in ind] - - # Load the spectra - spectra = [] - for i in ind: - spectra.append(self._load1(i)) - spectra = Spectrum(spectra, self.params) - - self.logger.debug(f"all spectra from {self.path} successfully loaded") - if len(ind) == 1: - return spectra[0] - else: - return spectra - - def all_fields(self, ind=None): - return np.fft.ifft(self.all_spectra(ind=ind), axis=-1) - - def _load1(self, i: int): - if i < 0: - i = self.nmax + i - spec = load_spectrum(self.path / SPECN_FN1.format(i)) - spec = np.atleast_2d(spec) - spec = Spectrum(spec, self.params) - return spec - - def plot_2D( - self, - left: float, - right: float, - unit: Union[Callable[[float], float], str], - ax: plt.Axes, - z_pos: Union[int, Iterable[int]] = None, - sim_ind: int = 0, - **kwargs, - ): - plot_range = PlotRange(left, right, unit) - vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind) - return propagation_plot(vals, plot_range, self.params, ax, **kwargs) - - def plot_1D( - self, - left: float, - right: float, - unit: Union[Callable[[float], float], str], - ax: plt.Axes, - z_pos: int, - sim_ind: int = 0, - **kwargs, - ): - plot_range = PlotRange(left, right, unit) - vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind) - return single_position_plot(vals, plot_range, self.params, ax, **kwargs) - - def plot_mean( - self, - left: float, - right: float, - unit: Union[Callable[[float], float], str], - ax: plt.Axes, - z_pos: int, - **kwargs, - ): - plot_range = PlotRange(left, right, unit) - vals = self.retrieve_plot_values(plot_range, z_pos, slice(None)) - return mean_values_plot(vals, plot_range, self.params, ax, **kwargs) - - def retrieve_plot_values( - self, plot_range: PlotRange, z_pos: Optional[Union[int, float]], sim_ind: Optional[int] - ): - - if plot_range.unit.type == "TIME": - vals = self.all_fields(ind=z_pos) - else: - vals = self.all_spectra(ind=z_pos) - - if sim_ind is None: - return vals - elif z_pos is None: - return vals[:, sim_ind] - else: - return vals[sim_ind] - - def rin_propagation( - self, left: float, right: float, unit: str - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """returns the RIN as function of unit and z - - Parameters - ---------- - left : float - left limit in unit - right : float - right limit in unit - unit : str - unit descriptor - - Returns - ------- - x : np.ndarray, shape (nt,) - x axis - y : np.ndarray, shape (z_num, ) - y axis - rin_prop : np.ndarray, shape (z_num, nt) - RIN - """ - spectra = [] - for spec in np.moveaxis(self.all_spectra(), 1, 0): - x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False) - spectra.append(tmp) - return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1)) - - def z_ind(self, z: float) -> int: - """return the closest z index to the given target - - Parameters - ---------- - z : float - target - - Returns - ------- - int - index - """ - return math.argclosest(self.z, z)