loads of stuff

This commit is contained in:
Benoît Sierro
2021-06-21 15:11:02 +02:00
parent b63a77cdd6
commit 5bd5e3e921
12 changed files with 432 additions and 103 deletions

View File

@@ -1,20 +1,14 @@
import argparse import argparse
import os import os
import random
from pathlib import Path
from collections import ChainMap from collections import ChainMap
from pathlib import Path
from ray.worker import get import numpy as np
from .. import io, env, const from .. import env, io, scripts
from ..logger import get_logger from ..logger import get_logger
from ..physics.simulate import (
SequencialSimulations,
resume_simulations,
run_simulation_sequence,
)
from ..physics.fiber import dispersion_coefficients from ..physics.fiber import dispersion_coefficients
from pprint import pprint from ..physics.simulate import SequencialSimulations, resume_simulations, run_simulation_sequence
try: try:
import ray import ray
@@ -24,8 +18,8 @@ except ImportError:
def set_env_variables(cmd_line_args: dict[str, str]): def set_env_variables(cmd_line_args: dict[str, str]):
cm = ChainMap(cmd_line_args, os.environ) cm = ChainMap(cmd_line_args, os.environ)
for env_key in const.global_config: for env_key in env.global_config:
k = env_key.replace(const.ENVIRON_KEY_BASE, "").lower() k = env_key.replace(env.ENVIRON_KEY_BASE, "").lower()
v = cm.get(k) v = cm.get(k)
if v is not None: if v is not None:
os.environ[env_key] = str(v) os.environ[env_key] = str(v)
@@ -35,8 +29,8 @@ def create_parser():
parser = argparse.ArgumentParser(description="scgenerator command", prog="scgenerator") parser = argparse.ArgumentParser(description="scgenerator command", prog="scgenerator")
subparsers = parser.add_subparsers(help="sub-command help") subparsers = parser.add_subparsers(help="sub-command help")
for key, args in const.global_config.items(): for key, args in env.global_config.items():
names = ["--" + key.replace(const.ENVIRON_KEY_BASE, "").replace("_", "-").lower()] names = ["--" + key.replace(env.ENVIRON_KEY_BASE, "").replace("_", "-").lower()]
if "short_name" in args: if "short_name" in args:
names.append(args["short_name"]) names.append(args["short_name"])
parser.add_argument( parser.add_argument(
@@ -66,6 +60,69 @@ def create_parser():
) )
merge_parser.set_defaults(func=merge) merge_parser.set_defaults(func=merge)
plot_parser = subparsers.add_parser("plot", help="generate basic plots of a simulation")
plot_parser.add_argument(
"sim_dir",
help="path to the root directory of the simulation (i.e. the "
"directory directly containing 'initial_config0.toml'",
)
plot_parser.add_argument(
"spectrum_limits",
nargs=argparse.REMAINDER,
help="comma-separated list of left limit, right limit and unit. "
"One plot is made for each limit set provided. Example : 600,1200,nm or -2,2,ps",
)
plot_parser.set_defaults(func=plot_all)
dispersion_parser = subparsers.add_parser(
"dispersion", help="show the dispersion of the given config"
)
dispersion_parser.add_argument("config", help="path to the config file")
dispersion_parser.add_argument(
"--limits", "-l", default=None, type=float, nargs=2, help="left and right limits in nm"
)
dispersion_parser.set_defaults(func=plot_dispersion)
init_pulse_plot_parser = subparsers.add_parser(
"plot-spec-field", help="plot the initial field and spectrum"
)
init_pulse_plot_parser.add_argument("config", help="path to the config file")
init_pulse_plot_parser.add_argument(
"--wavelength-limits",
"-l",
default=None,
type=float,
nargs=2,
help="left and right limits in nm",
)
init_pulse_plot_parser.add_argument(
"--time-limit", "-t", default=None, type=float, help="time axis limit in fs"
)
init_pulse_plot_parser.set_defaults(func=plot_init_field_spec)
init_plot_parser = subparsers.add_parser("plot-init", help="plot initial values")
init_plot_parser.add_argument("config", help="path to the config file")
init_plot_parser.add_argument(
"--dispersion-limits",
"-s",
default=None,
type=float,
nargs=2,
help="left and right limits for dispersion plots in nm",
)
init_plot_parser.add_argument(
"--time-limit", "-t", default=None, type=float, help="time axis limit in fs"
)
init_plot_parser.add_argument(
"--wavelength-limits",
"-l",
default=None,
nargs=2,
type=float,
help="wavelength axis limit in nm",
)
init_plot_parser.set_defaults(func=plot_init)
return parser return parser
@@ -98,9 +155,9 @@ def merge(args):
def prep_ray(): def prep_ray():
logger = get_logger(__name__) logger = get_logger(__name__)
if ray: if ray:
if env.get(const.START_RAY): if env.get(env.START_RAY):
init_str = ray.init() init_str = ray.init()
elif not env.get(const.NO_RAY): elif not env.get(env.NO_RAY):
try: try:
init_str = ray.init( init_str = ray.init(
address="auto", address="auto",
@@ -108,8 +165,8 @@ def prep_ray():
) )
logger.info(init_str) logger.info(init_str)
except ConnectionError as e: except ConnectionError as e:
logger.error(e) logger.warning(e)
return SequencialSimulations if env.get(const.NO_RAY) else None return SequencialSimulations if env.get(env.NO_RAY) else None
def resume_sim(args): def resume_sim(args):
@@ -120,5 +177,50 @@ def resume_sim(args):
run_simulation_sequence(*args.configs, method=method, prev_sim_dir=sim.sim_dir) run_simulation_sequence(*args.configs, method=method, prev_sim_dir=sim.sim_dir)
def plot_all(args):
root = Path(args.sim_dir).resolve()
scripts.plot_all(root, args.spectrum_limits)
def plot_init_field_spec(args):
if args.wavelength_limits is None:
l = None
else:
l = list(args.wavelength_limits)
if args.time_limit is None:
t = None
else:
t = [-args.time_limit, args.time_limit]
scripts.plot_init_field_spec(args.config, t, l)
def plot_init(args):
if args.wavelength_limits is None:
l = None
else:
l = list(args.wavelength_limits)
if args.dispersion_limits is None:
d = None
else:
d = list(args.dispersion_limits)
if args.time_limit is None:
t = None
else:
t = [-args.time_limit, args.time_limit]
scripts.plot_init(args.config, t, l, d)
def plot_dispersion(args):
if args.limits is None:
lims = [None, None]
else:
lims = 1e-9 * np.array(args.limits, dtype=float)
scripts.plot_dispersion(args.config, lims)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -19,47 +19,8 @@ def pbar_format(worker_id: int):
) )
ENVIRON_KEY_BASE = "SCGENERATOR_"
TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_"
PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_"
PARAM_SEPARATOR = " "
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
LOG_FILE_LEVEL = ENVIRON_KEY_BASE + "LOG_FILE_LEVEL"
LOG_PRINT_LEVEL = ENVIRON_KEY_BASE + "LOG_PRINT_LEVEL"
START_RAY = ENVIRON_KEY_BASE + "START_RAY"
NO_RAY = ENVIRON_KEY_BASE + "NO_RAY"
OUTPUT_PATH = ENVIRON_KEY_BASE + "OUTPUT_PATH"
global_config: dict[str, dict[str, Any]] = {
LOG_FILE_LEVEL: dict(
help="minimum lvl of message to be saved in the log file",
choices=["critical", "error", "warning", "info", "debug"],
default=None,
type=str,
),
LOG_PRINT_LEVEL: dict(
help="minimum lvl of message to be printed to the standard output",
choices=["critical", "error", "warning", "info", "debug"],
default="error",
type=str,
),
PBAR_POLICY: dict(
help="what to do with progress pars (print them, make them a txt file or nothing), default is print",
choices=["print", "file", "both", "none"],
default=None,
type=str,
),
START_RAY: dict(action="store_true", help="initialize ray (ray must be installed)", type=bool),
NO_RAY: dict(action="store_true", help="force not to use ray", type=bool),
OUTPUT_PATH: dict(
short_name="-o", help="path to the final output folder", default=None, type=str
),
}
SPEC1_FN = "spectrum_{}.npy" SPEC1_FN = "spectrum_{}.npy"
SPECN_FN = "spectra_{}.npy" SPECN_FN = "spectra_{}.npy"
Z_FN = "z.npy" Z_FN = "z.npy"
PARAM_FN = "params.toml" PARAM_FN = "params.toml"
PARAM_SEPARATOR = " "

