cleanup io handling and Propagation

This commit is contained in:
Benoît Sierro
2023-08-09 11:21:57 +02:00
parent a650169443
commit 7bb15871c3
6 changed files with 177 additions and 76 deletions

View File

@@ -1,8 +1,10 @@
# ruff: noqa # ruff: noqa
from scgenerator import io, math, operators, plotting from scgenerator import io, math, operators, plotting
from scgenerator.helpers import * from scgenerator.helpers import *
from scgenerator.io import MemoryIOHandler, ZipFileIOHandler
from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace
from scgenerator.parameter import Parameters from scgenerator.parameter import Parameters
from scgenerator.physics import fiber, materials, plasma, pulse, units from scgenerator.physics import fiber, materials, plasma, pulse, units
from scgenerator.physics.units import PlotRange from scgenerator.physics.units import PlotRange
from scgenerator.solver import SimulationResult, integrate, solve43 from scgenerator.solver import SimulationResult, integrate, solve43
from scgenerator.spectra import Propagation, Spectrum

View File

@@ -33,6 +33,7 @@ MANDATORY_PARAMETERS = {
"name", "name",
"w", "w",
"t", "t",
"l",
"fft", "fft",
"ifft", "ifft",
"w0", "w0",

View File

@@ -72,6 +72,9 @@ class PropagationIOHandler(Protocol):
def load_data(self, name: str) -> bytes: def load_data(self, name: str) -> bytes:
... ...
def clear(self):
...
class MemoryIOHandler: class MemoryIOHandler:
spectra: dict[int, np.ndarray] spectra: dict[int, np.ndarray]
@@ -99,6 +102,10 @@ class MemoryIOHandler:
def load_data(self, name: str) -> bytes: def load_data(self, name: str) -> bytes:
return self.data[name] return self.data[name]
def clear(self):
self.spectra = {}
self.data = {}
class ZipFileIOHandler: class ZipFileIOHandler:
file: BinaryIO file: BinaryIO
@@ -153,6 +160,9 @@ class ZipFileIOHandler:
with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file: with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file:
return file.read() return file.read()
def clear(self):
self.file.unlink(missing_ok=True)
@dataclass @dataclass
class DataFile: class DataFile:
@@ -201,11 +211,15 @@ class DataFile:
f"a bundled file prefixed with {self.prefix} " f"a bundled file prefixed with {self.prefix} "
"must have a PropagationIOHandler attached" "must have a PropagationIOHandler attached"
) )
if self.io is not None: # a DataFile obj may have a useless io obj attached to it
if self.prefix is not None:
return self.io.load_data(self.path) return self.io.load_data(self.path)
else: else:
return Path(self.path).read_bytes() return Path(self.path).read_bytes()
def similar_to(self, other: DataFile) -> bool:
return Path(self.path).name == Path(other.path).name
def unique_name(base_name: str, existing: set[str]) -> str: def unique_name(base_name: str, existing: set[str]) -> str:
name = base_name name = base_name

View File

