cleanup io handling and Propagation
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
# ruff: noqa
|
||||
from scgenerator import io, math, operators, plotting
|
||||
from scgenerator.helpers import *
|
||||
from scgenerator.io import MemoryIOHandler, ZipFileIOHandler
|
||||
from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace
|
||||
from scgenerator.parameter import Parameters
|
||||
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
||||
from scgenerator.physics.units import PlotRange
|
||||
from scgenerator.solver import SimulationResult, integrate, solve43
|
||||
from scgenerator.spectra import Propagation, Spectrum
|
||||
|
||||
@@ -33,6 +33,7 @@ MANDATORY_PARAMETERS = {
|
||||
"name",
|
||||
"w",
|
||||
"t",
|
||||
"l",
|
||||
"fft",
|
||||
"ifft",
|
||||
"w0",
|
||||
|
||||
@@ -72,6 +72,9 @@ class PropagationIOHandler(Protocol):
|
||||
def load_data(self, name: str) -> bytes:
|
||||
...
|
||||
|
||||
def clear(self):
|
||||
...
|
||||
|
||||
|
||||
class MemoryIOHandler:
|
||||
spectra: dict[int, np.ndarray]
|
||||
@@ -99,6 +102,10 @@ class MemoryIOHandler:
|
||||
def load_data(self, name: str) -> bytes:
|
||||
return self.data[name]
|
||||
|
||||
def clear(self):
|
||||
self.spectra = {}
|
||||
self.data = {}
|
||||
|
||||
|
||||
class ZipFileIOHandler:
|
||||
file: BinaryIO
|
||||
@@ -153,6 +160,9 @@ class ZipFileIOHandler:
|
||||
with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file:
|
||||
return file.read()
|
||||
|
||||
def clear(self):
|
||||
self.file.unlink(missing_ok=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataFile:
|
||||
@@ -201,11 +211,15 @@ class DataFile:
|
||||
f"a bundled file prefixed with {self.prefix} "
|
||||
"must have a PropagationIOHandler attached"
|
||||
)
|
||||
if self.io is not None:
|
||||
# a DataFile obj may have a useless io obj attached to it
|
||||
if self.prefix is not None:
|
||||
return self.io.load_data(self.path)
|
||||
else:
|
||||
return Path(self.path).read_bytes()
|
||||
|
||||
def similar_to(self, other: DataFile) -> bool:
|
||||
return Path(self.path).name == Path(other.path).name
|
||||
|
||||
|
||||
def unique_name(base_name: str, existing: set[str]) -> str:
|
||||
name = base_name
|
||||
|
||||
@@ -294,7 +294,7 @@ class Parameter:
|
||||
return is_value, v
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
@dataclass(repr=False, eq=False)
|
||||
class Parameters:
|
||||
"""
|
||||
This class defines each valid parameter's name, type and valid value.
|
||||
@@ -303,7 +303,7 @@ class Parameters:
|
||||
# internal machinery
|
||||
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
|
||||
_p_names: ClassVar[Set[str]] = set()
|
||||
_frozen: bool = field(init=False, default=False, repr=False)
|
||||
frozen: bool = field(init=False, default=False, repr=False)
|
||||
|
||||
# root
|
||||
name: str = Parameter(string, default="no name")
|
||||
@@ -460,28 +460,45 @@ class Parameters:
|
||||
self.__post_init__()
|
||||
|
||||
def __setattr__(self, k, v):
|
||||
if self._frozen and not k.endswith("_file"):
|
||||
if self.frozen and not k.endswith("_file"):
|
||||
raise AttributeError(
|
||||
f"cannot set attribute to frozen {self.__class__.__name__} instance"
|
||||
)
|
||||
object.__setattr__(self, k, v)
|
||||
|
||||
def __eq__(self, other: Parameters) -> bool:
|
||||
if not isinstance(other, Parameters):
|
||||
raise TypeError(
|
||||
f"cannot compare {self.__class__.__name__!r} with {type(other).__name__!r}"
|
||||
)
|
||||
for k, v in self.items():
|
||||
other_v = getattr(other, k)
|
||||
if isinstance(v, DataFile) and not v.similar_to(other_v):
|
||||
return False
|
||||
if other_v != v:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def items(self) -> Iterator[tuple[str, Any]]:
|
||||
for k, v in self._param_dico.items():
|
||||
if v is None:
|
||||
continue
|
||||
yield k, v
|
||||
|
||||
def copy(self, freeze: bool = False) -> Parameters:
|
||||
def copy(self, deep: bool = True, freeze: bool = False) -> Parameters:
|
||||
"""create a deep copy of self. if freeze is True, the returned copy is read-only"""
|
||||
params = Parameters(**deepcopy(self.strip_params_dict()))
|
||||
if deep:
|
||||
params = Parameters(**deepcopy(self.strip_params_dict()))
|
||||
else:
|
||||
params = Parameters(**self.strip_params_dict())
|
||||
if freeze:
|
||||
params.freeze()
|
||||
return params
|
||||
|
||||
def freeze(self):
|
||||
"""render the current instance read-only. This is not reversible"""
|
||||
self._frozen = True
|
||||
self.frozen = True
|
||||
|
||||
def to_json(self) -> str:
|
||||
d = self.dump_dict()
|
||||
@@ -536,7 +553,7 @@ class Parameters:
|
||||
else:
|
||||
return first
|
||||
|
||||
def compile(self, exhaustive=False) -> Parameters:
|
||||
def compile(self, exhaustive=False, strict: bool = True) -> Parameters:
|
||||
"""
|
||||
Computes missing parameters and returns them in a frozen `Parameters` instance
|
||||
|
||||
@@ -547,6 +564,8 @@ class Parameters:
|
||||
Depending on the specifics of the model and how the parameters were specified, there
|
||||
might be no difference between a normal compilation and an exhaustive one.
|
||||
by default False
|
||||
strict : bool, optional
|
||||
raise an exception when something cannot be computed, by default True
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -560,16 +579,23 @@ class Parameters:
|
||||
When all the necessary parameters cannot be computed, a `ValueError` is raised. In most
|
||||
cases, this is due to underdetermination by the user.
|
||||
"""
|
||||
obj = self.copy(deep=False, freeze=False)
|
||||
obj.compile_in_place(exhaustive, strict)
|
||||
return obj
|
||||
|
||||
def compile_in_place(self, exhaustive: bool = False, strict: bool = True):
|
||||
to_compute = MANDATORY_PARAMETERS
|
||||
evaluator = self.get_evaluator()
|
||||
try:
|
||||
for k in to_compute:
|
||||
for k in to_compute:
|
||||
try:
|
||||
evaluator.compute(k)
|
||||
except EvaluatorError as e:
|
||||
raise ValueError(
|
||||
"Could not compile the parameter set. Most likely, "
|
||||
f"an essential value is missing\n{e}"
|
||||
) from None
|
||||
except EvaluatorError as e:
|
||||
if strict:
|
||||
raise ValueError(
|
||||
"Could not compile the parameter set. Most likely, "
|
||||
f"an essential value is missing\n{e}"
|
||||
) from None
|
||||
|
||||
if exhaustive:
|
||||
for p in self._p_names:
|
||||
if p not in evaluator.main_map:
|
||||
@@ -577,11 +603,10 @@ class Parameters:
|
||||
evaluator.compute(p)
|
||||
except Exception:
|
||||
pass
|
||||
computed = self.__class__(
|
||||
**{k: v.value for k, v in evaluator.main_map.items() if k in self._p_names}
|
||||
)
|
||||
computed._frozen = True
|
||||
return computed
|
||||
self._param_dico |= {
|
||||
k: v.value for k, v in evaluator.main_map.items() if k in self._p_names
|
||||
}
|
||||
self.freeze()
|
||||
|
||||
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
||||
"""return a pretty formatted string describing the parameters"""
|
||||
|
||||
@@ -3,14 +3,23 @@ from __future__ import annotations
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scgenerator import math
|
||||
from scgenerator.io import DataFile, PropagationIOHandler, ZipFileIOHandler, unique_name
|
||||
from scgenerator.io import (
|
||||
DataFile,
|
||||
MemoryIOHandler,
|
||||
PropagationIOHandler,
|
||||
ZipFileIOHandler,
|
||||
unique_name,
|
||||
)
|
||||
from scgenerator.parameter import Parameters
|
||||
from scgenerator.physics import pulse, units
|
||||
|
||||
PARAMS_FN = "params.json"
|
||||
|
||||
|
||||
class Spectrum(np.ndarray):
|
||||
w: np.ndarray
|
||||
@@ -25,7 +34,7 @@ class Spectrum(np.ndarray):
|
||||
# add the new attribute to the created instance
|
||||
obj.w = params.compute("w")
|
||||
obj.t = params.compute("t")
|
||||
obj.t = params.compute("t")
|
||||
obj.l = params.compute("l")
|
||||
obj.ifft = params.compute("ifft")
|
||||
|
||||
# Finally, we must return the newly created object:
|
||||
@@ -116,16 +125,13 @@ class Spectrum(np.ndarray):
|
||||
|
||||
class Propagation:
|
||||
io: PropagationIOHandler
|
||||
params: Parameters
|
||||
parameters: Parameters
|
||||
_current_index: int
|
||||
|
||||
PARAMS_FN = "params.json"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
io_handler: PropagationIOHandler,
|
||||
params: Parameters | None = None,
|
||||
bundle_data: bool = False,
|
||||
params: Parameters,
|
||||
):
|
||||
"""
|
||||
A propagation is the object that manages IO for one single propagation.
|
||||
@@ -141,31 +147,7 @@ class Propagation:
|
||||
"""
|
||||
self.io = io_handler
|
||||
self._current_index = len(self.io)
|
||||
|
||||
new_params = params is not None
|
||||
if not new_params:
|
||||
if bundle_data:
|
||||
raise ValueError(
|
||||
"cannot bundle data to existing Propagation. Create a new one instead"
|
||||
)
|
||||
try:
|
||||
params_data = self.io.load_data(self.PARAMS_FN)
|
||||
params = Parameters.from_json(params_data.decode())
|
||||
except KeyError:
|
||||
raise ValueError(f"Missing Parameters in {self.__class__.__name__}.") from None
|
||||
|
||||
self.params = params
|
||||
|
||||
if bundle_data:
|
||||
if not params._frozen:
|
||||
warnings.warn(
|
||||
f"Parameters instance {params.name!r} is not frozen but will be saved "
|
||||
"in an unmodifiable state anyway."
|
||||
)
|
||||
_bundle_external_files(self.params, self.io)
|
||||
|
||||
if new_params:
|
||||
self.io.save_data(self.PARAMS_FN, self.params.to_json().encode())
|
||||
self.parameters = params
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._current_index
|
||||
@@ -174,9 +156,9 @@ class Propagation:
|
||||
if isinstance(key, slice):
|
||||
return self._load_slice(key)
|
||||
if isinstance(key, (float, np.floating)):
|
||||
key = math.argclosest(self.params.compute("z_targets"), key)
|
||||
key = math.argclosest(self.parameters.compute("z_targets"), key)
|
||||
array = self.io.load_spectrum(key)
|
||||
return Spectrum(array, self.params)
|
||||
return Spectrum(array, self.parameters)
|
||||
|
||||
def __setitem__(self, key: int, value: np.ndarray):
|
||||
if not isinstance(key, int):
|
||||
@@ -185,19 +167,78 @@ class Propagation:
|
||||
|
||||
def _load_slice(self, key: slice) -> Spectrum:
|
||||
_iter = range(len(self))[key]
|
||||
out = Spectrum(np.zeros((len(_iter), self.params.t_num), dtype=complex), self.params)
|
||||
out = Spectrum(
|
||||
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
|
||||
)
|
||||
for i in _iter:
|
||||
out[i] = self.io.load_spectrum(i)
|
||||
return out
|
||||
|
||||
def append(self, spectrum: np.ndarray):
|
||||
self.io.save_spectrum(self._current_index, spectrum.asarray())
|
||||
self.io.save_spectrum(self._current_index, np.asarray(spectrum))
|
||||
self._current_index += 1
|
||||
|
||||
def load_all(self) -> Spectrum:
|
||||
return self._load_slice(slice(None))
|
||||
|
||||
|
||||
def load_all(path: os.PathLike) -> Spectrum:
|
||||
io = ZipFileIOHandler(path)
|
||||
return Propagation(io).load_all()
|
||||
|
||||
|
||||
def propagation(
|
||||
file_or_params: os.PathLike | Parameters,
|
||||
params: Parameters | None = None,
|
||||
bundle_data: bool = False,
|
||||
) -> Propagation:
|
||||
file = None
|
||||
if isinstance(file_or_params, Parameters):
|
||||
params = file_or_params
|
||||
else:
|
||||
file = Path(file_or_params)
|
||||
|
||||
if file is not None and file.exists():
|
||||
io = ZipFileIOHandler(file)
|
||||
return _open_existing_propagation(io)
|
||||
|
||||
if params is None:
|
||||
raise ValueError("Parameters must be specified to create new simulation")
|
||||
|
||||
if file is not None:
|
||||
io = ZipFileIOHandler(file)
|
||||
else:
|
||||
io = MemoryIOHandler()
|
||||
|
||||
try:
|
||||
return _create_new_propagation(io, params, bundle_data)
|
||||
except Exception as e:
|
||||
io.clear()
|
||||
raise e
|
||||
|
||||
|
||||
def _open_existing_propagation(io: PropagationIOHandler) -> Propagation:
|
||||
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
|
||||
params.compile_in_place(exhaustive=True, strict=False)
|
||||
for k, v in params.items():
|
||||
if isinstance(v, DataFile):
|
||||
v.io = io
|
||||
return Propagation(io, params)
|
||||
|
||||
|
||||
def _create_new_propagation(
|
||||
io: PropagationIOHandler, params: Parameters, bundle_data: bool
|
||||
) -> Propagation:
|
||||
if params.frozen:
|
||||
params = params.copy()
|
||||
else:
|
||||
params = params.compile(exhaustive=True, strict=False)
|
||||
if bundle_data:
|
||||
_bundle_external_files(params, io)
|
||||
io.save_data(PARAMS_FN, params.to_json().encode())
|
||||
return Propagation(io, params)
|
||||
|
||||
|
||||
def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
|
||||
"""copies every external file specified in the parameters and saves it"""
|
||||
existing_files = set(io.keys())
|
||||
@@ -211,8 +252,3 @@ def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
|
||||
existing_files.add(value.path)
|
||||
|
||||
io.save_data(value.path, data)
|
||||
|
||||
|
||||
def load_all(path: os.PathLike) -> Spectrum:
|
||||
io = ZipFileIOHandler(path)
|
||||
return Propagation(io).load_all()
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
|
||||
from scgenerator.parameter import Parameters
|
||||
from scgenerator.spectra import Propagation
|
||||
from scgenerator.spectra import PARAMS_FN, propagation
|
||||
|
||||
PARAMS = dict(
|
||||
name="PM2000D sech pulse",
|
||||
@@ -78,13 +78,34 @@ def test_memory():
|
||||
assert len(io) == 8
|
||||
|
||||
|
||||
def test_reopen(tmp_path: Path):
|
||||
zpath = tmp_path / "file.zip"
|
||||
params = Parameters(**PARAMS)
|
||||
prop = propagation(zpath, params)
|
||||
|
||||
prop2 = propagation(zpath)
|
||||
|
||||
assert prop.parameters == prop2.parameters
|
||||
|
||||
|
||||
def test_clear(tmp_path: Path):
|
||||
params = Parameters(**PARAMS)
|
||||
zpath = tmp_path / "file.zip"
|
||||
prop = propagation(zpath, params)
|
||||
|
||||
assert zpath.exists()
|
||||
assert zpath.read_bytes() != b""
|
||||
|
||||
prop.io.clear()
|
||||
|
||||
assert not zpath.exists()
|
||||
|
||||
|
||||
def test_zip_bundle(tmp_path: Path):
|
||||
params = Parameters(**PARAMS)
|
||||
io = ZipFileIOHandler(tmp_path / "file.zip")
|
||||
prop = Propagation(io, params.copy(True))
|
||||
|
||||
assert (tmp_path / "file.zip").exists()
|
||||
assert (tmp_path / "file.zip").read_bytes() != b""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
propagation(tmp_path / "file2.zip", params, bundle_data=True)
|
||||
|
||||
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
|
||||
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
|
||||
@@ -92,30 +113,32 @@ def test_zip_bundle(tmp_path: Path):
|
||||
params.effective_area_file.path = str(new_aeff_path)
|
||||
params.freeze()
|
||||
|
||||
io = ZipFileIOHandler(tmp_path / "file2.zip")
|
||||
prop2 = Propagation(io, params, True)
|
||||
prop2 = propagation(tmp_path / "file3.zip", params, bundle_data=True)
|
||||
|
||||
assert params.dispersion_file.path == new_disp_path.name
|
||||
assert params.dispersion_file.prefix == "zip"
|
||||
assert params.effective_area_file.path == new_aeff_path.name
|
||||
assert params.effective_area_file.prefix == "zip"
|
||||
assert prop2.parameters.dispersion_file.path == new_disp_path.name
|
||||
assert prop2.parameters.dispersion_file.prefix == "zip"
|
||||
assert prop2.parameters.effective_area_file.path == new_aeff_path.name
|
||||
assert prop2.parameters.effective_area_file.prefix == "zip"
|
||||
|
||||
with ZipFile(tmp_path / "file2.zip", "r") as zfile:
|
||||
with ZipFile(tmp_path / "file3.zip", "r") as zfile:
|
||||
with zfile.open(new_aeff_path.name) as file:
|
||||
assert file.read() == new_aeff_path.read_bytes()
|
||||
with zfile.open(new_disp_path.name) as file:
|
||||
assert file.read() == new_disp_path.read_bytes()
|
||||
|
||||
with zfile.open(Propagation.PARAMS_FN) as file:
|
||||
with zfile.open(PARAMS_FN) as file:
|
||||
df = json.loads(file.read().decode())
|
||||
|
||||
assert (
|
||||
df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path
|
||||
df["dispersion_file"]
|
||||
== prop2.parameters.dispersion_file.prefix + "::" + Path(params.dispersion_file.path).name
|
||||
)
|
||||
|
||||
assert (
|
||||
df["effective_area_file"]
|
||||
== params.effective_area_file.prefix + "::" + params.effective_area_file.path
|
||||
== prop2.parameters.effective_area_file.prefix
|
||||
+ "::"
|
||||
+ Path(params.effective_area_file.path).name
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user