basic PropagationCollection
This commit is contained in:
@@ -2,8 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generic, TypeVar, overload
|
||||
from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -21,6 +23,7 @@ from scgenerator.physics import pulse, units
|
||||
|
||||
PARAMS_FN = "params.json"
|
||||
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Spectrum(np.ndarray):
|
||||
@@ -147,6 +150,9 @@ class Spectrum(np.ndarray):
|
||||
freq_amp = afreq_amp
|
||||
|
||||
|
||||
NO_PARAMS = object()
|
||||
|
||||
|
||||
class Propagation(Generic[ParamsOrNone]):
|
||||
io: PropagationIOHandler
|
||||
parameters: ParamsOrNone
|
||||
@@ -155,7 +161,7 @@ class Propagation(Generic[ParamsOrNone]):
|
||||
def __init__(
|
||||
self: Propagation[ParamsOrNone],
|
||||
io_handler: PropagationIOHandler,
|
||||
params: ParamsOrNone,
|
||||
params: ParamsOrNone = NO_PARAMS,
|
||||
):
|
||||
"""
|
||||
A propagation is the object that manages IO for one single propagation.
|
||||
@@ -166,15 +172,18 @@ class Propagation(Generic[ParamsOrNone]):
|
||||
----------
|
||||
io : PropagationIOHandler
|
||||
object that implements the PropagationIOHandler Protocol.
|
||||
params : Parameters | None
|
||||
simulations parameters. Those will be passed on to spectra when loaded.
|
||||
if None, loading spectra will result in normal numpy arrays rather than `Spectrum` obj.
|
||||
params : Parameters | None, optional
|
||||
- if `Parameters`, those will be passed on to spectra when loaded to create `Spectrum`
|
||||
objects.
|
||||
- if explicitly None, loading spectra will result in normal numpy arrays rather than
|
||||
`Spectrum` obj.
|
||||
- if unspecified, parameters will be lazily loaded the first time the `parameters`
|
||||
attribute is accessed (e.g. loading a spectrum).
|
||||
"""
|
||||
self.io = io_handler
|
||||
self._current_index = len(self.io)
|
||||
if params is not NO_PARAMS:
|
||||
self.parameters = params
|
||||
if self.parameters is not None:
|
||||
self.z_positions = self.parameters.compute("z_targets")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._current_index
|
||||
@@ -190,24 +199,23 @@ class Propagation(Generic[ParamsOrNone]):
|
||||
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
|
||||
if isinstance(key, slice):
|
||||
return self._load_slice(key)
|
||||
if isinstance(key, (float, np.floating)):
|
||||
if self.parameters is None:
|
||||
raise ValueError(f"cannot accept float key {key} when parameters is not set")
|
||||
key = math.argclosest(self.z_positions, key)
|
||||
elif key < 0:
|
||||
self._warn_negative_index(key)
|
||||
key = len(self) + key
|
||||
array = self.io.load_spectrum(key)
|
||||
if self.parameters is not None:
|
||||
return Spectrum.from_params(array, self.parameters)
|
||||
else:
|
||||
return array
|
||||
return self._spectrumize(array)
|
||||
|
||||
def __setitem__(self, key: int, value: np.ndarray):
|
||||
if not isinstance(key, int):
|
||||
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
|
||||
self.io.save_spectrum(key, np.asarray(value))
|
||||
|
||||
def _spectrumize(self, array: np.ndarray | Spectrum) -> np.ndarray | Spectrum:
|
||||
if self.parameters is not None and not isinstance(array, Spectrum):
|
||||
return Spectrum.from_params(array, self.parameters)
|
||||
else:
|
||||
return array
|
||||
|
||||
@overload
|
||||
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
|
||||
...
|
||||
@@ -220,17 +228,23 @@ class Propagation(Generic[ParamsOrNone]):
|
||||
self._warn_negative_index(key.start)
|
||||
self._warn_negative_index(key.stop)
|
||||
_iter = range(len(self))[key]
|
||||
# if self.parameters is not None:
|
||||
# out = Spectrum.from_params(
|
||||
# np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
|
||||
# )
|
||||
# for i in _iter:
|
||||
# out[i] = self.io.load_spectrum(i)
|
||||
# else:
|
||||
out = np.array([self.io.load_spectrum(i) for i in _iter])
|
||||
if self.parameters is not None:
|
||||
out = Spectrum.from_params(out, self.parameters)
|
||||
return out
|
||||
return self._spectrumize(np.array([self.io.load_spectrum(i) for i in _iter]))
|
||||
|
||||
@cached_property
|
||||
def parameters(self):
|
||||
try:
|
||||
return self.load_parameters()
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
return None
|
||||
|
||||
def load_parameters(self) -> Parameters:
|
||||
params = Parameters.from_json(self.io.load_data(PARAMS_FN).decode())
|
||||
params.compile_in_place(exhaustive=True, strict=False)
|
||||
for k, v in params.items():
|
||||
if isinstance(v, DataFile):
|
||||
v.io = self.io
|
||||
return params
|
||||
|
||||
def append(self, spectrum: np.ndarray):
|
||||
self.io.save_spectrum(self._current_index, np.asarray(spectrum))
|
||||
@@ -246,6 +260,30 @@ class Propagation(Generic[ParamsOrNone]):
|
||||
return self._load_slice(slice(None))
|
||||
|
||||
|
||||
@dataclass
|
||||
class PropagationCollection:
|
||||
propagations: list[Propagation]
|
||||
parameters: Parameters
|
||||
z: np.ndarray
|
||||
t: np.ndarray
|
||||
w: np.ndarray
|
||||
wl: np.ndarray
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> Propagation:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> list[Propagation]:
|
||||
...
|
||||
|
||||
def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]:
|
||||
return self.propagations[key]
|
||||
|
||||
def __iter__(self) -> Iterator[Propagation]:
|
||||
yield from self.propagations
|
||||
|
||||
|
||||
def propagation(
|
||||
file_or_params: os.PathLike | Parameters,
|
||||
params: Parameters | None = None,
|
||||
@@ -284,11 +322,12 @@ def propagation(
|
||||
file = Path(file_or_params)
|
||||
|
||||
if file is not None and file.exists() and params is None:
|
||||
io = ZipFileIOHandler(file)
|
||||
return _open_existing_propagation(io, load_parameters=load_parameters)
|
||||
return open_existing_propagation(file, load_parameters=load_parameters)
|
||||
|
||||
if params is None:
|
||||
raise ValueError("Parameters must be specified to create new simulation")
|
||||
raise ValueError(
|
||||
f"{file} doesn't exist, but no parameters have been passed to create a new propagation"
|
||||
)
|
||||
|
||||
if file is not None:
|
||||
if file.exists() and params is not None:
|
||||
@@ -307,17 +346,67 @@ def propagation(
|
||||
raise e
|
||||
|
||||
|
||||
def _open_existing_propagation(
|
||||
io: PropagationIOHandler, load_parameters: bool = True
|
||||
) -> Propagation:
|
||||
if not load_parameters:
|
||||
def propagation_series(
|
||||
files: Sequence[os.PathLike], index: int | slice | None = None, progress_bar: bool = False
|
||||
) -> tuple[Spectrum, PropagationCollection]:
|
||||
"""
|
||||
loads an existing sequence of propagation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
files : Sequence[os.PathLike]
|
||||
sequence (e.g. list) of propagation zip files.
|
||||
index : slice | int | None, optional
|
||||
what index of each propagation to load (e.g. `None`->all spectra, `-1`->last spectrum),
|
||||
by default None
|
||||
progress_bar : bool, optional
|
||||
print a progress bar to stderr as the files are loading (requires tqdm), by default False
|
||||
|
||||
Returns
|
||||
-------
|
||||
PropagationCollection
|
||||
convenient object to work with multiple propagations **built on the same grid**
|
||||
"""
|
||||
if len(files) == 0:
|
||||
raise ValueError("You must provide at least one file to build a propagation series")
|
||||
|
||||
propagations = [open_existing_propagation(f, load_parameters=False) for f in files]
|
||||
parameters = propagations[0].load_parameters()
|
||||
|
||||
rest = propagations
|
||||
if progress_bar:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(str(e))
|
||||
pass
|
||||
|
||||
rest = tqdm(rest)
|
||||
|
||||
spectrum = Spectrum.from_params([prop[:] for prop in rest], parameters)
|
||||
for prop in propagations:
|
||||
del prop.parameters
|
||||
|
||||
return spectrum, PropagationCollection(
|
||||
propagations,
|
||||
parameters,
|
||||
z=parameters.z_targets[: spectrum.shape[0]],
|
||||
t=spectrum.t,
|
||||
w=spectrum.w_disp,
|
||||
wl=spectrum.wl_disp,
|
||||
)
|
||||
|
||||
|
||||
def open_existing_propagation(file: os.PathLike, load_parameters: bool = True) -> Propagation:
|
||||
file = Path(file)
|
||||
if not file.exists():
|
||||
raise FileNotFoundError(f"no propagation found at {file}")
|
||||
io = ZipFileIOHandler(file)
|
||||
|
||||
if load_parameters:
|
||||
return Propagation(io)
|
||||
else:
|
||||
return Propagation(io, None)
|
||||
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
|
||||
params.compile_in_place(exhaustive=True, strict=False)
|
||||
for k, v in params.items():
|
||||
if isinstance(v, DataFile):
|
||||
v.io = io
|
||||
return Propagation(io, params)
|
||||
|
||||
|
||||
def _create_new_propagation(
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
|
||||
from scgenerator.parameter import Parameters
|
||||
from scgenerator.spectra import PARAMS_FN, propagation
|
||||
from scgenerator.spectra import PARAMS_FN, propagation, propagation_series
|
||||
|
||||
PARAMS = dict(
|
||||
name="PM2000D sech pulse",
|
||||
@@ -25,7 +25,7 @@ PARAMS = dict(
|
||||
raman_type="measured",
|
||||
quantum_noise=True,
|
||||
interpolation_degree=11,
|
||||
z_num=128,
|
||||
z_num=8,
|
||||
length=1.5,
|
||||
t_num=512,
|
||||
dt=5e-15,
|
||||
@@ -173,3 +173,28 @@ def test_zip_bundle(tmp_path: Path):
|
||||
def test_unique_name():
|
||||
existing = {"spec.npy", "spec_0.npy"}
|
||||
assert unique_name("spec.npy", existing) == "spec_1.npy"
|
||||
|
||||
existing = {"spec", "spec0"}
|
||||
assert unique_name("spec", existing) == "spec_0"
|
||||
|
||||
existing = {"spec", "spec_0"}
|
||||
assert unique_name("spec", existing) == "spec_1"
|
||||
|
||||
|
||||
def test_propagation_series(tmp_path: Path):
|
||||
params = Parameters(**PARAMS)
|
||||
with pytest.raises(ValueError):
|
||||
specs, props = propagation_series([])
|
||||
|
||||
flist = [tmp_path / f"prop{i}.zip" for i in range(10)]
|
||||
for i, f in enumerate(flist):
|
||||
params.name = f"prop {i}"
|
||||
prop = propagation(f, params)
|
||||
for _ in range(params.z_num):
|
||||
prop.append(np.zeros(params.t_num, dtype=complex))
|
||||
|
||||
assert set(flist) == set(tmp_path.glob("*.zip"))
|
||||
|
||||
specs, propagations = propagation_series(flist)
|
||||
assert specs.shape == (10, params.z_num, params.t_num)
|
||||
assert all(prop.parameters.name == f"prop {i}" for i, prop in enumerate(propagations))
|
||||
|
||||
Reference in New Issue
Block a user