@@ -294,7 +294,7 @@ class Parameter:
return is_value, v return is_value, v
@dataclass(repr=False) @dataclass(repr=False, eq=False)
class Parameters: class Parameters:
""" """
This class defines each valid parameter's name, type and valid value. This class defines each valid parameter's name, type and valid value.
@@ -303,7 +303,7 @@ class Parameters:
# internal machinery # internal machinery
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False) _param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
_p_names: ClassVar[Set[str]] = set() _p_names: ClassVar[Set[str]] = set()
_frozen: bool = field(init=False, default=False, repr=False) frozen: bool = field(init=False, default=False, repr=False)
# root # root
name: str = Parameter(string, default="no name") name: str = Parameter(string, default="no name")
@@ -460,28 +460,45 @@ class Parameters:
self.__post_init__() self.__post_init__()
def __setattr__(self, k, v): def __setattr__(self, k, v):
if self._frozen and not k.endswith("_file"): if self.frozen and not k.endswith("_file"):
raise AttributeError( raise AttributeError(
f"cannot set attribute to frozen {self.__class__.__name__} instance" f"cannot set attribute to frozen {self.__class__.__name__} instance"
) )
object.__setattr__(self, k, v) object.__setattr__(self, k, v)
def __eq__(self, other: Parameters) -> bool:
if not isinstance(other, Parameters):
raise TypeError(
f"cannot compare {self.__class__.__name__!r} with {type(other).__name__!r}"
)
for k, v in self.items():
other_v = getattr(other, k)
if isinstance(v, DataFile) and not v.similar_to(other_v):
return False
if other_v != v:
return False
return True
def items(self) -> Iterator[tuple[str, Any]]: def items(self) -> Iterator[tuple[str, Any]]:
for k, v in self._param_dico.items(): for k, v in self._param_dico.items():
if v is None: if v is None:
continue continue
yield k, v yield k, v
def copy(self, freeze: bool = False) -> Parameters: def copy(self, deep: bool = True, freeze: bool = False) -> Parameters:
"""create a deep copy of self. if freeze is True, the returned copy is read-only""" """create a deep copy of self. if freeze is True, the returned copy is read-only"""
params = Parameters(**deepcopy(self.strip_params_dict())) if deep:
params = Parameters(**deepcopy(self.strip_params_dict()))
else:
params = Parameters(**self.strip_params_dict())
if freeze: if freeze:
params.freeze() params.freeze()
return params return params
def freeze(self): def freeze(self):
"""render the current instance read-only. This is not reversible""" """render the current instance read-only. This is not reversible"""
self._frozen = True self.frozen = True
def to_json(self) -> str: def to_json(self) -> str:
d = self.dump_dict() d = self.dump_dict()
@@ -536,7 +553,7 @@ class Parameters:
else: else:
return first return first
def compile(self, exhaustive=False) -> Parameters: def compile(self, exhaustive=False, strict: bool = True) -> Parameters:
""" """
Computes missing parameters and returns them in a frozen `Parameters` instance Computes missing parameters and returns them in a frozen `Parameters` instance
@@ -547,6 +564,8 @@ class Parameters:
Depending on the specifics of the model and how the parameters were specified, there Depending on the specifics of the model and how the parameters were specified, there
might be no difference between a normal compilation and an exhaustive one. might be no difference between a normal compilation and an exhaustive one.
by default False by default False
strict : bool, optional
raise an exception when something cannot be computed, by default True
Returns Returns
------- -------
@@ -560,16 +579,23 @@ class Parameters:
When all the necessary parameters cannot be computed, a `ValueError` is raised. In most When all the necessary parameters cannot be computed, a `ValueError` is raised. In most
cases, this is due to underdetermination by the user. cases, this is due to underdetermination by the user.
""" """
obj = self.copy(deep=False, freeze=False)
obj.compile_in_place(exhaustive, strict)
return obj
def compile_in_place(self, exhaustive: bool = False, strict: bool = True):
to_compute = MANDATORY_PARAMETERS to_compute = MANDATORY_PARAMETERS
evaluator = self.get_evaluator() evaluator = self.get_evaluator()
try: for k in to_compute:
for k in to_compute: try:
evaluator.compute(k) evaluator.compute(k)
except EvaluatorError as e: except EvaluatorError as e:
raise ValueError( if strict:
"Could not compile the parameter set. Most likely, " raise ValueError(
f"an essential value is missing\n{e}" "Could not compile the parameter set. Most likely, "
) from None f"an essential value is missing\n{e}"
) from None
if exhaustive: if exhaustive:
for p in self._p_names: for p in self._p_names:
if p not in evaluator.main_map: if p not in evaluator.main_map:
@@ -577,11 +603,10 @@ class Parameters:
evaluator.compute(p) evaluator.compute(p)
except Exception: except Exception:
pass pass
computed = self.__class__( self._param_dico |= {
**{k: v.value for k, v in evaluator.main_map.items() if k in self._p_names} k: v.value for k, v in evaluator.main_map.items() if k in self._p_names
) }
computed._frozen = True self.freeze()
return computed
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str: def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
"""return a pretty formatted string describing the parameters""" """return a pretty formatted string describing the parameters"""

View File

