cleanup io handling and Propagation
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
# ruff: noqa
|
# ruff: noqa
|
||||||
from scgenerator import io, math, operators, plotting
|
from scgenerator import io, math, operators, plotting
|
||||||
from scgenerator.helpers import *
|
from scgenerator.helpers import *
|
||||||
|
from scgenerator.io import MemoryIOHandler, ZipFileIOHandler
|
||||||
from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace
|
from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace
|
||||||
from scgenerator.parameter import Parameters
|
from scgenerator.parameter import Parameters
|
||||||
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
||||||
from scgenerator.physics.units import PlotRange
|
from scgenerator.physics.units import PlotRange
|
||||||
from scgenerator.solver import SimulationResult, integrate, solve43
|
from scgenerator.solver import SimulationResult, integrate, solve43
|
||||||
|
from scgenerator.spectra import Propagation, Spectrum
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ MANDATORY_PARAMETERS = {
|
|||||||
"name",
|
"name",
|
||||||
"w",
|
"w",
|
||||||
"t",
|
"t",
|
||||||
|
"l",
|
||||||
"fft",
|
"fft",
|
||||||
"ifft",
|
"ifft",
|
||||||
"w0",
|
"w0",
|
||||||
|
|||||||
@@ -72,6 +72,9 @@ class PropagationIOHandler(Protocol):
|
|||||||
def load_data(self, name: str) -> bytes:
|
def load_data(self, name: str) -> bytes:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class MemoryIOHandler:
|
class MemoryIOHandler:
|
||||||
spectra: dict[int, np.ndarray]
|
spectra: dict[int, np.ndarray]
|
||||||
@@ -99,6 +102,10 @@ class MemoryIOHandler:
|
|||||||
def load_data(self, name: str) -> bytes:
|
def load_data(self, name: str) -> bytes:
|
||||||
return self.data[name]
|
return self.data[name]
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.spectra = {}
|
||||||
|
self.data = {}
|
||||||
|
|
||||||
|
|
||||||
class ZipFileIOHandler:
|
class ZipFileIOHandler:
|
||||||
file: BinaryIO
|
file: BinaryIO
|
||||||
@@ -153,6 +160,9 @@ class ZipFileIOHandler:
|
|||||||
with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file:
|
with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file:
|
||||||
return file.read()
|
return file.read()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.file.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataFile:
|
class DataFile:
|
||||||
@@ -201,11 +211,15 @@ class DataFile:
|
|||||||
f"a bundled file prefixed with {self.prefix} "
|
f"a bundled file prefixed with {self.prefix} "
|
||||||
"must have a PropagationIOHandler attached"
|
"must have a PropagationIOHandler attached"
|
||||||
)
|
)
|
||||||
if self.io is not None:
|
# a DataFile obj may have a useless io obj attached to it
|
||||||
|
if self.prefix is not None:
|
||||||
return self.io.load_data(self.path)
|
return self.io.load_data(self.path)
|
||||||
else:
|
else:
|
||||||
return Path(self.path).read_bytes()
|
return Path(self.path).read_bytes()
|
||||||
|
|
||||||
|
def similar_to(self, other: DataFile) -> bool:
|
||||||
|
return Path(self.path).name == Path(other.path).name
|
||||||
|
|
||||||
|
|
||||||
def unique_name(base_name: str, existing: set[str]) -> str:
|
def unique_name(base_name: str, existing: set[str]) -> str:
|
||||||
name = base_name
|
name = base_name
|
||||||
|
|||||||
@@ -294,7 +294,7 @@ class Parameter:
|
|||||||
return is_value, v
|
return is_value, v
|
||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
@dataclass(repr=False, eq=False)
|
||||||
class Parameters:
|
class Parameters:
|
||||||
"""
|
"""
|
||||||
This class defines each valid parameter's name, type and valid value.
|
This class defines each valid parameter's name, type and valid value.
|
||||||
@@ -303,7 +303,7 @@ class Parameters:
|
|||||||
# internal machinery
|
# internal machinery
|
||||||
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
|
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
|
||||||
_p_names: ClassVar[Set[str]] = set()
|
_p_names: ClassVar[Set[str]] = set()
|
||||||
_frozen: bool = field(init=False, default=False, repr=False)
|
frozen: bool = field(init=False, default=False, repr=False)
|
||||||
|
|
||||||
# root
|
# root
|
||||||
name: str = Parameter(string, default="no name")
|
name: str = Parameter(string, default="no name")
|
||||||
@@ -460,28 +460,45 @@ class Parameters:
|
|||||||
self.__post_init__()
|
self.__post_init__()
|
||||||
|
|
||||||
def __setattr__(self, k, v):
|
def __setattr__(self, k, v):
|
||||||
if self._frozen and not k.endswith("_file"):
|
if self.frozen and not k.endswith("_file"):
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
f"cannot set attribute to frozen {self.__class__.__name__} instance"
|
f"cannot set attribute to frozen {self.__class__.__name__} instance"
|
||||||
)
|
)
|
||||||
object.__setattr__(self, k, v)
|
object.__setattr__(self, k, v)
|
||||||
|
|
||||||
|
def __eq__(self, other: Parameters) -> bool:
|
||||||
|
if not isinstance(other, Parameters):
|
||||||
|
raise TypeError(
|
||||||
|
f"cannot compare {self.__class__.__name__!r} with {type(other).__name__!r}"
|
||||||
|
)
|
||||||
|
for k, v in self.items():
|
||||||
|
other_v = getattr(other, k)
|
||||||
|
if isinstance(v, DataFile) and not v.similar_to(other_v):
|
||||||
|
return False
|
||||||
|
if other_v != v:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def items(self) -> Iterator[tuple[str, Any]]:
|
def items(self) -> Iterator[tuple[str, Any]]:
|
||||||
for k, v in self._param_dico.items():
|
for k, v in self._param_dico.items():
|
||||||
if v is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
yield k, v
|
yield k, v
|
||||||
|
|
||||||
def copy(self, freeze: bool = False) -> Parameters:
|
def copy(self, deep: bool = True, freeze: bool = False) -> Parameters:
|
||||||
"""create a deep copy of self. if freeze is True, the returned copy is read-only"""
|
"""create a deep copy of self. if freeze is True, the returned copy is read-only"""
|
||||||
|
if deep:
|
||||||
params = Parameters(**deepcopy(self.strip_params_dict()))
|
params = Parameters(**deepcopy(self.strip_params_dict()))
|
||||||
|
else:
|
||||||
|
params = Parameters(**self.strip_params_dict())
|
||||||
if freeze:
|
if freeze:
|
||||||
params.freeze()
|
params.freeze()
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
"""render the current instance read-only. This is not reversible"""
|
"""render the current instance read-only. This is not reversible"""
|
||||||
self._frozen = True
|
self.frozen = True
|
||||||
|
|
||||||
def to_json(self) -> str:
|
def to_json(self) -> str:
|
||||||
d = self.dump_dict()
|
d = self.dump_dict()
|
||||||
@@ -536,7 +553,7 @@ class Parameters:
|
|||||||
else:
|
else:
|
||||||
return first
|
return first
|
||||||
|
|
||||||
def compile(self, exhaustive=False) -> Parameters:
|
def compile(self, exhaustive=False, strict: bool = True) -> Parameters:
|
||||||
"""
|
"""
|
||||||
Computes missing parameters and returns them in a frozen `Parameters` instance
|
Computes missing parameters and returns them in a frozen `Parameters` instance
|
||||||
|
|
||||||
@@ -547,6 +564,8 @@ class Parameters:
|
|||||||
Depending on the specifics of the model and how the parameters were specified, there
|
Depending on the specifics of the model and how the parameters were specified, there
|
||||||
might be no difference between a normal compilation and an exhaustive one.
|
might be no difference between a normal compilation and an exhaustive one.
|
||||||
by default False
|
by default False
|
||||||
|
strict : bool, optional
|
||||||
|
raise an exception when something cannot be computed, by default True
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -560,16 +579,23 @@ class Parameters:
|
|||||||
When all the necessary parameters cannot be computed, a `ValueError` is raised. In most
|
When all the necessary parameters cannot be computed, a `ValueError` is raised. In most
|
||||||
cases, this is due to underdetermination by the user.
|
cases, this is due to underdetermination by the user.
|
||||||
"""
|
"""
|
||||||
|
obj = self.copy(deep=False, freeze=False)
|
||||||
|
obj.compile_in_place(exhaustive, strict)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def compile_in_place(self, exhaustive: bool = False, strict: bool = True):
|
||||||
to_compute = MANDATORY_PARAMETERS
|
to_compute = MANDATORY_PARAMETERS
|
||||||
evaluator = self.get_evaluator()
|
evaluator = self.get_evaluator()
|
||||||
try:
|
|
||||||
for k in to_compute:
|
for k in to_compute:
|
||||||
|
try:
|
||||||
evaluator.compute(k)
|
evaluator.compute(k)
|
||||||
except EvaluatorError as e:
|
except EvaluatorError as e:
|
||||||
|
if strict:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not compile the parameter set. Most likely, "
|
"Could not compile the parameter set. Most likely, "
|
||||||
f"an essential value is missing\n{e}"
|
f"an essential value is missing\n{e}"
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
if exhaustive:
|
if exhaustive:
|
||||||
for p in self._p_names:
|
for p in self._p_names:
|
||||||
if p not in evaluator.main_map:
|
if p not in evaluator.main_map:
|
||||||
@@ -577,11 +603,10 @@ class Parameters:
|
|||||||
evaluator.compute(p)
|
evaluator.compute(p)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
computed = self.__class__(
|
self._param_dico |= {
|
||||||
**{k: v.value for k, v in evaluator.main_map.items() if k in self._p_names}
|
k: v.value for k, v in evaluator.main_map.items() if k in self._p_names
|
||||||
)
|
}
|
||||||
computed._frozen = True
|
self.freeze()
|
||||||
return computed
|
|
||||||
|
|
||||||
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
def pretty_str(self, params: Iterable[str] = None, exclude=None) -> str:
|
||||||
"""return a pretty formatted string describing the parameters"""
|
"""return a pretty formatted string describing the parameters"""
|
||||||
|
|||||||
@@ -3,14 +3,23 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from scgenerator import math
|
from scgenerator import math
|
||||||
from scgenerator.io import DataFile, PropagationIOHandler, ZipFileIOHandler, unique_name
|
from scgenerator.io import (
|
||||||
|
DataFile,
|
||||||
|
MemoryIOHandler,
|
||||||
|
PropagationIOHandler,
|
||||||
|
ZipFileIOHandler,
|
||||||
|
unique_name,
|
||||||
|
)
|
||||||
from scgenerator.parameter import Parameters
|
from scgenerator.parameter import Parameters
|
||||||
from scgenerator.physics import pulse, units
|
from scgenerator.physics import pulse, units
|
||||||
|
|
||||||
|
PARAMS_FN = "params.json"
|
||||||
|
|
||||||
|
|
||||||
class Spectrum(np.ndarray):
|
class Spectrum(np.ndarray):
|
||||||
w: np.ndarray
|
w: np.ndarray
|
||||||
@@ -25,7 +34,7 @@ class Spectrum(np.ndarray):
|
|||||||
# add the new attribute to the created instance
|
# add the new attribute to the created instance
|
||||||
obj.w = params.compute("w")
|
obj.w = params.compute("w")
|
||||||
obj.t = params.compute("t")
|
obj.t = params.compute("t")
|
||||||
obj.t = params.compute("t")
|
obj.l = params.compute("l")
|
||||||
obj.ifft = params.compute("ifft")
|
obj.ifft = params.compute("ifft")
|
||||||
|
|
||||||
# Finally, we must return the newly created object:
|
# Finally, we must return the newly created object:
|
||||||
@@ -116,16 +125,13 @@ class Spectrum(np.ndarray):
|
|||||||
|
|
||||||
class Propagation:
|
class Propagation:
|
||||||
io: PropagationIOHandler
|
io: PropagationIOHandler
|
||||||
params: Parameters
|
parameters: Parameters
|
||||||
_current_index: int
|
_current_index: int
|
||||||
|
|
||||||
PARAMS_FN = "params.json"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
io_handler: PropagationIOHandler,
|
io_handler: PropagationIOHandler,
|
||||||
params: Parameters | None = None,
|
params: Parameters,
|
||||||
bundle_data: bool = False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
A propagation is the object that manages IO for one single propagation.
|
A propagation is the object that manages IO for one single propagation.
|
||||||
@@ -141,31 +147,7 @@ class Propagation:
|
|||||||
"""
|
"""
|
||||||
self.io = io_handler
|
self.io = io_handler
|
||||||
self._current_index = len(self.io)
|
self._current_index = len(self.io)
|
||||||
|
self.parameters = params
|
||||||
new_params = params is not None
|
|
||||||
if not new_params:
|
|
||||||
if bundle_data:
|
|
||||||
raise ValueError(
|
|
||||||
"cannot bundle data to existing Propagation. Create a new one instead"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
params_data = self.io.load_data(self.PARAMS_FN)
|
|
||||||
params = Parameters.from_json(params_data.decode())
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f"Missing Parameters in {self.__class__.__name__}.") from None
|
|
||||||
|
|
||||||
self.params = params
|
|
||||||
|
|
||||||
if bundle_data:
|
|
||||||
if not params._frozen:
|
|
||||||
warnings.warn(
|
|
||||||
f"Parameters instance {params.name!r} is not frozen but will be saved "
|
|
||||||
"in an unmodifiable state anyway."
|
|
||||||
)
|
|
||||||
_bundle_external_files(self.params, self.io)
|
|
||||||
|
|
||||||
if new_params:
|
|
||||||
self.io.save_data(self.PARAMS_FN, self.params.to_json().encode())
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self._current_index
|
return self._current_index
|
||||||
@@ -174,9 +156,9 @@ class Propagation:
|
|||||||
if isinstance(key, slice):
|
if isinstance(key, slice):
|
||||||
return self._load_slice(key)
|
return self._load_slice(key)
|
||||||
if isinstance(key, (float, np.floating)):
|
if isinstance(key, (float, np.floating)):
|
||||||
key = math.argclosest(self.params.compute("z_targets"), key)
|
key = math.argclosest(self.parameters.compute("z_targets"), key)
|
||||||
array = self.io.load_spectrum(key)
|
array = self.io.load_spectrum(key)
|
||||||
return Spectrum(array, self.params)
|
return Spectrum(array, self.parameters)
|
||||||
|
|
||||||
def __setitem__(self, key: int, value: np.ndarray):
|
def __setitem__(self, key: int, value: np.ndarray):
|
||||||
if not isinstance(key, int):
|
if not isinstance(key, int):
|
||||||
@@ -185,19 +167,78 @@ class Propagation:
|
|||||||
|
|
||||||
def _load_slice(self, key: slice) -> Spectrum:
|
def _load_slice(self, key: slice) -> Spectrum:
|
||||||
_iter = range(len(self))[key]
|
_iter = range(len(self))[key]
|
||||||
out = Spectrum(np.zeros((len(_iter), self.params.t_num), dtype=complex), self.params)
|
out = Spectrum(
|
||||||
|
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
|
||||||
|
)
|
||||||
for i in _iter:
|
for i in _iter:
|
||||||
out[i] = self.io.load_spectrum(i)
|
out[i] = self.io.load_spectrum(i)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def append(self, spectrum: np.ndarray):
|
def append(self, spectrum: np.ndarray):
|
||||||
self.io.save_spectrum(self._current_index, spectrum.asarray())
|
self.io.save_spectrum(self._current_index, np.asarray(spectrum))
|
||||||
self._current_index += 1
|
self._current_index += 1
|
||||||
|
|
||||||
def load_all(self) -> Spectrum:
|
def load_all(self) -> Spectrum:
|
||||||
return self._load_slice(slice(None))
|
return self._load_slice(slice(None))
|
||||||
|
|
||||||
|
|
||||||
|
def load_all(path: os.PathLike) -> Spectrum:
|
||||||
|
io = ZipFileIOHandler(path)
|
||||||
|
return Propagation(io).load_all()
|
||||||
|
|
||||||
|
|
||||||
|
def propagation(
|
||||||
|
file_or_params: os.PathLike | Parameters,
|
||||||
|
params: Parameters | None = None,
|
||||||
|
bundle_data: bool = False,
|
||||||
|
) -> Propagation:
|
||||||
|
file = None
|
||||||
|
if isinstance(file_or_params, Parameters):
|
||||||
|
params = file_or_params
|
||||||
|
else:
|
||||||
|
file = Path(file_or_params)
|
||||||
|
|
||||||
|
if file is not None and file.exists():
|
||||||
|
io = ZipFileIOHandler(file)
|
||||||
|
return _open_existing_propagation(io)
|
||||||
|
|
||||||
|
if params is None:
|
||||||
|
raise ValueError("Parameters must be specified to create new simulation")
|
||||||
|
|
||||||
|
if file is not None:
|
||||||
|
io = ZipFileIOHandler(file)
|
||||||
|
else:
|
||||||
|
io = MemoryIOHandler()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _create_new_propagation(io, params, bundle_data)
|
||||||
|
except Exception as e:
|
||||||
|
io.clear()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def _open_existing_propagation(io: PropagationIOHandler) -> Propagation:
|
||||||
|
params = Parameters.from_json(io.load_data(PARAMS_FN).decode())
|
||||||
|
params.compile_in_place(exhaustive=True, strict=False)
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, DataFile):
|
||||||
|
v.io = io
|
||||||
|
return Propagation(io, params)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_new_propagation(
|
||||||
|
io: PropagationIOHandler, params: Parameters, bundle_data: bool
|
||||||
|
) -> Propagation:
|
||||||
|
if params.frozen:
|
||||||
|
params = params.copy()
|
||||||
|
else:
|
||||||
|
params = params.compile(exhaustive=True, strict=False)
|
||||||
|
if bundle_data:
|
||||||
|
_bundle_external_files(params, io)
|
||||||
|
io.save_data(PARAMS_FN, params.to_json().encode())
|
||||||
|
return Propagation(io, params)
|
||||||
|
|
||||||
|
|
||||||
def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
|
def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
|
||||||
"""copies every external file specified in the parameters and saves it"""
|
"""copies every external file specified in the parameters and saves it"""
|
||||||
existing_files = set(io.keys())
|
existing_files = set(io.keys())
|
||||||
@@ -211,8 +252,3 @@ def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
|
|||||||
existing_files.add(value.path)
|
existing_files.add(value.path)
|
||||||
|
|
||||||
io.save_data(value.path, data)
|
io.save_data(value.path, data)
|
||||||
|
|
||||||
|
|
||||||
def load_all(path: os.PathLike) -> Spectrum:
|
|
||||||
io = ZipFileIOHandler(path)
|
|
||||||
return Propagation(io).load_all()
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import pytest
|
|||||||
|
|
||||||
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
|
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
|
||||||
from scgenerator.parameter import Parameters
|
from scgenerator.parameter import Parameters
|
||||||
from scgenerator.spectra import Propagation
|
from scgenerator.spectra import PARAMS_FN, propagation
|
||||||
|
|
||||||
PARAMS = dict(
|
PARAMS = dict(
|
||||||
name="PM2000D sech pulse",
|
name="PM2000D sech pulse",
|
||||||
@@ -78,13 +78,34 @@ def test_memory():
|
|||||||
assert len(io) == 8
|
assert len(io) == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_reopen(tmp_path: Path):
|
||||||
|
zpath = tmp_path / "file.zip"
|
||||||
|
params = Parameters(**PARAMS)
|
||||||
|
prop = propagation(zpath, params)
|
||||||
|
|
||||||
|
prop2 = propagation(zpath)
|
||||||
|
|
||||||
|
assert prop.parameters == prop2.parameters
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear(tmp_path: Path):
|
||||||
|
params = Parameters(**PARAMS)
|
||||||
|
zpath = tmp_path / "file.zip"
|
||||||
|
prop = propagation(zpath, params)
|
||||||
|
|
||||||
|
assert zpath.exists()
|
||||||
|
assert zpath.read_bytes() != b""
|
||||||
|
|
||||||
|
prop.io.clear()
|
||||||
|
|
||||||
|
assert not zpath.exists()
|
||||||
|
|
||||||
|
|
||||||
def test_zip_bundle(tmp_path: Path):
|
def test_zip_bundle(tmp_path: Path):
|
||||||
params = Parameters(**PARAMS)
|
params = Parameters(**PARAMS)
|
||||||
io = ZipFileIOHandler(tmp_path / "file.zip")
|
|
||||||
prop = Propagation(io, params.copy(True))
|
|
||||||
|
|
||||||
assert (tmp_path / "file.zip").exists()
|
with pytest.raises(FileNotFoundError):
|
||||||
assert (tmp_path / "file.zip").read_bytes() != b""
|
propagation(tmp_path / "file2.zip", params, bundle_data=True)
|
||||||
|
|
||||||
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
|
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
|
||||||
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
|
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
|
||||||
@@ -92,30 +113,32 @@ def test_zip_bundle(tmp_path: Path):
|
|||||||
params.effective_area_file.path = str(new_aeff_path)
|
params.effective_area_file.path = str(new_aeff_path)
|
||||||
params.freeze()
|
params.freeze()
|
||||||
|
|
||||||
io = ZipFileIOHandler(tmp_path / "file2.zip")
|
prop2 = propagation(tmp_path / "file3.zip", params, bundle_data=True)
|
||||||
prop2 = Propagation(io, params, True)
|
|
||||||
|
|
||||||
assert params.dispersion_file.path == new_disp_path.name
|
assert prop2.parameters.dispersion_file.path == new_disp_path.name
|
||||||
assert params.dispersion_file.prefix == "zip"
|
assert prop2.parameters.dispersion_file.prefix == "zip"
|
||||||
assert params.effective_area_file.path == new_aeff_path.name
|
assert prop2.parameters.effective_area_file.path == new_aeff_path.name
|
||||||
assert params.effective_area_file.prefix == "zip"
|
assert prop2.parameters.effective_area_file.prefix == "zip"
|
||||||
|
|
||||||
with ZipFile(tmp_path / "file2.zip", "r") as zfile:
|
with ZipFile(tmp_path / "file3.zip", "r") as zfile:
|
||||||
with zfile.open(new_aeff_path.name) as file:
|
with zfile.open(new_aeff_path.name) as file:
|
||||||
assert file.read() == new_aeff_path.read_bytes()
|
assert file.read() == new_aeff_path.read_bytes()
|
||||||
with zfile.open(new_disp_path.name) as file:
|
with zfile.open(new_disp_path.name) as file:
|
||||||
assert file.read() == new_disp_path.read_bytes()
|
assert file.read() == new_disp_path.read_bytes()
|
||||||
|
|
||||||
with zfile.open(Propagation.PARAMS_FN) as file:
|
with zfile.open(PARAMS_FN) as file:
|
||||||
df = json.loads(file.read().decode())
|
df = json.loads(file.read().decode())
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path
|
df["dispersion_file"]
|
||||||
|
== prop2.parameters.dispersion_file.prefix + "::" + Path(params.dispersion_file.path).name
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
df["effective_area_file"]
|
df["effective_area_file"]
|
||||||
== params.effective_area_file.prefix + "::" + params.effective_area_file.path
|
== prop2.parameters.effective_area_file.prefix
|
||||||
|
+ "::"
|
||||||
|
+ Path(params.effective_area_file.path).name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user