Corrected Spectrograms
This commit is contained in:
@@ -359,6 +359,32 @@ def _polynom_extrapolation_in_place(y: np.ndarray, left_ind: int, right_ind: int
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
@numba.jit(nopython=True)
|
||||||
|
def linear_interp_2d(old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray):
|
||||||
|
new_vals = np.zeros((len(old_y), len(new_x)))
|
||||||
|
interpolable = (new_x > old_x[0]) & (new_x <= old_x[-1])
|
||||||
|
equal = new_x == old_x[0]
|
||||||
|
inds = np.searchsorted(old_x, new_x[interpolable])
|
||||||
|
for i, val in enumerate(old_y):
|
||||||
|
new_vals[i][interpolable] = val[inds - 1] + (new_x[interpolable] - old_x[inds - 1]) * (
|
||||||
|
val[inds] - val[inds - 1]
|
||||||
|
) / (old_x[inds] - old_x[inds - 1])
|
||||||
|
new_vals[i][equal] = val[0]
|
||||||
|
return new_vals
|
||||||
|
|
||||||
|
|
||||||
|
@numba.jit(nopython=True)
|
||||||
|
def linear_interp_1d(old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray):
|
||||||
|
new_vals = np.zeros(len(new_x))
|
||||||
|
interpolable = (new_x > old_x[0]) & (new_x <= old_x[-1])
|
||||||
|
inds = np.searchsorted(old_x, new_x[interpolable])
|
||||||
|
new_vals[interpolable] = old_y[inds - 1] + (new_x[interpolable] - old_x[inds - 1]) * (
|
||||||
|
old_y[inds] - old_y[inds - 1]
|
||||||
|
) / (old_x[inds] - old_x[inds - 1])
|
||||||
|
new_vals[new_x == old_x[0]] = old_y[0]
|
||||||
|
return new_vals
|
||||||
|
|
||||||
|
|
||||||
def envelope_ind(
|
def envelope_ind(
|
||||||
signal: np.ndarray, dmin: int = 1, dmax: int = 1, split: bool = False
|
signal: np.ndarray, dmin: int = 1, dmax: int = 1, split: bool = False
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
|||||||
@@ -374,7 +374,7 @@ class Parameters:
|
|||||||
t_num: int = Parameter(positive(int))
|
t_num: int = Parameter(positive(int))
|
||||||
z_num: int = Parameter(positive(int))
|
z_num: int = Parameter(positive(int))
|
||||||
time_window: float = Parameter(positive(float, int))
|
time_window: float = Parameter(positive(float, int))
|
||||||
dt: float = Parameter(in_range_excl(0, 5e-15))
|
dt: float = Parameter(in_range_excl(0, 10e-15))
|
||||||
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
|
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
|
||||||
step_size: float = Parameter(non_negative(float, int), default=0)
|
step_size: float = Parameter(non_negative(float, int), default=0)
|
||||||
interpolation_range: tuple[float, float] = Parameter(
|
interpolation_range: tuple[float, float] = Parameter(
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from dataclasses import astuple, dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Tuple, TypeVar
|
from typing import Literal, Tuple, TypeVar
|
||||||
|
|
||||||
|
import numba
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import pi
|
from numpy import pi
|
||||||
@@ -570,7 +571,9 @@ def ideal_compressed_pulse(spectra):
|
|||||||
return math.abs2(fftshift(ifft(np.sqrt(np.mean(math.abs2(spectra), axis=0)))))
|
return math.abs2(fftshift(ifft(np.sqrt(np.mean(math.abs2(spectra), axis=0)))))
|
||||||
|
|
||||||
|
|
||||||
def spectrogram(time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift=False):
|
def spectrogram(
|
||||||
|
time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift=False, w_ind: np.ndarray = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
returns the spectorgram of the field given in values
|
returns the spectorgram of the field given in values
|
||||||
|
|
||||||
@@ -594,14 +597,62 @@ def spectrogram(time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift
|
|||||||
delays : 1D array of size t_res
|
delays : 1D array of size t_res
|
||||||
new time axis
|
new time axis
|
||||||
"""
|
"""
|
||||||
|
if isinstance(t_win, tuple):
|
||||||
|
left, right = t_win
|
||||||
|
else:
|
||||||
t_lim = t_win / 2
|
t_lim = t_win / 2
|
||||||
delays = np.linspace(-t_lim, t_lim, t_res)
|
left, right = -t_lim, t_lim
|
||||||
spec = np.zeros((t_res, len(time)))
|
|
||||||
|
delays = np.linspace(left, right, t_res)
|
||||||
|
|
||||||
|
spec = np.zeros((t_res, len(time) if w_ind is None else len(w_ind)))
|
||||||
for i, delay in enumerate(delays):
|
for i, delay in enumerate(delays):
|
||||||
masked = values * np.exp(-(((time - delay) / gate_width) ** 2))
|
masked = values * np.exp(-(((time - delay) / gate_width) ** 2))
|
||||||
spec[i] = math.abs2(fft(masked))
|
spec[i] = math.abs2(fft(masked))
|
||||||
|
|
||||||
if shift:
|
if shift:
|
||||||
spec[i] = fftshift(spec[i])
|
spec = fftshift(spec, axes=1)
|
||||||
|
return spec, delays
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram_interp(
|
||||||
|
time: np.ndarray,
|
||||||
|
delays: np.ndarray,
|
||||||
|
values: np.ndarray,
|
||||||
|
old_w: np.ndarray,
|
||||||
|
w_ind: np.ndarray,
|
||||||
|
new_w: np.ndarray,
|
||||||
|
gate_width=200e-15,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
returns the spectorgram of the field already interpolated along the frequency axis
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
time : 1D array-like
|
||||||
|
time in the co-moving frame of reference
|
||||||
|
values : 1D array-like
|
||||||
|
field array that matches the time array
|
||||||
|
t_res : int, optional
|
||||||
|
how many "bins" the time array is subdivided into. Default : 256
|
||||||
|
t_win : float, optional
|
||||||
|
total time window (=length of time) over which the spectrogram is computed. Default : 24e-12
|
||||||
|
gate_width : float, optional
|
||||||
|
width of the gaussian gate function (=sqrt(2 log(2)) * FWHM). Default : 200e-15
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
spec : 2D array
|
||||||
|
real 2D spectrogram
|
||||||
|
delays : 1D array of size t_res
|
||||||
|
new time axis
|
||||||
|
"""
|
||||||
|
|
||||||
|
spec = np.zeros((len(delays), len(new_w)))
|
||||||
|
for i, delay in enumerate(delays):
|
||||||
|
masked = values * np.exp(-(((time - delay) / gate_width) ** 2))
|
||||||
|
spec[i] = math.linear_interp_1d(old_w, math.abs2(fft(masked)[w_ind]), new_w)
|
||||||
|
|
||||||
return spec, delays
|
return spec, delays
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ class RK4IP:
|
|||||||
state, self.params.linear_operator, self.params.nonlinear_operator
|
state, self.params.linear_operator, self.params.nonlinear_operator
|
||||||
)
|
)
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
|
# catch overflows as errors
|
||||||
warnings.filterwarnings("error", category=RuntimeWarning)
|
warnings.filterwarnings("error", category=RuntimeWarning)
|
||||||
for state in integrator:
|
for state in integrator:
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ class UnitMap(dict):
|
|||||||
def chained_function(x: _T) -> _T:
|
def chained_function(x: _T) -> _T:
|
||||||
return c2.inv(c1(x))
|
return c2.inv(c1(x))
|
||||||
|
|
||||||
|
chained_function.name = f"{name_1}_to_{name_2}"
|
||||||
chained_function.__name__ = f"{name_1}_to_{name_2}"
|
chained_function.__name__ = f"{name_1}_to_{name_2}"
|
||||||
chained_function.__doc__ = f"converts x from {name_1} to {name_2}"
|
chained_function.__doc__ = f"converts x from {name_1} to {name_2}"
|
||||||
|
|
||||||
@@ -124,17 +125,17 @@ class unit:
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
@unit("WL", r"Wavelength $\lambda$ (m)")
|
@unit("WL", r"Wavelength λ (m)")
|
||||||
def m(l: _T) -> _T:
|
def m(l: _T) -> _T:
|
||||||
return 2 * pi * c / l
|
return 2 * pi * c / l
|
||||||
|
|
||||||
|
|
||||||
@unit("WL", r"Wavelength $\lambda$ (nm)")
|
@unit("WL", r"Wavelength λ (nm)")
|
||||||
def nm(l: _T) -> _T:
|
def nm(l: _T) -> _T:
|
||||||
return 2 * pi * c / (l * 1e-9)
|
return 2 * pi * c / (l * 1e-9)
|
||||||
|
|
||||||
|
|
||||||
@unit("WL", r"Wavelength $\lambda$ (μm)")
|
@unit("WL", r"Wavelength λ (μm)")
|
||||||
def um(l: _T) -> _T:
|
def um(l: _T) -> _T:
|
||||||
return 2 * pi * c / (l * 1e-6)
|
return 2 * pi * c / (l * 1e-6)
|
||||||
|
|
||||||
@@ -418,6 +419,10 @@ class PlotRange(tuple):
|
|||||||
conserved_quantity: bool = property(itemgetter(3))
|
conserved_quantity: bool = property(itemgetter(3))
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def must_correct_wl(self) -> bool:
|
||||||
|
return self.unit.type == "WL" and self.conserved_quantity
|
||||||
|
|
||||||
def __new__(cls, left, right, unit, conserved_quantity=True):
|
def __new__(cls, left, right, unit, conserved_quantity=True):
|
||||||
return tuple.__new__(cls, (left, right, get_unit(unit), conserved_quantity))
|
return tuple.__new__(cls, (left, right, get_unit(unit), conserved_quantity))
|
||||||
|
|
||||||
@@ -461,8 +466,8 @@ def sort_axis(
|
|||||||
rw = (-4, 4, s)
|
rw = (-4, 4, s)
|
||||||
rt = (-2, 6, s)
|
rt = (-2, 6, s)
|
||||||
|
|
||||||
w, cw = sort_axis(w, rw)
|
w, cw, _ = sort_axis(w, rw)
|
||||||
t, ct = sort_axis(t, rt)
|
t, ct, _ = sort_axis(t, rt)
|
||||||
|
|
||||||
# slice y according to the given ranges
|
# slice y according to the given ranges
|
||||||
y = y[ct][:, cw]
|
y = y[ct][:, cw]
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Any, Callable, Literal, Optional, Union
|
|||||||
|
|
||||||
import matplotlib.gridspec as gs
|
import matplotlib.gridspec as gs
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from matplotlib.colors import ListedColormap
|
from matplotlib.colors import ListedColormap
|
||||||
from scipy.interpolate import UnivariateSpline
|
from scipy.interpolate import UnivariateSpline
|
||||||
@@ -12,7 +13,7 @@ from scipy.interpolate.interpolate import interp1d
|
|||||||
from . import math
|
from . import math
|
||||||
from .const import PARAM_SEPARATOR
|
from .const import PARAM_SEPARATOR
|
||||||
from .defaults import default_plotting as defaults
|
from .defaults import default_plotting as defaults
|
||||||
from .math import abs2, span
|
from .math import abs2, span, linear_interp_2d
|
||||||
from .parameter import Parameters
|
from .parameter import Parameters
|
||||||
from .physics import pulse, units
|
from .physics import pulse, units
|
||||||
from .physics.units import PlotRange, sort_axis
|
from .physics.units import PlotRange, sort_axis
|
||||||
@@ -256,8 +257,10 @@ def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0
|
|||||||
def propagation_plot(
|
def propagation_plot(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
plt_range: Union[PlotRange, RangeType],
|
plt_range: Union[PlotRange, RangeType],
|
||||||
params: Parameters,
|
x_axis: np.ndarray = None,
|
||||||
ax: plt.Axes,
|
y_axis: np.ndarray = None,
|
||||||
|
params: Parameters = None,
|
||||||
|
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
|
||||||
log: Union[int, float, bool, str] = "1D",
|
log: Union[int, float, bool, str] = "1D",
|
||||||
renormalize: bool = False,
|
renormalize: bool = False,
|
||||||
vmin: float = None,
|
vmin: float = None,
|
||||||
@@ -295,10 +298,12 @@ def propagation_plot(
|
|||||||
Axes obj on which to draw, by default None
|
Axes obj on which to draw, by default None
|
||||||
|
|
||||||
"""
|
"""
|
||||||
x_axis, y_axis, values = transform_2D_propagation(values, plt_range, params, log, skip)
|
x_axis, y_axis, values = transform_2D_propagation(
|
||||||
if renormalize and log is False:
|
values, plt_range, x_axis, y_axis, log, skip, params
|
||||||
|
)
|
||||||
|
if renormalize and not log:
|
||||||
values = values / values.max()
|
values = values / values.max()
|
||||||
if log is not False:
|
if log:
|
||||||
vmax = defaults["vmax"] if vmax is None else vmax
|
vmax = defaults["vmax"] if vmax is None else vmax
|
||||||
vmin = defaults["vmin"] if vmin is None else vmin
|
vmin = defaults["vmin"] if vmin is None else vmin
|
||||||
plot_2D(
|
plot_2D(
|
||||||
@@ -320,7 +325,7 @@ def plot_2D(
|
|||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
x_axis: np.ndarray,
|
x_axis: np.ndarray,
|
||||||
y_axis: np.ndarray,
|
y_axis: np.ndarray,
|
||||||
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]],
|
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
|
||||||
x_label: str = None,
|
x_label: str = None,
|
||||||
y_label: str = None,
|
y_label: str = None,
|
||||||
vmin: float = None,
|
vmin: float = None,
|
||||||
@@ -367,6 +372,8 @@ def plot_2D(
|
|||||||
cbar_ax = None
|
cbar_ax = None
|
||||||
if isinstance(ax, tuple) and len(ax) > 1:
|
if isinstance(ax, tuple) and len(ax) > 1:
|
||||||
ax, cbar_ax = ax[0], ax[1]
|
ax, cbar_ax = ax[0], ax[1]
|
||||||
|
elif ax is None:
|
||||||
|
ax = plt.gca()
|
||||||
|
|
||||||
fig = ax.get_figure()
|
fig = ax.get_figure()
|
||||||
|
|
||||||
@@ -414,10 +421,11 @@ def plot_2D(
|
|||||||
def transform_2D_propagation(
|
def transform_2D_propagation(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
plt_range: Union[PlotRange, RangeType],
|
plt_range: Union[PlotRange, RangeType],
|
||||||
params: Parameters,
|
x_axis: np.ndarray = None,
|
||||||
|
y_axis: np.ndarray = None,
|
||||||
log: Union[int, float, bool, str] = "1D",
|
log: Union[int, float, bool, str] = "1D",
|
||||||
skip: int = 1,
|
skip: int = 1,
|
||||||
y_axis=None,
|
params: Parameters = None,
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""transforms raws values into plottable values
|
"""transforms raws values into plottable values
|
||||||
|
|
||||||
@@ -427,7 +435,11 @@ def transform_2D_propagation(
|
|||||||
values to transform
|
values to transform
|
||||||
plt_range : Union[PlotRange, RangeType]
|
plt_range : Union[PlotRange, RangeType]
|
||||||
range
|
range
|
||||||
params : Parameters
|
x_axis : np.ndarray
|
||||||
|
corresponding x values in SI units
|
||||||
|
y_axis : np.ndarray
|
||||||
|
corresponding y values in SI units
|
||||||
|
params : Parameters, optional
|
||||||
parameters of the simulation
|
parameters of the simulation
|
||||||
log : Union[int, float, bool, str], optional
|
log : Union[int, float, bool, str], optional
|
||||||
see apply_log, by default "1D"
|
see apply_log, by default "1D"
|
||||||
@@ -448,16 +460,17 @@ def transform_2D_propagation(
|
|||||||
ValueError
|
ValueError
|
||||||
incorrect shape
|
incorrect shape
|
||||||
"""
|
"""
|
||||||
|
x_axis = get_x_axis(plt_range, x_axis, params)
|
||||||
|
if y_axis is None and params is not None:
|
||||||
|
y_axis = params.z_targets
|
||||||
|
|
||||||
if values.ndim != 2:
|
if values.ndim != 2:
|
||||||
raise ValueError(f"shape was {values.shape}. Can only plot 2D array")
|
raise ValueError(f"shape was {values.shape}. Can only plot 2D array")
|
||||||
is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params)
|
is_complex, plt_range = prep_plot_axis(values, plt_range)
|
||||||
if is_complex or any(values.ravel() < 0):
|
if is_complex or any(values.ravel() < 0):
|
||||||
values = abs2(values)
|
values = abs2(values)
|
||||||
# if params.full_field and plt_range.unit.type == "TIME":
|
# if params.full_field and plt_range.unit.type == "TIME":
|
||||||
# values = envelope_2d(x_axis, values)
|
# values = envelope_2d(x_axis, values)
|
||||||
if y_axis is None:
|
|
||||||
y_axis = params.z_targets
|
|
||||||
|
|
||||||
x_axis, values = uniform_axis(x_axis, values, plt_range)
|
x_axis, values = uniform_axis(x_axis, values, plt_range)
|
||||||
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
|
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
|
||||||
@@ -465,6 +478,17 @@ def transform_2D_propagation(
|
|||||||
return x_axis[::skip], y_axis, values[:, ::skip]
|
return x_axis[::skip], y_axis, values[:, ::skip]
|
||||||
|
|
||||||
|
|
||||||
|
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
|
||||||
|
if x_axis is None and params is not None:
|
||||||
|
if plt_range.unit.type in {"WL", "FREQ", "AFREQ"}:
|
||||||
|
x_axis = params.w.copy()
|
||||||
|
else:
|
||||||
|
x_axis = params.t.copy()
|
||||||
|
if x_axis is None:
|
||||||
|
raise ValueError("No x axis specified")
|
||||||
|
return x_axis
|
||||||
|
|
||||||
|
|
||||||
def mean_values_plot(
|
def mean_values_plot(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
plt_range: Union[PlotRange, RangeType],
|
plt_range: Union[PlotRange, RangeType],
|
||||||
@@ -538,10 +562,11 @@ def transform_mean_values(
|
|||||||
np.ndarray, shape (m, n)
|
np.ndarray, shape (m, n)
|
||||||
all the values
|
all the values
|
||||||
"""
|
"""
|
||||||
|
AAA
|
||||||
if values.ndim != 2:
|
if values.ndim != 2:
|
||||||
print(f"Shape was {values.shape}. Can only plot 2D arrays")
|
print(f"Shape was {values.shape}. Can only plot 2D arrays")
|
||||||
return
|
return
|
||||||
is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params)
|
is_complex, plt_range = prep_plot_axis(values, plt_range, params)
|
||||||
if is_complex:
|
if is_complex:
|
||||||
values = abs2(values)
|
values = abs2(values)
|
||||||
new_axis, ind, ext = sort_axis(x_axis, plt_range)
|
new_axis, ind, ext = sort_axis(x_axis, plt_range)
|
||||||
@@ -636,8 +661,9 @@ def plot_mean(
|
|||||||
def single_position_plot(
|
def single_position_plot(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
plt_range: Union[PlotRange, RangeType],
|
plt_range: Union[PlotRange, RangeType],
|
||||||
params: Parameters,
|
x_axis: np.ndarray = None,
|
||||||
ax: plt.Axes,
|
ax: plt.Axes = None,
|
||||||
|
params: Parameters = None,
|
||||||
log: Union[str, int, float, bool] = False,
|
log: Union[str, int, float, bool] = False,
|
||||||
vmin: float = None,
|
vmin: float = None,
|
||||||
vmax: float = None,
|
vmax: float = None,
|
||||||
@@ -647,8 +673,7 @@ def single_position_plot(
|
|||||||
y_label: str = None,
|
y_label: str = None,
|
||||||
**line_kwargs,
|
**line_kwargs,
|
||||||
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]:
|
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]:
|
||||||
|
x_axis, values = transform_1D_values(values, plt_range, x_axis, params, log, spacing)
|
||||||
x_axis, values = transform_1D_values(values, plt_range, params, log, spacing)
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
values = values / values.max()
|
values = values / values.max()
|
||||||
|
|
||||||
@@ -664,7 +689,7 @@ def single_position_plot(
|
|||||||
def plot_1D(
|
def plot_1D(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
x_axis: np.ndarray,
|
x_axis: np.ndarray,
|
||||||
ax: plt.Axes,
|
ax: Optional[plt.Axes],
|
||||||
x_label: str = None,
|
x_label: str = None,
|
||||||
y_label: str = None,
|
y_label: str = None,
|
||||||
vmin: float = None,
|
vmin: float = None,
|
||||||
@@ -693,11 +718,18 @@ def plot_1D(
|
|||||||
transpose : bool, optional
|
transpose : bool, optional
|
||||||
rotate the plot 90° counterclockwise, by default False
|
rotate the plot 90° counterclockwise, by default False
|
||||||
"""
|
"""
|
||||||
if transpose:
|
if ax is None:
|
||||||
|
ax = plt.gca()
|
||||||
|
if transpose == -1:
|
||||||
|
(line,) = ax.plot(values, x_axis, **line_kwargs)
|
||||||
|
ax.set_xlim(vmax, vmin)
|
||||||
|
ax.set_xlabel(y_label)
|
||||||
|
ax.set_ylabel(x_label)
|
||||||
|
elif transpose == 1:
|
||||||
(line,) = ax.plot(values, x_axis, **line_kwargs)
|
(line,) = ax.plot(values, x_axis, **line_kwargs)
|
||||||
ax.yaxis.tick_right()
|
ax.yaxis.tick_right()
|
||||||
ax.yaxis.set_label_position("right")
|
ax.yaxis.set_label_position("right")
|
||||||
ax.set_xlim(vmax, vmin)
|
ax.set_xlim(vmin, vmax)
|
||||||
ax.set_xlabel(y_label)
|
ax.set_xlabel(y_label)
|
||||||
ax.set_ylabel(x_label)
|
ax.set_ylabel(x_label)
|
||||||
else:
|
else:
|
||||||
@@ -711,7 +743,8 @@ def plot_1D(
|
|||||||
def transform_1D_values(
|
def transform_1D_values(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
plt_range: Union[PlotRange, RangeType],
|
plt_range: Union[PlotRange, RangeType],
|
||||||
params: Parameters,
|
x_axis: np.ndarray = None,
|
||||||
|
params: Parameters = None,
|
||||||
log: Union[int, float, bool] = False,
|
log: Union[int, float, bool] = False,
|
||||||
spacing: Union[int, float] = 1,
|
spacing: Union[int, float] = 1,
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
@@ -737,9 +770,10 @@ def transform_1D_values(
|
|||||||
tuple[np.ndarray, np.ndarray]
|
tuple[np.ndarray, np.ndarray]
|
||||||
x axis and values
|
x axis and values
|
||||||
"""
|
"""
|
||||||
|
x_axis = get_x_axis(plt_range, x_axis, params)
|
||||||
if len(values.shape) != 1:
|
if len(values.shape) != 1:
|
||||||
raise ValueError("Can only plot 1D values")
|
raise ValueError("Can only plot 1D values")
|
||||||
is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params)
|
is_complex, plt_range = prep_plot_axis(values, plt_range)
|
||||||
if is_complex:
|
if is_complex:
|
||||||
values = abs2(values)
|
values = abs2(values)
|
||||||
new_axis, ind, ext = sort_axis(x_axis, plt_range)
|
new_axis, ind, ext = sort_axis(x_axis, plt_range)
|
||||||
@@ -763,17 +797,20 @@ def transform_1D_values(
|
|||||||
|
|
||||||
def plot_spectrogram(
|
def plot_spectrogram(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
x_range: RangeType,
|
x_range: PlotRange,
|
||||||
y_range: RangeType,
|
y_range: PlotRange,
|
||||||
params: Parameters,
|
x_axis: np.ndarray = None,
|
||||||
t_res: int = None,
|
y_axis: np.ndarray = None,
|
||||||
gate_width: float = None,
|
t_res: int = 512,
|
||||||
|
w_res: int = 512,
|
||||||
|
gate_width: float = 200e-15,
|
||||||
|
params: Parameters = None,
|
||||||
log: bool = "2D",
|
log: bool = "2D",
|
||||||
vmin: float = None,
|
vmin: float = -50,
|
||||||
vmax: float = None,
|
vmax: float = 0,
|
||||||
cbar_label: str = "normalized intensity (dB)",
|
cbar_label: str = "normalized intensity (dB)",
|
||||||
cmap: str = None,
|
cmap: str = None,
|
||||||
ax: plt.Axes = None,
|
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
|
||||||
):
|
):
|
||||||
"""Plots a spectrogram given a complex field in the time domain
|
"""Plots a spectrogram given a complex field in the time domain
|
||||||
Parameters
|
Parameters
|
||||||
@@ -781,7 +818,7 @@ def plot_spectrogram(
|
|||||||
values : 2D array
|
values : 2D array
|
||||||
axis 0 defines the position in the fiber and axis 1 the position in time, frequency or wl
|
axis 0 defines the position in the fiber and axis 1 the position in time, frequency or wl
|
||||||
example : [[1, 2, 3], [0, 1, 0]] describes a quantity at 3 different freq/time and at two locations in the fiber
|
example : [[1, 2, 3], [0, 1, 0]] describes a quantity at 3 different freq/time and at two locations in the fiber
|
||||||
x_range, y_range : tupple (min, max, units)
|
x_range, y_range : PlotRange
|
||||||
one of them must be time, the other one must be wl/freq
|
one of them must be time, the other one must be wl/freq
|
||||||
min, max : int or float
|
min, max : int or float
|
||||||
minimum and maximum values given in the desired units
|
minimum and maximum values given in the desired units
|
||||||
@@ -798,10 +835,6 @@ def plot_spectrogram(
|
|||||||
max value of the colorbar
|
max value of the colorbar
|
||||||
cbar_label : str or None
|
cbar_label : str or None
|
||||||
label of the colorbar. Will not draw colorbar if None
|
label of the colorbar. Will not draw colorbar if None
|
||||||
file_type : str, optional
|
|
||||||
usually pdf or png
|
|
||||||
plt_name : str, optional
|
|
||||||
special name to give to the plot. A name is automatically assigned anyway
|
|
||||||
cmap : str, optional
|
cmap : str, optional
|
||||||
colormap to be used in matplotlib.pyplot.imshow
|
colormap to be used in matplotlib.pyplot.imshow
|
||||||
ax : matplotlib.axes._subplots.AxesSubplot object or tupple of 2 axis objects, optional
|
ax : matplotlib.axes._subplots.AxesSubplot object or tupple of 2 axis objects, optional
|
||||||
@@ -812,35 +845,49 @@ def plot_spectrogram(
|
|||||||
if values.ndim != 1:
|
if values.ndim != 1:
|
||||||
print("plot_spectrogram can only plot 1D arrays")
|
print("plot_spectrogram can only plot 1D arrays")
|
||||||
return
|
return
|
||||||
x_range: PlotRange
|
x_axis = get_x_axis(x_range, x_axis, params)
|
||||||
y_range: PlotRange
|
y_axis = get_x_axis(y_range, y_axis, params)
|
||||||
_, x_axis, x_range = prep_plot_axis(values, x_range, params)
|
_, x_range = prep_plot_axis(values, x_range)
|
||||||
_, y_axis, y_range = prep_plot_axis(values, y_range, params)
|
_, y_range = prep_plot_axis(values, y_range)
|
||||||
|
|
||||||
if (x_range.unit.type == "TIME") == (y_range.unit.type == "TIME"):
|
if (x_range.unit.type == "TIME") == (y_range.unit.type == "TIME"):
|
||||||
print("exactly one range must be a time range")
|
print("exactly one range must be a time range")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 0 axis means x-axis -> determine final orientation of spectrogram
|
# 0 axis means x-axis -> determine final orientation of spectrogram
|
||||||
time_axis = 0 if x_range.unit.type not in ["WL", "FREQ", "AFREQ"] else 1
|
time_axis = 1 if x_range.unit.type not in ["WL", "FREQ", "AFREQ"] else 1
|
||||||
if time_axis == 0:
|
if time_axis == 1:
|
||||||
|
t_axis = x_axis
|
||||||
t_range = x_range
|
t_range = x_range
|
||||||
|
w_range = y_range
|
||||||
|
w_axis = y_axis
|
||||||
else:
|
else:
|
||||||
|
t_axis = y_axis
|
||||||
t_range = y_range
|
t_range = y_range
|
||||||
|
w_range = x_range
|
||||||
|
w_axis = x_axis
|
||||||
|
|
||||||
|
old_w, w_ind, _ = w_range.sort_axis(w_axis)
|
||||||
|
new_w = np.linspace(w_range.left, w_range.right, w_res)
|
||||||
|
|
||||||
# Actually compute the spectrogram
|
# Actually compute the spectrogram
|
||||||
t_win = 2 * np.max(t_range.unit(np.abs((t_range.left, t_range.right))))
|
delays = np.linspace(t_range.unit(t_range.left), t_range.unit(t_range.right), t_res)
|
||||||
spec_kwargs = dict(t_res=t_res, t_win=t_win, gate_width=gate_width, shift=False)
|
spec, new_t = pulse.spectrogram_interp(
|
||||||
spec, new_t = pulse.spectrogram(
|
t_axis, delays, values, old_w, w_ind, new_w, gate_width=gate_width
|
||||||
params.t.copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None}
|
|
||||||
)
|
)
|
||||||
if time_axis == 0:
|
|
||||||
x_axis = new_t
|
|
||||||
else:
|
|
||||||
y_axis = new_t
|
|
||||||
|
|
||||||
x_axis, spec = uniform_axis(x_axis, spec, x_range)
|
new_t = t_range.unit.inv(new_t)
|
||||||
y_axis, spec.T[:] = uniform_axis(y_axis, spec.T, y_range)
|
|
||||||
|
if w_range.must_correct_wl:
|
||||||
|
spec = np.apply_along_axis(units.to_WL, 1, spec, new_w)
|
||||||
|
|
||||||
|
if time_axis == 1:
|
||||||
|
spec = spec.T
|
||||||
|
x_axis = new_t
|
||||||
|
y_axis = new_w
|
||||||
|
else:
|
||||||
|
x_axis = new_w
|
||||||
|
y_axis = new_t
|
||||||
|
|
||||||
values = apply_log(spec, log)
|
values = apply_log(spec, log)
|
||||||
|
|
||||||
@@ -849,8 +896,8 @@ def plot_spectrogram(
|
|||||||
x_axis,
|
x_axis,
|
||||||
y_axis,
|
y_axis,
|
||||||
ax,
|
ax,
|
||||||
x_range.unit.label,
|
None,
|
||||||
y_range.unit.label,
|
None,
|
||||||
vmin,
|
vmin,
|
||||||
vmax,
|
vmax,
|
||||||
False,
|
False,
|
||||||
@@ -907,7 +954,7 @@ def uniform_axis(
|
|||||||
if plt_range.unit.type == "WL" and plt_range.conserved_quantity:
|
if plt_range.unit.type == "WL" and plt_range.conserved_quantity:
|
||||||
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
|
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
|
||||||
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
|
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
|
||||||
values = np.array([interp1d(tmp_axis, v[ind])(new_axis) for v in values])
|
values = linear_interp_2d(tmp_axis, values[:, ind], new_axis)
|
||||||
return new_axis, values.squeeze()
|
return new_axis, values.squeeze()
|
||||||
|
|
||||||
|
|
||||||
@@ -954,16 +1001,12 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
|
|||||||
|
|
||||||
|
|
||||||
def prep_plot_axis(
|
def prep_plot_axis(
|
||||||
values: np.ndarray, plt_range: Union[PlotRange, RangeType], params: Parameters
|
values: np.ndarray, plt_range: Union[PlotRange, RangeType]
|
||||||
) -> tuple[bool, np.ndarray, PlotRange]:
|
) -> tuple[bool, PlotRange]:
|
||||||
is_spectrum = values.dtype == "complex"
|
is_spectrum = values.dtype == "complex"
|
||||||
if not isinstance(plt_range, PlotRange):
|
if not isinstance(plt_range, PlotRange):
|
||||||
plt_range = PlotRange(*plt_range)
|
plt_range = PlotRange(*plt_range)
|
||||||
if plt_range.unit.type in ["WL", "FREQ", "AFREQ"]:
|
return is_spectrum, plt_range
|
||||||
x_axis = params.w.copy()
|
|
||||||
else:
|
|
||||||
x_axis = params.t.copy()
|
|
||||||
return is_spectrum, x_axis, plt_range
|
|
||||||
|
|
||||||
|
|
||||||
def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):
|
def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):
|
||||||
|
|||||||
Reference in New Issue
Block a user