stable version before big units breaking change
This commit is contained in:
@@ -4,27 +4,27 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "scgenerator"
|
||||
version = "0.3.11"
|
||||
version = "0.3.12"
|
||||
description = "Simulate nonlinear pulse propagation in optical fibers"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
||||
license = {file = "LICENSE"}
|
||||
license = { file = "LICENSE" }
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT",
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT",
|
||||
"Programming Language :: Python :: 3",
|
||||
]
|
||||
requires-python = ">=3.10"
|
||||
keywords = ["nonlinear", "fiber optics", "simulation", "runge-kutta"]
|
||||
dependencies = [
|
||||
"numpy",
|
||||
"scipy",
|
||||
"matplotlib",
|
||||
"tomli",
|
||||
"tomli_w",
|
||||
"numba",
|
||||
"tqdm",
|
||||
"pydantic",
|
||||
"pydantic-settings",
|
||||
"numpy",
|
||||
"scipy",
|
||||
"matplotlib",
|
||||
"tomli",
|
||||
"tomli_w",
|
||||
"numba",
|
||||
"tqdm",
|
||||
"pydantic",
|
||||
"pydantic-settings",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
@@ -40,4 +40,3 @@ line-length = 100
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
skip = ["__init__.py"]
|
||||
|
||||
|
||||
@@ -38,6 +38,12 @@ def total_extent(*vec: np.ndarray) -> float:
|
||||
return right - left
|
||||
|
||||
|
||||
def span_above(arr: np.ndarray, threshold: float) -> tuple[int, int]:
|
||||
"""returns the first and last index where the array is above the specified threshold"""
|
||||
ind = np.where(arr >= threshold)[0]
|
||||
return np.min(ind), np.max(ind)
|
||||
|
||||
|
||||
def argclosest(array: np.ndarray, target: float | int | Sequence[float | int]) -> int | np.ndarray:
|
||||
"""
|
||||
returns the index/indices corresponding to the closest matches of target in array
|
||||
@@ -250,7 +256,6 @@ def irfftfreq(freq: np.ndarray, retstep: bool = False):
|
||||
def iwspace(w: np.ndarray, retstep: bool = False):
|
||||
"""invserse of wspace: recovers the (symmetric) time array corresponsding to `w`"""
|
||||
df = (w[1] - w[0]) * 0.5 / np.pi
|
||||
print(df)
|
||||
nt = len(w)
|
||||
period = 1 / df
|
||||
dt = period / nt
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
from typing import Any, Callable, Literal, Optional, Sequence, Union
|
||||
|
||||
import matplotlib.gridspec as gs
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -18,6 +18,7 @@ from scgenerator.math import abs2, linear_interp_2d, span
|
||||
from scgenerator.parameter import Parameters
|
||||
from scgenerator.physics import pulse, units
|
||||
from scgenerator.physics.units import PlotRange, sort_axis
|
||||
from scgenerator.spectra import Propagation, Spectrum
|
||||
|
||||
RangeType = tuple[float, float, Union[str, Callable]]
|
||||
NO_LIM = object()
|
||||
@@ -435,9 +436,6 @@ def transform_2D_propagation(
|
||||
x_axis: np.ndarray = None,
|
||||
y_axis: np.ndarray = None,
|
||||
log: Union[int, float, bool, str] = "1D",
|
||||
skip: int = 1,
|
||||
params: Parameters = None,
|
||||
conserved_quantity: bool = True,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
transforms raws values into plottable values
|
||||
@@ -454,14 +452,6 @@ def transform_2D_propagation(
|
||||
corresponding y values in SI units
|
||||
log : Union[int, float, bool, str], optional
|
||||
see apply_log, by default "1D"
|
||||
params : Parameters, optional
|
||||
parameters of the simulation, used to automatically fill in x and y axes
|
||||
skip : int, optional
|
||||
take one every skip values, by default 1 (take all values)
|
||||
conserved_quantity : bool, optional
|
||||
if the target axis is wavelength, the transformation is not linear has to be corrected.
|
||||
This is necessary when values is interpreted as averaged over a bin (e.g. amplitude),
|
||||
but shouldn't be used when it's not the case (e.g. coherence). by default True
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -477,10 +467,6 @@ def transform_2D_propagation(
|
||||
ValueError
|
||||
incorrect shape
|
||||
"""
|
||||
x_axis = get_x_axis(plt_range, x_axis, params)
|
||||
if y_axis is None and params is not None:
|
||||
y_axis = params.z_targets
|
||||
|
||||
if values.ndim != 2:
|
||||
raise ValueError(f"shape was {values.shape}. Can only plot 2D array")
|
||||
is_complex, plt_range = prep_plot_axis(values, plt_range)
|
||||
@@ -489,33 +475,10 @@ def transform_2D_propagation(
|
||||
# if params.full_field and plt_range.unit.type == "TIME":
|
||||
# values = envelope_2d(x_axis, values)
|
||||
|
||||
x_axis, values = uniform_axis(x_axis, values, plt_range, conserved_quantity)
|
||||
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None, conserved_quantity)
|
||||
x_axis, values = uniform_axis(x_axis, values, plt_range)
|
||||
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
|
||||
values = apply_log(values, log)
|
||||
return x_axis[::skip], y_axis, values[:, ::skip]
|
||||
|
||||
|
||||
def uniform_2d(
|
||||
old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray, new_y: np.ndarray, values: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
interpolates a 2d array according to the provides old and new axis
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_x : np.ndarray, shape (n,)
|
||||
old_y : np.ndarray, shape (m,)
|
||||
new_x : np.ndarray, shape (N,)
|
||||
new_y : np.ndarray, shape (M,)
|
||||
values : np.ndarray, shape (m, n)
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray, shape (M, N)
|
||||
"""
|
||||
values = interp1d(old_x, values, fill_value=0, bounds_error=False, axis=1)(new_x)
|
||||
values = interp1d(old_y, values, fill_value=0, bounds_error=False, axis=0)(new_y)
|
||||
return values
|
||||
return x_axis, y_axis, values
|
||||
|
||||
|
||||
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
|
||||
@@ -961,7 +924,6 @@ def uniform_axis(
|
||||
axis: np.ndarray,
|
||||
values: np.ndarray,
|
||||
new_axis_spec: Union[PlotRange, RangeType, str],
|
||||
conserved_quantity: bool = True,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
given some values(axis), creates a new uniformly spaced axis and interpolates
|
||||
@@ -1006,8 +968,6 @@ def uniform_axis(
|
||||
new_axis = tmp_axis
|
||||
values = values[:, ind]
|
||||
else:
|
||||
if plt_range.unit.type == "WL" and conserved_quantity:
|
||||
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
|
||||
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
|
||||
values = linear_interp_2d(tmp_axis, values[:, ind], new_axis)
|
||||
return new_axis, values.squeeze()
|
||||
@@ -1192,3 +1152,25 @@ def annotate_fwhm(
|
||||
x, y = (left, v_max / 2)
|
||||
trans = offset_copy(ax.transData, ax.get_figure(), offset, 0, "points")
|
||||
ax.text(x, y, arrow_label, transform=trans, **text_kwargs)
|
||||
|
||||
|
||||
def summary_plot(
|
||||
specs: Spectrum,
|
||||
z: Sequence[float] | None = None,
|
||||
wl_range: PlotRange | None = None,
|
||||
t_range: PlotRange | None = None,
|
||||
db_min: float = -50.0,
|
||||
):
|
||||
wl_int = specs.wl_int
|
||||
time_int = specs.time_int
|
||||
|
||||
if wl_range is None:
|
||||
imin, imax = math.span_above(wl_int, wl_int.max() * 1e-6)
|
||||
wl_range = PlotRange(specs.wl_disp[imin] * 1e9, specs.wl_disp[imax] * 1e9, "nm")
|
||||
|
||||
if t_range is None:
|
||||
imin, imax = math.span_above(time_int, time_int.max() * 1e-6)
|
||||
t_range = PlotRange(specs.t[imin] * 1e15, specs.t[imax] * 1e15, "fs")
|
||||
|
||||
fig, (left, right) = plt.subplots(1, 2)
|
||||
transform_2D_propagation(wl_int, wl_range, specs.w, z)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generic, TypeVar, overload
|
||||
|
||||
@@ -46,6 +47,7 @@ class Spectrum(np.ndarray):
|
||||
# We first cast to be our class type
|
||||
obj = np.asarray(input_array).view(cls)
|
||||
# add the new attribute to the created instance
|
||||
obj.order = np.argsort(w)
|
||||
obj.w = w
|
||||
if t is not None:
|
||||
obj.t = t
|
||||
@@ -70,44 +72,36 @@ class Spectrum(np.ndarray):
|
||||
self.w = getattr(obj, "w", None)
|
||||
self.t = getattr(obj, "t", None)
|
||||
self.l = getattr(obj, "l", None)
|
||||
self.order = getattr(obj, "order", None)
|
||||
self.ifft = getattr(obj, "ifft", None)
|
||||
|
||||
def __getitem__(self, key) -> "Spectrum":
|
||||
return super().__getitem__(key)
|
||||
|
||||
@property
|
||||
def wl_disp(self):
|
||||
return self.l[self.order][::-1]
|
||||
|
||||
@property
|
||||
def w_disp(self):
|
||||
return self.w[self.order]
|
||||
|
||||
@property
|
||||
def wl_int(self):
|
||||
return units.to_WL(math.abs2(self), self.l)
|
||||
return units.to_WL(math.abs2(self), self.l)[self.order][::-1]
|
||||
|
||||
@property
|
||||
def freq_int(self):
|
||||
return math.abs2(self)
|
||||
return math.abs2(self.freq_amp)
|
||||
|
||||
@property
|
||||
def afreq_int(self):
|
||||
return math.abs2(self)
|
||||
return math.abs2(self.freq_amp)
|
||||
|
||||
@property
|
||||
def time_int(self):
|
||||
return math.abs2(self.ifft(self))
|
||||
|
||||
def amplitude(self, unit):
|
||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||
x_axis = unit.inv(self.w)
|
||||
else:
|
||||
x_axis = unit.inv(self.t)
|
||||
|
||||
order = np.argsort(x_axis)
|
||||
func = dict(
|
||||
WL=self.wl_amp,
|
||||
FREQ=self.freq_amp,
|
||||
AFREQ=self.afreq_amp,
|
||||
TIME=self.time_amp,
|
||||
)[unit.type]
|
||||
|
||||
for spec in self:
|
||||
yield x_axis[order], func(spec)[:, order]
|
||||
|
||||
@property
|
||||
def wl_amp(self):
|
||||
return (
|
||||
@@ -119,15 +113,15 @@ class Spectrum(np.ndarray):
|
||||
)
|
||||
* self
|
||||
/ np.abs(self)
|
||||
)
|
||||
)[self.order][::-1]
|
||||
|
||||
@property
|
||||
def freq_amp(self):
|
||||
return self
|
||||
return self[self.order]
|
||||
|
||||
@property
|
||||
def afreq_amp(self):
|
||||
return self
|
||||
return self[self.order]
|
||||
|
||||
@property
|
||||
def time_amp(self):
|
||||
@@ -180,6 +174,8 @@ class Propagation(Generic[ParamsOrNone]):
|
||||
self.io = io_handler
|
||||
self._current_index = len(self.io)
|
||||
self.parameters = params
|
||||
if self.parameters is not None:
|
||||
self.z_positions = self.parameters.compute("z_targets")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._current_index
|
||||
@@ -198,8 +194,9 @@ class Propagation(Generic[ParamsOrNone]):
|
||||
if isinstance(key, (float, np.floating)):
|
||||
if self.parameters is None:
|
||||
raise ValueError(f"cannot accept float key {key} when parameters is not set")
|
||||
key = math.argclosest(self.parameters.compute("z_targets"), key)
|
||||
key = math.argclosest(self.z_positions, key)
|
||||
elif key < 0:
|
||||
self._warn_negative_index(key)
|
||||
key = len(self) + key
|
||||
array = self.io.load_spectrum(key)
|
||||
if self.parameters is not None:
|
||||
@@ -221,21 +218,31 @@ class Propagation(Generic[ParamsOrNone]):
|
||||
...
|
||||
|
||||
def _load_slice(self, key: slice) -> Spectrum:
|
||||
self._warn_negative_index(key.start)
|
||||
self._warn_negative_index(key.stop)
|
||||
_iter = range(len(self))[key]
|
||||
# if self.parameters is not None:
|
||||
# out = Spectrum.from_params(
|
||||
# np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
|
||||
# )
|
||||
# for i in _iter:
|
||||
# out[i] = self.io.load_spectrum(i)
|
||||
# else:
|
||||
out = np.array([self.io.load_spectrum(i) for i in _iter])
|
||||
if self.parameters is not None:
|
||||
out = Spectrum.from_params(
|
||||
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
|
||||
)
|
||||
for i in _iter:
|
||||
out[i] = self.io.load_spectrum(i)
|
||||
else:
|
||||
out = np.array([self.io.load_spectrum(i) for i in _iter])
|
||||
out = Spectrum.from_params(out, self.parameters)
|
||||
return out
|
||||
|
||||
def append(self, spectrum: np.ndarray):
|
||||
self.io.save_spectrum(self._current_index, np.asarray(spectrum))
|
||||
self._current_index += 1
|
||||
|
||||
def _warn_negative_index(self, index: int | None):
|
||||
if (index is not None and index >= 0) or self.parameters is None:
|
||||
return
|
||||
if self._current_index < len(self.z_positions):
|
||||
warnings.warn(f"attempting to access index {index} on an incomplete propagation obj")
|
||||
|
||||
def load_all(self) -> Spectrum:
|
||||
return self._load_slice(slice(None))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user