From a2ec986d0cf623b83406924caad98fe4291f6462 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 20 Jun 2024 11:13:15 +0200 Subject: [PATCH] plotting fixes --- src/scgenerator/plotting.py | 52 +++++++++++++++++++++++++++++++------ src/scgenerator/spectra.py | 8 +++--- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 9a08b16..50b4a6a 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -7,6 +7,10 @@ from typing import Any, Callable, Literal, Optional, Sequence, Union import matplotlib.colors import matplotlib.gridspec as gs import matplotlib.pyplot as plt +from matplotlib.ticker import ( + LinearLocator, + mtransforms, +) import numpy as np from matplotlib.axes import Axes from matplotlib.colors import ColorConverter, ListedColormap @@ -27,6 +31,22 @@ RangeType = tuple[float, float, Union[str, Callable]] NO_LIM = object() +class InverseLocator(LinearLocator): + def tick_values(self, vmin, vmax): + vmin, vmax = mtransforms.nonsingular(vmin, vmax, expander=0.05) + + if (vmin, vmax) in self.presets: + return self.presets[(vmin, vmax)] + + if self.numticks == 0: + return [] + + ticklocs = 1 / np.linspace(1 / vmax, 1 / vmin, self.numticks) + ticklocs = 20 * np.unique(np.round(ticklocs / 20)) + + return self.raise_if_exceeds(ticklocs) + + @dataclass class ImageData: x: np.ndarray @@ -1210,16 +1230,23 @@ def auto_wl_range(wl: np.ndarray, wl_int: np.ndarray, threshold: float) -> PlotR def summary_plot( specs: Spectrum, z: Sequence[float] | None = None, - wl_range: PlotRange | None = None, + spec_range: PlotRange | None = None, time_range: PlotRange | None = None, db_min: float = -50.0, lin_min: float = 1e-3, axes: tuple[Axes, Axes] | None = None, - wl_db="1D", + spec_db="1D", time_db=False, cmap: str | matplotlib.colors.LinearSegmentedColormap = None, ) -> tuple[ImageData, ImageData]: - wl_int = specs.wl_int + if spec_range is not None and spec_range.unit.type != "WL": + spec_int = specs.afreq_int + x_axis = specs.w_disp + freq_ax = True + else: + spec_int = specs.wl_int + x_axis = specs.wl_disp + freq_ax = False time_int = specs.time_int if z is None: @@ -1229,11 +1256,11 @@ def summary_plot( z = np.asarray(z) calc_limit, wl_disp_limit = ( - (10 ** (0.1 * db_min - 1), db_min) if wl_db and wl_db != "linear 1D" else (lin_min, 0) + (10 ** (0.1 * db_min - 1), db_min) if spec_db and spec_db != "linear 1D" else (lin_min, 0) ) - if wl_range is None: - imin, imax = math.span_above(wl_int, wl_int.max() * calc_limit) - wl_range = PlotRange(specs.wl_disp[imin] * 1e9, specs.wl_disp[imax] * 1e9, "nm") + if spec_range is None: + imin, imax = math.span_above(spec_int, spec_int.max() * calc_limit) + spec_range = PlotRange(specs.wl_disp[imin] * 1e9, specs.wl_disp[imax] * 1e9, "nm") calc_limit, time_disp_limit = ( (10 ** (0.1 * db_min - 1), db_min) if time_db and time_db != "linear 1D" else (lin_min, 0) @@ -1247,7 +1274,7 @@ def summary_plot( else: left, right = axes - image_spec = transform_2D_data(wl_int, wl_range, specs.wl_disp, z, log=wl_db) + image_spec = transform_2D_data(spec_int, spec_range, x_axis, z, log=spec_db) left.imshow( image_spec.data, extent=image_spec.extent, @@ -1266,4 +1293,13 @@ def summary_plot( vmin=time_disp_limit, cmap=cmap, ) + if freq_ax: + secax = left.secondary_xaxis( + "top", + functions=( + lambda x: units.nm_rads(spec_range.unit(x)), + lambda x: spec_range.unit.inv(units.nm_rads(x)), + ), + ) + secax.xaxis.set_major_locator(InverseLocator(numticks=5)) return image_spec, image_time diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 6ce49ae..f286a7c 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -267,6 +267,7 @@ class Spectrum(np.ndarray): gate_width: float = 100e-15, wavelength: bool = True, autocrop: bool | float = 1e-5, + step_size: int = 4, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: dt = self.t[1] - self.t[0] sigma = gate_width / (2 * np.sqrt(2 * np.log(2))) / dt @@ -276,7 +277,7 @@ class Spectrum(np.ndarray): 1 / dt, window=("gaussian", sigma), nperseg=nperseg, - noverlap=nperseg - 4, + noverlap=nperseg - step_size, detrend=False, scaling="psd", boundary=None, @@ -291,8 +292,9 @@ class Spectrum(np.ndarray): if wavelength: f = units.m_hz(f) s = units.to_WL(s.T, f).T - f = f[::-1] - s = s[::-1] + o = f.argsort() + f = f[o] + s = s[o] if autocrop: thr = s.max()