image data helper in plotting

This commit is contained in:
Benoît Sierro
2024-01-08 11:30:22 +01:00
parent 2d1cf0c9c1
commit eafb88a899
2 changed files with 26 additions and 20 deletions

View File

@@ -1,4 +1,5 @@
import os import os
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Literal, Optional, Sequence, Union from typing import Any, Callable, Literal, Optional, Sequence, Union
@@ -25,6 +26,17 @@ RangeType = tuple[float, float, Union[str, Callable]]
NO_LIM = object() NO_LIM = object()
@dataclass
class ImageData:
x: np.ndarray
y: np.ndarray
data: np.ndarray
@property
def extent(self) -> tuple[float, float, float, float]:
return get_extent(self.x, self.y)
def get_extent(x, y, facx=1, facy=1): def get_extent(x, y, facx=1, facy=1):
""" """
returns the extent 4-tuple needed for imshow, aligning each pixel returns the extent 4-tuple needed for imshow, aligning each pixel
@@ -310,9 +322,7 @@ def propagation_plot(
Axes obj on which to draw, by default None Axes obj on which to draw, by default None
""" """
x_axis, y_axis, values = transform_2D_propagation( x_axis, y_axis, values = transform_2D_data(values, plt_range, x_axis, y_axis, log, skip, params)
values, plt_range, x_axis, y_axis, log, skip, params
)
if renormalize and not log: if renormalize and not log:
values = values / values.max() values = values / values.max()
if log: if log:
@@ -431,13 +441,13 @@ def plot_2D(
return ax return ax
def transform_2D_propagation( def transform_2D_data(
values: np.ndarray, values: np.ndarray,
plt_range: Union[PlotRange, RangeType], plt_range: Union[PlotRange, RangeType],
x_axis: np.ndarray, x_axis: np.ndarray,
y_axis: np.ndarray, y_axis: np.ndarray,
log: Union[int, float, bool, str] = "1D", log: Union[int, float, bool, str] = "1D",
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> ImageData:
""" """
transforms raws values into plottable values transforms raws values into plottable values
@@ -456,12 +466,7 @@ def transform_2D_propagation(
Returns Returns
------- -------
np.ndarray ImageData
x_axis
np.ndarray
y_axis
np.ndarray
values
Raises Raises
------ ------
@@ -477,7 +482,7 @@ def transform_2D_propagation(
x_axis, values = uniform_axis(x_axis, values, plt_range) x_axis, values = uniform_axis(x_axis, values, plt_range)
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None) y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
values = apply_log(values, log) values = apply_log(values, log)
return x_axis, y_axis, values return ImageData(x_axis, y_axis, values)
def get_x_axis(plt_range, x_axis, params) -> np.ndarray: def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
@@ -1169,7 +1174,7 @@ def summary_plot(
wl_db="1D", wl_db="1D",
time_db=False, time_db=False,
cmap: str | matplotlib.colors.LinearSegmentedColormap = None, cmap: str | matplotlib.colors.LinearSegmentedColormap = None,
): ) -> tuple[ImageData, ImageData]:
wl_int = specs.wl_int wl_int = specs.wl_int
time_int = specs.time_int time_int = specs.time_int
@@ -1198,22 +1203,23 @@ def summary_plot(
else: else:
left, right = axes left, right = axes
x, y, values = transform_2D_propagation(wl_int, wl_range, specs.wl_disp, z, log=wl_db) image_spec = transform_2D_data(wl_int, wl_range, specs.wl_disp, z, log=wl_db)
left.imshow( left.imshow(
values, image_spec.data,
extent=get_extent(x, y), extent=image_spec.extent,
origin="lower", origin="lower",
aspect="auto", aspect="auto",
vmin=wl_disp_limit, vmin=wl_disp_limit,
cmap=cmap, cmap=cmap,
) )
x, y, values = transform_2D_propagation(time_int, time_range, specs.t, z, log=time_db) image_time = transform_2D_data(time_int, time_range, specs.t, z, log=time_db)
right.imshow( right.imshow(
values, image_time.data,
extent=get_extent(x, y), extent=image_time.extent,
origin="lower", origin="lower",
aspect="auto", aspect="auto",
vmin=time_disp_limit, vmin=time_disp_limit,
cmap=cmap, cmap=cmap,
) )
return image_spec, image_time

View File

@@ -60,7 +60,7 @@ def test_rk43_soliton(plot=False):
res = sc.integrate(spec0, end, lin, non_lin, targets=targets, atol=1e-10, rtol=1e-9) res = sc.integrate(spec0, end, lin, non_lin, targets=targets, atol=1e-10, rtol=1e-9)
if plot: if plot:
x, y, z = sc.plotting.transform_2D_propagation( x, y, z = sc.plotting.transform_2D_data(
res.spectra, res.spectra,
sc.PlotRange(500, 1300, "nm"), sc.PlotRange(500, 1300, "nm"),
w_c + w0, w_c + w0,