change: units and io refactoring

This commit is contained in:
Benoît Sierro
2023-09-12 11:55:35 +02:00
parent 292896bdce
commit ae09fdd3d0
7 changed files with 58 additions and 27 deletions

View File

@@ -1,7 +0,0 @@
* plot cli command should show by default, not save
* logger should hook to pbar when it exists
* find a way to make evaluator debugging easier
# Ideas
* have a `scan` section in the config to more granularly control parameter scanning. That way, the user provides single default values to all necessary parameters and then call the single-sim or scan commands to run the config accordingly.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "scgenerator" name = "scgenerator"
version = "0.3.10" version = "0.3.11"
description = "Simulate nonlinear pulse propagation in optical fibers" description = "Simulate nonlinear pulse propagation in optical fibers"
readme = "README.md" readme = "README.md"
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]

View File

@@ -6,6 +6,7 @@ import json
import os import os
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import BinaryIO, Protocol, Sequence from typing import BinaryIO, Protocol, Sequence
from zipfile import ZipFile from zipfile import ZipFile
@@ -218,7 +219,7 @@ class DataFile:
return os.fspath(Path(self.path)) return os.fspath(Path(self.path))
return self.prefix + self.PREFIX_SEP + self.path return self.prefix + self.PREFIX_SEP + self.path
def load_data(self) -> bytes: def load_bytes(self) -> bytes:
if self.prefix is not None and self.io is None: if self.prefix is not None and self.io is None:
raise ValueError( raise ValueError(
f"a bundled file prefixed with {self.prefix} " f"a bundled file prefixed with {self.prefix} "
@@ -230,10 +231,53 @@ class DataFile:
else: else:
return Path(self.path).read_bytes() return Path(self.path).read_bytes()
def load_arrays(self, *labels: str | tuple[str, ...]) -> tuple[np.ndarray, ...]:
raw_data = self.load_bytes()
extension = self.path.lower().split()[-1]
if extension == "npz":
df = np.load(BytesIO(raw_data))
return load_npz_data(df, *labels)
elif extension == "npy":
return tuple(np.load(BytesIO(raw_data)))
else:
return load_txt_data(raw_data.decode())
def similar_to(self, other: DataFile) -> bool: def similar_to(self, other: DataFile) -> bool:
return Path(self.path).name == Path(other.path).name return Path(self.path).name == Path(other.path).name
def load_txt_data(s: str) -> tuple[np.ndarray, ...]:
lines = s.splitlines()
for delimiter in ", \t;":
try:
return tuple(np.loadtxt(lines, delimiter=delimiter).T)
except ValueError:
continue
raise ValueError("Could not load text data as numpy array")
def load_npz_data(
df: np.lib.npyio.NpzFile, *labels: str | tuple[str, ...]
) -> tuple[np.ndarray, ...]:
if not labels:
return tuple(df.values())
out = []
for key in labels:
if not isinstance(key, (list, tuple, set)):
out.append(df[key])
continue
for possible_key in key:
try:
out.append(df[possible_key])
except KeyError:
continue
else:
break
else:
raise KeyError(f"no key of {key!r} present in {df}")
return tuple(out)
def unique_name(base_name: str, existing: set[str]) -> str: def unique_name(base_name: str, existing: set[str]) -> str:
name = base_name name = base_name
p = Path(base_name) p = Path(base_name)

View File

@@ -672,17 +672,15 @@ def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) ->
np.ndarray, shape (n,) np.ndarray, shape (n,)
wl-dependent effective mode field area wl-dependent effective mode field area
""" """
data = np.load(BytesIO(effective_area_file.load_data())) wl, effective_area = effective_area_file.load_arrays(
effective_area = data.get("A_eff", data.get("effective_area")) ("wavelength", "wl"), ("A_eff", "effective_area")
wl = data["wavelength"] )
return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l) return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l)
def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]: def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
disp_file = np.load(BytesIO(dispersion_file.load_data())) wl_for_disp, D = dispersion_file.load_arrays(("wavelength", "wl"), "dispersion")
wl_for_disp = disp_file["wavelength"]
interp_range = (np.min(wl_for_disp), np.max(wl_for_disp)) interp_range = (np.min(wl_for_disp), np.max(wl_for_disp))
D = disp_file["dispersion"]
beta2 = D_to_beta2(D, wl_for_disp) beta2 = D_to_beta2(D, wl_for_disp)
return wl_for_disp, beta2, interp_range return wl_for_disp, beta2, interp_range
@@ -703,9 +701,7 @@ def load_custom_loss(l: np.ndarray, loss_file: DataFile) -> np.ndarray:
np.ndarray, shape (n,) np.ndarray, shape (n,)
loss in 1/m units loss in 1/m units
""" """
loss_data = np.load(BytesIO(loss_file.load_data())) wl, loss = loss_file.load_arrays(("wavelength", "wl"), "loss")
wl = loss_data["wavelength"]
loss = loss_data["loss"]
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l) return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)

View File

@@ -460,9 +460,7 @@ def interp_custom_field(
def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]: def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
data = field_file.load_data() return field_file.load_arrays("time", "field")
field_data = np.load(BytesIO(data))
return field_data["time"], field_data["field"]
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float: def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:

View File

@@ -1042,15 +1042,15 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
if log is not False: if log is not False:
if isinstance(log, (float, int, np.floating, np.integer)) and log is not True: if isinstance(log, (float, int, np.floating, np.integer)) and log is not True:
values = units.to_log(values, ref=log) values = math.to_dB(values, ref=log)
elif log == "2D": elif log == "2D":
values = units.to_log2D(values) values = math.to_dB(values, ref=values.max())
elif log == "1D" or log is True: elif log == "1D" or log is True:
values = np.apply_along_axis(units.to_log, -1, values) values = np.apply_along_axis(math.to_dB, -1, values)
elif log == "smooth 1D": elif log == "smooth 1D":
ref = np.max(values, axis=1) ref = np.max(values, axis=1)
ind = np.argmax((ref[:-1] - ref[1:]) < 0) ind = np.argmax((ref[:-1] - ref[1:]) < 0)
values = units.to_log(values, ref=np.max(ref[ind:])) values = math.to_dB(values, ref=np.max(ref[ind:]))
else: else:
raise ValueError(f"Log argument {log} not recognized") raise ValueError(f"Log argument {log} not recognized")
return values return values

View File

@@ -126,8 +126,8 @@ def test_zip_data_copy(tmp_path: Path):
prop = propagation(tmp_path / "file.zip") prop = propagation(tmp_path / "file.zip")
assert prop.parameters.effective_area_file.load_data() == new_aeff_path.read_bytes() assert prop.parameters.effective_area_file.load_bytes() == new_aeff_path.read_bytes()
assert prop.parameters.dispersion_file.load_data() == new_disp_path.read_bytes() assert prop.parameters.dispersion_file.load_bytes() == new_disp_path.read_bytes()
def test_zip_bundle(tmp_path: Path): def test_zip_bundle(tmp_path: Path):