small refactor

This commit is contained in:
Benoît Sierro
2023-10-04 08:21:37 +02:00
parent 5adde638ef
commit 27de20e6ca
2 changed files with 10 additions and 6 deletions

View File

@@ -241,9 +241,9 @@ class Propagation(Generic[ParamsOrNone]):
def load_parameters(self) -> Parameters: def load_parameters(self) -> Parameters:
params = Parameters.from_json(self.io.load_data(PARAMS_FN).decode()) params = Parameters.from_json(self.io.load_data(PARAMS_FN).decode())
params.compile_in_place(exhaustive=True, strict=False) params.compile_in_place(exhaustive=True, strict=False)
for k, v in params.items(): for _, value in params.items():
if isinstance(v, DataFile): if isinstance(value, DataFile):
v.io = self.io value.io = self.io
return params return params
def append(self, spectrum: np.ndarray): def append(self, spectrum: np.ndarray):
@@ -326,7 +326,8 @@ def propagation(
if params is None: if params is None:
raise ValueError( raise ValueError(
f"{file} doesn't exist, but no parameters have been passed to create a new propagation" f"{file} doesn't exist, but no parameters have been "
"specified to create a new propagation"
) )
if file is not None: if file is not None:
@@ -383,7 +384,9 @@ def propagation_series(
rest = tqdm(rest) rest = tqdm(rest)
spectrum = Spectrum.from_params([prop[:] for prop in rest], parameters) if index is None:
index = slice(None)
spectrum = Spectrum.from_params([prop[index] for prop in rest], parameters)
for prop in propagations: for prop in propagations:
del prop.parameters del prop.parameters

View File

@@ -184,7 +184,7 @@ def test_unique_name():
def test_propagation_series(tmp_path: Path): def test_propagation_series(tmp_path: Path):
params = Parameters(**PARAMS) params = Parameters(**PARAMS)
with pytest.raises(ValueError): with pytest.raises(ValueError):
specs, props = propagation_series([]) specs, _ = propagation_series([])
flist = [tmp_path / f"prop{i}.zip" for i in range(10)] flist = [tmp_path / f"prop{i}.zip" for i in range(10)]
for i, f in enumerate(flist): for i, f in enumerate(flist):
@@ -196,5 +196,6 @@ def test_propagation_series(tmp_path: Path):
assert set(flist) == set(tmp_path.glob("*.zip")) assert set(flist) == set(tmp_path.glob("*.zip"))
specs, propagations = propagation_series(flist) specs, propagations = propagation_series(flist)
assert specs.shape == (10, params.z_num, params.t_num) assert specs.shape == (10, params.z_num, params.t_num)
assert all(prop.parameters.name == f"prop {i}" for i, prop in enumerate(propagations)) assert all(prop.parameters.name == f"prop {i}" for i, prop in enumerate(propagations))