From eafb88a899c52e3326149906431a550595a5bc4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Mon, 8 Jan 2024 11:30:22 +0100 Subject: [PATCH] image data helper in plotting --- src/scgenerator/plotting.py | 44 +++++++++++++++++++++---------------- tests/test_integrator.py | 2 +- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 05e4d4b..dae2b22 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -1,4 +1,5 @@ import os +from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Literal, Optional, Sequence, Union @@ -25,6 +26,17 @@ RangeType = tuple[float, float, Union[str, Callable]] NO_LIM = object() +@dataclass +class ImageData: + x: np.ndarray + y: np.ndarray + data: np.ndarray + + @property + def extent(self) -> tuple[float, float, float, float]: + return get_extent(self.x, self.y) + + def get_extent(x, y, facx=1, facy=1): """ returns the extent 4-tuple needed for imshow, aligning each pixel @@ -310,9 +322,7 @@ def propagation_plot( Axes obj on which to draw, by default None """ - x_axis, y_axis, values = transform_2D_propagation( - values, plt_range, x_axis, y_axis, log, skip, params - ) + x_axis, y_axis, values = transform_2D_data(values, plt_range, x_axis, y_axis, log, skip, params) if renormalize and not log: values = values / values.max() if log: @@ -431,13 +441,13 @@ def plot_2D( return ax -def transform_2D_propagation( +def transform_2D_data( values: np.ndarray, plt_range: Union[PlotRange, RangeType], x_axis: np.ndarray, y_axis: np.ndarray, log: Union[int, float, bool, str] = "1D", -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: +) -> ImageData: """ transforms raws values into plottable values @@ -456,12 +466,7 @@ def transform_2D_propagation( Returns ------- - np.ndarray - x_axis - np.ndarray - y_axis - np.ndarray - values + ImageData Raises ------ @@ -477,7 +482,7 @@ def transform_2D_propagation( x_axis, values = uniform_axis(x_axis, values, plt_range) y_axis, values.T[:] = uniform_axis(y_axis, values.T, None) values = apply_log(values, log) - return x_axis, y_axis, values + return ImageData(x_axis, y_axis, values) def get_x_axis(plt_range, x_axis, params) -> np.ndarray: @@ -1169,7 +1174,7 @@ def summary_plot( wl_db="1D", time_db=False, cmap: str | matplotlib.colors.LinearSegmentedColormap = None, -): +) -> tuple[ImageData, ImageData]: wl_int = specs.wl_int time_int = specs.time_int @@ -1198,22 +1203,23 @@ def summary_plot( else: left, right = axes - x, y, values = transform_2D_propagation(wl_int, wl_range, specs.wl_disp, z, log=wl_db) + image_spec = transform_2D_data(wl_int, wl_range, specs.wl_disp, z, log=wl_db) left.imshow( - values, - extent=get_extent(x, y), + image_spec.data, + extent=image_spec.extent, origin="lower", aspect="auto", vmin=wl_disp_limit, cmap=cmap, ) - x, y, values = transform_2D_propagation(time_int, time_range, specs.t, z, log=time_db) + image_time = transform_2D_data(time_int, time_range, specs.t, z, log=time_db) right.imshow( - values, - extent=get_extent(x, y), + image_time.data, + extent=image_time.extent, origin="lower", aspect="auto", vmin=time_disp_limit, cmap=cmap, ) + return image_spec, image_time diff --git a/tests/test_integrator.py b/tests/test_integrator.py index 3073839..45688f5 100644 --- a/tests/test_integrator.py +++ b/tests/test_integrator.py @@ -60,7 +60,7 @@ def test_rk43_soliton(plot=False): res = sc.integrate(spec0, end, lin, non_lin, targets=targets, atol=1e-10, rtol=1e-9) if plot: - x, y, z = sc.plotting.transform_2D_propagation( + x, y, z = sc.plotting.transform_2D_data( res.spectra, sc.PlotRange(500, 1300, "nm"), w_c + w0,