Corrected Spectrograms

This commit is contained in:
Benoît Sierro
2022-10-20 11:12:18 +02:00
parent 243f41feec
commit 256ff1d36e
6 changed files with 200 additions and 74 deletions

View File

@@ -359,6 +359,32 @@ def _polynom_extrapolation_in_place(y: np.ndarray, left_ind: int, right_ind: int
return y
@numba.jit(nopython=True)
def linear_interp_2d(old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray):
new_vals = np.zeros((len(old_y), len(new_x)))
interpolable = (new_x > old_x[0]) & (new_x <= old_x[-1])
equal = new_x == old_x[0]
inds = np.searchsorted(old_x, new_x[interpolable])
for i, val in enumerate(old_y):
new_vals[i][interpolable] = val[inds - 1] + (new_x[interpolable] - old_x[inds - 1]) * (
val[inds] - val[inds - 1]
) / (old_x[inds] - old_x[inds - 1])
new_vals[i][equal] = val[0]
return new_vals
@numba.jit(nopython=True)
def linear_interp_1d(old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray):
new_vals = np.zeros(len(new_x))
interpolable = (new_x > old_x[0]) & (new_x <= old_x[-1])
inds = np.searchsorted(old_x, new_x[interpolable])
new_vals[interpolable] = old_y[inds - 1] + (new_x[interpolable] - old_x[inds - 1]) * (
old_y[inds] - old_y[inds - 1]
) / (old_x[inds] - old_x[inds - 1])
new_vals[new_x == old_x[0]] = old_y[0]
return new_vals
def envelope_ind(
signal: np.ndarray, dmin: int = 1, dmax: int = 1, split: bool = False
) -> tuple[np.ndarray, np.ndarray]:

View File

@@ -374,7 +374,7 @@ class Parameters:
t_num: int = Parameter(positive(int))
z_num: int = Parameter(positive(int))
time_window: float = Parameter(positive(float, int))
dt: float = Parameter(in_range_excl(0, 5e-15))
dt: float = Parameter(in_range_excl(0, 10e-15))
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
step_size: float = Parameter(non_negative(float, int), default=0)
interpolation_range: tuple[float, float] = Parameter(

View File

@@ -15,6 +15,7 @@ from dataclasses import astuple, dataclass
from pathlib import Path
from typing import Literal, Tuple, TypeVar
import numba
import matplotlib.pyplot as plt
import numpy as np
from numpy import pi
@@ -570,7 +571,9 @@ def ideal_compressed_pulse(spectra):
return math.abs2(fftshift(ifft(np.sqrt(np.mean(math.abs2(spectra), axis=0)))))
def spectrogram(time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift=False):
def spectrogram(
time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift=False, w_ind: np.ndarray = None
):
"""
returns the spectorgram of the field given in values
@@ -594,14 +597,62 @@ def spectrogram(time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift
delays : 1D array of size t_res
new time axis
"""
t_lim = t_win / 2
delays = np.linspace(-t_lim, t_lim, t_res)
spec = np.zeros((t_res, len(time)))
if isinstance(t_win, tuple):
left, right = t_win
else:
t_lim = t_win / 2
left, right = -t_lim, t_lim
delays = np.linspace(left, right, t_res)
spec = np.zeros((t_res, len(time) if w_ind is None else len(w_ind)))
for i, delay in enumerate(delays):
masked = values * np.exp(-(((time - delay) / gate_width) ** 2))
spec[i] = math.abs2(fft(masked))
if shift:
spec[i] = fftshift(spec[i])
if shift:
spec = fftshift(spec, axes=1)
return spec, delays
def spectrogram_interp(
time: np.ndarray,
delays: np.ndarray,
values: np.ndarray,
old_w: np.ndarray,
w_ind: np.ndarray,
new_w: np.ndarray,
gate_width=200e-15,
):
"""
returns the spectorgram of the field already interpolated along the frequency axis
Parameters
----------
time : 1D array-like
time in the co-moving frame of reference
values : 1D array-like
field array that matches the time array
t_res : int, optional
how many "bins" the time array is subdivided into. Default : 256
t_win : float, optional
total time window (=length of time) over which the spectrogram is computed. Default : 24e-12
gate_width : float, optional
width of the gaussian gate function (=sqrt(2 log(2)) * FWHM). Default : 200e-15
Returns
----------
spec : 2D array
real 2D spectrogram
delays : 1D array of size t_res
new time axis
"""
spec = np.zeros((len(delays), len(new_w)))
for i, delay in enumerate(delays):
masked = values * np.exp(-(((time - delay) / gate_width) ** 2))
spec[i] = math.linear_interp_1d(old_w, math.abs2(fft(masked)[w_ind]), new_w)
return spec, delays

View File

@@ -180,6 +180,7 @@ class RK4IP:
state, self.params.linear_operator, self.params.nonlinear_operator
)
with warnings.catch_warnings():
# catch overflows as errors
warnings.filterwarnings("error", category=RuntimeWarning)
for state in integrator:

View File

@@ -47,6 +47,7 @@ class UnitMap(dict):
def chained_function(x: _T) -> _T:
return c2.inv(c1(x))
chained_function.name = f"{name_1}_to_{name_2}"
chained_function.__name__ = f"{name_1}_to_{name_2}"
chained_function.__doc__ = f"converts x from {name_1} to {name_2}"
@@ -124,17 +125,17 @@ class unit:
return func
@unit("WL", r"Wavelength $\lambda$ (m)")
@unit("WL", r"Wavelength λ (m)")
def m(l: _T) -> _T:
return 2 * pi * c / l
@unit("WL", r"Wavelength $\lambda$ (nm)")
@unit("WL", r"Wavelength λ (nm)")
def nm(l: _T) -> _T:
return 2 * pi * c / (l * 1e-9)
@unit("WL", r"Wavelength $\lambda$ (μm)")
@unit("WL", r"Wavelength λ (μm)")
def um(l: _T) -> _T:
return 2 * pi * c / (l * 1e-6)
@@ -418,6 +419,10 @@ class PlotRange(tuple):
conserved_quantity: bool = property(itemgetter(3))
__slots__ = []
@property
def must_correct_wl(self) -> bool:
return self.unit.type == "WL" and self.conserved_quantity
def __new__(cls, left, right, unit, conserved_quantity=True):
return tuple.__new__(cls, (left, right, get_unit(unit), conserved_quantity))
@@ -461,8 +466,8 @@ def sort_axis(
rw = (-4, 4, s)
rt = (-2, 6, s)
w, cw = sort_axis(w, rw)
t, ct = sort_axis(t, rt)
w, cw, _ = sort_axis(w, rw)
t, ct, _ = sort_axis(t, rt)
# slice y according to the given ranges
y = y[ct][:, cw]

View File

@@ -4,6 +4,7 @@ from typing import Any, Callable, Literal, Optional, Union
import matplotlib.gridspec as gs
import matplotlib.pyplot as plt
import numba
import numpy as np
from matplotlib.colors import ListedColormap
from scipy.interpolate import UnivariateSpline
@@ -12,7 +13,7 @@ from scipy.interpolate.interpolate import interp1d
from . import math
from .const import PARAM_SEPARATOR
from .defaults import default_plotting as defaults
from .math import abs2, span
from .math import abs2, span, linear_interp_2d
from .parameter import Parameters
from .physics import pulse, units
from .physics.units import PlotRange, sort_axis
@@ -256,8 +257,10 @@ def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0
def propagation_plot(
values: np.ndarray,
plt_range: Union[PlotRange, RangeType],
params: Parameters,
ax: plt.Axes,
x_axis: np.ndarray = None,
y_axis: np.ndarray = None,
params: Parameters = None,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
log: Union[int, float, bool, str] = "1D",
renormalize: bool = False,
vmin: float = None,
@@ -295,10 +298,12 @@ def propagation_plot(
Axes obj on which to draw, by default None
"""
x_axis, y_axis, values = transform_2D_propagation(values, plt_range, params, log, skip)
if renormalize and log is False:
x_axis, y_axis, values = transform_2D_propagation(
values, plt_range, x_axis, y_axis, log, skip, params
)
if renormalize and not log:
values = values / values.max()
if log is not False:
if log:
vmax = defaults["vmax"] if vmax is None else vmax
vmin = defaults["vmin"] if vmin is None else vmin
plot_2D(
@@ -320,7 +325,7 @@ def plot_2D(
values: np.ndarray,
x_axis: np.ndarray,
y_axis: np.ndarray,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]],
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
x_label: str = None,
y_label: str = None,
vmin: float = None,
@@ -367,6 +372,8 @@ def plot_2D(
cbar_ax = None
if isinstance(ax, tuple) and len(ax) > 1:
ax, cbar_ax = ax[0], ax[1]
elif ax is None:
ax = plt.gca()
fig = ax.get_figure()
@@ -414,10 +421,11 @@ def plot_2D(
def transform_2D_propagation(
values: np.ndarray,
plt_range: Union[PlotRange, RangeType],
params: Parameters,
x_axis: np.ndarray = None,
y_axis: np.ndarray = None,
log: Union[int, float, bool, str] = "1D",
skip: int = 1,
y_axis=None,
params: Parameters = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""transforms raws values into plottable values
@@ -427,7 +435,11 @@ def transform_2D_propagation(
values to transform
plt_range : Union[PlotRange, RangeType]
range
params : Parameters
x_axis : np.ndarray
corresponding x values in SI units
y_axis : np.ndarray
corresponding y values in SI units
params : Parameters, optional
parameters of the simulation
log : Union[int, float, bool, str], optional
see apply_log, by default "1D"
@@ -448,16 +460,17 @@ def transform_2D_propagation(
ValueError
incorrect shape
"""
x_axis = get_x_axis(plt_range, x_axis, params)
if y_axis is None and params is not None:
y_axis = params.z_targets
if values.ndim != 2:
raise ValueError(f"shape was {values.shape}. Can only plot 2D array")
is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params)
is_complex, plt_range = prep_plot_axis(values, plt_range)
if is_complex or any(values.ravel() < 0):
values = abs2(values)
# if params.full_field and plt_range.unit.type == "TIME":
# values = envelope_2d(x_axis, values)
if y_axis is None:
y_axis = params.z_targets
x_axis, values = uniform_axis(x_axis, values, plt_range)
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
@@ -465,6 +478,17 @@ def transform_2D_propagation(
return x_axis[::skip], y_axis, values[:, ::skip]
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
if x_axis is None and params is not None:
if plt_range.unit.type in {"WL", "FREQ", "AFREQ"}:
x_axis = params.w.copy()
else:
x_axis = params.t.copy()
if x_axis is None:
raise ValueError("No x axis specified")
return x_axis
def mean_values_plot(
values: np.ndarray,
plt_range: Union[PlotRange, RangeType],
@@ -538,10 +562,11 @@ def transform_mean_values(
np.ndarray, shape (m, n)
all the values
"""
AAA
if values.ndim != 2:
print(f"Shape was {values.shape}. Can only plot 2D arrays")
return
is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params)
is_complex, plt_range = prep_plot_axis(values, plt_range, params)
if is_complex:
values = abs2(values)
new_axis, ind, ext = sort_axis(x_axis, plt_range)
@@ -636,8 +661,9 @@ def plot_mean(
def single_position_plot(
values: np.ndarray,
plt_range: Union[PlotRange, RangeType],
params: Parameters,
ax: plt.Axes,
x_axis: np.ndarray = None,
ax: plt.Axes = None,
params: Parameters = None,
log: Union[str, int, float, bool] = False,
vmin: float = None,
vmax: float = None,
@@ -647,8 +673,7 @@ def single_position_plot(
y_label: str = None,
**line_kwargs,
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]:
x_axis, values = transform_1D_values(values, plt_range, params, log, spacing)
x_axis, values = transform_1D_values(values, plt_range, x_axis, params, log, spacing)
if renormalize:
values = values / values.max()
@@ -664,7 +689,7 @@ def single_position_plot(
def plot_1D(
values: np.ndarray,
x_axis: np.ndarray,
ax: plt.Axes,
ax: Optional[plt.Axes],
x_label: str = None,
y_label: str = None,
vmin: float = None,
@@ -693,11 +718,18 @@ def plot_1D(
transpose : bool, optional
rotate the plot 90° counterclockwise, by default False
"""
if transpose:
if ax is None:
ax = plt.gca()
if transpose == -1:
(line,) = ax.plot(values, x_axis, **line_kwargs)
ax.set_xlim(vmax, vmin)
ax.set_xlabel(y_label)
ax.set_ylabel(x_label)
elif transpose == 1:
(line,) = ax.plot(values, x_axis, **line_kwargs)
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
ax.set_xlim(vmax, vmin)
ax.set_xlim(vmin, vmax)
ax.set_xlabel(y_label)
ax.set_ylabel(x_label)
else:
@@ -711,7 +743,8 @@ def plot_1D(
def transform_1D_values(
values: np.ndarray,
plt_range: Union[PlotRange, RangeType],
params: Parameters,
x_axis: np.ndarray = None,
params: Parameters = None,
log: Union[int, float, bool] = False,
spacing: Union[int, float] = 1,
) -> tuple[np.ndarray, np.ndarray]:
@@ -737,9 +770,10 @@ def transform_1D_values(
tuple[np.ndarray, np.ndarray]
x axis and values
"""
x_axis = get_x_axis(plt_range, x_axis, params)
if len(values.shape) != 1:
raise ValueError("Can only plot 1D values")
is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params)
is_complex, plt_range = prep_plot_axis(values, plt_range)
if is_complex:
values = abs2(values)
new_axis, ind, ext = sort_axis(x_axis, plt_range)
@@ -763,17 +797,20 @@ def transform_1D_values(
def plot_spectrogram(
values: np.ndarray,
x_range: RangeType,
y_range: RangeType,
params: Parameters,
t_res: int = None,
gate_width: float = None,
x_range: PlotRange,
y_range: PlotRange,
x_axis: np.ndarray = None,
y_axis: np.ndarray = None,
t_res: int = 512,
w_res: int = 512,
gate_width: float = 200e-15,
params: Parameters = None,
log: bool = "2D",
vmin: float = None,
vmax: float = None,
vmin: float = -50,
vmax: float = 0,
cbar_label: str = "normalized intensity (dB)",
cmap: str = None,
ax: plt.Axes = None,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
):
"""Plots a spectrogram given a complex field in the time domain
Parameters
@@ -781,7 +818,7 @@ def plot_spectrogram(
values : 2D array
axis 0 defines the position in the fiber and axis 1 the position in time, frequency or wl
example : [[1, 2, 3], [0, 1, 0]] describes a quantity at 3 different freq/time and at two locations in the fiber
x_range, y_range : tupple (min, max, units)
x_range, y_range : PlotRange
one of them must be time, the other one must be wl/freq
min, max : int or float
minimum and maximum values given in the desired units
@@ -798,10 +835,6 @@ def plot_spectrogram(
max value of the colorbar
cbar_label : str or None
label of the colorbar. Will not draw colorbar if None
file_type : str, optional
usually pdf or png
plt_name : str, optional
special name to give to the plot. A name is automatically assigned anyway
cmap : str, optional
colormap to be used in matplotlib.pyplot.imshow
ax : matplotlib.axes._subplots.AxesSubplot object or tupple of 2 axis objects, optional
@@ -812,35 +845,49 @@ def plot_spectrogram(
if values.ndim != 1:
print("plot_spectrogram can only plot 1D arrays")
return
x_range: PlotRange
y_range: PlotRange
_, x_axis, x_range = prep_plot_axis(values, x_range, params)
_, y_axis, y_range = prep_plot_axis(values, y_range, params)
x_axis = get_x_axis(x_range, x_axis, params)
y_axis = get_x_axis(y_range, y_axis, params)
_, x_range = prep_plot_axis(values, x_range)
_, y_range = prep_plot_axis(values, y_range)
if (x_range.unit.type == "TIME") == (y_range.unit.type == "TIME"):
print("exactly one range must be a time range")
return
# 0 axis means x-axis -> determine final orientation of spectrogram
time_axis = 0 if x_range.unit.type not in ["WL", "FREQ", "AFREQ"] else 1
if time_axis == 0:
time_axis = 1 if x_range.unit.type not in ["WL", "FREQ", "AFREQ"] else 1
if time_axis == 1:
t_axis = x_axis
t_range = x_range
w_range = y_range
w_axis = y_axis
else:
t_axis = y_axis
t_range = y_range
w_range = x_range
w_axis = x_axis
old_w, w_ind, _ = w_range.sort_axis(w_axis)
new_w = np.linspace(w_range.left, w_range.right, w_res)
# Actually compute the spectrogram
t_win = 2 * np.max(t_range.unit(np.abs((t_range.left, t_range.right))))
spec_kwargs = dict(t_res=t_res, t_win=t_win, gate_width=gate_width, shift=False)
spec, new_t = pulse.spectrogram(
params.t.copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None}
delays = np.linspace(t_range.unit(t_range.left), t_range.unit(t_range.right), t_res)
spec, new_t = pulse.spectrogram_interp(
t_axis, delays, values, old_w, w_ind, new_w, gate_width=gate_width
)
if time_axis == 0:
x_axis = new_t
else:
y_axis = new_t
x_axis, spec = uniform_axis(x_axis, spec, x_range)
y_axis, spec.T[:] = uniform_axis(y_axis, spec.T, y_range)
new_t = t_range.unit.inv(new_t)
if w_range.must_correct_wl:
spec = np.apply_along_axis(units.to_WL, 1, spec, new_w)
if time_axis == 1:
spec = spec.T
x_axis = new_t
y_axis = new_w
else:
x_axis = new_w
y_axis = new_t
values = apply_log(spec, log)
@@ -849,8 +896,8 @@ def plot_spectrogram(
x_axis,
y_axis,
ax,
x_range.unit.label,
y_range.unit.label,
None,
None,
vmin,
vmax,
False,
@@ -907,7 +954,7 @@ def uniform_axis(
if plt_range.unit.type == "WL" and plt_range.conserved_quantity:
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
values = np.array([interp1d(tmp_axis, v[ind])(new_axis) for v in values])
values = linear_interp_2d(tmp_axis, values[:, ind], new_axis)
return new_axis, values.squeeze()
@@ -954,16 +1001,12 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
def prep_plot_axis(
values: np.ndarray, plt_range: Union[PlotRange, RangeType], params: Parameters
) -> tuple[bool, np.ndarray, PlotRange]:
values: np.ndarray, plt_range: Union[PlotRange, RangeType]
) -> tuple[bool, PlotRange]:
is_spectrum = values.dtype == "complex"
if not isinstance(plt_range, PlotRange):
plt_range = PlotRange(*plt_range)
if plt_range.unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = params.w.copy()
else:
x_axis = params.t.copy()
return is_spectrum, x_axis, plt_range
return is_spectrum, plt_range
def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):