Files
scgenerator/tests/test_io_handlers.py
Benoît Sierro 27de20e6ca small refactor
2023-10-04 08:21:37 +02:00

202 lines
6.1 KiB
Python

import json
from pathlib import Path
from zipfile import ZipFile
import numpy as np
import pytest
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
from scgenerator.parameter import Parameters
from scgenerator.spectra import PARAMS_FN, propagation, propagation_series
PARAMS = dict(
name="PM2000D sech pulse",
# pulse
wavelength=1546e-9,
# field_file="./Pos30000New.npz",
width=1.5e-12,
shape="sech",
repetition_rate=40e6,
# fiber
dispersion_file="./PM2000D_2 extrapolated 4 0.npz",
effective_area_file="./PM2000D_A_eff_marcuse.npz",
n2=4.5e-20,
# simulation
raman_type="measured",
quantum_noise=True,
interpolation_degree=11,
z_num=8,
length=1.5,
t_num=512,
dt=5e-15,
)
def test_file(tmp_path: Path):
params = Parameters(**PARAMS)
stuff = np.random.rand(8, 512)
io = ZipFileIOHandler(tmp_path / "test.zip")
io.save_data("params.json", params.to_json().encode())
for i, spec in enumerate(stuff):
io.save_spectrum(i, spec)
new_params = Parameters.from_json(io.load_data("params.json").decode())
assert new_params is not params
for k, v in params.items():
v_new = getattr(new_params, k)
if isinstance(v, DataFile):
assert Path(v.path).name == Path(v_new.path).name
else:
assert v == getattr(new_params, k)
for i in range(8):
assert np.all(io.load_spectrum(i) == stuff[i])
assert len(ZipFileIOHandler(tmp_path / "test.zip")) == len(io) == 8
def test_memory():
params = Parameters(**PARAMS)
stuff = np.random.rand(8, 512)
io = MemoryIOHandler()
assert len(io) == 0
io.save_data("params.json", params.to_json().encode())
for i, spec in enumerate(stuff):
io.save_spectrum(i, spec)
new_params = Parameters.from_json(io.load_data("params.json").decode())
for k, v in params.items():
v_new = getattr(new_params, k)
if isinstance(v, DataFile):
assert Path(v.path).name == Path(v_new.path).name
else:
assert v == getattr(new_params, k)
for i in range(8):
assert np.all(io.load_spectrum(i) == stuff[i])
assert len(io) == 8
def test_reopen(tmp_path: Path):
zpath = tmp_path / "file.zip"
params = Parameters(**PARAMS)
prop = propagation(zpath, params)
prop2 = propagation(zpath)
assert prop.parameters == prop2.parameters
def test_clear(tmp_path: Path):
params = Parameters(**PARAMS)
zpath = tmp_path / "file.zip"
prop = propagation(zpath, params)
assert zpath.exists()
assert zpath.read_bytes() != b""
prop.io.clear()
assert not zpath.exists()
def test_overwrite(tmp_path: Path):
params = Parameters(**PARAMS)
zpath = tmp_path / "file.zip"
_ = propagation(zpath, params)
orig_file = zpath.read_bytes()
with pytest.raises(FileExistsError):
_ = propagation(zpath, params)
_ = propagation(zpath, params, overwrite=True)
assert zpath.read_bytes() != orig_file
assert len(zpath.read_bytes()) == len(orig_file)
def test_zip_data_copy(tmp_path: Path):
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
params = Parameters(
**(PARAMS | dict(dispersion_file=new_disp_path, effective_area_file=new_aeff_path))
)
_ = propagation(tmp_path / "file.zip", params, bundle_data=True)
prop = propagation(tmp_path / "file.zip")
assert prop.parameters.effective_area_file.load_bytes() == new_aeff_path.read_bytes()
assert prop.parameters.dispersion_file.load_bytes() == new_disp_path.read_bytes()
def test_zip_bundle(tmp_path: Path):
params = Parameters(**PARAMS)
with pytest.raises(FileNotFoundError):
propagation(tmp_path / "file2.zip", params, bundle_data=True)
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
params.dispersion_file.path = str(new_disp_path)
params.effective_area_file.path = str(new_aeff_path)
params.freeze()
prop2 = propagation(tmp_path / "file3.zip", params, bundle_data=True)
assert prop2.parameters.dispersion_file.path == new_disp_path.name
assert prop2.parameters.dispersion_file.prefix == "zip"
assert prop2.parameters.effective_area_file.path == new_aeff_path.name
assert prop2.parameters.effective_area_file.prefix == "zip"
with ZipFile(tmp_path / "file3.zip", "r") as zfile:
with zfile.open(new_aeff_path.name) as file:
assert file.read() == new_aeff_path.read_bytes()
with zfile.open(new_disp_path.name) as file:
assert file.read() == new_disp_path.read_bytes()
with zfile.open(PARAMS_FN) as file:
df = json.loads(file.read().decode())
assert (
df["dispersion_file"]
== prop2.parameters.dispersion_file.prefix + "::" + Path(params.dispersion_file.path).name
)
assert (
df["effective_area_file"]
== prop2.parameters.effective_area_file.prefix
+ "::"
+ Path(params.effective_area_file.path).name
)
def test_unique_name():
existing = {"spec.npy", "spec_0.npy"}
assert unique_name("spec.npy", existing) == "spec_1.npy"
existing = {"spec", "spec0"}
assert unique_name("spec", existing) == "spec_0"
existing = {"spec", "spec_0"}
assert unique_name("spec", existing) == "spec_1"
def test_propagation_series(tmp_path: Path):
params = Parameters(**PARAMS)
with pytest.raises(ValueError):
specs, _ = propagation_series([])
flist = [tmp_path / f"prop{i}.zip" for i in range(10)]
for i, f in enumerate(flist):
params.name = f"prop {i}"
prop = propagation(f, params)
for _ in range(params.z_num):
prop.append(np.zeros(params.t_num, dtype=complex))
assert set(flist) == set(tmp_path.glob("*.zip"))
specs, propagations = propagation_series(flist)
assert specs.shape == (10, params.z_num, params.t_num)
assert all(prop.parameters.name == f"prop {i}" for i, prop in enumerate(propagations))