diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 097aeb5..4a9a18d 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -9,7 +9,7 @@ from .io import Paths, load_toml, load_params from .math import abs2, argclosest, span from .physics import fiber, materials, pulse, simulate, units from .physics.simulate import RK4IP, new_simulation, resume_simulations -from .plotting import plot_avg, plot_results_1D, plot_results_2D, plot_spectrogram +from .plotting import mean_values_plot, single_position_plot, propagation_plot, plot_spectrogram from .spectra import Pulse, Spectrum from .utils.parameter import BareParams, BareConfig from . import utils, io, initialize, math diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 8924af1..6ddb06f 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -14,7 +14,7 @@ from .errors import * from .logger import get_logger from .math import power_fact from .physics import fiber, pulse, units -from .utils import count_variations, override_config, pretty_format_value, required_simulations +from .utils import override_config, required_simulations from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters @@ -320,8 +320,10 @@ class ParamSequence: config_dict : Union[Dict[str, Any], os.PathLike, BareConfig] Can be either a dictionary, a path to a config toml file or BareConfig obj """ - if isinstance(config_dict, BareConfig): + if isinstance(config_dict, Config): self.config = config_dict + elif isinstance(config_dict, BareConfig): + self.config = Config.from_bare(config_dict) else: if not isinstance(config_dict, Mapping): config_dict = io.load_toml(config_dict) @@ -329,7 +331,7 @@ class ParamSequence: self.name = self.config.name self.logger = get_logger(__name__) - self.update_num_sim(count_variations(self.config)) + self.update_num_sim() def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: """iterates through all possible parameters, yielding a config as well as a flattened @@ -343,14 +345,18 @@ class ParamSequence: def __repr__(self) -> str: return f"dispatcher generated from config {self.name}" - def update_num_sim(self, num_sim): + def update_num_sim(self): + num_sim = self.count_variations() self.num_sim = num_sim self.num_steps = self.num_sim * self.config.z_num self.single_sim = self.num_sim == 1 + def count_variations(self) -> int: + return count_variations(self.config) + class ContinuationParamSequence(ParamSequence): - def __init__(self, prev_sim_dir: os.PathLike, new_config_dict: Dict[str, Any]): + def __init__(self, prev_sim_dir: os.PathLike, new_config: BareConfig): """Parameter sequence that builds on a previous simulation but with a new configuration It is recommended that only the fiber and the number of points stored may be changed and changing other parameters could results in unexpected behaviors. The new config doesn't have to @@ -364,30 +370,19 @@ class ContinuationParamSequence(ParamSequence): new config """ self.prev_sim_dir = Path(prev_sim_dir) - init_config = io.load_config(self.prev_sim_dir / "initial_config.toml") - - new_variable_keys = set(new_config_dict.get("variable", {}).keys()) - new_config = utils.override_config(new_config_dict, init_config) - super().__init__(new_config) - additional_sims_factor = int( - np.prod( - [ - len(init_config.variable[k]) - for k in (new_variable_keys & init_config.variable.keys()) - ] - ) - ) - self.update_num_sim(self.num_sim * additional_sims_factor) + self.bare_configs = io.load_config_sequence(new_config.previous_config_file) + self.bare_configs.append(new_config) + self.bare_configs[0] = Config.from_bare(self.bare_configs[0]) + final_config = utils.final_config_from_sequence(*self.bare_configs) + super().__init__(final_config) def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: """iterates through all possible parameters, yielding a config as well as a flattened computed parameters set each time""" - for variable_list, bare_params in required_simulations(self.config): - variable_list.insert(1, ("prev_data_dir", None)) - for prev_data_dir in self.find_prev_data_dirs(variable_list): - variable_list[1] = ("prev_data_dir", str(prev_data_dir.name)) - bare_params.prev_data_dir = str(prev_data_dir.resolve()) - yield variable_list, Params.from_bare(bare_params) + for variable_list, bare_params in required_simulations(*self.bare_configs): + prev_data_dir = self.find_prev_data_dirs(variable_list)[0] + bare_params.prev_data_dir = str(prev_data_dir.resolve()) + yield variable_list, Params.from_bare(bare_params) def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]: """finds the previous simulation data that this new config should start from @@ -419,6 +414,17 @@ class ContinuationParamSequence(ParamSequence): return path_dic[max_in_common] + def count_variations(self) -> int: + return count_variations(*self.bare_configs) + + +def count_variations(*bare_configs: BareConfig) -> int: + sim_num = 1 + for conf in bare_configs: + for l in conf.variable.values(): + sim_num *= len(l) + return sim_num * (bare_configs[0].repeat or 1) + class RecoveryParamSequence(ParamSequence): def __init__(self, config_dict, task_id): @@ -506,7 +512,7 @@ class RecoveryParamSequence(ParamSequence): return path_dic[max_in_common] -def validate_config_sequence(*configs: os.PathLike) -> Tuple[Config, int]: +def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]: """validates a sequence of configs where all but the first one may have parameters missing @@ -517,19 +523,17 @@ def validate_config_sequence(*configs: os.PathLike) -> Tuple[Config, int]: Returns ------- - Dict[str, Any] - the final config as would be simulated, but of course missing input fields in the middle + int + total number of simulations """ + previous = None - variables = set() for config in configs: if (p := Path(config)).is_dir(): config = p / "initial_config.toml" - dico = io.load_toml(config) - previous = Config.from_bare(override_config(dico, previous)) - variables |= {(k, tuple(v)) for k, v in previous.variable.items()} - variables.add(("repeat", range(previous.repeat))) - return previous, int(np.product([len(v) for _, v in variables if len(v) > 0])) + new_conf = io.load_config(config) + previous = Config.from_bare(override_config(new_conf, previous)) + return previous.name, count_variations(*configs) def wspace(t, t_num=0): diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 83f4b71..33561cc 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -162,6 +162,36 @@ def load_config(path: os.PathLike) -> BareConfig: return BareConfig(**config) +def load_config_sequence(*config_paths: os.PathLike) -> list[BareConfig]: + """Loads a sequence of + + Parameters + ---------- + config_paths : os.PathLike + either one path (the last config containing previous_config_file parameter) + or a list of config path in the order they have to be simulated + + Returns + ------- + list[BareConfig] + all loaded configs + """ + if config_paths[0] is None: + return [] + all_configs = [load_config(config_paths[0])] + if len(config_paths) == 1: + while True: + if all_configs[0].previous_config_file is not None: + all_configs.insert(0, load_config(all_configs[0].previous_config_file)) + else: + break + else: + for i, path in enumerate(config_paths[1:]): + all_configs.append(load_config(path)) + all_configs[i + 1].previous_config_file = config_paths[i] + return all_configs + + def load_material_dico(name: str) -> dict[str, Any]: """loads a material dictionary Parameters diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 2ac98bd..04f7eff 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -679,18 +679,11 @@ def run_simulation_sequence( method=None, prev_sim_dir: os.PathLike = None, ): - config_files = list(config_files) - if len(config_files) == 1: - while True: - conf = io.load_toml(config_files[0]) - if (prev := conf.get("previous_config_file")) is not None: - config_files.insert(0, prev) - else: - break + configs = io.load_config_sequence(*config_files) prev = prev_sim_dir - for config_file in config_files: - sim = new_simulation(config_file, prev, method) + for config in configs: + sim = new_simulation(config, prev, method) sim.run() prev = sim.sim_dir path_trees = io.build_path_trees(sim.sim_dir) @@ -703,25 +696,21 @@ def run_simulation_sequence( def new_simulation( - config: Union[dict, os.PathLike], + config: utils.BareConfig, prev_sim_dir=None, method: Type[Simulations] = None, ) -> Simulations: - if isinstance(config, dict): - config_dict = config - else: - config_dict = io.load_toml(config) logger = get_logger(__name__) if prev_sim_dir is not None: - config_dict["prev_sim_dir"] = str(prev_sim_dir) + config.prev_sim_dir = str(prev_sim_dir) task_id = random.randint(1e9, 1e12) if prev_sim_dir is None: - param_seq = initialize.ParamSequence(config_dict) + param_seq = initialize.ParamSequence(config) else: - param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config_dict) + param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config) logger.info(f"running {param_seq.name}") diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index a23b87f..6e040e2 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -159,6 +159,11 @@ def D_ps_nm_km(D: _T) -> _T: return 1e-6 * D +@unit("OTHER", r"a.u.") +def unity(x: _T) -> _T: + return x + + @unit("TEMPERATURE", r"Temperature (K)") def K(t: _T) -> _T: return t @@ -229,7 +234,7 @@ def standardize_dictionary(dico): return dico -def sort_axis(axis, plt_range: PlotRange): +def sort_axis(axis, plt_range: PlotRange) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: """ given an axis, returns this axis cropped according to the given range, converted and sorted diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 322201d..699cde9 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -1,12 +1,17 @@ import os from pathlib import Path +import re from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union +from PIL.Image import new import matplotlib.gridspec as gs import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import ListedColormap +from numpy.core.fromnumeric import mean from scipy.interpolate import UnivariateSpline +from scipy.interpolate.interpolate import interp1d +from tqdm import utils from .logger import get_logger @@ -14,9 +19,32 @@ from . import io, math from .defaults import default_plotting as defaults from .math import abs2, make_uniform_1D, span from .physics import pulse, units -from .utils.parameter import BareParams +from .utils.parameter import BareConfig, BareParams RangeType = Tuple[float, float, Union[str, Callable]] +NO_LIM = object() + + +def get_extent(x, y, facx=1, facy=1): + """ + returns the extent 4-tuple needed for imshow, aligning each pixel + center to the corresponding value assuming uniformly spaced axes + multiplying values by a constant factor is optional + """ + try: + dx = (x[1] - x[0]) / 2 + except IndexError: + dx = 1 + try: + dy = (y[1] - y[0]) / 2 + except IndexError: + dy = 1 + return ( + (np.min(x) - dx) * facx, + (np.max(x) + dx) * facx, + (np.min(y) - dy) * facy, + (np.max(y) + dy) * facy, + ) def plot_setup( @@ -25,12 +53,6 @@ def plot_setup( figsize: Tuple[float, float] = defaults["figsize"], mode: Literal["default", "coherence", "coherence_T"] = "default", ) -> Tuple[Path, plt.Figure, Union[plt.Axes, Tuple[plt.Axes]]]: - """It should return : - - a folder_name - - a file name - - a fig - - an axis - """ out_path = defaults["name"] if out_path is None else out_path out_path = Path(out_path) plot_name = out_path.name.replace(f".{file_type}", "") @@ -236,89 +258,136 @@ def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0 return None -def _finish_plot_2D( - values, - x_axis, - x_label, - y_axis, - y_label, - log, - vmin, - vmax, - transpose, - cmap, - cbar_label, - ax, - file_name, - file_type, -): - logger = get_logger(__name__) - # apply log transform if required +def propagation_plot( + values: np.ndarray, + plt_range: Union[units.PlotRange, RangeType], + params: BareParams, + ax: plt.Axes, + log: Union[int, float, bool, str] = "1D", + vmin: float = None, + vmax: float = None, + transpose: bool = False, + cbar_label: Optional[str] = "normalized intensity (dB)", + cmap: str = None, +) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]: + """transforms and plots a 2D propagation + + Parameters + ---------- + values : np.ndarray + raw values, either complex fields or complex spectra + plt_range : Union[units.PlotRange, RangeType] + time, wavelength or frequency range + params : BareParams + parameters of the simulation + log : Union[int, float, bool, str], optional + what kind of log to apply, see apply_log for details. by default "1D" + vmin : float, optional + minimum value, by default None + vmax : float, optional + maximum value, by default None + transpose : bool, optional + whether to transpose the plot (rotate the plot 90° counterclockwise), by default False + cbar_label : Optional[str], optional + label of the colorbar. No colorbar is drawn if this is set to None, by default "normalized intensity (dB)" + cmap : str, optional + colormap, by default None + ax : plt.Axes, optional + Axes obj on which to draw, by default None + + """ + x_axis, y_axis, values = transform_2D_propagation(values, plt_range, params, log) if log is not False: vmax = defaults["vmax"] if vmax is None else vmax vmin = defaults["vmin"] if vmin is None else vmin - if isinstance(log, (float, int)) and log is not True: - values = units.to_log(values, ref=log) + plot_2D( + values, + x_axis, + y_axis, + ax, + plt_range.unit.label, + "propagation distance (m)", + vmin, + vmax, + transpose, + cmap, + cbar_label, + ) - elif log == "2D": - values = units.to_log2D(values) - elif log == "1D": - values = np.apply_along_axis(units.to_log, 1, values) +def plot_2D( + values: np.ndarray, + x_axis: np.ndarray, + y_axis: np.ndarray, + ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]], + x_label: str = None, + y_label: str = None, + vmin: float = None, + vmax: float = None, + transpose: bool = False, + cmap: str = None, + cbar_label: str = "", +) -> Union[tuple[plt.Axes, plt.Axes], plt.Axes]: + """plots given 2D values in a standard - elif log == "smooth 1D": - ref = np.max(values, axis=1) - ind = np.argmax((ref[:-1] - ref[1:]) < 0) - values = units.to_log(values, ref=np.max(ref[ind:])) + Parameters + ---------- + values : np.ndarray, shape (m, n) + real values to plot + x_axis : np.ndarray, shape (n,) + x axis + y_axis : np.ndarray, shape (m,) + y axis + ax : Union[plt.Axes, tuple[plt.Axes, plt.Axes]] + the ax on which to draw, or a tuple (ax, cbar_ax) where cbar_ax is the ax for the color bar + x_label : str, optional + x label + y_label : str, optional + y label + vmin : float, optional + minimum value (values below are the same color as the minimum of the colormap) + vmax : float, optional + maximum value (values above are the same color as the maximum of the colormap) + transpose : bool, optional + whether to rotate the plot 90° counterclockwise + cmap : str, optional + color map name + cbar_label : str, optional + label of the color bar axes. No color bar is drawn if cbar_label = None - elif log == "unique 1D": - try: - ref = _finish_plot_2D.ref - logger.info(f"recovered reference value {ref} for log plot") - except AttributeError: - ref = np.max(values, axis=1) - ind = np.argmax((ref[:-1] - ref[1:]) < 0) - ref = np.max(ref[ind:]) - _finish_plot_2D.ref = ref - - values = units.to_log(values, ref=ref) + Returns + ------- + Union[tuple[plt.Axes, plt.Axes], plt.Axes] + ax if no color bar is drawn, a tuple (ax, cbar_ax) otherwise + """ + # apply log transform if required cmap = defaults["cmap"] if cmap is None else cmap - is_new_plot = ax is None cbar_ax = None if isinstance(ax, tuple) and len(ax) > 1: ax, cbar_ax = ax[0], ax[1] - if is_new_plot: - out_path, fig, ax = plot_setup(out_path=Path(file_name), file_type=file_type) - else: - fig = ax.get_figure() + fig = ax.get_figure() # Determine grid extent and spacing to be able to center # each pixel since by default imshow draws values at the lower-left corner if transpose: - dy = x_axis[1] - x_axis[0] - ext_y = span(x_axis) - dx = y_axis[1] - y_axis[0] - ext_x = span(y_axis) + extent = get_extent(y_axis, x_axis) values = values.T ax.set_xlabel(y_label) ax.set_ylabel(x_label) else: - dx = x_axis[1] - x_axis[0] - ext_x = span(x_axis) - dy = y_axis[1] - y_axis[0] - ext_y = span(y_axis) + extent = get_extent(x_axis, y_axis) ax.set_ylabel(y_label) ax.set_xlabel(x_label) - ax.set_xlim(*ext_x) - ax.set_ylim(*ext_y) + ax.set_xlim(*extent[:2]) + ax.set_ylim(*extent[2:]) interpolation = defaults["interpolation_2D"] im = ax.imshow( values, - extent=[ext_x[0] - dx / 2, ext_x[1] + dx / 2, ext_y[0] - dy / 2, ext_y[1] + dy / 2], + extent=extent, cmap=cmap, vmin=vmin, vmax=vmax, @@ -335,13 +404,328 @@ def _finish_plot_2D( cbar = fig.colorbar(im, cax=cbar_ax, orientation="vertical") cbar.ax.set_ylabel(cbar_label) - if is_new_plot: - fig.savefig(out_path, bbox_inches="tight", dpi=200) - logger.info(f"plot saved in {out_path}") if cbar_label is not None: - return fig, (ax, cbar.ax) + return ax, cbar.ax else: - return fig, ax + return ax + + +def transform_2D_propagation( + values: np.ndarray, + plt_range: Union[units.PlotRange, RangeType], + params: BareParams, + log: Union[int, float, bool, str] = "1D", +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + + if values.ndim != 2: + raise ValueError(f"shape was {values.shape}. Can only plot 2D array") + is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params) + if is_complex: + values = abs2(values) + + y_axis = params.z_targets + + x_axis, values = uniform_axis(x_axis, values, plt_range) + y_axis, values.T[:] = uniform_axis(y_axis, values.T, None) + values = apply_log(values, log) + return x_axis, y_axis, values + + +def mean_values_plot( + values: np.ndarray, + plt_range: Union[units.PlotRange, RangeType], + params: BareParams, + ax: plt.Axes, + log: Union[float, int, str, bool] = False, + vmin: float = None, + vmax: float = None, + transpose: bool = False, + spacing: Union[float, int] = 1, + renormalize: bool = True, + y_label: str = None, + line_labels: Tuple[str, str] = None, + mean_style: dict[str, Any] = None, + individual_style: dict[str, Any] = None, +) -> tuple[plt.Line2D, list[plt.Line2D]]: + + x_axis, mean_values, values = transform_mean_values(values, plt_range, params, log, spacing) + if renormalize and log is False: + maxi = mean_values.max() + mean_values = mean_values / maxi + values = values / maxi + + if log is not False: + vmax = defaults["vmax_with_headroom"] if vmax is None else vmax + vmin = defaults["vmin"] if vmin is None else vmin + return plot_mean( + values, + mean_values, + x_axis, + ax, + plt_range.unit.label, + y_label, + line_labels, + vmin, + vmax, + transpose, + mean_style, + individual_style, + ) + + +def transform_mean_values( + values: np.ndarray, + plt_range: Union[units.PlotRange, RangeType], + params: BareParams, + log: Union[bool, int, float] = False, + spacing: Union[int, float] = 1, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """transforms values similar to transform_1D_values but with a collection of lines, giving also the mean + + Parameters + ---------- + values : np.ndarray, shape (m, n) + values to transform + plt_range : Union[units.PlotRange, RangeType] + x axis specifications + params : BareParams + parameters of the simulation + log : Union[bool, int, float], optional + see transform_1D_values for details, by default False + spacing : Union[int, float], optional + see transform_1D_values for details, by default 1 + + Returns + ------- + np.ndarray, shape (n,) + x axis + np.ndarray, shape (n,) + mean y values + np.ndarray, shape (m, n) + all the values + """ + if values.ndim != 2: + print(f"Shape was {values.shape}. Can only plot 2D arrays") + return + is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params) + if is_complex: + values = abs2(values) + new_axis, ind, ext = units.sort_axis(x_axis, plt_range) + values = values[:, ind] + if plt_range.unit.type == "WL": + values = np.apply_along_axis(units.to_WL, -1, values, new_axis) + + if isinstance(spacing, (float, np.floating)): + tmp_axis = np.linspace(*span(new_axis), int(len(new_axis) / spacing)) + values = np.array([UnivariateSpline(new_axis, v, k=4, s=0)(tmp_axis) for v in values]) + new_axis = tmp_axis + elif isinstance(spacing, (int, np.integer)) and spacing > 1: + values = values[:, ::spacing] + new_axis = new_axis[::spacing] + + mean_values = np.mean(values, axis=0) + + if log is not False: + if log is not True and isinstance(log, (int, float, np.integer, np.floating)): + ref = log + else: + ref = mean_values.max() + values = apply_log(values, ref) + mean_values = apply_log(mean_values, ref) + return new_axis, mean_values, values + + +def plot_mean( + values: np.ndarray, + mean_values: np.ndarray, + x_axis: np.ndarray, + ax: plt.Axes, + x_label: str = None, + y_label: str = None, + line_labels: tuple[str, str] = None, + vmin: float = None, + vmax: float = None, + transpose: bool = False, + mean_style: dict[str, Any] = None, + individual_style: dict[str, Any] = None, +) -> tuple[plt.Line2D, list[plt.Line2D]]: + """plots already transformed 1D values + + Parameters + ---------- + values : np.ndarray, shape (m, n) + values to plot + mean_values : np.ndarray, shape (n,) + values to plot + x_axis : np.ndarray, shape (n,) + corresponding x axis + ax : plt.Axes + ax on which to plot + x_label : str, optional + x label, by default None + y_label : str, optional + y label, by default None + line_labels: tuple[str, str] + label of the mean line and the individual lines, by default None + vmin : float, optional + minimum y limit, by default None + vmax : float, optional + maximum y limit, by default None + transpose : bool, optional + rotate the plot 90° counterclockwise, by default False + """ + individual_style = defaults["muted_style"] if individual_style is None else individual_style + mean_style = defaults["highlighted_style"] if mean_style is None else mean_style + labels = defaults["avg_line_labels"] if line_labels is None else line_labels + lines = [] + if transpose: + for value in values[:-1]: + lines += ax.plot(value, x_axis, **individual_style) + lines += ax.plot(values[-1], x_axis, **individual_style) + (mean_line,) = ax.plot(mean_values, x_axis, **mean_style) + ax.set_xlim(vmax, vmin) + ax.yaxis.tick_right() + ax.set_xlabel(y_label) + ax.set_ylabel(x_label) + else: + for value in values[:-1]: + lines += ax.plot(x_axis, value, **individual_style) + lines += ax.plot(x_axis, values[-1], label=labels[0], **individual_style) + (mean_line,) = ax.plot(x_axis, mean_values, label=labels[1], **mean_style) + ax.set_ylim(vmin, vmax) + ax.set_ylabel(y_label) + ax.set_xlabel(x_label) + + return mean_line, lines + + +def single_position_plot( + values: np.ndarray, + plt_range: Union[units.PlotRange, RangeType], + params: BareParams, + ax: plt.Axes, + log: Union[str, int, float, bool] = False, + vmin: float = None, + vmax: float = None, + transpose: bool = False, + spacing: Union[int, float] = 1, + renormalize: bool = False, + y_label: str = None, + **line_kwargs, +) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]: + + x_axis, values = transform_1D_values(values, plt_range, params, log, spacing) + if renormalize: + values = values / values.max() + + if log is not False: + vmax = defaults["vmax_with_headroom"] if vmax is None else vmax + vmin = defaults["vmin"] if vmin is None else vmin + + return plot_1D( + values, x_axis, ax, plt_range.unit.label, y_label, vmin, vmax, transpose, **line_kwargs + ) + + +def plot_1D( + values: np.ndarray, + x_axis: np.ndarray, + ax: plt.Axes, + x_label: str = None, + y_label: str = None, + vmin: float = None, + vmax: float = None, + transpose: bool = False, + **line_kwargs, +) -> plt.Line2D: + """plots already transformed 1D values + + Parameters + ---------- + values : np.ndarray, shape (n,) + values to plot + x_axis : np.ndarray, shape (n,) + corresponding x axis + ax : plt.Axes, + ax on which to plot + x_label : str, optional + x label + y_label : str, optional + y label + vmin : float, optional + minimum y limit, by default None + vmax : float, optional + maximum y limit, by default None + transpose : bool, optional + rotate the plot 90° counterclockwise, by default False + """ + if transpose: + (line,) = ax.plot(values, x_axis, **line_kwargs) + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + ax.set_xlim(vmax, vmin) + ax.set_xlabel(y_label) + ax.set_ylabel(x_label) + else: + (line,) = ax.plot(x_axis, values, **line_kwargs) + ax.set_ylim(vmin, vmax) + ax.set_ylabel(y_label) + ax.set_xlabel(x_label) + return line + + +def transform_1D_values( + values: np.ndarray, + plt_range: Union[units.PlotRange, RangeType], + params: BareParams, + log: Union[int, float, bool] = False, + spacing: Union[int, float] = 1, +) -> tuple[np.ndarray, np.ndarray]: + """transforms raw values to be plotted + + Parameters + ---------- + values : np.ndarray, shape (n,) + values to plot, may be complex + plt_range : Union[units.PlotRange, RangeType] + plot range specification, either (min, max, unit) or a PlotRange obj + params : BareParams + parameters of the simulations + log : Union[int, float, bool], optional + if True, will convert to dB relative to max. If a float or int, whill + convert to dB relative to that number, by default False + spacing : Union[int, float], optional + change the resolution by either taking only 1 every `spacing` value (int) or + multiplying the original spacing between point by `spacing` and interpolating + + Returns + ------- + tuple[np.ndarray, np.ndarray] + x axis and values + """ + if len(values.shape) != 1: + print(f"Shape was {values.shape}. Can only plot 1D arrays") + return + is_complex, x_axis, plt_range = prep_plot_axis(values, plt_range, params) + if is_complex: + values = abs2(values) + new_axis, ind, ext = units.sort_axis(x_axis, plt_range) + values = values[ind] + if plt_range.unit.type == "WL": + values = units.to_WL(values, new_axis) + + if isinstance(spacing, (float, np.floating)): + tmp_axis = np.linspace(*span(new_axis), int(len(new_axis) / spacing)) + values = UnivariateSpline(new_axis, values, k=4, s=0)(tmp_axis) + new_axis = tmp_axis + elif isinstance(spacing, (int, np.integer)) and spacing > 1: + values = values[::spacing] + new_axis = new_axis[::spacing] + + if isinstance(log, str): + log = True + values = apply_log(values, log) + return new_axis, values def plot_spectrogram( @@ -351,357 +735,190 @@ def plot_spectrogram( params: BareParams, t_res: int = None, gate_width: float = None, - log: bool = True, + log: bool = "2D", vmin: float = None, vmax: float = None, cbar_label: str = "normalized intensity (dB)", - file_type: str = "png", - file_name: str = "plot", cmap: str = None, ax: plt.Axes = None, ): """Plots a spectrogram given a complex field in the time domain Parameters ---------- - values : 2D array - axis 0 defines the position in the fiber and axis 1 the position in time, frequency or wl - example : [[1, 2, 3], [0, 1, 0]] describes a quantity at 3 different freq/time and at two locations in the fiber - x_range, y_range : tupple (min, max, units) - one of them must be time, the other one must be wl/freq - min, max : int or float - minimum and maximum values given in the desired units - units : function to convert from the desired units to rad/s or to time. - common functions are already defined in scgenerator.physics.units - look there for more details - params : BareParams - parameters of the simulations - log : bool, optional - whether to compute the logarithm of the spectrogram - vmin : float, optional - min value of the colorbar - vmax : float, optional - max value of the colorbar - cbar_label : str or None - label of the colorbar. Will not draw colorbar if None - file_type : str, optional - usually pdf or png - plt_name : str, optional - special name to give to the plot. A name is automatically assigned anyway - cmap : str, optional - colormap to be used in matplotlib.pyplot.imshow - ax : matplotlib.axes._subplots.AxesSubplot object or tupple of 2 axis objects, optional - axis on which to draw the plot - if only one is given, a new one will be created to draw the colorbar + values : 2D array + axis 0 defines the position in the fiber and axis 1 the position in time, frequency or wl + example : [[1, 2, 3], [0, 1, 0]] describes a quantity at 3 different freq/time and at two locations in the fiber + x_range, y_range : tupple (min, max, units) + one of them must be time, the other one must be wl/freq + min, max : int or float + minimum and maximum values given in the desired units + units : function to convert from the desired units to rad/s or to time. + common functions are already defined in scgenerator.physics.units + look there for more details + params : BareParams + parameters of the simulations + log : bool, optional + whether to compute the logarithm of the spectrogram + vmin : float, optional + min value of the colorbar + vmax : float, optional + max value of the colorbar + cbar_label : str or None + label of the colorbar. Will not draw colorbar if None + file_type : str, optional + usually pdf or png + plt_name : str, optional + special name to give to the plot. A name is automatically assigned anyway + cmap : str, optional + colormap to be used in matplotlib.pyplot.imshow + ax : matplotlib.axes._subplots.AxesSubplot object or tupple of 2 axis objects, optional + axis on which to draw the plot + if only one is given, a new one will be created to draw the colorbar """ if values.ndim != 1: print("plot_spectrogram can only plot 1D arrays") return + x_range: units.PlotRange + y_range: units.PlotRange + _, x_axis, x_range = prep_plot_axis(values, x_range, params) + _, y_axis, y_range = prep_plot_axis(values, y_range, params) - if (x_range[2].type == "TIME") == (y_range[2].type == "TIME"): + if (x_range.unit.type == "TIME") == (y_range.unit.type == "TIME"): print("exactly one range must be a time range") return - log = "2D" if log in ["2D", True] else False - # 0 axis means x-axis -> determine final orientation of spectrogram - time_axis = 0 if x_range[2].type not in ["WL", "FREQ", "AFREQ"] else 1 + time_axis = 0 if x_range.unit.type not in ["WL", "FREQ", "AFREQ"] else 1 if time_axis == 0: - t_range, f_range = x_range, y_range + t_range = x_range else: - t_range, f_range = y_range, x_range + t_range = y_range # Actually compute the spectrogram - t_win = 2 * np.max(t_range[2](np.abs(t_range[:2]))) + t_win = 2 * np.max(t_range.unit(np.abs((t_range.left, t_range.right)))) spec_kwargs = dict(t_res=t_res, t_win=t_win, gate_width=gate_width, shift=False) spec, new_t = pulse.spectrogram( params.t.copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None} ) - - # Crop and reoder axis - new_t, ind_t, _ = units.sort_axis(new_t, t_range) - new_f, ind_f, _ = units.sort_axis(params.w, f_range) - values = spec[ind_t][:, ind_f] - if f_range[2].type == "WL": - values = np.apply_along_axis(units.to_WL, 1, values, units.m(f_range[2].inv(new_f))) - values = np.apply_along_axis(make_uniform_1D, 1, values, new_f) - if time_axis == 0: - x_axis, y_axis = new_t, new_f - values = values.T + x_axis = new_t else: - x_axis, y_axis = new_f, new_t + y_axis = new_t - return _finish_plot_2D( + x_axis, spec = uniform_axis(x_axis, spec, x_range) + y_axis, spec.T[:] = uniform_axis(y_axis, spec.T, y_range) + + values = apply_log(spec, log) + + return plot_2D( values, x_axis, - x_range[2].label, y_axis, - y_range[2].label, - log, + ax, + x_range.unit.label, + y_range.unit.label, vmin, vmax, False, cmap, cbar_label, - ax, - file_name, - file_type, ) -def plot_results_2D( - values: np.ndarray, - plt_range: Union[units.PlotRange, tuple], - params: BareParams, - log: Union[int, float, bool, str] = "1D", - skip: int = 1, - vmin: float = None, - vmax: float = None, - transpose: bool = False, - cbar_label: Optional[str] = "normalized intensity (dB)", - file_type: str = "png", - file_name: str = "plot", - cmap: str = None, - ax: plt.Axes = None, -): - """ - plots 2D arrays and automatically saves the plots, as well as returns it +def uniform_axis( + axis: np.ndarray, values: np.ndarray, new_axis_spec: Union[units.PlotRange, RangeType, str] +) -> tuple[np.ndarray, np.ndarray]: + """given some values(axis), creates a new uniformly spaced axis and interpolates + the values over it. Parameters ---------- - values : 2D array - axis 0 defines the position in the fiber and axis 1 the position in time, frequency or wl - example : [[1, 2, 3], [0, 1, 0]] describes a quantity at 3 different freq/time and at two locations in the fiber - plt_range : tupple (min, max, units) - min, max : int or float - minimum and maximum values given in the desired units - units : function to convert from the desired units to rad/s or to time. - common functions are already defined in scgenerator.physics.units - look there for more details - params : dict - parameters of the simulations - log : str {"1D", "2D", "smooth 1D"} or int, float or bool, optional - str : plot in dB - 1D : takes the log for every slice - 2D : takes the log for the whole 2D array - smooth 1D : figures out a smart reference value for the whole 2D array - int, float : plot in dB - reference value - bool : whether to use 1D variant or nothing - skip : int, optional - take 1 every skip values along the -1 axis - vmin : float, optional - min value of the colorbar - vmax : float, optional - max value of the colorbar - cbar_label : str or None - label of the colorbar. Will not draw colorbar if None - file_type : str, optional - usually pdf or png - plt_name : str, optional - special name to give to the plot. A name is automatically assigned anyway - cmap : str, optional - colormap to be used in matplotlib.pyplot.imshow - ax : matplotlib.axes._subplots.AxesSubplot object or tupple of 2 axis objects, optional - axis on which to draw the plot - if only one is given, a new one will be created to draw the colorbar - returns - fig, ax : matplotlib objects containing the plots - example: - if spectra is a (m, n, nt) array, one can plot a spectrum evolution as such: - >>> fig, ax = plot_results_2D(spectra[:, -1], (600, 1600, nm), log=True, "Heidt2017") + axis : np.ndarray, shape (n,) + grid points to which values correspond + values : np.ndarray, shape (n,) or (m, n) + values as function of axis + new_axis_spec : Union[PlotRange, RangeType, str] + specifications of the new axis. May be None, a unit as a str, + a tuple (min, max, unit) or a PlotRange obj + + Returns + ------- + tuple[np.ndarray, np.ndarray] + new axis and new values + + Raises + ------ + TypeError + invalid new_axis_spec """ - - if values.ndim != 2: - print(f"Shape was {values.shape}. plot_results_2D can only plot 2D arrays") - return - - is_spectrum, x_axis, plt_range = _prep_plot(values, plt_range, params) - - # crop and convert - x_axis, ind, ext = units.sort_axis(x_axis[::skip], plt_range) - values = values[:, ::skip][:, ind] - if is_spectrum: - values = abs2(values) - - # make uniform if converting to wavelength - if plt_range.unit.type == "WL": - if is_spectrum: - values = np.apply_along_axis(units.to_WL, 1, values, x_axis) - values = np.array( - [make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values] - ) - - lim_diff = 1e-5 * np.max(params.z_targets) - dz_s = np.diff(params.z_targets) - if not np.all(np.diff(dz_s) < lim_diff): - new_z = np.linspace( - *span(params.z_targets), - int( - np.floor( - (np.max(params.z_targets) - np.min(params.z_targets)) - / np.min(dz_s[dz_s > lim_diff]) - ) - ), - ) - values = np.array( - [make_uniform_1D(v, params.z_targets, n=len(new_z), method="linear") for v in values.T] - ).T - params.z_targets = new_z - return _finish_plot_2D( - values, - x_axis, - plt_range.unit.label, - params.z_targets, - "propagation distance (m)", - log, - vmin, - vmax, - transpose, - cmap, - cbar_label, - ax, - file_name, - file_type, - ) - - -def plot_results_1D( - values: np.ndarray, - plt_range: Union[units.PlotRange, tuple], - params: BareParams, - log: Union[str, int, float, bool] = False, - spacing: Union[int, float] = 1, - vmin: float = None, - vmax: float = None, - ylabel: str = None, - yscaling: float = 1, - file_type: str = "pdf", - file_name: str = None, - ax: plt.Axes = None, - transpose: bool = False, - **line_kwargs, -) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]: - """ - - Parameters - ---------- - values : 1D array - if values are complex, the abs^2 is computed before plotting - plt_range : tupple (min, max, units) - min, max : int or float - minimum and maximum values given in the desired units - units : function to convert from the desired units to rad/s or to time. - common functions are already defined in scgenerator.physics.units - look there for more details - params : dict - parameters of the simulations - log : str {"1D"} or int, float or bool, optional - str : plot in dB - 1D : takes the log for every slice - int, float : plot in dB - reference value - bool : whether to use 1D variant or nothing - spacing : int, float, optional - tells the function to take one value every `spacing` one available. If a float is given, it will interpolate with a spline. - vmin : float, optional - min value of the colorbar - vmax : float, optional - max value of the colorbar - ylabel : str, optional - label of the y axis (x axis in transposed mode). Default is "normalized intensity (dB)" for log plots - yscaling : float, optional - scale the y values by this amount - file_type : str, optional - usually pdf or png - plt_name : str, optional - special name to give to the plot. A name is automatically assigned anyway - ax : matplotlib.axes._subplots.AxesSubplot object, optional - axis on which to draw the plot - transpose : bool, optional - transpose the plot - line_kwargs : to be passed to plt.plot - returns - fig, ax, line : matplotlib objects containing the plots - x_axis : np.ndarray - new x axis array - ind : np.ndarray - corresponding indices on the old axis - example: - if spectra is a (m, n, nt) array, one can plot a spectrum evolution as such: - >>> fig, ax = plot_results_2D(spectra[:, -1], (600, 1600, nm), log=True, "Heidt2017") - """ - - if len(values.shape) != 1: - print(f"Shape was {values.shape}. plot_results_1D can only plot 1D arrays") - return - - is_spectrum, x_axis, plt_range = _prep_plot(values, plt_range, params) - - # crop and convert - x_axis, ind, ext = units.sort_axis(x_axis, plt_range) - values = values[ind] - if is_spectrum: - values = abs2(values) - values *= yscaling - - # make uniform if converting to wavelength - if plt_range.unit.type == "WL": - if is_spectrum: - values = units.to_WL(values, units.m.inv(params.w[ind])) - - # change the resolution - if isinstance(spacing, float): - new_x_axis = np.linspace(*span(x_axis), int(len(x_axis) / spacing)) - values = UnivariateSpline(x_axis, values, k=4, s=0)(new_x_axis) - x_axis = new_x_axis - elif isinstance(spacing, int) and spacing > 1: - values = values[::spacing] - x_axis = x_axis[::spacing] - - # apply log transform if required - if log == False: - pass + if new_axis_spec is None: + new_axis_spec = "unity" + if isinstance(new_axis_spec, str) or callable(new_axis_spec): + unit = units.get_unit(new_axis_spec) + plt_range = units.PlotRange(unit.inv(axis.min()), unit.inv(axis.max()), new_axis_spec) + elif isinstance(new_axis_spec, tuple): + plt_range = units.PlotRange(*new_axis_spec) + elif isinstance(new_axis_spec, units.PlotRange): + plt_range = new_axis_spec else: - ylabel = "normalized intensity (dB)" if ylabel is None else ylabel - vmax = defaults["vmax_with_headroom"] if vmax is None else vmax - vmin = defaults["vmin"] if vmin is None else vmin - if isinstance(log, (float, int)) and log != True: + raise TypeError(f"Don't know how to interpret {new_axis_spec}") + tmp_axis, ind, ext = units.sort_axis(axis, plt_range) + if np.allclose((diff := np.diff(tmp_axis))[0], diff): + new_axis = tmp_axis + else: + values = np.atleast_2d(values) + if plt_range.unit.type == "WL": + values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis) + new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis)) + values = np.array([interp1d(tmp_axis, v[ind])(new_axis) for v in values]) + return new_axis, values.squeeze() + + +def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarray: + """apply log transform + + Parameters + ---------- + values : np.ndarray + input array + log : Union[str, bool, float, int] + True -> "1D" + "1D" -> each row has its own reference value + "smooth 1D" -> attempted compromise between 2D and 1D. Will clip the highest values + "2D" -> same reference value for the whole 2D array + float, int -> take this value as the reference + False -> don't apply log + + Returns + ------- + np.ndarray + values with log applied + + Raises + ------ + ValueError + unrecognized log argument + """ + + if log is not False: + if isinstance(log, (float, int, np.floating, np.integer)) and log is not True: values = units.to_log(values, ref=log) + elif log == "2D": + values = units.to_log2D(values) + elif log == "1D" or log is True: + values = np.apply_along_axis(units.to_log, -1, values) + elif log == "smooth 1D": + ref = np.max(values, axis=1) + ind = np.argmax((ref[:-1] - ref[1:]) < 0) + values = units.to_log(values, ref=np.max(ref[ind:])) else: - values = units.to_log(values) - - is_new_plot = ax is None - - folder_name = "" - if file_name is None: - file_name = params.name + str(plt_range) - if is_new_plot: - out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type) - else: - fig = ax.get_figure() - if transpose: - (line,) = ax.plot(values, x_axis, **line_kwargs) - ax.yaxis.tick_right() - ax.yaxis.set_label_position("right") - ax.set_xlim(vmax, vmin) - ax.set_xlabel(ylabel) - ax.set_ylabel(plt_range.unit.label) - else: - (line,) = ax.plot(x_axis, values, **line_kwargs) - ax.set_ylim(vmin, vmax) - ax.set_ylabel(ylabel) - ax.set_xlabel(plt_range.unit.label) - - if is_new_plot: - fig.savefig(out_path, bbox_inches="tight", dpi=200) - print(f"plot saved in {out_path}") - return fig, ax, line, x_axis, values + raise ValueError(f"Log argument {log} not recognized") + return values -def _prep_plot( - values: np.ndarray, plt_range: Union[units.PlotRange, tuple], params: BareParams +def prep_plot_axis( + values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], params: BareParams ) -> tuple[bool, np.ndarray, units.PlotRange]: is_spectrum = values.dtype == "complex" if not isinstance(plt_range, units.PlotRange): @@ -713,273 +930,6 @@ def _prep_plot( return is_spectrum, x_axis, plt_range -def plot_avg( - values: np.ndarray, - plt_range: Union[units.PlotRange, tuple], - params: BareParams, - log: Union[float, int, str, bool] = False, - spacing: Union[float, int] = 1, - vmin: float = None, - vmax: float = None, - ylabel: str = None, - yscaling: float = 1, - renormalize: bool = True, - add_coherence: bool = False, - file_type: str = "png", - file_name: str = None, - ax: plt.Axes = None, - line_labels: Tuple[str, str] = None, - legend: bool = True, - legend_kwargs: Dict[str, Any] = {}, - transpose: bool = False, -): - """ - plots 1D arrays and there mean and automatically saves the plots, as well as returns it - - Parameters - ---------- - values : 2D array - axis 0 defines the position in the fiber and axis 1 the position in time, frequency or wl - example : [[1, 2, 3], [0, 1, 0]] describes a quantity at 3 different freq/time and at two locations in the fiber - plt_range : tupple (min, max, units) - min, max : int or float - minimum and maximum values given in the desired units - units : function to convert from the desired units to rad/s or to time. - common functions are already defined in scgenerator.physics.units - look there for more details - params : dict - parameters of the simulations - log : str {"1D"} or int, float or bool, optional - str : plot in dB - 1D : takes the log for every slice - int, float : plot in dB - reference value - bool : whether to use 1D variant or nothing - spacing : int, float, optional - tells the function to take one value every `spacing` one available. If a float is given, it will interpolate with a spline. - vmin : float, optional - min value of the colorbar - vmax : float, optional - max value of the colorbar - ylabel : str, optional - label of the y axis (x axis in transposed mode). Default is 'normalized intensity (dB)' for log plots - yscaling : float, optional - scale the y values by this amount - renormalize : bool, optional - if converting to wl scale, renormalize with to_WL function to ensure energy is conserved - add_coherence : bool, optional - whether to add a subplot with coherence - file_type : str, optional - usually pdf or png - plt_name : str, optional - special name to give to the plot. A name is automatically assigned anyway - ax : matplotlib.axes._subplots.AxesSubplot object, optional - axis on which to draw the plot - line_labels : tupple(str), optional - label of the lines. line_labels[0] is the label of the mean and line_labels[1] is the label of the indiv. values - legend : bool, optional - whether to draw the legend - transpose : bool, optional - transpose the plot - returns - fig, ax : matplotlib objects containing the plots - example: - if spectra is a (m, n, nt) array, one can plot a spectrum evolution as such: - >>> fig, ax = plot_results_2D(spectra[:, -1], (600, 1600, nm), log=True, "Heidt2017") - """ - - if len(values.shape) != 2: - print(f"Shape was {values.shape}. plot_avg can only plot 2D arrays") - return - - is_spectrum, x_axis, plt_range = _prep_plot(values, plt_range, params) - - # crop and convert - x_axis, ind, ext = units.sort_axis(x_axis, plt_range) - if add_coherence: - coherence = pulse.g12(values) - coherence = coherence[ind] - else: - coherence = None - values = values[:, ind] - - is_new_plot = ax is None - folder_name = "" - original_lines = [] - - # compute the mean spectrum - if is_spectrum: - values = abs2(values) - values *= yscaling - mean_values = np.mean(values, axis=0) - if plt_range.unit.type == "WL" and renormalize: - values = np.apply_along_axis(units.to_WL, 1, values, x_axis) - mean_values = units.to_WL(mean_values, x_axis) - - # change the resolution - if isinstance(spacing, float): - new_x_axis = np.linspace(*span(x_axis), int(len(x_axis) / spacing)) - values = np.array( - [UnivariateSpline(x_axis, value, k=4, s=0)(new_x_axis) for value in values] - ) - if add_coherence: - coherence = UnivariateSpline(x_axis, coherence, k=4, s=0)(new_x_axis) - mean_values = np.mean(values, axis=0) - x_axis = new_x_axis - elif isinstance(spacing, int) and spacing > 1: - values = values[:, ::spacing] - mean_values = mean_values[::spacing] - x_axis = x_axis[::spacing] - if add_coherence: - coherence = coherence[::spacing] - - # apply log transform if required - if log != False: - ylabel = "normalized intensity (dB)" if ylabel is None else ylabel - vmax = defaults["vmax_with_headroom"] if vmax is None else vmax - vmin = defaults["vmin"] if vmin is None else vmin - if isinstance(log, (float, int)) and log != True: - ref = log - else: - ref = np.max(mean_values) - values = units.to_log(values, ref=ref) - mean_values = units.to_log(mean_values, ref=ref) - - if is_new_plot: - if add_coherence: - mode = "coherence_T" if transpose else "coherence" - out_path, fig, (top, bot) = plot_setup( - out_path=Path(folder_name) / file_name, file_type=file_type, mode=mode - ) - else: - out_path, fig, top = plot_setup( - out_path=Path(folder_name) / file_name, file_type=file_type - ) - bot = top - else: - if isinstance(ax, (tuple, list)): - top, bot = ax - if transpose: - bot.set_xlim(1.1, -0.1) - bot.set_xlabel(r"|$g_{12}$|") - else: - bot.set_ylim(-0.1, 1.1) - bot.set_ylabel(r"|$g_{12}$|") - else: - bot, top = ax, ax - - fig = top.get_figure() - original_lines = top.get_lines() - - # Actual Plotting - - gray_style = defaults["muted_style"] - highlighted_style = defaults["highlighted_style"] - - if transpose: - for value in values: - top.plot(value, x_axis, **gray_style) - top.plot(mean_values, x_axis, **highlighted_style) - if add_coherence: - bot.plot(coherence, x_axis, c=defaults["color_cycle"][0]) - - top.set_xlim(left=vmax, right=vmin) - top.yaxis.tick_right() - top.set_xlabel(ylabel) - top.set_ylim(*ext) - bot.yaxis.tick_right() - bot.yaxis.set_label_position("right") - bot.set_ylabel(plt_range.unit.label) - bot.set_ylim(*ext) - else: - for value in values: - top.plot(x_axis, value, **gray_style) - top.plot(x_axis, mean_values, **highlighted_style) - if add_coherence: - bot.plot(x_axis, coherence, c=defaults["color_cycle"][0]) - - top.set_ylim(bottom=vmin, top=vmax) - top.set_ylabel(ylabel) - top.set_xlim(*ext) - bot.set_xlabel(plt_range.unit.label) - bot.set_xlim(*ext) - - custom_lines = [ - plt.Line2D([0], [0], lw=2, c=gray_style["c"]), - plt.Line2D([0], [0], lw=2, c=highlighted_style["c"]), - ] - line_labels = defaults["avg_line_labels"] if line_labels is None else line_labels - line_labels = list(line_labels) - - if not is_new_plot: - custom_lines += original_lines - line_labels += [l.get_label() for l in original_lines] - - if legend: - top.legend(custom_lines, line_labels, **legend_kwargs) - - if is_new_plot: - fig.savefig(out_path, bbox_inches="tight", dpi=200) - print(f"plot saved in {out_path}") - - if top is bot: - return fig, top - else: - return fig, (top, bot) - - -def prepare_plot_1D(values, plt_range, x_axis, yscaling=1, spacing=1, frep=80e6): - """prepares the values for plotting - Parameters - ---------- - values : array - the values to plot. - if complex, will take the abs^2 - if 2D, will consider it a as a list of values, each corresponding to the same x_axis - plt_range : tupple (float, float, fct) - fct as defined in scgenerator.physics.units - x_axis : 1D array - the corresponding x_axis - yscaling : float, optional - scale the y values by this amount - spacing : int, float, optional - tells the function to take one value every `spacing` one available. If a float is given, it will interpolate with a spline. - frep : float - used for conversion between frequency and wavelength if necessary - Returns - ---------- - new_x_axis : array - new_values : array - """ - is_spectrum = values.dtype == "complex" - - unique = len(values.shape) == 1 - values = np.atleast_2d(values) - - x_axis, ind, ext = units.sort_axis(x_axis, plt_range) - - if is_spectrum: - values = abs2(values) - values *= yscaling - - values = values[:, ind] - - if plt_range.unit.type == "WL": - values = np.apply_along_axis(units.to_WL, -1, values, x_axis) - - if isinstance(spacing, float): - new_x_axis = np.linspace(*span(x_axis), int(len(x_axis) / spacing)) - values = np.array( - [UnivariateSpline(x_axis, value, k=4, s=0)(new_x_axis) for value in values] - ) - x_axis = new_x_axis - elif isinstance(spacing, int) and spacing > 1: - values = values[:, ::spacing] - x_axis = x_axis[::spacing] - - return x_axis, np.squeeze(values) - - def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)): """returns a new colormap based on "name" but that has a solid bacground (default=white)""" top = plt.get_cmap(name, 1024) diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index dace2e8..66a5307 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -14,6 +14,7 @@ from ..initialize import ParamSequence from ..physics import units, fiber from ..spectra import Pulse from ..utils import pretty_format_value, pretty_format_from_file_name, auto_crop +from ..plotting import plot_setup from .. import env, math @@ -33,20 +34,22 @@ def plot_all(sim_dir: Path, limits: list[str], **opts): limits = [ tuple(func(el) for func, el in zip([float, float, str], lim.split(","))) for lim in limits ] - print(limits) with tqdm(total=len(dir_list) * len(limits)) as bar: for p in dir_list: pulse = Pulse(p) for left, right, unit in limits: + path, fig, ax = plot_setup( + pulse.path.parent / f"{pulse.path.name}_{left:.1f}_{right:.1f}_{unit}" + ) pulse.plot_2D( left, right, unit, - file_name=p.parent - / f"{pretty_format_from_file_name(p.name)} {left} {right} {unit}", + ax, **opts, ) bar.update() + fig.savefig(path, bbox_inches="tight") plt.close("all") diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index 5ca8092..eee512f 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -127,22 +127,20 @@ def main(): ) if args.command == "merge": - final_config = load_config(Path(args.configs[0]) / "initial_config.toml") + final_name = load_config(Path(args.configs[0]) / "initial_config.toml").name sim_num = "many" args.nodes = 1 args.cpus_per_node = 1 else: config_paths = args.configs - final_config, sim_num = validate_config_sequence(*config_paths) + final_name, sim_num = validate_config_sequence(*config_paths) args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node) - submit_path = Path( - "submit " + final_config.name + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh" - ) + submit_path = Path("submit " + final_name + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh") tmp_path = Path("submit tmp.sh") - job_name = f"supercontinuum {final_config.name}" + job_name = f"supercontinuum {final_name}" submit_sh = template.format( job_name=job_name, configs_list=" ".join(f'"{c}"' for c in args.configs), **vars(args) ) diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 64a59a0..ed1897a 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -1,22 +1,17 @@ import os from collections.abc import Sequence from pathlib import Path -from re import UNICODE -from typing import Callable, Dict, Iterable, Optional, Union -from matplotlib.pyplot import subplot -from dataclasses import replace +from typing import Callable, Dict, Iterable, Union +import matplotlib.pyplot as plt import numpy as np -from numpy.lib import utils -from numpy.lib.arraysetops import isin -from tqdm.std import Bar from . import initialize, io, math -from .physics import units, pulse from .const import SPECN_FN from .logger import get_logger -from .plotting import plot_avg, plot_results_1D, plot_results_2D -from .utils.parameter import BareParams, validator_and +from .physics import pulse, units +from .plotting import mean_values_plot, propagation_plot, single_position_plot +from .utils.parameter import BareParams class Spectrum(np.ndarray): @@ -255,7 +250,7 @@ class Pulse(Sequence): def _to_time_amp(self, spectrum): return np.fft.ifft(spectrum) - def all_spectra(self, ind) -> Spectrum: + def all_spectra(self, ind=None) -> Spectrum: """ loads the data already simulated. defauft shape is (z_targets, n, nt) @@ -318,35 +313,38 @@ class Pulse(Sequence): left: float, right: float, unit: Union[Callable[[float], float], str], + ax: plt.Axes, z_pos: Union[int, Iterable[int]] = None, sim_ind: int = 0, **kwargs, ): plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind) - return plot_results_2D(vals, plt_range, self.params, **kwargs) + return propagation_plot(vals, plt_range, self.params, ax, **kwargs) def plot_1D( self, left: float, right: float, unit: Union[Callable[[float], float], str], + ax: plt.Axes, z_pos: int, sim_ind: int = 0, **kwargs, ): plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind) - return plot_results_1D(vals, plt_range, self.params, **kwargs) + return single_position_plot(vals, plt_range, self.params, ax, **kwargs) - def plot_avg( + def plot_mean( self, left: float, right: float, unit: Union[Callable[[float], float], str], + ax: plt.Axes, z_pos: int, **kwargs, ): plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, slice(None)) - return plot_avg(vals, plt_range, self.params, **kwargs) + return mean_values_plot(vals, plt_range, self.params, ax, **kwargs) def retrieve_plot_values(self, left, right, unit, z_pos, sim_ind): plt_range = units.PlotRange(left, right, unit) diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index bbc6c16..6c4e99b 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -182,13 +182,6 @@ def progress_worker( pbars[0].update() -def count_variations(config: BareConfig) -> int: - """returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and - variable_params_num is the number of distinct parameters that will vary.""" - sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat - return int(sim_num) - - def format_variable_list(l: List[Tuple[str, Any]]): joints = 2 * PARAM_SEPARATOR str_list = [] @@ -229,7 +222,7 @@ def pretty_format_from_file_name(name: str) -> str: return PARAM_SEPARATOR.join(out) -def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]: +def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], dict[str, Any]]]: """given a config with "variable" parameters, iterates through every possible combination, yielding a a list of (parameter_name, value) tuples and a full config dictionary. @@ -240,10 +233,10 @@ def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any] Yields ------- - Iterator[Tuple[List[Tuple[str, Any]], BareParams]] + Iterator[Tuple[List[Tuple[str, Any]], dict[str, Any]]] variable_list : a list of (name, value) tuple of parameter name and value that are variable. - params : a BareParams obj for one simulation + params : a dict[str, Any] to be fed to Params """ possible_keys = [] possible_ranges = [] @@ -264,10 +257,12 @@ def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any] param_dict = asdict(config) param_dict.pop("variable") param_dict.update(indiv_config) - yield variable_list, BareParams(**param_dict) + yield variable_list, param_dict -def required_simulations(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]: +def required_simulations( + *configs: BareConfig, +) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]: """takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different parameter set and iterates through every single necessary simulation @@ -281,22 +276,49 @@ def required_simulations(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, A dict : a config dictionary for one simulation """ i = 0 # unique sim id - for variable_only, bare_params in variable_iterator(config): - for j in range(config.repeat): + for data in itertools.product(*[variable_iterator(config) for config in configs]): + all_variable_only, all_params_dict = list(zip(*data)) + params_dict = all_params_dict[0] + for p in all_params_dict[1:]: + params_dict.update({k: v for k, v in p.items() if v is not None}) + variable_only = reduce_all_variable(all_variable_only) + for j in range(configs[0].repeat or 1): variable_ind = [("id", i)] + variable_only + [("num", j)] i += 1 - yield variable_ind, bare_params + yield variable_ind, BareParams(**params_dict) -def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig: +def reduce_all_variable(all_variable: list[list[tuple[str, Any]]]) -> list[tuple[str, Any]]: + out = [] + for n, variable_list in enumerate(all_variable): + out += [("fiber", "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[n % 26] * (n // 26 + 1)), *variable_list] + return out + + +def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig: """makes sure all the parameters set in new are there, leaves untouched parameters in old""" + new_dict = asdict(new) if old is None: - return BareConfig(**new) + return BareConfig(**new_dict) variable = deepcopy(old.variable) - variable.update(new.pop("variable", {})) # add new variable - for k in new: - variable.pop(k, None) # remove old ones - return replace(old, variable=variable, **{k: None for k in variable}, **new) + new_dict = {k: v for k, v in new_dict.items() if v is not None} + + for k, v in new_dict.pop("variable", {}).items(): + variable[k] = v + for k in variable: + new_dict[k] = None + return replace(old, variable=variable, **new_dict) + + +def final_config_from_sequence(*configs: BareConfig) -> BareConfig: + if len(configs) == 0: + raise ValueError("Must provide at least one config") + if len(configs) == 1: + return configs[0] + elif len(configs) == 2: + return override_config(*configs[::-1]) + else: + return override_config(configs[-1], final_config_from_sequence(*configs[:-1])) def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray: diff --git a/testing/test_new_iterator.py b/testing/test_new_iterator.py new file mode 100644 index 0000000..e58e1b8 --- /dev/null +++ b/testing/test_new_iterator.py @@ -0,0 +1,14 @@ +import scgenerator as sc +from pathlib import Path + +p = Path("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PPP") + +configs = [ + sc.io.load_config(p / c) + for c in ("PM1550.toml", "PMHNLF_appended.toml", "PM2000_appended.toml") +] + +for variable, params in sc.utils.required_simulations(*configs): + print(variable) + +# sc.initialize.ContinuationParamSequence(configs[-1])