diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index 75e06de..666a3f8 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -74,6 +74,9 @@ def create_parser(): "One plot is made for each limit set provided. Example : 600,1200,nm or -2,2,ps", ) plot_parser.add_argument("--options", "-o", default=None) + plot_parser.add_argument( + "--show", action="store_true", help="show the plots instead of saving them" + ) plot_parser.set_defaults(func=plot_all) dispersion_parser = subparsers.add_parser( @@ -186,7 +189,7 @@ def plot_all(args): if args.options is not None: opts |= dict([o.split("=")[:2] for o in re.split("[, ]", args.options)]) root = Path(args.sim_dir).resolve() - scripts.plot_all(root, args.spectrum_limits, **opts) + scripts.plot_all(root, args.spectrum_limits, show=args.show, **opts) def plot_init_field_spec(args): diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 699cde9..b10bfd4 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -13,6 +13,8 @@ from scipy.interpolate import UnivariateSpline from scipy.interpolate.interpolate import interp1d from tqdm import utils +from scgenerator.const import PARAM_SEPARATOR + from .logger import get_logger from . import io, math @@ -65,7 +67,7 @@ def plot_setup( # ensure no overwrite ind = 0 - while (full_path := (out_dir / (plot_name + f"_{ind}." + file_type))).exists(): + while (full_path := (out_dir / (plot_name + f"{PARAM_SEPARATOR}{ind}." + file_type))).exists(): ind += 1 if mode == "default": @@ -864,10 +866,11 @@ def uniform_axis( else: raise TypeError(f"Don't know how to interpret {new_axis_spec}") tmp_axis, ind, ext = units.sort_axis(axis, plt_range) + values = np.atleast_2d(values) if np.allclose((diff := np.diff(tmp_axis))[0], diff): new_axis = tmp_axis + values = values[:, ind] else: - values = np.atleast_2d(values) if plt_range.unit.type == "WL": values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis) new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis)) diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index 66a5307..e9dce6b 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -9,11 +9,12 @@ import numpy as np from tqdm import tqdm from ..utils.parameter import BareParams +from ..const import PARAM_SEPARATOR from ..initialize import ParamSequence from ..physics import units, fiber from ..spectra import Pulse -from ..utils import pretty_format_value, pretty_format_from_file_name, auto_crop +from ..utils import pretty_format_value, pretty_format_from_sim_name, auto_crop from ..plotting import plot_setup from .. import env, math @@ -24,7 +25,7 @@ def fingerprint(params: BareParams): return h1, h2 -def plot_all(sim_dir: Path, limits: list[str], **opts): +def plot_all(sim_dir: Path, limits: list[str], show=False, **opts): for k, v in opts.items(): if k in ["skip"]: opts[k] = int(v) @@ -39,7 +40,12 @@ def plot_all(sim_dir: Path, limits: list[str], **opts): pulse = Pulse(p) for left, right, unit in limits: path, fig, ax = plot_setup( - pulse.path.parent / f"{pulse.path.name}_{left:.1f}_{right:.1f}_{unit}" + pulse.path.parent + / ( + pretty_format_from_sim_name(pulse.path.name) + + PARAM_SEPARATOR + + f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}" + ) ) pulse.plot_2D( left, @@ -49,8 +55,11 @@ def plot_all(sim_dir: Path, limits: list[str], **opts): **opts, ) bar.update() - fig.savefig(path, bbox_inches="tight") - plt.close("all") + if show: + plt.show() + else: + fig.savefig(path, bbox_inches="tight") + plt.close(fig) def plot_init_field_spec( diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index 6c4e99b..d625991 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -208,15 +208,30 @@ def format_value(value) -> str: def pretty_format_value(name: str, value) -> str: - return getattr(BareParams, name).display(value) + try: + return getattr(BareParams, name).display(value) + except AttributeError: + return name + PARAM_SEPARATOR + str(value) -def pretty_format_from_file_name(name: str) -> str: +def pretty_format_from_sim_name(name: str) -> str: + """formats a pretty version of a simulation directory + + Parameters + ---------- + name : str + name of the simulation (directory name) + + Returns + ------- + str + prettier name + """ s = name.split(PARAM_SEPARATOR) out = [] for key, value in zip(s[::2], s[1::2]): try: - out.append(getattr(BareParams, key).display(float(value))) + out += [key.replace("_", " "), getattr(BareParams, key).display(float(value))] except (AttributeError, ValueError): out.append(key + PARAM_SEPARATOR + value) return PARAM_SEPARATOR.join(out) @@ -307,6 +322,9 @@ def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig: variable[k] = v for k in variable: new_dict[k] = None + + new_dict["readjust_wavelength"] = False + return replace(old, variable=variable, **new_dict) diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index e0b8c62..a24bfa6 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -359,7 +359,7 @@ class BareParams: quantum_noise: bool = Parameter(boolean) shape: str = Parameter(literal("gaussian", "sech")) wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm")) - intensity_noise: float = Parameter(in_range_incl(0, 1)) + intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%")) width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) diff --git a/testing/test_new_iterator.py b/testing/test_new_iterator.py index e58e1b8..e15fd90 100644 --- a/testing/test_new_iterator.py +++ b/testing/test_new_iterator.py @@ -1,14 +1,12 @@ import scgenerator as sc from pathlib import Path +import os -p = Path("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PPP") +os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/") -configs = [ - sc.io.load_config(p / c) - for c in ("PM1550.toml", "PMHNLF_appended.toml", "PM2000_appended.toml") -] +root = Path("PM1550+PMHNLF+PM1550+PM2000") -for variable, params in sc.utils.required_simulations(*configs): - print(variable) +confs = sc.io.load_config_sequence(root / "4_PM2000.toml") +final = sc.utils.final_config_from_sequence(*confs) -# sc.initialize.ContinuationParamSequence(configs[-1]) +print(final)