new param system
This commit is contained in:
4
src/scgenerator/cli/__main__.py
Normal file
4
src/scgenerator/cli/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -105,7 +105,7 @@ def prep_ray(args):
|
||||
def resume_sim(args):
|
||||
|
||||
method = prep_ray(args)
|
||||
sim = resume_simulations(args.sim_dir, method=method)
|
||||
sim = resume_simulations(Path(args.sim_dir), method=method)
|
||||
sim.run()
|
||||
run_simulation_sequence(
|
||||
*args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
from .. import const
|
||||
import toml
|
||||
|
||||
valid_commands = ["finish", "next"]
|
||||
|
||||
|
||||
class Configurator:
|
||||
def __init__(self, name):
|
||||
self.config = dict(name=name, fiber=dict(), gas=dict(), pulse=dict(), simulation=dict())
|
||||
|
||||
def list_input(self):
|
||||
answer = ""
|
||||
while answer == "":
|
||||
answer = input("Please enter a list of values (one per line)\n")
|
||||
|
||||
out = [self.process_input(answer)]
|
||||
|
||||
while answer != "":
|
||||
answer = input()
|
||||
out.append(self.process_input(answer))
|
||||
|
||||
return out[:-1]
|
||||
|
||||
def process_input(self, s):
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return float(s)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return s
|
||||
|
||||
def accept(self, question, default=True):
|
||||
question += " ([y]/n)" if default else " (y/[n])"
|
||||
question += "\n"
|
||||
inp = input(question)
|
||||
|
||||
yes_str = ["y", "yes"]
|
||||
if default:
|
||||
yes_str.append("")
|
||||
|
||||
return inp.lower() in yes_str
|
||||
|
||||
def print_current(self, config: dict):
|
||||
print(toml.dumps(config))
|
||||
|
||||
def get(self, section, param_name):
|
||||
question = f"Please enter a value for the parameter '{param_name}'\n"
|
||||
valid = const.valid_param_types[section][param_name]
|
||||
|
||||
is_valid = False
|
||||
value = None
|
||||
|
||||
while not is_valid:
|
||||
answer = input(question)
|
||||
if answer == "variable" and param_name in const.valid_variable[section]:
|
||||
value = self.list_input()
|
||||
print(value)
|
||||
is_valid = all(valid(v) for v in value)
|
||||
else:
|
||||
value = self.process_input(answer)
|
||||
is_valid = valid(value)
|
||||
|
||||
return value
|
||||
|
||||
def ask_next_command(self):
|
||||
s = ""
|
||||
raw_input = input(s).split(" ")
|
||||
return raw_input[0], raw_input[1:]
|
||||
|
||||
def main(self):
|
||||
editing = True
|
||||
while editing:
|
||||
command, args = self.ask_next_command()
|
||||
@@ -1,6 +1,3 @@
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
|
||||
@@ -19,243 +16,6 @@ def pbar_format(worker_id: int):
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
|
||||
|
||||
def in_range_excl(func, r):
|
||||
def _in_range(n):
|
||||
if not func(n):
|
||||
return False
|
||||
return n > r[0] and n < r[1]
|
||||
|
||||
_in_range.__doc__ = func.__doc__ + f" between {r[0]} and {r[1]} (exclusive) "
|
||||
return _in_range
|
||||
|
||||
|
||||
def in_range_incl(func, r):
|
||||
def _in_range(n):
|
||||
if not func(n):
|
||||
return False
|
||||
return n >= r[0] and n <= r[1]
|
||||
|
||||
_in_range.__doc__ = func.__doc__ + f" between {r[0]} and {r[1]} (inclusive)"
|
||||
return _in_range
|
||||
|
||||
|
||||
def num(n):
|
||||
"""must be a single, real, non-negative number"""
|
||||
return isinstance(n, (float, int)) and n >= 0
|
||||
|
||||
|
||||
def integer(n):
|
||||
"""must be a strictly positive integer"""
|
||||
return isinstance(n, int) and n > 0
|
||||
|
||||
|
||||
def boolean(b):
|
||||
"""must be a boolean"""
|
||||
return type(b) == bool
|
||||
|
||||
|
||||
def behaviors(l):
|
||||
"""must be a valid list of behaviors"""
|
||||
for s in l:
|
||||
if s.lower() not in ["spm", "raman", "ss"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def beta(l):
|
||||
"""must be a valid beta array"""
|
||||
for n in l:
|
||||
if not isinstance(n, (float, int)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def field_0(f):
|
||||
return isinstance(f, (str, tuple, list, np.ndarray))
|
||||
|
||||
|
||||
def he_mode(mode):
|
||||
"""must be a valide HE mode"""
|
||||
if not isinstance(mode, (list, tuple)):
|
||||
return False
|
||||
if not len(mode) == 2:
|
||||
return False
|
||||
for m in mode:
|
||||
if not integer(m):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def fit_parameters(param):
|
||||
"""must be a valide fitting parameter tuple of the mercatili_adjusted model"""
|
||||
if not isinstance(param, (list, tuple)):
|
||||
return False
|
||||
if not len(param) == 2:
|
||||
return False
|
||||
for n in param:
|
||||
if not integer(n):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def string(l=None):
|
||||
if l is None:
|
||||
|
||||
def _string(s):
|
||||
return isinstance(s, str)
|
||||
|
||||
_string.__doc__ = f"must be a str"
|
||||
else:
|
||||
|
||||
def _string(s):
|
||||
return isinstance(s, str) and s.lower() in l
|
||||
|
||||
_string.__doc__ = f"must be a str matching one of {l}"
|
||||
|
||||
return _string
|
||||
|
||||
|
||||
def capillary_resonance_strengths(l):
|
||||
"""must be a list of non-zero, real number"""
|
||||
if not isinstance(l, (list, tuple)):
|
||||
return False
|
||||
for m in l:
|
||||
if not num(m):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def capillary_nested(n):
|
||||
"""must be a non negative integer"""
|
||||
return isinstance(n, int) and n >= 0
|
||||
|
||||
|
||||
valid_param_types = dict(
|
||||
root=dict(
|
||||
name=string(),
|
||||
prev_data_dir=string(),
|
||||
),
|
||||
fiber=dict(
|
||||
input_transmission=in_range_incl(num, (0, 1)),
|
||||
gamma=num,
|
||||
n2=num,
|
||||
effective_mode_diameter=num,
|
||||
A_eff=num,
|
||||
pitch=in_range_excl(num, (0, 1e-3)),
|
||||
pitch_ratio=in_range_excl(num, (0, 1)),
|
||||
core_radius=in_range_excl(num, (0, 1e-3)),
|
||||
he_mode=he_mode,
|
||||
fit_parameters=fit_parameters,
|
||||
beta=beta,
|
||||
dispersion_file=string(),
|
||||
model=string(["pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"]),
|
||||
length=in_range_excl(num, (0, 1e9)),
|
||||
capillary_num=integer,
|
||||
capillary_outer_d=in_range_excl(num, (0, 1e-3)),
|
||||
capillary_thickness=in_range_excl(num, (0, 1e-3)),
|
||||
capillary_spacing=in_range_excl(num, (0, 1e-3)),
|
||||
capillary_resonance_strengths=capillary_resonance_strengths,
|
||||
capillary_nested=capillary_nested,
|
||||
),
|
||||
gas=dict(
|
||||
gas_name=string(["vacuum", "helium", "air"]),
|
||||
pressure=num,
|
||||
temperature=num,
|
||||
plasma_density=num,
|
||||
),
|
||||
pulse=dict(
|
||||
field_0=field_0,
|
||||
field_file=string(),
|
||||
repetition_rate=num,
|
||||
peak_power=num,
|
||||
mean_power=num,
|
||||
energy=num,
|
||||
soliton_num=num,
|
||||
quantum_noise=boolean,
|
||||
shape=string(["gaussian", "sech"]),
|
||||
wavelength=in_range_excl(num, (100e-9, 3000e-9)),
|
||||
intensity_noise=in_range_incl(num, (0, 1)),
|
||||
width=in_range_excl(num, (0, 1e-9)),
|
||||
t0=in_range_excl(num, (0, 1e-9)),
|
||||
),
|
||||
simulation=dict(
|
||||
behaviors=behaviors,
|
||||
parallel=boolean,
|
||||
raman_type=string(["measured", "agrawal", "stolen"]),
|
||||
ideal_gas=boolean,
|
||||
repeat=integer,
|
||||
t_num=integer,
|
||||
z_num=integer,
|
||||
time_window=num,
|
||||
dt=in_range_excl(num, (0, 5e-15)),
|
||||
tolerated_error=in_range_excl(num, (1e-15, 1e-5)),
|
||||
step_size=num,
|
||||
lower_wavelength_interp_limit=in_range_excl(num, (100e-9, 3000e-9)),
|
||||
upper_wavelength_interp_limit=in_range_excl(num, (100e-9, 5000e-9)),
|
||||
frep=num,
|
||||
prev_sim_dir=string(),
|
||||
readjust_wavelength=boolean,
|
||||
),
|
||||
)
|
||||
|
||||
hc_model_specific_parameters = dict(
|
||||
marcatili=["core_radius", "he_mode"],
|
||||
marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"],
|
||||
hasan=[
|
||||
"core_radius",
|
||||
"capillary_num",
|
||||
"capillary_thickness",
|
||||
"capillary_resonance_strengths",
|
||||
"capillary_nested",
|
||||
"capillary_spacing",
|
||||
"capillary_outer_d",
|
||||
],
|
||||
)
|
||||
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
|
||||
|
||||
valid_variable = dict(
|
||||
fiber=[
|
||||
"beta",
|
||||
"gamma",
|
||||
"pitch",
|
||||
"pitch_ratio",
|
||||
"core_radius",
|
||||
"capillary_num",
|
||||
"capillary_outer_d",
|
||||
"capillary_thickness",
|
||||
"capillary_spacing",
|
||||
"capillary_resonance_strengths",
|
||||
"capillary_nested",
|
||||
"he_mode",
|
||||
"fit_parameters",
|
||||
"input_transmission",
|
||||
"n2",
|
||||
],
|
||||
gas=["pressure", "temperature", "gas_name", "plasma_density"],
|
||||
pulse=[
|
||||
"peak_power",
|
||||
"mean_power",
|
||||
"energy",
|
||||
"quantum_noise",
|
||||
"shape",
|
||||
"wavelength",
|
||||
"intensity_noise",
|
||||
"width",
|
||||
"soliton_num",
|
||||
],
|
||||
simulation=[
|
||||
"behaviors",
|
||||
"raman_type",
|
||||
"tolerated_error",
|
||||
"step_size",
|
||||
"ideal_gas",
|
||||
"readjust_wavelength",
|
||||
],
|
||||
)
|
||||
|
||||
ENVIRON_KEY_BASE = "SCGENERATOR_"
|
||||
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
|
||||
LOG_POLICY = ENVIRON_KEY_BASE + "LOG_POLICY"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from .errors import MissingParameterError
|
||||
from pathlib import Path
|
||||
|
||||
default_parameters = dict(
|
||||
input_transmission=1.0,
|
||||
@@ -28,6 +27,7 @@ default_parameters = dict(
|
||||
upper_wavelength_interp_limit=1900e-9,
|
||||
ideal_gas=False,
|
||||
readjust_wavelength=False,
|
||||
recovery_last_stored=0,
|
||||
)
|
||||
|
||||
default_plotting = dict(
|
||||
@@ -36,7 +36,7 @@ default_plotting = dict(
|
||||
vmin=-40,
|
||||
vmax=0,
|
||||
vmax_with_headroom=2,
|
||||
name="plot",
|
||||
out_path=Path("plot"),
|
||||
avg_main_to_coherence_ratio=4,
|
||||
avg_line_labels=["individual values", "mean"],
|
||||
muted_style=dict(linewidth=0.5, c=(0.8, 0.8, 0.8, 0.4)),
|
||||
@@ -57,76 +57,3 @@ default_plotting = dict(
|
||||
text_topright_style=dict(verticalalignment="top", horizontalalignment="right"),
|
||||
text_topleft_style=dict(verticalalignment="top", horizontalalignment="left"),
|
||||
)
|
||||
|
||||
|
||||
def get(section_dict, param, **kwargs):
|
||||
"""checks if param is in the parameter section dict and attempts to fill in a default value
|
||||
|
||||
Parameters
|
||||
----------
|
||||
section_dict : dict
|
||||
the parameters section {fiber, pulse, simulation, root} sub-dictionary
|
||||
param : str
|
||||
the name of the parameter (dict key)
|
||||
kwargs : any
|
||||
key word arguments passed to the MissingParameterError constructor
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the updated section_dict dictionary
|
||||
|
||||
Raises
|
||||
------
|
||||
MissingFiberParameterError
|
||||
raised when a parameter is missing and no default exists
|
||||
"""
|
||||
|
||||
# whether the parameter is in the right place and valid is checked elsewhere,
|
||||
# here, we just make sure it is present.
|
||||
if param not in section_dict and param not in section_dict.get("variable", {}):
|
||||
try:
|
||||
section_dict[param] = default_parameters[param]
|
||||
# LOG
|
||||
except KeyError:
|
||||
raise MissingParameterError(param, **kwargs)
|
||||
return section_dict
|
||||
|
||||
|
||||
def get_fiber(section_dict, param, **kwargs):
|
||||
"""wrapper for fiber parameters that depend on fiber model"""
|
||||
return get(section_dict, param, fiber_model=section_dict["model"], **kwargs)
|
||||
|
||||
|
||||
def get_multiple(section_dict, params, num, **kwargs):
|
||||
"""similar to th get method but works with several parameters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
section_dict : dict
|
||||
the parameters section {fiber, pulse, simulation, root}, sub-dictionary
|
||||
params : list of str
|
||||
names of the required parameters
|
||||
num : int
|
||||
how many of the parameters in params are required
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
the updated section_dict
|
||||
|
||||
Raises
|
||||
------
|
||||
MissingParameterError
|
||||
raised when not enough parameters are provided and no defaults exist
|
||||
"""
|
||||
gotten = 0
|
||||
for param in params:
|
||||
try:
|
||||
section_dict = get(section_dict, param, **kwargs)
|
||||
gotten += 1
|
||||
except MissingParameterError:
|
||||
pass
|
||||
if gotten >= num:
|
||||
return section_dict
|
||||
raise MissingParameterError(params, num_required=num, **kwargs)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Literal, Optional, Set
|
||||
|
||||
from .const import ENVIRON_KEY_BASE, PBAR_POLICY, LOG_POLICY, TMP_FOLDER_KEY_BASE
|
||||
from .const import ENVIRON_KEY_BASE, LOG_POLICY, PBAR_POLICY, TMP_FOLDER_KEY_BASE
|
||||
|
||||
|
||||
def data_folder(task_id: int) -> Optional[Path]:
|
||||
def data_folder(task_id: int) -> Optional[str]:
|
||||
idstr = str(int(task_id))
|
||||
tmp = os.getenv(TMP_FOLDER_KEY_BASE + idstr)
|
||||
return tmp
|
||||
|
||||
@@ -34,18 +34,3 @@ class DuplicateParameterError(Exception):
|
||||
|
||||
class IncompleteDataFolderError(FileNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
# class MissingFiberParameterError(MissingParameterError):
|
||||
# def __init__(self, param, model):
|
||||
# self.param = param
|
||||
# self.model = model
|
||||
# super().__init__(
|
||||
# f"'{self.param}' is a required parameter for fiber model '{self.model}' and no default value is set"
|
||||
# )
|
||||
|
||||
|
||||
# class MissingPulseParameterError(MissingParameterError):
|
||||
# def __init__(self, param):
|
||||
# self.param = param
|
||||
# super().__init__(f"'{self.param}' is a required pulse parameter and no default value is set")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
||||
from dataclasses import asdict
|
||||
import itertools
|
||||
import os
|
||||
import shutil
|
||||
@@ -11,18 +12,17 @@ import toml
|
||||
|
||||
from . import env, utils
|
||||
from .const import (
|
||||
__version__,
|
||||
ENVIRON_KEY_BASE,
|
||||
PARAM_FN,
|
||||
PARAM_SEPARATOR,
|
||||
PBAR_POLICY,
|
||||
SPEC1_FN,
|
||||
SPECN_FN,
|
||||
TMP_FOLDER_KEY_BASE,
|
||||
Z_FN,
|
||||
__version__,
|
||||
)
|
||||
from .errors import IncompleteDataFolderError
|
||||
from .logger import get_logger
|
||||
from .utils.parameter import BareConfig, BareParams
|
||||
|
||||
PathTree = List[Tuple[Path, ...]]
|
||||
|
||||
@@ -88,6 +88,10 @@ def load_toml(path: os.PathLike):
|
||||
path = conform_toml_path(path)
|
||||
with open(path, mode="r") as file:
|
||||
dico = toml.load(file)
|
||||
|
||||
for section in ["simulation", "fiber", "pulse", "gas"]:
|
||||
dico.update(dico.pop(section, {}))
|
||||
|
||||
return dico
|
||||
|
||||
|
||||
@@ -99,52 +103,15 @@ def save_toml(path: os.PathLike, dico):
|
||||
return dico
|
||||
|
||||
|
||||
def serializable(val):
|
||||
"""returns True if val is serializable into a Json file"""
|
||||
types = (np.ndarray, float, int, str, list, tuple)
|
||||
|
||||
out = isinstance(val, types)
|
||||
if isinstance(val, np.ndarray):
|
||||
out &= val.dtype != "complex"
|
||||
return out
|
||||
|
||||
|
||||
def prepare_for_serialization(dico: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""prepares a dictionary for serialization. Some keys may not be preserved
|
||||
(dropped due to no conversion available)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dico : dict
|
||||
dictionary
|
||||
"""
|
||||
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"]
|
||||
types = (np.ndarray, float, int, str, list, tuple, dict)
|
||||
out = {}
|
||||
for key, value in dico.items():
|
||||
if key in forbiden_keys:
|
||||
continue
|
||||
if not isinstance(value, types):
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
out[key] = prepare_for_serialization(value)
|
||||
elif isinstance(value, np.ndarray) and value.dtype == complex:
|
||||
continue
|
||||
else:
|
||||
out[key] = value
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path:
|
||||
def save_parameters(params: BareParams, destination_dir: Path, file_name="params.toml") -> Path:
|
||||
"""saves a parameter dictionary. Note that is does remove some entries, particularly
|
||||
those that take a lot of space ("t", "w", ...)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
param_dict : Dict[str, Any]
|
||||
params : Dict[str, Any]
|
||||
dictionary to save
|
||||
data_dir : Path
|
||||
destination_dir : Path
|
||||
destination directory
|
||||
|
||||
Returns
|
||||
@@ -152,12 +119,8 @@ def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path:
|
||||
Path
|
||||
path to newly created the paramter file
|
||||
"""
|
||||
param = param_dict.copy()
|
||||
file_path = destination_dir / "params.toml"
|
||||
|
||||
param = prepare_for_serialization(param)
|
||||
param["datetime"] = datetime.now()
|
||||
param["version"] = __version__
|
||||
param = params.prepare_for_dump()
|
||||
file_path = destination_dir / file_name
|
||||
|
||||
file_path.parent.mkdir(exist_ok=True)
|
||||
|
||||
@@ -168,7 +131,7 @@ def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path:
|
||||
return file_path
|
||||
|
||||
|
||||
def load_previous_parameters(path: os.PathLike):
|
||||
def load_params(path: os.PathLike) -> BareParams:
|
||||
"""loads a parameters toml files and converts data to appropriate type
|
||||
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
|
||||
|
||||
@@ -179,15 +142,29 @@ def load_previous_parameters(path: os.PathLike):
|
||||
|
||||
Returns
|
||||
----------
|
||||
dict
|
||||
flattened parameters dictionary
|
||||
BareParams
|
||||
params obj
|
||||
"""
|
||||
params = load_toml(path)
|
||||
return BareParams(**params)
|
||||
|
||||
for k, v in params.items():
|
||||
if isinstance(v, list) and isinstance(v[0], (float, int)):
|
||||
params[k] = np.array(v)
|
||||
return params
|
||||
|
||||
def load_config(path: os.PathLike) -> BareConfig:
|
||||
"""loads a parameters toml files and converts data to appropriate type
|
||||
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
path to the toml
|
||||
|
||||
Returns
|
||||
----------
|
||||
BareParams
|
||||
config obj
|
||||
"""
|
||||
config = load_toml(path)
|
||||
return BareConfig(**config)
|
||||
|
||||
|
||||
def load_material_dico(name):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from .env import log_policy
|
||||
|
||||
from .env import log_policy
|
||||
|
||||
# class DebugOnlyFileHandler(logging.FileHandler):
|
||||
# def __init__(
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Type, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from scipy.interpolate import griddata, interp1d
|
||||
from scipy.special import jn_zeros
|
||||
from scipy.interpolate import interp1d, griddata
|
||||
from numba import jit
|
||||
|
||||
|
||||
def span(*vec):
|
||||
@@ -54,7 +54,6 @@ def power_fact(x, n):
|
||||
raise TypeError(f"type {type(x)} of x not supported.")
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def _power_fact_single(x, n):
|
||||
result = 1.0
|
||||
for k in range(n):
|
||||
@@ -62,7 +61,6 @@ def _power_fact_single(x, n):
|
||||
return result
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def _power_fact_array(x, n):
|
||||
result = np.ones(len(x), dtype=np.float64)
|
||||
for k in range(n):
|
||||
@@ -70,7 +68,6 @@ def _power_fact_array(x, n):
|
||||
return result
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def abs2(z: np.ndarray) -> np.ndarray:
|
||||
return z.real ** 2 + z.imag ** 2
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
class Parameter:
|
||||
"""base class for parameters"""
|
||||
|
||||
all = dict(fiber=dict(), pulse=dict(), gas=dict(), simulation=dict())
|
||||
help_message = "no help message lol"
|
||||
|
||||
def __init_subclass__(cls, section):
|
||||
Parameter.all[section][cls.__name__.lower()] = cls
|
||||
|
||||
def __init__(self, s):
|
||||
self.s = s
|
||||
valid = True
|
||||
try:
|
||||
self.value = self._convert()
|
||||
valid = self.valid()
|
||||
except ValueError:
|
||||
valid = False
|
||||
|
||||
if not valid:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} {self.__class__.help_message}. input : {self.s}"
|
||||
)
|
||||
|
||||
def _convert(self):
|
||||
value = self.conversion_func(self.s)
|
||||
return value
|
||||
|
||||
|
||||
class Wavelength(Parameter, section="pulse"):
|
||||
help_message = "must be a strictly positive real number"
|
||||
|
||||
def valid(self):
|
||||
return self.value > 0
|
||||
|
||||
def conversion_func(self, s: str) -> float:
|
||||
return float(s)
|
||||
@@ -1,15 +1,14 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib import disp
|
||||
from numpy.lib.arraysetops import isin
|
||||
import toml
|
||||
from numba import jit
|
||||
from numpy.fft import fft, ifft
|
||||
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from .. import io
|
||||
from ..const import hc_model_specific_parameters
|
||||
from ..math import abs2, argclosest, power_fact, u_nm
|
||||
from ..utils.parameter import BareParams, hc_model_specific_parameters
|
||||
from . import materials as mat
|
||||
from . import units
|
||||
from .units import c, pi
|
||||
@@ -25,7 +24,7 @@ def lambda_for_dispersion():
|
||||
return np.linspace(190e-9, 3000e-9, 4000)
|
||||
|
||||
|
||||
def is_dynamic_dispersion(params):
|
||||
def is_dynamic_dispersion(pressure=None):
|
||||
"""tests if the parameter dictionary implies that the dispersion profile of the fiber changes with z
|
||||
|
||||
Parameters
|
||||
@@ -38,8 +37,8 @@ def is_dynamic_dispersion(params):
|
||||
bool : True if dispersion is supposed to change with z
|
||||
"""
|
||||
out = False
|
||||
if "pressure" in params:
|
||||
out |= isinstance(params["pressure"], (tuple, list)) and len(params["pressure"]) == 2
|
||||
if pressure is not None:
|
||||
out |= isinstance(pressure, (tuple, list)) and len(pressure) == 2
|
||||
|
||||
return out
|
||||
|
||||
@@ -483,7 +482,19 @@ def HCPCF_dispersion(
|
||||
return beta2(w, n_eff)
|
||||
|
||||
|
||||
def dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg):
|
||||
def dynamic_HCPCF_dispersion(
|
||||
lambda_: np.ndarray,
|
||||
pressure_values: List[float],
|
||||
core_radius: float,
|
||||
fiber_model: str,
|
||||
model_params: Dict[str, Any],
|
||||
temperature: float,
|
||||
ideal_gas: bool,
|
||||
w0: float,
|
||||
interp_range: Tuple[float, float],
|
||||
material_dico: Dict[str, Any],
|
||||
deg,
|
||||
):
|
||||
"""returns functions for beta2 coefficients and gamma instead of static values
|
||||
|
||||
Parameters
|
||||
@@ -504,25 +515,22 @@ def dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg):
|
||||
in the fiber
|
||||
"""
|
||||
|
||||
# store values because storing functions acts weird with dict
|
||||
pressure_values = params["pressure"]
|
||||
a = params["core_radius"]
|
||||
fiber_model = params["fiber_model"]
|
||||
model_params = {k: params[k] for k in hc_model_specific_parameters[fiber_model]}
|
||||
temp = params["temperature"]
|
||||
ideal_gas = params["ideal_gas"]
|
||||
w0 = params["w0"]
|
||||
interp_range = params["interp_range"]
|
||||
|
||||
A_eff = 1.5 * a ** 2
|
||||
A_eff = 1.5 * core_radius ** 2
|
||||
|
||||
# defining function instead of storing every possilble value
|
||||
pressure = lambda r: mat.pressure_from_gradient(r, *pressure_values)
|
||||
beta2 = lambda r: HCPCF_dispersion(
|
||||
lambda_, a, material_dico, fiber_model, model_params, pressure(r), temp, ideal_gas
|
||||
lambda_,
|
||||
core_radius,
|
||||
material_dico,
|
||||
fiber_model,
|
||||
model_params,
|
||||
pressure(r),
|
||||
temperature,
|
||||
ideal_gas,
|
||||
)
|
||||
|
||||
n2 = lambda r: mat.non_linear_refractive_index(material_dico, pressure(r), temp)
|
||||
n2 = lambda r: mat.non_linear_refractive_index(material_dico, pressure(r), temperature)
|
||||
ratio_range = np.linspace(0, 1, 256)
|
||||
|
||||
gamma_grid = np.array([gamma_parameter(n2(r), w0, A_eff) for r in ratio_range])
|
||||
@@ -640,7 +648,7 @@ def PCF_dispersion(lambda_, pitch, ratio_d, w0=None, n2=None, A_eff=None):
|
||||
return beta2, gamma
|
||||
|
||||
|
||||
def dispersion_central(fiber_model, params, deg=8):
|
||||
def compute_dispersion(params: BareParams, deg=8):
|
||||
"""dispatch function depending on what type of fiber is used
|
||||
|
||||
Parameters
|
||||
@@ -660,8 +668,8 @@ def dispersion_central(fiber_model, params, deg=8):
|
||||
nonlinear parameter
|
||||
"""
|
||||
|
||||
if "dispersion_file" in params:
|
||||
disp_file = np.load(params["dispersion_file"])
|
||||
if params.dispersion_file is not None:
|
||||
disp_file = np.load(params.dispersion_file)
|
||||
lambda_ = disp_file["wavelength"]
|
||||
D = disp_file["dispersion"]
|
||||
beta2 = D_to_beta2(D, lambda_)
|
||||
@@ -669,21 +677,20 @@ def dispersion_central(fiber_model, params, deg=8):
|
||||
else:
|
||||
lambda_ = lambda_for_dispersion()
|
||||
beta2 = np.zeros_like(lambda_)
|
||||
fiber_model = fiber_model.lower()
|
||||
|
||||
if fiber_model == "pcf":
|
||||
if params.model == "pcf":
|
||||
beta2, gamma = PCF_dispersion(
|
||||
lambda_,
|
||||
params["pitch"],
|
||||
params["pitch_ratio"],
|
||||
w0=params["w0"],
|
||||
n2=params.get("n2"),
|
||||
A_eff=params.get("A_eff"),
|
||||
params.pitch,
|
||||
params.pitch_ratio,
|
||||
w0=params.w0,
|
||||
n2=params.n2,
|
||||
A_eff=params.A_eff,
|
||||
)
|
||||
|
||||
else:
|
||||
# Load material info
|
||||
gas_name = params["gas_name"]
|
||||
gas_name = params.gas_name
|
||||
|
||||
if gas_name == "vacuum":
|
||||
material_dico = None
|
||||
@@ -691,8 +698,20 @@ def dispersion_central(fiber_model, params, deg=8):
|
||||
material_dico = toml.loads(io.Paths.gets("gas"))[gas_name]
|
||||
|
||||
# compute dispersion
|
||||
if params.get("dynamic_dispersion", False):
|
||||
return dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg)
|
||||
if params.dynamic_dispersion:
|
||||
return dynamic_HCPCF_dispersion(
|
||||
lambda_,
|
||||
params.pressure,
|
||||
params.core_radius,
|
||||
params.model,
|
||||
{k: getattr(params, k) for k in hc_model_specific_parameters[params.model]},
|
||||
params.temperature,
|
||||
params.ideal_gas,
|
||||
params.w0,
|
||||
params.interp_range,
|
||||
material_dico,
|
||||
deg,
|
||||
)
|
||||
else:
|
||||
|
||||
# actually compute the dispersion
|
||||
@@ -700,31 +719,31 @@ def dispersion_central(fiber_model, params, deg=8):
|
||||
beta2 = HCPCF_dispersion(
|
||||
lambda_,
|
||||
material_dico,
|
||||
fiber_model,
|
||||
{k: params[k] for k in hc_model_specific_parameters[fiber_model]},
|
||||
params["pressure"],
|
||||
params["temperature"],
|
||||
params["ideal_gas"],
|
||||
params.model,
|
||||
{k: getattr(params, k) for k in hc_model_specific_parameters[params.model]},
|
||||
params.pressure,
|
||||
params.temperature,
|
||||
params.ideal_gas,
|
||||
)
|
||||
|
||||
if material_dico is not None:
|
||||
A_eff = 1.5 * params["core_radius"] ** 2
|
||||
A_eff = 1.5 * params.core_radius ** 2
|
||||
n2 = mat.non_linear_refractive_index(
|
||||
material_dico, params["pressure"], params["temperature"]
|
||||
material_dico, params.pressure, params.temperature
|
||||
)
|
||||
gamma = gamma_parameter(n2, params["w0"], A_eff)
|
||||
gamma = gamma_parameter(n2, params.w0, A_eff)
|
||||
else:
|
||||
gamma = None
|
||||
|
||||
# add plasma if wanted
|
||||
if params["plasma_density"] > 0:
|
||||
beta2 += plasma_dispersion(lambda_, params["plasma_density"])
|
||||
if params.plasma_density > 0:
|
||||
beta2 += plasma_dispersion(lambda_, params.plasma_density)
|
||||
|
||||
beta2_coef = dispersion_coefficients(lambda_, beta2, params["w0"], params["interp_range"], deg)
|
||||
beta2_coef = dispersion_coefficients(lambda_, beta2, params.w0, params.interp_range, deg)
|
||||
|
||||
if gamma is None:
|
||||
if "A_eff" in params:
|
||||
gamma = gamma_parameter(params.get("n2", 2.6e-20), params["w0"], params["A_eff"])
|
||||
if params.A_eff is not None:
|
||||
gamma = gamma_parameter(params.n2, params.w0, params.A_eff)
|
||||
else:
|
||||
gamma = 0
|
||||
|
||||
@@ -855,15 +874,13 @@ def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=Non
|
||||
"""
|
||||
|
||||
# Compute raman response function if necessary
|
||||
f_r = 0.18
|
||||
if "raman" in behaviors:
|
||||
if "hr_w" == None:
|
||||
raise TypeError("freq-dependent Raman response must be give")
|
||||
else:
|
||||
if f_r is None:
|
||||
if raman_type in ["stolen", "measured"]:
|
||||
f_r = 0.18
|
||||
elif raman_type == "agrawal":
|
||||
f_r = 0.245
|
||||
if hr_w is None:
|
||||
raise ValueError("freq-dependent Raman response must be give")
|
||||
if f_r is None:
|
||||
if raman_type == "agrawal":
|
||||
f_r = 0.245
|
||||
|
||||
if "spm" in behaviors:
|
||||
spm_part = lambda fi: (1 - f_r) * abs2(fi)
|
||||
@@ -875,7 +892,6 @@ def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=Non
|
||||
else:
|
||||
raman_part = lambda fi: 0
|
||||
|
||||
spm_part = jit(spm_part, nopython=True)
|
||||
ss_part = w_c / w0 if "ss" in behaviors else 0
|
||||
|
||||
if isinstance(gamma, (float, int)):
|
||||
@@ -924,7 +940,6 @@ def fast_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None)):
|
||||
return -1j * out
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def _fast_disp_loop(dispersion: np.ndarray, beta_arr, power_fact_arr):
|
||||
for k in range(len(beta_arr) - 1, -1, -1):
|
||||
dispersion = dispersion + beta_arr[k] * power_fact_arr[k]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from ..logger import get_logger
|
||||
|
||||
from . import units
|
||||
from .units import NA, c, kB
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ n is the number of spectra at the same z position and nt is the size of the time
|
||||
|
||||
import itertools
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -18,13 +19,13 @@ import numpy as np
|
||||
from numpy import pi
|
||||
from numpy.fft import fft, fftshift, ifft
|
||||
from scipy.interpolate import UnivariateSpline
|
||||
from numba import jit
|
||||
|
||||
from .. import io
|
||||
from ..defaults import default_plotting
|
||||
|
||||
from ..logger import get_logger
|
||||
from ..plotting import plot_setup
|
||||
from ..math import *
|
||||
from ..plotting import plot_setup
|
||||
from ..utils.parameter import BareParams
|
||||
|
||||
c = 299792458.0
|
||||
hbar = 1.05457148e-34
|
||||
@@ -205,6 +206,48 @@ def conform_pulse_params(
|
||||
return width, t0, peak_power, energy, soliton_num
|
||||
|
||||
|
||||
def setup_custom_field(params: BareParams) -> bool:
|
||||
"""sets up a custom field function if necessary and returns
|
||||
True if it did so, False otherwise
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : Dict[str, Any]
|
||||
params dictionary
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the field has been modified
|
||||
"""
|
||||
field_0 = width = peak_power = energy = None
|
||||
|
||||
did_set = True
|
||||
|
||||
if params.prev_data_dir is not None:
|
||||
spec = io.load_last_spectrum(Path(params.prev_data_dir))[1]
|
||||
field_0 = np.fft.ifft(spec) * np.sqrt(params.input_transmission)
|
||||
elif params.field_file is not None:
|
||||
field_data = np.load(params.field_file)
|
||||
field_interp = interp1d(
|
||||
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
|
||||
)
|
||||
field_0 = field_interp(params.t)
|
||||
|
||||
field_0 = field_0 * modify_field_ratio(
|
||||
params.t,
|
||||
field_0,
|
||||
params.peak_power,
|
||||
params.energy,
|
||||
params.intensity_noise,
|
||||
)
|
||||
width, peak_power, energy = measure_field(params.t, field_0)
|
||||
else:
|
||||
did_set = False
|
||||
|
||||
return did_set, width, peak_power, energy, field_0
|
||||
|
||||
|
||||
def E0_to_P0(E0, t0, shape="gaussian"):
|
||||
"""convert an initial total pulse energy to a pulse peak peak_power"""
|
||||
return E0 / (t0 * P0T0_to_E0_fac[shape])
|
||||
@@ -223,12 +266,10 @@ def gauss_pulse(t, t0, P0, offset=0):
|
||||
return np.sqrt(P0) * np.exp(-(((t - offset) / t0) ** 2))
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def photon_number(spectrum, w, dw, gamma):
|
||||
return np.sum(1 / gamma * abs2(spectrum) / w * dw)
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def pulse_energy(spectrum, w, dw, _):
|
||||
return np.sum(abs2(spectrum) * dw)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -18,7 +19,14 @@ except ModuleNotFoundError:
|
||||
|
||||
|
||||
class RK4IP:
|
||||
def __init__(self, sim_params, save_data=False, job_identifier="", task_id=0, n_percent=10):
|
||||
def __init__(
|
||||
self,
|
||||
params: initialize.Params,
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
task_id=0,
|
||||
n_percent=10,
|
||||
):
|
||||
"""A 1D solver using 4th order Runge-Kutta in the interaction picture
|
||||
|
||||
Parameters
|
||||
@@ -76,31 +84,29 @@ class RK4IP:
|
||||
self.logger = get_logger(self.job_identifier)
|
||||
self.resuming = False
|
||||
self.save_data = save_data
|
||||
self._extract_params(sim_params)
|
||||
self._setup_functions()
|
||||
self.starting_num = sim_params.get("recovery_last_stored", 0)
|
||||
self._setup_sim_parameters()
|
||||
|
||||
def _extract_params(self, params):
|
||||
self.w_c = params.pop("w_c")
|
||||
self.w0 = params.pop("w0")
|
||||
self.w_power_fact = params.pop("w_power_fact")
|
||||
self.spec_0 = params.pop("spec_0")
|
||||
self.z_targets = params.pop("z_targets")
|
||||
self.z_final = params.pop("length")
|
||||
self.beta = params.pop("beta_func", params.pop("beta"))
|
||||
self.gamma = params.pop("gamma_func", params.pop("gamma"))
|
||||
self.behaviors = params.pop("behaviors")
|
||||
self.raman_type = params.pop("raman_type", "stolen")
|
||||
self.f_r = params.pop("f_r", 0)
|
||||
self.hr_w = params.pop("hr_w", None)
|
||||
self.adapt_step_size = params.pop("adapt_step_size", True)
|
||||
self.error_ok = params.pop("error_ok")
|
||||
self.dynamic_dispersion = params.pop("dynamic_dispersion", False)
|
||||
self.w_c = params.w_c
|
||||
self.w0 = params.w0
|
||||
self.w_power_fact = params.w_power_fact
|
||||
self.spec_0 = params.spec_0
|
||||
self.z_targets = params.z_targets
|
||||
self.z_final = params.length
|
||||
self.beta = params.beta_func if params.beta_func is not None else params.beta
|
||||
self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma
|
||||
self.behaviors = params.behaviors
|
||||
self.raman_type = params.raman_type
|
||||
self.hr_w = params.hr_w
|
||||
self.adapt_step_size = params.adapt_step_size
|
||||
self.error_ok = params.error_ok
|
||||
self.dynamic_dispersion = params.dynamic_dispersion
|
||||
self.starting_num = params.recovery_last_stored
|
||||
|
||||
self._setup_functions()
|
||||
self._setup_sim_parameters()
|
||||
|
||||
def _setup_functions(self):
|
||||
self.N_func = create_non_linear_op(
|
||||
self.behaviors, self.w_c, self.w0, self.gamma, self.raman_type, self.f_r, self.hr_w
|
||||
self.behaviors, self.w_c, self.w0, self.gamma, self.raman_type, hr_w=self.hr_w
|
||||
)
|
||||
|
||||
if self.dynamic_dispersion:
|
||||
@@ -303,7 +309,7 @@ class RK4IP:
|
||||
class SequentialRK4IP(RK4IP):
|
||||
def __init__(
|
||||
self,
|
||||
sim_params,
|
||||
params: initialize.Params,
|
||||
pbars: utils.PBars,
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
@@ -312,7 +318,7 @@ class SequentialRK4IP(RK4IP):
|
||||
):
|
||||
self.pbars = pbars
|
||||
super().__init__(
|
||||
sim_params,
|
||||
params,
|
||||
save_data=save_data,
|
||||
job_identifier=job_identifier,
|
||||
task_id=task_id,
|
||||
@@ -326,7 +332,7 @@ class SequentialRK4IP(RK4IP):
|
||||
class MutliProcRK4IP(RK4IP):
|
||||
def __init__(
|
||||
self,
|
||||
sim_params,
|
||||
params: initialize.Params,
|
||||
p_queue: multiprocessing.Queue,
|
||||
worker_id: int,
|
||||
save_data=False,
|
||||
@@ -337,7 +343,7 @@ class MutliProcRK4IP(RK4IP):
|
||||
self.worker_id = worker_id
|
||||
self.p_queue = p_queue
|
||||
super().__init__(
|
||||
sim_params,
|
||||
params,
|
||||
save_data=save_data,
|
||||
job_identifier=job_identifier,
|
||||
task_id=task_id,
|
||||
@@ -351,7 +357,7 @@ class MutliProcRK4IP(RK4IP):
|
||||
class RayRK4IP(RK4IP):
|
||||
def __init__(
|
||||
self,
|
||||
sim_params,
|
||||
params: initialize.Params,
|
||||
p_actor,
|
||||
worker_id: int,
|
||||
save_data=False,
|
||||
@@ -362,7 +368,7 @@ class RayRK4IP(RK4IP):
|
||||
self.worker_id = worker_id
|
||||
self.p_actor = p_actor
|
||||
super().__init__(
|
||||
sim_params,
|
||||
params,
|
||||
save_data=save_data,
|
||||
job_identifier=job_identifier,
|
||||
task_id=task_id,
|
||||
@@ -414,7 +420,7 @@ class Simulations:
|
||||
if isinstance(method, str):
|
||||
method = Simulations.simulation_methods_dict[method]
|
||||
return method(param_seq, task_id)
|
||||
elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"]:
|
||||
elif param_seq.num_sim > 1 and param_seq.config.parallel:
|
||||
return Simulations.get_best_method()(param_seq, task_id)
|
||||
else:
|
||||
return SequencialSimulations(param_seq, task_id)
|
||||
@@ -439,7 +445,7 @@ class Simulations:
|
||||
|
||||
self.name = self.param_seq.name
|
||||
self.sim_dir = io.get_sim_dir(self.id, name_if_new=self.name)
|
||||
io.save_toml(os.path.join(self.sim_dir, "initial_config.toml"), self.param_seq.config)
|
||||
io.save_parameters(self.param_seq.config, self.sim_dir, file_name="initial_config.toml")
|
||||
|
||||
self.sim_jobs_per_node = 1
|
||||
self.max_concurrent_jobs = np.inf
|
||||
@@ -447,9 +453,7 @@ class Simulations:
|
||||
@property
|
||||
def finished_and_complete(self):
|
||||
try:
|
||||
io.check_data_integrity(
|
||||
io.get_data_dirs(self.sim_dir), self.param_seq["simulation", "z_num"]
|
||||
)
|
||||
io.check_data_integrity(io.get_data_dirs(self.sim_dir), self.param_seq.config.z_num)
|
||||
return True
|
||||
except IncompleteDataFolderError:
|
||||
return False
|
||||
@@ -472,15 +476,15 @@ class Simulations:
|
||||
self.new_sim(v_list_str, params)
|
||||
self.finish()
|
||||
|
||||
def new_sim(self, v_list_str: str, params: dict):
|
||||
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||
"""responsible to launch a new simulation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
v_list_str : str
|
||||
string that uniquely identifies the simulation as returned by utils.format_variable_list
|
||||
params : dict
|
||||
a flattened parameter dictionary, as returned by initialize.compute_init_parameters
|
||||
params : initialize.Params
|
||||
computed parameters
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -507,7 +511,7 @@ class SequencialSimulations(Simulations, priority=0):
|
||||
super().__init__(param_seq, task_id=task_id)
|
||||
self.pbars = utils.PBars(self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1)
|
||||
|
||||
def new_sim(self, v_list_str: str, params: Dict[str, Any]):
|
||||
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
|
||||
SequentialRK4IP(
|
||||
params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id
|
||||
@@ -517,7 +521,7 @@ class SequencialSimulations(Simulations, priority=0):
|
||||
pass
|
||||
|
||||
def finish(self):
|
||||
pass
|
||||
self.pbars.close()
|
||||
|
||||
|
||||
class MultiProcSimulations(Simulations, priority=1):
|
||||
@@ -553,7 +557,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
||||
worker.start()
|
||||
super().run()
|
||||
|
||||
def new_sim(self, v_list_str: str, params: dict):
|
||||
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||
self.queue.put((v_list_str, params), block=True, timeout=None)
|
||||
|
||||
def finish(self):
|
||||
@@ -576,7 +580,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
||||
p_queue: multiprocessing.Queue,
|
||||
):
|
||||
while True:
|
||||
raw_data: Tuple[List[tuple], Dict[str, Any]] = queue.get()
|
||||
raw_data: Tuple[List[tuple], initialize.Params] = queue.get()
|
||||
if raw_data == 0:
|
||||
queue.task_done()
|
||||
return
|
||||
@@ -635,7 +639,7 @@ class RaySimulations(Simulations, priority=2):
|
||||
.remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
|
||||
)
|
||||
|
||||
def new_sim(self, v_list_str: str, params: dict):
|
||||
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||
while len(self.jobs) >= self.sim_jobs_total:
|
||||
self._collect_1_job()
|
||||
|
||||
@@ -707,28 +711,27 @@ def new_simulation(
|
||||
method: Type[Simulations] = None,
|
||||
) -> Simulations:
|
||||
|
||||
config = io.load_toml(config_file)
|
||||
config_dict = io.load_toml(config_file)
|
||||
|
||||
if prev_sim_dir is not None:
|
||||
config.setdefault("simulation", {})
|
||||
config["simulation"]["prev_sim_dir"] = str(prev_sim_dir)
|
||||
config_dict["prev_sim_dir"] = str(prev_sim_dir)
|
||||
|
||||
task_id = np.random.randint(1e9, 1e12)
|
||||
|
||||
if prev_sim_dir is None:
|
||||
param_seq = initialize.ParamSequence(config)
|
||||
param_seq = initialize.ParamSequence(config_dict)
|
||||
else:
|
||||
param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config)
|
||||
param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config_dict)
|
||||
|
||||
print(f"{param_seq.name=}")
|
||||
|
||||
return Simulations.new(param_seq, task_id, method)
|
||||
|
||||
|
||||
def resume_simulations(sim_dir: str, method: Type[Simulations] = None) -> Simulations:
|
||||
def resume_simulations(sim_dir: Path, method: Type[Simulations] = None) -> Simulations:
|
||||
|
||||
task_id = np.random.randint(1e9, 1e12)
|
||||
config = io.load_toml(os.path.join(sim_dir, "initial_config.toml"))
|
||||
config = io.load_toml(sim_dir / "initial_config.toml")
|
||||
io.set_data_folder(task_id, sim_dir)
|
||||
param_seq = initialize.RecoveryParamSequence(config, task_id)
|
||||
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
# For example, nm(X) means "I give the number X in nm, figure out the ang. freq."
|
||||
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
|
||||
|
||||
from numba.core.types.misc import Phantom
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy import isin, pi
|
||||
from numpy import pi
|
||||
|
||||
c = 299792458.0
|
||||
hbar = 1.05457148e-34
|
||||
@@ -217,7 +218,7 @@ units_map = dict(
|
||||
)
|
||||
|
||||
|
||||
def get_unit(unit):
|
||||
def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]:
|
||||
if isinstance(unit, str):
|
||||
return units_map[unit]
|
||||
return unit
|
||||
|
||||
@@ -1,55 +1,47 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import matplotlib.gridspec as gs
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib.colors import ListedColormap
|
||||
from scgenerator.utils import variable_iterator
|
||||
from scipy.interpolate import UnivariateSpline
|
||||
|
||||
from . import io, math
|
||||
from .math import abs2, length, make_uniform_1D, span
|
||||
from .physics import pulse, units
|
||||
from .defaults import default_plotting as defaults
|
||||
from .math import abs2, make_uniform_1D, span
|
||||
from .physics import pulse, units
|
||||
from .utils.parameter import BareParams
|
||||
|
||||
RangeType = Tuple[float, float, Union[str, Callable]]
|
||||
|
||||
|
||||
def plot_setup(
|
||||
folder_name=None,
|
||||
file_name=None,
|
||||
file_type="png",
|
||||
figsize=defaults["figsize"],
|
||||
params=None,
|
||||
mode="default",
|
||||
):
|
||||
out_path: Path,
|
||||
file_type: str = "png",
|
||||
figsize: Tuple[float, float] = defaults["figsize"],
|
||||
mode: Literal["default", "coherence", "coherence_T"] = "default",
|
||||
) -> Tuple[Path, plt.Figure, Union[plt.Axes, Tuple[plt.Axes]]]:
|
||||
"""It should return :
|
||||
- a folder_name
|
||||
- a file name
|
||||
- a fig
|
||||
- an axis
|
||||
"""
|
||||
file_name = defaults["name"] if file_name is None else file_name
|
||||
out_path = defaults["name"] if out_path is None else out_path
|
||||
plot_name = out_path.stem
|
||||
out_dir = out_path.resolve().parent
|
||||
|
||||
if params is not None:
|
||||
folder_name = params.get("plot.folder_name", folder_name)
|
||||
file_name = params.get("plot.file_name", file_name)
|
||||
file_type = params.get("plot.file_type", file_type)
|
||||
figsize = params.get("plot.figsize", figsize)
|
||||
file_name = plot_name + "." + file_type
|
||||
out_path = out_dir / file_name
|
||||
|
||||
# ensure output folder_name exists
|
||||
folder_name, file_name = (
|
||||
os.path.split(file_name)
|
||||
if folder_name is None
|
||||
else (folder_name, os.path.split(file_name)[1])
|
||||
)
|
||||
folder_name = os.path.join(io.Paths.get("plots"), folder_name)
|
||||
if not os.path.exists(os.path.abspath(folder_name)):
|
||||
os.makedirs(os.path.abspath(folder_name))
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# ensure no overwrite
|
||||
ind = 0
|
||||
while os.path.exists(os.path.join(folder_name, file_name + "_" + str(ind) + "." + file_type)):
|
||||
while (full_path := (out_dir / (plot_name + f"_{ind}." + file_type))).exists():
|
||||
ind += 1
|
||||
file_name = file_name + "_" + str(ind) + "." + file_type
|
||||
|
||||
if mode == "default":
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
@@ -78,7 +70,7 @@ def plot_setup(
|
||||
else:
|
||||
raise ValueError(f"mode {mode} not understood")
|
||||
|
||||
return folder_name, file_name, fig, ax
|
||||
return full_path, fig, ax
|
||||
|
||||
|
||||
def draw_across(ax1, xy1, ax2, xy2, clip_on=False, **kwargs):
|
||||
@@ -297,9 +289,7 @@ def _finish_plot_2D(
|
||||
|
||||
folder_name = ""
|
||||
if is_new_plot:
|
||||
folder_name, file_name, fig, ax = plot_setup(
|
||||
file_name=file_name, file_type=file_type, params=params
|
||||
)
|
||||
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
|
||||
@@ -345,8 +335,8 @@ def _finish_plot_2D(
|
||||
cbar.ax.set_ylabel(cbar_label)
|
||||
|
||||
if is_new_plot:
|
||||
fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200)
|
||||
print(f"plot saved in {os.path.join(folder_name, file_name)}")
|
||||
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||
print(f"plot saved in {out_path}")
|
||||
if cbar_label is not None:
|
||||
return fig, ax, cbar.ax
|
||||
else:
|
||||
@@ -354,20 +344,20 @@ def _finish_plot_2D(
|
||||
|
||||
|
||||
def plot_spectrogram(
|
||||
values,
|
||||
x_range,
|
||||
y_range,
|
||||
params,
|
||||
t_res=None,
|
||||
gate_width=None,
|
||||
log=True,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
cbar_label="normalized intensity (dB)",
|
||||
file_type="png",
|
||||
file_name=None,
|
||||
cmap=None,
|
||||
ax=None,
|
||||
values: np.ndarray,
|
||||
x_range: RangeType,
|
||||
y_range: RangeType,
|
||||
params: BareParams,
|
||||
t_res: int = None,
|
||||
gate_width: float = None,
|
||||
log: bool = True,
|
||||
vmin: float = None,
|
||||
vmax: float = None,
|
||||
cbar_label: str = "normalized intensity (dB)",
|
||||
file_type: str = "png",
|
||||
file_name: str = None,
|
||||
cmap: str = None,
|
||||
ax: plt.Axes = None,
|
||||
):
|
||||
"""Plots a spectrogram given a complex field in the time domain
|
||||
Parameters
|
||||
@@ -382,7 +372,7 @@ def plot_spectrogram(
|
||||
units : function to convert from the desired units to rad/s or to time.
|
||||
common functions are already defined in scgenerator.physics.units
|
||||
look there for more details
|
||||
params : dict
|
||||
params : BareParams
|
||||
parameters of the simulations
|
||||
log : bool, optional
|
||||
whether to compute the logarithm of the spectrogram
|
||||
@@ -424,16 +414,16 @@ def plot_spectrogram(
|
||||
t_win = 2 * np.max(t_range[2](np.abs(t_range[:2])))
|
||||
spec_kwargs = dict(t_res=t_res, t_win=t_win, gate_width=gate_width, shift=False)
|
||||
spec, new_t = pulse.spectrogram(
|
||||
params["t"].copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None}
|
||||
params.t.copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None}
|
||||
)
|
||||
|
||||
# Crop and reoder axis
|
||||
new_t, ind_t, _ = units.sort_axis(new_t, t_range)
|
||||
new_f, ind_f, _ = units.sort_axis(params["w"], f_range)
|
||||
new_f, ind_f, _ = units.sort_axis(params.w, f_range)
|
||||
values = spec[ind_t][:, ind_f]
|
||||
if f_range[2].type == "WL":
|
||||
values = np.apply_along_axis(
|
||||
units.to_WL, 1, values, params["frep"], units.m(f_range[2].inv(new_f))
|
||||
units.to_WL, 1, values, params.frep, units.m(f_range[2].inv(new_f))
|
||||
)
|
||||
values = np.apply_along_axis(make_uniform_1D, 1, values, new_f)
|
||||
|
||||
@@ -463,19 +453,19 @@ def plot_spectrogram(
|
||||
|
||||
|
||||
def plot_results_2D(
|
||||
values,
|
||||
plt_range,
|
||||
params,
|
||||
log="1D",
|
||||
skip=16,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
transpose=False,
|
||||
cbar_label="normalized intensity (dB)",
|
||||
file_type="png",
|
||||
file_name=None,
|
||||
cmap=None,
|
||||
ax=None,
|
||||
values: np.ndarray,
|
||||
plt_range: RangeType,
|
||||
params: BareParams,
|
||||
log: Union[int, float, bool, str] = "1D",
|
||||
skip: int = 16,
|
||||
vmin: float = None,
|
||||
vmax: float = None,
|
||||
transpose: bool = False,
|
||||
cbar_label: Optional[str] = "normalized intensity (dB)",
|
||||
file_type: str = "png",
|
||||
file_name: str = None,
|
||||
cmap: str = None,
|
||||
ax: plt.Axes = None,
|
||||
):
|
||||
"""
|
||||
plots 2D arrays and automatically saves the plots, as well as returns it
|
||||
@@ -540,27 +530,32 @@ def plot_results_2D(
|
||||
# make uniform if converting to wavelength
|
||||
if plt_range[2].type == "WL":
|
||||
if is_spectrum:
|
||||
values = np.apply_along_axis(units.to_WL, 1, values, params.get("frep", 1), x_axis)
|
||||
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
||||
values = np.array(
|
||||
[make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values]
|
||||
)
|
||||
|
||||
z = params["z_targets"]
|
||||
lim_diff = 1e-5 * np.max(z)
|
||||
dz_s = np.diff(z)
|
||||
lim_diff = 1e-5 * np.max(params.z_targets)
|
||||
dz_s = np.diff(params.z_targets)
|
||||
if not np.all(np.diff(dz_s) < lim_diff):
|
||||
new_z = np.linspace(
|
||||
*span(z), int(np.floor((np.max(z) - np.min(z)) / np.min(dz_s[dz_s > lim_diff])))
|
||||
*span(params.z_targets),
|
||||
int(
|
||||
np.floor(
|
||||
(np.max(params.z_targets) - np.min(params.z_targets))
|
||||
/ np.min(dz_s[dz_s > lim_diff])
|
||||
)
|
||||
),
|
||||
)
|
||||
values = np.array(
|
||||
[make_uniform_1D(v, z, n=len(new_z), method="linear") for v in values.T]
|
||||
[make_uniform_1D(v, params.z_targets, n=len(new_z), method="linear") for v in values.T]
|
||||
).T
|
||||
z = new_z
|
||||
params.z_targets = new_z
|
||||
return _finish_plot_2D(
|
||||
values,
|
||||
x_axis,
|
||||
plt_range[2].label,
|
||||
z,
|
||||
params.z_targets,
|
||||
"propagation distance (m)",
|
||||
log,
|
||||
vmin,
|
||||
@@ -576,20 +571,20 @@ def plot_results_2D(
|
||||
|
||||
|
||||
def plot_results_1D(
|
||||
values,
|
||||
plt_range,
|
||||
params,
|
||||
log=False,
|
||||
spacing=1,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
ylabel=None,
|
||||
yscaling=1,
|
||||
file_type="pdf",
|
||||
file_name=None,
|
||||
ax=None,
|
||||
line_label=None,
|
||||
transpose=False,
|
||||
values: np.ndarray,
|
||||
plt_range: RangeType,
|
||||
params: BareParams,
|
||||
log: Union[str, int, float, bool] = False,
|
||||
spacing: Union[int, float] = 1,
|
||||
vmin: float = None,
|
||||
vmax: float = None,
|
||||
ylabel: str = None,
|
||||
yscaling: float = 1,
|
||||
file_type: str = "pdf",
|
||||
file_name: str = None,
|
||||
ax: plt.Axes = None,
|
||||
line_label: str = None,
|
||||
transpose: bool = False,
|
||||
**line_kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -656,7 +651,7 @@ def plot_results_1D(
|
||||
# make uniform if converting to wavelength
|
||||
if plt_range[2].type == "WL":
|
||||
if is_spectrum:
|
||||
values = units.to_WL(values, params["frep"], units.m.inv(params["w"][ind]))
|
||||
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind]))
|
||||
|
||||
# change the resolution
|
||||
if isinstance(spacing, float):
|
||||
@@ -683,9 +678,7 @@ def plot_results_1D(
|
||||
|
||||
folder_name = ""
|
||||
if is_new_plot:
|
||||
folder_name, file_name, fig, ax = plot_setup(
|
||||
file_name=file_name, file_type=file_type, params=params
|
||||
)
|
||||
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
if transpose:
|
||||
@@ -702,40 +695,40 @@ def plot_results_1D(
|
||||
ax.set_xlabel(plt_range[2].label)
|
||||
|
||||
if is_new_plot:
|
||||
fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200)
|
||||
print(f"plot saved in {os.path.join(folder_name, file_name)}")
|
||||
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||
print(f"plot saved in {out_path}")
|
||||
return fig, ax, x_axis, values
|
||||
|
||||
|
||||
def _prep_plot(values, plt_range, params):
|
||||
def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams):
|
||||
is_spectrum = values.dtype == "complex"
|
||||
plt_range = (*plt_range[:2], units.get_unit(plt_range[2]))
|
||||
if plt_range[2].type in ["WL", "FREQ", "AFREQ"]:
|
||||
x_axis = params["w"].copy()
|
||||
x_axis = params.w.copy()
|
||||
else:
|
||||
x_axis = params["t"].copy()
|
||||
x_axis = params.t.copy()
|
||||
return is_spectrum, x_axis, plt_range
|
||||
|
||||
|
||||
def plot_avg(
|
||||
values,
|
||||
plt_range,
|
||||
params,
|
||||
log=False,
|
||||
spacing=1,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
ylabel=None,
|
||||
yscaling=1,
|
||||
renormalize=True,
|
||||
add_coherence=False,
|
||||
file_type="png",
|
||||
file_name=None,
|
||||
ax=None,
|
||||
line_labels=None,
|
||||
legend=True,
|
||||
legend_kwargs={},
|
||||
transpose=False,
|
||||
values: np.ndarray,
|
||||
plt_range: RangeType,
|
||||
params: BareParams,
|
||||
log: Union[float, int, str, bool] = False,
|
||||
spacing: Union[float, int] = 1,
|
||||
vmin: float = None,
|
||||
vmax: float = None,
|
||||
ylabel: str = None,
|
||||
yscaling: float = 1,
|
||||
renormalize: bool = True,
|
||||
add_coherence: bool = False,
|
||||
file_type: str = "png",
|
||||
file_name: str = None,
|
||||
ax: plt.Axes = None,
|
||||
line_labels: Tuple[str, str] = None,
|
||||
legend: bool = True,
|
||||
legend_kwargs: Dict[str, Any] = {},
|
||||
transpose: bool = False,
|
||||
):
|
||||
"""
|
||||
plots 1D arrays and there mean and automatically saves the plots, as well as returns it
|
||||
@@ -817,8 +810,8 @@ def plot_avg(
|
||||
values *= yscaling
|
||||
mean_values = np.mean(values, axis=0)
|
||||
if plt_range[2].type == "WL" and renormalize:
|
||||
values = np.apply_along_axis(units.to_WL, 1, values, params["frep"], x_axis)
|
||||
mean_values = units.to_WL(mean_values, params["frep"], x_axis)
|
||||
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
||||
mean_values = units.to_WL(mean_values, params.frep, x_axis)
|
||||
|
||||
# change the resolution
|
||||
if isinstance(spacing, float):
|
||||
@@ -852,12 +845,12 @@ def plot_avg(
|
||||
if is_new_plot:
|
||||
if add_coherence:
|
||||
mode = "coherence_T" if transpose else "coherence"
|
||||
folder_name, file_name, fig, (top, bot) = plot_setup(
|
||||
file_name=file_name, file_type=file_type, params=params, mode=mode
|
||||
out_path, fig, (top, bot) = plot_setup(
|
||||
out_path=Path(folder_name) / file_name, file_type=file_type, mode=mode
|
||||
)
|
||||
else:
|
||||
folder_name, file_name, fig, top = plot_setup(
|
||||
file_name=file_name, file_type=file_type, params=params
|
||||
out_path, fig, top = plot_setup(
|
||||
out_path=Path(folder_name) / file_name, file_type=file_type
|
||||
)
|
||||
bot = top
|
||||
else:
|
||||
@@ -923,8 +916,8 @@ def plot_avg(
|
||||
top.legend(custom_lines, line_labels, **legend_kwargs)
|
||||
|
||||
if is_new_plot:
|
||||
fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200)
|
||||
print(f"plot saved in {os.path.join(folder_name, file_name)}")
|
||||
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||
print(f"plot saved in {out_path}")
|
||||
|
||||
if top is bot:
|
||||
return fig, top
|
||||
@@ -984,46 +977,6 @@ def prepare_plot_1D(values, plt_range, x_axis, yscaling=1, spacing=1, frep=80e6)
|
||||
return x_axis, np.squeeze(values)
|
||||
|
||||
|
||||
def plot_dispersion_parameter(params, plt_range):
|
||||
"""
|
||||
Plots the dispersion parameter D as well as the beta2 parameter over the given range
|
||||
"""
|
||||
# TODO allow several curves, with legends, to be plotted
|
||||
|
||||
x_axis = np.linspace(*plt_range[:2], 1000)
|
||||
w_axis = plt_range[2](x_axis)
|
||||
|
||||
if "disp_obj" in params:
|
||||
D = params["disp_obj"].D_w(w_axis)
|
||||
beta2 = params["disp_obj"].beta2_w(w_axis)
|
||||
else:
|
||||
print("no dispersion information given")
|
||||
return
|
||||
|
||||
fig, (ax_D, ax_beta2) = plt.subplots(1, 2)
|
||||
|
||||
ax_D.plot(x_axis, 1e6 * D)
|
||||
ax_D.plot(
|
||||
x_axis,
|
||||
0 * x_axis,
|
||||
":",
|
||||
c="k",
|
||||
)
|
||||
ax_D.set_xlabel(plt_range[2].label)
|
||||
ax_D.set_ylabel(r"Dispersion parameter $D$ ($\frac{\mathrm{ps}}{\mathrm{nm\ km}}$)")
|
||||
|
||||
ax_beta2.plot(x_axis, 1e27 * beta2)
|
||||
ax_beta2.plot(
|
||||
x_axis,
|
||||
0 * x_axis,
|
||||
":",
|
||||
c="k",
|
||||
)
|
||||
ax_beta2.set_xlabel(plt_range[2].label)
|
||||
ax_beta2.set_ylabel(r"$\beta_2$ parameter ($\frac{\mathrm{ps}^2}{\mathrm{km}}$)")
|
||||
plt.show()
|
||||
|
||||
|
||||
def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):
|
||||
"""returns a new colormap based on "name" but that has a solid bacground (default=white)"""
|
||||
top = plt.get_cmap(name, 1024)
|
||||
|
||||
@@ -128,11 +128,11 @@ def main():
|
||||
args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node)
|
||||
|
||||
submit_path = Path(
|
||||
"submit " + final_config["name"] + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh"
|
||||
"submit " + final_config.name + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh"
|
||||
)
|
||||
tmp_path = Path("submit tmp.sh")
|
||||
|
||||
job_name = f"supercontinuum {final_config['name']}"
|
||||
job_name = f"supercontinuum {final_config.name}"
|
||||
submit_sh = template.format(
|
||||
job_name=job_name, configs_list=" ".join(f'"{c}"' for c in args.configs), **vars(args)
|
||||
)
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from glob import glob
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scgenerator.const import SPECN_FN
|
||||
|
||||
from . import io, initialize, math
|
||||
from .plotting import units
|
||||
from . import initialize, io, math
|
||||
from .const import SPECN_FN
|
||||
from .logger import get_logger
|
||||
from .plotting import units
|
||||
|
||||
|
||||
class Spectrum(np.ndarray):
|
||||
@@ -43,7 +41,7 @@ class Pulse(Sequence):
|
||||
|
||||
self.params = None
|
||||
try:
|
||||
self.params = io.load_previous_parameters(self.path / "params.toml")
|
||||
self.params = io.load_params(self.path / "params.toml")
|
||||
except FileNotFoundError:
|
||||
self.logger.info(f"parameters corresponding to {self.path} not found")
|
||||
|
||||
|
||||
@@ -4,24 +4,23 @@ scgenerator module but some function may be used in any python program
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
import multiprocessing
|
||||
import threading
|
||||
import time
|
||||
from collections import abc
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict, replace
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Tuple, TypeVar, Union
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import env
|
||||
from .const import PARAM_SEPARATOR, valid_variable
|
||||
from .math import *
|
||||
from .. import env
|
||||
from ..const import PARAM_SEPARATOR
|
||||
from ..math import *
|
||||
from .parameter import BareConfig, BareParams
|
||||
|
||||
T_ = TypeVar("T_")
|
||||
|
||||
@@ -177,18 +176,11 @@ def progress_worker(
|
||||
pbars[0].update()
|
||||
|
||||
|
||||
def count_variations(config: dict) -> Tuple[int, int]:
|
||||
def count_variations(config: BareConfig) -> Tuple[int, int]:
|
||||
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
|
||||
variable_params_num is the number of distinct parameters that will vary."""
|
||||
sim_num = 1
|
||||
variable_params_num = 0
|
||||
|
||||
for section_name in valid_variable:
|
||||
for array in config.get(section_name, {}).get("variable", {}).values():
|
||||
sim_num *= len(array)
|
||||
variable_params_num += 1
|
||||
|
||||
sim_num *= config["simulation"].get("repeat", 1)
|
||||
variable_params_num = len(config.variable)
|
||||
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
|
||||
return sim_num, variable_params_num
|
||||
|
||||
|
||||
@@ -217,49 +209,45 @@ def format_value(value):
|
||||
return str(value)
|
||||
|
||||
|
||||
def variable_iterator(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
|
||||
def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]:
|
||||
"""given a config with "variable" parameters, iterates through every possible combination,
|
||||
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : dict
|
||||
initial config dictionary
|
||||
config : BareConfig
|
||||
initial config obj
|
||||
|
||||
Yields
|
||||
-------
|
||||
Iterator[Tuple[List[Tuple[str, Any]], dict]]
|
||||
Iterator[Tuple[List[Tuple[str, Any]], BareParams]]
|
||||
variable_list : a list of (name, value) tuple of parameter name and value that are variable.
|
||||
|
||||
dict : a config dictionary for one simulation
|
||||
params : a BareParams obj for one simulation
|
||||
"""
|
||||
indiv_config = deepcopy(config)
|
||||
variable_dict = {
|
||||
section_name: indiv_config.get(section_name, {}).pop("variable", {})
|
||||
for section_name in valid_variable
|
||||
}
|
||||
|
||||
possible_keys = []
|
||||
possible_ranges = []
|
||||
|
||||
for section_name, section in variable_dict.items():
|
||||
for key in section:
|
||||
arr = variable_dict[section_name][key]
|
||||
possible_keys.append((section_name, key))
|
||||
possible_ranges.append(range(len(arr)))
|
||||
for key, values in config.variable.items():
|
||||
possible_keys.append(key)
|
||||
possible_ranges.append(range(len(values)))
|
||||
|
||||
combinations = itertools.product(*possible_ranges)
|
||||
|
||||
for combination in combinations:
|
||||
indiv_config = {}
|
||||
variable_list = []
|
||||
for i, key in enumerate(possible_keys):
|
||||
parameter_value = variable_dict[key[0]][key[1]][combination[i]]
|
||||
indiv_config[key[0]][key[1]] = parameter_value
|
||||
variable_list.append((key[1], parameter_value))
|
||||
yield variable_list, indiv_config
|
||||
parameter_value = config.variable[key][combination[i]]
|
||||
indiv_config[key] = parameter_value
|
||||
variable_list.append((key, parameter_value))
|
||||
param_dict = asdict(config)
|
||||
param_dict.pop("variable")
|
||||
param_dict.update(indiv_config)
|
||||
yield variable_list, BareParams(**param_dict)
|
||||
|
||||
|
||||
def required_simulations(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
|
||||
def required_simulations(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]:
|
||||
"""takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different
|
||||
parameter set and iterates through every single necessary simulation
|
||||
|
||||
@@ -273,48 +261,19 @@ def required_simulations(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]
|
||||
dict : a config dictionary for one simulation
|
||||
"""
|
||||
i = 0 # unique sim id
|
||||
for variable_only, full_config in variable_iterator(config):
|
||||
for j in range(config["simulation"]["repeat"]):
|
||||
for variable_only, bare_params in variable_iterator(config):
|
||||
for j in range(config.repeat):
|
||||
variable_ind = [("id", i)] + variable_only + [("num", j)]
|
||||
i += 1
|
||||
yield variable_ind, full_config
|
||||
yield variable_ind, bare_params
|
||||
|
||||
|
||||
def deep_update(d: Mapping, u: Mapping) -> dict:
|
||||
for k, v in u.items():
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
d[k] = deep_update(d.get(k, {}), v)
|
||||
else:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
def override_config(new: Dict[str, Any], old: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
|
||||
"""makes sure all the parameters set in new are there, leaves untouched parameters in old"""
|
||||
if old is None:
|
||||
return new
|
||||
out = deepcopy(old)
|
||||
for section_name, section in new.items():
|
||||
if isinstance(section, Mapping):
|
||||
for param_name, value in section.items():
|
||||
if param_name == "variable" and isinstance(value, Mapping):
|
||||
out[section_name].setdefault("variable", {})
|
||||
for p, v in value.items():
|
||||
# override previously unvariable param
|
||||
if p in old[section_name]:
|
||||
del out[section_name][p]
|
||||
out[section_name]["variable"][p] = v
|
||||
else:
|
||||
# override previously variable param
|
||||
if (
|
||||
"variable" in old[section_name]
|
||||
and isinstance(old[section_name]["variable"], Mapping)
|
||||
and param_name in old[section_name]["variable"]
|
||||
):
|
||||
del out[section_name]["variable"][param_name]
|
||||
if len(out[section_name]["variable"]) == 0:
|
||||
del out[section_name["variable"]]
|
||||
out[section_name][param_name] = value
|
||||
else:
|
||||
out[section_name] = section
|
||||
return out
|
||||
return BareConfig(**new)
|
||||
variable = deepcopy(old.variable)
|
||||
variable.update(new.pop("variable", {})) # add new variable
|
||||
for k in new:
|
||||
variable.pop(k) # remove old ones
|
||||
return replace(old, variable=variable, **{k: None for k in variable}, **new)
|
||||
442
src/scgenerator/utils/parameter.py
Normal file
442
src/scgenerator/utils/parameter.py
Normal file
@@ -0,0 +1,442 @@
|
||||
import datetime
|
||||
from copy import copy
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..const import __version__
|
||||
|
||||
|
||||
@lru_cache
|
||||
def type_checker(*types):
|
||||
def _type_checker_wrapper(validator, n=None):
|
||||
if isinstance(validator, str) and n is not None:
|
||||
_name = validator
|
||||
validator = lambda *args: None
|
||||
|
||||
def _type_checker_wrapped(name, n):
|
||||
if not isinstance(n, types):
|
||||
raise TypeError(
|
||||
f"{name!r} value must be of type {' or '.join(format(t) for t in types)} "
|
||||
f"instead of {type(n)}"
|
||||
)
|
||||
validator(name, n)
|
||||
|
||||
if n is None:
|
||||
return _type_checker_wrapped
|
||||
else:
|
||||
_type_checker_wrapped(_name, n)
|
||||
|
||||
return _type_checker_wrapper
|
||||
|
||||
|
||||
@type_checker(str)
|
||||
def string(name, n):
|
||||
if len(n) == 0:
|
||||
raise ValueError(f"{name!r} must not be empty")
|
||||
|
||||
|
||||
def in_range_excl(_min, _max):
|
||||
@type_checker(float, int)
|
||||
def _in_range(name, n):
|
||||
if n <= _min or n >= _max:
|
||||
raise ValueError(f"{name!r} must be between {_min} and {_max} (exclusive)")
|
||||
|
||||
return _in_range
|
||||
|
||||
|
||||
def in_range_incl(_min, _max):
|
||||
@type_checker(float, int)
|
||||
def _in_range(name, n):
|
||||
if n < _min or n > _max:
|
||||
raise ValueError(f"{name!r} must be between {_min} and {_max} (inclusive)")
|
||||
|
||||
return _in_range
|
||||
|
||||
|
||||
def boolean(name, n):
|
||||
if not n is True and not n is False:
|
||||
raise ValueError(f"{name!r} must be True or False")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def non_negative(*types):
|
||||
@type_checker(*types)
|
||||
def _non_negative(name, n):
|
||||
if n < 0:
|
||||
raise ValueError(f"{name!r} must be non negative")
|
||||
|
||||
return _non_negative
|
||||
|
||||
|
||||
@lru_cache
|
||||
def positive(*types):
|
||||
@type_checker(*types)
|
||||
def _positive(name, n):
|
||||
if n <= 0:
|
||||
raise ValueError(f"{name!r} must be positive")
|
||||
|
||||
return _positive
|
||||
|
||||
|
||||
@type_checker(tuple, list)
|
||||
def int_pair(name, t):
|
||||
invalid = len(t) != 2
|
||||
for m in t:
|
||||
if invalid or not isinstance(m, int):
|
||||
raise ValueError(f"{name!r} must be a list or a tuple of 2 int")
|
||||
|
||||
|
||||
def literal(*l):
|
||||
l = set(l)
|
||||
|
||||
@type_checker(str)
|
||||
def _string(name, s):
|
||||
if not s in l:
|
||||
raise ValueError(f"{name!r} must be a str in {l}")
|
||||
|
||||
return _string
|
||||
|
||||
|
||||
def validator_list(validator):
|
||||
"""returns a new validator that applies validator to each el of an iterable"""
|
||||
|
||||
@type_checker(list, tuple)
|
||||
def _list_validator(name, l):
|
||||
for i, el in enumerate(l):
|
||||
validator(name + f"[{i}]", el)
|
||||
|
||||
return _list_validator
|
||||
|
||||
|
||||
def validator_or(*validators):
|
||||
"""combines many validators and raises an exception only if all of them raise an exception"""
|
||||
|
||||
n = len(validators)
|
||||
|
||||
def _or_validator(name, value):
|
||||
errors = []
|
||||
for validator in validators:
|
||||
try:
|
||||
validator(name, value)
|
||||
except (ValueError, TypeError) as e:
|
||||
errors.append(e)
|
||||
errors.sort(key=lambda el: isinstance(el, ValueError))
|
||||
if len(errors) == n:
|
||||
raise errors[-1]
|
||||
|
||||
return _or_validator
|
||||
|
||||
|
||||
def validator_and(*validators):
|
||||
def _and_validator(name, n):
|
||||
for v in validators:
|
||||
v(name, n)
|
||||
|
||||
return _and_validator
|
||||
|
||||
|
||||
@type_checker(list, tuple, np.ndarray)
|
||||
def num_list(name, l):
|
||||
for i, el in enumerate(l):
|
||||
type_checker(int, float)(name + f"[{i}]", el)
|
||||
|
||||
|
||||
def func_validator(name, n):
|
||||
if not callable(n):
|
||||
raise TypeError(f"{name!r} must be callable")
|
||||
|
||||
|
||||
class Parameter:
|
||||
def __init__(self, validator, converter=None, default=None):
|
||||
"""Single parameter
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tpe : type
|
||||
type of the paramter
|
||||
validators : Callable[[str, Any], None]
|
||||
signature : validator(name, value)
|
||||
must raise a ValueError when value doesn't fit the criteria checked by
|
||||
validator. name is passed to validator to be included in the error message
|
||||
converter : Callable, optional
|
||||
converts a valid value (for example, str.lower), by default None
|
||||
default : callable, optional
|
||||
factory function for a default value (for example, list), by default None
|
||||
"""
|
||||
|
||||
self.validator = validator
|
||||
self.converter = converter
|
||||
self.default = default
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
self.name = name
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if not instance:
|
||||
return self
|
||||
return instance.__dict__[self.name]
|
||||
|
||||
def __delete__(self, instance):
|
||||
del instance.__dict__[self.name]
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if isinstance(value, Parameter):
|
||||
defaut = None if self.default is None else copy(self.default)
|
||||
instance.__dict__[self.name] = defaut
|
||||
else:
|
||||
if value is not None:
|
||||
self.validator(self.name, value)
|
||||
if self.converter is not None:
|
||||
value = self.converter(value)
|
||||
instance.__dict__[self.name] = value
|
||||
|
||||
|
||||
class VariableParameter:
|
||||
def __init__(self, parameterBase):
|
||||
self.pbase = parameterBase
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
self.name = name
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if not instance:
|
||||
return self
|
||||
return instance.__dict__[self.name]
|
||||
|
||||
def __delete__(self, instance):
|
||||
del instance.__dict__[self.name]
|
||||
|
||||
def __set__(self, instance, value: dict):
|
||||
if isinstance(value, VariableParameter):
|
||||
value = {}
|
||||
else:
|
||||
for k, v in value.items():
|
||||
if k not in valid_variable:
|
||||
raise TypeError(f"{k!r} is not a valide variable parameter")
|
||||
if len(v) == 0:
|
||||
raise ValueError(f"variable parameter {k!r} must not be empty")
|
||||
|
||||
p = getattr(self.pbase, k)
|
||||
|
||||
for el in v:
|
||||
p.validator(k, el)
|
||||
instance.__dict__[self.name] = value
|
||||
|
||||
|
||||
valid_variable = {
|
||||
"beta",
|
||||
"gamma",
|
||||
"pitch",
|
||||
"pitch_ratio",
|
||||
"core_radius",
|
||||
"capillary_num",
|
||||
"capillary_outer_d",
|
||||
"capillary_thickness",
|
||||
"capillary_spacing",
|
||||
"capillary_resonance_strengths",
|
||||
"capillary_nested",
|
||||
"he_mode",
|
||||
"fit_parameters",
|
||||
"input_transmission",
|
||||
"n2",
|
||||
"pressure",
|
||||
"temperature",
|
||||
"gas_name",
|
||||
"plasma_density" "peak_power",
|
||||
"mean_power",
|
||||
"peak_power",
|
||||
"energy",
|
||||
"quantum_noise",
|
||||
"shape",
|
||||
"wavelength",
|
||||
"intensity_noise",
|
||||
"width",
|
||||
"soliton_num",
|
||||
"behaviors",
|
||||
"raman_type",
|
||||
"tolerated_error",
|
||||
"step_size",
|
||||
"ideal_gas",
|
||||
"readjust_wavelength",
|
||||
}
|
||||
|
||||
hc_model_specific_parameters = dict(
|
||||
marcatili=["core_radius", "he_mode"],
|
||||
marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"],
|
||||
hasan=[
|
||||
"core_radius",
|
||||
"capillary_num",
|
||||
"capillary_thickness",
|
||||
"capillary_resonance_strengths",
|
||||
"capillary_nested",
|
||||
"capillary_spacing",
|
||||
"capillary_outer_d",
|
||||
],
|
||||
)
|
||||
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BareParams:
|
||||
"""
|
||||
This class defines each valid parameter's name, type and valid value but doesn't provide
|
||||
any method to act on those. For that, use initialize.Params
|
||||
"""
|
||||
|
||||
# root
|
||||
name: str = Parameter(string)
|
||||
prev_data_dir: str = Parameter(string)
|
||||
|
||||
# # fiber
|
||||
input_transmission: float = Parameter(in_range_incl(0, 1))
|
||||
gamma: float = Parameter(non_negative(float, int))
|
||||
n2: float = Parameter(non_negative(float, int))
|
||||
effective_mode_diameter: float = Parameter(positive(float, int))
|
||||
A_eff: float = Parameter(non_negative(float, int))
|
||||
pitch: float = Parameter(in_range_excl(0, 1e-3))
|
||||
pitch_ratio: float = Parameter(in_range_excl(0, 1))
|
||||
core_radius: float = Parameter(in_range_excl(0, 1e-3))
|
||||
he_mode: Tuple[int, int] = Parameter(int_pair)
|
||||
fit_parameters: Tuple[int, int] = Parameter(int_pair)
|
||||
beta: Iterable[float] = Parameter(num_list)
|
||||
dispersion_file: str = Parameter(string)
|
||||
model: str = Parameter(literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"))
|
||||
length: float = Parameter(non_negative(float, int))
|
||||
capillary_num: int = Parameter(positive(int))
|
||||
capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3))
|
||||
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3))
|
||||
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3))
|
||||
capillary_resonance_strengths: Iterable[float] = Parameter(num_list)
|
||||
capillary_nested: int = Parameter(non_negative(int))
|
||||
|
||||
# gas
|
||||
gas_name: str = Parameter(literal("vacuum", "helium", "air"), converter=str.lower)
|
||||
pressure: Union[float, Iterable[float]] = Parameter(
|
||||
validator_or(non_negative(float, int), num_list)
|
||||
)
|
||||
temperature: float = Parameter(positive(float, int))
|
||||
plasma_density: float = Parameter(non_negative(float, int))
|
||||
|
||||
# pulse
|
||||
field_file: str = Parameter(string)
|
||||
repetition_rate: float = Parameter(non_negative(float, int))
|
||||
peak_power: float = Parameter(positive(float, int))
|
||||
mean_power: float = Parameter(positive(float, int))
|
||||
energy: float = Parameter(positive(float, int))
|
||||
soliton_num: float = Parameter(positive(float, int))
|
||||
quantum_noise: bool = Parameter(boolean)
|
||||
shape: str = Parameter(literal("gaussian", "sech"))
|
||||
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9))
|
||||
intensity_noise: float = Parameter(in_range_incl(0, 1))
|
||||
width: float = Parameter(in_range_excl(0, 1e-9))
|
||||
t0: float = Parameter(in_range_excl(0, 1e-9))
|
||||
|
||||
# simulation
|
||||
behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")))
|
||||
parallel: bool = Parameter(boolean)
|
||||
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
|
||||
ideal_gas: bool = Parameter(boolean)
|
||||
repeat: int = Parameter(positive(int))
|
||||
t_num: int = Parameter(positive(int))
|
||||
z_num: int = Parameter(positive(int))
|
||||
time_window: float = Parameter(positive(float, int))
|
||||
dt: float = Parameter(in_range_excl(0, 5e-15))
|
||||
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-5))
|
||||
step_size: float = Parameter(positive(float, int))
|
||||
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9))
|
||||
upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9))
|
||||
frep: float = Parameter(positive(float, int))
|
||||
prev_sim_dir: str = Parameter(string)
|
||||
readjust_wavelength: bool = Parameter(boolean)
|
||||
recovery_last_stored: int = Parameter(non_negative(int))
|
||||
|
||||
# computed
|
||||
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
w_c: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
t: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
L_D: float = Parameter(non_negative(float, int))
|
||||
L_NL: float = Parameter(non_negative(float, int))
|
||||
L_sol: float = Parameter(non_negative(float, int))
|
||||
dynamic_dispersion: bool = Parameter(boolean)
|
||||
adapt_step_size: bool = Parameter(boolean)
|
||||
error_ok: float = Parameter(positive(float))
|
||||
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
beta_func: Callable[[float], List[float]] = Parameter(func_validator)
|
||||
gamma_func: Callable[[float], float] = Parameter(func_validator)
|
||||
|
||||
def prepare_for_dump(self) -> Dict[str, Any]:
|
||||
param = asdict(self)
|
||||
param = BareParams.strip_params_dict(param)
|
||||
param["datetime"] = datetime.datetime.now()
|
||||
param["version"] = __version__
|
||||
return param
|
||||
|
||||
@staticmethod
|
||||
def strip_params_dict(dico: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""prepares a dictionary for serialization. Some keys may not be preserved
|
||||
(dropped because they take a lot of space and can be exactly reconstructed)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dico : dict
|
||||
dictionary
|
||||
"""
|
||||
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"]
|
||||
types = (np.ndarray, float, int, str, list, tuple, dict)
|
||||
out = {}
|
||||
for key, value in dico.items():
|
||||
if key in forbiden_keys:
|
||||
continue
|
||||
if not isinstance(value, types):
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
out[key] = BareParams.strip_params_dict(value)
|
||||
elif isinstance(value, np.ndarray) and value.dtype == complex:
|
||||
continue
|
||||
else:
|
||||
out[key] = value
|
||||
|
||||
if "variable" in out and len(out["variable"]) == 0:
|
||||
del out["variable"]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class BareConfig(BareParams):
|
||||
variable: dict = VariableParameter(BareParams)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
numero = type_checker(int)
|
||||
|
||||
@numero
|
||||
def natural_number(name, n):
|
||||
if n < 0:
|
||||
raise ValueError(f"{name!r} must be positive")
|
||||
|
||||
try:
|
||||
numero("a", np.arange(45))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
try:
|
||||
natural_number("b", -1)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
try:
|
||||
natural_number("c", 1.0)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
try:
|
||||
natural_number("d", 1)
|
||||
print("success !")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
37
testing/test_new_params.py
Normal file
37
testing/test_new_params.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from numba.core import config
|
||||
from scgenerator.initialize import Config, Params, BareParams
|
||||
from scgenerator.utils import variable_iterator, override_config
|
||||
from scgenerator.io import load_toml
|
||||
from pprint import pprint
|
||||
from dataclasses import asdict
|
||||
|
||||
dico = load_toml("testing/configs/ensure_consistency/good2.toml")
|
||||
out = dict(variable=dict())
|
||||
for k, v in dico.items():
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if kk == "variable":
|
||||
for kkk, vvv in vv.items():
|
||||
out["variable"][kkk] = vvv
|
||||
else:
|
||||
out[kk] = vv
|
||||
|
||||
pprint(out)
|
||||
p = Config(**out)
|
||||
print(p)
|
||||
|
||||
for l, c in variable_iterator(p):
|
||||
print(l, c.width, c.intensity_noise)
|
||||
print()
|
||||
|
||||
config2 = override_config(dict(width=1.2e-13, variable=dict(peak_power=[1e5, 2e5])), p)
|
||||
print(
|
||||
f"{config2.variable=}",
|
||||
f"{config2.intensity_noise=}",
|
||||
f"{config2.width=}",
|
||||
f"{config2.peak_power=}",
|
||||
)
|
||||
|
||||
par = BareParams()
|
||||
|
||||
print(all(v is None for v in vars(par).values()))
|
||||
Reference in New Issue
Block a user