@@ -3,14 +3,23 @@ from __future__ import annotations
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Callable
import numpy as np import numpy as np
from scgenerator import math from scgenerator import math
from scgenerator.io import DataFile, PropagationIOHandler, ZipFileIOHandler, unique_name from scgenerator.io import (
DataFile,
MemoryIOHandler,
PropagationIOHandler,
ZipFileIOHandler,
unique_name,
)
from scgenerator.parameter import Parameters from scgenerator.parameter import Parameters
from scgenerator.physics import pulse, units from scgenerator.physics import pulse, units
PARAMS_FN = "params.json"
class Spectrum(np.ndarray): class Spectrum(np.ndarray):
w: np.ndarray w: np.ndarray
@@ -25,7 +34,7 @@ class Spectrum(np.ndarray):
# add the new attribute to the created instance # add the new attribute to the created instance
obj.w = params.compute("w") obj.w = params.compute("w")
obj.t = params.compute("t") obj.t = params.compute("t")
obj.t = params.compute("t") obj.l = params.compute("l")
obj.ifft = params.compute("ifft") obj.ifft = params.compute("ifft")
# Finally, we must return the newly created object: # Finally, we must return the newly created object:
@@ -116,16 +125,13 @@ class Spectrum(np.ndarray):
class Propagation: class Propagation:
io: PropagationIOHandler io: PropagationIOHandler
params: Parameters parameters: Parameters
_current_index: int _current_index: int
PARAMS_FN = "params.json"
def __init__( def __init__(
self, self,
io_handler: PropagationIOHandler, io_handler: PropagationIOHandler,
params: Parameters | None = None, params: Parameters,
bundle_data: bool = False,
): ):
""" """
A propagation is the object that manages IO for one single propagation. A propagation is the object that manages IO for one single propagation.
@@ -141,31 +147,7 @@ class Propagation:
""" """
self.io = io_handler self.io = io_handler
self._current_index = len(self.io) self._current_index = len(self.io)
self.parameters = params
new_params = params is not None
if not new_params:
if bundle_data:
raise ValueError(
"cannot bundle data to existing Propagation. Create a new one instead"
)
try:
params_data = self.io.load_data(self.PARAMS_FN)
params = Parameters.from_json(params_data.decode())
except KeyError:
raise ValueError(f"Missing Parameters in {self.__class__.__name__}.") from None
self.params = params
if bundle_data:
if not params._frozen:
warnings.warn(
f"Parameters instance {params.name!r} is not frozen but will be saved "
"in an unmodifiable state anyway."
)
_bundle_external_files(self.params, self.io)
if new_params:
self.io.save_data(self.PARAMS_FN, self.params.to_json().encode())
def __len__(self) -> int: def __len__(self) -> int:
return self._current_index return self._current_index
@@ -174,9 +156,9 @@ class Propagation:
if isinstance(key, slice): if isinstance(key, slice):
return self._load_slice(key) return self._load_slice(key)
if isinstance(key, (float, np.floating)): if isinstance(key, (float, np.floating)):
key = math.argclosest(self.params.compute("z_targets"), key) key = math.argclosest(self.parameters.compute("z_targets"), key)
array = self.io.load_spectrum(key) array = self.io.load_spectrum(key)
return Spectrum(array, self.params) return Spectrum(array, self.parameters)
def __setitem__(self, key: int, value: np.ndarray): def __setitem__(self, key: int, value: np.ndarray):
if not isinstance(key, int): if not isinstance(key, int):
@@ -185,19 +167,78 @@ class Propagation:
def _load_slice(self, key: slice) -> Spectrum: def _load_slice(self, key: slice) -> Spectrum:
_iter = range(len(self))[key] _iter = range(len(self))[key]
out = Spectrum(np.zeros((len(_iter), self.params.t_num), dtype=complex), self.params) out = Spectrum(
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
)
for i in _iter: for i in _iter:
out[i] = self.io.load_spectrum(i) out[i] = self.io.load_spectrum(i)
return out return out
def append(self, spectrum: np.ndarray): def append(self, spectrum: np.ndarray):
self.io.save_spectrum(self._current_index, spectrum.asarray()) self.io.save_spectrum(self._current_index, np.asarray(spectrum))
self._current_index += 1 self._current_index += 1
def load_all(self) -> Spectrum: def load_all(self) -> Spectrum:
return self._load_slice(slice(None)) return self._load_slice(slice(None))
def load_all(path: os.PathLike) -> Spectrum:
io = ZipFileIOHandler(path)
return Propagation(io).load_all()
def propagation(
file_or_params: os.PathLike | Parameters,
params: Parameters | None = None,
bundle_data: bool = False,
) -> Propagation:
file = None
if isinstance(file_or_params, Parameters):
params = file_or_params
else:
file = Path(file_or_params)
if file is not None and file.exists():
io = ZipFileIOHandler(file)
return _open_existing_propagation(io)
if params is None:
raise ValueError("Parameters must be specified to create new simulation")
if file is not None:
io = ZipFileIOHandler(file)
else:
io = MemoryIOHandler()
try:
return _create_new_propagation(io, params, bundle_data)
except Exception as e:
io.clear()
raise e
def _open_existing_propagation(io: PropagationIOHandler) -> Propagation:
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
params.compile_in_place(exhaustive=True, strict=False)
for k, v in params.items():
if isinstance(v, DataFile):
v.io = io
return Propagation(io, params)
def _create_new_propagation(
io: PropagationIOHandler, params: Parameters, bundle_data: bool
) -> Propagation:
if params.frozen:
params = params.copy()
else:
params = params.compile(exhaustive=True, strict=False)
if bundle_data:
_bundle_external_files(params, io)
io.save_data(PARAMS_FN, params.to_json().encode())
return Propagation(io, params)
def _bundle_external_files(params: Parameters, io: PropagationIOHandler): def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
"""copies every external file specified in the parameters and saves it""" """copies every external file specified in the parameters and saves it"""
existing_files = set(io.keys()) existing_files = set(io.keys())
@@ -211,8 +252,3 @@ def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
existing_files.add(value.path) existing_files.add(value.path)
io.save_data(value.path, data) io.save_data(value.path, data)
def load_all(path: os.PathLike) -> Spectrum:
io = ZipFileIOHandler(path)
return Propagation(io).load_all()

