From 5bd5e3e92139167bdef10ca52aabc90226612e04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Mon, 21 Jun 2021 15:11:02 +0200 Subject: [PATCH] loads of stuff --- src/scgenerator/cli/cli.py | 138 +++++++++++++++--- src/scgenerator/const.py | 41 +----- src/scgenerator/env.py | 54 +++++-- src/scgenerator/initialize.py | 2 +- src/scgenerator/io.py | 16 +-- src/scgenerator/math.py | 7 + src/scgenerator/physics/fiber.py | 14 +- src/scgenerator/physics/materials.py | 6 +- src/scgenerator/physics/units.py | 2 + src/scgenerator/scripts/__init__.py | 208 +++++++++++++++++++++++++++ src/scgenerator/utils/__init__.py | 12 +- src/scgenerator/utils/parameter.py | 35 +++-- 12 files changed, 432 insertions(+), 103 deletions(-) diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index 087eede..134138f 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -1,20 +1,14 @@ import argparse import os -import random -from pathlib import Path from collections import ChainMap +from pathlib import Path -from ray.worker import get +import numpy as np -from .. import io, env, const +from .. import env, io, scripts from ..logger import get_logger -from ..physics.simulate import ( - SequencialSimulations, - resume_simulations, - run_simulation_sequence, -) from ..physics.fiber import dispersion_coefficients -from pprint import pprint +from ..physics.simulate import SequencialSimulations, resume_simulations, run_simulation_sequence try: import ray @@ -24,8 +18,8 @@ except ImportError: def set_env_variables(cmd_line_args: dict[str, str]): cm = ChainMap(cmd_line_args, os.environ) - for env_key in const.global_config: - k = env_key.replace(const.ENVIRON_KEY_BASE, "").lower() + for env_key in env.global_config: + k = env_key.replace(env.ENVIRON_KEY_BASE, "").lower() v = cm.get(k) if v is not None: os.environ[env_key] = str(v) @@ -35,8 +29,8 @@ def create_parser(): parser = argparse.ArgumentParser(description="scgenerator command", prog="scgenerator") subparsers = parser.add_subparsers(help="sub-command help") - for key, args in const.global_config.items(): - names = ["--" + key.replace(const.ENVIRON_KEY_BASE, "").replace("_", "-").lower()] + for key, args in env.global_config.items(): + names = ["--" + key.replace(env.ENVIRON_KEY_BASE, "").replace("_", "-").lower()] if "short_name" in args: names.append(args["short_name"]) parser.add_argument( @@ -66,6 +60,69 @@ def create_parser(): ) merge_parser.set_defaults(func=merge) + plot_parser = subparsers.add_parser("plot", help="generate basic plots of a simulation") + plot_parser.add_argument( + "sim_dir", + help="path to the root directory of the simulation (i.e. the " + "directory directly containing 'initial_config0.toml'", + ) + plot_parser.add_argument( + "spectrum_limits", + nargs=argparse.REMAINDER, + help="comma-separated list of left limit, right limit and unit. " + "One plot is made for each limit set provided. Example : 600,1200,nm or -2,2,ps", + ) + plot_parser.set_defaults(func=plot_all) + + dispersion_parser = subparsers.add_parser( + "dispersion", help="show the dispersion of the given config" + ) + dispersion_parser.add_argument("config", help="path to the config file") + dispersion_parser.add_argument( + "--limits", "-l", default=None, type=float, nargs=2, help="left and right limits in nm" + ) + dispersion_parser.set_defaults(func=plot_dispersion) + + init_pulse_plot_parser = subparsers.add_parser( + "plot-spec-field", help="plot the initial field and spectrum" + ) + init_pulse_plot_parser.add_argument("config", help="path to the config file") + init_pulse_plot_parser.add_argument( + "--wavelength-limits", + "-l", + default=None, + type=float, + nargs=2, + help="left and right limits in nm", + ) + init_pulse_plot_parser.add_argument( + "--time-limit", "-t", default=None, type=float, help="time axis limit in fs" + ) + init_pulse_plot_parser.set_defaults(func=plot_init_field_spec) + + init_plot_parser = subparsers.add_parser("plot-init", help="plot initial values") + init_plot_parser.add_argument("config", help="path to the config file") + init_plot_parser.add_argument( + "--dispersion-limits", + "-s", + default=None, + type=float, + nargs=2, + help="left and right limits for dispersion plots in nm", + ) + init_plot_parser.add_argument( + "--time-limit", "-t", default=None, type=float, help="time axis limit in fs" + ) + init_plot_parser.add_argument( + "--wavelength-limits", + "-l", + default=None, + nargs=2, + type=float, + help="wavelength axis limit in nm", + ) + init_plot_parser.set_defaults(func=plot_init) + return parser @@ -98,9 +155,9 @@ def merge(args): def prep_ray(): logger = get_logger(__name__) if ray: - if env.get(const.START_RAY): + if env.get(env.START_RAY): init_str = ray.init() - elif not env.get(const.NO_RAY): + elif not env.get(env.NO_RAY): try: init_str = ray.init( address="auto", @@ -108,8 +165,8 @@ def prep_ray(): ) logger.info(init_str) except ConnectionError as e: - logger.error(e) - return SequencialSimulations if env.get(const.NO_RAY) else None + logger.warning(e) + return SequencialSimulations if env.get(env.NO_RAY) else None def resume_sim(args): @@ -120,5 +177,50 @@ def resume_sim(args): run_simulation_sequence(*args.configs, method=method, prev_sim_dir=sim.sim_dir) +def plot_all(args): + root = Path(args.sim_dir).resolve() + scripts.plot_all(root, args.spectrum_limits) + + +def plot_init_field_spec(args): + if args.wavelength_limits is None: + l = None + else: + l = list(args.wavelength_limits) + + if args.time_limit is None: + t = None + else: + t = [-args.time_limit, args.time_limit] + + scripts.plot_init_field_spec(args.config, t, l) + + +def plot_init(args): + if args.wavelength_limits is None: + l = None + else: + l = list(args.wavelength_limits) + if args.dispersion_limits is None: + d = None + else: + d = list(args.dispersion_limits) + + if args.time_limit is None: + t = None + else: + t = [-args.time_limit, args.time_limit] + + scripts.plot_init(args.config, t, l, d) + + +def plot_dispersion(args): + if args.limits is None: + lims = [None, None] + else: + lims = 1e-9 * np.array(args.limits, dtype=float) + scripts.plot_dispersion(args.config, lims) + + if __name__ == "__main__": main() diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index ead019e..84738fa 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -19,47 +19,8 @@ def pbar_format(worker_id: int): ) -ENVIRON_KEY_BASE = "SCGENERATOR_" -TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_" -PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_" -PARAM_SEPARATOR = " " - -PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY" -LOG_FILE_LEVEL = ENVIRON_KEY_BASE + "LOG_FILE_LEVEL" -LOG_PRINT_LEVEL = ENVIRON_KEY_BASE + "LOG_PRINT_LEVEL" -START_RAY = ENVIRON_KEY_BASE + "START_RAY" -NO_RAY = ENVIRON_KEY_BASE + "NO_RAY" -OUTPUT_PATH = ENVIRON_KEY_BASE + "OUTPUT_PATH" - - -global_config: dict[str, dict[str, Any]] = { - LOG_FILE_LEVEL: dict( - help="minimum lvl of message to be saved in the log file", - choices=["critical", "error", "warning", "info", "debug"], - default=None, - type=str, - ), - LOG_PRINT_LEVEL: dict( - help="minimum lvl of message to be printed to the standard output", - choices=["critical", "error", "warning", "info", "debug"], - default="error", - type=str, - ), - PBAR_POLICY: dict( - help="what to do with progress pars (print them, make them a txt file or nothing), default is print", - choices=["print", "file", "both", "none"], - default=None, - type=str, - ), - START_RAY: dict(action="store_true", help="initialize ray (ray must be installed)", type=bool), - NO_RAY: dict(action="store_true", help="force not to use ray", type=bool), - OUTPUT_PATH: dict( - short_name="-o", help="path to the final output folder", default=None, type=str - ), -} - - SPEC1_FN = "spectrum_{}.npy" SPECN_FN = "spectra_{}.npy" Z_FN = "z.npy" PARAM_FN = "params.toml" +PARAM_SEPARATOR = " " diff --git a/src/scgenerator/env.py b/src/scgenerator/env.py index 6f44b71..46c0d6e 100644 --- a/src/scgenerator/env.py +++ b/src/scgenerator/env.py @@ -1,14 +1,45 @@ import os +from pathlib import Path from typing import Any, Dict, Literal, Optional, Set -from .const import ( - ENVIRON_KEY_BASE, - LOG_FILE_LEVEL, - LOG_PRINT_LEVEL, - PBAR_POLICY, - TMP_FOLDER_KEY_BASE, - global_config, -) + +ENVIRON_KEY_BASE = "SCGENERATOR_" +TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_" +PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_" + +PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY" +LOG_FILE_LEVEL = ENVIRON_KEY_BASE + "LOG_FILE_LEVEL" +LOG_PRINT_LEVEL = ENVIRON_KEY_BASE + "LOG_PRINT_LEVEL" +START_RAY = ENVIRON_KEY_BASE + "START_RAY" +NO_RAY = ENVIRON_KEY_BASE + "NO_RAY" +OUTPUT_PATH = ENVIRON_KEY_BASE + "OUTPUT_PATH" + + +global_config: dict[str, dict[str, Any]] = { + LOG_FILE_LEVEL: dict( + help="minimum lvl of message to be saved in the log file", + choices=["critical", "error", "warning", "info", "debug"], + default=None, + type=str, + ), + LOG_PRINT_LEVEL: dict( + help="minimum lvl of message to be printed to the standard output", + choices=["critical", "error", "warning", "info", "debug"], + default="error", + type=str, + ), + PBAR_POLICY: dict( + help="what to do with progress pars (print them, make them a txt file or nothing), default is print", + choices=["print", "file", "both", "none"], + default=None, + type=str, + ), + START_RAY: dict(action="store_true", help="initialize ray (ray must be installed)", type=bool), + NO_RAY: dict(action="store_true", help="force not to use ray", type=bool), + OUTPUT_PATH: dict( + short_name="-o", help="path to the final output folder", default=None, type=str + ), +} def data_folder(task_id: int) -> Optional[str]: @@ -36,6 +67,13 @@ def all_environ() -> Dict[str, str]: return d +def output_path() -> Path: + p = get(OUTPUT_PATH) + if p is not None: + return Path(p).resolve() + return None + + def pbar_policy() -> Set[Literal["print", "file"]]: policy = get(PBAR_POLICY) if policy == "print" or policy is None: diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 006e8dc..2ee47b9 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -97,7 +97,7 @@ class Params(BareParams): self.dynamic_dispersion = False else: self.dynamic_dispersion = fiber.is_dynamic_dispersion(self.pressure) - self.beta, temp_gamma = fiber.compute_dispersion(self) + self.beta, temp_gamma, self.interp_range = fiber.compute_dispersion(self) if self.dynamic_dispersion: self.gamma_func = temp_gamma self.beta_func = self.beta diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 88b01a3..1c13fa6 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -1,8 +1,6 @@ -from dataclasses import asdict import itertools import os import shutil -from datetime import datetime from pathlib import Path from typing import Any, Dict, Generator, List, Sequence, Tuple @@ -11,15 +9,9 @@ import pkg_resources as pkg import toml from . import env, utils -from .const import ( - PARAM_FN, - PARAM_SEPARATOR, - SPEC1_FN, - SPECN_FN, - TMP_FOLDER_KEY_BASE, - Z_FN, - __version__, -) +from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__ +from .env import TMP_FOLDER_KEY_BASE + from .errors import IncompleteDataFolderError from .logger import get_logger from .utils.parameter import BareConfig, BareParams @@ -93,6 +85,8 @@ def load_toml(path: os.PathLike): section = dico.pop(key, {}) dico["variable"].update(section.pop("variable", {})) dico.update(section) + if len(dico["variable"]) == 0: + dico.pop("variable") return dico diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 2581966..208f386 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -211,3 +211,10 @@ def make_uniform_1D(values, x_axis, n=1024, method="linear"): """ xx = np.linspace(*span(x_axis), len(x_axis)) return interp1d(x_axis, values, kind=method)(xx) + + +def all_zeros(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """find all the x values such that y(x)=0 with linear interpolation""" + pos = np.argwhere(y[1:] * y[:-1] < 0)[:, 0] + m = (y[pos] - y[pos - 1]) / (x[pos] - x[pos - 1]) + return -y[pos] / m + x[pos] diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 08fd85b..be4144f 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Literal, Tuple import numpy as np import toml @@ -276,10 +276,10 @@ def A_eff_hasan(core_radius, capillary_num, capillary_spacing): def HCPCF_find_with_given_ZDW( - variable, - target, - search_range, - material_dico, + variable: Literal["pressure", "temperature"], + target: float, + search_range: tuple[float, float], + material_dico: dict[str, Any], model="marcatili", model_params={}, pressure=None, @@ -673,10 +673,12 @@ def compute_dispersion(params: BareParams): if params.dispersion_file is not None: disp_file = np.load(params.dispersion_file) lambda_ = disp_file["wavelength"] + interp_range = (np.min(lambda_), np.max(lambda_)) D = disp_file["dispersion"] beta2 = D_to_beta2(D, lambda_) gamma = None else: + interp_range = params.interp_range lambda_ = lambda_for_dispersion() beta2 = np.zeros_like(lambda_) @@ -744,7 +746,7 @@ def compute_dispersion(params: BareParams): else: gamma = 0 - return beta2_coef, gamma + return beta2_coef, gamma, interp_range @np_cache diff --git a/src/scgenerator/physics/materials.py b/src/scgenerator/physics/materials.py index d7832a2..f7bd06c 100644 --- a/src/scgenerator/physics/materials.py +++ b/src/scgenerator/physics/materials.py @@ -2,7 +2,7 @@ import numpy as np from ..logger import get_logger from . import units -from .units import NA, c, kB +from .units import NA, c, kB, me, e def pressure_from_gradient(ratio, p0, p1): @@ -170,3 +170,7 @@ def non_linear_refractive_index(material_dico, pressure=None, temperature=None): ratio = 1 return ratio * n2_ref + + +def adiabadicity(w: np.ndarray, I: float, field: np.ndarray) -> np.ndarray: + return w * np.sqrt(2 * me * I) / (e * np.abs(field)) diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index f832c15..b00a201 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -14,6 +14,8 @@ NA = 6.02214076e23 R = 8.31446261815324 kB = 1.380649e-23 epsilon0 = 8.85418781e-12 +me = 9.1093837015e-31 +e = -1.602176634e-19 prefix = dict(P=1e12, G=1e9, M=1e6, k=1e3, d=1e-1, c=1e-2, m=1e-3, u=1e-6, n=1e-9, p=1e-12, f=1e-15) diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index e69de29..34d3c9a 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -0,0 +1,208 @@ +from itertools import cycle +import itertools +from pathlib import Path +from typing import Iterable +from cycler import cycler + +import matplotlib.pyplot as plt +import numpy as np + +from ..utils.parameter import BareParams + +from ..initialize import ParamSequence +from ..physics import units, fiber +from ..spectra import Pulse +from ..utils import pretty_format_value +from .. import env, math + + +def plot_all(sim_dir: Path, limits: list[str]): + for p in sim_dir.glob("*"): + if not p.is_dir(): + continue + + pulse = Pulse(p) + for lim in limits: + left, right, unit = lim.split(",") + left = float(left) + right = float(right) + pulse.plot_2D(left, right, unit, file_name=p.parent / f"{p.name}_{left}_{right}_{unit}") + + +def plot_init_field_spec( + config_path: Path, + lim_t: tuple[float, float] = None, + lim_l: tuple[float, float] = None, +): + fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7)) + all_labels = [] + already_plotted = set() + for style, lbl, params in plot_helper(config_path): + if (bbb := hash(params.field_0.tobytes())) not in already_plotted: + already_plotted.add(bbb) + else: + continue + + plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params) + all_labels.append(lbl) + finish_plot(fig, left, right, all_labels, params) + + +def plot_dispersion(config_path: Path, lim: tuple[float, float] = None): + fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7), sharex=True) + left.grid() + right.grid() + all_labels = [] + already_plotted = set() + for style, lbl, params in plot_helper(config_path): + if (bbb := tuple(params.beta)) not in already_plotted: + already_plotted.add(bbb) + else: + continue + + plot_1_dispersion(lim, left, right, style, lbl, params) + all_labels.append(lbl) + finish_plot(fig, left, right, all_labels, params) + + +def plot_init( + config_path: Path, + lim_field: tuple[float, float] = None, + lim_spec: tuple[float, float] = None, + lim_disp: tuple[float, float] = None, +): + fig, ((tl, tr), (bl, br)) = plt.subplots(2, 2, figsize=(14, 10)) + tl.grid() + tr.grid() + all_labels = [] + already_plotted = set() + for style, lbl, params in plot_helper(config_path): + if (bbb := hash(params.field_0.tobytes())) not in already_plotted: + already_plotted.add(bbb) + else: + continue + lbl = plot_1_dispersion(lim_disp, tl, tr, style, lbl, params) + lbl = plot_1_init_spec_field(lim_field, lim_spec, bl, br, style, lbl, params) + all_labels.append(lbl) + finish_plot(fig, tr, all_labels, params) + + +def plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params): + field = math.abs2(params.field_0) + spec = math.abs2(params.spec_0) + t = units.fs.inv(params.t) + wl = units.nm.inv(params.w) + + lbl.append(f"max at {wl[spec.argmax()]:.1f} nm") + + mt = np.ones_like(t, dtype=bool) + if lim_t is not None: + mt &= t >= lim_t[0] + mt &= t <= lim_t[1] + else: + mt = find_lim(t, field) + ml = np.ones_like(wl, dtype=bool) + if lim_l is not None: + ml &= t >= lim_l[0] + ml &= t <= lim_l[1] + else: + ml = find_lim(wl, spec) + + left.plot(t[mt], field[mt]) + right.plot(wl[ml], spec[ml], label=" ", **style) + return lbl + + +def plot_1_dispersion(lim, left, right, style, lbl, params): + coef = params.beta / np.cumprod([1] + list(range(1, len(params.beta)))) + w_c = params.w_c + + beta_arr = np.zeros_like(w_c) + for k, beta in reversed(list(enumerate(coef))): + beta_arr = beta_arr + beta * w_c ** k + wl = units.m.inv(params.w) + + zdw = math.all_zeros(wl, beta_arr) + if len(zdw) > 0: + zdw = zdw[np.argmin(abs(zdw - params.wavelength))] + lbl.append(f"ZDW at {zdw*1e9:.1f}nm") + else: + lbl.append("") + + m = np.ones_like(wl, dtype=bool) + if lim is None: + lim = params.interp_range + m &= wl >= lim[0] + m &= wl <= lim[1] + + m = np.argwhere(m)[:, 0] + m = np.array(sorted(m, key=lambda el: wl[el])) + + # plot D + D = fiber.beta2_to_D(beta_arr, wl) * 1e6 + right.plot(1e9 * wl[m], D[m], label=" ", **style) + right.set_ylabel(units.D_ps_nm_km.label) + + # plot beta + left.plot(1e9 * wl[m], units.beta2_fs_cm.inv(beta_arr[m]), label=" ", **style) + left.set_ylabel(units.beta2_fs_cm.label) + + left.set_xlabel("wavelength (nm)") + right.set_xlabel("wavelength (nm)") + return lbl + + +def finish_plot(fig, legend_axes, all_labels, params): + fig.suptitle(params.name) + plt.tight_layout() + + handles, _ = legend_axes.get_legend_handles_labels() + lbl_lengths = [[len(l) for l in lbl] for lbl in all_labels] + lengths = np.max(lbl_lengths, axis=0) + labels = [ + " ".join(format(l, f">{int(s)}s") for s, l in zip(lengths, lab)) for lab in all_labels + ] + + legend = legend_axes.legend(handles, labels, prop=dict(size=8, family="monospace")) + + out_path = env.output_path() + + show = out_path is None + if not show: + file_name = out_path.stem + ".pdf" + out_path = out_path.parent / file_name + if ( + out_path.exists() + and input(f"{out_path.name} already exsits, overwrite ? (y/[n])\n > ") != "y" + ): + show = True + else: + fig.savefig(out_path, bbox_inches="tight") + if show: + plt.show() + + +def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], BareParams]]: + cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"]) + pseq = ParamSequence(config_path) + for style, (variables, params) in zip(cc, pseq): + lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]] + yield style, lbl, params + + +def find_lim(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> int: + threshold = y.min() + rel_thr * (y.max() - y.min()) + above_threshold = y > threshold + ind = np.argsort(x) + valid_ind = [ + np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k + ] + ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0] + width = len(ind_above) + return np.concatenate( + ( + np.arange(max(ind_above[0] - width, 0), ind_above[0]), + ind_above, + np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)), + ) + ) diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index 35fa78e..5ed4857 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -4,30 +4,24 @@ scgenerator module but some function may be used in any python program """ -import functools import itertools import multiprocessing import threading -from functools import update_wrapper from collections import abc from copy import deepcopy from dataclasses import asdict, replace from io import StringIO from pathlib import Path from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union -from copy import copy import numpy as np -from numpy.lib.index_tricks import nd_grid from tqdm import tqdm -from ..logger import get_logger from .. import env from ..const import PARAM_SEPARATOR from ..math import * from .parameter import BareConfig, BareParams -from scgenerator import logger T_ = TypeVar("T_") @@ -204,7 +198,7 @@ def branch_id(branch: Tuple[Path, ...]) -> str: return "".join("".join(b.name.split()[2:-2]) for b in branch) -def format_value(value): +def format_value(value) -> str: if type(value) == type(False): return str(value) elif isinstance(value, (float, int)): @@ -215,6 +209,10 @@ def format_value(value): return str(value) +def pretty_format_value(name: str, value) -> str: + return getattr(BareParams, name).display(value) + + def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]: """given a config with "variable" parameters, iterates through every possible combination, yielding a a list of (parameter_name, value) tuples and a full config dictionary. diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index c842639..206110b 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -5,6 +5,7 @@ from functools import lru_cache from typing import Any, Callable, Dict, Iterable, List, Tuple, Union import numpy as np +from numpy.lib.function_base import disp from ..const import __version__ @@ -158,7 +159,7 @@ def func_validator(name, n): class Parameter: - def __init__(self, validator, converter=None, default=None): + def __init__(self, validator, converter=None, default=None, display_info=None): """Single parameter Parameters @@ -178,6 +179,7 @@ class Parameter: self.validator = validator self.converter = converter self.default = default + self.display_info = display_info def __set_name__(self, owner, name): self.name = name @@ -201,6 +203,16 @@ class Parameter: value = self.converter(value) instance.__dict__[self.name] = value + def display(self, num: float): + if self.display_info is None: + return str(num) + else: + fac, unit = self.display_info + num_str = format(num * fac, ".2f") + if num_str.endswith(".00"): + num_str = num_str[:-3] + return f"{num_str} {unit}" + class VariableParameter: def __init__(self, parameterBase): @@ -243,6 +255,7 @@ valid_variable = { "gamma", "pitch", "pitch_ratio", + "effective_mode_diameter", "core_radius", "capillary_num", "capillary_outer_d", @@ -326,26 +339,26 @@ class BareParams: capillary_nested: int = Parameter(non_negative(int)) # gas - gas_name: str = Parameter(literal("vacuum", "helium", "air"), converter=str.lower) + gas_name: str = Parameter(string, converter=str.lower) pressure: Union[float, Iterable[float]] = Parameter( - validator_or(non_negative(float, int), num_list) + validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar") ) - temperature: float = Parameter(positive(float, int)) + temperature: float = Parameter(positive(float, int), display_info=(1, "K")) plasma_density: float = Parameter(non_negative(float, int)) # pulse field_file: str = Parameter(string) - repetition_rate: float = Parameter(non_negative(float, int)) - peak_power: float = Parameter(positive(float, int)) - mean_power: float = Parameter(positive(float, int)) - energy: float = Parameter(positive(float, int)) + repetition_rate: float = Parameter(non_negative(float, int), display_info=(1e-6, "MHz")) + peak_power: float = Parameter(positive(float, int), display_info=(1e-3, "kW")) + mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW")) + energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ")) soliton_num: float = Parameter(non_negative(float, int)) quantum_noise: bool = Parameter(boolean) shape: str = Parameter(literal("gaussian", "sech")) - wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9)) + wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm")) intensity_noise: float = Parameter(in_range_incl(0, 1)) - width: float = Parameter(in_range_excl(0, 1e-9)) - t0: float = Parameter(in_range_excl(0, 1e-9)) + width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) + t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) # simulation behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")))