diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 53ef5df..406062c 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -1,8 +1,14 @@ +from __future__ import annotations + import datetime import json +import os +from dataclasses import dataclass from pathlib import Path -from typing import Sequence +from typing import BinaryIO, Protocol, Sequence +from zipfile import ZipFile +import numpy as np import pkg_resources @@ -11,27 +17,207 @@ def data_file(path: str) -> Path: return Path(pkg_resources.resource_filename("scgenerator", path)) -class DatetimeEncoder(json.JSONEncoder): +class CustomEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, (datetime.date, datetime.datetime)): return obj.isoformat() + elif isinstance(obj, np.ndarray): + return tuple(obj) + elif isinstance(obj, DataFile): + return obj.__json__() -def decode_datetime_hook(obj): +def _decode_datetime(s: str) -> datetime.date | datetime.datetime | str: + try: + return datetime.datetime.fromisoformat(s) + except Exception: + pass + try: + return datetime.date.fromisoformat(s) + except Exception: + pass + return s + + +def custom_decode_hook(obj): + """ + Some simple, non json-compatible objects are encoded as str, this function reconstructs the + original object and sets it in the decoded json structure + """ for k, v in obj.items(): - if not isinstance(v, str): - continue - try: - dt = datetime.datetime.fromisoformat(v) - except Exception: - try: - dt = datetime.date.fromisoformat(v) - except Exception: - continue - obj[k] = dt + if isinstance(v, str): + obj[k] = _decode_datetime(v) + elif isinstance(v, list): + obj[k] = tuple(v) return obj +class PropagationIOHandler(Protocol): + def __len__(self) -> int: + ... + + def keys(self) -> list[str]: + ... + + def save_spectrum(self, index: int, spectrum: np.ndarray): + ... + + def load_spectrum(self, index: int) -> np.ndarray: + ... + + def save_data(self, name: str, data: bytes): + ... + + def load_data(self, name: str) -> bytes: + ... + + +class MemoryIOHandler: + spectra: dict[int, np.ndarray] + data: dict[str, bytes] + + def __init__(self): + self.spectra = {} + self.data = {} + + def __len__(self) -> int: + return len(self.spectra) + + def keys(self) -> list[str]: + return list(self.data.keys()) + + def save_spectrum(self, index: int, spectrum: np.ndarray): + self.spectra[index] = spectrum + + def load_spectrum(self, index: int) -> np.ndarray: + return self.spectra[index] + + def save_data(self, name: str, data: bytes): + self.data[name] = data + + def load_data(self, name: str) -> bytes: + return self.data[name] + + +class ZipFileIOHandler: + file: BinaryIO + SPECTRUM_FN = "spectra/spectrum_{}.npy" + + def __init__(self, file: os.PathLike): + """ + Create a IO handler to be used in Propagation. + This handler saves spectra as numpy `.npy` files. + + Parameters + ---------- + file : os.PathLike + path to the desired zip file. Will be created if it doesn't exist. + Passing in an already opened file-like object is not supported at the moment + """ + self.file = Path(file) + if not self.file.exists(): + ZipFile(self.file, "w").close() + + def __len__(self) -> int: + with ZipFile(self.file, "r") as zip_file: + i = 0 + while True: + try: + zip_file.open(self.SPECTRUM_FN.format(i)) + except KeyError: + return i + i += 1 + + def keys(self) -> list[str]: + with ZipFile(self.file, "r") as zip_file: + return zip_file.namelist() + + def save_spectrum(self, index: int, spectrum: np.ndarray): + with ZipFile(self.file, "a") as zip_file, zip_file.open( + self.SPECTRUM_FN.format(index), "w" + ) as file: + np.lib.format.write_array(file, spectrum, allow_pickle=False) + + def load_spectrum(self, index: int) -> np.ndarray: + with ZipFile(self.file, "r") as zip_file, zip_file.open( + self.SPECTRUM_FN.format(index), "r" + ) as file: + return np.load(file) + + def save_data(self, name: str, data: bytes): + with ZipFile(self.file, "a") as zip_file, zip_file.open(name, "w") as file: + file.write(data) + + def load_data(self, name: str) -> bytes: + with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file: + return file.read() + + +@dataclass +class DataFile: + """ + Holds information about external files necessary for a simulation. + In the current implementation, only reading data from the file is supported + """ + + prefix: str | None + path: str + io: PropagationIOHandler + + PREFIX_SEP = "::" + + @classmethod + def from_str(cls, s: str, io: PropagationIOHandler | None = None) -> DataFile: + if cls.PREFIX_SEP in s: + prefix, path = s.split(cls.PREFIX_SEP, 1) + else: + prefix = None + path = s + return cls(prefix, path, io) + + @classmethod + def validate(cls, name: str, data: str | DataFile) -> DataFile: + """To be used to automatically construct a DataFile when creating a Parameters obj""" + if isinstance(data, cls): + return data + elif isinstance(data, str): + return cls.from_str(data) + elif isinstance(data, Path): + return cls.from_str(os.fspath(data)) + else: + raise TypeError( + f"{name!r} must be a path or a bundled file specifier, not a {type(data)!r}" + ) + + def __json__(self) -> str: + if self.prefix is None: + return os.fspath(Path(self.path)) + return self.prefix + self.PREFIX_SEP + self.path + + def load_data(self) -> bytes: + if self.prefix is not None and self.io is None: + raise ValueError( + f"a bundled file prefixed with {self.prefix} " + "must have a PropagationIOHandler attached" + ) + if self.io is not None: + return self.io.load_data(self.path) + else: + return Path(self.path).read_bytes() + + +def unique_name(base_name: str, existing: set[str]) -> str: + name = base_name + p = Path(base_name) + base = p.stem + ext = p.suffix + i = 0 + while name in existing: + name = f"{base}_{i}{ext}" + i += 1 + return name + + def format_graph(left_elements: Sequence[str], middle: str, right_elements: Sequence[str]): if len(left_elements) == 0: left_elements = [""] diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 71318ba..7d36d04 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -4,7 +4,7 @@ import datetime as datetime_module import json import os from copy import copy, deepcopy -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from functools import lru_cache, wraps from math import isnan from pathlib import Path @@ -12,12 +12,10 @@ from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeV import numpy as np -from scgenerator import utils from scgenerator.const import MANDATORY_PARAMETERS, __version__ from scgenerator.evaluator import Evaluator, EvaluatorError -from scgenerator.io import DatetimeEncoder, decode_datetime_hook +from scgenerator.io import CustomEncoder, DataFile, custom_decode_hook from scgenerator.operators import Qualifier, SpecOperator -from scgenerator.utils import update_path_name T = TypeVar("T") DISPLAY_INFO = {} @@ -77,6 +75,11 @@ def string(name, n): raise ValueError(f"{name!r} must not be empty") +def low_string(name, n): + string(name, n) + return n.lower() + + def in_range_excl(_min, _max): @type_checker(float, int) def _in_range(name, n): @@ -255,9 +258,6 @@ class Parameter: del instance._param_dico[self.name] def __set__(self, instance: Parameters, value): - if instance._frozen: - raise AttributeError("Parameters instance is frozen and can no longer be modified") - if isinstance(value, Parameter): if self.default is not None: instance._param_dico[self.name] = copy(self.default) @@ -288,9 +288,9 @@ class Parameter: except TypeError: is_value = True if is_value: - if self.converter is not None: - v = self.converter(v) - self._validator(self.name, v) + ret_val = self._validator(self.name, v) + if ret_val is not None: + v = ret_val return is_value, v @@ -307,7 +307,6 @@ class Parameters: # root name: str = Parameter(string, default="no name") - output_path: Path = Parameter(type_checker(Path), converter=Path) # fiber input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0) @@ -315,10 +314,10 @@ class Parameters: n2: float = Parameter(non_negative(float, int)) chi3: float = Parameter(non_negative(float, int)) loss: str = Parameter(literal("capillary")) - loss_file: str = Parameter(string) + loss_file: DataFile = Parameter(DataFile.validate) effective_mode_diameter: float = Parameter(positive(float, int)) effective_area: float = Parameter(non_negative(float, int)) - effective_area_file: str = Parameter(string) + effective_area_file: DataFile = Parameter(DataFile.validate) numerical_aperture: float = Parameter(in_range_excl(0, 1)) pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm")) pcf_pitch_ratio: float = Parameter(in_range_excl(0, 1)) @@ -326,7 +325,7 @@ class Parameters: he_mode: tuple[int, int] = Parameter(int_pair, default=(1, 1)) fit_parameters: tuple[int, int] = Parameter(float_pair, default=(0.08, 200e-9)) beta2_coefficients: Iterable[float] = Parameter(num_list) - dispersion_file: str = Parameter(string) + dispersion_file: DataFile = Parameter(DataFile.validate) model: str = Parameter( literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"), ) @@ -345,7 +344,7 @@ class Parameters: capillary_nested: int = Parameter(non_negative(int), default=0) # gas - gas_name: str = Parameter(string, converter=str.lower, default="vacuum") + gas_name: str = Parameter(low_string, default="vacuum") pressure: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar")) pressure_in: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar")) pressure_out: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar")) @@ -353,7 +352,7 @@ class Parameters: plasma_density: float = Parameter(non_negative(float, int), default=0) # pulse - field_file: str = Parameter(string) + field_file: DataFile = Parameter(DataFile.validate) input_time: np.ndarray = Parameter(type_checker(np.ndarray)) input_field: np.ndarray = Parameter(type_checker(np.ndarray)) repetition_rate: float = Parameter( @@ -380,11 +379,9 @@ class Parameters: # simulation full_field: bool = Parameter(boolean, default=False) integration_scheme: str = Parameter( - literal("erk43", "erk54", "cqe", "sd", "constant"), - converter=str.lower, - default="erk43", + literal("erk43", "erk54", "cqe", "sd", "constant"), default="erk43" ) - raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower) + raman_type: str = Parameter(literal("measured", "agrawal", "stolen")) raman_fraction: float = Parameter(non_negative(float, int)) spm: bool = Parameter(boolean, default=True) repeat: int = Parameter(positive(int), default=1) @@ -437,7 +434,8 @@ class Parameters: @classmethod def from_json(cls, s: str) -> Parameters: - return cls(**json.loads(s, object_hook=decode_datetime_hook)) + decoded = json.loads(s, object_hook=custom_decode_hook) + return cls(**decoded) @classmethod def load(cls, path: os.PathLike) -> Parameters: @@ -462,18 +460,32 @@ class Parameters: self.__post_init__() def __setattr__(self, k, v): - if self._frozen: + if self._frozen and not k.endswith("_file"): raise AttributeError( f"cannot set attribute to frozen {self.__class__.__name__} instance" ) object.__setattr__(self, k, v) - def copy(self) -> Parameters: - return Parameters(**deepcopy(self.strip_params_dict())) + def items(self) -> Iterator[tuple[str, Any]]: + for k, v in self._param_dico.items(): + if v is None: + continue + yield k, v + + def copy(self, freeze: bool = False) -> Parameters: + """create a deep copy of self. if freeze is True, the returned copy is read-only""" + params = Parameters(**deepcopy(self.strip_params_dict())) + if freeze: + params.freeze() + return params + + def freeze(self): + """render the current instance read-only. This is not reversible""" + self._frozen = True def to_json(self) -> str: d = self.dump_dict() - return json.dumps(d, cls=DatetimeEncoder, default=list) + return json.dumps(d, cls=CustomEncoder, indent=4) def get_evaluator(self): evaluator = Evaluator.default(self.full_field) @@ -611,7 +623,7 @@ class Parameters: "linear_op", "c_to_a_factor", } - types = (np.ndarray, float, int, str, list, tuple, Path) + types = (np.ndarray, float, int, str, list, tuple, Path, DataFile) c = deepcopy if copy else lambda x: x out = {} for key, value in self._param_dico.items(): diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 45a3772..e7aeb8d 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -1,3 +1,4 @@ +from io import BytesIO from typing import Iterable, TypeVar import numpy as np @@ -7,6 +8,7 @@ from numpy.polynomial.chebyshev import Chebyshev, cheb2poly from scipy.interpolate import interp1d from scgenerator import io +from scgenerator.io import DataFile from scgenerator.math import argclosest, u_nm from scgenerator.physics import materials as mat from scgenerator.physics import units @@ -653,7 +655,7 @@ def saitoh_paramters(pcf_pitch_ratio: float) -> tuple[float, float]: return A, B -def load_custom_effective_area(effective_area_file: str, l: np.ndarray) -> np.ndarray: +def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) -> np.ndarray: """ loads custom effective area file @@ -669,14 +671,14 @@ def load_custom_effective_area(effective_area_file: str, l: np.ndarray) -> np.nd np.ndarray, shape (n,) wl-dependent effective mode field area """ - data = np.load(effective_area_file) + data = np.load(BytesIO(effective_area_file.load_data())) effective_area = data.get("A_eff", data.get("effective_area")) wl = data["wavelength"] return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l) -def load_custom_dispersion(dispersion_file: str) -> tuple[np.ndarray, np.ndarray]: - disp_file = np.load(dispersion_file) +def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]: + disp_file = np.load(BytesIO(dispersion_file.load_data())) wl_for_disp = disp_file["wavelength"] interp_range = (np.min(wl_for_disp), np.max(wl_for_disp)) D = disp_file["dispersion"] @@ -684,7 +686,7 @@ def load_custom_dispersion(dispersion_file: str) -> tuple[np.ndarray, np.ndarray return wl_for_disp, beta2, interp_range -def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray: +def load_custom_loss(l: np.ndarray, loss_file: DataFile) -> np.ndarray: """ loads a npz loss file that contains a wavelength and a loss entry @@ -700,7 +702,7 @@ def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray: np.ndarray, shape (n,) loss in 1/m units """ - loss_data = np.load(loss_file) + loss_data = np.load(BytesIO(loss_file.load_data())) wl = loss_data["wavelength"] loss = loss_data["loss"] return interp1d(wl, loss, fill_value=0, bounds_error=False)(l) diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 07c785e..4ece788 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -12,6 +12,7 @@ n is the number of spectra at the same z position and nt is the size of the time import itertools import os from dataclasses import astuple, dataclass +from io import BytesIO from pathlib import Path from typing import Literal, Tuple, TypeVar @@ -25,6 +26,7 @@ from scipy.optimize._optimize import OptimizeResult from scgenerator import math from scgenerator.defaults import default_plotting +from scgenerator.io import DataFile from scgenerator.physics import units c = 299792458.0 @@ -410,8 +412,9 @@ def interp_custom_field( return field_0 -def load_custom_field(field_file: str) -> tuple[np.ndarray, np.ndarray]: - field_data = np.load(field_file) +def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]: + data = field_file.load_data() + field_data = np.load(BytesIO(data)) return field_data["time"], field_data["field"] diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index acd9e51..d6219d6 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -1,37 +1,32 @@ from __future__ import annotations import os +import warnings from pathlib import Path -from typing import Callable, Iterator, Optional, Union -import matplotlib.pyplot as plt import numpy as np from scgenerator import math -from scgenerator.const import PARAM_FN, SPEC1_FN, SPEC1_FN_N -from scgenerator.logger import get_logger +from scgenerator.io import DataFile, PropagationIOHandler, ZipFileIOHandler, unique_name from scgenerator.parameter import Parameters from scgenerator.physics import pulse, units -from scgenerator.physics.units import PlotRange -from scgenerator.plotting import ( - mean_values_plot, - propagation_plot, - single_position_plot, - transform_1D_values, - transform_2D_propagation, -) -from scgenerator.utils import load_spectrum, load_toml, simulations_list class Spectrum(np.ndarray): - params: Parameters + w: np.ndarray + t: np.ndarray + l: np.ndarray + ifft: Callable[[np.ndarray], np.ndarray] def __new__(cls, input_array, params: Parameters): # Input array is an already formed ndarray instance # We first cast to be our class type obj = np.asarray(input_array).view(cls) # add the new attribute to the created instance - obj.params = params + obj.w = params.compute("w") + obj.t = params.compute("t") + obj.t = params.compute("t") + obj.ifft = params.compute("ifft") # Finally, we must return the newly created object: return obj @@ -40,14 +35,17 @@ class Spectrum(np.ndarray): # see InfoArray.__array_finalize__ for comments if obj is None: return - self.params = getattr(obj, "params", None) + self.w = getattr(obj, "w", None) + self.t = getattr(obj, "t", None) + self.l = getattr(obj, "l", None) + self.ifft = getattr(obj, "ifft", None) def __getitem__(self, key) -> "Spectrum": return super().__getitem__(key) @property def wl_int(self): - return units.to_WL(math.abs2(self), self.params.l) + return units.to_WL(math.abs2(self), self.l) @property def freq_int(self): @@ -59,13 +57,13 @@ class Spectrum(np.ndarray): @property def time_int(self): - return math.abs2(self.params.ifft(self)) + return math.abs2(self.ifft(self)) def amplitude(self, unit): if unit.type in ["WL", "FREQ", "AFREQ"]: - x_axis = unit.inv(self.params.w) + x_axis = unit.inv(self.w) else: - x_axis = unit.inv(self.params.t) + x_axis = unit.inv(self.t) order = np.argsort(x_axis) func = dict( @@ -84,7 +82,7 @@ class Spectrum(np.ndarray): np.sqrt( units.to_WL( math.abs2(self), - self.params.l, + self.l, ) ) * self @@ -106,311 +104,115 @@ class Spectrum(np.ndarray): @property def wl_max(self): if self.ndim == 1: - return self.params.l[np.argmax(self.wl_int, axis=-1)] + return self.l[np.argmax(self.wl_int, axis=-1)] return np.array([s.wl_max for s in self]) def mask_wl(self, pos: float, width: float) -> Spectrum: - return self * np.exp( - -(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2) - ) + return self * np.exp(-(((self.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)) def measure(self) -> tuple[float, float, float]: - return pulse.measure_field(self.params.t, self.time_amp) + return pulse.measure_field(self.t, self.time_amp) -class SimulationSeries: - """ - SimulationsSeries are the interface the user should use to load and - interact with simulation data. The object loads each fiber of the simulation - into a separate object and exposes convenience methods to make the series behave - as a single fiber. - - It should be noted that the last spectrum of a fiber and the first one of the next - fibers are identical. Therefore, SimulationSeries object will return fewer datapoints - than when manually mergin the corresponding data. - - """ - - path: Path - fibers: list[SimulatedFiber] +class Propagation: + io: PropagationIOHandler params: Parameters - z_indices: list[tuple[int, int]] - fiber_positions: list[tuple[str, float]] + _current_index: int - def __init__(self, path: os.PathLike): + PARAMS_FN = "params.json" + + def __init__( + self, + io_handler: PropagationIOHandler, + params: Parameters | None = None, + bundle_data: bool = False, + ): """ - Create a SimulationSeries + A propagation is the object that manages IO for one single propagation. + It is recommended to use the "memory_propagation" and "zip_propagation" convenience + factories instead of creating this object directly. Parameters ---------- - path : os.PathLike - path to the last fiber of the series - - Raises - ------ - FileNotFoundError - No simulation found in specified directory + io : PropagationIOHandler + object that implements the PropagationIOHandler Protocol. + params : Parameters + simulations parameters. Those will be saved via the """ - self.logger = get_logger() - for self.path in simulations_list(path): - break - else: - raise FileNotFoundError(f"No simulation in {path}") - self.fibers = [SimulatedFiber(self.path)] - while (p := self.fibers[-1].params.prev_data_dir) is not None: - p = Path(p) - if not p.is_absolute(): - p = Path(self.fibers[-1].params.output_path) / p - self.fibers.append(SimulatedFiber(p)) - self.fibers = self.fibers[::-1] + self.io = io_handler + self._current_index = len(self.io) - self.fiber_positions = [(self.fibers[0].params.name, 0.0)] - self.params = Parameters(**self.fibers[0].params.dump_dict(False, False)) - z_targets = list(self.params.z_targets) - self.z_indices = [(0, j) for j in range(self.params.z_num)] - for i, fiber in enumerate(self.fibers[1:]): - self.fiber_positions.append((fiber.params.name, z_targets[-1])) - z_targets += list(fiber.params.z_targets[1:] + z_targets[-1]) - self.z_indices += [(i + 1, j) for j in range(1, fiber.params.z_num)] - self.params.z_targets = np.array(z_targets) - self.params.length = self.params.z_targets[-1] - self.params.z_num = len(self.params.z_targets) + new_params = params is not None + if not new_params: + if bundle_data: + raise ValueError( + "cannot bundle data to existing Propagation. Create a new one instead" + ) + try: + params_data = self.io.load_data(self.PARAMS_FN) + params = Parameters.from_json(params_data.decode()) + except KeyError: + raise ValueError(f"Missing Parameters in {self.__class__.__name__}.") from None - def spectra( - self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0 - ) -> Spectrum: - ... - if z_descr is None: - out = [self.fibers[i].spectra(j, sim_ind) for i, j in self.z_indices] - else: - if isinstance(z_descr, (float, np.floating)): - fib_ind, z_ind = self.z_ind(z_descr) - else: - fib_ind, z_ind = self.z_indices[z_descr] - out = self.fibers[fib_ind].spectra(z_ind, sim_ind) - return Spectrum(out, self.params) + self.params = params - def fields( - self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0 - ) -> Spectrum: - return self.params.ifft(self.spectra(z_descr, sim_ind)) + if bundle_data: + if not params._frozen: + warnings.warn( + f"Parameters instance {params.name!r} is not frozen but will be saved " + "in an unmodifiable state anyway." + ) + _bundle_external_files(self.params, self.io) - def z_ind(self, pos: float) -> tuple[int, int]: - if 0 <= pos <= self.params.length: - ind = np.argmin(np.abs(self.params.z_targets - pos)) - return self.z_indices[ind] - else: - raise ValueError(f"cannot match z={pos} with max length of {self.params.length}") + if new_params: + self.io.save_data(self.PARAMS_FN, self.params.to_json().encode()) - def plot_2D( - self, - left: float, - right: float, - unit: Union[Callable[[float], float], str], - ax: plt.Axes, - sim_ind: int = 0, - **kwargs, - ): - plot_range = PlotRange(left, right, unit) - vals = self.retrieve_plot_values(plot_range, None, sim_ind) - return propagation_plot(vals, plot_range, self.params, ax, **kwargs) + def __len__(self) -> int: + return self._current_index - def plot_values_2D( - self, - left: float, - right: float, - unit: Union[Callable[[float], float], str], - sim_ind: int = 0, - **kwargs, - ): - plot_range = PlotRange(left, right, unit) - vals = self.retrieve_plot_values(plot_range, None, sim_ind) - return transform_2D_propagation(vals, plot_range, self.params, **kwargs) + def __getitem__(self, key: int | slice) -> Spectrum: + if isinstance(key, slice): + return self._load_slice(key) + if isinstance(key, (float, np.floating)): + key = math.argclosest(self.params.compute("z_targets"), key) + array = self.io.load_spectrum(key) + return Spectrum(array, self.params) - def plot_1D( - self, - left: float, - right: float, - unit: Union[Callable[[float], float], str], - ax: plt.Axes, - z_pos: int, - sim_ind: int = 0, - **kwargs, - ): - plot_range = PlotRange(left, right, unit) - vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind) - return single_position_plot(vals, plot_range, self.params, ax, **kwargs) + def __setitem__(self, key: int, value: np.ndarray): + if not isinstance(key, int): + raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}") + self.io.save_spectrum(key, np.asarray(value)) - def plot_values_1D( - self, - left: float, - right: float, - unit: Union[Callable[[float], float], str], - z_pos: int, - sim_ind: int = 0, - **kwargs, - ) -> tuple[np.ndarray, np.ndarray]: - """ - gives the desired values already tranformes according to the give range + def _load_slice(self, key: slice) -> Spectrum: + _iter = range(len(self))[key] + out = Spectrum(np.zeros((len(_iter), self.params.t_num), dtype=complex), self.params) + for i in _iter: + out[i] = self.io.load_spectrum(i) + return out - Parameters - ---------- - left : float - leftmost limit in unit - right : float - rightmost limit in unit - unit : Union[Callable[[float], float], str] - unit - z_pos : Union[int, float] - position either as an index (int) or a real position (float) - sim_ind : Optional[int] - which simulation to take when more than one are present + def append(self, spectrum: np.ndarray): + self.io.save_spectrum(self._current_index, spectrum.asarray()) + self._current_index += 1 - Returns - ------- - np.ndarray - x axis - np.ndarray - y values - """ - plot_range = PlotRange(left, right, unit) - vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind) - return transform_1D_values(vals, plot_range, self.params, **kwargs) - - def plot_mean( - self, - left: float, - right: float, - unit: Union[Callable[[float], float], str], - ax: plt.Axes, - z_pos: int, - **kwargs, - ): - plot_range = PlotRange(left, right, unit) - vals = self.retrieve_plot_values(plot_range, z_pos, None) - return mean_values_plot(vals, plot_range, self.params, ax, **kwargs) - - def retrieve_plot_values( - self, plot_range: PlotRange, z_pos: Optional[Union[int, float]], sim_ind: Optional[int] - ): - if plot_range.unit.type == "TIME": - return self.fields(z_pos, sim_ind) - else: - return self.spectra(z_pos, sim_ind) - - def rin_propagation( - self, left: float, right: float, unit: str - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - returns the RIN as function of unit and z - - Parameters - ---------- - left : float - left limit in unit - right : float - right limit in unit - unit : str - unit descriptor - - Returns - ------- - x : np.ndarray, shape (nt,) - x axis - y : np.ndarray, shape (z_num, ) - y axis - rin_prop : np.ndarray, shape (z_num, nt) - RIN - """ - spectra = [] - for spec in np.moveaxis(self.spectra(None, None), 1, 0): - x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False) - spectra.append(tmp) - return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1)) - - # Magic methods - - def __iter__(self) -> Iterator[Spectrum]: - for i, j in self.z_indices: - yield self.fibers[i].spectra(j, None) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(path={self.path})" - - def __eq__(self, other: SimulationSeries) -> bool: - return self.path == other.path and self.params == other.params - - def __contains__(self, fiber: SimulatedFiber) -> bool: - return fiber in self.fibers - - def __getitem__(self, key) -> Spectrum: - if isinstance(key, tuple): - return self.spectra(*key) - else: - return self.spectra(key, None) + def load_all(self) -> Spectrum: + return self._load_slice(slice(None)) -class SimulatedFiber: - params: Parameters - t: np.ndarray - w: np.ndarray +def _bundle_external_files(params: Parameters, io: PropagationIOHandler): + """copies every external file specified in the parameters and saves it""" + existing_files = set(io.keys()) + for _, value in params.items(): + if isinstance(value, DataFile): + data = value.load_data() - def __init__(self, path: os.PathLike): - self.path = Path(path) - self.params = Parameters(**load_toml(self.path / PARAM_FN)) - self.params.output_path = str(self.path.resolve()) - self.t = self.params.t - self.w = self.params.w - self.z = self.params.z_targets + value.io = io + value.prefix = "zip" + value.path = unique_name(Path(value.path).name, existing_files) + existing_files.add(value.path) - def spectra( - self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0 - ) -> np.ndarray: - if z_descr is None: - out = [self.spectra(i, sim_ind) for i in range(self.params.z_num)] - else: - if isinstance(z_descr, (float, np.floating)): - return self.spectra(self.z_ind(z_descr), sim_ind) - else: - z_ind = z_descr + io.save_data(value.path, data) - if z_ind < 0: - z_ind = self.params.z_num + z_ind - if sim_ind is None: - out = [self._load_1(z_ind, i) for i in range(self.params.repeat)] - else: - out = self._load_1(z_ind) - return Spectrum(out, self.params) - - def z_ind(self, pos: float) -> int: - if 0 <= pos <= self.z[-1]: - return np.argmin(np.abs(self.z - pos)) - else: - raise ValueError(f"cannot match z={pos} with max length of {self.params.length}") - - def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray: - """ - loads a spectrum file - - Parameters - ---------- - z_ind : int - z_index relative to the entire simulation - sim_ind : int, optional - simulation index, used when repeated simulations with same parameters are ran, - by default 0 - - Returns - ------- - np.ndarray - loaded spectrum file - """ - if sim_ind > 0: - return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind)) - else: - return load_spectrum(self.path / SPEC1_FN.format(z_ind)) - psd = np.fft.rfft(signal) / np.sqrt(0.5 * len(time) / dt) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(path={self.path})" +def load_all(path: os.PathLike) -> Spectrum: + io = ZipFileIOHandler(path) + return Propagation(io).load_all() diff --git a/tests/data/PM2000D_2 extrapolated 4 0.npz b/tests/data/PM2000D_2 extrapolated 4 0.npz new file mode 100644 index 0000000..1a9da5e Binary files /dev/null and b/tests/data/PM2000D_2 extrapolated 4 0.npz differ diff --git a/tests/data/PM2000D_A_eff_marcuse.npz b/tests/data/PM2000D_A_eff_marcuse.npz new file mode 100644 index 0000000..41b572e Binary files /dev/null and b/tests/data/PM2000D_A_eff_marcuse.npz differ diff --git a/tests/test_io_handlers.py b/tests/test_io_handlers.py new file mode 100644 index 0000000..4492441 --- /dev/null +++ b/tests/test_io_handlers.py @@ -0,0 +1,124 @@ +import json +from pathlib import Path +from zipfile import ZipFile + +import numpy as np +import pytest + +from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name +from scgenerator.parameter import Parameters +from scgenerator.spectra import Propagation + +PARAMS = dict( + name="PM2000D sech pulse", + # pulse + wavelength=1546e-9, + # field_file="./Pos30000New.npz", + width=1.5e-12, + shape="sech", + repetition_rate=40e6, + # fiber + dispersion_file="./PM2000D_2 extrapolated 4 0.npz", + effective_area_file="./PM2000D_A_eff_marcuse.npz", + wavelength_window=(400e-9, 4000e-9), + n2=4.5e-20, + # simulation + raman_type="measured", + quantum_noise=True, + interpolation_degree=11, + z_num=128, + length=1.5, + t_num=512, + dt=5e-15, +) + + +def test_file(tmp_path: Path): + params = Parameters(**PARAMS) + stuff = np.random.rand(8, 512) + io = ZipFileIOHandler(tmp_path / "test.zip") + io.save_data("params.json", params.to_json().encode()) + for i, spec in enumerate(stuff): + io.save_spectrum(i, spec) + + new_params = Parameters.from_json(io.load_data("params.json").decode()) + assert new_params is not params + for k, v in params.items(): + v_new = getattr(new_params, k) + if isinstance(v, DataFile): + assert Path(v.path).name == Path(v_new.path).name + else: + assert v == getattr(new_params, k) + + for i in range(8): + assert np.all(io.load_spectrum(i) == stuff[i]) + + assert len(ZipFileIOHandler(tmp_path / "test.zip")) == len(io) == 8 + + +def test_memory(): + params = Parameters(**PARAMS) + stuff = np.random.rand(8, 512) + io = MemoryIOHandler() + assert len(io) == 0 + io.save_data("params.json", params.to_json().encode()) + for i, spec in enumerate(stuff): + io.save_spectrum(i, spec) + + new_params = Parameters.from_json(io.load_data("params.json").decode()) + for k, v in params.items(): + v_new = getattr(new_params, k) + if isinstance(v, DataFile): + assert Path(v.path).name == Path(v_new.path).name + else: + assert v == getattr(new_params, k) + for i in range(8): + assert np.all(io.load_spectrum(i) == stuff[i]) + + assert len(io) == 8 + + +def test_zip_bundle(tmp_path: Path): + params = Parameters(**PARAMS) + io = ZipFileIOHandler(tmp_path / "file.zip") + prop = Propagation(io, params.copy(True)) + + assert (tmp_path / "file.zip").exists() + assert (tmp_path / "file.zip").read_bytes() != b"" + + new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz") + new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz") + params.dispersion_file.path = str(new_disp_path) + params.effective_area_file.path = str(new_aeff_path) + params.freeze() + + io = ZipFileIOHandler(tmp_path / "file2.zip") + prop2 = Propagation(io, params, True) + + assert params.dispersion_file.path == new_disp_path.name + assert params.dispersion_file.prefix == "zip" + assert params.effective_area_file.path == new_aeff_path.name + assert params.effective_area_file.prefix == "zip" + + with ZipFile(tmp_path / "file2.zip", "r") as zfile: + with zfile.open(new_aeff_path.name) as file: + assert file.read() == new_aeff_path.read_bytes() + with zfile.open(new_disp_path.name) as file: + assert file.read() == new_disp_path.read_bytes() + + with zfile.open(Propagation.PARAMS_FN) as file: + df = json.loads(file.read().decode()) + + assert ( + df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path + ) + + assert ( + df["effective_area_file"] + == params.effective_area_file.prefix + "::" + params.effective_area_file.path + ) + + +def test_unique_name(): + existing = {"spec.npy", "spec_0.npy"} + assert unique_name("spec.npy", existing) == "spec_1.npy"