diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index c42e7e2..4b2dc1f 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -9,7 +9,7 @@ from matplotlib.axes import Axes from matplotlib.colors import ListedColormap from matplotlib.figure import Figure from matplotlib.transforms import offset_copy -from scipy.interpolate import UnivariateSpline +from scipy.interpolate import UnivariateSpline, interp1d from scgenerator import math from scgenerator.const import PARAM_SEPARATOR @@ -63,9 +63,7 @@ def plot_setup( # ensure no overwrite ind = 0 - while ( - full_path := (out_dir / (plot_name + f"{PARAM_SEPARATOR}{ind}." + file_type)) - ).exists(): + while (full_path := (out_dir / (plot_name + f"{PARAM_SEPARATOR}{ind}." + file_type))).exists(): ind += 1 if mode == "default": @@ -215,9 +213,7 @@ def create_zoom_axis( return inset -def corner_annotation( - text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0.05, **text_kwargs -): +def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0.05, **text_kwargs): """puts an annotatin in a corner of an ax Parameters ---------- @@ -487,6 +483,29 @@ def transform_2D_propagation( return x_axis[::skip], y_axis, values[:, ::skip] +def uniform_2d( + old_x: np.ndarray, old_y: np.ndarray, new_x: np.ndarray, new_y: np.ndarray, values: np.ndarray +) -> np.ndarray: + """ + interpolates a 2d array according to the provides old and new axis + + Parameters + ---------- + old_x : np.ndarray, shape (n,) + old_y : np.ndarray, shape (m,) + new_x : np.ndarray, shape (N,) + new_y : np.ndarray, shape (M,) + values : np.ndarray, shape (m, n) + + Returns + ------- + np.ndarray, shape (M, N) + """ + values = interp1d(old_x, values, fill_value=0, bounds_error=False, axis=1)(new_x) + values = interp1d(old_y, values, fill_value=0, bounds_error=False, axis=0)(new_y) + return values + + def get_x_axis(plt_range, x_axis, params) -> np.ndarray: if x_axis is None and params is not None: if plt_range.unit.type in {"WL", "FREQ", "AFREQ"}: @@ -514,9 +533,7 @@ def mean_values_plot( mean_style: dict[str, Any] = None, individual_style: dict[str, Any] = None, ) -> tuple[plt.Line2D, list[plt.Line2D]]: - x_axis, mean_values, values = transform_mean_values( - values, plt_range, params, log, spacing - ) + x_axis, mean_values, values = transform_mean_values(values, plt_range, params, log, spacing) if renormalize and log is False: maxi = mean_values.max() mean_values = mean_values / maxi @@ -585,9 +602,7 @@ def transform_mean_values( if isinstance(spacing, (float, np.floating)): tmp_axis = np.linspace(*span(new_axis), int(len(new_axis) / spacing)) - values = np.array( - [UnivariateSpline(new_axis, v, k=4, s=0)(tmp_axis) for v in values] - ) + values = np.array([UnivariateSpline(new_axis, v, k=4, s=0)(tmp_axis) for v in values]) new_axis = tmp_axis elif isinstance(spacing, (int, np.integer)) and spacing > 1: values = values[:, ::spacing] @@ -644,9 +659,7 @@ def plot_mean( transpose : bool, optional rotate the plot 90° counterclockwise, by default False """ - individual_style = ( - defaults["muted_style"] if individual_style is None else individual_style - ) + individual_style = defaults["muted_style"] if individual_style is None else individual_style mean_style = defaults["highlighted_style"] if mean_style is None else mean_style labels = defaults["avg_line_labels"] if line_labels is None else line_labels lines = [] @@ -686,9 +699,7 @@ def single_position_plot( y_label: str = None, **line_kwargs, ) -> tuple[Figure, Axes, plt.Line2D, np.ndarray, np.ndarray]: - x_axis, values = transform_1D_values( - values, plt_range, x_axis, params, log, spacing - ) + x_axis, values = transform_1D_values(values, plt_range, x_axis, params, log, spacing) if renormalize: values = values / values.max() @@ -977,9 +988,7 @@ def uniform_axis( values = values[:, ind] else: if plt_range.unit.type == "WL" and plt_range.conserved_quantity: - values[:, ind] = np.apply_along_axis( - units.to_WL, 1, values[:, ind], tmp_axis - ) + values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis) new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis)) values = linear_interp_2d(tmp_axis, values[:, ind], new_axis) return new_axis, values.squeeze() @@ -1036,18 +1045,14 @@ def prep_plot_axis( return is_spectrum, plt_range -def white_bottom_cmap( - name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1) -): +def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)): """returns a new colormap based on "name" but that has a solid bacground (default=white)""" top = plt.get_cmap(name, 1024) n_bottom = 80 bottom = np.ones((n_bottom, 4)) for i in range(4): bottom[:, i] = np.linspace(c_back[i], top(start)[i], n_bottom) - return ListedColormap( - np.vstack((bottom, top(np.linspace(start, end, 1024)))), name=new_name - ) + return ListedColormap(np.vstack((bottom, top(np.linspace(start, end, 1024)))), name=new_name) def default_marker_style(k): @@ -1074,9 +1079,7 @@ def default_marker_style(k): def arrowstyle(direction=1, color="white"): return dict( - arrowprops=dict( - arrowstyle="->", connectionstyle=f"arc3,rad={direction*0.3}", color=color - ), + arrowprops=dict(arrowstyle="->", connectionstyle=f"arc3,rad={direction*0.3}", color=color), color=color, backgroundcolor=(0.5, 0.5, 0.5, 0.5), ) @@ -1121,9 +1124,7 @@ def measure_and_annotate_fwhm( _, (left, right), *_ = pulse.find_lobe_limits(unit.inv(t), field) arrow_label = f"{right - left:.1f} {unit.name}" - annotate_fwhm( - ax, left, right, arrow_label, field.max(), side, arrow_length_pts, arrow_props - ) + annotate_fwhm(ax, left, right, arrow_label, field.max(), side, arrow_length_pts, arrow_props) return right - left @@ -1143,9 +1144,7 @@ def annotate_fwhm( if color: arrow_dict |= dict(color=color) annotate_kwargs |= dict(color=color) - text_kwargs = ( - dict(ha="right" if side == "left" else "left", va="center") | annotate_kwargs - ) + text_kwargs = dict(ha="right" if side == "left" else "left", va="center") | annotate_kwargs if arrow_props is not None: arrow_dict |= arrow_props txt = {}