added easy uniform_2d method

This new function is made to lighten syntax when
This commit is contained in:
Benoît Sierro
2023-06-28 15:34:34 +02:00
parent 514dc4d99f
commit ce37a4f292

View File

@@ -9,7 +9,7 @@ from matplotlib.axes import Axes
from matplotlib.colors import ListedColormap
from matplotlib.figure import Figure
from matplotlib.transforms import offset_copy
from scipy.interpolate import UnivariateSpline
from scipy.interpolate import UnivariateSpline, interp1d
from scgenerator import math
from scgenerator.const import PARAM_SEPARATOR
@@ -63,9 +63,7 @@ def plot_setup(
# ensure no overwrite
ind = 0
while (
full_path := (out_dir / (plot_name + f"{PARAM_SEPARATOR}{ind}." + file_type))
).exists():
while (full_path := (out_dir / (plot_name + f"{PARAM_SEPARATOR}{ind}." + file_type))).exists():
ind += 1
if mode == "default":
@@ -215,9 +213,7 @@ def create_zoom_axis(
return inset
def corner_annotation(
text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0.05, **text_kwargs
):
def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0.05, **text_kwargs):
"""puts an annotatin in a corner of an ax
Parameters
----------
@@ -487,6 +483,29 @@ def transform_2D_propagation(
return x_axis[::skip], y_axis, values[:, ::skip]
def uniform_2d(
old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray, new_y: np.ndarray, values: np.ndarray
) -> np.ndarray:
"""
interpolates a 2d array according to the provides old and new axis
Parameters
----------
old_x : np.ndarray, shape (n,)
old_y : np.ndarray, shape (m,)
new_x : np.ndarray, shape (N,)
new_y : np.ndarray, shape (M,)
values : np.ndarray, shape (m, n)
Returns
-------
np.ndarray, shape (M, N)
"""
values = interp1d(old_x, values, fill_value=0, bounds_error=False, axis=1)(new_x)
values = interp1d(old_y, values, fill_value=0, bounds_error=False, axis=0)(new_y)
return values
def get_x_axis(plt_range, x_axis, params) -> np.ndarray:
if x_axis is None and params is not None:
if plt_range.unit.type in {"WL", "FREQ", "AFREQ"}:
@@ -514,9 +533,7 @@ def mean_values_plot(
mean_style: dict[str, Any] = None,
individual_style: dict[str, Any] = None,
) -> tuple[plt.Line2D, list[plt.Line2D]]:
x_axis, mean_values, values = transform_mean_values(
values, plt_range, params, log, spacing
)
x_axis, mean_values, values = transform_mean_values(values, plt_range, params, log, spacing)
if renormalize and log is False:
maxi = mean_values.max()
mean_values = mean_values / maxi
@@ -585,9 +602,7 @@ def transform_mean_values(
if isinstance(spacing, (float, np.floating)):
tmp_axis = np.linspace(*span(new_axis), int(len(new_axis) / spacing))
values = np.array(
[UnivariateSpline(new_axis, v, k=4, s=0)(tmp_axis) for v in values]
)
values = np.array([UnivariateSpline(new_axis, v, k=4, s=0)(tmp_axis) for v in values])
new_axis = tmp_axis
elif isinstance(spacing, (int, np.integer)) and spacing > 1:
values = values[:, ::spacing]
@@ -644,9 +659,7 @@ def plot_mean(
transpose : bool, optional
rotate the plot 90° counterclockwise, by default False
"""
individual_style = (
defaults["muted_style"] if individual_style is None else individual_style
)
individual_style = defaults["muted_style"] if individual_style is None else individual_style
mean_style = defaults["highlighted_style"] if mean_style is None else mean_style
labels = defaults["avg_line_labels"] if line_labels is None else line_labels
lines = []
@@ -686,9 +699,7 @@ def single_position_plot(
y_label: str = None,
**line_kwargs,
) -> tuple[Figure, Axes, plt.Line2D, np.ndarray, np.ndarray]:
x_axis, values = transform_1D_values(
values, plt_range, x_axis, params, log, spacing
)
x_axis, values = transform_1D_values(values, plt_range, x_axis, params, log, spacing)
if renormalize:
values = values / values.max()
@@ -977,9 +988,7 @@ def uniform_axis(
values = values[:, ind]
else:
if plt_range.unit.type == "WL" and plt_range.conserved_quantity:
values[:, ind] = np.apply_along_axis(
units.to_WL, 1, values[:, ind], tmp_axis
)
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
values = linear_interp_2d(tmp_axis, values[:, ind], new_axis)
return new_axis, values.squeeze()
@@ -1036,18 +1045,14 @@ def prep_plot_axis(
return is_spectrum, plt_range
def white_bottom_cmap(
name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)
):
def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):
"""returns a new colormap based on "name" but that has a solid bacground (default=white)"""
top = plt.get_cmap(name, 1024)
n_bottom = 80
bottom = np.ones((n_bottom, 4))
for i in range(4):
bottom[:, i] = np.linspace(c_back[i], top(start)[i], n_bottom)
return ListedColormap(
np.vstack((bottom, top(np.linspace(start, end, 1024)))), name=new_name
)
return ListedColormap(np.vstack((bottom, top(np.linspace(start, end, 1024)))), name=new_name)
def default_marker_style(k):
@@ -1074,9 +1079,7 @@ def default_marker_style(k):
def arrowstyle(direction=1, color="white"):
return dict(
arrowprops=dict(
arrowstyle="->", connectionstyle=f"arc3,rad={direction*0.3}", color=color
),
arrowprops=dict(arrowstyle="->", connectionstyle=f"arc3,rad={direction*0.3}", color=color),
color=color,
backgroundcolor=(0.5, 0.5, 0.5, 0.5),
)
@@ -1121,9 +1124,7 @@ def measure_and_annotate_fwhm(
_, (left, right), *_ = pulse.find_lobe_limits(unit.inv(t), field)
arrow_label = f"{right - left:.1f} {unit.name}"
annotate_fwhm(
ax, left, right, arrow_label, field.max(), side, arrow_length_pts, arrow_props
)
annotate_fwhm(ax, left, right, arrow_label, field.max(), side, arrow_length_pts, arrow_props)
return right - left
@@ -1143,9 +1144,7 @@ def annotate_fwhm(
if color:
arrow_dict |= dict(color=color)
annotate_kwargs |= dict(color=color)
text_kwargs = (
dict(ha="right" if side == "left" else "left", va="center") | annotate_kwargs
)
text_kwargs = dict(ha="right" if side == "left" else "left", va="center") | annotate_kwargs
if arrow_props is not None:
arrow_dict |= arrow_props
txt = {}