diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 043be3d..4b9812e 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -1,8 +1,10 @@ # ruff: noqa from scgenerator import io, math, operators, plotting from scgenerator.helpers import * +from scgenerator.io import MemoryIOHandler, ZipFileIOHandler from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace from scgenerator.parameter import Parameters from scgenerator.physics import fiber, materials, plasma, pulse, units from scgenerator.physics.units import PlotRange from scgenerator.solver import SimulationResult, integrate, solve43 +from scgenerator.spectra import Propagation, Spectrum diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index 470ec17..87992b3 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -33,6 +33,7 @@ MANDATORY_PARAMETERS = { "name", "w", "t", + "l", "fft", "ifft", "w0", diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 9876486..6bc790d 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -72,6 +72,9 @@ class PropagationIOHandler(Protocol): def load_data(self, name: str) -> bytes: ... + def clear(self): + ... + class MemoryIOHandler: spectra: dict[int, np.ndarray] @@ -99,6 +102,10 @@ class MemoryIOHandler: def load_data(self, name: str) -> bytes: return self.data[name] + def clear(self): + self.spectra = {} + self.data = {} + class ZipFileIOHandler: file: BinaryIO @@ -153,6 +160,9 @@ class ZipFileIOHandler: with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file: return file.read() + def clear(self): + self.file.unlink(missing_ok=True) + @dataclass class DataFile: @@ -201,11 +211,15 @@ class DataFile: f"a bundled file prefixed with {self.prefix} " "must have a PropagationIOHandler attached" ) - if self.io is not None: + # a DataFile obj may have a useless io obj attached to it + if self.prefix is not None: return self.io.load_data(self.path) else: return Path(self.path).read_bytes() + def similar_to(self, other: DataFile) -> bool: + return Path(self.path).name == Path(other.path).name + def unique_name(base_name: str, existing: set[str]) -> str: name = base_name diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 7d36d04..89f3952 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -294,7 +294,7 @@ class Parameter: return is_value, v -@dataclass(repr=False) +@dataclass(repr=False, eq=False) class Parameters: """ This class defines each valid parameter's name, type and valid value. @@ -303,7 +303,7 @@ class Parameters: # internal machinery _param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False) _p_names: ClassVar[Set[str]] = set() - _frozen: bool = field(init=False, default=False, repr=False) + frozen: bool = field(init=False, default=False, repr=False) # root name: str = Parameter(string, default="no name") @@ -460,28 +460,45 @@ class Parameters: self.__post_init__() def __setattr__(self, k, v): - if self._frozen and not k.endswith("_file"): + if self.frozen and not k.endswith("_file"): raise AttributeError( f"cannot set attribute to frozen {self.__class__.__name__} instance" ) object.__setattr__(self, k, v) + def __eq__(self, other: Parameters) -> bool: + if not isinstance(other, Parameters): + raise TypeError( + f"cannot compare {self.__class__.__name__!r} with {type(other).__name__!r}" + ) + for k, v in self.items(): + other_v = getattr(other, k) + if isinstance(v, DataFile) and not v.similar_to(other_v): + return False + if other_v != v: + return False + + return True + def items(self) -> Iterator[tuple[str, Any]]: for k, v in self._param_dico.items(): if v is None: continue yield k, v - def copy(self, freeze: bool = False) -> Parameters: + def copy(self, deep: bool = True, freeze: bool = False) -> Parameters: """create a deep copy of self. if freeze is True, the returned copy is read-only""" - params = Parameters(**deepcopy(self.strip_params_dict())) + if deep: + params = Parameters(**deepcopy(self.strip_params_dict())) + else: + params = Parameters(**self.strip_params_dict()) if freeze: params.freeze() return params def freeze(self): """render the current instance read-only. This is not reversible""" - self._frozen = True + self.frozen = True def to_json(self) -> str: d = self.dump_dict() @@ -536,7 +553,7 @@ class Parameters: else: return first - def compile(self, exhaustive=False) -> Parameters: + def compile(self, exhaustive=False, strict: bool = True) -> Parameters: """ Computes missing parameters and returns them in a frozen `Parameters` instance @@ -547,6 +564,8 @@ class Parameters: Depending on the specifics of the model and how the parameters were specified, there might be no difference between a normal compilation and an exhaustive one. by default False + strict : bool, optional + raise an exception when something cannot be computed, by default True Returns ------- @@ -560,16 +579,23 @@ class Parameters: When all the necessary parameters cannot be computed, a `ValueError` is raised. In most cases, this is due to underdetermination by the user. """ + obj = self.copy(deep=False, freeze=False) + obj.compile_in_place(exhaustive, strict) + return obj + + def compile_in_place(self, exhaustive: bool = False, strict: bool = True): to_compute = MANDATORY_PARAMETERS evaluator = self.get_evaluator() - try: - for k in to_compute: + for k in to_compute: + try: evaluator.compute(k) - except EvaluatorError as e: - raise ValueError( - "Could not compile the parameter set. Most likely, " - f"an essential value is missing\n{e}" - ) from None + except EvaluatorError as e: + if strict: + raise ValueError( + "Could not compile the parameter set. Most likely, " + f"an essential value is missing\n{e}" + ) from None + if exhaustive: for p in self._p_names: if p not in evaluator.main_map: @@ -577,11 +603,10 @@ class Parameters: evaluator.compute(p) except Exception: pass - computed = self.__class__( - **{k: v.value for k, v in evaluator.main_map.items() if k in self._p_names} - ) - computed._frozen = True - return computed + self._param_dico |= { + k: v.value for k, v in evaluator.main_map.items() if k in self._p_names + } + self.freeze() def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str: """return a pretty formatted string describing the parameters""" diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index d6219d6..edc89ec 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -3,14 +3,23 @@ from __future__ import annotations import os import warnings from pathlib import Path +from typing import Callable import numpy as np from scgenerator import math -from scgenerator.io import DataFile, PropagationIOHandler, ZipFileIOHandler, unique_name +from scgenerator.io import ( + DataFile, + MemoryIOHandler, + PropagationIOHandler, + ZipFileIOHandler, + unique_name, +) from scgenerator.parameter import Parameters from scgenerator.physics import pulse, units +PARAMS_FN = "params.json" + class Spectrum(np.ndarray): w: np.ndarray @@ -25,7 +34,7 @@ class Spectrum(np.ndarray): # add the new attribute to the created instance obj.w = params.compute("w") obj.t = params.compute("t") - obj.t = params.compute("t") + obj.l = params.compute("l") obj.ifft = params.compute("ifft") # Finally, we must return the newly created object: @@ -116,16 +125,13 @@ class Spectrum(np.ndarray): class Propagation: io: PropagationIOHandler - params: Parameters + parameters: Parameters _current_index: int - PARAMS_FN = "params.json" - def __init__( self, io_handler: PropagationIOHandler, - params: Parameters | None = None, - bundle_data: bool = False, + params: Parameters, ): """ A propagation is the object that manages IO for one single propagation. @@ -141,31 +147,7 @@ class Propagation: """ self.io = io_handler self._current_index = len(self.io) - - new_params = params is not None - if not new_params: - if bundle_data: - raise ValueError( - "cannot bundle data to existing Propagation. Create a new one instead" - ) - try: - params_data = self.io.load_data(self.PARAMS_FN) - params = Parameters.from_json(params_data.decode()) - except KeyError: - raise ValueError(f"Missing Parameters in {self.__class__.__name__}.") from None - - self.params = params - - if bundle_data: - if not params._frozen: - warnings.warn( - f"Parameters instance {params.name!r} is not frozen but will be saved " - "in an unmodifiable state anyway." - ) - _bundle_external_files(self.params, self.io) - - if new_params: - self.io.save_data(self.PARAMS_FN, self.params.to_json().encode()) + self.parameters = params def __len__(self) -> int: return self._current_index @@ -174,9 +156,9 @@ class Propagation: if isinstance(key, slice): return self._load_slice(key) if isinstance(key, (float, np.floating)): - key = math.argclosest(self.params.compute("z_targets"), key) + key = math.argclosest(self.parameters.compute("z_targets"), key) array = self.io.load_spectrum(key) - return Spectrum(array, self.params) + return Spectrum(array, self.parameters) def __setitem__(self, key: int, value: np.ndarray): if not isinstance(key, int): @@ -185,19 +167,78 @@ class Propagation: def _load_slice(self, key: slice) -> Spectrum: _iter = range(len(self))[key] - out = Spectrum(np.zeros((len(_iter), self.params.t_num), dtype=complex), self.params) + out = Spectrum( + np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters + ) for i in _iter: out[i] = self.io.load_spectrum(i) return out def append(self, spectrum: np.ndarray): - self.io.save_spectrum(self._current_index, spectrum.asarray()) + self.io.save_spectrum(self._current_index, np.asarray(spectrum)) self._current_index += 1 def load_all(self) -> Spectrum: return self._load_slice(slice(None)) +def load_all(path: os.PathLike) -> Spectrum: + io = ZipFileIOHandler(path) + return Propagation(io).load_all() + + +def propagation( + file_or_params: os.PathLike | Parameters, + params: Parameters | None = None, + bundle_data: bool = False, +) -> Propagation: + file = None + if isinstance(file_or_params, Parameters): + params = file_or_params + else: + file = Path(file_or_params) + + if file is not None and file.exists(): + io = ZipFileIOHandler(file) + return _open_existing_propagation(io) + + if params is None: + raise ValueError("Parameters must be specified to create new simulation") + + if file is not None: + io = ZipFileIOHandler(file) + else: + io = MemoryIOHandler() + + try: + return _create_new_propagation(io, params, bundle_data) + except Exception as e: + io.clear() + raise e + + +def _open_existing_propagation(io: PropagationIOHandler) -> Propagation: + params = Parameters.from_json(io.load_data(PARAMS_FN).decode()) + params.compile_in_place(exhaustive=True, strict=False) + for k, v in params.items(): + if isinstance(v, DataFile): + v.io = io + return Propagation(io, params) + + +def _create_new_propagation( + io: PropagationIOHandler, params: Parameters, bundle_data: bool +) -> Propagation: + if params.frozen: + params = params.copy() + else: + params = params.compile(exhaustive=True, strict=False) + if bundle_data: + _bundle_external_files(params, io) + io.save_data(PARAMS_FN, params.to_json().encode()) + return Propagation(io, params) + + def _bundle_external_files(params: Parameters, io: PropagationIOHandler): """copies every external file specified in the parameters and saves it""" existing_files = set(io.keys()) @@ -211,8 +252,3 @@ def _bundle_external_files(params: Parameters, io: PropagationIOHandler): existing_files.add(value.path) io.save_data(value.path, data) - - -def load_all(path: os.PathLike) -> Spectrum: - io = ZipFileIOHandler(path) - return Propagation(io).load_all() diff --git a/tests/test_io_handlers.py b/tests/test_io_handlers.py index 4492441..e1a0457 100644 --- a/tests/test_io_handlers.py +++ b/tests/test_io_handlers.py @@ -7,7 +7,7 @@ import pytest from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name from scgenerator.parameter import Parameters -from scgenerator.spectra import Propagation +from scgenerator.spectra import PARAMS_FN, propagation PARAMS = dict( name="PM2000D sech pulse", @@ -78,13 +78,34 @@ def test_memory(): assert len(io) == 8 +def test_reopen(tmp_path: Path): + zpath = tmp_path / "file.zip" + params = Parameters(**PARAMS) + prop = propagation(zpath, params) + + prop2 = propagation(zpath) + + assert prop.parameters == prop2.parameters + + +def test_clear(tmp_path: Path): + params = Parameters(**PARAMS) + zpath = tmp_path / "file.zip" + prop = propagation(zpath, params) + + assert zpath.exists() + assert zpath.read_bytes() != b"" + + prop.io.clear() + + assert not zpath.exists() + + def test_zip_bundle(tmp_path: Path): params = Parameters(**PARAMS) - io = ZipFileIOHandler(tmp_path / "file.zip") - prop = Propagation(io, params.copy(True)) - assert (tmp_path / "file.zip").exists() - assert (tmp_path / "file.zip").read_bytes() != b"" + with pytest.raises(FileNotFoundError): + propagation(tmp_path / "file2.zip", params, bundle_data=True) new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz") new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz") @@ -92,30 +113,32 @@ def test_zip_bundle(tmp_path: Path): params.effective_area_file.path = str(new_aeff_path) params.freeze() - io = ZipFileIOHandler(tmp_path / "file2.zip") - prop2 = Propagation(io, params, True) + prop2 = propagation(tmp_path / "file3.zip", params, bundle_data=True) - assert params.dispersion_file.path == new_disp_path.name - assert params.dispersion_file.prefix == "zip" - assert params.effective_area_file.path == new_aeff_path.name - assert params.effective_area_file.prefix == "zip" + assert prop2.parameters.dispersion_file.path == new_disp_path.name + assert prop2.parameters.dispersion_file.prefix == "zip" + assert prop2.parameters.effective_area_file.path == new_aeff_path.name + assert prop2.parameters.effective_area_file.prefix == "zip" - with ZipFile(tmp_path / "file2.zip", "r") as zfile: + with ZipFile(tmp_path / "file3.zip", "r") as zfile: with zfile.open(new_aeff_path.name) as file: assert file.read() == new_aeff_path.read_bytes() with zfile.open(new_disp_path.name) as file: assert file.read() == new_disp_path.read_bytes() - with zfile.open(Propagation.PARAMS_FN) as file: + with zfile.open(PARAMS_FN) as file: df = json.loads(file.read().decode()) assert ( - df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path + df["dispersion_file"] + == prop2.parameters.dispersion_file.prefix + "::" + Path(params.dispersion_file.path).name ) assert ( df["effective_area_file"] - == params.effective_area_file.prefix + "::" + params.effective_area_file.path + == prop2.parameters.effective_area_file.prefix + + "::" + + Path(params.effective_area_file.path).name )