change: units and io refactoring
This commit is contained in:
7
TODO.md
7
TODO.md
@@ -1,7 +0,0 @@
|
|||||||
* plot cli command should show by default, not save
|
|
||||||
* logger should hook to pbar when it exists
|
|
||||||
* find a way to make evaluator debugging easier
|
|
||||||
|
|
||||||
# Ideas
|
|
||||||
|
|
||||||
* have a `scan` section in the config to more granularly control parameter scanning. That way, the user provides single default values to all necessary parameters and then call the single-sim or scan commands to run the config accordingly.
|
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "scgenerator"
|
name = "scgenerator"
|
||||||
version = "0.3.10"
|
version = "0.3.11"
|
||||||
description = "Simulate nonlinear pulse propagation in optical fibers"
|
description = "Simulate nonlinear pulse propagation in optical fibers"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import BinaryIO, Protocol, Sequence
|
from typing import BinaryIO, Protocol, Sequence
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
@@ -218,7 +219,7 @@ class DataFile:
|
|||||||
return os.fspath(Path(self.path))
|
return os.fspath(Path(self.path))
|
||||||
return self.prefix + self.PREFIX_SEP + self.path
|
return self.prefix + self.PREFIX_SEP + self.path
|
||||||
|
|
||||||
def load_data(self) -> bytes:
|
def load_bytes(self) -> bytes:
|
||||||
if self.prefix is not None and self.io is None:
|
if self.prefix is not None and self.io is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"a bundled file prefixed with {self.prefix} "
|
f"a bundled file prefixed with {self.prefix} "
|
||||||
@@ -230,10 +231,53 @@ class DataFile:
|
|||||||
else:
|
else:
|
||||||
return Path(self.path).read_bytes()
|
return Path(self.path).read_bytes()
|
||||||
|
|
||||||
|
def load_arrays(self, *labels: str | tuple[str, ...]) -> tuple[np.ndarray, ...]:
|
||||||
|
raw_data = self.load_bytes()
|
||||||
|
extension = self.path.lower().split()[-1]
|
||||||
|
if extension == "npz":
|
||||||
|
df = np.load(BytesIO(raw_data))
|
||||||
|
return load_npz_data(df, *labels)
|
||||||
|
elif extension == "npy":
|
||||||
|
return tuple(np.load(BytesIO(raw_data)))
|
||||||
|
else:
|
||||||
|
return load_txt_data(raw_data.decode())
|
||||||
|
|
||||||
def similar_to(self, other: DataFile) -> bool:
|
def similar_to(self, other: DataFile) -> bool:
|
||||||
return Path(self.path).name == Path(other.path).name
|
return Path(self.path).name == Path(other.path).name
|
||||||
|
|
||||||
|
|
||||||
|
def load_txt_data(s: str) -> tuple[np.ndarray, ...]:
|
||||||
|
lines = s.splitlines()
|
||||||
|
for delimiter in ", \t;":
|
||||||
|
try:
|
||||||
|
return tuple(np.loadtxt(lines, delimiter=delimiter).T)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
raise ValueError("Could not load text data as numpy array")
|
||||||
|
|
||||||
|
|
||||||
|
def load_npz_data(
|
||||||
|
df: np.lib.npyio.NpzFile, *labels: str | tuple[str, ...]
|
||||||
|
) -> tuple[np.ndarray, ...]:
|
||||||
|
if not labels:
|
||||||
|
return tuple(df.values())
|
||||||
|
out = []
|
||||||
|
for key in labels:
|
||||||
|
if not isinstance(key, (list, tuple, set)):
|
||||||
|
out.append(df[key])
|
||||||
|
continue
|
||||||
|
for possible_key in key:
|
||||||
|
try:
|
||||||
|
out.append(df[possible_key])
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise KeyError(f"no key of {key!r} present in {df}")
|
||||||
|
return tuple(out)
|
||||||
|
|
||||||
|
|
||||||
def unique_name(base_name: str, existing: set[str]) -> str:
|
def unique_name(base_name: str, existing: set[str]) -> str:
|
||||||
name = base_name
|
name = base_name
|
||||||
p = Path(base_name)
|
p = Path(base_name)
|
||||||
|
|||||||
@@ -672,17 +672,15 @@ def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) ->
|
|||||||
np.ndarray, shape (n,)
|
np.ndarray, shape (n,)
|
||||||
wl-dependent effective mode field area
|
wl-dependent effective mode field area
|
||||||
"""
|
"""
|
||||||
data = np.load(BytesIO(effective_area_file.load_data()))
|
wl, effective_area = effective_area_file.load_arrays(
|
||||||
effective_area = data.get("A_eff", data.get("effective_area"))
|
("wavelength", "wl"), ("A_eff", "effective_area")
|
||||||
wl = data["wavelength"]
|
)
|
||||||
return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l)
|
return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l)
|
||||||
|
|
||||||
|
|
||||||
def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
|
def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
|
||||||
disp_file = np.load(BytesIO(dispersion_file.load_data()))
|
wl_for_disp, D = dispersion_file.load_arrays(("wavelength", "wl"), "dispersion")
|
||||||
wl_for_disp = disp_file["wavelength"]
|
|
||||||
interp_range = (np.min(wl_for_disp), np.max(wl_for_disp))
|
interp_range = (np.min(wl_for_disp), np.max(wl_for_disp))
|
||||||
D = disp_file["dispersion"]
|
|
||||||
beta2 = D_to_beta2(D, wl_for_disp)
|
beta2 = D_to_beta2(D, wl_for_disp)
|
||||||
return wl_for_disp, beta2, interp_range
|
return wl_for_disp, beta2, interp_range
|
||||||
|
|
||||||
@@ -703,9 +701,7 @@ def load_custom_loss(l: np.ndarray, loss_file: DataFile) -> np.ndarray:
|
|||||||
np.ndarray, shape (n,)
|
np.ndarray, shape (n,)
|
||||||
loss in 1/m units
|
loss in 1/m units
|
||||||
"""
|
"""
|
||||||
loss_data = np.load(BytesIO(loss_file.load_data()))
|
wl, loss = loss_file.load_arrays(("wavelength", "wl"), "loss")
|
||||||
wl = loss_data["wavelength"]
|
|
||||||
loss = loss_data["loss"]
|
|
||||||
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -460,9 +460,7 @@ def interp_custom_field(
|
|||||||
|
|
||||||
|
|
||||||
def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
|
def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
|
||||||
data = field_file.load_data()
|
return field_file.load_arrays("time", "field")
|
||||||
field_data = np.load(BytesIO(data))
|
|
||||||
return field_data["time"], field_data["field"]
|
|
||||||
|
|
||||||
|
|
||||||
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:
|
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:
|
||||||
|
|||||||
@@ -1042,15 +1042,15 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
|
|||||||
|
|
||||||
if log is not False:
|
if log is not False:
|
||||||
if isinstance(log, (float, int, np.floating, np.integer)) and log is not True:
|
if isinstance(log, (float, int, np.floating, np.integer)) and log is not True:
|
||||||
values = units.to_log(values, ref=log)
|
values = math.to_dB(values, ref=log)
|
||||||
elif log == "2D":
|
elif log == "2D":
|
||||||
values = units.to_log2D(values)
|
values = math.to_dB(values, ref=values.max())
|
||||||
elif log == "1D" or log is True:
|
elif log == "1D" or log is True:
|
||||||
values = np.apply_along_axis(units.to_log, -1, values)
|
values = np.apply_along_axis(math.to_dB, -1, values)
|
||||||
elif log == "smooth 1D":
|
elif log == "smooth 1D":
|
||||||
ref = np.max(values, axis=1)
|
ref = np.max(values, axis=1)
|
||||||
ind = np.argmax((ref[:-1] - ref[1:]) < 0)
|
ind = np.argmax((ref[:-1] - ref[1:]) < 0)
|
||||||
values = units.to_log(values, ref=np.max(ref[ind:]))
|
values = math.to_dB(values, ref=np.max(ref[ind:]))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Log argument {log} not recognized")
|
raise ValueError(f"Log argument {log} not recognized")
|
||||||
return values
|
return values
|
||||||
|
|||||||
@@ -126,8 +126,8 @@ def test_zip_data_copy(tmp_path: Path):
|
|||||||
|
|
||||||
prop = propagation(tmp_path / "file.zip")
|
prop = propagation(tmp_path / "file.zip")
|
||||||
|
|
||||||
assert prop.parameters.effective_area_file.load_data() == new_aeff_path.read_bytes()
|
assert prop.parameters.effective_area_file.load_bytes() == new_aeff_path.read_bytes()
|
||||||
assert prop.parameters.dispersion_file.load_data() == new_disp_path.read_bytes()
|
assert prop.parameters.dispersion_file.load_bytes() == new_disp_path.read_bytes()
|
||||||
|
|
||||||
|
|
||||||
def test_zip_bundle(tmp_path: Path):
|
def test_zip_bundle(tmp_path: Path):
|
||||||
|
|||||||
Reference in New Issue
Block a user