diff --git a/TODO.md b/TODO.md deleted file mode 100644 index ec9062b..0000000 --- a/TODO.md +++ /dev/null @@ -1,7 +0,0 @@ -* plot cli command should show by default, not save -* logger should hook to pbar when it exists -* find a way to make evaluator debugging easier - -# Ideas - -* have a `scan` section in the config to more granularly control parameter scanning. That way, the user provides single default values to all necessary parameters and then call the single-sim or scan commands to run the config accordingly. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index bd4d253..656d353 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.3.10" +version = "0.3.11" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index e550964..18f6fea 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -6,6 +6,7 @@ import json import os import re from dataclasses import dataclass +from io import BytesIO from pathlib import Path from typing import BinaryIO, Protocol, Sequence from zipfile import ZipFile @@ -218,7 +219,7 @@ class DataFile: return os.fspath(Path(self.path)) return self.prefix + self.PREFIX_SEP + self.path - def load_data(self) -> bytes: + def load_bytes(self) -> bytes: if self.prefix is not None and self.io is None: raise ValueError( f"a bundled file prefixed with {self.prefix} " @@ -230,10 +231,53 @@ class DataFile: else: return Path(self.path).read_bytes() + def load_arrays(self, *labels: str | tuple[str, ...]) -> tuple[np.ndarray, ...]: + raw_data = self.load_bytes() + extension = self.path.lower().split()[-1] + if extension == "npz": + df = np.load(BytesIO(raw_data)) + return load_npz_data(df, *labels) + elif extension == "npy": + return tuple(np.load(BytesIO(raw_data))) + else: + return load_txt_data(raw_data.decode()) + def similar_to(self, other: DataFile) -> bool: return Path(self.path).name == Path(other.path).name +def load_txt_data(s: str) -> tuple[np.ndarray, ...]: + lines = s.splitlines() + for delimiter in ", \t;": + try: + return tuple(np.loadtxt(lines, delimiter=delimiter).T) + except ValueError: + continue + raise ValueError("Could not load text data as numpy array") + + +def load_npz_data( + df: np.lib.npyio.NpzFile, *labels: str | tuple[str, ...] +) -> tuple[np.ndarray, ...]: + if not labels: + return tuple(df.values()) + out = [] + for key in labels: + if not isinstance(key, (list, tuple, set)): + out.append(df[key]) + continue + for possible_key in key: + try: + out.append(df[possible_key]) + except KeyError: + continue + else: + break + else: + raise KeyError(f"no key of {key!r} present in {df}") + return tuple(out) + + def unique_name(base_name: str, existing: set[str]) -> str: name = base_name p = Path(base_name) diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 094c722..548f6ec 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -672,17 +672,15 @@ def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) -> np.ndarray, shape (n,) wl-dependent effective mode field area """ - data = np.load(BytesIO(effective_area_file.load_data())) - effective_area = data.get("A_eff", data.get("effective_area")) - wl = data["wavelength"] + wl, effective_area = effective_area_file.load_arrays( + ("wavelength", "wl"), ("A_eff", "effective_area") + ) return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l) def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]: - disp_file = np.load(BytesIO(dispersion_file.load_data())) - wl_for_disp = disp_file["wavelength"] + wl_for_disp, D = dispersion_file.load_arrays(("wavelength", "wl"), "dispersion") interp_range = (np.min(wl_for_disp), np.max(wl_for_disp)) - D = disp_file["dispersion"] beta2 = D_to_beta2(D, wl_for_disp) return wl_for_disp, beta2, interp_range @@ -703,9 +701,7 @@ def load_custom_loss(l: np.ndarray, loss_file: DataFile) -> np.ndarray: np.ndarray, shape (n,) loss in 1/m units """ - loss_data = np.load(BytesIO(loss_file.load_data())) - wl = loss_data["wavelength"] - loss = loss_data["loss"] + wl, loss = loss_file.load_arrays(("wavelength", "wl"), "loss") return interp1d(wl, loss, fill_value=0, bounds_error=False)(l) diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index c68ef62..64a066b 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -460,9 +460,7 @@ def interp_custom_field( def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]: - data = field_file.load_data() - field_data = np.load(BytesIO(data)) - return field_data["time"], field_data["field"] + return field_file.load_arrays("time", "field") def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float: diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 8595ecb..e8008c8 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -1042,15 +1042,15 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr if log is not False: if isinstance(log, (float, int, np.floating, np.integer)) and log is not True: - values = units.to_log(values, ref=log) + values = math.to_dB(values, ref=log) elif log == "2D": - values = units.to_log2D(values) + values = math.to_dB(values, ref=values.max()) elif log == "1D" or log is True: - values = np.apply_along_axis(units.to_log, -1, values) + values = np.apply_along_axis(math.to_dB, -1, values) elif log == "smooth 1D": ref = np.max(values, axis=1) ind = np.argmax((ref[:-1] - ref[1:]) < 0) - values = units.to_log(values, ref=np.max(ref[ind:])) + values = math.to_dB(values, ref=np.max(ref[ind:])) else: raise ValueError(f"Log argument {log} not recognized") return values diff --git a/tests/test_io_handlers.py b/tests/test_io_handlers.py index f712d59..036b605 100644 --- a/tests/test_io_handlers.py +++ b/tests/test_io_handlers.py @@ -126,8 +126,8 @@ def test_zip_data_copy(tmp_path: Path): prop = propagation(tmp_path / "file.zip") - assert prop.parameters.effective_area_file.load_data() == new_aeff_path.read_bytes() - assert prop.parameters.dispersion_file.load_data() == new_disp_path.read_bytes() + assert prop.parameters.effective_area_file.load_bytes() == new_aeff_path.read_bytes() + assert prop.parameters.dispersion_file.load_bytes() == new_disp_path.read_bytes() def test_zip_bundle(tmp_path: Path):