From ff7e78b1b52b970798227c0d0ce83b28931f33fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 5 Sep 2023 09:05:39 +0200 Subject: [PATCH] change: grid and units improvements --- pyproject.toml | 2 +- src/scgenerator/io.py | 2 +- src/scgenerator/math.py | 34 +++++++++++++++++++++++++++++++++- src/scgenerator/noise.py | 20 +------------------- src/scgenerator/spectra.py | 30 +++++++++++++++++++++++------- src/scgenerator/utils.py | 1 - tests/test_grid.py | 7 +++++++ 7 files changed, 66 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0db193e..1888a9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.3.9" +version = "0.3.10" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 9cd9521..e550964 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -28,7 +28,7 @@ class TimedMessage: def data_file(path: str) -> Path: """returns a `Path` object pointing to the desired data file included in `scgenerator`""" - return importlib.resources.path("scgenerator", "data") / path + return importlib.resources.files("scgenerator") / "data" / path class CustomEncoder(json.JSONEncoder): diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index ca3dbb3..1c7a655 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -140,7 +140,7 @@ def to_dB(arr: np.ndarray, ref=None, axis=None) -> np.ndarray: ref = np.max(arr) m = arr / ref above_0 = m > 0 - m = 10 * np.log10(m, out=np.zeros_like(m) - 10 * np.log10(m[above_0].min()), where=above_0) + m = 10 * np.log10(m, out=np.ones_like(m) * (10 * np.log10(m[above_0].min())), where=above_0) return m @@ -229,6 +229,38 @@ def tspace(time_window: float | None = None, t_num: float | None = None, dt: flo raise TypeError("not enough parameter to determine time vector") +def irfftfreq(freq: np.ndarray, retstep: bool = False): + """ + Given an array of positive only frequency, this returns the corresponding time array centered + around 0 that will be aligned with the `numpy.fft.irfft` of a spectrum aligned with `freq`. + if `retstep` is True, the sample spacing is returned as well + """ + df = freq[1] - freq[0] + nt = (len(freq) - 1) * 2 + period = 1 / df + dt = period / nt + + t = np.linspace(-(period - dt) / 2, (period - dt) / 2, nt) + if retstep: + return t, dt + else: + return t + + +def iwspace(w: np.ndarray, retstep: bool = False): + """invserse of wspace: recovers the (symmetric) time array corresponsding to `w`""" + df = (w[1] - w[0]) * 0.5 / np.pi + print(df) + nt = len(w) + period = 1 / df + dt = period / nt + t = np.linspace(-(period - dt) / 2, (period - dt) / 2, nt) + if retstep: + return t, dt + else: + return t + + def dt_from_min_wl(wl_min: float, wavelength: float) -> float: return 0.5 * 1 / c * 1 / (1 / wl_min - 1 / wavelength) diff --git a/src/scgenerator/noise.py b/src/scgenerator/noise.py index 1d6b6bd..6b86ce0 100644 --- a/src/scgenerator/noise.py +++ b/src/scgenerator/noise.py @@ -147,7 +147,7 @@ class NoiseMeasurement: freq, spec = self.sample_spectrum(nt, dt) if phase is None: phase = 2 * np.pi * np.random.rand(len(freq)) - time, dt = irfftfreq(freq, True) + time, dt = math.irfftfreq(freq, True) amp = np.sqrt(spec) * np.exp(1j * phase) signal = np.fft.irfft(amp) * np.sqrt(0.5 * nt / dt) @@ -163,24 +163,6 @@ class NoiseMeasurement: return integrated_noise(self.freq, self.psd) -def irfftfreq(freq: np.ndarray, retstep: bool = False): - """ - Given an array of positive only frequency, this returns the corresponding time array centered - around 0 that will be aligned with the `numpy.fft.irfft` of a spectrum aligned with `freq`. - if `retstep` is True, the sample spacing is returned as well - """ - df = freq[1] - freq[0] - nt = (len(freq) - 1) * 2 - period = 1 / df - dt = period / nt - - t = np.linspace(-(period - dt) / 2, (period - dt) / 2, nt) - if retstep: - return t, dt - else: - return t - - def log_power(x): return 10 * np.log10(np.abs(np.where(x == 0, 1e-7, x))) diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 7be6999..fd28f37 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -28,15 +28,31 @@ class Spectrum(np.ndarray): l: np.ndarray ifft: Callable[[np.ndarray], np.ndarray] - def __new__(cls, input_array, params: Parameters): + @classmethod + def from_params(cls, array, params: Parameters): + w = params.compute("w") + t = params.compute("t") + ifft = params.compute("ifft") + return cls(array, w, t, ifft) + + def __new__( + cls, + input_array, + w: np.ndarray, + t: np.ndarray | None = None, + ifft: Callable[[np.ndarray], np.ndarray] = np.fft.ifft, + ): # Input array is an already formed ndarray instance # We first cast to be our class type obj = np.asarray(input_array).view(cls) # add the new attribute to the created instance - obj.w = params.compute("w") - obj.t = params.compute("t") - obj.l = params.compute("l") - obj.ifft = params.compute("ifft") + obj.w = w + if t is not None: + obj.t = t + else: + obj.t = math.iwspace(obj.w) + obj.ifft = ifft + obj.l = 2 * np.pi * units.c / obj.w if not (len(obj.w) == len(obj.t) == len(obj.l) == obj.shape[-1]): raise ValueError( @@ -187,7 +203,7 @@ class Propagation(Generic[ParamsOrNone]): key = len(self) + key array = self.io.load_spectrum(key) if self.parameters is not None: - return Spectrum(array, self.parameters) + return Spectrum.from_params(array, self.parameters) else: return array @@ -207,7 +223,7 @@ class Propagation(Generic[ParamsOrNone]): def _load_slice(self, key: slice) -> Spectrum: _iter = range(len(self))[key] if self.parameters is not None: - out = Spectrum( + out = Spectrum.from_params( np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters ) for i in _iter: diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index c04d0c2..4d48a43 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -5,7 +5,6 @@ scgenerator module but some function may be used in any python program """ from __future__ import annotations -import datetime import itertools import json import os diff --git a/tests/test_grid.py b/tests/test_grid.py index cb4cad3..3980b2f 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -20,3 +20,10 @@ def test_wl_dispersion(): wl = sc.units.m.inv(w + sc.units.nm(1546)) wl_disp, ind_disp = sc.fiber.lambda_for_envelope_dispersion(wl, (950e-9, 4000e-9)) assert all(np.diff(wl_disp) > 0) + + +def test_iwspace(): + t = sc.tspace(dt=15.6, t_num=512) + assert sc.math.iwspace(sc.wspace(t)) == pytest.approx(t) + assert sc.math.iwspace(sc.wspace(t) + 4564568456.4) != pytest.approx(t) + assert sc.math.iwspace(sc.wspace(t) + 4564568456.4) == pytest.approx(t, rel=1e-3)