change: units and io refactoring

This commit is contained in:
Benoît Sierro
2023-09-12 11:55:35 +02:00
parent 292896bdce
commit ae09fdd3d0
7 changed files with 58 additions and 27 deletions

View File

@@ -1,7 +0,0 @@
* plot cli command should show by default, not save
* logger should hook to pbar when it exists
* find a way to make evaluator debugging easier
# Ideas
* have a `scan` section in the config to more granularly control parameter scanning. That way, the user provides single default values to all necessary parameters and then call the single-sim or scan commands to run the config accordingly.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "scgenerator"
version = "0.3.10"
version = "0.3.11"
description = "Simulate nonlinear pulse propagation in optical fibers"
readme = "README.md"
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]

View File

@@ -6,6 +6,7 @@ import json
import os
import re
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import BinaryIO, Protocol, Sequence
from zipfile import ZipFile
@@ -218,7 +219,7 @@ class DataFile:
return os.fspath(Path(self.path))
return self.prefix + self.PREFIX_SEP + self.path
def load_data(self) -> bytes:
def load_bytes(self) -> bytes:
if self.prefix is not None and self.io is None:
raise ValueError(
f"a bundled file prefixed with {self.prefix} "
@@ -230,10 +231,53 @@ class DataFile:
else:
return Path(self.path).read_bytes()
def load_arrays(self, *labels: str | tuple[str, ...]) -> tuple[np.ndarray, ...]:
raw_data = self.load_bytes()
extension = self.path.lower().split()[-1]
if extension == "npz":
df = np.load(BytesIO(raw_data))
return load_npz_data(df, *labels)
elif extension == "npy":
return tuple(np.load(BytesIO(raw_data)))
else:
return load_txt_data(raw_data.decode())
def similar_to(self, other: DataFile) -> bool:
return Path(self.path).name == Path(other.path).name
def load_txt_data(s: str) -> tuple[np.ndarray, ...]:
lines = s.splitlines()
for delimiter in ", \t;":
try:
return tuple(np.loadtxt(lines, delimiter=delimiter).T)
except ValueError:
continue
raise ValueError("Could not load text data as numpy array")
def load_npz_data(
df: np.lib.npyio.NpzFile, *labels: str | tuple[str, ...]
) -> tuple[np.ndarray, ...]:
if not labels:
return tuple(df.values())
out = []
for key in labels:
if not isinstance(key, (list, tuple, set)):
out.append(df[key])
continue
for possible_key in key:
try:
out.append(df[possible_key])
except KeyError:
continue
else:
break
else:
raise KeyError(f"no key of {key!r} present in {df}")
return tuple(out)
def unique_name(base_name: str, existing: set[str]) -> str:
name = base_name
p = Path(base_name)

View File

@@ -672,17 +672,15 @@ def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) ->
np.ndarray, shape (n,)
wl-dependent effective mode field area
"""
data = np.load(BytesIO(effective_area_file.load_data()))
effective_area = data.get("A_eff", data.get("effective_area"))
wl = data["wavelength"]
wl, effective_area = effective_area_file.load_arrays(
("wavelength", "wl"), ("A_eff", "effective_area")
)
return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l)
def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
disp_file = np.load(BytesIO(dispersion_file.load_data()))
wl_for_disp = disp_file["wavelength"]
wl_for_disp, D = dispersion_file.load_arrays(("wavelength", "wl"), "dispersion")
interp_range = (np.min(wl_for_disp), np.max(wl_for_disp))
D = disp_file["dispersion"]
beta2 = D_to_beta2(D, wl_for_disp)
return wl_for_disp, beta2, interp_range
@@ -703,9 +701,7 @@ def load_custom_loss(l: np.ndarray, loss_file: DataFile) -> np.ndarray:
np.ndarray, shape (n,)
loss in 1/m units
"""
loss_data = np.load(BytesIO(loss_file.load_data()))
wl = loss_data["wavelength"]
loss = loss_data["loss"]
wl, loss = loss_file.load_arrays(("wavelength", "wl"), "loss")
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)

View File

@@ -460,9 +460,7 @@ def interp_custom_field(
def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
data = field_file.load_data()
field_data = np.load(BytesIO(data))
return field_data["time"], field_data["field"]
return field_file.load_arrays("time", "field")
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:

View File

@@ -1042,15 +1042,15 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
if log is not False:
if isinstance(log, (float, int, np.floating, np.integer)) and log is not True:
values = units.to_log(values, ref=log)
values = math.to_dB(values, ref=log)
elif log == "2D":
values = units.to_log2D(values)
values = math.to_dB(values, ref=values.max())
elif log == "1D" or log is True:
values = np.apply_along_axis(units.to_log, -1, values)
values = np.apply_along_axis(math.to_dB, -1, values)
elif log == "smooth 1D":
ref = np.max(values, axis=1)
ind = np.argmax((ref[:-1] - ref[1:]) < 0)
values = units.to_log(values, ref=np.max(ref[ind:]))
values = math.to_dB(values, ref=np.max(ref[ind:]))
else:
raise ValueError(f"Log argument {log} not recognized")
return values

View File

@@ -126,8 +126,8 @@ def test_zip_data_copy(tmp_path: Path):
prop = propagation(tmp_path / "file.zip")
assert prop.parameters.effective_area_file.load_data() == new_aeff_path.read_bytes()
assert prop.parameters.dispersion_file.load_data() == new_disp_path.read_bytes()
assert prop.parameters.effective_area_file.load_bytes() == new_aeff_path.read_bytes()
assert prop.parameters.dispersion_file.load_bytes() == new_disp_path.read_bytes()
def test_zip_bundle(tmp_path: Path):