From 5adde638ef61154cd64c12d85bcc37b38cd22a4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 3 Oct 2023 16:52:23 +0200 Subject: [PATCH] basic PropagationCollection --- src/scgenerator/spectra.py | 169 ++++++++++++++++++++++++++++--------- tests/test_io_handlers.py | 29 ++++++- 2 files changed, 156 insertions(+), 42 deletions(-) diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index a5bd938..03b375b 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -2,8 +2,10 @@ from __future__ import annotations import os import warnings +from dataclasses import dataclass +from functools import cached_property from pathlib import Path -from typing import Callable, Generic, TypeVar, overload +from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload import numpy as np @@ -21,6 +23,7 @@ from scgenerator.physics import pulse, units PARAMS_FN = "params.json" ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None) +logger = get_logger(__name__) class Spectrum(np.ndarray): @@ -147,6 +150,9 @@ class Spectrum(np.ndarray): freq_amp = afreq_amp +NO_PARAMS = object() + + class Propagation(Generic[ParamsOrNone]): io: PropagationIOHandler parameters: ParamsOrNone @@ -155,7 +161,7 @@ class Propagation(Generic[ParamsOrNone]): def __init__( self: Propagation[ParamsOrNone], io_handler: PropagationIOHandler, - params: ParamsOrNone, + params: ParamsOrNone = NO_PARAMS, ): """ A propagation is the object that manages IO for one single propagation. @@ -166,15 +172,18 @@ class Propagation(Generic[ParamsOrNone]): ---------- io : PropagationIOHandler object that implements the PropagationIOHandler Protocol. - params : Parameters | None - simulations parameters. Those will be passed on to spectra when loaded. - if None, loading spectra will result in normal numpy arrays rather than `Spectrum` obj. + params : Parameters | None, optional + - if `Parameters`, those will be passed on to spectra when loaded to create `Spectrum` + objects. + - if explicitly None, loading spectra will result in normal numpy arrays rather than + `Spectrum` obj. + - if unspecified, parameters will be lazily loaded the first time the `parameters` + attribute is accessed (e.g. loading a spectrum). """ self.io = io_handler self._current_index = len(self.io) - self.parameters = params - if self.parameters is not None: - self.z_positions = self.parameters.compute("z_targets") + if params is not NO_PARAMS: + self.parameters = params def __len__(self) -> int: return self._current_index @@ -190,24 +199,23 @@ class Propagation(Generic[ParamsOrNone]): def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray: if isinstance(key, slice): return self._load_slice(key) - if isinstance(key, (float, np.floating)): - if self.parameters is None: - raise ValueError(f"cannot accept float key {key} when parameters is not set") - key = math.argclosest(self.z_positions, key) elif key < 0: self._warn_negative_index(key) key = len(self) + key array = self.io.load_spectrum(key) - if self.parameters is not None: - return Spectrum.from_params(array, self.parameters) - else: - return array + return self._spectrumize(array) def __setitem__(self, key: int, value: np.ndarray): if not isinstance(key, int): raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}") self.io.save_spectrum(key, np.asarray(value)) + def _spectrumize(self, array: np.ndarray | Spectrum) -> np.ndarray | Spectrum: + if self.parameters is not None and not isinstance(array, Spectrum): + return Spectrum.from_params(array, self.parameters) + else: + return array + @overload def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum: ... @@ -220,17 +228,23 @@ class Propagation(Generic[ParamsOrNone]): self._warn_negative_index(key.start) self._warn_negative_index(key.stop) _iter = range(len(self))[key] - # if self.parameters is not None: - # out = Spectrum.from_params( - # np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters - # ) - # for i in _iter: - # out[i] = self.io.load_spectrum(i) - # else: - out = np.array([self.io.load_spectrum(i) for i in _iter]) - if self.parameters is not None: - out = Spectrum.from_params(out, self.parameters) - return out + return self._spectrumize(np.array([self.io.load_spectrum(i) for i in _iter])) + + @cached_property + def parameters(self): + try: + return self.load_parameters() + except Exception as e: + logger.error(str(e)) + return None + + def load_parameters(self) -> Parameters: + params = Parameters.from_json(self.io.load_data(PARAMS_FN).decode()) + params.compile_in_place(exhaustive=True, strict=False) + for k, v in params.items(): + if isinstance(v, DataFile): + v.io = self.io + return params def append(self, spectrum: np.ndarray): self.io.save_spectrum(self._current_index, np.asarray(spectrum)) @@ -246,6 +260,30 @@ class Propagation(Generic[ParamsOrNone]): return self._load_slice(slice(None)) +@dataclass +class PropagationCollection: + propagations: list[Propagation] + parameters: Parameters + z: np.ndarray + t: np.ndarray + w: np.ndarray + wl: np.ndarray + + @overload + def __getitem__(self, key: int) -> Propagation: + ... + + @overload + def __getitem__(self, key: slice) -> list[Propagation]: + ... + + def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]: + return self.propagations[key] + + def __iter__(self) -> Iterator[Propagation]: + yield from self.propagations + + def propagation( file_or_params: os.PathLike | Parameters, params: Parameters | None = None, @@ -284,11 +322,12 @@ def propagation( file = Path(file_or_params) if file is not None and file.exists() and params is None: - io = ZipFileIOHandler(file) - return _open_existing_propagation(io, load_parameters=load_parameters) + return open_existing_propagation(file, load_parameters=load_parameters) if params is None: - raise ValueError("Parameters must be specified to create new simulation") + raise ValueError( + f"{file} doesn't exist, but no parameters have been passed to create a new propagation" + ) if file is not None: if file.exists() and params is not None: @@ -307,17 +346,67 @@ def propagation( raise e -def _open_existing_propagation( - io: PropagationIOHandler, load_parameters: bool = True -) -> Propagation: - if not load_parameters: +def propagation_series( + files: Sequence[os.PathLike], index: int | slice | None = None, progress_bar: bool = False +) -> tuple[Spectrum, PropagationCollection]: + """ + loads an existing sequence of propagation + + Parameters + ---------- + files : Sequence[os.PathLike] + sequence (e.g. list) of propagation zip files. + index : slice | int | None, optional + what index of each propagation to load (e.g. `None`->all spectra, `-1`->last spectrum), + by default None + progress_bar : bool, optional + print a progress bar to stderr as the files are loading (requires tqdm), by default False + + Returns + ------- + PropagationCollection + convenient object to work with multiple propagations **built on the same grid** + """ + if len(files) == 0: + raise ValueError("You must provide at least one file to build a propagation series") + + propagations = [open_existing_propagation(f, load_parameters=False) for f in files] + parameters = propagations[0].load_parameters() + + rest = propagations + if progress_bar: + try: + from tqdm import tqdm + except ModuleNotFoundError as e: + logger.error(str(e)) + pass + + rest = tqdm(rest) + + spectrum = Spectrum.from_params([prop[:] for prop in rest], parameters) + for prop in propagations: + del prop.parameters + + return spectrum, PropagationCollection( + propagations, + parameters, + z=parameters.z_targets[: spectrum.shape[0]], + t=spectrum.t, + w=spectrum.w_disp, + wl=spectrum.wl_disp, + ) + + +def open_existing_propagation(file: os.PathLike, load_parameters: bool = True) -> Propagation: + file = Path(file) + if not file.exists(): + raise FileNotFoundError(f"no propagation found at {file}") + io = ZipFileIOHandler(file) + + if load_parameters: + return Propagation(io) + else: return Propagation(io, None) - params = Parameters.from_json(io.load_data(PARAMS_FN).decode()) - params.compile_in_place(exhaustive=True, strict=False) - for k, v in params.items(): - if isinstance(v, DataFile): - v.io = io - return Propagation(io, params) def _create_new_propagation( diff --git a/tests/test_io_handlers.py b/tests/test_io_handlers.py index f8d648c..65198ee 100644 --- a/tests/test_io_handlers.py +++ b/tests/test_io_handlers.py @@ -7,7 +7,7 @@ import pytest from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name from scgenerator.parameter import Parameters -from scgenerator.spectra import PARAMS_FN, propagation +from scgenerator.spectra import PARAMS_FN, propagation, propagation_series PARAMS = dict( name="PM2000D sech pulse", @@ -25,7 +25,7 @@ PARAMS = dict( raman_type="measured", quantum_noise=True, interpolation_degree=11, - z_num=128, + z_num=8, length=1.5, t_num=512, dt=5e-15, @@ -173,3 +173,28 @@ def test_zip_bundle(tmp_path: Path): def test_unique_name(): existing = {"spec.npy", "spec_0.npy"} assert unique_name("spec.npy", existing) == "spec_1.npy" + + existing = {"spec", "spec0"} + assert unique_name("spec", existing) == "spec_0" + + existing = {"spec", "spec_0"} + assert unique_name("spec", existing) == "spec_1" + + +def test_propagation_series(tmp_path: Path): + params = Parameters(**PARAMS) + with pytest.raises(ValueError): + specs, props = propagation_series([]) + + flist = [tmp_path / f"prop{i}.zip" for i in range(10)] + for i, f in enumerate(flist): + params.name = f"prop {i}" + prop = propagation(f, params) + for _ in range(params.z_num): + prop.append(np.zeros(params.t_num, dtype=complex)) + + assert set(flist) == set(tmp_path.glob("*.zip")) + + specs, propagations = propagation_series(flist) + assert specs.shape == (10, params.z_num, params.t_num) + assert all(prop.parameters.name == f"prop {i}" for i, prop in enumerate(propagations))