basic PropagationCollection

This commit is contained in:
Benoît Sierro
2023-10-03 16:52:23 +02:00
parent a903bfcb5e
commit 5adde638ef
2 changed files with 156 additions and 42 deletions

View File

@@ -2,8 +2,10 @@ from __future__ import annotations
import os import os
import warnings import warnings
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Callable, Generic, TypeVar, overload from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
import numpy as np import numpy as np
@@ -21,6 +23,7 @@ from scgenerator.physics import pulse, units
PARAMS_FN = "params.json" PARAMS_FN = "params.json"
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None) ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
logger = get_logger(__name__)
class Spectrum(np.ndarray): class Spectrum(np.ndarray):
@@ -147,6 +150,9 @@ class Spectrum(np.ndarray):
freq_amp = afreq_amp freq_amp = afreq_amp
NO_PARAMS = object()
class Propagation(Generic[ParamsOrNone]): class Propagation(Generic[ParamsOrNone]):
io: PropagationIOHandler io: PropagationIOHandler
parameters: ParamsOrNone parameters: ParamsOrNone
@@ -155,7 +161,7 @@ class Propagation(Generic[ParamsOrNone]):
def __init__( def __init__(
self: Propagation[ParamsOrNone], self: Propagation[ParamsOrNone],
io_handler: PropagationIOHandler, io_handler: PropagationIOHandler,
params: ParamsOrNone, params: ParamsOrNone = NO_PARAMS,
): ):
""" """
A propagation is the object that manages IO for one single propagation. A propagation is the object that manages IO for one single propagation.
@@ -166,15 +172,18 @@ class Propagation(Generic[ParamsOrNone]):
---------- ----------
io : PropagationIOHandler io : PropagationIOHandler
object that implements the PropagationIOHandler Protocol. object that implements the PropagationIOHandler Protocol.
params : Parameters | None params : Parameters | None, optional
simulations parameters. Those will be passed on to spectra when loaded. - if `Parameters`, those will be passed on to spectra when loaded to create `Spectrum`
if None, loading spectra will result in normal numpy arrays rather than `Spectrum` obj. objects.
- if explicitly None, loading spectra will result in normal numpy arrays rather than
`Spectrum` obj.
- if unspecified, parameters will be lazily loaded the first time the `parameters`
attribute is accessed (e.g. loading a spectrum).
""" """
self.io = io_handler self.io = io_handler
self._current_index = len(self.io) self._current_index = len(self.io)
self.parameters = params if params is not NO_PARAMS:
if self.parameters is not None: self.parameters = params
self.z_positions = self.parameters.compute("z_targets")
def __len__(self) -> int: def __len__(self) -> int:
return self._current_index return self._current_index
@@ -190,24 +199,23 @@ class Propagation(Generic[ParamsOrNone]):
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray: def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
if isinstance(key, slice): if isinstance(key, slice):
return self._load_slice(key) return self._load_slice(key)
if isinstance(key, (float, np.floating)):
if self.parameters is None:
raise ValueError(f"cannot accept float key {key} when parameters is not set")
key = math.argclosest(self.z_positions, key)
elif key < 0: elif key < 0:
self._warn_negative_index(key) self._warn_negative_index(key)
key = len(self) + key key = len(self) + key
array = self.io.load_spectrum(key) array = self.io.load_spectrum(key)
if self.parameters is not None: return self._spectrumize(array)
return Spectrum.from_params(array, self.parameters)
else:
return array
def __setitem__(self, key: int, value: np.ndarray): def __setitem__(self, key: int, value: np.ndarray):
if not isinstance(key, int): if not isinstance(key, int):
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}") raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
self.io.save_spectrum(key, np.asarray(value)) self.io.save_spectrum(key, np.asarray(value))
def _spectrumize(self, array: np.ndarray | Spectrum) -> np.ndarray | Spectrum:
if self.parameters is not None and not isinstance(array, Spectrum):
return Spectrum.from_params(array, self.parameters)
else:
return array
@overload @overload
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum: def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
... ...
@@ -220,17 +228,23 @@ class Propagation(Generic[ParamsOrNone]):
self._warn_negative_index(key.start) self._warn_negative_index(key.start)
self._warn_negative_index(key.stop) self._warn_negative_index(key.stop)
_iter = range(len(self))[key] _iter = range(len(self))[key]
# if self.parameters is not None: return self._spectrumize(np.array([self.io.load_spectrum(i) for i in _iter]))
# out = Spectrum.from_params(
# np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters @cached_property
# ) def parameters(self):
# for i in _iter: try:
# out[i] = self.io.load_spectrum(i) return self.load_parameters()
# else: except Exception as e:
out = np.array([self.io.load_spectrum(i) for i in _iter]) logger.error(str(e))
if self.parameters is not None: return None
out = Spectrum.from_params(out, self.parameters)
return out def load_parameters(self) -> Parameters:
params = Parameters.from_json(self.io.load_data(PARAMS_FN).decode())
params.compile_in_place(exhaustive=True, strict=False)
for k, v in params.items():
if isinstance(v, DataFile):
v.io = self.io
return params
def append(self, spectrum: np.ndarray): def append(self, spectrum: np.ndarray):
self.io.save_spectrum(self._current_index, np.asarray(spectrum)) self.io.save_spectrum(self._current_index, np.asarray(spectrum))
@@ -246,6 +260,30 @@ class Propagation(Generic[ParamsOrNone]):
return self._load_slice(slice(None)) return self._load_slice(slice(None))
@dataclass
class PropagationCollection:
propagations: list[Propagation]
parameters: Parameters
z: np.ndarray
t: np.ndarray
w: np.ndarray
wl: np.ndarray
@overload
def __getitem__(self, key: int) -> Propagation:
...
@overload
def __getitem__(self, key: slice) -> list[Propagation]:
...
def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]:
return self.propagations[key]
def __iter__(self) -> Iterator[Propagation]:
yield from self.propagations
def propagation( def propagation(
file_or_params: os.PathLike | Parameters, file_or_params: os.PathLike | Parameters,
params: Parameters | None = None, params: Parameters | None = None,
@@ -284,11 +322,12 @@ def propagation(
file = Path(file_or_params) file = Path(file_or_params)
if file is not None and file.exists() and params is None: if file is not None and file.exists() and params is None:
io = ZipFileIOHandler(file) return open_existing_propagation(file, load_parameters=load_parameters)
return _open_existing_propagation(io, load_parameters=load_parameters)
if params is None: if params is None:
raise ValueError("Parameters must be specified to create new simulation") raise ValueError(
f"{file} doesn't exist, but no parameters have been passed to create a new propagation"
)
if file is not None: if file is not None:
if file.exists() and params is not None: if file.exists() and params is not None:
@@ -307,17 +346,67 @@ def propagation(
raise e raise e
def _open_existing_propagation( def propagation_series(
io: PropagationIOHandler, load_parameters: bool = True files: Sequence[os.PathLike], index: int | slice | None = None, progress_bar: bool = False
) -> Propagation: ) -> tuple[Spectrum, PropagationCollection]:
if not load_parameters: """
loads an existing sequence of propagation
Parameters
----------
files : Sequence[os.PathLike]
sequence (e.g. list) of propagation zip files.
index : slice | int | None, optional
what index of each propagation to load (e.g. `None`->all spectra, `-1`->last spectrum),
by default None
progress_bar : bool, optional
print a progress bar to stderr as the files are loading (requires tqdm), by default False
Returns
-------
PropagationCollection
convenient object to work with multiple propagations **built on the same grid**
"""
if len(files) == 0:
raise ValueError("You must provide at least one file to build a propagation series")
propagations = [open_existing_propagation(f, load_parameters=False) for f in files]
parameters = propagations[0].load_parameters()
rest = propagations
if progress_bar:
try:
from tqdm import tqdm
except ModuleNotFoundError as e:
logger.error(str(e))
pass
rest = tqdm(rest)
spectrum = Spectrum.from_params([prop[:] for prop in rest], parameters)
for prop in propagations:
del prop.parameters
return spectrum, PropagationCollection(
propagations,
parameters,
z=parameters.z_targets[: spectrum.shape[0]],
t=spectrum.t,
w=spectrum.w_disp,
wl=spectrum.wl_disp,
)
def open_existing_propagation(file: os.PathLike, load_parameters: bool = True) -> Propagation:
file = Path(file)
if not file.exists():
raise FileNotFoundError(f"no propagation found at {file}")
io = ZipFileIOHandler(file)
if load_parameters:
return Propagation(io)
else:
return Propagation(io, None) return Propagation(io, None)
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
params.compile_in_place(exhaustive=True, strict=False)
for k, v in params.items():
if isinstance(v, DataFile):
v.io = io
return Propagation(io, params)
def _create_new_propagation( def _create_new_propagation(

View File

@@ -7,7 +7,7 @@ import pytest
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
from scgenerator.parameter import Parameters from scgenerator.parameter import Parameters
from scgenerator.spectra import PARAMS_FN, propagation from scgenerator.spectra import PARAMS_FN, propagation, propagation_series
PARAMS = dict( PARAMS = dict(
name="PM2000D sech pulse", name="PM2000D sech pulse",
@@ -25,7 +25,7 @@ PARAMS = dict(
raman_type="measured", raman_type="measured",
quantum_noise=True, quantum_noise=True,
interpolation_degree=11, interpolation_degree=11,
z_num=128, z_num=8,
length=1.5, length=1.5,
t_num=512, t_num=512,
dt=5e-15, dt=5e-15,
@@ -173,3 +173,28 @@ def test_zip_bundle(tmp_path: Path):
def test_unique_name(): def test_unique_name():
existing = {"spec.npy", "spec_0.npy"} existing = {"spec.npy", "spec_0.npy"}
assert unique_name("spec.npy", existing) == "spec_1.npy" assert unique_name("spec.npy", existing) == "spec_1.npy"
existing = {"spec", "spec0"}
assert unique_name("spec", existing) == "spec_0"
existing = {"spec", "spec_0"}
assert unique_name("spec", existing) == "spec_1"
def test_propagation_series(tmp_path: Path):
params = Parameters(**PARAMS)
with pytest.raises(ValueError):
specs, props = propagation_series([])
flist = [tmp_path / f"prop{i}.zip" for i in range(10)]
for i, f in enumerate(flist):
params.name = f"prop {i}"
prop = propagation(f, params)
for _ in range(params.z_num):
prop.append(np.zeros(params.t_num, dtype=complex))
assert set(flist) == set(tmp_path.glob("*.zip"))
specs, propagations = propagation_series(flist)
assert specs.shape == (10, params.z_num, params.t_num)
assert all(prop.parameters.name == f"prop {i}" for i, prop in enumerate(propagations))