rudimentary new io data structure
This commit is contained in:
124
tests/test_io_handlers.py
Normal file
124
tests/test_io_handlers.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
|
||||
from scgenerator.parameter import Parameters
|
||||
from scgenerator.spectra import Propagation
|
||||
|
||||
PARAMS = dict(
|
||||
name="PM2000D sech pulse",
|
||||
# pulse
|
||||
wavelength=1546e-9,
|
||||
# field_file="./Pos30000New.npz",
|
||||
width=1.5e-12,
|
||||
shape="sech",
|
||||
repetition_rate=40e6,
|
||||
# fiber
|
||||
dispersion_file="./PM2000D_2 extrapolated 4 0.npz",
|
||||
effective_area_file="./PM2000D_A_eff_marcuse.npz",
|
||||
wavelength_window=(400e-9, 4000e-9),
|
||||
n2=4.5e-20,
|
||||
# simulation
|
||||
raman_type="measured",
|
||||
quantum_noise=True,
|
||||
interpolation_degree=11,
|
||||
z_num=128,
|
||||
length=1.5,
|
||||
t_num=512,
|
||||
dt=5e-15,
|
||||
)
|
||||
|
||||
|
||||
def test_file(tmp_path: Path):
|
||||
params = Parameters(**PARAMS)
|
||||
stuff = np.random.rand(8, 512)
|
||||
io = ZipFileIOHandler(tmp_path / "test.zip")
|
||||
io.save_data("params.json", params.to_json().encode())
|
||||
for i, spec in enumerate(stuff):
|
||||
io.save_spectrum(i, spec)
|
||||
|
||||
new_params = Parameters.from_json(io.load_data("params.json").decode())
|
||||
assert new_params is not params
|
||||
for k, v in params.items():
|
||||
v_new = getattr(new_params, k)
|
||||
if isinstance(v, DataFile):
|
||||
assert Path(v.path).name == Path(v_new.path).name
|
||||
else:
|
||||
assert v == getattr(new_params, k)
|
||||
|
||||
for i in range(8):
|
||||
assert np.all(io.load_spectrum(i) == stuff[i])
|
||||
|
||||
assert len(ZipFileIOHandler(tmp_path / "test.zip")) == len(io) == 8
|
||||
|
||||
|
||||
def test_memory():
|
||||
params = Parameters(**PARAMS)
|
||||
stuff = np.random.rand(8, 512)
|
||||
io = MemoryIOHandler()
|
||||
assert len(io) == 0
|
||||
io.save_data("params.json", params.to_json().encode())
|
||||
for i, spec in enumerate(stuff):
|
||||
io.save_spectrum(i, spec)
|
||||
|
||||
new_params = Parameters.from_json(io.load_data("params.json").decode())
|
||||
for k, v in params.items():
|
||||
v_new = getattr(new_params, k)
|
||||
if isinstance(v, DataFile):
|
||||
assert Path(v.path).name == Path(v_new.path).name
|
||||
else:
|
||||
assert v == getattr(new_params, k)
|
||||
for i in range(8):
|
||||
assert np.all(io.load_spectrum(i) == stuff[i])
|
||||
|
||||
assert len(io) == 8
|
||||
|
||||
|
||||
def test_zip_bundle(tmp_path: Path):
|
||||
params = Parameters(**PARAMS)
|
||||
io = ZipFileIOHandler(tmp_path / "file.zip")
|
||||
prop = Propagation(io, params.copy(True))
|
||||
|
||||
assert (tmp_path / "file.zip").exists()
|
||||
assert (tmp_path / "file.zip").read_bytes() != b""
|
||||
|
||||
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
|
||||
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
|
||||
params.dispersion_file.path = str(new_disp_path)
|
||||
params.effective_area_file.path = str(new_aeff_path)
|
||||
params.freeze()
|
||||
|
||||
io = ZipFileIOHandler(tmp_path / "file2.zip")
|
||||
prop2 = Propagation(io, params, True)
|
||||
|
||||
assert params.dispersion_file.path == new_disp_path.name
|
||||
assert params.dispersion_file.prefix == "zip"
|
||||
assert params.effective_area_file.path == new_aeff_path.name
|
||||
assert params.effective_area_file.prefix == "zip"
|
||||
|
||||
with ZipFile(tmp_path / "file2.zip", "r") as zfile:
|
||||
with zfile.open(new_aeff_path.name) as file:
|
||||
assert file.read() == new_aeff_path.read_bytes()
|
||||
with zfile.open(new_disp_path.name) as file:
|
||||
assert file.read() == new_disp_path.read_bytes()
|
||||
|
||||
with zfile.open(Propagation.PARAMS_FN) as file:
|
||||
df = json.loads(file.read().decode())
|
||||
|
||||
assert (
|
||||
df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path
|
||||
)
|
||||
|
||||
assert (
|
||||
df["effective_area_file"]
|
||||
== params.effective_area_file.prefix + "::" + params.effective_area_file.path
|
||||
)
|
||||
|
||||
|
||||
def test_unique_name():
|
||||
existing = {"spec.npy", "spec_0.npy"}
|
||||
assert unique_name("spec.npy", existing) == "spec_1.npy"
|
||||
Reference in New Issue
Block a user