image data helper in plotting
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Literal, Optional, Sequence, Union
|
from typing import Any, Callable, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
@@ -25,6 +26,17 @@ RangeType = tuple[float, float, Union[str, Callable]]
|
|||||||
NO_LIM = object()
|
NO_LIM = object()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageData:
|
||||||
|
x: np.ndarray
|
||||||
|
y: np.ndarray
|
||||||
|
data: np.ndarray
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extent(self) -> tuple[float, float, float, float]:
|
||||||
|
return get_extent(self.x, self.y)
|
||||||
|
|
||||||
|
|
||||||
def get_extent(x, y, facx=1, facy=1):
|
def get_extent(x, y, facx=1, facy=1):
|
||||||
"""
|
"""
|
||||||
returns the extent 4-tuple needed for imshow, aligning each pixel
|
returns the extent 4-tuple needed for imshow, aligning each pixel
|
||||||
@@ -310,9 +322,7 @@ def propagation_plot(
|
|||||||
Axes obj on which to draw, by default None
|
Axes obj on which to draw, by default None
|
||||||
|
|
||||||
"""
|
"""
|
||||||
x_axis, y_axis, values = transform_2D_propagation(
|
x_axis, y_axis, values = transform_2D_data(values, plt_range, x_axis, y_axis, log, skip, params)
|
||||||
values, plt_range, x_axis, y_axis, log, skip, params
|
|
||||||
)
|
|
||||||
if renormalize and not log:
|
if renormalize and not log:
|
||||||
values = values / values.max()
|
values = values / values.max()
|
||||||
if log:
|
if log:
|
||||||
@@ -431,13 +441,13 @@ def plot_2D(
|
|||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def transform_2D_propagation(
|
def transform_2D_data(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
plt_range: Union[PlotRange, RangeType],
|
plt_range: Union[PlotRange, RangeType],
|
||||||
x_axis: np.ndarray,
|
x_axis: np.ndarray,
|
||||||
y_axis: np.ndarray,
|
y_axis: np.ndarray,
|
||||||
log: Union[int, float, bool, str] = "1D",
|
log: Union[int, float, bool, str] = "1D",
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> ImageData:
|
||||||
"""
|
"""
|
||||||
transforms raws values into plottable values
|
transforms raws values into plottable values
|
||||||
|
|
||||||
@@ -456,12 +466,7 @@ def transform_2D_propagation(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
np.ndarray
|
ImageData
|
||||||
x_axis
|
|
||||||
np.ndarray
|
|
||||||
y_axis
|
|
||||||
np.ndarray
|
|
||||||
values
|
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
@@ -477,7 +482,7 @@ def transform_2D_propagation(
|
|||||||
x_axis, values = uniform_axis(x_axis, values, plt_range)
|
x_axis, values = uniform_axis(x_axis, values, plt_range)
|
||||||
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
|
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
|
||||||
values = apply_log(values, log)
|
values = apply_log(values, log)
|
||||||
return x_axis, y_axis, values
|
return ImageData(x_axis, y_axis, values)
|
||||||
|
|
||||||
|
|
||||||
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
|
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
|
||||||
@@ -1169,7 +1174,7 @@ def summary_plot(
|
|||||||
wl_db="1D",
|
wl_db="1D",
|
||||||
time_db=False,
|
time_db=False,
|
||||||
cmap: str | matplotlib.colors.LinearSegmentedColormap = None,
|
cmap: str | matplotlib.colors.LinearSegmentedColormap = None,
|
||||||
):
|
) -> tuple[ImageData, ImageData]:
|
||||||
wl_int = specs.wl_int
|
wl_int = specs.wl_int
|
||||||
time_int = specs.time_int
|
time_int = specs.time_int
|
||||||
|
|
||||||
@@ -1198,22 +1203,23 @@ def summary_plot(
|
|||||||
else:
|
else:
|
||||||
left, right = axes
|
left, right = axes
|
||||||
|
|
||||||
x, y, values = transform_2D_propagation(wl_int, wl_range, specs.wl_disp, z, log=wl_db)
|
image_spec = transform_2D_data(wl_int, wl_range, specs.wl_disp, z, log=wl_db)
|
||||||
left.imshow(
|
left.imshow(
|
||||||
values,
|
image_spec.data,
|
||||||
extent=get_extent(x, y),
|
extent=image_spec.extent,
|
||||||
origin="lower",
|
origin="lower",
|
||||||
aspect="auto",
|
aspect="auto",
|
||||||
vmin=wl_disp_limit,
|
vmin=wl_disp_limit,
|
||||||
cmap=cmap,
|
cmap=cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
x, y, values = transform_2D_propagation(time_int, time_range, specs.t, z, log=time_db)
|
image_time = transform_2D_data(time_int, time_range, specs.t, z, log=time_db)
|
||||||
right.imshow(
|
right.imshow(
|
||||||
values,
|
image_time.data,
|
||||||
extent=get_extent(x, y),
|
extent=image_time.extent,
|
||||||
origin="lower",
|
origin="lower",
|
||||||
aspect="auto",
|
aspect="auto",
|
||||||
vmin=time_disp_limit,
|
vmin=time_disp_limit,
|
||||||
cmap=cmap,
|
cmap=cmap,
|
||||||
)
|
)
|
||||||
|
return image_spec, image_time
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ def test_rk43_soliton(plot=False):
|
|||||||
|
|
||||||
res = sc.integrate(spec0, end, lin, non_lin, targets=targets, atol=1e-10, rtol=1e-9)
|
res = sc.integrate(spec0, end, lin, non_lin, targets=targets, atol=1e-10, rtol=1e-9)
|
||||||
if plot:
|
if plot:
|
||||||
x, y, z = sc.plotting.transform_2D_propagation(
|
x, y, z = sc.plotting.transform_2D_data(
|
||||||
res.spectra,
|
res.spectra,
|
||||||
sc.PlotRange(500, 1300, "nm"),
|
sc.PlotRange(500, 1300, "nm"),
|
||||||
w_c + w0,
|
w_c + w0,
|
||||||
|
|||||||
Reference in New Issue
Block a user