stable version before big units breaking change

This commit is contained in:
Benoît Sierro
2023-09-25 10:31:51 +02:00
parent 9381ebe3f3
commit 2f9b5005a5
4 changed files with 84 additions and 91 deletions

View File

@@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "scgenerator" name = "scgenerator"
version = "0.3.11" version = "0.3.12"
description = "Simulate nonlinear pulse propagation in optical fibers" description = "Simulate nonlinear pulse propagation in optical fibers"
readme = "README.md" readme = "README.md"
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
license = {file = "LICENSE"} license = { file = "LICENSE" }
classifiers = [ classifiers = [
"License :: OSI Approved :: MIT", "License :: OSI Approved :: MIT",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
@@ -40,4 +40,3 @@ line-length = 100
[tool.isort] [tool.isort]
profile = "black" profile = "black"
skip = ["__init__.py"] skip = ["__init__.py"]

View File

@@ -38,6 +38,12 @@ def total_extent(*vec: np.ndarray) -> float:
return right - left return right - left
def span_above(arr: np.ndarray, threshold: float) -> tuple[int, int]:
"""returns the first and last index where the array is above the specified threshold"""
ind = np.where(arr >= threshold)[0]
return np.min(ind), np.max(ind)
def argclosest(array: np.ndarray, target: float | int | Sequence[float | int]) -> int | np.ndarray: def argclosest(array: np.ndarray, target: float | int | Sequence[float | int]) -> int | np.ndarray:
""" """
returns the index/indices corresponding to the closest matches of target in array returns the index/indices corresponding to the closest matches of target in array
@@ -250,7 +256,6 @@ def irfftfreq(freq: np.ndarray, retstep: bool = False):
def iwspace(w: np.ndarray, retstep: bool = False): def iwspace(w: np.ndarray, retstep: bool = False):
"""invserse of wspace: recovers the (symmetric) time array corresponsding to `w`""" """invserse of wspace: recovers the (symmetric) time array corresponsding to `w`"""
df = (w[1] - w[0]) * 0.5 / np.pi df = (w[1] - w[0]) * 0.5 / np.pi
print(df)
nt = len(w) nt = len(w)
period = 1 / df period = 1 / df
dt = period / nt dt = period / nt

View File

@@ -1,6 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union from typing import Any, Callable, Literal, Optional, Sequence, Union
import matplotlib.gridspec as gs import matplotlib.gridspec as gs
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -18,6 +18,7 @@ from scgenerator.math import abs2, linear_interp_2d, span
from scgenerator.parameter import Parameters from scgenerator.parameter import Parameters
from scgenerator.physics import pulse, units from scgenerator.physics import pulse, units
from scgenerator.physics.units import PlotRange, sort_axis from scgenerator.physics.units import PlotRange, sort_axis
from scgenerator.spectra import Propagation, Spectrum
RangeType = tuple[float, float, Union[str, Callable]] RangeType = tuple[float, float, Union[str, Callable]]
NO_LIM = object() NO_LIM = object()
@@ -435,9 +436,6 @@ def transform_2D_propagation(
x_axis: np.ndarray = None, x_axis: np.ndarray = None,
y_axis: np.ndarray = None, y_axis: np.ndarray = None,
log: Union[int, float, bool, str] = "1D", log: Union[int, float, bool, str] = "1D",
skip: int = 1,
params: Parameters = None,
conserved_quantity: bool = True,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
""" """
transforms raws values into plottable values transforms raws values into plottable values
@@ -454,14 +452,6 @@ def transform_2D_propagation(
corresponding y values in SI units corresponding y values in SI units
log : Union[int, float, bool, str], optional log : Union[int, float, bool, str], optional
see apply_log, by default "1D" see apply_log, by default "1D"
params : Parameters, optional
parameters of the simulation, used to automatically fill in x and y axes
skip : int, optional
take one every skip values, by default 1 (take all values)
conserved_quantity : bool, optional
if the target axis is wavelength, the transformation is not linear has to be corrected.
This is necessary when values is interpreted as averaged over a bin (e.g. amplitude),
but shouldn't be used when it's not the case (e.g. coherence). by default True
Returns Returns
------- -------
@@ -477,10 +467,6 @@ def transform_2D_propagation(
ValueError ValueError
incorrect shape incorrect shape
""" """
x_axis = get_x_axis(plt_range, x_axis, params)
if y_axis is None and params is not None:
y_axis = params.z_targets
if values.ndim != 2: if values.ndim != 2:
raise ValueError(f"shape was {values.shape}. Can only plot 2D array") raise ValueError(f"shape was {values.shape}. Can only plot 2D array")
is_complex, plt_range = prep_plot_axis(values, plt_range) is_complex, plt_range = prep_plot_axis(values, plt_range)
@@ -489,33 +475,10 @@ def transform_2D_propagation(
# if params.full_field and plt_range.unit.type == "TIME": # if params.full_field and plt_range.unit.type == "TIME":
# values = envelope_2d(x_axis, values) # values = envelope_2d(x_axis, values)
x_axis, values = uniform_axis(x_axis, values, plt_range, conserved_quantity) x_axis, values = uniform_axis(x_axis, values, plt_range)
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None, conserved_quantity) y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
values = apply_log(values, log) values = apply_log(values, log)
return x_axis[::skip], y_axis, values[:, ::skip] return x_axis, y_axis, values
def uniform_2d(
old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray, new_y: np.ndarray, values: np.ndarray
) -> np.ndarray:
"""
interpolates a 2d array according to the provides old and new axis
Parameters
----------
old_x : np.ndarray, shape (n,)
old_y : np.ndarray, shape (m,)
new_x : np.ndarray, shape (N,)
new_y : np.ndarray, shape (M,)
values : np.ndarray, shape (m, n)
Returns
-------
np.ndarray, shape (M, N)
"""
values = interp1d(old_x, values, fill_value=0, bounds_error=False, axis=1)(new_x)
values = interp1d(old_y, values, fill_value=0, bounds_error=False, axis=0)(new_y)
return values
def get_x_axis(plt_range, x_axis, params) -> np.ndarray: def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
@@ -961,7 +924,6 @@ def uniform_axis(
axis: np.ndarray, axis: np.ndarray,
values: np.ndarray, values: np.ndarray,
new_axis_spec: Union[PlotRange, RangeType, str], new_axis_spec: Union[PlotRange, RangeType, str],
conserved_quantity: bool = True,
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
""" """
given some values(axis), creates a new uniformly spaced axis and interpolates given some values(axis), creates a new uniformly spaced axis and interpolates
@@ -1006,8 +968,6 @@ def uniform_axis(
new_axis = tmp_axis new_axis = tmp_axis
values = values[:, ind] values = values[:, ind]
else: else:
if plt_range.unit.type == "WL" and conserved_quantity:
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis)) new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
values = linear_interp_2d(tmp_axis, values[:, ind], new_axis) values = linear_interp_2d(tmp_axis, values[:, ind], new_axis)
return new_axis, values.squeeze() return new_axis, values.squeeze()
@@ -1192,3 +1152,25 @@ def annotate_fwhm(
x, y = (left, v_max / 2) x, y = (left, v_max / 2)
trans = offset_copy(ax.transData, ax.get_figure(), offset, 0, "points") trans = offset_copy(ax.transData, ax.get_figure(), offset, 0, "points")
ax.text(x, y, arrow_label, transform=trans, **text_kwargs) ax.text(x, y, arrow_label, transform=trans, **text_kwargs)
def summary_plot(
specs: Spectrum,
z: Sequence[float] | None = None,
wl_range: PlotRange | None = None,
t_range: PlotRange | None = None,
db_min: float = -50.0,
):
wl_int = specs.wl_int
time_int = specs.time_int
if wl_range is None:
imin, imax = math.span_above(wl_int, wl_int.max() * 1e-6)
wl_range = PlotRange(specs.wl_disp[imin] * 1e9, specs.wl_disp[imax] * 1e9, "nm")
if t_range is None:
imin, imax = math.span_above(time_int, time_int.max() * 1e-6)
t_range = PlotRange(specs.t[imin] * 1e15, specs.t[imax] * 1e15, "fs")
fig, (left, right) = plt.subplots(1, 2)
transform_2D_propagation(wl_int, wl_range, specs.w, z)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, Generic, TypeVar, overload from typing import Callable, Generic, TypeVar, overload
@@ -46,6 +47,7 @@ class Spectrum(np.ndarray):
# We first cast to be our class type # We first cast to be our class type
obj = np.asarray(input_array).view(cls) obj = np.asarray(input_array).view(cls)
# add the new attribute to the created instance # add the new attribute to the created instance
obj.order = np.argsort(w)
obj.w = w obj.w = w
if t is not None: if t is not None:
obj.t = t obj.t = t
@@ -70,44 +72,36 @@ class Spectrum(np.ndarray):
self.w = getattr(obj, "w", None) self.w = getattr(obj, "w", None)
self.t = getattr(obj, "t", None) self.t = getattr(obj, "t", None)
self.l = getattr(obj, "l", None) self.l = getattr(obj, "l", None)
self.order = getattr(obj, "order", None)
self.ifft = getattr(obj, "ifft", None) self.ifft = getattr(obj, "ifft", None)
def __getitem__(self, key) -> "Spectrum": def __getitem__(self, key) -> "Spectrum":
return super().__getitem__(key) return super().__getitem__(key)
@property
def wl_disp(self):
return self.l[self.order][::-1]
@property
def w_disp(self):
return self.w[self.order]
@property @property
def wl_int(self): def wl_int(self):
return units.to_WL(math.abs2(self), self.l) return units.to_WL(math.abs2(self), self.l)[self.order][::-1]
@property @property
def freq_int(self): def freq_int(self):
return math.abs2(self) return math.abs2(self.freq_amp)
@property @property
def afreq_int(self): def afreq_int(self):
return math.abs2(self) return math.abs2(self.freq_amp)
@property @property
def time_int(self): def time_int(self):
return math.abs2(self.ifft(self)) return math.abs2(self.ifft(self))
def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self.wl_amp,
FREQ=self.freq_amp,
AFREQ=self.afreq_amp,
TIME=self.time_amp,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
@property @property
def wl_amp(self): def wl_amp(self):
return ( return (
@@ -119,15 +113,15 @@ class Spectrum(np.ndarray):
) )
* self * self
/ np.abs(self) / np.abs(self)
) )[self.order][::-1]
@property @property
def freq_amp(self): def freq_amp(self):
return self return self[self.order]
@property @property
def afreq_amp(self): def afreq_amp(self):
return self return self[self.order]
@property @property
def time_amp(self): def time_amp(self):
@@ -180,6 +174,8 @@ class Propagation(Generic[ParamsOrNone]):
self.io = io_handler self.io = io_handler
self._current_index = len(self.io) self._current_index = len(self.io)
self.parameters = params self.parameters = params
if self.parameters is not None:
self.z_positions = self.parameters.compute("z_targets")
def __len__(self) -> int: def __len__(self) -> int:
return self._current_index return self._current_index
@@ -198,8 +194,9 @@ class Propagation(Generic[ParamsOrNone]):
if isinstance(key, (float, np.floating)): if isinstance(key, (float, np.floating)):
if self.parameters is None: if self.parameters is None:
raise ValueError(f"cannot accept float key {key} when parameters is not set") raise ValueError(f"cannot accept float key {key} when parameters is not set")
key = math.argclosest(self.parameters.compute("z_targets"), key) key = math.argclosest(self.z_positions, key)
elif key < 0: elif key < 0:
self._warn_negative_index(key)
key = len(self) + key key = len(self) + key
array = self.io.load_spectrum(key) array = self.io.load_spectrum(key)
if self.parameters is not None: if self.parameters is not None:
@@ -221,21 +218,31 @@ class Propagation(Generic[ParamsOrNone]):
... ...
def _load_slice(self, key: slice) -> Spectrum: def _load_slice(self, key: slice) -> Spectrum:
self._warn_negative_index(key.start)
self._warn_negative_index(key.stop)
_iter = range(len(self))[key] _iter = range(len(self))[key]
if self.parameters is not None: # if self.parameters is not None:
out = Spectrum.from_params( # out = Spectrum.from_params(
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters # np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
) # )
for i in _iter: # for i in _iter:
out[i] = self.io.load_spectrum(i) # out[i] = self.io.load_spectrum(i)
else: # else:
out = np.array([self.io.load_spectrum(i) for i in _iter]) out = np.array([self.io.load_spectrum(i) for i in _iter])
if self.parameters is not None:
out = Spectrum.from_params(out, self.parameters)
return out return out
def append(self, spectrum: np.ndarray): def append(self, spectrum: np.ndarray):
self.io.save_spectrum(self._current_index, np.asarray(spectrum)) self.io.save_spectrum(self._current_index, np.asarray(spectrum))
self._current_index += 1 self._current_index += 1
def _warn_negative_index(self, index: int | None):
if (index is not None and index >= 0) or self.parameters is None:
return
if self._current_index < len(self.z_positions):
warnings.warn(f"attempting to access index {index} on an incomplete propagation obj")
def load_all(self) -> Spectrum: def load_all(self) -> Spectrum:
return self._load_slice(slice(None)) return self._load_slice(slice(None))