new: allow Propagation w/o params
This commit is contained in:
@@ -1,9 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable, Generic, TypeVar, overload
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -15,10 +14,12 @@ from scgenerator.io import (
|
|||||||
ZipFileIOHandler,
|
ZipFileIOHandler,
|
||||||
unique_name,
|
unique_name,
|
||||||
)
|
)
|
||||||
|
from scgenerator.logger import get_logger
|
||||||
from scgenerator.parameter import Parameters
|
from scgenerator.parameter import Parameters
|
||||||
from scgenerator.physics import pulse, units
|
from scgenerator.physics import pulse, units
|
||||||
|
|
||||||
PARAMS_FN = "params.json"
|
PARAMS_FN = "params.json"
|
||||||
|
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
|
||||||
|
|
||||||
|
|
||||||
class Spectrum(np.ndarray):
|
class Spectrum(np.ndarray):
|
||||||
@@ -138,27 +139,28 @@ class Spectrum(np.ndarray):
|
|||||||
return pulse.measure_field(self.t, self.time_amp)
|
return pulse.measure_field(self.t, self.time_amp)
|
||||||
|
|
||||||
|
|
||||||
class Propagation:
|
class Propagation(Generic[ParamsOrNone]):
|
||||||
io: PropagationIOHandler
|
io: PropagationIOHandler
|
||||||
parameters: Parameters
|
parameters: ParamsOrNone
|
||||||
_current_index: int
|
_current_index: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self: Propagation[ParamsOrNone],
|
||||||
io_handler: PropagationIOHandler,
|
io_handler: PropagationIOHandler,
|
||||||
params: Parameters,
|
params: ParamsOrNone,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
A propagation is the object that manages IO for one single propagation.
|
A propagation is the object that manages IO for one single propagation.
|
||||||
It is recommended to use the "memory_propagation" and "zip_propagation" convenience
|
It is recommended to use the "propagation" convenience factories instead of creating this
|
||||||
factories instead of creating this object directly.
|
object directly.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
io : PropagationIOHandler
|
io : PropagationIOHandler
|
||||||
object that implements the PropagationIOHandler Protocol.
|
object that implements the PropagationIOHandler Protocol.
|
||||||
params : Parameters
|
params : Parameters | None
|
||||||
simulations parameters. Those will be saved via the
|
simulations parameters. Those will be passed on to spectra when loaded.
|
||||||
|
if None, loading spectra will result in normal numpy arrays rather than `Spectrum` obj.
|
||||||
"""
|
"""
|
||||||
self.io = io_handler
|
self.io = io_handler
|
||||||
self._current_index = len(self.io)
|
self._current_index = len(self.io)
|
||||||
@@ -167,26 +169,47 @@ class Propagation:
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self._current_index
|
return self._current_index
|
||||||
|
|
||||||
def __getitem__(self, key: int | slice) -> Spectrum:
|
@overload
|
||||||
|
def __getitem__(self: Propagation[Parameters], key: int | slice) -> Spectrum:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self: Propagation[None], key: int | slice) -> np.ndarray:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
|
||||||
if isinstance(key, slice):
|
if isinstance(key, slice):
|
||||||
return self._load_slice(key)
|
return self._load_slice(key)
|
||||||
if isinstance(key, (float, np.floating)):
|
if isinstance(key, (float, np.floating)):
|
||||||
|
if self.parameters is None:
|
||||||
|
raise ValueError(f"cannot accept float key {key} when parameters is not set")
|
||||||
key = math.argclosest(self.parameters.compute("z_targets"), key)
|
key = math.argclosest(self.parameters.compute("z_targets"), key)
|
||||||
elif key < 0:
|
elif key < 0:
|
||||||
key = len(self) + key
|
key = len(self) + key
|
||||||
array = self.io.load_spectrum(key)
|
array = self.io.load_spectrum(key)
|
||||||
return Spectrum(array, self.parameters)
|
if self.parameters is not None:
|
||||||
|
return Spectrum(array, self.parameters)
|
||||||
|
else:
|
||||||
|
return array
|
||||||
|
|
||||||
def __setitem__(self, key: int, value: np.ndarray):
|
def __setitem__(self, key: int, value: np.ndarray):
|
||||||
if not isinstance(key, int):
|
if not isinstance(key, int):
|
||||||
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
|
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
|
||||||
self.io.save_spectrum(key, np.asarray(value))
|
self.io.save_spectrum(key, np.asarray(value))
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def _load_slice(self: Propagation[None], key: slice) -> np.ndarray:
|
||||||
|
...
|
||||||
|
|
||||||
def _load_slice(self, key: slice) -> Spectrum:
|
def _load_slice(self, key: slice) -> Spectrum:
|
||||||
_iter = range(len(self))[key]
|
_iter = range(len(self))[key]
|
||||||
out = Spectrum(
|
out = np.zeros((len(_iter), self.parameters.t_num), dtype=complex)
|
||||||
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
|
if self.parameters is not None:
|
||||||
)
|
out = Spectrum(out, self.parameters)
|
||||||
for i in _iter:
|
for i in _iter:
|
||||||
out[i] = self.io.load_spectrum(i)
|
out[i] = self.io.load_spectrum(i)
|
||||||
return out
|
return out
|
||||||
@@ -204,6 +227,7 @@ def propagation(
|
|||||||
params: Parameters | None = None,
|
params: Parameters | None = None,
|
||||||
bundle_data: bool = False,
|
bundle_data: bool = False,
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
|
load_parameters: bool = True,
|
||||||
) -> Propagation:
|
) -> Propagation:
|
||||||
file = None
|
file = None
|
||||||
if isinstance(file_or_params, Parameters):
|
if isinstance(file_or_params, Parameters):
|
||||||
@@ -213,7 +237,7 @@ def propagation(
|
|||||||
|
|
||||||
if file is not None and file.exists() and params is None:
|
if file is not None and file.exists() and params is None:
|
||||||
io = ZipFileIOHandler(file)
|
io = ZipFileIOHandler(file)
|
||||||
return _open_existing_propagation(io)
|
return _open_existing_propagation(io, load_parameters=load_parameters)
|
||||||
|
|
||||||
if params is None:
|
if params is None:
|
||||||
raise ValueError("Parameters must be specified to create new simulation")
|
raise ValueError("Parameters must be specified to create new simulation")
|
||||||
@@ -235,7 +259,11 @@ def propagation(
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def _open_existing_propagation(io: PropagationIOHandler) -> Propagation:
|
def _open_existing_propagation(
|
||||||
|
io: PropagationIOHandler, load_parameters: bool = True
|
||||||
|
) -> Propagation:
|
||||||
|
if not load_parameters:
|
||||||
|
return Propagation(io, None)
|
||||||
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
|
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
|
||||||
params.compile_in_place(exhaustive=True, strict=False)
|
params.compile_in_place(exhaustive=True, strict=False)
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user