cleanup io handling and Propagation

This commit is contained in:
Benoît Sierro
2023-08-09 11:21:57 +02:00
parent a650169443
commit 7bb15871c3
6 changed files with 177 additions and 76 deletions

View File

@@ -1,8 +1,10 @@
# ruff: noqa
from scgenerator import io, math, operators, plotting
from scgenerator.helpers import *
from scgenerator.io import MemoryIOHandler, ZipFileIOHandler
from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace
from scgenerator.parameter import Parameters
from scgenerator.physics import fiber, materials, plasma, pulse, units
from scgenerator.physics.units import PlotRange
from scgenerator.solver import SimulationResult, integrate, solve43
from scgenerator.spectra import Propagation, Spectrum

View File

@@ -33,6 +33,7 @@ MANDATORY_PARAMETERS = {
"name",
"w",
"t",
"l",
"fft",
"ifft",
"w0",

View File

@@ -72,6 +72,9 @@ class PropagationIOHandler(Protocol):
def load_data(self, name: str) -> bytes:
...
def clear(self):
...
class MemoryIOHandler:
spectra: dict[int, np.ndarray]
@@ -99,6 +102,10 @@ class MemoryIOHandler:
def load_data(self, name: str) -> bytes:
return self.data[name]
def clear(self):
self.spectra = {}
self.data = {}
class ZipFileIOHandler:
file: BinaryIO
@@ -153,6 +160,9 @@ class ZipFileIOHandler:
with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file:
return file.read()
def clear(self):
self.file.unlink(missing_ok=True)
@dataclass
class DataFile:
@@ -201,11 +211,15 @@ class DataFile:
f"a bundled file prefixed with {self.prefix} "
"must have a PropagationIOHandler attached"
)
if self.io is not None:
# a DataFile obj may have a useless io obj attached to it
if self.prefix is not None:
return self.io.load_data(self.path)
else:
return Path(self.path).read_bytes()
def similar_to(self, other: DataFile) -> bool:
return Path(self.path).name == Path(other.path).name
def unique_name(base_name: str, existing: set[str]) -> str:
name = base_name

View File

