image data helper in plotting

This commit is contained in:
Benoît Sierro
2024-01-08 11:30:22 +01:00
parent 2d1cf0c9c1
commit eafb88a899
2 changed files with 26 additions and 20 deletions

View File

@@ -1,4 +1,5 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Sequence, Union
@@ -25,6 +26,17 @@ RangeType = tuple[float, float, Union[str, Callable]]
NO_LIM = object()
@dataclass
class ImageData:
x: np.ndarray
y: np.ndarray
data: np.ndarray
@property
def extent(self) -> tuple[float, float, float, float]:
return get_extent(self.x, self.y)
def get_extent(x, y, facx=1, facy=1):
"""
returns the extent 4-tuple needed for imshow, aligning each pixel
@@ -310,9 +322,7 @@ def propagation_plot(
Axes obj on which to draw, by default None
"""
x_axis, y_axis, values = transform_2D_propagation(
values, plt_range, x_axis, y_axis, log, skip, params
)
x_axis, y_axis, values = transform_2D_data(values, plt_range, x_axis, y_axis, log, skip, params)
if renormalize and not log:
values = values / values.max()
if log:
@@ -431,13 +441,13 @@ def plot_2D(
return ax
def transform_2D_propagation(
def transform_2D_data(
values: np.ndarray,
plt_range: Union[PlotRange, RangeType],
x_axis: np.ndarray,
y_axis: np.ndarray,
log: Union[int, float, bool, str] = "1D",
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> ImageData:
"""
transforms raws values into plottable values
@@ -456,12 +466,7 @@ def transform_2D_propagation(
Returns
-------
np.ndarray
x_axis
np.ndarray
y_axis
np.ndarray
values
ImageData
Raises
------
@@ -477,7 +482,7 @@ def transform_2D_propagation(
x_axis, values = uniform_axis(x_axis, values, plt_range)
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
values = apply_log(values, log)
return x_axis, y_axis, values
return ImageData(x_axis, y_axis, values)
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
@@ -1169,7 +1174,7 @@ def summary_plot(
wl_db="1D",
time_db=False,
cmap: str | matplotlib.colors.LinearSegmentedColormap = None,
):
) -> tuple[ImageData, ImageData]:
wl_int = specs.wl_int
time_int = specs.time_int
@@ -1198,22 +1203,23 @@ def summary_plot(
else:
left, right = axes
x, y, values = transform_2D_propagation(wl_int, wl_range, specs.wl_disp, z, log=wl_db)
image_spec = transform_2D_data(wl_int, wl_range, specs.wl_disp, z, log=wl_db)
left.imshow(
values,
extent=get_extent(x, y),
image_spec.data,
extent=image_spec.extent,
origin="lower",
aspect="auto",
vmin=wl_disp_limit,
cmap=cmap,
)
x, y, values = transform_2D_propagation(time_int, time_range, specs.t, z, log=time_db)
image_time = transform_2D_data(time_int, time_range, specs.t, z, log=time_db)
right.imshow(
values,
extent=get_extent(x, y),
image_time.data,
extent=image_time.extent,
origin="lower",
aspect="auto",
vmin=time_disp_limit,
cmap=cmap,
)
return image_spec, image_time

View File

@@ -60,7 +60,7 @@ def test_rk43_soliton(plot=False):
res = sc.integrate(spec0, end, lin, non_lin, targets=targets, atol=1e-10, rtol=1e-9)
if plot:
x, y, z = sc.plotting.transform_2D_propagation(
x, y, z = sc.plotting.transform_2D_data(
res.spectra,
sc.PlotRange(500, 1300, "nm"),
w_c + w0,