new: allow Propagation w/o params
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Callable, Generic, TypeVar, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -15,10 +14,12 @@ from scgenerator.io import (
|
||||
ZipFileIOHandler,
|
||||
unique_name,
|
||||
)
|
||||
from scgenerator.logger import get_logger
|
||||
from scgenerator.parameter import Parameters
|
||||
from scgenerator.physics import pulse, units
|
||||
|
||||
PARAMS_FN = "params.json"
|
||||
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
|
||||
|
||||
|
||||
class Spectrum(np.ndarray):
|
||||
@@ -138,27 +139,28 @@ class Spectrum(np.ndarray):
|
||||
return pulse.measure_field(self.t, self.time_amp)
|
||||
|
||||
|
||||
class Propagation:
|
||||
class Propagation(Generic[ParamsOrNone]):
|
||||
io: PropagationIOHandler
|
||||
parameters: Parameters
|
||||
parameters: ParamsOrNone
|
||||
_current_index: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
self: Propagation[ParamsOrNone],
|
||||
io_handler: PropagationIOHandler,
|
||||
params: Parameters,
|
||||
params: ParamsOrNone,
|
||||
):
|
||||
"""
|
||||
A propagation is the object that manages IO for one single propagation.
|
||||
It is recommended to use the "memory_propagation" and "zip_propagation" convenience
|
||||
factories instead of creating this object directly.
|
||||
It is recommended to use the "propagation" convenience factories instead of creating this
|
||||
object directly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
io : PropagationIOHandler
|
||||
object that implements the PropagationIOHandler Protocol.
|
||||
params : Parameters
|
||||
simulations parameters. Those will be saved via the
|
||||
params : Parameters | None
|
||||
simulations parameters. Those will be passed on to spectra when loaded.
|
||||
if None, loading spectra will result in normal numpy arrays rather than `Spectrum` obj.
|
||||
"""
|
||||
self.io = io_handler
|
||||
self._current_index = len(self.io)
|
||||
@@ -167,26 +169,47 @@ class Propagation:
|
||||
def __len__(self) -> int:
|
||||
return self._current_index
|
||||
|
||||
def __getitem__(self, key: int | slice) -> Spectrum:
|
||||
@overload
|
||||
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray:
|
||||
...
|
||||
|
||||
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
|
||||
if isinstance(key, slice):
|
||||
return self._load_slice(key)
|
||||
if isinstance(key, (float, np.floating)):
|
||||
if self.parameters is None:
|
||||
raise ValueError(f"cannot accept float key {key} when parameters is not set")
|
||||
key = math.argclosest(self.parameters.compute("z_targets"), key)
|
||||
elif key < 0:
|
||||
key = len(self) + key
|
||||
array = self.io.load_spectrum(key)
|
||||
if self.parameters is not None:
|
||||
return Spectrum(array, self.parameters)
|
||||
else:
|
||||
return array
|
||||
|
||||
def __setitem__(self, key: int, value: np.ndarray):
|
||||
if not isinstance(key, int):
|
||||
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
|
||||
self.io.save_spectrum(key, np.asarray(value))
|
||||
|
||||
@overload
|
||||
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray:
|
||||
...
|
||||
|
||||
def _load_slice(self, key: slice) -> Spectrum:
|
||||
_iter = range(len(self))[key]
|
||||
out = Spectrum(
|
||||
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
|
||||
)
|
||||
out = np.zeros((len(_iter), self.parameters.t_num), dtype=complex)
|
||||
if self.parameters is not None:
|
||||
out = Spectrum(out, self.parameters)
|
||||
for i in _iter:
|
||||
out[i] = self.io.load_spectrum(i)
|
||||
return out
|
||||
@@ -204,6 +227,7 @@ def propagation(
|
||||
params: Parameters | None = None,
|
||||
bundle_data: bool = False,
|
||||
overwrite: bool = False,
|
||||
load_parameters: bool = True,
|
||||
) -> Propagation:
|
||||
file = None
|
||||
if isinstance(file_or_params, Parameters):
|
||||
@@ -213,7 +237,7 @@ def propagation(
|
||||
|
||||
if file is not None and file.exists() and params is None:
|
||||
io = ZipFileIOHandler(file)
|
||||
return _open_existing_propagation(io)
|
||||
return _open_existing_propagation(io, load_parameters=load_parameters)
|
||||
|
||||
if params is None:
|
||||
raise ValueError("Parameters must be specified to create new simulation")
|
||||
@@ -235,7 +259,11 @@ def propagation(
|
||||
raise e
|
||||
|
||||
|
||||
def _open_existing_propagation(io: PropagationIOHandler) -> Propagation:
|
||||
def _open_existing_propagation(
|
||||
io: PropagationIOHandler, load_parameters: bool = True
|
||||
) -> Propagation:
|
||||
if not load_parameters:
|
||||
return Propagation(io, None)
|
||||
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
|
||||
params.compile_in_place(exhaustive=True, strict=False)
|
||||
for k, v in params.items():
|
||||
|
||||
Reference in New Issue
Block a user