View File

@@ -1,14 +1,45 @@
import os import os
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set from typing import Any, Dict, Literal, Optional, Set
from .const import (
ENVIRON_KEY_BASE, ENVIRON_KEY_BASE = "SCGENERATOR_"
LOG_FILE_LEVEL, TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_"
LOG_PRINT_LEVEL, PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_"
PBAR_POLICY,
TMP_FOLDER_KEY_BASE, PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
global_config, LOG_FILE_LEVEL = ENVIRON_KEY_BASE + "LOG_FILE_LEVEL"
) LOG_PRINT_LEVEL = ENVIRON_KEY_BASE + "LOG_PRINT_LEVEL"
START_RAY = ENVIRON_KEY_BASE + "START_RAY"
NO_RAY = ENVIRON_KEY_BASE + "NO_RAY"
OUTPUT_PATH = ENVIRON_KEY_BASE + "OUTPUT_PATH"
global_config: dict[str, dict[str, Any]] = {
LOG_FILE_LEVEL: dict(
help="minimum lvl of message to be saved in the log file",
choices=["critical", "error", "warning", "info", "debug"],
default=None,
type=str,
),
LOG_PRINT_LEVEL: dict(
help="minimum lvl of message to be printed to the standard output",
choices=["critical", "error", "warning", "info", "debug"],
default="error",
type=str,
),
PBAR_POLICY: dict(
help="what to do with progress pars (print them, make them a txt file or nothing), default is print",
choices=["print", "file", "both", "none"],
default=None,
type=str,
),
START_RAY: dict(action="store_true", help="initialize ray (ray must be installed)", type=bool),
NO_RAY: dict(action="store_true", help="force not to use ray", type=bool),
OUTPUT_PATH: dict(
short_name="-o", help="path to the final output folder", default=None, type=str
),
}
def data_folder(task_id: int) -> Optional[str]: def data_folder(task_id: int) -> Optional[str]:
@@ -36,6 +67,13 @@ def all_environ() -> Dict[str, str]:
return d return d
def output_path() -> Path:
p = get(OUTPUT_PATH)
if p is not None:
return Path(p).resolve()
return None
def pbar_policy() -> Set[Literal["print", "file"]]: def pbar_policy() -> Set[Literal["print", "file"]]:
policy = get(PBAR_POLICY) policy = get(PBAR_POLICY)
if policy == "print" or policy is None: if policy == "print" or policy is None:

