basic PropagationCollection

This commit is contained in:
Benoît Sierro
2023-10-03 16:52:23 +02:00
parent a903bfcb5e
commit 5adde638ef
2 changed files with 156 additions and 42 deletions

View File

@@ -2,8 +2,10 @@ from __future__ import annotations
import os
import warnings
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Callable, Generic, TypeVar, overload
from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
import numpy as np
@@ -21,6 +23,7 @@ from scgenerator.physics import pulse, units
PARAMS_FN = "params.json"
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
logger = get_logger(__name__)
class Spectrum(np.ndarray):
@@ -147,6 +150,9 @@ class Spectrum(np.ndarray):
freq_amp = afreq_amp
NO_PARAMS = object()
class Propagation(Generic[ParamsOrNone]):
io: PropagationIOHandler
parameters: ParamsOrNone
@@ -155,7 +161,7 @@ class Propagation(Generic[ParamsOrNone]):
def __init__(
self: Propagation[ParamsOrNone],
io_handler: PropagationIOHandler,
params: ParamsOrNone,
params: ParamsOrNone = NO_PARAMS,
):
"""
A propagation is the object that manages IO for one single propagation.
@@ -166,15 +172,18 @@ class Propagation(Generic[ParamsOrNone]):
----------
io : PropagationIOHandler
object that implements the PropagationIOHandler Protocol.
params : Parameters | None
simulations parameters. Those will be passed on to spectra when loaded.
if None, loading spectra will result in normal numpy arrays rather than `Spectrum` obj.
params : Parameters | None, optional
- if `Parameters`, those will be passed on to spectra when loaded to create `Spectrum`
objects.
- if explicitly None, loading spectra will result in normal numpy arrays rather than
`Spectrum` obj.
- if unspecified, parameters will be lazily loaded the first time the `parameters`
attribute is accessed (e.g. loading a spectrum).
"""
self.io = io_handler
self._current_index = len(self.io)
self.parameters = params
if self.parameters is not None:
self.z_positions = self.parameters.compute("z_targets")
if params is not NO_PARAMS:
self.parameters = params
def __len__(self) -> int:
return self._current_index
@@ -190,24 +199,23 @@ class Propagation(Generic[ParamsOrNone]):
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
if isinstance(key, slice):
return self._load_slice(key)
if isinstance(key, (float, np.floating)):
if self.parameters is None:
raise ValueError(f"cannot accept float key {key} when parameters is not set")
key = math.argclosest(self.z_positions, key)
elif key < 0:
self._warn_negative_index(key)
key = len(self) + key
array = self.io.load_spectrum(key)
if self.parameters is not None:
return Spectrum.from_params(array, self.parameters)
else:
return array
return self._spectrumize(array)
def __setitem__(self, key: int, value: np.ndarray):
if not isinstance(key, int):
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
self.io.save_spectrum(key, np.asarray(value))
def _spectrumize(self, array: np.ndarray | Spectrum) -> np.ndarray | Spectrum:
if self.parameters is not None and not isinstance(array, Spectrum):
return Spectrum.from_params(array, self.parameters)
else:
return array
@overload
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
...
@@ -220,17 +228,23 @@ class Propagation(Generic[ParamsOrNone]):
self._warn_negative_index(key.start)
self._warn_negative_index(key.stop)
_iter = range(len(self))[key]
# if self.parameters is not None:
# out = Spectrum.from_params(
# np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
# )
# for i in _iter:
# out[i] = self.io.load_spectrum(i)
# else:
out = np.array([self.io.load_spectrum(i) for i in _iter])
if self.parameters is not None:
out = Spectrum.from_params(out, self.parameters)
return out
return self._spectrumize(np.array([self.io.load_spectrum(i) for i in _iter]))
@cached_property
def parameters(self):
try:
return self.load_parameters()
except Exception as e:
logger.error(str(e))
return None
def load_parameters(self) -> Parameters:
params = Parameters.from_json(self.io.load_data(PARAMS_FN).decode())
params.compile_in_place(exhaustive=True, strict=False)
for k, v in params.items():
if isinstance(v, DataFile):
v.io = self.io
return params
def append(self, spectrum: np.ndarray):
self.io.save_spectrum(self._current_index, np.asarray(spectrum))
@@ -246,6 +260,30 @@ class Propagation(Generic[ParamsOrNone]):
return self._load_slice(slice(None))
@dataclass
class PropagationCollection:
propagations: list[Propagation]
parameters: Parameters
z: np.ndarray
t: np.ndarray
w: np.ndarray
wl: np.ndarray
@overload
def __getitem__(self, key: int) -> Propagation:
...
@overload
def __getitem__(self, key: slice) -> list[Propagation]:
...
def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]:
return self.propagations[key]
def __iter__(self) -> Iterator[Propagation]:
yield from self.propagations
def propagation(
file_or_params: os.PathLike | Parameters,
params: Parameters | None = None,
@@ -284,11 +322,12 @@ def propagation(
file = Path(file_or_params)
if file is not None and file.exists() and params is None:
io = ZipFileIOHandler(file)
return _open_existing_propagation(io, load_parameters=load_parameters)
return open_existing_propagation(file, load_parameters=load_parameters)
if params is None:
raise ValueError("Parameters must be specified to create new simulation")
raise ValueError(
f"{file} doesn't exist, but no parameters have been passed to create a new propagation"
)
if file is not None:
if file.exists() and params is not None:
@@ -307,17 +346,67 @@ def propagation(
raise e
def _open_existing_propagation(
io: PropagationIOHandler, load_parameters: bool = True
) -> Propagation:
if not load_parameters:
def propagation_series(
files: Sequence[os.PathLike], index: int | slice | None = None, progress_bar: bool = False
) -> tuple[Spectrum, PropagationCollection]:
"""
loads an existing sequence of propagation
Parameters
----------
files : Sequence[os.PathLike]
sequence (e.g. list) of propagation zip files.
index : slice | int | None, optional
what index of each propagation to load (e.g. `None`->all spectra, `-1`->last spectrum),
by default None
progress_bar : bool, optional
print a progress bar to stderr as the files are loading (requires tqdm), by default False
Returns
-------
PropagationCollection
convenient object to work with multiple propagations **built on the same grid**
"""
if len(files) == 0:
raise ValueError("You must provide at least one file to build a propagation series")
propagations = [open_existing_propagation(f, load_parameters=False) for f in files]
parameters = propagations[0].load_parameters()
rest = propagations
if progress_bar:
try:
from tqdm import tqdm
except ModuleNotFoundError as e:
logger.error(str(e))
pass
rest = tqdm(rest)
spectrum = Spectrum.from_params([prop[:] for prop in rest], parameters)
for prop in propagations:
del prop.parameters
return spectrum, PropagationCollection(
propagations,
parameters,
z=parameters.z_targets[: spectrum.shape[0]],
t=spectrum.t,
w=spectrum.w_disp,
wl=spectrum.wl_disp,
)
def open_existing_propagation(file: os.PathLike, load_parameters: bool = True) -> Propagation:
file = Path(file)
if not file.exists():
raise FileNotFoundError(f"no propagation found at {file}")
io = ZipFileIOHandler(file)
if load_parameters:
return Propagation(io)
else:
return Propagation(io, None)
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
params.compile_in_place(exhaustive=True, strict=False)
for k, v in params.items():
if isinstance(v, DataFile):
v.io = io
return Propagation(io, params)
def _create_new_propagation(

View File

@@ -7,7 +7,7 @@ import pytest
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
from scgenerator.parameter import Parameters
from scgenerator.spectra import PARAMS_FN, propagation
from scgenerator.spectra import PARAMS_FN, propagation, propagation_series
PARAMS = dict(
name="PM2000D sech pulse",
@@ -25,7 +25,7 @@ PARAMS = dict(
raman_type="measured",
quantum_noise=True,
interpolation_degree=11,
z_num=128,
z_num=8,
length=1.5,
t_num=512,
dt=5e-15,
@@ -173,3 +173,28 @@ def test_zip_bundle(tmp_path: Path):
def test_unique_name():
existing = {"spec.npy", "spec_0.npy"}
assert unique_name("spec.npy", existing) == "spec_1.npy"
existing = {"spec", "spec0"}
assert unique_name("spec", existing) == "spec_0"
existing = {"spec", "spec_0"}
assert unique_name("spec", existing) == "spec_1"
def test_propagation_series(tmp_path: Path):
params = Parameters(**PARAMS)
with pytest.raises(ValueError):
specs, props = propagation_series([])
flist = [tmp_path / f"prop{i}.zip" for i in range(10)]
for i, f in enumerate(flist):
params.name = f"prop {i}"
prop = propagation(f, params)
for _ in range(params.z_num):
prop.append(np.zeros(params.t_num, dtype=complex))
assert set(flist) == set(tmp_path.glob("*.zip"))
specs, propagations = propagation_series(flist)
assert specs.shape == (10, params.z_num, params.t_num)
assert all(prop.parameters.name == f"prop {i}" for i, prop in enumerate(propagations))