stable version before big units breaking change

This commit is contained in:
Benoît Sierro
2023-09-25 10:31:51 +02:00
parent 9381ebe3f3
commit 2f9b5005a5
4 changed files with 84 additions and 91 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "scgenerator"
version = "0.3.11"
version = "0.3.12"
description = "Simulate nonlinear pulse propagation in optical fibers"
readme = "README.md"
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
@@ -40,4 +40,3 @@ line-length = 100
[tool.isort]
profile = "black"
skip = ["__init__.py"]

View File

@@ -38,6 +38,12 @@ def total_extent(*vec: np.ndarray) -> float:
return right - left
def span_above(arr: np.ndarray, threshold: float) -> tuple[int, int]:
"""returns the first and last index where the array is above the specified threshold"""
ind = np.where(arr >= threshold)[0]
return np.min(ind), np.max(ind)
def argclosest(array: np.ndarray, target: float | int | Sequence[float | int]) -> int | np.ndarray:
"""
returns the index/indices corresponding to the closest matches of target in array
@@ -250,7 +256,6 @@ def irfftfreq(freq: np.ndarray, retstep: bool = False):
def iwspace(w: np.ndarray, retstep: bool = False):
"""invserse of wspace: recovers the (symmetric) time array corresponsding to `w`"""
df = (w[1] - w[0]) * 0.5 / np.pi
print(df)
nt = len(w)
period = 1 / df
dt = period / nt

View File

@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, Sequence, Union
import matplotlib.gridspec as gs
import matplotlib.pyplot as plt
@@ -18,6 +18,7 @@ from scgenerator.math import abs2, linear_interp_2d, span
from scgenerator.parameter import Parameters
from scgenerator.physics import pulse, units
from scgenerator.physics.units import PlotRange, sort_axis
from scgenerator.spectra import Propagation, Spectrum
RangeType = tuple[float, float, Union[str, Callable]]
NO_LIM = object()
@@ -435,9 +436,6 @@ def transform_2D_propagation(
x_axis: np.ndarray = None,
y_axis: np.ndarray = None,
log: Union[int, float, bool, str] = "1D",
skip: int = 1,
params: Parameters = None,
conserved_quantity: bool = True,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
transforms raws values into plottable values
@@ -454,14 +452,6 @@ def transform_2D_propagation(
corresponding y values in SI units
log : Union[int, float, bool, str], optional
see apply_log, by default "1D"
params : Parameters, optional
parameters of the simulation, used to automatically fill in x and y axes
skip : int, optional
take one every skip values, by default 1 (take all values)
conserved_quantity : bool, optional
if the target axis is wavelength, the transformation is not linear has to be corrected.
This is necessary when values is interpreted as averaged over a bin (e.g. amplitude),
but shouldn't be used when it's not the case (e.g. coherence). by default True
Returns
-------
@@ -477,10 +467,6 @@ def transform_2D_propagation(
ValueError
incorrect shape
"""
x_axis = get_x_axis(plt_range, x_axis, params)
if y_axis is None and params is not None:
y_axis = params.z_targets
if values.ndim != 2:
raise ValueError(f"shape was {values.shape}. Can only plot 2D array")
is_complex, plt_range = prep_plot_axis(values, plt_range)
@@ -489,33 +475,10 @@ def transform_2D_propagation(
# if params.full_field and plt_range.unit.type == "TIME":
# values = envelope_2d(x_axis, values)
x_axis, values = uniform_axis(x_axis, values, plt_range, conserved_quantity)
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None, conserved_quantity)
x_axis, values = uniform_axis(x_axis, values, plt_range)
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
values = apply_log(values, log)
return x_axis[::skip], y_axis, values[:, ::skip]
def uniform_2d(
old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray, new_y: np.ndarray, values: np.ndarray
) -> np.ndarray:
"""
interpolates a 2d array according to the provides old and new axis
Parameters
----------
old_x : np.ndarray, shape (n,)
old_y : np.ndarray, shape (m,)
new_x : np.ndarray, shape (N,)
new_y : np.ndarray, shape (M,)
values : np.ndarray, shape (m, n)
Returns
-------
np.ndarray, shape (M, N)
"""
values = interp1d(old_x, values, fill_value=0, bounds_error=False, axis=1)(new_x)
values = interp1d(old_y, values, fill_value=0, bounds_error=False, axis=0)(new_y)
return values
return x_axis, y_axis, values
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
@@ -961,7 +924,6 @@ def uniform_axis(
axis: np.ndarray,
values: np.ndarray,
new_axis_spec: Union[PlotRange, RangeType, str],
conserved_quantity: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
"""
given some values(axis), creates a new uniformly spaced axis and interpolates
@@ -1006,8 +968,6 @@ def uniform_axis(
new_axis = tmp_axis
values = values[:, ind]
else:
if plt_range.unit.type == "WL" and conserved_quantity:
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
values = linear_interp_2d(tmp_axis, values[:, ind], new_axis)
return new_axis, values.squeeze()
@@ -1192,3 +1152,25 @@ def annotate_fwhm(
x, y = (left, v_max / 2)
trans = offset_copy(ax.transData, ax.get_figure(), offset, 0, "points")
ax.text(x, y, arrow_label, transform=trans, **text_kwargs)
def summary_plot(
specs: Spectrum,
z: Sequence[float] | None = None,
wl_range: PlotRange | None = None,
t_range: PlotRange | None = None,
db_min: float = -50.0,
):
wl_int = specs.wl_int
time_int = specs.time_int
if wl_range is None:
imin, imax = math.span_above(wl_int, wl_int.max() * 1e-6)
wl_range = PlotRange(specs.wl_disp[imin] * 1e9, specs.wl_disp[imax] * 1e9, "nm")
if t_range is None:
imin, imax = math.span_above(time_int, time_int.max() * 1e-6)
t_range = PlotRange(specs.t[imin] * 1e15, specs.t[imax] * 1e15, "fs")
fig, (left, right) = plt.subplots(1, 2)
transform_2D_propagation(wl_int, wl_range, specs.w, z)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import os
import warnings
from pathlib import Path
from typing import Callable, Generic, TypeVar, overload
@@ -46,6 +47,7 @@ class Spectrum(np.ndarray):
# We first cast to be our class type
obj = np.asarray(input_array).view(cls)
# add the new attribute to the created instance
obj.order = np.argsort(w)
obj.w = w
if t is not None:
obj.t = t
@@ -70,44 +72,36 @@ class Spectrum(np.ndarray):
self.w = getattr(obj, "w", None)
self.t = getattr(obj, "t", None)
self.l = getattr(obj, "l", None)
self.order = getattr(obj, "order", None)
self.ifft = getattr(obj, "ifft", None)
def __getitem__(self, key) -> "Spectrum":
return super().__getitem__(key)
@property
def wl_disp(self):
return self.l[self.order][::-1]
@property
def w_disp(self):
return self.w[self.order]
@property
def wl_int(self):
return units.to_WL(math.abs2(self), self.l)
return units.to_WL(math.abs2(self), self.l)[self.order][::-1]
@property
def freq_int(self):
return math.abs2(self)
return math.abs2(self.freq_amp)
@property
def afreq_int(self):
return math.abs2(self)
return math.abs2(self.freq_amp)
@property
def time_int(self):
return math.abs2(self.ifft(self))
def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self.wl_amp,
FREQ=self.freq_amp,
AFREQ=self.afreq_amp,
TIME=self.time_amp,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
@property
def wl_amp(self):
return (
@@ -119,15 +113,15 @@ class Spectrum(np.ndarray):
)
* self
/ np.abs(self)
)
)[self.order][::-1]
@property
def freq_amp(self):
return self
return self[self.order]
@property
def afreq_amp(self):
return self
return self[self.order]
@property
def time_amp(self):
@@ -180,6 +174,8 @@ class Propagation(Generic[ParamsOrNone]):
self.io = io_handler
self._current_index = len(self.io)
self.parameters = params
if self.parameters is not None:
self.z_positions = self.parameters.compute("z_targets")
def __len__(self) -> int:
return self._current_index
@@ -198,8 +194,9 @@ class Propagation(Generic[ParamsOrNone]):
if isinstance(key, (float, np.floating)):
if self.parameters is None:
raise ValueError(f"cannot accept float key {key} when parameters is not set")
key = math.argclosest(self.parameters.compute("z_targets"), key)
key = math.argclosest(self.z_positions, key)
elif key < 0:
self._warn_negative_index(key)
key = len(self) + key
array = self.io.load_spectrum(key)
if self.parameters is not None:
@@ -221,21 +218,31 @@ class Propagation(Generic[ParamsOrNone]):
...
def _load_slice(self, key: slice) -> Spectrum:
self._warn_negative_index(key.start)
self._warn_negative_index(key.stop)
_iter = range(len(self))[key]
if self.parameters is not None:
out = Spectrum.from_params(
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
)
for i in _iter:
out[i] = self.io.load_spectrum(i)
else:
# if self.parameters is not None:
# out = Spectrum.from_params(
# np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
# )
# for i in _iter:
# out[i] = self.io.load_spectrum(i)
# else:
out = np.array([self.io.load_spectrum(i) for i in _iter])
if self.parameters is not None:
out = Spectrum.from_params(out, self.parameters)
return out
def append(self, spectrum: np.ndarray):
self.io.save_spectrum(self._current_index, np.asarray(spectrum))
self._current_index += 1
def _warn_negative_index(self, index: int | None):
if (index is not None and index >= 0) or self.parameters is None:
return
if self._current_index < len(self.z_positions):
warnings.warn(f"attempting to access index {index} on an incomplete propagation obj")
def load_all(self) -> Spectrum:
return self._load_slice(slice(None))