new: allow Propagation w/o params

This commit is contained in:
Benoît Sierro
2023-08-23 09:39:19 +02:00
parent 97014de0c0
commit 808ce7dd48

View File

@@ -1,9 +1,8 @@
from __future__ import annotations from __future__ import annotations
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable, Generic, TypeVar, overload
import numpy as np import numpy as np
@@ -15,10 +14,12 @@ from scgenerator.io import (
ZipFileIOHandler, ZipFileIOHandler,
unique_name, unique_name,
) )
from scgenerator.logger import get_logger
from scgenerator.parameter import Parameters from scgenerator.parameter import Parameters
from scgenerator.physics import pulse, units from scgenerator.physics import pulse, units
PARAMS_FN = "params.json" PARAMS_FN = "params.json"
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
class Spectrum(np.ndarray): class Spectrum(np.ndarray):
@@ -138,27 +139,28 @@ class Spectrum(np.ndarray):
return pulse.measure_field(self.t, self.time_amp) return pulse.measure_field(self.t, self.time_amp)
class Propagation: class Propagation(Generic[ParamsOrNone]):
io: PropagationIOHandler io: PropagationIOHandler
parameters: Parameters parameters: ParamsOrNone
_current_index: int _current_index: int
def __init__( def __init__(
self, self: Propagation[ParamsOrNone],
io_handler: PropagationIOHandler, io_handler: PropagationIOHandler,
params: Parameters, params: ParamsOrNone,
): ):
""" """
A propagation is the object that manages IO for one single propagation. A propagation is the object that manages IO for one single propagation.
It is recommended to use the "memory_propagation" and "zip_propagation" convenience It is recommended to use the "propagation" convenience factories instead of creating this
factories instead of creating this object directly. object directly.
Parameters Parameters
---------- ----------
io : PropagationIOHandler io : PropagationIOHandler
object that implements the PropagationIOHandler Protocol. object that implements the PropagationIOHandler Protocol.
params : Parameters params : Parameters | None
simulations parameters. Those will be saved via the simulations parameters. Those will be passed on to spectra when loaded.
if None, loading spectra will result in normal numpy arrays rather than `Spectrum` obj.
""" """
self.io = io_handler self.io = io_handler
self._current_index = len(self.io) self._current_index = len(self.io)
@@ -167,26 +169,47 @@ class Propagation:
def __len__(self) -> int: def __len__(self) -> int:
return self._current_index return self._current_index
def __getitem__(self, key: int | slice) -> Spectrum: @overload
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum:
...
@overload
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray:
...
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
if isinstance(key, slice): if isinstance(key, slice):
return self._load_slice(key) return self._load_slice(key)
if isinstance(key, (float, np.floating)): if isinstance(key, (float, np.floating)):
if self.parameters is None:
raise ValueError(f"cannot accept float key {key} when parameters is not set")
key = math.argclosest(self.parameters.compute("z_targets"), key) key = math.argclosest(self.parameters.compute("z_targets"), key)
elif key < 0: elif key < 0:
key = len(self) + key key = len(self) + key
array = self.io.load_spectrum(key) array = self.io.load_spectrum(key)
return Spectrum(array, self.parameters) if self.parameters is not None:
return Spectrum(array, self.parameters)
else:
return array
def __setitem__(self, key: int, value: np.ndarray): def __setitem__(self, key: int, value: np.ndarray):
if not isinstance(key, int): if not isinstance(key, int):
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}") raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
self.io.save_spectrum(key, np.asarray(value)) self.io.save_spectrum(key, np.asarray(value))
@overload
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
...
@overload
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray:
...
def _load_slice(self, key: slice) -> Spectrum: def _load_slice(self, key: slice) -> Spectrum:
_iter = range(len(self))[key] _iter = range(len(self))[key]
out = Spectrum( out = np.zeros((len(_iter), self.parameters.t_num), dtype=complex)
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters if self.parameters is not None:
) out = Spectrum(out, self.parameters)
for i in _iter: for i in _iter:
out[i] = self.io.load_spectrum(i) out[i] = self.io.load_spectrum(i)
return out return out
@@ -204,6 +227,7 @@ def propagation(
params: Parameters | None = None, params: Parameters | None = None,
bundle_data: bool = False, bundle_data: bool = False,
overwrite: bool = False, overwrite: bool = False,
load_parameters: bool = True,
) -> Propagation: ) -> Propagation:
file = None file = None
if isinstance(file_or_params, Parameters): if isinstance(file_or_params, Parameters):
@@ -213,7 +237,7 @@ def propagation(
if file is not None and file.exists() and params is None: if file is not None and file.exists() and params is None:
io = ZipFileIOHandler(file) io = ZipFileIOHandler(file)
return _open_existing_propagation(io) return _open_existing_propagation(io, load_parameters=load_parameters)
if params is None: if params is None:
raise ValueError("Parameters must be specified to create new simulation") raise ValueError("Parameters must be specified to create new simulation")
@@ -235,7 +259,11 @@ def propagation(
raise e raise e
def _open_existing_propagation(io: PropagationIOHandler) -> Propagation: def _open_existing_propagation(
io: PropagationIOHandler, load_parameters: bool = True
) -> Propagation:
if not load_parameters:
return Propagation(io, None)
params = Parameters.from_json(io.load_data(PARAMS_FN).decode()) params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
params.compile_in_place(exhaustive=True, strict=False) params.compile_in_place(exhaustive=True, strict=False)
for k, v in params.items(): for k, v in params.items():