diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 21be5fa..ab9860d 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -359,6 +359,32 @@ def _polynom_extrapolation_in_place(y: np.ndarray, left_ind: int, right_ind: int return y +@numba.jit(nopython=True) +def linear_interp_2d(old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray): + new_vals = np.zeros((len(old_y), len(new_x))) + interpolable = (new_x > old_x[0]) & (new_x <= old_x[-1]) + equal = new_x == old_x[0] + inds = np.searchsorted(old_x, new_x[interpolable]) + for i, val in enumerate(old_y): + new_vals[i][interpolable] = val[inds - 1] + (new_x[interpolable] - old_x[inds - 1]) * ( + val[inds] - val[inds - 1] + ) / (old_x[inds] - old_x[inds - 1]) + new_vals[i][equal] = val[0] + return new_vals + + +@numba.jit(nopython=True) +def linear_interp_1d(old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray): + new_vals = np.zeros(len(new_x)) + interpolable = (new_x > old_x[0]) & (new_x <= old_x[-1]) + inds = np.searchsorted(old_x, new_x[interpolable]) + new_vals[interpolable] = old_y[inds - 1] + (new_x[interpolable] - old_x[inds - 1]) * ( + old_y[inds] - old_y[inds - 1] + ) / (old_x[inds] - old_x[inds - 1]) + new_vals[new_x == old_x[0]] = old_y[0] + return new_vals + + def envelope_ind( signal: np.ndarray, dmin: int = 1, dmax: int = 1, split: bool = False ) -> tuple[np.ndarray, np.ndarray]: diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 2cdccbc..4734f69 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -374,7 +374,7 @@ class Parameters: t_num: int = Parameter(positive(int)) z_num: int = Parameter(positive(int)) time_window: float = Parameter(positive(float, int)) - dt: float = Parameter(in_range_excl(0, 5e-15)) + dt: float = Parameter(in_range_excl(0, 10e-15)) tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11) step_size: float = Parameter(non_negative(float, int), default=0) interpolation_range: tuple[float, float] = Parameter( diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index d3eff47..778cfd8 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -15,6 +15,7 @@ from dataclasses import astuple, dataclass from pathlib import Path from typing import Literal, Tuple, TypeVar +import numba import matplotlib.pyplot as plt import numpy as np from numpy import pi @@ -570,7 +571,9 @@ def ideal_compressed_pulse(spectra): return math.abs2(fftshift(ifft(np.sqrt(np.mean(math.abs2(spectra), axis=0))))) -def spectrogram(time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift=False): +def spectrogram( + time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift=False, w_ind: np.ndarray = None +): """ returns the spectorgram of the field given in values @@ -594,14 +597,62 @@ def spectrogram(time, values, t_res=256, t_win=24e-12, gate_width=200e-15, shift delays : 1D array of size t_res new time axis """ - t_lim = t_win / 2 - delays = np.linspace(-t_lim, t_lim, t_res) - spec = np.zeros((t_res, len(time))) + if isinstance(t_win, tuple): + left, right = t_win + else: + t_lim = t_win / 2 + left, right = -t_lim, t_lim + + delays = np.linspace(left, right, t_res) + + spec = np.zeros((t_res, len(time) if w_ind is None else len(w_ind))) for i, delay in enumerate(delays): masked = values * np.exp(-(((time - delay) / gate_width) ** 2)) spec[i] = math.abs2(fft(masked)) - if shift: - spec[i] = fftshift(spec[i]) + + if shift: + spec = fftshift(spec, axes=1) + return spec, delays + + +def spectrogram_interp( + time: np.ndarray, + delays: np.ndarray, + values: np.ndarray, + old_w: np.ndarray, + w_ind: np.ndarray, + new_w: np.ndarray, + gate_width=200e-15, +): + """ + returns the spectorgram of the field already interpolated along the frequency axis + + Parameters + ---------- + time : 1D array-like + time in the co-moving frame of reference + values : 1D array-like + field array that matches the time array + t_res : int, optional + how many "bins" the time array is subdivided into. Default : 256 + t_win : float, optional + total time window (=length of time) over which the spectrogram is computed. Default : 24e-12 + gate_width : float, optional + width of the gaussian gate function (=sqrt(2 log(2)) * FWHM). Default : 200e-15 + + Returns + ---------- + spec : 2D array + real 2D spectrogram + delays : 1D array of size t_res + new time axis + """ + + spec = np.zeros((len(delays), len(new_w))) + for i, delay in enumerate(delays): + masked = values * np.exp(-(((time - delay) / gate_width) ** 2)) + spec[i] = math.linear_interp_1d(old_w, math.abs2(fft(masked)[w_ind]), new_w) + return spec, delays diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 17b7079..c607459 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -180,6 +180,7 @@ class RK4IP: state, self.params.linear_operator, self.params.nonlinear_operator ) with warnings.catch_warnings(): + # catch overflows as errors warnings.filterwarnings("error", category=RuntimeWarning) for state in integrator: diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index f1c8348..01fa75e 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -47,6 +47,7 @@ class UnitMap(dict): def chained_function(x: _T) -> _T: return c2.inv(c1(x)) + chained_function.name = f"{name_1}_to_{name_2}" chained_function.__name__ = f"{name_1}_to_{name_2}" chained_function.__doc__ = f"converts x from {name_1} to {name_2}" @@ -124,17 +125,17 @@ class unit: return func -@unit("WL", r"Wavelength $\lambda$ (m)") +@unit("WL", r"Wavelength λ (m)") def m(l: _T) -> _T: return 2 * pi * c / l -@unit("WL", r"Wavelength $\lambda$ (nm)") +@unit("WL", r"Wavelength λ (nm)") def nm(l: _T) -> _T: return 2 * pi * c / (l * 1e-9) -@unit("WL", r"Wavelength $\lambda$ (μm)") +@unit("WL", r"Wavelength λ (μm)") def um(l: _T) -> _T: return 2 * pi * c / (l * 1e-6) @@ -418,6 +419,10 @@ class PlotRange(tuple): conserved_quantity: bool = property(itemgetter(3)) __slots__ = [] + @property + def must_correct_wl(self) -> bool: + return self.unit.type == "WL" and self.conserved_quantity + def __new__(cls, left, right, unit, conserved_quantity=True): return tuple.__new__(cls, (left, right, get_unit(unit), conserved_quantity)) @@ -461,8 +466,8 @@ def sort_axis( rw = (-4, 4, s) rt = (-2, 6, s) - w, cw = sort_axis(w, rw) - t, ct = sort_axis(t, rt) + w, cw, _ = sort_axis(w, rw) + t, ct, _ = sort_axis(t, rt) # slice y according to the given ranges y = y[ct][:, cw] diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 6114fd5..9290424 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Literal, Optional, Union import matplotlib.gridspec as gs import matplotlib.pyplot as plt +import numba import numpy as np from matplotlib.colors import ListedColormap from scipy.interpolate import UnivariateSpline @@ -12,7 +13,7 @@ from scipy.interpolate.interpolate import interp1d from . import math from .const import PARAM_SEPARATOR from .defaults import default_plotting as defaults -from .math import abs2, span +from .math import abs2, span, linear_interp_2d from .parameter import Parameters from .physics import pulse, units from .physics.units import PlotRange, sort_axis @@ -256,8 +257,10 @@ def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0 def propagation_plot( values: np.ndarray, plt_range: Union[PlotRange, RangeType], - params: Parameters, - ax: plt.Axes, + x_axis: np.ndarray = None, + y_axis: np.ndarray = None, + params: Parameters = None, + ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None, log: Union[int, float, bool, str] = "1D", renormalize: bool = False, vmin: float = None, @@ -295,10 +298,12 @@ def propagation_plot( Axes obj on which to draw, by default None """ - x_axis, y_axis, values = transform_2D_propagation(values, plt_range, params, log, skip) - if renormalize and log is False: + x_axis, y_axis, values = transform_2D_propagation( + values, plt_range, x_axis, y_axis, log, skip, params + ) + if renormalize and not log: values = values / values.max() - if log is not False: + if log: vmax = defaults["vmax"] if vmax is None else vmax vmin = defaults["vmin"] if vmin is None else vmin plot_2D( @@ -320,7 +325,7 @@ def plot_2D( values: np.ndarray, x_axis: np.ndarray, y_axis: np.ndarray, - ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]], + ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None, x_label: str = None, y_label: str = None, vmin: float = None, @@ -367,6 +372,8 @@ def plot_2D( cbar_ax = None if isinstance(ax, tuple) and len(ax) > 1: ax, cbar_ax = ax[0], ax[1] + elif ax is None: + ax = plt.gca() fig = ax.get_figure() @@ -414,10 +421,11 @@ def plot_2D( def transform_2D_propagation( values: np.ndarray, plt_range: Union[PlotRange, RangeType], - params: Parameters, + x_axis: np.ndarray = None, + y_axis: np.ndarray = None, log: Union[int, float, bool, str] = "1D", skip: int = 1, - y_axis=None, + params: Parameters = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """transforms raws values into plottable values @@ -427,7 +435,11 @@ def transform_2D_propagation( values to transform plt_range : Union[PlotRange, RangeType] range - params : Parameters + x_axis : np.ndarray + corresponding x values in SI units + y_axis : np.ndarray + corresponding y values in SI units + params : Parameters, optional parameters of the simulation log : Union[int, float, bool, str], optional see apply_log, by default "1D" @@ -448,16 +460,17 @@ def transform_2D_propagation( ValueError incorrect shape """ + x_axis = get_x_axis(plt_range, x_axis, params) + if y_axis is None and params is not None: + y_axis = params.z_targets if values.ndim != 2: raise ValueError(f"shape was {values.shape}. Can only plot 2D array") - is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params) + is_complex, plt_range = prep_plot_axis(values, plt_range) if is_complex or any(values.ravel() < 0): values = abs2(values) # if params.full_field and plt_range.unit.type == "TIME": # values = envelope_2d(x_axis, values) - if y_axis is None: - y_axis = params.z_targets x_axis, values = uniform_axis(x_axis, values, plt_range) y_axis, values.T[:] = uniform_axis(y_axis, values.T, None) @@ -465,6 +478,17 @@ def transform_2D_propagation( return x_axis[::skip], y_axis, values[:, ::skip] +def get_x_axis(plt_range, x_axis, params) -> np.ndarray: + if x_axis is None and params is not None: + if plt_range.unit.type in {"WL", "FREQ", "AFREQ"}: + x_axis = params.w.copy() + else: + x_axis = params.t.copy() + if x_axis is None: + raise ValueError("No x axis specified") + return x_axis + + def mean_values_plot( values: np.ndarray, plt_range: Union[PlotRange, RangeType], @@ -538,10 +562,11 @@ def transform_mean_values( np.ndarray, shape (m, n) all the values """ + AAA if values.ndim != 2: print(f"Shape was {values.shape}. Can only plot 2D arrays") return - is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params) + is_complex, plt_range = prep_plot_axis(values, plt_range, params) if is_complex: values = abs2(values) new_axis, ind, ext = sort_axis(x_axis, plt_range) @@ -636,8 +661,9 @@ def plot_mean( def single_position_plot( values: np.ndarray, plt_range: Union[PlotRange, RangeType], - params: Parameters, - ax: plt.Axes, + x_axis: np.ndarray = None, + ax: plt.Axes = None, + params: Parameters = None, log: Union[str, int, float, bool] = False, vmin: float = None, vmax: float = None, @@ -647,8 +673,7 @@ def single_position_plot( y_label: str = None, **line_kwargs, ) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]: - - x_axis, values = transform_1D_values(values, plt_range, params, log, spacing) + x_axis, values = transform_1D_values(values, plt_range, x_axis, params, log, spacing) if renormalize: values = values / values.max() @@ -664,7 +689,7 @@ def single_position_plot( def plot_1D( values: np.ndarray, x_axis: np.ndarray, - ax: plt.Axes, + ax: Optional[plt.Axes], x_label: str = None, y_label: str = None, vmin: float = None, @@ -693,11 +718,18 @@ def plot_1D( transpose : bool, optional rotate the plot 90° counterclockwise, by default False """ - if transpose: + if ax is None: + ax = plt.gca() + if transpose == -1: + (line,) = ax.plot(values, x_axis, **line_kwargs) + ax.set_xlim(vmax, vmin) + ax.set_xlabel(y_label) + ax.set_ylabel(x_label) + elif transpose == 1: (line,) = ax.plot(values, x_axis, **line_kwargs) ax.yaxis.tick_right() ax.yaxis.set_label_position("right") - ax.set_xlim(vmax, vmin) + ax.set_xlim(vmin, vmax) ax.set_xlabel(y_label) ax.set_ylabel(x_label) else: @@ -711,7 +743,8 @@ def plot_1D( def transform_1D_values( values: np.ndarray, plt_range: Union[PlotRange, RangeType], - params: Parameters, + x_axis: np.ndarray = None, + params: Parameters = None, log: Union[int, float, bool] = False, spacing: Union[int, float] = 1, ) -> tuple[np.ndarray, np.ndarray]: @@ -737,9 +770,10 @@ def transform_1D_values( tuple[np.ndarray, np.ndarray] x axis and values """ + x_axis = get_x_axis(plt_range, x_axis, params) if len(values.shape) != 1: raise ValueError("Can only plot 1D values") - is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params) + is_complex, plt_range = prep_plot_axis(values, plt_range) if is_complex: values = abs2(values) new_axis, ind, ext = sort_axis(x_axis, plt_range) @@ -763,17 +797,20 @@ def transform_1D_values( def plot_spectrogram( values: np.ndarray, - x_range: RangeType, - y_range: RangeType, - params: Parameters, - t_res: int = None, - gate_width: float = None, + x_range: PlotRange, + y_range: PlotRange, + x_axis: np.ndarray = None, + y_axis: np.ndarray = None, + t_res: int = 512, + w_res: int = 512, + gate_width: float = 200e-15, + params: Parameters = None, log: bool = "2D", - vmin: float = None, - vmax: float = None, + vmin: float = -50, + vmax: float = 0, cbar_label: str = "normalized intensity (dB)", cmap: str = None, - ax: plt.Axes = None, + ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None, ): """Plots a spectrogram given a complex field in the time domain Parameters @@ -781,7 +818,7 @@ def plot_spectrogram( values : 2D array axis 0 defines the position in the fiber and axis 1 the position in time, frequency or wl example : [[1, 2, 3], [0, 1, 0]] describes a quantity at 3 different freq/time and at two locations in the fiber - x_range, y_range : tupple (min, max, units) + x_range, y_range : PlotRange one of them must be time, the other one must be wl/freq min, max : int or float minimum and maximum values given in the desired units @@ -798,10 +835,6 @@ def plot_spectrogram( max value of the colorbar cbar_label : str or None label of the colorbar. Will not draw colorbar if None - file_type : str, optional - usually pdf or png - plt_name : str, optional - special name to give to the plot. A name is automatically assigned anyway cmap : str, optional colormap to be used in matplotlib.pyplot.imshow ax : matplotlib.axes._subplots.AxesSubplot object or tupple of 2 axis objects, optional @@ -812,35 +845,49 @@ def plot_spectrogram( if values.ndim != 1: print("plot_spectrogram can only plot 1D arrays") return - x_range: PlotRange - y_range: PlotRange - _, x_axis, x_range = prep_plot_axis(values, x_range, params) - _, y_axis, y_range = prep_plot_axis(values, y_range, params) + x_axis = get_x_axis(x_range, x_axis, params) + y_axis = get_x_axis(y_range, y_axis, params) + _, x_range = prep_plot_axis(values, x_range) + _, y_range = prep_plot_axis(values, y_range) if (x_range.unit.type == "TIME") == (y_range.unit.type == "TIME"): print("exactly one range must be a time range") return # 0 axis means x-axis -> determine final orientation of spectrogram - time_axis = 0 if x_range.unit.type not in ["WL", "FREQ", "AFREQ"] else 1 - if time_axis == 0: + time_axis = 1 if x_range.unit.type not in ["WL", "FREQ", "AFREQ"] else 1 + if time_axis == 1: + t_axis = x_axis t_range = x_range + w_range = y_range + w_axis = y_axis else: + t_axis = y_axis t_range = y_range + w_range = x_range + w_axis = x_axis + + old_w, w_ind, _ = w_range.sort_axis(w_axis) + new_w = np.linspace(w_range.left, w_range.right, w_res) # Actually compute the spectrogram - t_win = 2 * np.max(t_range.unit(np.abs((t_range.left, t_range.right)))) - spec_kwargs = dict(t_res=t_res, t_win=t_win, gate_width=gate_width, shift=False) - spec, new_t = pulse.spectrogram( - params.t.copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None} + delays = np.linspace(t_range.unit(t_range.left), t_range.unit(t_range.right), t_res) + spec, new_t = pulse.spectrogram_interp( + t_axis, delays, values, old_w, w_ind, new_w, gate_width=gate_width ) - if time_axis == 0: - x_axis = new_t - else: - y_axis = new_t - x_axis, spec = uniform_axis(x_axis, spec, x_range) - y_axis, spec.T[:] = uniform_axis(y_axis, spec.T, y_range) + new_t = t_range.unit.inv(new_t) + + if w_range.must_correct_wl: + spec = np.apply_along_axis(units.to_WL, 1, spec, new_w) + + if time_axis == 1: + spec = spec.T + x_axis = new_t + y_axis = new_w + else: + x_axis = new_w + y_axis = new_t values = apply_log(spec, log) @@ -849,8 +896,8 @@ def plot_spectrogram( x_axis, y_axis, ax, - x_range.unit.label, - y_range.unit.label, + None, + None, vmin, vmax, False, @@ -907,7 +954,7 @@ def uniform_axis( if plt_range.unit.type == "WL" and plt_range.conserved_quantity: values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis) new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis)) - values = np.array([interp1d(tmp_axis, v[ind])(new_axis) for v in values]) + values = linear_interp_2d(tmp_axis, values[:, ind], new_axis) return new_axis, values.squeeze() @@ -954,16 +1001,12 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr def prep_plot_axis( - values: np.ndarray, plt_range: Union[PlotRange, RangeType], params: Parameters -) -> tuple[bool, np.ndarray, PlotRange]: + values: np.ndarray, plt_range: Union[PlotRange, RangeType] +) -> tuple[bool, PlotRange]: is_spectrum = values.dtype == "complex" if not isinstance(plt_range, PlotRange): plt_range = PlotRange(*plt_range) - if plt_range.unit.type in ["WL", "FREQ", "AFREQ"]: - x_axis = params.w.copy() - else: - x_axis = params.t.copy() - return is_spectrum, x_axis, plt_range + return is_spectrum, plt_range def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):