Corrected Spectrograms

This commit is contained in:
Benoît Sierro
2022-10-20 11:12:18 +02:00
parent 243f41feec
commit 256ff1d36e
6 changed files with 200 additions and 74 deletions

View File

@@ -359,6 +359,32 @@ def _polynom_extrapolation_in_place(y: np.ndarray, left_ind: int, right_ind: int
return y return y
@numba.jit(nopython=True)
def linear_interp_2d(old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray):
new_vals = np.zeros((len(old_y), len(new_x)))
interpolable = (new_x > old_x[0]) & (new_x <= old_x[-1])
equal = new_x == old_x[0]
inds = np.searchsorted(old_x, new_x[interpolable])
for i, val in enumerate(old_y):
new_vals[i][interpolable] = val[inds - 1] + (new_x[interpolable] - old_x[inds - 1]) * (
val[inds] - val[inds - 1]
) / (old_x[inds] - old_x[inds - 1])
new_vals[i][equal] = val[0]
return new_vals
@numba.jit(nopython=True)
def linear_interp_1d(old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray):
new_vals = np.zeros(len(new_x))
interpolable = (new_x > old_x[0]) & (new_x <= old_x[-1])
inds = np.searchsorted(old_x, new_x[interpolable])
new_vals[interpolable] = old_y[inds - 1] + (new_x[interpolable] - old_x[inds - 1]) * (
old_y[inds] - old_y[inds - 1]
) / (old_x[inds] - old_x[inds - 1])
new_vals[new_x == old_x[0]] = old_y[0]
return new_vals
def envelope_ind( def envelope_ind(
signal: np.ndarray, dmin: int = 1, dmax: int = 1, split: bool = False signal: np.ndarray, dmin: int = 1, dmax: int = 1, split: bool = False
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:

View File

@@ -374,7 +374,7 @@ class Parameters:
t_num: int = Parameter(positive(int)) t_num: int = Parameter(positive(int))
z_num: int = Parameter(positive(int)) z_num: int = Parameter(positive(int))
time_window: float = Parameter(positive(float, int)) time_window: float = Parameter(positive(float, int))
dt: float = Parameter(in_range_excl(0, 5e-15)) dt: float = Parameter(in_range_excl(0, 10e-15))
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11) tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
step_size: float = Parameter(non_negative(float, int), default=0) step_size: float = Parameter(non_negative(float, int), default=0)
interpolation_range: tuple[float, float] = Parameter( interpolation_range: tuple[float, float] = Parameter(

View File

@@ -15,6 +15,7 @@ from dataclasses import astuple, dataclass
from pathlib import Path from pathlib import Path
from typing import Literal, Tuple, TypeVar from typing import Literal, Tuple, TypeVar
import numba
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from numpy import pi from numpy import pi
@@ -570,7 +571,9 @@ def ideal_compressed_pulse(spectra):
return math.abs2(fftshift(ifft(np.sqrt(np.mean(math.abs2(spectra), axis=0))))) return math.abs2(fftshift(ifft(np.sqrt(np.mean(math.abs2(spectra), axis=0)))))
def spectrogram(time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift=False): def spectrogram(
time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift=False, w_ind: np.ndarray = None
):
""" """
returns the spectorgram of the field given in values returns the spectorgram of the field given in values
@@ -594,14 +597,62 @@ def spectrogram(time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift
delays : 1D array of size t_res delays : 1D array of size t_res
new time axis new time axis
""" """
t_lim = t_win / 2 if isinstance(t_win, tuple):
delays = np.linspace(-t_lim, t_lim, t_res) left, right = t_win
spec = np.zeros((t_res, len(time))) else:
t_lim = t_win / 2
left, right = -t_lim, t_lim
delays = np.linspace(left, right, t_res)
spec = np.zeros((t_res, len(time) if w_ind is None else len(w_ind)))
for i, delay in enumerate(delays): for i, delay in enumerate(delays):
masked = values * np.exp(-(((time - delay) / gate_width) ** 2)) masked = values * np.exp(-(((time - delay) / gate_width) ** 2))
spec[i] = math.abs2(fft(masked)) spec[i] = math.abs2(fft(masked))
if shift:
spec[i] = fftshift(spec[i]) if shift:
spec = fftshift(spec, axes=1)
return spec, delays
def spectrogram_interp(
time: np.ndarray,
delays: np.ndarray,
values: np.ndarray,
old_w: np.ndarray,
w_ind: np.ndarray,
new_w: np.ndarray,
gate_width=200e-15,
):
"""
returns the spectorgram of the field already interpolated along the frequency axis
Parameters
----------
time : 1D array-like
time in the co-moving frame of reference
values : 1D array-like
field array that matches the time array
t_res : int, optional
how many "bins" the time array is subdivided into. Default : 256
t_win : float, optional
total time window (=length of time) over which the spectrogram is computed. Default : 24e-12
gate_width : float, optional
width of the gaussian gate function (=sqrt(2 log(2)) * FWHM). Default : 200e-15
Returns
----------
spec : 2D array
real 2D spectrogram
delays : 1D array of size t_res
new time axis
"""
spec = np.zeros((len(delays), len(new_w)))
for i, delay in enumerate(delays):
masked = values * np.exp(-(((time - delay) / gate_width) ** 2))
spec[i] = math.linear_interp_1d(old_w, math.abs2(fft(masked)[w_ind]), new_w)
return spec, delays return spec, delays

View File

@@ -180,6 +180,7 @@ class RK4IP:
state, self.params.linear_operator, self.params.nonlinear_operator state, self.params.linear_operator, self.params.nonlinear_operator
) )
with warnings.catch_warnings(): with warnings.catch_warnings():
# catch overflows as errors
warnings.filterwarnings("error", category=RuntimeWarning) warnings.filterwarnings("error", category=RuntimeWarning)
for state in integrator: for state in integrator:

View File

@@ -47,6 +47,7 @@ class UnitMap(dict):
def chained_function(x: _T) -> _T: def chained_function(x: _T) -> _T:
return c2.inv(c1(x)) return c2.inv(c1(x))
chained_function.name = f"{name_1}_to_{name_2}"
chained_function.__name__ = f"{name_1}_to_{name_2}" chained_function.__name__ = f"{name_1}_to_{name_2}"
chained_function.__doc__ = f"converts x from {name_1} to {name_2}" chained_function.__doc__ = f"converts x from {name_1} to {name_2}"
@@ -124,17 +125,17 @@ class unit:
return func return func
@unit("WL", r"Wavelength $\lambda$ (m)") @unit("WL", r"Wavelength λ (m)")
def m(l: _T) -> _T: def m(l: _T) -> _T:
return 2 * pi * c / l return 2 * pi * c / l
@unit("WL", r"Wavelength $\lambda$ (nm)") @unit("WL", r"Wavelength λ (nm)")
def nm(l: _T) -> _T: def nm(l: _T) -> _T:
return 2 * pi * c / (l * 1e-9) return 2 * pi * c / (l * 1e-9)
@unit("WL", r"Wavelength $\lambda$ (μm)") @unit("WL", r"Wavelength λ (μm)")
def um(l: _T) -> _T: def um(l: _T) -> _T:
return 2 * pi * c / (l * 1e-6) return 2 * pi * c / (l * 1e-6)
@@ -418,6 +419,10 @@ class PlotRange(tuple):
conserved_quantity: bool = property(itemgetter(3)) conserved_quantity: bool = property(itemgetter(3))
__slots__ = [] __slots__ = []
@property
def must_correct_wl(self) -> bool:
return self.unit.type == "WL" and self.conserved_quantity
def __new__(cls, left, right, unit, conserved_quantity=True): def __new__(cls, left, right, unit, conserved_quantity=True):
return tuple.__new__(cls, (left, right, get_unit(unit), conserved_quantity)) return tuple.__new__(cls, (left, right, get_unit(unit), conserved_quantity))
@@ -461,8 +466,8 @@ def sort_axis(
rw = (-4, 4, s) rw = (-4, 4, s)
rt = (-2, 6, s) rt = (-2, 6, s)
w, cw = sort_axis(w, rw) w, cw, _ = sort_axis(w, rw)
t, ct = sort_axis(t, rt) t, ct, _ = sort_axis(t, rt)
# slice y according to the given ranges # slice y according to the given ranges
y = y[ct][:, cw] y = y[ct][:, cw]

View File

@@ -4,6 +4,7 @@ from typing import Any, Callable, Literal, Optional, Union
import matplotlib.gridspec as gs import matplotlib.gridspec as gs
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numba
import numpy as np import numpy as np
from matplotlib.colors import ListedColormap from matplotlib.colors import ListedColormap
from scipy.interpolate import UnivariateSpline from scipy.interpolate import UnivariateSpline
@@ -12,7 +13,7 @@ from scipy.interpolate.interpolate import interp1d
from . import math from . import math
from .const import PARAM_SEPARATOR from .const import PARAM_SEPARATOR
from .defaults import default_plotting as defaults from .defaults import default_plotting as defaults
from .math import abs2, span from .math import abs2, span, linear_interp_2d
from .parameter import Parameters from .parameter import Parameters
from .physics import pulse, units from .physics import pulse, units
from .physics.units import PlotRange, sort_axis from .physics.units import PlotRange, sort_axis
@@ -256,8 +257,10 @@ def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0
def propagation_plot( def propagation_plot(
values: np.ndarray, values: np.ndarray,
plt_range: Union[PlotRange, RangeType], plt_range: Union[PlotRange, RangeType],
params: Parameters, x_axis: np.ndarray = None,
ax: plt.Axes, y_axis: np.ndarray = None,
params: Parameters = None,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
log: Union[int, float, bool, str] = "1D", log: Union[int, float, bool, str] = "1D",
renormalize: bool = False, renormalize: bool = False,
vmin: float = None, vmin: float = None,
@@ -295,10 +298,12 @@ def propagation_plot(
Axes obj on which to draw, by default None Axes obj on which to draw, by default None
""" """
x_axis, y_axis, values = transform_2D_propagation(values, plt_range, params, log, skip) x_axis, y_axis, values = transform_2D_propagation(
if renormalize and log is False: values, plt_range, x_axis, y_axis, log, skip, params
)
if renormalize and not log:
values = values / values.max() values = values / values.max()
if log is not False: if log:
vmax = defaults["vmax"] if vmax is None else vmax vmax = defaults["vmax"] if vmax is None else vmax
vmin = defaults["vmin"] if vmin is None else vmin vmin = defaults["vmin"] if vmin is None else vmin
plot_2D( plot_2D(
@@ -320,7 +325,7 @@ def plot_2D(
values: np.ndarray, values: np.ndarray,
x_axis: np.ndarray, x_axis: np.ndarray,
y_axis: np.ndarray, y_axis: np.ndarray,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]], ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
x_label: str = None, x_label: str = None,
y_label: str = None, y_label: str = None,
vmin: float = None, vmin: float = None,
@@ -367,6 +372,8 @@ def plot_2D(
cbar_ax = None cbar_ax = None
if isinstance(ax, tuple) and len(ax) > 1: if isinstance(ax, tuple) and len(ax) > 1:
ax, cbar_ax = ax[0], ax[1] ax, cbar_ax = ax[0], ax[1]
elif ax is None:
ax = plt.gca()
fig = ax.get_figure() fig = ax.get_figure()
@@ -414,10 +421,11 @@ def plot_2D(
def transform_2D_propagation( def transform_2D_propagation(
values: np.ndarray, values: np.ndarray,
plt_range: Union[PlotRange, RangeType], plt_range: Union[PlotRange, RangeType],
params: Parameters, x_axis: np.ndarray = None,
y_axis: np.ndarray = None,
log: Union[int, float, bool, str] = "1D", log: Union[int, float, bool, str] = "1D",
skip: int = 1, skip: int = 1,
y_axis=None, params: Parameters = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""transforms raws values into plottable values """transforms raws values into plottable values
@@ -427,7 +435,11 @@ def transform_2D_propagation(
values to transform values to transform
plt_range : Union[PlotRange, RangeType] plt_range : Union[PlotRange, RangeType]
range range
params : Parameters x_axis : np.ndarray
corresponding x values in SI units
y_axis : np.ndarray
corresponding y values in SI units
params : Parameters, optional
parameters of the simulation parameters of the simulation
log : Union[int, float, bool, str], optional log : Union[int, float, bool, str], optional
see apply_log, by default "1D" see apply_log, by default "1D"
@@ -448,16 +460,17 @@ def transform_2D_propagation(
ValueError ValueError
incorrect shape incorrect shape
""" """
x_axis = get_x_axis(plt_range, x_axis, params)
if y_axis is None and params is not None:
y_axis = params.z_targets
if values.ndim != 2: if values.ndim != 2:
raise ValueError(f"shape was {values.shape}. Can only plot 2D array") raise ValueError(f"shape was {values.shape}. Can only plot 2D array")
is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params) is_complex, plt_range = prep_plot_axis(values, plt_range)
if is_complex or any(values.ravel() < 0): if is_complex or any(values.ravel() < 0):
values = abs2(values) values = abs2(values)
# if params.full_field and plt_range.unit.type == "TIME": # if params.full_field and plt_range.unit.type == "TIME":
# values = envelope_2d(x_axis, values) # values = envelope_2d(x_axis, values)
if y_axis is None:
y_axis = params.z_targets
x_axis, values = uniform_axis(x_axis, values, plt_range) x_axis, values = uniform_axis(x_axis, values, plt_range)
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None) y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
@@ -465,6 +478,17 @@ def transform_2D_propagation(
return x_axis[::skip], y_axis, values[:, ::skip] return x_axis[::skip], y_axis, values[:, ::skip]
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
if x_axis is None and params is not None:
if plt_range.unit.type in {"WL", "FREQ", "AFREQ"}:
x_axis = params.w.copy()
else:
x_axis = params.t.copy()
if x_axis is None:
raise ValueError("No x axis specified")
return x_axis
def mean_values_plot( def mean_values_plot(
values: np.ndarray, values: np.ndarray,
plt_range: Union[PlotRange, RangeType], plt_range: Union[PlotRange, RangeType],
@@ -538,10 +562,11 @@ def transform_mean_values(
np.ndarray, shape (m, n) np.ndarray, shape (m, n)
all the values all the values
""" """
AAA
if values.ndim != 2: if values.ndim != 2:
print(f"Shape was {values.shape}. Can only plot 2D arrays") print(f"Shape was {values.shape}. Can only plot 2D arrays")
return return
is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params) is_complex, plt_range = prep_plot_axis(values, plt_range, params)
if is_complex: if is_complex:
values = abs2(values) values = abs2(values)
new_axis, ind, ext = sort_axis(x_axis, plt_range) new_axis, ind, ext = sort_axis(x_axis, plt_range)
@@ -636,8 +661,9 @@ def plot_mean(
def single_position_plot( def single_position_plot(
values: np.ndarray, values: np.ndarray,
plt_range: Union[PlotRange, RangeType], plt_range: Union[PlotRange, RangeType],
params: Parameters, x_axis: np.ndarray = None,
ax: plt.Axes, ax: plt.Axes = None,
params: Parameters = None,
log: Union[str, int, float, bool] = False, log: Union[str, int, float, bool] = False,
vmin: float = None, vmin: float = None,
vmax: float = None, vmax: float = None,
@@ -647,8 +673,7 @@ def single_position_plot(
y_label: str = None, y_label: str = None,
**line_kwargs, **line_kwargs,
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]: ) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]:
x_axis, values = transform_1D_values(values, plt_range, x_axis, params, log, spacing)
x_axis, values = transform_1D_values(values, plt_range, params, log, spacing)
if renormalize: if renormalize:
values = values / values.max() values = values / values.max()
@@ -664,7 +689,7 @@ def single_position_plot(
def plot_1D( def plot_1D(
values: np.ndarray, values: np.ndarray,
x_axis: np.ndarray, x_axis: np.ndarray,
ax: plt.Axes, ax: Optional[plt.Axes],
x_label: str = None, x_label: str = None,
y_label: str = None, y_label: str = None,
vmin: float = None, vmin: float = None,
@@ -693,11 +718,18 @@ def plot_1D(
transpose : bool, optional transpose : bool, optional
rotate the plot 90° counterclockwise, by default False rotate the plot 90° counterclockwise, by default False
""" """
if transpose: if ax is None:
ax = plt.gca()
if transpose == -1:
(line,) = ax.plot(values, x_axis, **line_kwargs)
ax.set_xlim(vmax, vmin)
ax.set_xlabel(y_label)
ax.set_ylabel(x_label)
elif transpose == 1:
(line,) = ax.plot(values, x_axis, **line_kwargs) (line,) = ax.plot(values, x_axis, **line_kwargs)
ax.yaxis.tick_right() ax.yaxis.tick_right()
ax.yaxis.set_label_position("right") ax.yaxis.set_label_position("right")
ax.set_xlim(vmax, vmin) ax.set_xlim(vmin, vmax)
ax.set_xlabel(y_label) ax.set_xlabel(y_label)
ax.set_ylabel(x_label) ax.set_ylabel(x_label)
else: else:
@@ -711,7 +743,8 @@ def plot_1D(
def transform_1D_values( def transform_1D_values(
values: np.ndarray, values: np.ndarray,
plt_range: Union[PlotRange, RangeType], plt_range: Union[PlotRange, RangeType],
params: Parameters, x_axis: np.ndarray = None,
params: Parameters = None,
log: Union[int, float, bool] = False, log: Union[int, float, bool] = False,
spacing: Union[int, float] = 1, spacing: Union[int, float] = 1,
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
@@ -737,9 +770,10 @@ def transform_1D_values(
tuple[np.ndarray, np.ndarray] tuple[np.ndarray, np.ndarray]
x axis and values x axis and values
""" """
x_axis = get_x_axis(plt_range, x_axis, params)
if len(values.shape) != 1: if len(values.shape) != 1:
raise ValueError("Can only plot 1D values") raise ValueError("Can only plot 1D values")
is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params) is_complex, plt_range = prep_plot_axis(values, plt_range)
if is_complex: if is_complex:
values = abs2(values) values = abs2(values)
new_axis, ind, ext = sort_axis(x_axis, plt_range) new_axis, ind, ext = sort_axis(x_axis, plt_range)
@@ -763,17 +797,20 @@ def transform_1D_values(
def plot_spectrogram( def plot_spectrogram(
values: np.ndarray, values: np.ndarray,
x_range: RangeType, x_range: PlotRange,
y_range: RangeType, y_range: PlotRange,
params: Parameters, x_axis: np.ndarray = None,
t_res: int = None, y_axis: np.ndarray = None,
gate_width: float = None, t_res: int = 512,
w_res: int = 512,
gate_width: float = 200e-15,
params: Parameters = None,
log: bool = "2D", log: bool = "2D",
vmin: float = None, vmin: float = -50,
vmax: float = None, vmax: float = 0,
cbar_label: str = "normalized intensity (dB)", cbar_label: str = "normalized intensity (dB)",
cmap: str = None, cmap: str = None,
ax: plt.Axes = None, ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
): ):
"""Plots a spectrogram given a complex field in the time domain """Plots a spectrogram given a complex field in the time domain
Parameters Parameters
@@ -781,7 +818,7 @@ def plot_spectrogram(
values : 2D array values : 2D array
axis 0 defines the position in the fiber and axis 1 the position in time, frequency or wl axis 0 defines the position in the fiber and axis 1 the position in time, frequency or wl
example : [[1, 2, 3], [0, 1, 0]] describes a quantity at 3 different freq/time and at two locations in the fiber example : [[1, 2, 3], [0, 1, 0]] describes a quantity at 3 different freq/time and at two locations in the fiber
x_range, y_range : tupple (min, max, units) x_range, y_range : PlotRange
one of them must be time, the other one must be wl/freq one of them must be time, the other one must be wl/freq
min, max : int or float min, max : int or float
minimum and maximum values given in the desired units minimum and maximum values given in the desired units
@@ -798,10 +835,6 @@ def plot_spectrogram(
max value of the colorbar max value of the colorbar
cbar_label : str or None cbar_label : str or None
label of the colorbar. Will not draw colorbar if None label of the colorbar. Will not draw colorbar if None
file_type : str, optional
usually pdf or png
plt_name : str, optional
special name to give to the plot. A name is automatically assigned anyway
cmap : str, optional cmap : str, optional
colormap to be used in matplotlib.pyplot.imshow colormap to be used in matplotlib.pyplot.imshow
ax : matplotlib.axes._subplots.AxesSubplot object or tupple of 2 axis objects, optional ax : matplotlib.axes._subplots.AxesSubplot object or tupple of 2 axis objects, optional
@@ -812,35 +845,49 @@ def plot_spectrogram(
if values.ndim != 1: if values.ndim != 1:
print("plot_spectrogram can only plot 1D arrays") print("plot_spectrogram can only plot 1D arrays")
return return
x_range: PlotRange x_axis = get_x_axis(x_range, x_axis, params)
y_range: PlotRange y_axis = get_x_axis(y_range, y_axis, params)
_, x_axis, x_range = prep_plot_axis(values, x_range, params) _, x_range = prep_plot_axis(values, x_range)
_, y_axis, y_range = prep_plot_axis(values, y_range, params) _, y_range = prep_plot_axis(values, y_range)
if (x_range.unit.type == "TIME") == (y_range.unit.type == "TIME"): if (x_range.unit.type == "TIME") == (y_range.unit.type == "TIME"):
print("exactly one range must be a time range") print("exactly one range must be a time range")
return return
# 0 axis means x-axis -> determine final orientation of spectrogram # 0 axis means x-axis -> determine final orientation of spectrogram
time_axis = 0 if x_range.unit.type not in ["WL", "FREQ", "AFREQ"] else 1 time_axis = 1 if x_range.unit.type not in ["WL", "FREQ", "AFREQ"] else 1
if time_axis == 0: if time_axis == 1:
t_axis = x_axis
t_range = x_range t_range = x_range
w_range = y_range
w_axis = y_axis
else: else:
t_axis = y_axis
t_range = y_range t_range = y_range
w_range = x_range
w_axis = x_axis
old_w, w_ind, _ = w_range.sort_axis(w_axis)
new_w = np.linspace(w_range.left, w_range.right, w_res)
# Actually compute the spectrogram # Actually compute the spectrogram
t_win = 2 * np.max(t_range.unit(np.abs((t_range.left, t_range.right)))) delays = np.linspace(t_range.unit(t_range.left), t_range.unit(t_range.right), t_res)
spec_kwargs = dict(t_res=t_res, t_win=t_win, gate_width=gate_width, shift=False) spec, new_t = pulse.spectrogram_interp(
spec, new_t = pulse.spectrogram( t_axis, delays, values, old_w, w_ind, new_w, gate_width=gate_width
params.t.copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None}
) )
if time_axis == 0:
x_axis = new_t
else:
y_axis = new_t
x_axis, spec = uniform_axis(x_axis, spec, x_range) new_t = t_range.unit.inv(new_t)
y_axis, spec.T[:] = uniform_axis(y_axis, spec.T, y_range)
if w_range.must_correct_wl:
spec = np.apply_along_axis(units.to_WL, 1, spec, new_w)
if time_axis == 1:
spec = spec.T
x_axis = new_t
y_axis = new_w
else:
x_axis = new_w
y_axis = new_t
values = apply_log(spec, log) values = apply_log(spec, log)
@@ -849,8 +896,8 @@ def plot_spectrogram(
x_axis, x_axis,
y_axis, y_axis,
ax, ax,
x_range.unit.label, None,
y_range.unit.label, None,
vmin, vmin,
vmax, vmax,
False, False,
@@ -907,7 +954,7 @@ def uniform_axis(
if plt_range.unit.type == "WL" and plt_range.conserved_quantity: if plt_range.unit.type == "WL" and plt_range.conserved_quantity:
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis) values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis)) new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
values = np.array([interp1d(tmp_axis, v[ind])(new_axis) for v in values]) values = linear_interp_2d(tmp_axis, values[:, ind], new_axis)
return new_axis, values.squeeze() return new_axis, values.squeeze()
@@ -954,16 +1001,12 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
def prep_plot_axis( def prep_plot_axis(
values: np.ndarray, plt_range: Union[PlotRange, RangeType], params: Parameters values: np.ndarray, plt_range: Union[PlotRange, RangeType]
) -> tuple[bool, np.ndarray, PlotRange]: ) -> tuple[bool, PlotRange]:
is_spectrum = values.dtype == "complex" is_spectrum = values.dtype == "complex"
if not isinstance(plt_range, PlotRange): if not isinstance(plt_range, PlotRange):
plt_range = PlotRange(*plt_range) plt_range = PlotRange(*plt_range)
if plt_range.unit.type in ["WL", "FREQ", "AFREQ"]: return is_spectrum, plt_range
x_axis = params.w.copy()
else:
x_axis = params.t.copy()
return is_spectrum, x_axis, plt_range
def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)): def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):