image data helper in plotting
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, Sequence, Union
|
||||
|
||||
@@ -25,6 +26,17 @@ RangeType = tuple[float, float, Union[str, Callable]]
|
||||
NO_LIM = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageData:
|
||||
x: np.ndarray
|
||||
y: np.ndarray
|
||||
data: np.ndarray
|
||||
|
||||
@property
|
||||
def extent(self) -> tuple[float, float, float, float]:
|
||||
return get_extent(self.x, self.y)
|
||||
|
||||
|
||||
def get_extent(x, y, facx=1, facy=1):
|
||||
"""
|
||||
returns the extent 4-tuple needed for imshow, aligning each pixel
|
||||
@@ -310,9 +322,7 @@ def propagation_plot(
|
||||
Axes obj on which to draw, by default None
|
||||
|
||||
"""
|
||||
x_axis, y_axis, values = transform_2D_propagation(
|
||||
values, plt_range, x_axis, y_axis, log, skip, params
|
||||
)
|
||||
x_axis, y_axis, values = transform_2D_data(values, plt_range, x_axis, y_axis, log, skip, params)
|
||||
if renormalize and not log:
|
||||
values = values / values.max()
|
||||
if log:
|
||||
@@ -431,13 +441,13 @@ def plot_2D(
|
||||
return ax
|
||||
|
||||
|
||||
def transform_2D_propagation(
|
||||
def transform_2D_data(
|
||||
values: np.ndarray,
|
||||
plt_range: Union[PlotRange, RangeType],
|
||||
x_axis: np.ndarray,
|
||||
y_axis: np.ndarray,
|
||||
log: Union[int, float, bool, str] = "1D",
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
) -> ImageData:
|
||||
"""
|
||||
transforms raws values into plottable values
|
||||
|
||||
@@ -456,12 +466,7 @@ def transform_2D_propagation(
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
x_axis
|
||||
np.ndarray
|
||||
y_axis
|
||||
np.ndarray
|
||||
values
|
||||
ImageData
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -477,7 +482,7 @@ def transform_2D_propagation(
|
||||
x_axis, values = uniform_axis(x_axis, values, plt_range)
|
||||
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
|
||||
values = apply_log(values, log)
|
||||
return x_axis, y_axis, values
|
||||
return ImageData(x_axis, y_axis, values)
|
||||
|
||||
|
||||
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
|
||||
@@ -1169,7 +1174,7 @@ def summary_plot(
|
||||
wl_db="1D",
|
||||
time_db=False,
|
||||
cmap: str | matplotlib.colors.LinearSegmentedColormap = None,
|
||||
):
|
||||
) -> tuple[ImageData, ImageData]:
|
||||
wl_int = specs.wl_int
|
||||
time_int = specs.time_int
|
||||
|
||||
@@ -1198,22 +1203,23 @@ def summary_plot(
|
||||
else:
|
||||
left, right = axes
|
||||
|
||||
x, y, values = transform_2D_propagation(wl_int, wl_range, specs.wl_disp, z, log=wl_db)
|
||||
image_spec = transform_2D_data(wl_int, wl_range, specs.wl_disp, z, log=wl_db)
|
||||
left.imshow(
|
||||
values,
|
||||
extent=get_extent(x, y),
|
||||
image_spec.data,
|
||||
extent=image_spec.extent,
|
||||
origin="lower",
|
||||
aspect="auto",
|
||||
vmin=wl_disp_limit,
|
||||
cmap=cmap,
|
||||
)
|
||||
|
||||
x, y, values = transform_2D_propagation(time_int, time_range, specs.t, z, log=time_db)
|
||||
image_time = transform_2D_data(time_int, time_range, specs.t, z, log=time_db)
|
||||
right.imshow(
|
||||
values,
|
||||
extent=get_extent(x, y),
|
||||
image_time.data,
|
||||
extent=image_time.extent,
|
||||
origin="lower",
|
||||
aspect="auto",
|
||||
vmin=time_disp_limit,
|
||||
cmap=cmap,
|
||||
)
|
||||
return image_spec, image_time
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_rk43_soliton(plot=False):
|
||||
|
||||
res = sc.integrate(spec0, end, lin, non_lin, targets=targets, atol=1e-10, rtol=1e-9)
|
||||
if plot:
|
||||
x, y, z = sc.plotting.transform_2D_propagation(
|
||||
x, y, z = sc.plotting.transform_2D_data(
|
||||
res.spectra,
|
||||
sc.PlotRange(500, 1300, "nm"),
|
||||
w_c + w0,
|
||||
|
||||
Reference in New Issue
Block a user