new: allow Propagation w/o params

This commit is contained in:
Benoît Sierro
2023-08-23 09:39:19 +02:00
parent 97014de0c0
commit 808ce7dd48

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
import os
import warnings
from pathlib import Path
from typing import Callable
from typing import Callable, Generic, TypeVar, overload
import numpy as np
@@ -15,10 +14,12 @@ from scgenerator.io import (
ZipFileIOHandler,
unique_name,
)
from scgenerator.logger import get_logger
from scgenerator.parameter import Parameters
from scgenerator.physics import pulse, units
PARAMS_FN = "params.json"
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
class Spectrum(np.ndarray):
@@ -138,27 +139,28 @@ class Spectrum(np.ndarray):
return pulse.measure_field(self.t, self.time_amp)
class Propagation:
class Propagation(Generic[ParamsOrNone]):
io: PropagationIOHandler
parameters: Parameters
parameters: ParamsOrNone
_current_index: int
def __init__(
self,
self: Propagation[ParamsOrNone],
io_handler: PropagationIOHandler,
params: Parameters,
params: ParamsOrNone,
):
"""
A propagation is the object that manages IO for one single propagation.
It is recommended to use the "memory_propagation" and "zip_propagation" convenience
factories instead of creating this object directly.
It is recommended to use the "propagation" convenience factories instead of creating this
object directly.
Parameters
----------
io : PropagationIOHandler
object that implements the PropagationIOHandler Protocol.
params : Parameters
simulations parameters. Those will be saved via the
params : Parameters | None
simulations parameters. Those will be passed on to spectra when loaded.
if None, loading spectra will result in normal numpy arrays rather than `Spectrum` obj.
"""
self.io = io_handler
self._current_index = len(self.io)
@@ -167,26 +169,47 @@ class Propagation:
def __len__(self) -> int:
return self._current_index
def __getitem__(self, key: int | slice) -> Spectrum:
@overload
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum:
...
@overload
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray:
...
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
if isinstance(key, slice):
return self._load_slice(key)
if isinstance(key, (float, np.floating)):
if self.parameters is None:
raise ValueError(f"cannot accept float key {key} when parameters is not set")
key = math.argclosest(self.parameters.compute("z_targets"), key)
elif key < 0:
key = len(self) + key
array = self.io.load_spectrum(key)
if self.parameters is not None:
return Spectrum(array, self.parameters)
else:
return array
def __setitem__(self, key: int, value: np.ndarray):
if not isinstance(key, int):
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
self.io.save_spectrum(key, np.asarray(value))
@overload
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
...
@overload
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray:
...
def _load_slice(self, key: slice) -> Spectrum:
_iter = range(len(self))[key]
out = Spectrum(
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
)
out = np.zeros((len(_iter), self.parameters.t_num), dtype=complex)
if self.parameters is not None:
out = Spectrum(out, self.parameters)
for i in _iter:
out[i] = self.io.load_spectrum(i)
return out
@@ -204,6 +227,7 @@ def propagation(
params: Parameters | None = None,
bundle_data: bool = False,
overwrite: bool = False,
load_parameters: bool = True,
) -> Propagation:
file = None
if isinstance(file_or_params, Parameters):
@@ -213,7 +237,7 @@ def propagation(
if file is not None and file.exists() and params is None:
io = ZipFileIOHandler(file)
return _open_existing_propagation(io)
return _open_existing_propagation(io, load_parameters=load_parameters)
if params is None:
raise ValueError("Parameters must be specified to create new simulation")
@@ -235,7 +259,11 @@ def propagation(
raise e
def _open_existing_propagation(io: PropagationIOHandler) -> Propagation:
def _open_existing_propagation(
io: PropagationIOHandler, load_parameters: bool = True
) -> Propagation:
if not load_parameters:
return Propagation(io, None)
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
params.compile_in_place(exhaustive=True, strict=False)
for k, v in params.items():