added easy uniform_2d method
This new function is made to lighten syntax when
This commit is contained in:
@@ -9,7 +9,7 @@ from matplotlib.axes import Axes
|
|||||||
from matplotlib.colors import ListedColormap
|
from matplotlib.colors import ListedColormap
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from matplotlib.transforms import offset_copy
|
from matplotlib.transforms import offset_copy
|
||||||
from scipy.interpolate import UnivariateSpline
|
from scipy.interpolate import UnivariateSpline, interp1d
|
||||||
|
|
||||||
from scgenerator import math
|
from scgenerator import math
|
||||||
from scgenerator.const import PARAM_SEPARATOR
|
from scgenerator.const import PARAM_SEPARATOR
|
||||||
@@ -63,9 +63,7 @@ def plot_setup(
|
|||||||
|
|
||||||
# ensure no overwrite
|
# ensure no overwrite
|
||||||
ind = 0
|
ind = 0
|
||||||
while (
|
while (full_path := (out_dir / (plot_name + f"{PARAM_SEPARATOR}{ind}." + file_type))).exists():
|
||||||
full_path := (out_dir / (plot_name + f"{PARAM_SEPARATOR}{ind}." + file_type))
|
|
||||||
).exists():
|
|
||||||
ind += 1
|
ind += 1
|
||||||
|
|
||||||
if mode == "default":
|
if mode == "default":
|
||||||
@@ -215,9 +213,7 @@ def create_zoom_axis(
|
|||||||
return inset
|
return inset
|
||||||
|
|
||||||
|
|
||||||
def corner_annotation(
|
def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0.05, **text_kwargs):
|
||||||
text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0.05, **text_kwargs
|
|
||||||
):
|
|
||||||
"""puts an annotatin in a corner of an ax
|
"""puts an annotatin in a corner of an ax
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -487,6 +483,29 @@ def transform_2D_propagation(
|
|||||||
return x_axis[::skip], y_axis, values[:, ::skip]
|
return x_axis[::skip], y_axis, values[:, ::skip]
|
||||||
|
|
||||||
|
|
||||||
|
def uniform_2d(
|
||||||
|
old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray, new_y: np.ndarray, values: np.ndarray
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
interpolates a 2d array according to the provides old and new axis
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
old_x : np.ndarray, shape (n,)
|
||||||
|
old_y : np.ndarray, shape (m,)
|
||||||
|
new_x : np.ndarray, shape (N,)
|
||||||
|
new_y : np.ndarray, shape (M,)
|
||||||
|
values : np.ndarray, shape (m, n)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray, shape (M, N)
|
||||||
|
"""
|
||||||
|
values = interp1d(old_x, values, fill_value=0, bounds_error=False, axis=1)(new_x)
|
||||||
|
values = interp1d(old_y, values, fill_value=0, bounds_error=False, axis=0)(new_y)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
|
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
|
||||||
if x_axis is None and params is not None:
|
if x_axis is None and params is not None:
|
||||||
if plt_range.unit.type in {"WL", "FREQ", "AFREQ"}:
|
if plt_range.unit.type in {"WL", "FREQ", "AFREQ"}:
|
||||||
@@ -514,9 +533,7 @@ def mean_values_plot(
|
|||||||
mean_style: dict[str, Any] = None,
|
mean_style: dict[str, Any] = None,
|
||||||
individual_style: dict[str, Any] = None,
|
individual_style: dict[str, Any] = None,
|
||||||
) -> tuple[plt.Line2D, list[plt.Line2D]]:
|
) -> tuple[plt.Line2D, list[plt.Line2D]]:
|
||||||
x_axis, mean_values, values = transform_mean_values(
|
x_axis, mean_values, values = transform_mean_values(values, plt_range, params, log, spacing)
|
||||||
values, plt_range, params, log, spacing
|
|
||||||
)
|
|
||||||
if renormalize and log is False:
|
if renormalize and log is False:
|
||||||
maxi = mean_values.max()
|
maxi = mean_values.max()
|
||||||
mean_values = mean_values / maxi
|
mean_values = mean_values / maxi
|
||||||
@@ -585,9 +602,7 @@ def transform_mean_values(
|
|||||||
|
|
||||||
if isinstance(spacing, (float, np.floating)):
|
if isinstance(spacing, (float, np.floating)):
|
||||||
tmp_axis = np.linspace(*span(new_axis), int(len(new_axis) / spacing))
|
tmp_axis = np.linspace(*span(new_axis), int(len(new_axis) / spacing))
|
||||||
values = np.array(
|
values = np.array([UnivariateSpline(new_axis, v, k=4, s=0)(tmp_axis) for v in values])
|
||||||
[UnivariateSpline(new_axis, v, k=4, s=0)(tmp_axis) for v in values]
|
|
||||||
)
|
|
||||||
new_axis = tmp_axis
|
new_axis = tmp_axis
|
||||||
elif isinstance(spacing, (int, np.integer)) and spacing > 1:
|
elif isinstance(spacing, (int, np.integer)) and spacing > 1:
|
||||||
values = values[:, ::spacing]
|
values = values[:, ::spacing]
|
||||||
@@ -644,9 +659,7 @@ def plot_mean(
|
|||||||
transpose : bool, optional
|
transpose : bool, optional
|
||||||
rotate the plot 90° counterclockwise, by default False
|
rotate the plot 90° counterclockwise, by default False
|
||||||
"""
|
"""
|
||||||
individual_style = (
|
individual_style = defaults["muted_style"] if individual_style is None else individual_style
|
||||||
defaults["muted_style"] if individual_style is None else individual_style
|
|
||||||
)
|
|
||||||
mean_style = defaults["highlighted_style"] if mean_style is None else mean_style
|
mean_style = defaults["highlighted_style"] if mean_style is None else mean_style
|
||||||
labels = defaults["avg_line_labels"] if line_labels is None else line_labels
|
labels = defaults["avg_line_labels"] if line_labels is None else line_labels
|
||||||
lines = []
|
lines = []
|
||||||
@@ -686,9 +699,7 @@ def single_position_plot(
|
|||||||
y_label: str = None,
|
y_label: str = None,
|
||||||
**line_kwargs,
|
**line_kwargs,
|
||||||
) -> tuple[Figure, Axes, plt.Line2D, np.ndarray, np.ndarray]:
|
) -> tuple[Figure, Axes, plt.Line2D, np.ndarray, np.ndarray]:
|
||||||
x_axis, values = transform_1D_values(
|
x_axis, values = transform_1D_values(values, plt_range, x_axis, params, log, spacing)
|
||||||
values, plt_range, x_axis, params, log, spacing
|
|
||||||
)
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
values = values / values.max()
|
values = values / values.max()
|
||||||
|
|
||||||
@@ -977,9 +988,7 @@ def uniform_axis(
|
|||||||
values = values[:, ind]
|
values = values[:, ind]
|
||||||
else:
|
else:
|
||||||
if plt_range.unit.type == "WL" and plt_range.conserved_quantity:
|
if plt_range.unit.type == "WL" and plt_range.conserved_quantity:
|
||||||
values[:, ind] = np.apply_along_axis(
|
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
|
||||||
units.to_WL, 1, values[:, ind], tmp_axis
|
|
||||||
)
|
|
||||||
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
|
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
|
||||||
values = linear_interp_2d(tmp_axis, values[:, ind], new_axis)
|
values = linear_interp_2d(tmp_axis, values[:, ind], new_axis)
|
||||||
return new_axis, values.squeeze()
|
return new_axis, values.squeeze()
|
||||||
@@ -1036,18 +1045,14 @@ def prep_plot_axis(
|
|||||||
return is_spectrum, plt_range
|
return is_spectrum, plt_range
|
||||||
|
|
||||||
|
|
||||||
def white_bottom_cmap(
|
def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):
|
||||||
name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)
|
|
||||||
):
|
|
||||||
"""returns a new colormap based on "name" but that has a solid bacground (default=white)"""
|
"""returns a new colormap based on "name" but that has a solid bacground (default=white)"""
|
||||||
top = plt.get_cmap(name, 1024)
|
top = plt.get_cmap(name, 1024)
|
||||||
n_bottom = 80
|
n_bottom = 80
|
||||||
bottom = np.ones((n_bottom, 4))
|
bottom = np.ones((n_bottom, 4))
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
bottom[:, i] = np.linspace(c_back[i], top(start)[i], n_bottom)
|
bottom[:, i] = np.linspace(c_back[i], top(start)[i], n_bottom)
|
||||||
return ListedColormap(
|
return ListedColormap(np.vstack((bottom, top(np.linspace(start, end, 1024)))), name=new_name)
|
||||||
np.vstack((bottom, top(np.linspace(start, end, 1024)))), name=new_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def default_marker_style(k):
|
def default_marker_style(k):
|
||||||
@@ -1074,9 +1079,7 @@ def default_marker_style(k):
|
|||||||
|
|
||||||
def arrowstyle(direction=1, color="white"):
|
def arrowstyle(direction=1, color="white"):
|
||||||
return dict(
|
return dict(
|
||||||
arrowprops=dict(
|
arrowprops=dict(arrowstyle="->", connectionstyle=f"arc3,rad={direction*0.3}", color=color),
|
||||||
arrowstyle="->", connectionstyle=f"arc3,rad={direction*0.3}", color=color
|
|
||||||
),
|
|
||||||
color=color,
|
color=color,
|
||||||
backgroundcolor=(0.5, 0.5, 0.5, 0.5),
|
backgroundcolor=(0.5, 0.5, 0.5, 0.5),
|
||||||
)
|
)
|
||||||
@@ -1121,9 +1124,7 @@ def measure_and_annotate_fwhm(
|
|||||||
_, (left, right), *_ = pulse.find_lobe_limits(unit.inv(t), field)
|
_, (left, right), *_ = pulse.find_lobe_limits(unit.inv(t), field)
|
||||||
arrow_label = f"{right - left:.1f} {unit.name}"
|
arrow_label = f"{right - left:.1f} {unit.name}"
|
||||||
|
|
||||||
annotate_fwhm(
|
annotate_fwhm(ax, left, right, arrow_label, field.max(), side, arrow_length_pts, arrow_props)
|
||||||
ax, left, right, arrow_label, field.max(), side, arrow_length_pts, arrow_props
|
|
||||||
)
|
|
||||||
return right - left
|
return right - left
|
||||||
|
|
||||||
|
|
||||||
@@ -1143,9 +1144,7 @@ def annotate_fwhm(
|
|||||||
if color:
|
if color:
|
||||||
arrow_dict |= dict(color=color)
|
arrow_dict |= dict(color=color)
|
||||||
annotate_kwargs |= dict(color=color)
|
annotate_kwargs |= dict(color=color)
|
||||||
text_kwargs = (
|
text_kwargs = dict(ha="right" if side == "left" else "left", va="center") | annotate_kwargs
|
||||||
dict(ha="right" if side == "left" else "left", va="center") | annotate_kwargs
|
|
||||||
)
|
|
||||||
if arrow_props is not None:
|
if arrow_props is not None:
|
||||||
arrow_dict |= arrow_props
|
arrow_dict |= arrow_props
|
||||||
txt = {}
|
txt = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user