diff --git a/pyproject.toml b/pyproject.toml index 656d353..88f3295 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,27 +4,27 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.3.11" +version = "0.3.12" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] -license = {file = "LICENSE"} +license = { file = "LICENSE" } classifiers = [ - "License :: OSI Approved :: MIT", - "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT", + "Programming Language :: Python :: 3", ] requires-python = ">=3.10" keywords = ["nonlinear", "fiber optics", "simulation", "runge-kutta"] dependencies = [ - "numpy", - "scipy", - "matplotlib", - "tomli", - "tomli_w", - "numba", - "tqdm", - "pydantic", - "pydantic-settings", + "numpy", + "scipy", + "matplotlib", + "tomli", + "tomli_w", + "numba", + "tqdm", + "pydantic", + "pydantic-settings", ] [tool.ruff] @@ -40,4 +40,3 @@ line-length = 100 [tool.isort] profile = "black" skip = ["__init__.py"] - diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 1c7a655..1c542b1 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -38,6 +38,12 @@ def total_extent(*vec: np.ndarray) -> float: return right - left +def span_above(arr: np.ndarray, threshold: float) -> tuple[int, int]: + """returns the first and last index where the array is above the specified threshold""" + ind = np.where(arr >= threshold)[0] + return np.min(ind), np.max(ind) + + def argclosest(array: np.ndarray, target: float | int | Sequence[float | int]) -> int | np.ndarray: """ returns the index/indices corresponding to the closest matches of target in array @@ -250,7 +256,6 @@ def irfftfreq(freq: np.ndarray, retstep: bool = False): def iwspace(w: np.ndarray, retstep: bool = False): """invserse of wspace: recovers the (symmetric) time array corresponsding to `w`""" df = (w[1] - w[0]) * 0.5 / np.pi - print(df) nt = len(w) period = 1 / df dt = period / nt diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index e8008c8..a72538b 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, Sequence, Union import matplotlib.gridspec as gs import matplotlib.pyplot as plt @@ -18,6 +18,7 @@ from scgenerator.math import abs2, linear_interp_2d, span from scgenerator.parameter import Parameters from scgenerator.physics import pulse, units from scgenerator.physics.units import PlotRange, sort_axis +from scgenerator.spectra import Propagation, Spectrum RangeType = tuple[float, float, Union[str, Callable]] NO_LIM = object() @@ -435,9 +436,6 @@ def transform_2D_propagation( x_axis: np.ndarray = None, y_axis: np.ndarray = None, log: Union[int, float, bool, str] = "1D", - skip: int = 1, - params: Parameters = None, - conserved_quantity: bool = True, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ transforms raws values into plottable values @@ -454,14 +452,6 @@ def transform_2D_propagation( corresponding y values in SI units log : Union[int, float, bool, str], optional see apply_log, by default "1D" - params : Parameters, optional - parameters of the simulation, used to automatically fill in x and y axes - skip : int, optional - take one every skip values, by default 1 (take all values) - conserved_quantity : bool, optional - if the target axis is wavelength, the transformation is not linear has to be corrected. - This is necessary when values is interpreted as averaged over a bin (e.g. amplitude), - but shouldn't be used when it's not the case (e.g. coherence). by default True Returns ------- @@ -477,10 +467,6 @@ def transform_2D_propagation( ValueError incorrect shape """ - x_axis = get_x_axis(plt_range, x_axis, params) - if y_axis is None and params is not None: - y_axis = params.z_targets - if values.ndim != 2: raise ValueError(f"shape was {values.shape}. Can only plot 2D array") is_complex, plt_range = prep_plot_axis(values, plt_range) @@ -489,33 +475,10 @@ def transform_2D_propagation( # if params.full_field and plt_range.unit.type == "TIME": # values = envelope_2d(x_axis, values) - x_axis, values = uniform_axis(x_axis, values, plt_range, conserved_quantity) - y_axis, values.T[:] = uniform_axis(y_axis, values.T, None, conserved_quantity) + x_axis, values = uniform_axis(x_axis, values, plt_range) + y_axis, values.T[:] = uniform_axis(y_axis, values.T, None) values = apply_log(values, log) - return x_axis[::skip], y_axis, values[:, ::skip] - - -def uniform_2d( - old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray, new_y: np.ndarray, values: np.ndarray -) -> np.ndarray: - """ - interpolates a 2d array according to the provides old and new axis - - Parameters - ---------- - old_x : np.ndarray, shape (n,) - old_y : np.ndarray, shape (m,) - new_x : np.ndarray, shape (N,) - new_y : np.ndarray, shape (M,) - values : np.ndarray, shape (m, n) - - Returns - ------- - np.ndarray, shape (M, N) - """ - values = interp1d(old_x, values, fill_value=0, bounds_error=False, axis=1)(new_x) - values = interp1d(old_y, values, fill_value=0, bounds_error=False, axis=0)(new_y) - return values + return x_axis, y_axis, values def get_x_axis(plt_range, x_axis, params) -> np.ndarray: @@ -961,7 +924,6 @@ def uniform_axis( axis: np.ndarray, values: np.ndarray, new_axis_spec: Union[PlotRange, RangeType, str], - conserved_quantity: bool = True, ) -> tuple[np.ndarray, np.ndarray]: """ given some values(axis), creates a new uniformly spaced axis and interpolates @@ -1006,8 +968,6 @@ def uniform_axis( new_axis = tmp_axis values = values[:, ind] else: - if plt_range.unit.type == "WL" and conserved_quantity: - values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis) new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis)) values = linear_interp_2d(tmp_axis, values[:, ind], new_axis) return new_axis, values.squeeze() @@ -1192,3 +1152,25 @@ def annotate_fwhm( x, y = (left, v_max / 2) trans = offset_copy(ax.transData, ax.get_figure(), offset, 0, "points") ax.text(x, y, arrow_label, transform=trans, **text_kwargs) + + +def summary_plot( + specs: Spectrum, + z: Sequence[float] | None = None, + wl_range: PlotRange | None = None, + t_range: PlotRange | None = None, + db_min: float = -50.0, +): + wl_int = specs.wl_int + time_int = specs.time_int + + if wl_range is None: + imin, imax = math.span_above(wl_int, wl_int.max() * 1e-6) + wl_range = PlotRange(specs.wl_disp[imin] * 1e9, specs.wl_disp[imax] * 1e9, "nm") + + if t_range is None: + imin, imax = math.span_above(time_int, time_int.max() * 1e-6) + t_range = PlotRange(specs.t[imin] * 1e15, specs.t[imax] * 1e15, "fs") + + fig, (left, right) = plt.subplots(1, 2) + transform_2D_propagation(wl_int, wl_range, specs.w, z) diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index e2bda18..61e0d61 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import warnings from pathlib import Path from typing import Callable, Generic, TypeVar, overload @@ -46,6 +47,7 @@ class Spectrum(np.ndarray): # We first cast to be our class type obj = np.asarray(input_array).view(cls) # add the new attribute to the created instance + obj.order = np.argsort(w) obj.w = w if t is not None: obj.t = t @@ -70,44 +72,36 @@ class Spectrum(np.ndarray): self.w = getattr(obj, "w", None) self.t = getattr(obj, "t", None) self.l = getattr(obj, "l", None) + self.order = getattr(obj, "order", None) self.ifft = getattr(obj, "ifft", None) def __getitem__(self, key) -> "Spectrum": return super().__getitem__(key) + @property + def wl_disp(self): + return self.l[self.order][::-1] + + @property + def w_disp(self): + return self.w[self.order] + @property def wl_int(self): - return units.to_WL(math.abs2(self), self.l) + return units.to_WL(math.abs2(self), self.l)[self.order][::-1] @property def freq_int(self): - return math.abs2(self) + return math.abs2(self.freq_amp) @property def afreq_int(self): - return math.abs2(self) + return math.abs2(self.freq_amp) @property def time_int(self): return math.abs2(self.ifft(self)) - def amplitude(self, unit): - if unit.type in ["WL", "FREQ", "AFREQ"]: - x_axis = unit.inv(self.w) - else: - x_axis = unit.inv(self.t) - - order = np.argsort(x_axis) - func = dict( - WL=self.wl_amp, - FREQ=self.freq_amp, - AFREQ=self.afreq_amp, - TIME=self.time_amp, - )[unit.type] - - for spec in self: - yield x_axis[order], func(spec)[:, order] - @property def wl_amp(self): return ( @@ -119,15 +113,15 @@ class Spectrum(np.ndarray): ) * self / np.abs(self) - ) + )[self.order][::-1] @property def freq_amp(self): - return self + return self[self.order] @property def afreq_amp(self): - return self + return self[self.order] @property def time_amp(self): @@ -180,6 +174,8 @@ class Propagation(Generic[ParamsOrNone]): self.io = io_handler self._current_index = len(self.io) self.parameters = params + if self.parameters is not None: + self.z_positions = self.parameters.compute("z_targets") def __len__(self) -> int: return self._current_index @@ -198,8 +194,9 @@ class Propagation(Generic[ParamsOrNone]): if isinstance(key, (float, np.floating)): if self.parameters is None: raise ValueError(f"cannot accept float key {key} when parameters is not set") - key = math.argclosest(self.parameters.compute("z_targets"), key) + key = math.argclosest(self.z_positions, key) elif key < 0: + self._warn_negative_index(key) key = len(self) + key array = self.io.load_spectrum(key) if self.parameters is not None: @@ -221,21 +218,31 @@ class Propagation(Generic[ParamsOrNone]): ... def _load_slice(self, key: slice) -> Spectrum: + self._warn_negative_index(key.start) + self._warn_negative_index(key.stop) _iter = range(len(self))[key] + # if self.parameters is not None: + # out = Spectrum.from_params( + # np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters + # ) + # for i in _iter: + # out[i] = self.io.load_spectrum(i) + # else: + out = np.array([self.io.load_spectrum(i) for i in _iter]) if self.parameters is not None: - out = Spectrum.from_params( - np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters - ) - for i in _iter: - out[i] = self.io.load_spectrum(i) - else: - out = np.array([self.io.load_spectrum(i) for i in _iter]) + out = Spectrum.from_params(out, self.parameters) return out def append(self, spectrum: np.ndarray): self.io.save_spectrum(self._current_index, np.asarray(spectrum)) self._current_index += 1 + def _warn_negative_index(self, index: int | None): + if (index is not None and index >= 0) or self.parameters is None: + return + if self._current_index < len(self.z_positions): + warnings.warn(f"attempting to access index {index} on an incomplete propagation obj") + def load_all(self) -> Spectrum: return self._load_slice(slice(None))