View File

@@ -97,7 +97,7 @@ class Params(BareParams):
self.dynamic_dispersion = False self.dynamic_dispersion = False
else: else:
self.dynamic_dispersion = fiber.is_dynamic_dispersion(self.pressure) self.dynamic_dispersion = fiber.is_dynamic_dispersion(self.pressure)
self.beta, temp_gamma = fiber.compute_dispersion(self) self.beta, temp_gamma, self.interp_range = fiber.compute_dispersion(self)
if self.dynamic_dispersion: if self.dynamic_dispersion:
self.gamma_func = temp_gamma self.gamma_func = temp_gamma
self.beta_func = self.beta self.beta_func = self.beta

View File

@@ -1,8 +1,6 @@
from dataclasses import asdict
import itertools import itertools
import os import os
import shutil import shutil
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Generator, List, Sequence, Tuple from typing import Any, Dict, Generator, List, Sequence, Tuple
@@ -11,15 +9,9 @@ import pkg_resources as pkg
import toml import toml
from . import env, utils from . import env, utils
from .const import ( from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__
PARAM_FN, from .env import TMP_FOLDER_KEY_BASE
PARAM_SEPARATOR,
SPEC1_FN,
SPECN_FN,
TMP_FOLDER_KEY_BASE,
Z_FN,
__version__,
)
from .errors import IncompleteDataFolderError from .errors import IncompleteDataFolderError
from .logger import get_logger from .logger import get_logger
from .utils.parameter import BareConfig, BareParams from .utils.parameter import BareConfig, BareParams
@@ -93,6 +85,8 @@ def load_toml(path: os.PathLike):
section = dico.pop(key, {}) section = dico.pop(key, {})
dico["variable"].update(section.pop("variable", {})) dico["variable"].update(section.pop("variable", {}))
dico.update(section) dico.update(section)
if len(dico["variable"]) == 0:
dico.pop("variable")
return dico return dico

