basic PropagationCollection
This commit is contained in:
@@ -2,8 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Generic, TypeVar, overload
|
from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -21,6 +23,7 @@ from scgenerator.physics import pulse, units
|
|||||||
|
|
||||||
PARAMS_FN = "params.json"
|
PARAMS_FN = "params.json"
|
||||||
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
|
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Spectrum(np.ndarray):
|
class Spectrum(np.ndarray):
|
||||||
@@ -147,6 +150,9 @@ class Spectrum(np.ndarray):
|
|||||||
freq_amp = afreq_amp
|
freq_amp = afreq_amp
|
||||||
|
|
||||||
|
|
||||||
|
NO_PARAMS = object()
|
||||||
|
|
||||||
|
|
||||||
class Propagation(Generic[ParamsOrNone]):
|
class Propagation(Generic[ParamsOrNone]):
|
||||||
io: PropagationIOHandler
|
io: PropagationIOHandler
|
||||||
parameters: ParamsOrNone
|
parameters: ParamsOrNone
|
||||||
@@ -155,7 +161,7 @@ class Propagation(Generic[ParamsOrNone]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self: Propagation[ParamsOrNone],
|
self: Propagation[ParamsOrNone],
|
||||||
io_handler: PropagationIOHandler,
|
io_handler: PropagationIOHandler,
|
||||||
params: ParamsOrNone,
|
params: ParamsOrNone = NO_PARAMS,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
A propagation is the object that manages IO for one single propagation.
|
A propagation is the object that manages IO for one single propagation.
|
||||||
@@ -166,15 +172,18 @@ class Propagation(Generic[ParamsOrNone]):
|
|||||||
----------
|
----------
|
||||||
io : PropagationIOHandler
|
io : PropagationIOHandler
|
||||||
object that implements the PropagationIOHandler Protocol.
|
object that implements the PropagationIOHandler Protocol.
|
||||||
params : Parameters | None
|
params : Parameters | None, optional
|
||||||
simulations parameters. Those will be passed on to spectra when loaded.
|
- if `Parameters`, those will be passed on to spectra when loaded to create `Spectrum`
|
||||||
if None, loading spectra will result in normal numpy arrays rather than `Spectrum` obj.
|
objects.
|
||||||
|
- if explicitly None, loading spectra will result in normal numpy arrays rather than
|
||||||
|
`Spectrum` obj.
|
||||||
|
- if unspecified, parameters will be lazily loaded the first time the `parameters`
|
||||||
|
attribute is accessed (e.g. loading a spectrum).
|
||||||
"""
|
"""
|
||||||
self.io = io_handler
|
self.io = io_handler
|
||||||
self._current_index = len(self.io)
|
self._current_index = len(self.io)
|
||||||
|
if params is not NO_PARAMS:
|
||||||
self.parameters = params
|
self.parameters = params
|
||||||
if self.parameters is not None:
|
|
||||||
self.z_positions = self.parameters.compute("z_targets")
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self._current_index
|
return self._current_index
|
||||||
@@ -190,24 +199,23 @@ class Propagation(Generic[ParamsOrNone]):
|
|||||||
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
|
def __getitem__(self, key: int | slice) -> Spectrum | np.ndarray:
|
||||||
if isinstance(key, slice):
|
if isinstance(key, slice):
|
||||||
return self._load_slice(key)
|
return self._load_slice(key)
|
||||||
if isinstance(key, (float, np.floating)):
|
|
||||||
if self.parameters is None:
|
|
||||||
raise ValueError(f"cannot accept float key {key} when parameters is not set")
|
|
||||||
key = math.argclosest(self.z_positions, key)
|
|
||||||
elif key < 0:
|
elif key < 0:
|
||||||
self._warn_negative_index(key)
|
self._warn_negative_index(key)
|
||||||
key = len(self) + key
|
key = len(self) + key
|
||||||
array = self.io.load_spectrum(key)
|
array = self.io.load_spectrum(key)
|
||||||
if self.parameters is not None:
|
return self._spectrumize(array)
|
||||||
return Spectrum.from_params(array, self.parameters)
|
|
||||||
else:
|
|
||||||
return array
|
|
||||||
|
|
||||||
def __setitem__(self, key: int, value: np.ndarray):
|
def __setitem__(self, key: int, value: np.ndarray):
|
||||||
if not isinstance(key, int):
|
if not isinstance(key, int):
|
||||||
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
|
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
|
||||||
self.io.save_spectrum(key, np.asarray(value))
|
self.io.save_spectrum(key, np.asarray(value))
|
||||||
|
|
||||||
|
def _spectrumize(self, array: np.ndarray | Spectrum) -> np.ndarray | Spectrum:
|
||||||
|
if self.parameters is not None and not isinstance(array, Spectrum):
|
||||||
|
return Spectrum.from_params(array, self.parameters)
|
||||||
|
else:
|
||||||
|
return array
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
|
def _load_slice(self: Propagation[Parameters], key: slice) -> Spectrum:
|
||||||
...
|
...
|
||||||
@@ -220,17 +228,23 @@ class Propagation(Generic[ParamsOrNone]):
|
|||||||
self._warn_negative_index(key.start)
|
self._warn_negative_index(key.start)
|
||||||
self._warn_negative_index(key.stop)
|
self._warn_negative_index(key.stop)
|
||||||
_iter = range(len(self))[key]
|
_iter = range(len(self))[key]
|
||||||
# if self.parameters is not None:
|
return self._spectrumize(np.array([self.io.load_spectrum(i) for i in _iter]))
|
||||||
# out = Spectrum.from_params(
|
|
||||||
# np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
|
@cached_property
|
||||||
# )
|
def parameters(self):
|
||||||
# for i in _iter:
|
try:
|
||||||
# out[i] = self.io.load_spectrum(i)
|
return self.load_parameters()
|
||||||
# else:
|
except Exception as e:
|
||||||
out = np.array([self.io.load_spectrum(i) for i in _iter])
|
logger.error(str(e))
|
||||||
if self.parameters is not None:
|
return None
|
||||||
out = Spectrum.from_params(out, self.parameters)
|
|
||||||
return out
|
def load_parameters(self) -> Parameters:
|
||||||
|
params = Parameters.from_json(self.io.load_data(PARAMS_FN).decode())
|
||||||
|
params.compile_in_place(exhaustive=True, strict=False)
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, DataFile):
|
||||||
|
v.io = self.io
|
||||||
|
return params
|
||||||
|
|
||||||
def append(self, spectrum: np.ndarray):
|
def append(self, spectrum: np.ndarray):
|
||||||
self.io.save_spectrum(self._current_index, np.asarray(spectrum))
|
self.io.save_spectrum(self._current_index, np.asarray(spectrum))
|
||||||
@@ -246,6 +260,30 @@ class Propagation(Generic[ParamsOrNone]):
|
|||||||
return self._load_slice(slice(None))
|
return self._load_slice(slice(None))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PropagationCollection:
|
||||||
|
propagations: list[Propagation]
|
||||||
|
parameters: Parameters
|
||||||
|
z: np.ndarray
|
||||||
|
t: np.ndarray
|
||||||
|
w: np.ndarray
|
||||||
|
wl: np.ndarray
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, key: int) -> Propagation:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, key: slice) -> list[Propagation]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __getitem__(self, key: int | slice) -> Propagation | list[Propagation]:
|
||||||
|
return self.propagations[key]
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Propagation]:
|
||||||
|
yield from self.propagations
|
||||||
|
|
||||||
|
|
||||||
def propagation(
|
def propagation(
|
||||||
file_or_params: os.PathLike | Parameters,
|
file_or_params: os.PathLike | Parameters,
|
||||||
params: Parameters | None = None,
|
params: Parameters | None = None,
|
||||||
@@ -284,11 +322,12 @@ def propagation(
|
|||||||
file = Path(file_or_params)
|
file = Path(file_or_params)
|
||||||
|
|
||||||
if file is not None and file.exists() and params is None:
|
if file is not None and file.exists() and params is None:
|
||||||
io = ZipFileIOHandler(file)
|
return open_existing_propagation(file, load_parameters=load_parameters)
|
||||||
return _open_existing_propagation(io, load_parameters=load_parameters)
|
|
||||||
|
|
||||||
if params is None:
|
if params is None:
|
||||||
raise ValueError("Parameters must be specified to create new simulation")
|
raise ValueError(
|
||||||
|
f"{file} doesn't exist, but no parameters have been passed to create a new propagation"
|
||||||
|
)
|
||||||
|
|
||||||
if file is not None:
|
if file is not None:
|
||||||
if file.exists() and params is not None:
|
if file.exists() and params is not None:
|
||||||
@@ -307,17 +346,67 @@ def propagation(
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def _open_existing_propagation(
|
def propagation_series(
|
||||||
io: PropagationIOHandler, load_parameters: bool = True
|
files: Sequence[os.PathLike], index: int | slice | None = None, progress_bar: bool = False
|
||||||
) -> Propagation:
|
) -> tuple[Spectrum, PropagationCollection]:
|
||||||
if not load_parameters:
|
"""
|
||||||
|
loads an existing sequence of propagation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
files : Sequence[os.PathLike]
|
||||||
|
sequence (e.g. list) of propagation zip files.
|
||||||
|
index : slice | int | None, optional
|
||||||
|
what index of each propagation to load (e.g. `None`->all spectra, `-1`->last spectrum),
|
||||||
|
by default None
|
||||||
|
progress_bar : bool, optional
|
||||||
|
print a progress bar to stderr as the files are loading (requires tqdm), by default False
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
PropagationCollection
|
||||||
|
convenient object to work with multiple propagations **built on the same grid**
|
||||||
|
"""
|
||||||
|
if len(files) == 0:
|
||||||
|
raise ValueError("You must provide at least one file to build a propagation series")
|
||||||
|
|
||||||
|
propagations = [open_existing_propagation(f, load_parameters=False) for f in files]
|
||||||
|
parameters = propagations[0].load_parameters()
|
||||||
|
|
||||||
|
rest = propagations
|
||||||
|
if progress_bar:
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
pass
|
||||||
|
|
||||||
|
rest = tqdm(rest)
|
||||||
|
|
||||||
|
spectrum = Spectrum.from_params([prop[:] for prop in rest], parameters)
|
||||||
|
for prop in propagations:
|
||||||
|
del prop.parameters
|
||||||
|
|
||||||
|
return spectrum, PropagationCollection(
|
||||||
|
propagations,
|
||||||
|
parameters,
|
||||||
|
z=parameters.z_targets[: spectrum.shape[0]],
|
||||||
|
t=spectrum.t,
|
||||||
|
w=spectrum.w_disp,
|
||||||
|
wl=spectrum.wl_disp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def open_existing_propagation(file: os.PathLike, load_parameters: bool = True) -> Propagation:
|
||||||
|
file = Path(file)
|
||||||
|
if not file.exists():
|
||||||
|
raise FileNotFoundError(f"no propagation found at {file}")
|
||||||
|
io = ZipFileIOHandler(file)
|
||||||
|
|
||||||
|
if load_parameters:
|
||||||
|
return Propagation(io)
|
||||||
|
else:
|
||||||
return Propagation(io, None)
|
return Propagation(io, None)
|
||||||
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
|
|
||||||
params.compile_in_place(exhaustive=True, strict=False)
|
|
||||||
for k, v in params.items():
|
|
||||||
if isinstance(v, DataFile):
|
|
||||||
v.io = io
|
|
||||||
return Propagation(io, params)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_new_propagation(
|
def _create_new_propagation(
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import pytest
|
|||||||
|
|
||||||
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
|
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
|
||||||
from scgenerator.parameter import Parameters
|
from scgenerator.parameter import Parameters
|
||||||
from scgenerator.spectra import PARAMS_FN, propagation
|
from scgenerator.spectra import PARAMS_FN, propagation, propagation_series
|
||||||
|
|
||||||
PARAMS = dict(
|
PARAMS = dict(
|
||||||
name="PM2000D sech pulse",
|
name="PM2000D sech pulse",
|
||||||
@@ -25,7 +25,7 @@ PARAMS = dict(
|
|||||||
raman_type="measured",
|
raman_type="measured",
|
||||||
quantum_noise=True,
|
quantum_noise=True,
|
||||||
interpolation_degree=11,
|
interpolation_degree=11,
|
||||||
z_num=128,
|
z_num=8,
|
||||||
length=1.5,
|
length=1.5,
|
||||||
t_num=512,
|
t_num=512,
|
||||||
dt=5e-15,
|
dt=5e-15,
|
||||||
@@ -173,3 +173,28 @@ def test_zip_bundle(tmp_path: Path):
|
|||||||
def test_unique_name():
|
def test_unique_name():
|
||||||
existing = {"spec.npy", "spec_0.npy"}
|
existing = {"spec.npy", "spec_0.npy"}
|
||||||
assert unique_name("spec.npy", existing) == "spec_1.npy"
|
assert unique_name("spec.npy", existing) == "spec_1.npy"
|
||||||
|
|
||||||
|
existing = {"spec", "spec0"}
|
||||||
|
assert unique_name("spec", existing) == "spec_0"
|
||||||
|
|
||||||
|
existing = {"spec", "spec_0"}
|
||||||
|
assert unique_name("spec", existing) == "spec_1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_propagation_series(tmp_path: Path):
|
||||||
|
params = Parameters(**PARAMS)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
specs, props = propagation_series([])
|
||||||
|
|
||||||
|
flist = [tmp_path / f"prop{i}.zip" for i in range(10)]
|
||||||
|
for i, f in enumerate(flist):
|
||||||
|
params.name = f"prop {i}"
|
||||||
|
prop = propagation(f, params)
|
||||||
|
for _ in range(params.z_num):
|
||||||
|
prop.append(np.zeros(params.t_num, dtype=complex))
|
||||||
|
|
||||||
|
assert set(flist) == set(tmp_path.glob("*.zip"))
|
||||||
|
|
||||||
|
specs, propagations = propagation_series(flist)
|
||||||
|
assert specs.shape == (10, params.z_num, params.t_num)
|
||||||
|
assert all(prop.parameters.name == f"prop {i}" for i, prop in enumerate(propagations))
|
||||||
|
|||||||
Reference in New Issue
Block a user