View File

@@ -7,7 +7,7 @@ import pytest
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
from scgenerator.parameter import Parameters from scgenerator.parameter import Parameters
from scgenerator.spectra import Propagation from scgenerator.spectra import PARAMS_FN, propagation
PARAMS = dict( PARAMS = dict(
name="PM2000D sech pulse", name="PM2000D sech pulse",
@@ -78,13 +78,34 @@ def test_memory():
assert len(io) == 8 assert len(io) == 8
def test_reopen(tmp_path: Path):
zpath = tmp_path / "file.zip"
params = Parameters(**PARAMS)
prop = propagation(zpath, params)
prop2 = propagation(zpath)
assert prop.parameters == prop2.parameters
def test_clear(tmp_path: Path):
params = Parameters(**PARAMS)
zpath = tmp_path / "file.zip"
prop = propagation(zpath, params)
assert zpath.exists()
assert zpath.read_bytes() != b""
prop.io.clear()
assert not zpath.exists()
def test_zip_bundle(tmp_path: Path): def test_zip_bundle(tmp_path: Path):
params = Parameters(**PARAMS) params = Parameters(**PARAMS)
io = ZipFileIOHandler(tmp_path / "file.zip")
prop = Propagation(io, params.copy(True))
assert (tmp_path / "file.zip").exists() with pytest.raises(FileNotFoundError):
assert (tmp_path / "file.zip").read_bytes() != b"" propagation(tmp_path / "file2.zip", params, bundle_data=True)
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz") new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz") new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
@@ -92,30 +113,32 @@ def test_zip_bundle(tmp_path: Path):
params.effective_area_file.path = str(new_aeff_path) params.effective_area_file.path = str(new_aeff_path)
params.freeze() params.freeze()
io = ZipFileIOHandler(tmp_path / "file2.zip") prop2 = propagation(tmp_path / "file3.zip", params, bundle_data=True)
prop2 = Propagation(io, params, True)
assert params.dispersion_file.path == new_disp_path.name assert prop2.parameters.dispersion_file.path == new_disp_path.name
assert params.dispersion_file.prefix == "zip" assert prop2.parameters.dispersion_file.prefix == "zip"
assert params.effective_area_file.path == new_aeff_path.name assert prop2.parameters.effective_area_file.path == new_aeff_path.name
assert params.effective_area_file.prefix == "zip" assert prop2.parameters.effective_area_file.prefix == "zip"
with ZipFile(tmp_path / "file2.zip", "r") as zfile: with ZipFile(tmp_path / "file3.zip", "r") as zfile:
with zfile.open(new_aeff_path.name) as file: with zfile.open(new_aeff_path.name) as file:
assert file.read() == new_aeff_path.read_bytes() assert file.read() == new_aeff_path.read_bytes()
with zfile.open(new_disp_path.name) as file: with zfile.open(new_disp_path.name) as file:
assert file.read() == new_disp_path.read_bytes() assert file.read() == new_disp_path.read_bytes()
with zfile.open(Propagation.PARAMS_FN) as file: with zfile.open(PARAMS_FN) as file:
df = json.loads(file.read().decode()) df = json.loads(file.read().decode())
assert ( assert (
df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path df["dispersion_file"]
== prop2.parameters.dispersion_file.prefix + "::" + Path(params.dispersion_file.path).name
) )
assert ( assert (
df["effective_area_file"] df["effective_area_file"]
== params.effective_area_file.prefix + "::" + params.effective_area_file.path == prop2.parameters.effective_area_file.prefix
+ "::"
+ Path(params.effective_area_file.path).name
) )