163 lines
4.7 KiB
Python
163 lines
4.7 KiB
Python
import json
|
|
from pathlib import Path
|
|
from zipfile import ZipFile
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
|
|
from scgenerator.parameter import Parameters
|
|
from scgenerator.spectra import PARAMS_FN, propagation
|
|
|
|
PARAMS = dict(
|
|
name="PM2000D sech pulse",
|
|
# pulse
|
|
wavelength=1546e-9,
|
|
# field_file="./Pos30000New.npz",
|
|
width=1.5e-12,
|
|
shape="sech",
|
|
repetition_rate=40e6,
|
|
# fiber
|
|
dispersion_file="./PM2000D_2 extrapolated 4 0.npz",
|
|
effective_area_file="./PM2000D_A_eff_marcuse.npz",
|
|
wavelength_window=(400e-9, 4000e-9),
|
|
n2=4.5e-20,
|
|
# simulation
|
|
raman_type="measured",
|
|
quantum_noise=True,
|
|
interpolation_degree=11,
|
|
z_num=128,
|
|
length=1.5,
|
|
t_num=512,
|
|
dt=5e-15,
|
|
)
|
|
|
|
|
|
def test_file(tmp_path: Path):
|
|
params = Parameters(**PARAMS)
|
|
stuff = np.random.rand(8, 512)
|
|
io = ZipFileIOHandler(tmp_path / "test.zip")
|
|
io.save_data("params.json", params.to_json().encode())
|
|
for i, spec in enumerate(stuff):
|
|
io.save_spectrum(i, spec)
|
|
|
|
new_params = Parameters.from_json(io.load_data("params.json").decode())
|
|
assert new_params is not params
|
|
for k, v in params.items():
|
|
v_new = getattr(new_params, k)
|
|
if isinstance(v, DataFile):
|
|
assert Path(v.path).name == Path(v_new.path).name
|
|
else:
|
|
assert v == getattr(new_params, k)
|
|
|
|
for i in range(8):
|
|
assert np.all(io.load_spectrum(i) == stuff[i])
|
|
|
|
assert len(ZipFileIOHandler(tmp_path / "test.zip")) == len(io) == 8
|
|
|
|
|
|
def test_memory():
|
|
params = Parameters(**PARAMS)
|
|
stuff = np.random.rand(8, 512)
|
|
io = MemoryIOHandler()
|
|
assert len(io) == 0
|
|
io.save_data("params.json", params.to_json().encode())
|
|
for i, spec in enumerate(stuff):
|
|
io.save_spectrum(i, spec)
|
|
|
|
new_params = Parameters.from_json(io.load_data("params.json").decode())
|
|
for k, v in params.items():
|
|
v_new = getattr(new_params, k)
|
|
if isinstance(v, DataFile):
|
|
assert Path(v.path).name == Path(v_new.path).name
|
|
else:
|
|
assert v == getattr(new_params, k)
|
|
for i in range(8):
|
|
assert np.all(io.load_spectrum(i) == stuff[i])
|
|
|
|
assert len(io) == 8
|
|
|
|
|
|
def test_reopen(tmp_path: Path):
|
|
zpath = tmp_path / "file.zip"
|
|
params = Parameters(**PARAMS)
|
|
prop = propagation(zpath, params)
|
|
|
|
prop2 = propagation(zpath)
|
|
|
|
assert prop.parameters == prop2.parameters
|
|
|
|
|
|
def test_clear(tmp_path: Path):
|
|
params = Parameters(**PARAMS)
|
|
zpath = tmp_path / "file.zip"
|
|
prop = propagation(zpath, params)
|
|
|
|
assert zpath.exists()
|
|
assert zpath.read_bytes() != b""
|
|
|
|
prop.io.clear()
|
|
|
|
assert not zpath.exists()
|
|
|
|
|
|
def test_overwrite(tmp_path: Path):
|
|
params = Parameters(**PARAMS)
|
|
zpath = tmp_path / "file.zip"
|
|
_ = propagation(zpath, params)
|
|
orig_file = zpath.read_bytes()
|
|
|
|
with pytest.raises(FileExistsError):
|
|
_ = propagation(zpath, params)
|
|
|
|
_ = propagation(zpath, params, overwrite=True)
|
|
|
|
assert zpath.read_bytes() != orig_file
|
|
assert len(zpath.read_bytes()) == len(orig_file)
|
|
|
|
|
|
def test_zip_bundle(tmp_path: Path):
|
|
params = Parameters(**PARAMS)
|
|
|
|
with pytest.raises(FileNotFoundError):
|
|
propagation(tmp_path / "file2.zip", params, bundle_data=True)
|
|
|
|
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
|
|
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
|
|
params.dispersion_file.path = str(new_disp_path)
|
|
params.effective_area_file.path = str(new_aeff_path)
|
|
params.freeze()
|
|
|
|
prop2 = propagation(tmp_path / "file3.zip", params, bundle_data=True)
|
|
|
|
assert prop2.parameters.dispersion_file.path == new_disp_path.name
|
|
assert prop2.parameters.dispersion_file.prefix == "zip"
|
|
assert prop2.parameters.effective_area_file.path == new_aeff_path.name
|
|
assert prop2.parameters.effective_area_file.prefix == "zip"
|
|
|
|
with ZipFile(tmp_path / "file3.zip", "r") as zfile:
|
|
with zfile.open(new_aeff_path.name) as file:
|
|
assert file.read() == new_aeff_path.read_bytes()
|
|
with zfile.open(new_disp_path.name) as file:
|
|
assert file.read() == new_disp_path.read_bytes()
|
|
|
|
with zfile.open(PARAMS_FN) as file:
|
|
df = json.loads(file.read().decode())
|
|
|
|
assert (
|
|
df["dispersion_file"]
|
|
== prop2.parameters.dispersion_file.prefix + "::" + Path(params.dispersion_file.path).name
|
|
)
|
|
|
|
assert (
|
|
df["effective_area_file"]
|
|
== prop2.parameters.effective_area_file.prefix
|
|
+ "::"
|
|
+ Path(params.effective_area_file.path).name
|
|
)
|
|
|
|
|
|
def test_unique_name():
|
|
existing = {"spec.npy", "spec_0.npy"}
|
|
assert unique_name("spec.npy", existing) == "spec_1.npy"
|