change: units and io refactoring
This commit is contained in:
7
TODO.md
7
TODO.md
@@ -1,7 +0,0 @@
|
||||
* plot cli command should show by default, not save
|
||||
* logger should hook to pbar when it exists
|
||||
* find a way to make evaluator debugging easier
|
||||
|
||||
# Ideas
|
||||
|
||||
* have a `scan` section in the config to more granularly control parameter scanning. That way, the user provides single default values to all necessary parameters and then call the single-sim or scan commands to run the config accordingly.
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "scgenerator"
|
||||
version = "0.3.10"
|
||||
version = "0.3.11"
|
||||
description = "Simulate nonlinear pulse propagation in optical fibers"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
||||
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO, Protocol, Sequence
|
||||
from zipfile import ZipFile
|
||||
@@ -218,7 +219,7 @@ class DataFile:
|
||||
return os.fspath(Path(self.path))
|
||||
return self.prefix + self.PREFIX_SEP + self.path
|
||||
|
||||
def load_data(self) -> bytes:
|
||||
def load_bytes(self) -> bytes:
|
||||
if self.prefix is not None and self.io is None:
|
||||
raise ValueError(
|
||||
f"a bundled file prefixed with {self.prefix} "
|
||||
@@ -230,10 +231,53 @@ class DataFile:
|
||||
else:
|
||||
return Path(self.path).read_bytes()
|
||||
|
||||
def load_arrays(self, *labels: str | tuple[str, ...]) -> tuple[np.ndarray, ...]:
|
||||
raw_data = self.load_bytes()
|
||||
extension = self.path.lower().split()[-1]
|
||||
if extension == "npz":
|
||||
df = np.load(BytesIO(raw_data))
|
||||
return load_npz_data(df, *labels)
|
||||
elif extension == "npy":
|
||||
return tuple(np.load(BytesIO(raw_data)))
|
||||
else:
|
||||
return load_txt_data(raw_data.decode())
|
||||
|
||||
def similar_to(self, other: DataFile) -> bool:
|
||||
return Path(self.path).name == Path(other.path).name
|
||||
|
||||
|
||||
def load_txt_data(s: str) -> tuple[np.ndarray, ...]:
|
||||
lines = s.splitlines()
|
||||
for delimiter in ", \t;":
|
||||
try:
|
||||
return tuple(np.loadtxt(lines, delimiter=delimiter).T)
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError("Could not load text data as numpy array")
|
||||
|
||||
|
||||
def load_npz_data(
|
||||
df: np.lib.npyio.NpzFile, *labels: str | tuple[str, ...]
|
||||
) -> tuple[np.ndarray, ...]:
|
||||
if not labels:
|
||||
return tuple(df.values())
|
||||
out = []
|
||||
for key in labels:
|
||||
if not isinstance(key, (list, tuple, set)):
|
||||
out.append(df[key])
|
||||
continue
|
||||
for possible_key in key:
|
||||
try:
|
||||
out.append(df[possible_key])
|
||||
except KeyError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise KeyError(f"no key of {key!r} present in {df}")
|
||||
return tuple(out)
|
||||
|
||||
|
||||
def unique_name(base_name: str, existing: set[str]) -> str:
|
||||
name = base_name
|
||||
p = Path(base_name)
|
||||
|
||||
@@ -672,17 +672,15 @@ def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) ->
|
||||
np.ndarray, shape (n,)
|
||||
wl-dependent effective mode field area
|
||||
"""
|
||||
data = np.load(BytesIO(effective_area_file.load_data()))
|
||||
effective_area = data.get("A_eff", data.get("effective_area"))
|
||||
wl = data["wavelength"]
|
||||
wl, effective_area = effective_area_file.load_arrays(
|
||||
("wavelength", "wl"), ("A_eff", "effective_area")
|
||||
)
|
||||
return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l)
|
||||
|
||||
|
||||
def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
|
||||
disp_file = np.load(BytesIO(dispersion_file.load_data()))
|
||||
wl_for_disp = disp_file["wavelength"]
|
||||
wl_for_disp, D = dispersion_file.load_arrays(("wavelength", "wl"), "dispersion")
|
||||
interp_range = (np.min(wl_for_disp), np.max(wl_for_disp))
|
||||
D = disp_file["dispersion"]
|
||||
beta2 = D_to_beta2(D, wl_for_disp)
|
||||
return wl_for_disp, beta2, interp_range
|
||||
|
||||
@@ -703,9 +701,7 @@ def load_custom_loss(l: np.ndarray, loss_file: DataFile) -> np.ndarray:
|
||||
np.ndarray, shape (n,)
|
||||
loss in 1/m units
|
||||
"""
|
||||
loss_data = np.load(BytesIO(loss_file.load_data()))
|
||||
wl = loss_data["wavelength"]
|
||||
loss = loss_data["loss"]
|
||||
wl, loss = loss_file.load_arrays(("wavelength", "wl"), "loss")
|
||||
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
||||
|
||||
|
||||
|
||||
@@ -460,9 +460,7 @@ def interp_custom_field(
|
||||
|
||||
|
||||
def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
|
||||
data = field_file.load_data()
|
||||
field_data = np.load(BytesIO(data))
|
||||
return field_data["time"], field_data["field"]
|
||||
return field_file.load_arrays("time", "field")
|
||||
|
||||
|
||||
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:
|
||||
|
||||
@@ -1042,15 +1042,15 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
|
||||
|
||||
if log is not False:
|
||||
if isinstance(log, (float, int, np.floating, np.integer)) and log is not True:
|
||||
values = units.to_log(values, ref=log)
|
||||
values = math.to_dB(values, ref=log)
|
||||
elif log == "2D":
|
||||
values = units.to_log2D(values)
|
||||
values = math.to_dB(values, ref=values.max())
|
||||
elif log == "1D" or log is True:
|
||||
values = np.apply_along_axis(units.to_log, -1, values)
|
||||
values = np.apply_along_axis(math.to_dB, -1, values)
|
||||
elif log == "smooth 1D":
|
||||
ref = np.max(values, axis=1)
|
||||
ind = np.argmax((ref[:-1] - ref[1:]) < 0)
|
||||
values = units.to_log(values, ref=np.max(ref[ind:]))
|
||||
values = math.to_dB(values, ref=np.max(ref[ind:]))
|
||||
else:
|
||||
raise ValueError(f"Log argument {log} not recognized")
|
||||
return values
|
||||
|
||||
@@ -126,8 +126,8 @@ def test_zip_data_copy(tmp_path: Path):
|
||||
|
||||
prop = propagation(tmp_path / "file.zip")
|
||||
|
||||
assert prop.parameters.effective_area_file.load_data() == new_aeff_path.read_bytes()
|
||||
assert prop.parameters.dispersion_file.load_data() == new_disp_path.read_bytes()
|
||||
assert prop.parameters.effective_area_file.load_bytes() == new_aeff_path.read_bytes()
|
||||
assert prop.parameters.dispersion_file.load_bytes() == new_disp_path.read_bytes()
|
||||
|
||||
|
||||
def test_zip_bundle(tmp_path: Path):
|
||||
|
||||
Reference in New Issue
Block a user