diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index a2514d8..a23b87f 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -186,6 +186,9 @@ class PlotRange: right: float = Parameter(type_checker(int, float)) unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit) + def __str__(self): + return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}" + def beta2_coef(beta): fac = 1e27 diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 4ca3d76..322201d 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -580,10 +580,9 @@ def plot_results_1D( file_type: str = "pdf", file_name: str = None, ax: plt.Axes = None, - line_label: str = None, transpose: bool = False, **line_kwargs, -) -> tuple[plt.Figure, plt.Axes, np.ndarray, np.ndarray]: +) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]: """ Parameters @@ -620,13 +619,15 @@ def plot_results_1D( special name to give to the plot. A name is automatically assigned anyway ax : matplotlib.axes._subplots.AxesSubplot object, optional axis on which to draw the plot - line_label : str, optional - label of the line transpose : bool, optional transpose the plot line_kwargs : to be passed to plt.plot returns - fig, ax : matplotlib objects containing the plots + fig, ax, line : matplotlib objects containing the plots + x_axis : np.ndarray + new x axis array + ind : np.ndarray + corresponding indices on the old axis example: if spectra is a (m, n, nt) array, one can plot a spectrum evolution as such: >>> fig, ax = plot_results_2D(spectra[:, -1], (600, 1600, nm), log=True, "Heidt2017") @@ -674,19 +675,21 @@ def plot_results_1D( is_new_plot = ax is None folder_name = "" + if file_name is None: + file_name = params.name + str(plt_range) if is_new_plot: out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type) else: fig = ax.get_figure() if transpose: - ax.plot(values, x_axis, label=line_label, **line_kwargs) + (line,) = ax.plot(values, x_axis, **line_kwargs) ax.yaxis.tick_right() ax.yaxis.set_label_position("right") ax.set_xlim(vmax, vmin) ax.set_xlabel(ylabel) ax.set_ylabel(plt_range.unit.label) else: - ax.plot(x_axis, values, label=line_label, **line_kwargs) + (line,) = ax.plot(x_axis, values, **line_kwargs) ax.set_ylim(vmin, vmax) ax.set_ylabel(ylabel) ax.set_xlabel(plt_range.unit.label) @@ -694,7 +697,7 @@ def plot_results_1D( if is_new_plot: fig.savefig(out_path, bbox_inches="tight", dpi=200) print(f"plot saved in {out_path}") - return fig, ax, x_axis, values + return fig, ax, line, x_axis, values def _prep_plot( diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 1f7d45a..64a59a0 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -7,6 +7,8 @@ from matplotlib.pyplot import subplot from dataclasses import replace import numpy as np +from numpy.lib import utils +from numpy.lib.arraysetops import isin from tqdm.std import Bar from . import initialize, io, math @@ -14,7 +16,7 @@ from .physics import units, pulse from .const import SPECN_FN from .logger import get_logger from .plotting import plot_avg, plot_results_1D, plot_results_2D -from .utils.parameter import BareParams +from .utils.parameter import BareParams, validator_and class Spectrum(np.ndarray): @@ -184,7 +186,7 @@ class Pulse(Sequence): return self.nmax def __getitem__(self, key) -> Spectrum: - return self.all_spectra(ind=range(self.nmax)[key]).squeeze() + return self.all_spectra(key) def intensity(self, unit): if unit.type in ["WL", "FREQ", "AFREQ"]: @@ -253,7 +255,7 @@ class Pulse(Sequence): def _to_time_amp(self, spectrum): return np.fft.ifft(spectrum) - def all_spectra(self, ind=None) -> Spectrum: + def all_spectra(self, ind) -> Spectrum: """ loads the data already simulated. defauft shape is (z_targets, n, nt) @@ -280,6 +282,10 @@ class Pulse(Sequence): ind = self.default_ind if isinstance(ind, (int, np.integer)): ind = [ind] + elif isinstance(ind, (float, np.floating)): + ind = [self.z_ind(ind)] + elif isinstance(ind[0], (float, np.floating)): + ind = [self.z_ind(ii) for ii in ind] # Load the spectra spectra = [] @@ -312,11 +318,11 @@ class Pulse(Sequence): left: float, right: float, unit: Union[Callable[[float], float], str], - z_ind: Union[int, Iterable[int]] = None, + z_pos: Union[int, Iterable[int]] = None, sim_ind: int = 0, **kwargs, ): - plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind) + plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind) return plot_results_2D(vals, plt_range, self.params, **kwargs) def plot_1D( @@ -324,28 +330,47 @@ class Pulse(Sequence): left: float, right: float, unit: Union[Callable[[float], float], str], - z_ind: int, + z_pos: int, sim_ind: int = 0, **kwargs, ): - plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind) - return plot_results_1D(vals[0], plt_range, self.params, **kwargs) + plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind) + return plot_results_1D(vals, plt_range, self.params, **kwargs) def plot_avg( self, left: float, right: float, unit: Union[Callable[[float], float], str], - z_ind: int, + z_pos: int, **kwargs, ): - plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, slice(None)) + plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, slice(None)) return plot_avg(vals, plt_range, self.params, **kwargs) - def retrieve_plot_values(self, left, right, unit, z_ind, sim_ind): + def retrieve_plot_values(self, left, right, unit, z_pos, sim_ind): plt_range = units.PlotRange(left, right, unit) if plt_range.unit.type == "TIME": - vals = self.all_fields(ind=z_ind)[:, sim_ind] + vals = self.all_fields(ind=z_pos) else: - vals = self.all_spectra(ind=z_ind)[:, sim_ind] + vals = self.all_spectra(ind=z_pos) + if vals.ndim == 3: + vals = vals[:, sim_ind] + else: + vals = vals[sim_ind] return plt_range, vals + + def z_ind(self, z: float) -> int: + """return the closest z index to the given target + + Parameters + ---------- + z : float + target + + Returns + ------- + int + index + """ + return math.argclosest(self.z, z)