View File

@@ -211,3 +211,10 @@ def make_uniform_1D(values, x_axis, n=1024, method="linear"):
""" """
xx = np.linspace(*span(x_axis), len(x_axis)) xx = np.linspace(*span(x_axis), len(x_axis))
return interp1d(x_axis, values, kind=method)(xx) return interp1d(x_axis, values, kind=method)(xx)
def all_zeros(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""find all the x values such that y(x)=0 with linear interpolation"""
pos = np.argwhere(y[1:] * y[:-1] < 0)[:, 0]
m = (y[pos] - y[pos - 1]) / (x[pos] - x[pos - 1])
return -y[pos] / m + x[pos]

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Literal, Tuple
import numpy as np import numpy as np
import toml import toml
@@ -276,10 +276,10 @@ def A_eff_hasan(core_radius, capillary_num, capillary_spacing):
def HCPCF_find_with_given_ZDW( def HCPCF_find_with_given_ZDW(
variable, variable: Literal["pressure", "temperature"],
target, target: float,
search_range, search_range: tuple[float, float],
material_dico, material_dico: dict[str, Any],
model="marcatili", model="marcatili",
model_params={}, model_params={},
pressure=None, pressure=None,
@@ -673,10 +673,12 @@ def compute_dispersion(params: BareParams):
if params.dispersion_file is not None: if params.dispersion_file is not None:
disp_file = np.load(params.dispersion_file) disp_file = np.load(params.dispersion_file)
lambda_ = disp_file["wavelength"] lambda_ = disp_file["wavelength"]
interp_range = (np.min(lambda_), np.max(lambda_))
D = disp_file["dispersion"] D = disp_file["dispersion"]
beta2 = D_to_beta2(D, lambda_) beta2 = D_to_beta2(D, lambda_)
gamma = None gamma = None
else: else:
interp_range = params.interp_range
lambda_ = lambda_for_dispersion() lambda_ = lambda_for_dispersion()
beta2 = np.zeros_like(lambda_) beta2 = np.zeros_like(lambda_)
@@ -744,7 +746,7 @@ def compute_dispersion(params: BareParams):
else: else:
gamma = 0 gamma = 0
return beta2_coef, gamma return beta2_coef, gamma, interp_range
@np_cache @np_cache

View File

@@ -2,7 +2,7 @@ import numpy as np
from ..logger import get_logger from ..logger import get_logger
from . import units from . import units
from .units import NA, c, kB from .units import NA, c, kB, me, e
def pressure_from_gradient(ratio, p0, p1): def pressure_from_gradient(ratio, p0, p1):
@@ -170,3 +170,7 @@ def non_linear_refractive_index(material_dico, pressure=None, temperature=None):
ratio = 1 ratio = 1
return ratio * n2_ref return ratio * n2_ref
def adiabadicity(w: np.ndarray, I: float, field: np.ndarray) -> np.ndarray:
return w * np.sqrt(2 * me * I) / (e * np.abs(field))

View File

@@ -14,6 +14,8 @@ NA = 6.02214076e23
R = 8.31446261815324 R = 8.31446261815324
kB = 1.380649e-23 kB = 1.380649e-23
epsilon0 = 8.85418781e-12 epsilon0 = 8.85418781e-12
me = 9.1093837015e-31
e = -1.602176634e-19
prefix = dict(P=1e12, G=1e9, M=1e6, k=1e3, d=1e-1, c=1e-2, m=1e-3, u=1e-6, n=1e-9, p=1e-12, f=1e-15) prefix = dict(P=1e12, G=1e9, M=1e6, k=1e3, d=1e-1, c=1e-2, m=1e-3, u=1e-6, n=1e-9, p=1e-12, f=1e-15)

View File

@@ -0,0 +1,208 @@
from itertools import cycle
import itertools
from pathlib import Path
from typing import Iterable
from cycler import cycler
import matplotlib.pyplot as plt
import numpy as np
from ..utils.parameter import BareParams
from ..initialize import ParamSequence
from ..physics import units, fiber
from ..spectra import Pulse
from ..utils import pretty_format_value
from .. import env, math
def plot_all(sim_dir: Path, limits: list[str]):
for p in sim_dir.glob("*"):
if not p.is_dir():
continue
pulse = Pulse(p)
for lim in limits:
left, right, unit = lim.split(",")
left = float(left)
right = float(right)
pulse.plot_2D(left, right, unit, file_name=p.parent / f"{p.name}_{left}_{right}_{unit}")
def plot_init_field_spec(
config_path: Path,
lim_t: tuple[float, float] = None,
lim_l: tuple[float, float] = None,
):
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7))
all_labels = []
already_plotted = set()
for style, lbl, params in plot_helper(config_path):
if (bbb := hash(params.field_0.tobytes())) not in already_plotted:
already_plotted.add(bbb)
else:
continue
plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params)
all_labels.append(lbl)
finish_plot(fig, left, right, all_labels, params)
def plot_dispersion(config_path: Path, lim: tuple[float, float] = None):
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7), sharex=True)
left.grid()
right.grid()
all_labels = []
already_plotted = set()
for style, lbl, params in plot_helper(config_path):
if (bbb := tuple(params.beta)) not in already_plotted:
already_plotted.add(bbb)
else:
continue
plot_1_dispersion(lim, left, right, style, lbl, params)
all_labels.append(lbl)
finish_plot(fig, left, right, all_labels, params)
def plot_init(
config_path: Path,
lim_field: tuple[float, float] = None,
lim_spec: tuple[float, float] = None,
lim_disp: tuple[float, float] = None,
):
fig, ((tl, tr), (bl, br)) = plt.subplots(2, 2, figsize=(14, 10))
tl.grid()
tr.grid()
all_labels = []
already_plotted = set()
for style, lbl, params in plot_helper(config_path):
if (bbb := hash(params.field_0.tobytes())) not in already_plotted:
already_plotted.add(bbb)
else:
continue
lbl = plot_1_dispersion(lim_disp, tl, tr, style, lbl, params)
lbl = plot_1_init_spec_field(lim_field, lim_spec, bl, br, style, lbl, params)
all_labels.append(lbl)
finish_plot(fig, tr, all_labels, params)
def plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params):
field = math.abs2(params.field_0)
spec = math.abs2(params.spec_0)
t = units.fs.inv(params.t)
wl = units.nm.inv(params.w)
lbl.append(f"max at {wl[spec.argmax()]:.1f} nm")
mt = np.ones_like(t, dtype=bool)
if lim_t is not None:
mt &= t >= lim_t[0]
mt &= t <= lim_t[1]
else:
mt = find_lim(t, field)
ml = np.ones_like(wl, dtype=bool)
if lim_l is not None:
ml &= t >= lim_l[0]
ml &= t <= lim_l[1]
else:
ml = find_lim(wl, spec)
left.plot(t[mt], field[mt])
right.plot(wl[ml], spec[ml], label=" ", **style)
return lbl
def plot_1_dispersion(lim, left, right, style, lbl, params):
coef = params.beta / np.cumprod([1] + list(range(1, len(params.beta))))
w_c = params.w_c
beta_arr = np.zeros_like(w_c)
for k, beta in reversed(list(enumerate(coef))):
beta_arr = beta_arr + beta * w_c ** k
wl = units.m.inv(params.w)
zdw = math.all_zeros(wl, beta_arr)
if len(zdw) > 0:
zdw = zdw[np.argmin(abs(zdw - params.wavelength))]
lbl.append(f"ZDW at {zdw*1e9:.1f}nm")
else:
lbl.append("")
m = np.ones_like(wl, dtype=bool)
if lim is None:
lim = params.interp_range
m &= wl >= lim[0]
m &= wl <= lim[1]
m = np.argwhere(m)[:, 0]
m = np.array(sorted(m, key=lambda el: wl[el]))
# plot D
D = fiber.beta2_to_D(beta_arr, wl) * 1e6
right.plot(1e9 * wl[m], D[m], label=" ", **style)
right.set_ylabel(units.D_ps_nm_km.label)
# plot beta
left.plot(1e9 * wl[m], units.beta2_fs_cm.inv(beta_arr[m]), label=" ", **style)
left.set_ylabel(units.beta2_fs_cm.label)
left.set_xlabel("wavelength (nm)")
right.set_xlabel("wavelength (nm)")
return lbl
def finish_plot(fig, legend_axes, all_labels, params):
fig.suptitle(params.name)
plt.tight_layout()
handles, _ = legend_axes.get_legend_handles_labels()
lbl_lengths = [[len(l) for l in lbl] for lbl in all_labels]
lengths = np.max(lbl_lengths, axis=0)
labels = [
" ".join(format(l, f">{int(s)}s") for s, l in zip(lengths, lab)) for lab in all_labels
]
legend = legend_axes.legend(handles, labels, prop=dict(size=8, family="monospace"))
out_path = env.output_path()
show = out_path is None
if not show:
file_name = out_path.stem + ".pdf"
out_path = out_path.parent / file_name
if (
out_path.exists()
and input(f"{out_path.name} already exsits, overwrite ? (y/[n])\n > ") != "y"
):
show = True
else:
fig.savefig(out_path, bbox_inches="tight")
if show:
plt.show()
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], BareParams]]:
cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"])
pseq = ParamSequence(config_path)
for style, (variables, params) in zip(cc, pseq):
lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]]
yield style, lbl, params
def find_lim(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> int:
threshold = y.min() + rel_thr * (y.max() - y.min())
above_threshold = y > threshold
ind = np.argsort(x)
valid_ind = [
np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k
]
ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0]
width = len(ind_above)
return np.concatenate(
(
np.arange(max(ind_above[0] - width, 0), ind_above[0]),
ind_above,
np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)),
)
)

View File

@@ -4,30 +4,24 @@ scgenerator module but some function may be used in any python program
""" """
import functools
import itertools import itertools
import multiprocessing import multiprocessing
import threading import threading
from functools import update_wrapper
from collections import abc from collections import abc
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, replace from dataclasses import asdict, replace
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union
from copy import copy
import numpy as np import numpy as np
from numpy.lib.index_tricks import nd_grid
from tqdm import tqdm from tqdm import tqdm
from ..logger import get_logger
from .. import env from .. import env
from ..const import PARAM_SEPARATOR from ..const import PARAM_SEPARATOR
from ..math import * from ..math import *
from .parameter import BareConfig, BareParams from .parameter import BareConfig, BareParams
from scgenerator import logger
T_ = TypeVar("T_") T_ = TypeVar("T_")
@@ -204,7 +198,7 @@ def branch_id(branch: Tuple[Path, ...]) -> str:
return "".join("".join(b.name.split()[2:-2]) for b in branch) return "".join("".join(b.name.split()[2:-2]) for b in branch)
def format_value(value): def format_value(value) -> str:
if type(value) == type(False): if type(value) == type(False):
return str(value) return str(value)
elif isinstance(value, (float, int)): elif isinstance(value, (float, int)):
@@ -215,6 +209,10 @@ def format_value(value):
return str(value) return str(value)
def pretty_format_value(name: str, value) -> str:
return getattr(BareParams, name).display(value)
def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]: def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]:
"""given a config with "variable" parameters, iterates through every possible combination, """given a config with "variable" parameters, iterates through every possible combination,
yielding a a list of (parameter_name, value) tuples and a full config dictionary. yielding a a list of (parameter_name, value) tuples and a full config dictionary.

View File

@@ -5,6 +5,7 @@ from functools import lru_cache
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import numpy as np import numpy as np
from numpy.lib.function_base import disp
from ..const import __version__ from ..const import __version__
@@ -158,7 +159,7 @@ def func_validator(name, n):
class Parameter: class Parameter:
def __init__(self, validator, converter=None, default=None): def __init__(self, validator, converter=None, default=None, display_info=None):
"""Single parameter """Single parameter
Parameters Parameters
@@ -178,6 +179,7 @@ class Parameter:
self.validator = validator self.validator = validator
self.converter = converter self.converter = converter
self.default = default self.default = default
self.display_info = display_info
def __set_name__(self, owner, name): def __set_name__(self, owner, name):
self.name = name self.name = name
@@ -201,6 +203,16 @@ class Parameter:
value = self.converter(value) value = self.converter(value)
instance.__dict__[self.name] = value instance.__dict__[self.name] = value
def display(self, num: float):
if self.display_info is None:
return str(num)
else:
fac, unit = self.display_info
num_str = format(num * fac, ".2f")
if num_str.endswith(".00"):
num_str = num_str[:-3]
return f"{num_str} {unit}"
class VariableParameter: class VariableParameter:
def __init__(self, parameterBase): def __init__(self, parameterBase):
@@ -243,6 +255,7 @@ valid_variable = {
"gamma", "gamma",
"pitch", "pitch",
"pitch_ratio", "pitch_ratio",
"effective_mode_diameter",
"core_radius", "core_radius",
"capillary_num", "capillary_num",
"capillary_outer_d", "capillary_outer_d",
@@ -326,26 +339,26 @@ class BareParams:
capillary_nested: int = Parameter(non_negative(int)) capillary_nested: int = Parameter(non_negative(int))
# gas # gas
gas_name: str = Parameter(literal("vacuum", "helium", "air"), converter=str.lower) gas_name: str = Parameter(string, converter=str.lower)
pressure: Union[float, Iterable[float]] = Parameter( pressure: Union[float, Iterable[float]] = Parameter(
validator_or(non_negative(float, int), num_list) validator_or(non_negative(float, int), num_list), display_info=(1e-5, "bar")
) )
temperature: float = Parameter(positive(float, int)) temperature: float = Parameter(positive(float, int), display_info=(1, "K"))
plasma_density: float = Parameter(non_negative(float, int)) plasma_density: float = Parameter(non_negative(float, int))
# pulse # pulse
field_file: str = Parameter(string) field_file: str = Parameter(string)
repetition_rate: float = Parameter(non_negative(float, int)) repetition_rate: float = Parameter(non_negative(float, int), display_info=(1e-6, "MHz"))
peak_power: float = Parameter(positive(float, int)) peak_power: float = Parameter(positive(float, int), display_info=(1e-3, "kW"))
mean_power: float = Parameter(positive(float, int)) mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW"))
energy: float = Parameter(positive(float, int)) energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ"))
soliton_num: float = Parameter(non_negative(float, int)) soliton_num: float = Parameter(non_negative(float, int))
quantum_noise: bool = Parameter(boolean) quantum_noise: bool = Parameter(boolean)
shape: str = Parameter(literal("gaussian", "sech")) shape: str = Parameter(literal("gaussian", "sech"))
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9)) wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
intensity_noise: float = Parameter(in_range_incl(0, 1)) intensity_noise: float = Parameter(in_range_incl(0, 1))
width: float = Parameter(in_range_excl(0, 1e-9)) width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
t0: float = Parameter(in_range_excl(0, 1e-9)) t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
# simulation # simulation
behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss"))) behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")))