diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 0976b8a..937e38a 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -1,9 +1,8 @@ from __future__ import annotations import os -import warnings from pathlib import Path -from typing import Callable +from typing import Callable, Generic, TypeVar, overload import numpy as np @@ -15,10 +14,12 @@ from scgenerator.io import ( ZipFileIOHandler, unique_name, ) +from scgenerator.logger import get_logger from scgenerator.parameter import Parameters from scgenerator.physics import pulse, units PARAMS_FN = "params.json" +ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None) class Spectrum(np.ndarray): @@ -138,27 +139,28 @@ class Spectrum(np.ndarray): return pulse.measure_field(self.t, self.time_amp) -class Propagation: +class Propagation(Generic[ParamsOrNone]): io: PropagationIOHandler - parameters: Parameters + parameters: ParamsOrNone _current_index: int def __init__( - self, + self: Propagation[ParamsOrNone], io_handler: PropagationIOHandler, - params: Parameters, + params: ParamsOrNone, ): """ A propagation is the object that manages IO for one single propagation. - It is recommended to use the "memory_propagation" and "zip_propagation" convenience - factories instead of creating this object directly. + It is recommended to use the "propagation" convenience factories instead of creating this + object directly. Parameters ---------- io : PropagationIOHandler object that implements the PropagationIOHandler Protocol. - params : Parameters - simulations parameters. Those will be saved via the + params : Parameters | None + simulations parameters. Those will be passed on to spectra when loaded. + if None, loading spectra will result in normal numpy arrays rather than `Spectrum` obj. """ self.io = io_handler self._current_index = len(self.io) @@ -167,26 +169,47 @@ class Propagation: def __len__(self) -> int: return self._current_index - def __getitem__(self, key: int | slice) -> Spectrum: + @overload + def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum: + ... + + @overload + def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray: + ... + + def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray: if isinstance(key, slice): return self._load_slice(key) if isinstance(key, (float, np.floating)): + if self.parameters is None: + raise ValueError(f"cannot accept float key {key} when parameters is not set") key = math.argclosest(self.parameters.compute("z_targets"), key) elif key < 0: key = len(self) + key array = self.io.load_spectrum(key) - return Spectrum(array, self.parameters) + if self.parameters is not None: + return Spectrum(array, self.parameters) + else: + return array def __setitem__(self, key: int, value: np.ndarray): if not isinstance(key, int): raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}") self.io.save_spectrum(key, np.asarray(value)) + @overload + def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum: + ... + + @overload + def _load_slice(self: Propagation[None], key: slice) -> np.ndarray: + ... + def _load_slice(self, key: slice) -> Spectrum: _iter = range(len(self))[key] - out = Spectrum( - np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters - ) + out = np.zeros((len(_iter), self.parameters.t_num), dtype=complex) + if self.parameters is not None: + out = Spectrum(out, self.parameters) for i in _iter: out[i] = self.io.load_spectrum(i) return out @@ -204,6 +227,7 @@ def propagation( params: Parameters | None = None, bundle_data: bool = False, overwrite: bool = False, + load_parameters: bool = True, ) -> Propagation: file = None if isinstance(file_or_params, Parameters): @@ -213,7 +237,7 @@ def propagation( if file is not None and file.exists() and params is None: io = ZipFileIOHandler(file) - return _open_existing_propagation(io) + return _open_existing_propagation(io, load_parameters=load_parameters) if params is None: raise ValueError("Parameters must be specified to create new simulation") @@ -235,7 +259,11 @@ def propagation( raise e -def _open_existing_propagation(io: PropagationIOHandler) -> Propagation: +def _open_existing_propagation( + io: PropagationIOHandler, load_parameters: bool = True +) -> Propagation: + if not load_parameters: + return Propagation(io, None) params = Parameters.from_json(io.load_data(PARAMS_FN).decode()) params.compile_in_place(exhaustive=True, strict=False) for k, v in params.items():