@@ -294,7 +294,7 @@ class Parameter:
return is_value, v
@dataclass(repr=False)
@dataclass(repr=False, eq=False)
class Parameters:
"""
This class defines each valid parameter's name, type and valid value.
@@ -303,7 +303,7 @@ class Parameters:
# internal machinery
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
_p_names: ClassVar[Set[str]] = set()
_frozen: bool = field(init=False, default=False, repr=False)
frozen: bool = field(init=False, default=False, repr=False)
# root
name: str = Parameter(string, default="no name")
@@ -460,28 +460,45 @@ class Parameters:
self.__post_init__()
def __setattr__(self, k, v):
if self._frozen and not k.endswith("_file"):
if self.frozen and not k.endswith("_file"):
raise AttributeError(
f"cannot set attribute to frozen {self.__class__.__name__} instance"
)
object.__setattr__(self, k, v)
def __eq__(self, other: Parameters) -> bool:
if not isinstance(other, Parameters):
raise TypeError(
f"cannot compare {self.__class__.__name__!r} with {type(other).__name__!r}"
)
for k, v in self.items():
other_v = getattr(other, k)
if isinstance(v, DataFile) and not v.similar_to(other_v):
return False
if other_v != v:
return False
return True
def items(self) -> Iterator[tuple[str, Any]]:
for k, v in self._param_dico.items():
if v is None:
continue
yield k, v
def copy(self, freeze: bool = False) -> Parameters:
def copy(self, deep: bool = True, freeze: bool = False) -> Parameters:
"""create a deep copy of self. if freeze is True, the returned copy is read-only"""
params = Parameters(**deepcopy(self.strip_params_dict()))
if deep:
params = Parameters(**deepcopy(self.strip_params_dict()))
else:
params = Parameters(**self.strip_params_dict())
if freeze:
params.freeze()
return params
def freeze(self):
"""render the current instance read-only. This is not reversible"""
self._frozen = True
self.frozen = True
def to_json(self) -> str:
d = self.dump_dict()
@@ -536,7 +553,7 @@ class Parameters:
else:
return first
def compile(self, exhaustive=False) -> Parameters:
def compile(self, exhaustive=False, strict: bool = True) -> Parameters:
"""
Computes missing parameters and returns them in a frozen `Parameters` instance
@@ -547,6 +564,8 @@ class Parameters:
Depending on the specifics of the model and how the parameters were specified, there
might be no difference between a normal compilation and an exhaustive one.
by default False
strict : bool, optional
raise an exception when something cannot be computed, by default True
Returns
-------
@@ -560,16 +579,23 @@ class Parameters:
When all the necessary parameters cannot be computed, a `ValueError` is raised. In most
cases, this is due to underdetermination by the user.
"""
obj = self.copy(deep=False, freeze=False)
obj.compile_in_place(exhaustive, strict)
return obj
def compile_in_place(self, exhaustive: bool = False, strict: bool = True):
to_compute = MANDATORY_PARAMETERS
evaluator = self.get_evaluator()
try:
for k in to_compute:
for k in to_compute:
try:
evaluator.compute(k)
except EvaluatorError as e:
raise ValueError(
"Could not compile the parameter set. Most likely, "
f"an essential value is missing\n{e}"
) from None
except EvaluatorError as e:
if strict:
raise ValueError(
"Could not compile the parameter set. Most likely, "
f"an essential value is missing\n{e}"
) from None
if exhaustive:
for p in self._p_names:
if p not in evaluator.main_map:
@@ -577,11 +603,10 @@ class Parameters:
evaluator.compute(p)
except Exception:
pass
computed = self.__class__(
**{k: v.value for k, v in evaluator.main_map.items() if k in self._p_names}
)
computed._frozen = True
return computed
self._param_dico |= {
k: v.value for k, v in evaluator.main_map.items() if k in self._p_names
}
self.freeze()
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
"""return a pretty formatted string describing the parameters"""

View File

@@ -3,14 +3,23 @@ from __future__ import annotations
import os
import warnings
from pathlib import Path
from typing import Callable
import numpy as np
from scgenerator import math
from scgenerator.io import DataFile, PropagationIOHandler, ZipFileIOHandler, unique_name
from scgenerator.io import (
DataFile,
MemoryIOHandler,
PropagationIOHandler,
ZipFileIOHandler,
unique_name,
)
from scgenerator.parameter import Parameters
from scgenerator.physics import pulse, units
PARAMS_FN = "params.json"
class Spectrum(np.ndarray):
w: np.ndarray
@@ -25,7 +34,7 @@ class Spectrum(np.ndarray):
# add the new attribute to the created instance
obj.w = params.compute("w")
obj.t = params.compute("t")
obj.t = params.compute("t")
obj.l = params.compute("l")
obj.ifft = params.compute("ifft")
# Finally, we must return the newly created object:
@@ -116,16 +125,13 @@ class Spectrum(np.ndarray):
class Propagation:
io: PropagationIOHandler
params: Parameters
parameters: Parameters
_current_index: int
PARAMS_FN = "params.json"
def __init__(
self,
io_handler: PropagationIOHandler,
params: Parameters | None = None,
bundle_data: bool = False,
params: Parameters,
):
"""
A propagation is the object that manages IO for one single propagation.
@@ -141,31 +147,7 @@ class Propagation:
"""
self.io = io_handler
self._current_index = len(self.io)
new_params = params is not None
if not new_params:
if bundle_data:
raise ValueError(
"cannot bundle data to existing Propagation. Create a new one instead"
)
try:
params_data = self.io.load_data(self.PARAMS_FN)
params = Parameters.from_json(params_data.decode())
except KeyError:
raise ValueError(f"Missing Parameters in {self.__class__.__name__}.") from None
self.params = params
if bundle_data:
if not params._frozen:
warnings.warn(
f"Parameters instance {params.name!r} is not frozen but will be saved "
"in an unmodifiable state anyway."
)
_bundle_external_files(self.params, self.io)
if new_params:
self.io.save_data(self.PARAMS_FN, self.params.to_json().encode())
self.parameters = params
def __len__(self) -> int:
return self._current_index
@@ -174,9 +156,9 @@ class Propagation:
if isinstance(key, slice):
return self._load_slice(key)
if isinstance(key, (float, np.floating)):
key = math.argclosest(self.params.compute("z_targets"), key)
key = math.argclosest(self.parameters.compute("z_targets"), key)
array = self.io.load_spectrum(key)
return Spectrum(array, self.params)
return Spectrum(array, self.parameters)
def __setitem__(self, key: int, value: np.ndarray):
if not isinstance(key, int):
@@ -185,19 +167,78 @@ class Propagation:
def _load_slice(self, key: slice) -> Spectrum:
_iter = range(len(self))[key]
out = Spectrum(np.zeros((len(_iter), self.params.t_num), dtype=complex), self.params)
out = Spectrum(
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
)
for i in _iter:
out[i] = self.io.load_spectrum(i)
return out
def append(self, spectrum: np.ndarray):
self.io.save_spectrum(self._current_index, spectrum.asarray())
self.io.save_spectrum(self._current_index, np.asarray(spectrum))
self._current_index += 1
def load_all(self) -> Spectrum:
return self._load_slice(slice(None))
def load_all(path: os.PathLike) -> Spectrum:
io = ZipFileIOHandler(path)
return Propagation(io).load_all()
def propagation(
file_or_params: os.PathLike | Parameters,
params: Parameters | None = None,
bundle_data: bool = False,
) -> Propagation:
file = None
if isinstance(file_or_params, Parameters):
params = file_or_params
else:
file = Path(file_or_params)
if file is not None and file.exists():
io = ZipFileIOHandler(file)
return _open_existing_propagation(io)
if params is None:
raise ValueError("Parameters must be specified to create new simulation")
if file is not None:
io = ZipFileIOHandler(file)
else:
io = MemoryIOHandler()
try:
return _create_new_propagation(io, params, bundle_data)
except Exception as e:
io.clear()
raise e
def _open_existing_propagation(io: PropagationIOHandler) -> Propagation:
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
params.compile_in_place(exhaustive=True, strict=False)
for k, v in params.items():
if isinstance(v, DataFile):
v.io = io
return Propagation(io, params)
def _create_new_propagation(
io: PropagationIOHandler, params: Parameters, bundle_data: bool
) -> Propagation:
if params.frozen:
params = params.copy()
else:
params = params.compile(exhaustive=True, strict=False)
if bundle_data:
_bundle_external_files(params, io)
io.save_data(PARAMS_FN, params.to_json().encode())
return Propagation(io, params)
def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
"""copies every external file specified in the parameters and saves it"""
existing_files = set(io.keys())
@@ -211,8 +252,3 @@ def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
existing_files.add(value.path)
io.save_data(value.path, data)
def load_all(path: os.PathLike) -> Spectrum:
io = ZipFileIOHandler(path)
return Propagation(io).load_all()

View File

@@ -7,7 +7,7 @@ import pytest
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
from scgenerator.parameter import Parameters
from scgenerator.spectra import Propagation
from scgenerator.spectra import PARAMS_FN, propagation
PARAMS = dict(
name="PM2000D sech pulse",
@@ -78,13 +78,34 @@ def test_memory():
assert len(io) == 8
def test_reopen(tmp_path: Path):
zpath = tmp_path / "file.zip"
params = Parameters(**PARAMS)
prop = propagation(zpath, params)
prop2 = propagation(zpath)
assert prop.parameters == prop2.parameters
def test_clear(tmp_path: Path):
params = Parameters(**PARAMS)
zpath = tmp_path / "file.zip"
prop = propagation(zpath, params)
assert zpath.exists()
assert zpath.read_bytes() != b""
prop.io.clear()
assert not zpath.exists()
def test_zip_bundle(tmp_path: Path):
params = Parameters(**PARAMS)
io = ZipFileIOHandler(tmp_path / "file.zip")
prop = Propagation(io, params.copy(True))
assert (tmp_path / "file.zip").exists()
assert (tmp_path / "file.zip").read_bytes() != b""
with pytest.raises(FileNotFoundError):
propagation(tmp_path / "file2.zip", params, bundle_data=True)
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
@@ -92,30 +113,32 @@ def test_zip_bundle(tmp_path: Path):
params.effective_area_file.path = str(new_aeff_path)
params.freeze()
io = ZipFileIOHandler(tmp_path / "file2.zip")
prop2 = Propagation(io, params, True)
prop2 = propagation(tmp_path / "file3.zip", params, bundle_data=True)
assert params.dispersion_file.path == new_disp_path.name
assert params.dispersion_file.prefix == "zip"
assert params.effective_area_file.path == new_aeff_path.name
assert params.effective_area_file.prefix == "zip"
assert prop2.parameters.dispersion_file.path == new_disp_path.name
assert prop2.parameters.dispersion_file.prefix == "zip"
assert prop2.parameters.effective_area_file.path == new_aeff_path.name
assert prop2.parameters.effective_area_file.prefix == "zip"
with ZipFile(tmp_path / "file2.zip", "r") as zfile:
with ZipFile(tmp_path / "file3.zip", "r") as zfile:
with zfile.open(new_aeff_path.name) as file:
assert file.read() == new_aeff_path.read_bytes()
with zfile.open(new_disp_path.name) as file:
assert file.read() == new_disp_path.read_bytes()
with zfile.open(Propagation.PARAMS_FN) as file:
with zfile.open(PARAMS_FN) as file:
df = json.loads(file.read().decode())
assert (
df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path
df["dispersion_file"]
== prop2.parameters.dispersion_file.prefix + "::" + Path(params.dispersion_file.path).name
)
assert (
df["effective_area_file"]
== params.effective_area_file.prefix + "::" + params.effective_area_file.path
== prop2.parameters.effective_area_file.prefix
+ "::"
+ Path(params.effective_area_file.path).name
)