small refactor
This commit is contained in:
@@ -241,9 +241,9 @@ class Propagation(Generic[ParamsOrNone]):
|
|||||||
def load_parameters(self) -> Parameters:
|
def load_parameters(self) -> Parameters:
|
||||||
params = Parameters.from_json(self.io.load_data(PARAMS_FN).decode())
|
params = Parameters.from_json(self.io.load_data(PARAMS_FN).decode())
|
||||||
params.compile_in_place(exhaustive=True, strict=False)
|
params.compile_in_place(exhaustive=True, strict=False)
|
||||||
for k, v in params.items():
|
for _, value in params.items():
|
||||||
if isinstance(v, DataFile):
|
if isinstance(value, DataFile):
|
||||||
v.io = self.io
|
value.io = self.io
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def append(self, spectrum: np.ndarray):
|
def append(self, spectrum: np.ndarray):
|
||||||
@@ -326,7 +326,8 @@ def propagation(
|
|||||||
|
|
||||||
if params is None:
|
if params is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{file} doesn't exist, but no parameters have been passed to create a new propagation"
|
f"{file} doesn't exist, but no parameters have been "
|
||||||
|
"specified to create a new propagation"
|
||||||
)
|
)
|
||||||
|
|
||||||
if file is not None:
|
if file is not None:
|
||||||
@@ -383,7 +384,9 @@ def propagation_series(
|
|||||||
|
|
||||||
rest = tqdm(rest)
|
rest = tqdm(rest)
|
||||||
|
|
||||||
spectrum = Spectrum.from_params([prop[:] for prop in rest], parameters)
|
if index is None:
|
||||||
|
index = slice(None)
|
||||||
|
spectrum = Spectrum.from_params([prop[index] for prop in rest], parameters)
|
||||||
for prop in propagations:
|
for prop in propagations:
|
||||||
del prop.parameters
|
del prop.parameters
|
||||||
|
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ def test_unique_name():
|
|||||||
def test_propagation_series(tmp_path: Path):
|
def test_propagation_series(tmp_path: Path):
|
||||||
params = Parameters(**PARAMS)
|
params = Parameters(**PARAMS)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
specs, props = propagation_series([])
|
specs, _ = propagation_series([])
|
||||||
|
|
||||||
flist = [tmp_path / f"prop{i}.zip" for i in range(10)]
|
flist = [tmp_path / f"prop{i}.zip" for i in range(10)]
|
||||||
for i, f in enumerate(flist):
|
for i, f in enumerate(flist):
|
||||||
@@ -196,5 +196,6 @@ def test_propagation_series(tmp_path: Path):
|
|||||||
assert set(flist) == set(tmp_path.glob("*.zip"))
|
assert set(flist) == set(tmp_path.glob("*.zip"))
|
||||||
|
|
||||||
specs, propagations = propagation_series(flist)
|
specs, propagations = propagation_series(flist)
|
||||||
|
|
||||||
assert specs.shape == (10, params.z_num, params.t_num)
|
assert specs.shape == (10, params.z_num, params.t_num)
|
||||||
assert all(prop.parameters.name == f"prop {i}" for i, prop in enumerate(propagations))
|
assert all(prop.parameters.name == f"prop {i}" for i, prop in enumerate(propagations))
|
||||||
|
|||||||
Reference in New Issue
Block a user