rudimentary new io data structure

This commit is contained in:
Benoît Sierro
2023-08-08 10:59:08 +02:00
parent 7b6e33ca0f
commit 98fa32c24b
8 changed files with 475 additions and 346 deletions

124
tests/test_io_handlers.py Normal file
View File

@@ -0,0 +1,124 @@
import json
from pathlib import Path
from zipfile import ZipFile
import numpy as np
import pytest
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
from scgenerator.parameter import Parameters
from scgenerator.spectra import Propagation
PARAMS = dict(
name="PM2000D sech pulse",
# pulse
wavelength=1546e-9,
# field_file="./Pos30000New.npz",
width=1.5e-12,
shape="sech",
repetition_rate=40e6,
# fiber
dispersion_file="./PM2000D_2 extrapolated 4 0.npz",
effective_area_file="./PM2000D_A_eff_marcuse.npz",
wavelength_window=(400e-9, 4000e-9),
n2=4.5e-20,
# simulation
raman_type="measured",
quantum_noise=True,
interpolation_degree=11,
z_num=128,
length=1.5,
t_num=512,
dt=5e-15,
)
def test_file(tmp_path: Path):
params = Parameters(**PARAMS)
stuff = np.random.rand(8, 512)
io = ZipFileIOHandler(tmp_path / "test.zip")
io.save_data("params.json", params.to_json().encode())
for i, spec in enumerate(stuff):
io.save_spectrum(i, spec)
new_params = Parameters.from_json(io.load_data("params.json").decode())
assert new_params is not params
for k, v in params.items():
v_new = getattr(new_params, k)
if isinstance(v, DataFile):
assert Path(v.path).name == Path(v_new.path).name
else:
assert v == getattr(new_params, k)
for i in range(8):
assert np.all(io.load_spectrum(i) == stuff[i])
assert len(ZipFileIOHandler(tmp_path / "test.zip")) == len(io) == 8
def test_memory():
params = Parameters(**PARAMS)
stuff = np.random.rand(8, 512)
io = MemoryIOHandler()
assert len(io) == 0
io.save_data("params.json", params.to_json().encode())
for i, spec in enumerate(stuff):
io.save_spectrum(i, spec)
new_params = Parameters.from_json(io.load_data("params.json").decode())
for k, v in params.items():
v_new = getattr(new_params, k)
if isinstance(v, DataFile):
assert Path(v.path).name == Path(v_new.path).name
else:
assert v == getattr(new_params, k)
for i in range(8):
assert np.all(io.load_spectrum(i) == stuff[i])
assert len(io) == 8
def test_zip_bundle(tmp_path: Path):
params = Parameters(**PARAMS)
io = ZipFileIOHandler(tmp_path / "file.zip")
prop = Propagation(io, params.copy(True))
assert (tmp_path / "file.zip").exists()
assert (tmp_path / "file.zip").read_bytes() != b""
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
params.dispersion_file.path = str(new_disp_path)
params.effective_area_file.path = str(new_aeff_path)
params.freeze()
io = ZipFileIOHandler(tmp_path / "file2.zip")
prop2 = Propagation(io, params, True)
assert params.dispersion_file.path == new_disp_path.name
assert params.dispersion_file.prefix == "zip"
assert params.effective_area_file.path == new_aeff_path.name
assert params.effective_area_file.prefix == "zip"
with ZipFile(tmp_path / "file2.zip", "r") as zfile:
with zfile.open(new_aeff_path.name) as file:
assert file.read() == new_aeff_path.read_bytes()
with zfile.open(new_disp_path.name) as file:
assert file.read() == new_disp_path.read_bytes()
with zfile.open(Propagation.PARAMS_FN) as file:
df = json.loads(file.read().decode())
assert (
df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path
)
assert (
df["effective_area_file"]
== params.effective_area_file.prefix + "::" + params.effective_area_file.path
)
def test_unique_name():
existing = {"spec.npy", "spec_0.npy"}
assert unique_name("spec.npy", existing) == "spec_1.npy"