new param system
This commit is contained in:
4
src/scgenerator/cli/__main__.py
Normal file
4
src/scgenerator/cli/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .cli import main
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -105,7 +105,7 @@ def prep_ray(args):
|
|||||||
def resume_sim(args):
|
def resume_sim(args):
|
||||||
|
|
||||||
method = prep_ray(args)
|
method = prep_ray(args)
|
||||||
sim = resume_simulations(args.sim_dir, method=method)
|
sim = resume_simulations(Path(args.sim_dir), method=method)
|
||||||
sim.run()
|
sim.run()
|
||||||
run_simulation_sequence(
|
run_simulation_sequence(
|
||||||
*args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name
|
*args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name
|
||||||
|
|||||||
@@ -1,78 +0,0 @@
|
|||||||
from .. import const
|
|
||||||
import toml
|
|
||||||
|
|
||||||
valid_commands = ["finish", "next"]
|
|
||||||
|
|
||||||
|
|
||||||
class Configurator:
|
|
||||||
def __init__(self, name):
|
|
||||||
self.config = dict(name=name, fiber=dict(), gas=dict(), pulse=dict(), simulation=dict())
|
|
||||||
|
|
||||||
def list_input(self):
|
|
||||||
answer = ""
|
|
||||||
while answer == "":
|
|
||||||
answer = input("Please enter a list of values (one per line)\n")
|
|
||||||
|
|
||||||
out = [self.process_input(answer)]
|
|
||||||
|
|
||||||
while answer != "":
|
|
||||||
answer = input()
|
|
||||||
out.append(self.process_input(answer))
|
|
||||||
|
|
||||||
return out[:-1]
|
|
||||||
|
|
||||||
def process_input(self, s):
|
|
||||||
try:
|
|
||||||
return int(s)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
return float(s)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return s
|
|
||||||
|
|
||||||
def accept(self, question, default=True):
|
|
||||||
question += " ([y]/n)" if default else " (y/[n])"
|
|
||||||
question += "\n"
|
|
||||||
inp = input(question)
|
|
||||||
|
|
||||||
yes_str = ["y", "yes"]
|
|
||||||
if default:
|
|
||||||
yes_str.append("")
|
|
||||||
|
|
||||||
return inp.lower() in yes_str
|
|
||||||
|
|
||||||
def print_current(self, config: dict):
|
|
||||||
print(toml.dumps(config))
|
|
||||||
|
|
||||||
def get(self, section, param_name):
|
|
||||||
question = f"Please enter a value for the parameter '{param_name}'\n"
|
|
||||||
valid = const.valid_param_types[section][param_name]
|
|
||||||
|
|
||||||
is_valid = False
|
|
||||||
value = None
|
|
||||||
|
|
||||||
while not is_valid:
|
|
||||||
answer = input(question)
|
|
||||||
if answer == "variable" and param_name in const.valid_variable[section]:
|
|
||||||
value = self.list_input()
|
|
||||||
print(value)
|
|
||||||
is_valid = all(valid(v) for v in value)
|
|
||||||
else:
|
|
||||||
value = self.process_input(answer)
|
|
||||||
is_valid = valid(value)
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def ask_next_command(self):
|
|
||||||
s = ""
|
|
||||||
raw_input = input(s).split(" ")
|
|
||||||
return raw_input[0], raw_input[1:]
|
|
||||||
|
|
||||||
def main(self):
|
|
||||||
editing = True
|
|
||||||
while editing:
|
|
||||||
command, args = self.ask_next_command()
|
|
||||||
@@ -1,6 +1,3 @@
|
|||||||
import numpy as np
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
|
||||||
@@ -19,243 +16,6 @@ def pbar_format(worker_id: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
#####
|
|
||||||
|
|
||||||
|
|
||||||
def in_range_excl(func, r):
|
|
||||||
def _in_range(n):
|
|
||||||
if not func(n):
|
|
||||||
return False
|
|
||||||
return n > r[0] and n < r[1]
|
|
||||||
|
|
||||||
_in_range.__doc__ = func.__doc__ + f" between {r[0]} and {r[1]} (exclusive) "
|
|
||||||
return _in_range
|
|
||||||
|
|
||||||
|
|
||||||
def in_range_incl(func, r):
|
|
||||||
def _in_range(n):
|
|
||||||
if not func(n):
|
|
||||||
return False
|
|
||||||
return n >= r[0] and n <= r[1]
|
|
||||||
|
|
||||||
_in_range.__doc__ = func.__doc__ + f" between {r[0]} and {r[1]} (inclusive)"
|
|
||||||
return _in_range
|
|
||||||
|
|
||||||
|
|
||||||
def num(n):
|
|
||||||
"""must be a single, real, non-negative number"""
|
|
||||||
return isinstance(n, (float, int)) and n >= 0
|
|
||||||
|
|
||||||
|
|
||||||
def integer(n):
|
|
||||||
"""must be a strictly positive integer"""
|
|
||||||
return isinstance(n, int) and n > 0
|
|
||||||
|
|
||||||
|
|
||||||
def boolean(b):
|
|
||||||
"""must be a boolean"""
|
|
||||||
return type(b) == bool
|
|
||||||
|
|
||||||
|
|
||||||
def behaviors(l):
|
|
||||||
"""must be a valid list of behaviors"""
|
|
||||||
for s in l:
|
|
||||||
if s.lower() not in ["spm", "raman", "ss"]:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def beta(l):
|
|
||||||
"""must be a valid beta array"""
|
|
||||||
for n in l:
|
|
||||||
if not isinstance(n, (float, int)):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def field_0(f):
|
|
||||||
return isinstance(f, (str, tuple, list, np.ndarray))
|
|
||||||
|
|
||||||
|
|
||||||
def he_mode(mode):
|
|
||||||
"""must be a valide HE mode"""
|
|
||||||
if not isinstance(mode, (list, tuple)):
|
|
||||||
return False
|
|
||||||
if not len(mode) == 2:
|
|
||||||
return False
|
|
||||||
for m in mode:
|
|
||||||
if not integer(m):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def fit_parameters(param):
|
|
||||||
"""must be a valide fitting parameter tuple of the mercatili_adjusted model"""
|
|
||||||
if not isinstance(param, (list, tuple)):
|
|
||||||
return False
|
|
||||||
if not len(param) == 2:
|
|
||||||
return False
|
|
||||||
for n in param:
|
|
||||||
if not integer(n):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def string(l=None):
|
|
||||||
if l is None:
|
|
||||||
|
|
||||||
def _string(s):
|
|
||||||
return isinstance(s, str)
|
|
||||||
|
|
||||||
_string.__doc__ = f"must be a str"
|
|
||||||
else:
|
|
||||||
|
|
||||||
def _string(s):
|
|
||||||
return isinstance(s, str) and s.lower() in l
|
|
||||||
|
|
||||||
_string.__doc__ = f"must be a str matching one of {l}"
|
|
||||||
|
|
||||||
return _string
|
|
||||||
|
|
||||||
|
|
||||||
def capillary_resonance_strengths(l):
|
|
||||||
"""must be a list of non-zero, real number"""
|
|
||||||
if not isinstance(l, (list, tuple)):
|
|
||||||
return False
|
|
||||||
for m in l:
|
|
||||||
if not num(m):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def capillary_nested(n):
|
|
||||||
"""must be a non negative integer"""
|
|
||||||
return isinstance(n, int) and n >= 0
|
|
||||||
|
|
||||||
|
|
||||||
valid_param_types = dict(
|
|
||||||
root=dict(
|
|
||||||
name=string(),
|
|
||||||
prev_data_dir=string(),
|
|
||||||
),
|
|
||||||
fiber=dict(
|
|
||||||
input_transmission=in_range_incl(num, (0, 1)),
|
|
||||||
gamma=num,
|
|
||||||
n2=num,
|
|
||||||
effective_mode_diameter=num,
|
|
||||||
A_eff=num,
|
|
||||||
pitch=in_range_excl(num, (0, 1e-3)),
|
|
||||||
pitch_ratio=in_range_excl(num, (0, 1)),
|
|
||||||
core_radius=in_range_excl(num, (0, 1e-3)),
|
|
||||||
he_mode=he_mode,
|
|
||||||
fit_parameters=fit_parameters,
|
|
||||||
beta=beta,
|
|
||||||
dispersion_file=string(),
|
|
||||||
model=string(["pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"]),
|
|
||||||
length=in_range_excl(num, (0, 1e9)),
|
|
||||||
capillary_num=integer,
|
|
||||||
capillary_outer_d=in_range_excl(num, (0, 1e-3)),
|
|
||||||
capillary_thickness=in_range_excl(num, (0, 1e-3)),
|
|
||||||
capillary_spacing=in_range_excl(num, (0, 1e-3)),
|
|
||||||
capillary_resonance_strengths=capillary_resonance_strengths,
|
|
||||||
capillary_nested=capillary_nested,
|
|
||||||
),
|
|
||||||
gas=dict(
|
|
||||||
gas_name=string(["vacuum", "helium", "air"]),
|
|
||||||
pressure=num,
|
|
||||||
temperature=num,
|
|
||||||
plasma_density=num,
|
|
||||||
),
|
|
||||||
pulse=dict(
|
|
||||||
field_0=field_0,
|
|
||||||
field_file=string(),
|
|
||||||
repetition_rate=num,
|
|
||||||
peak_power=num,
|
|
||||||
mean_power=num,
|
|
||||||
energy=num,
|
|
||||||
soliton_num=num,
|
|
||||||
quantum_noise=boolean,
|
|
||||||
shape=string(["gaussian", "sech"]),
|
|
||||||
wavelength=in_range_excl(num, (100e-9, 3000e-9)),
|
|
||||||
intensity_noise=in_range_incl(num, (0, 1)),
|
|
||||||
width=in_range_excl(num, (0, 1e-9)),
|
|
||||||
t0=in_range_excl(num, (0, 1e-9)),
|
|
||||||
),
|
|
||||||
simulation=dict(
|
|
||||||
behaviors=behaviors,
|
|
||||||
parallel=boolean,
|
|
||||||
raman_type=string(["measured", "agrawal", "stolen"]),
|
|
||||||
ideal_gas=boolean,
|
|
||||||
repeat=integer,
|
|
||||||
t_num=integer,
|
|
||||||
z_num=integer,
|
|
||||||
time_window=num,
|
|
||||||
dt=in_range_excl(num, (0, 5e-15)),
|
|
||||||
tolerated_error=in_range_excl(num, (1e-15, 1e-5)),
|
|
||||||
step_size=num,
|
|
||||||
lower_wavelength_interp_limit=in_range_excl(num, (100e-9, 3000e-9)),
|
|
||||||
upper_wavelength_interp_limit=in_range_excl(num, (100e-9, 5000e-9)),
|
|
||||||
frep=num,
|
|
||||||
prev_sim_dir=string(),
|
|
||||||
readjust_wavelength=boolean,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
hc_model_specific_parameters = dict(
|
|
||||||
marcatili=["core_radius", "he_mode"],
|
|
||||||
marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"],
|
|
||||||
hasan=[
|
|
||||||
"core_radius",
|
|
||||||
"capillary_num",
|
|
||||||
"capillary_thickness",
|
|
||||||
"capillary_resonance_strengths",
|
|
||||||
"capillary_nested",
|
|
||||||
"capillary_spacing",
|
|
||||||
"capillary_outer_d",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
|
|
||||||
|
|
||||||
valid_variable = dict(
|
|
||||||
fiber=[
|
|
||||||
"beta",
|
|
||||||
"gamma",
|
|
||||||
"pitch",
|
|
||||||
"pitch_ratio",
|
|
||||||
"core_radius",
|
|
||||||
"capillary_num",
|
|
||||||
"capillary_outer_d",
|
|
||||||
"capillary_thickness",
|
|
||||||
"capillary_spacing",
|
|
||||||
"capillary_resonance_strengths",
|
|
||||||
"capillary_nested",
|
|
||||||
"he_mode",
|
|
||||||
"fit_parameters",
|
|
||||||
"input_transmission",
|
|
||||||
"n2",
|
|
||||||
],
|
|
||||||
gas=["pressure", "temperature", "gas_name", "plasma_density"],
|
|
||||||
pulse=[
|
|
||||||
"peak_power",
|
|
||||||
"mean_power",
|
|
||||||
"energy",
|
|
||||||
"quantum_noise",
|
|
||||||
"shape",
|
|
||||||
"wavelength",
|
|
||||||
"intensity_noise",
|
|
||||||
"width",
|
|
||||||
"soliton_num",
|
|
||||||
],
|
|
||||||
simulation=[
|
|
||||||
"behaviors",
|
|
||||||
"raman_type",
|
|
||||||
"tolerated_error",
|
|
||||||
"step_size",
|
|
||||||
"ideal_gas",
|
|
||||||
"readjust_wavelength",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
ENVIRON_KEY_BASE = "SCGENERATOR_"
|
ENVIRON_KEY_BASE = "SCGENERATOR_"
|
||||||
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
|
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
|
||||||
LOG_POLICY = ENVIRON_KEY_BASE + "LOG_POLICY"
|
LOG_POLICY = ENVIRON_KEY_BASE + "LOG_POLICY"
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
from .errors import MissingParameterError
|
|
||||||
|
|
||||||
default_parameters = dict(
|
default_parameters = dict(
|
||||||
input_transmission=1.0,
|
input_transmission=1.0,
|
||||||
@@ -28,6 +27,7 @@ default_parameters = dict(
|
|||||||
upper_wavelength_interp_limit=1900e-9,
|
upper_wavelength_interp_limit=1900e-9,
|
||||||
ideal_gas=False,
|
ideal_gas=False,
|
||||||
readjust_wavelength=False,
|
readjust_wavelength=False,
|
||||||
|
recovery_last_stored=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
default_plotting = dict(
|
default_plotting = dict(
|
||||||
@@ -36,7 +36,7 @@ default_plotting = dict(
|
|||||||
vmin=-40,
|
vmin=-40,
|
||||||
vmax=0,
|
vmax=0,
|
||||||
vmax_with_headroom=2,
|
vmax_with_headroom=2,
|
||||||
name="plot",
|
out_path=Path("plot"),
|
||||||
avg_main_to_coherence_ratio=4,
|
avg_main_to_coherence_ratio=4,
|
||||||
avg_line_labels=["individual values", "mean"],
|
avg_line_labels=["individual values", "mean"],
|
||||||
muted_style=dict(linewidth=0.5, c=(0.8, 0.8, 0.8, 0.4)),
|
muted_style=dict(linewidth=0.5, c=(0.8, 0.8, 0.8, 0.4)),
|
||||||
@@ -57,76 +57,3 @@ default_plotting = dict(
|
|||||||
text_topright_style=dict(verticalalignment="top", horizontalalignment="right"),
|
text_topright_style=dict(verticalalignment="top", horizontalalignment="right"),
|
||||||
text_topleft_style=dict(verticalalignment="top", horizontalalignment="left"),
|
text_topleft_style=dict(verticalalignment="top", horizontalalignment="left"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get(section_dict, param, **kwargs):
|
|
||||||
"""checks if param is in the parameter section dict and attempts to fill in a default value
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
section_dict : dict
|
|
||||||
the parameters section {fiber, pulse, simulation, root} sub-dictionary
|
|
||||||
param : str
|
|
||||||
the name of the parameter (dict key)
|
|
||||||
kwargs : any
|
|
||||||
key word arguments passed to the MissingParameterError constructor
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict
|
|
||||||
the updated section_dict dictionary
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
MissingFiberParameterError
|
|
||||||
raised when a parameter is missing and no default exists
|
|
||||||
"""
|
|
||||||
|
|
||||||
# whether the parameter is in the right place and valid is checked elsewhere,
|
|
||||||
# here, we just make sure it is present.
|
|
||||||
if param not in section_dict and param not in section_dict.get("variable", {}):
|
|
||||||
try:
|
|
||||||
section_dict[param] = default_parameters[param]
|
|
||||||
# LOG
|
|
||||||
except KeyError:
|
|
||||||
raise MissingParameterError(param, **kwargs)
|
|
||||||
return section_dict
|
|
||||||
|
|
||||||
|
|
||||||
def get_fiber(section_dict, param, **kwargs):
|
|
||||||
"""wrapper for fiber parameters that depend on fiber model"""
|
|
||||||
return get(section_dict, param, fiber_model=section_dict["model"], **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_multiple(section_dict, params, num, **kwargs):
|
|
||||||
"""similar to th get method but works with several parameters
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
section_dict : dict
|
|
||||||
the parameters section {fiber, pulse, simulation, root}, sub-dictionary
|
|
||||||
params : list of str
|
|
||||||
names of the required parameters
|
|
||||||
num : int
|
|
||||||
how many of the parameters in params are required
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict
|
|
||||||
the updated section_dict
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
MissingParameterError
|
|
||||||
raised when not enough parameters are provided and no defaults exist
|
|
||||||
"""
|
|
||||||
gotten = 0
|
|
||||||
for param in params:
|
|
||||||
try:
|
|
||||||
section_dict = get(section_dict, param, **kwargs)
|
|
||||||
gotten += 1
|
|
||||||
except MissingParameterError:
|
|
||||||
pass
|
|
||||||
if gotten >= num:
|
|
||||||
return section_dict
|
|
||||||
raise MissingParameterError(params, num_required=num, **kwargs)
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Literal, Optional, Set
|
from typing import Dict, Literal, Optional, Set
|
||||||
|
|
||||||
from .const import ENVIRON_KEY_BASE, PBAR_POLICY, LOG_POLICY, TMP_FOLDER_KEY_BASE
|
from .const import ENVIRON_KEY_BASE, LOG_POLICY, PBAR_POLICY, TMP_FOLDER_KEY_BASE
|
||||||
|
|
||||||
|
|
||||||
def data_folder(task_id: int) -> Optional[Path]:
|
def data_folder(task_id: int) -> Optional[str]:
|
||||||
idstr = str(int(task_id))
|
idstr = str(int(task_id))
|
||||||
tmp = os.getenv(TMP_FOLDER_KEY_BASE + idstr)
|
tmp = os.getenv(TMP_FOLDER_KEY_BASE + idstr)
|
||||||
return tmp
|
return tmp
|
||||||
|
|||||||
@@ -34,18 +34,3 @@ class DuplicateParameterError(Exception):
|
|||||||
|
|
||||||
class IncompleteDataFolderError(FileNotFoundError):
|
class IncompleteDataFolderError(FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# class MissingFiberParameterError(MissingParameterError):
|
|
||||||
# def __init__(self, param, model):
|
|
||||||
# self.param = param
|
|
||||||
# self.model = model
|
|
||||||
# super().__init__(
|
|
||||||
# f"'{self.param}' is a required parameter for fiber model '{self.model}' and no default value is set"
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
# class MissingPulseParameterError(MissingParameterError):
|
|
||||||
# def __init__(self, param):
|
|
||||||
# self.param = param
|
|
||||||
# super().__init__(f"'{self.param}' is a required pulse parameter and no default value is set")
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
|||||||
|
from dataclasses import asdict
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@@ -11,18 +12,17 @@ import toml
|
|||||||
|
|
||||||
from . import env, utils
|
from . import env, utils
|
||||||
from .const import (
|
from .const import (
|
||||||
__version__,
|
|
||||||
ENVIRON_KEY_BASE,
|
|
||||||
PARAM_FN,
|
PARAM_FN,
|
||||||
PARAM_SEPARATOR,
|
PARAM_SEPARATOR,
|
||||||
PBAR_POLICY,
|
|
||||||
SPEC1_FN,
|
SPEC1_FN,
|
||||||
SPECN_FN,
|
SPECN_FN,
|
||||||
TMP_FOLDER_KEY_BASE,
|
TMP_FOLDER_KEY_BASE,
|
||||||
Z_FN,
|
Z_FN,
|
||||||
|
__version__,
|
||||||
)
|
)
|
||||||
from .errors import IncompleteDataFolderError
|
from .errors import IncompleteDataFolderError
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .utils.parameter import BareConfig, BareParams
|
||||||
|
|
||||||
PathTree = List[Tuple[Path, ...]]
|
PathTree = List[Tuple[Path, ...]]
|
||||||
|
|
||||||
@@ -88,6 +88,10 @@ def load_toml(path: os.PathLike):
|
|||||||
path = conform_toml_path(path)
|
path = conform_toml_path(path)
|
||||||
with open(path, mode="r") as file:
|
with open(path, mode="r") as file:
|
||||||
dico = toml.load(file)
|
dico = toml.load(file)
|
||||||
|
|
||||||
|
for section in ["simulation", "fiber", "pulse", "gas"]:
|
||||||
|
dico.update(dico.pop(section, {}))
|
||||||
|
|
||||||
return dico
|
return dico
|
||||||
|
|
||||||
|
|
||||||
@@ -99,52 +103,15 @@ def save_toml(path: os.PathLike, dico):
|
|||||||
return dico
|
return dico
|
||||||
|
|
||||||
|
|
||||||
def serializable(val):
|
def save_parameters(params: BareParams, destination_dir: Path, file_name="params.toml") -> Path:
|
||||||
"""returns True if val is serializable into a Json file"""
|
|
||||||
types = (np.ndarray, float, int, str, list, tuple)
|
|
||||||
|
|
||||||
out = isinstance(val, types)
|
|
||||||
if isinstance(val, np.ndarray):
|
|
||||||
out &= val.dtype != "complex"
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_for_serialization(dico: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""prepares a dictionary for serialization. Some keys may not be preserved
|
|
||||||
(dropped due to no conversion available)
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
dico : dict
|
|
||||||
dictionary
|
|
||||||
"""
|
|
||||||
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"]
|
|
||||||
types = (np.ndarray, float, int, str, list, tuple, dict)
|
|
||||||
out = {}
|
|
||||||
for key, value in dico.items():
|
|
||||||
if key in forbiden_keys:
|
|
||||||
continue
|
|
||||||
if not isinstance(value, types):
|
|
||||||
continue
|
|
||||||
if isinstance(value, dict):
|
|
||||||
out[key] = prepare_for_serialization(value)
|
|
||||||
elif isinstance(value, np.ndarray) and value.dtype == complex:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
out[key] = value
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path:
|
|
||||||
"""saves a parameter dictionary. Note that is does remove some entries, particularly
|
"""saves a parameter dictionary. Note that is does remove some entries, particularly
|
||||||
those that take a lot of space ("t", "w", ...)
|
those that take a lot of space ("t", "w", ...)
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
param_dict : Dict[str, Any]
|
params : Dict[str, Any]
|
||||||
dictionary to save
|
dictionary to save
|
||||||
data_dir : Path
|
destination_dir : Path
|
||||||
destination directory
|
destination directory
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@@ -152,12 +119,8 @@ def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path:
|
|||||||
Path
|
Path
|
||||||
path to newly created the paramter file
|
path to newly created the paramter file
|
||||||
"""
|
"""
|
||||||
param = param_dict.copy()
|
param = params.prepare_for_dump()
|
||||||
file_path = destination_dir / "params.toml"
|
file_path = destination_dir / file_name
|
||||||
|
|
||||||
param = prepare_for_serialization(param)
|
|
||||||
param["datetime"] = datetime.now()
|
|
||||||
param["version"] = __version__
|
|
||||||
|
|
||||||
file_path.parent.mkdir(exist_ok=True)
|
file_path.parent.mkdir(exist_ok=True)
|
||||||
|
|
||||||
@@ -168,7 +131,7 @@ def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path:
|
|||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
def load_previous_parameters(path: os.PathLike):
|
def load_params(path: os.PathLike) -> BareParams:
|
||||||
"""loads a parameters toml files and converts data to appropriate type
|
"""loads a parameters toml files and converts data to appropriate type
|
||||||
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
|
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
|
||||||
|
|
||||||
@@ -179,15 +142,29 @@ def load_previous_parameters(path: os.PathLike):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
dict
|
BareParams
|
||||||
flattened parameters dictionary
|
params obj
|
||||||
"""
|
"""
|
||||||
params = load_toml(path)
|
params = load_toml(path)
|
||||||
|
return BareParams(**params)
|
||||||
|
|
||||||
for k, v in params.items():
|
|
||||||
if isinstance(v, list) and isinstance(v[0], (float, int)):
|
def load_config(path: os.PathLike) -> BareConfig:
|
||||||
params[k] = np.array(v)
|
"""loads a parameters toml files and converts data to appropriate type
|
||||||
return params
|
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
path to the toml
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
BareParams
|
||||||
|
config obj
|
||||||
|
"""
|
||||||
|
config = load_toml(path)
|
||||||
|
return BareConfig(**config)
|
||||||
|
|
||||||
|
|
||||||
def load_material_dico(name):
|
def load_material_dico(name):
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
|
||||||
from .env import log_policy
|
|
||||||
|
|
||||||
|
from .env import log_policy
|
||||||
|
|
||||||
# class DebugOnlyFileHandler(logging.FileHandler):
|
# class DebugOnlyFileHandler(logging.FileHandler):
|
||||||
# def __init__(
|
# def __init__(
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from typing import Type, Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy.interpolate import griddata, interp1d
|
||||||
from scipy.special import jn_zeros
|
from scipy.special import jn_zeros
|
||||||
from scipy.interpolate import interp1d, griddata
|
|
||||||
from numba import jit
|
|
||||||
|
|
||||||
|
|
||||||
def span(*vec):
|
def span(*vec):
|
||||||
@@ -54,7 +54,6 @@ def power_fact(x, n):
|
|||||||
raise TypeError(f"type {type(x)} of x not supported.")
|
raise TypeError(f"type {type(x)} of x not supported.")
|
||||||
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
|
||||||
def _power_fact_single(x, n):
|
def _power_fact_single(x, n):
|
||||||
result = 1.0
|
result = 1.0
|
||||||
for k in range(n):
|
for k in range(n):
|
||||||
@@ -62,7 +61,6 @@ def _power_fact_single(x, n):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
|
||||||
def _power_fact_array(x, n):
|
def _power_fact_array(x, n):
|
||||||
result = np.ones(len(x), dtype=np.float64)
|
result = np.ones(len(x), dtype=np.float64)
|
||||||
for k in range(n):
|
for k in range(n):
|
||||||
@@ -70,7 +68,6 @@ def _power_fact_array(x, n):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
|
||||||
def abs2(z: np.ndarray) -> np.ndarray:
|
def abs2(z: np.ndarray) -> np.ndarray:
|
||||||
return z.real ** 2 + z.imag ** 2
|
return z.real ** 2 + z.imag ** 2
|
||||||
|
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
class Parameter:
|
|
||||||
"""base class for parameters"""
|
|
||||||
|
|
||||||
all = dict(fiber=dict(), pulse=dict(), gas=dict(), simulation=dict())
|
|
||||||
help_message = "no help message lol"
|
|
||||||
|
|
||||||
def __init_subclass__(cls, section):
|
|
||||||
Parameter.all[section][cls.__name__.lower()] = cls
|
|
||||||
|
|
||||||
def __init__(self, s):
|
|
||||||
self.s = s
|
|
||||||
valid = True
|
|
||||||
try:
|
|
||||||
self.value = self._convert()
|
|
||||||
valid = self.valid()
|
|
||||||
except ValueError:
|
|
||||||
valid = False
|
|
||||||
|
|
||||||
if not valid:
|
|
||||||
raise ValueError(
|
|
||||||
f"{self.__class__.__name__} {self.__class__.help_message}. input : {self.s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _convert(self):
|
|
||||||
value = self.conversion_func(self.s)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class Wavelength(Parameter, section="pulse"):
|
|
||||||
help_message = "must be a strictly positive real number"
|
|
||||||
|
|
||||||
def valid(self):
|
|
||||||
return self.value > 0
|
|
||||||
|
|
||||||
def conversion_func(self, s: str) -> float:
|
|
||||||
return float(s)
|
|
||||||
@@ -1,15 +1,14 @@
|
|||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.lib import disp
|
|
||||||
from numpy.lib.arraysetops import isin
|
|
||||||
import toml
|
import toml
|
||||||
from numba import jit
|
|
||||||
from numpy.fft import fft, ifft
|
from numpy.fft import fft, ifft
|
||||||
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
|
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
|
||||||
from .. import io
|
from .. import io
|
||||||
from ..const import hc_model_specific_parameters
|
|
||||||
from ..math import abs2, argclosest, power_fact, u_nm
|
from ..math import abs2, argclosest, power_fact, u_nm
|
||||||
|
from ..utils.parameter import BareParams, hc_model_specific_parameters
|
||||||
from . import materials as mat
|
from . import materials as mat
|
||||||
from . import units
|
from . import units
|
||||||
from .units import c, pi
|
from .units import c, pi
|
||||||
@@ -25,7 +24,7 @@ def lambda_for_dispersion():
|
|||||||
return np.linspace(190e-9, 3000e-9, 4000)
|
return np.linspace(190e-9, 3000e-9, 4000)
|
||||||
|
|
||||||
|
|
||||||
def is_dynamic_dispersion(params):
|
def is_dynamic_dispersion(pressure=None):
|
||||||
"""tests if the parameter dictionary implies that the dispersion profile of the fiber changes with z
|
"""tests if the parameter dictionary implies that the dispersion profile of the fiber changes with z
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -38,8 +37,8 @@ def is_dynamic_dispersion(params):
|
|||||||
bool : True if dispersion is supposed to change with z
|
bool : True if dispersion is supposed to change with z
|
||||||
"""
|
"""
|
||||||
out = False
|
out = False
|
||||||
if "pressure" in params:
|
if pressure is not None:
|
||||||
out |= isinstance(params["pressure"], (tuple, list)) and len(params["pressure"]) == 2
|
out |= isinstance(pressure, (tuple, list)) and len(pressure) == 2
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -483,7 +482,19 @@ def HCPCF_dispersion(
|
|||||||
return beta2(w, n_eff)
|
return beta2(w, n_eff)
|
||||||
|
|
||||||
|
|
||||||
def dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg):
|
def dynamic_HCPCF_dispersion(
|
||||||
|
lambda_: np.ndarray,
|
||||||
|
pressure_values: List[float],
|
||||||
|
core_radius: float,
|
||||||
|
fiber_model: str,
|
||||||
|
model_params: Dict[str, Any],
|
||||||
|
temperature: float,
|
||||||
|
ideal_gas: bool,
|
||||||
|
w0: float,
|
||||||
|
interp_range: Tuple[float, float],
|
||||||
|
material_dico: Dict[str, Any],
|
||||||
|
deg,
|
||||||
|
):
|
||||||
"""returns functions for beta2 coefficients and gamma instead of static values
|
"""returns functions for beta2 coefficients and gamma instead of static values
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -504,25 +515,22 @@ def dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg):
|
|||||||
in the fiber
|
in the fiber
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# store values because storing functions acts weird with dict
|
A_eff = 1.5 * core_radius ** 2
|
||||||
pressure_values = params["pressure"]
|
|
||||||
a = params["core_radius"]
|
|
||||||
fiber_model = params["fiber_model"]
|
|
||||||
model_params = {k: params[k] for k in hc_model_specific_parameters[fiber_model]}
|
|
||||||
temp = params["temperature"]
|
|
||||||
ideal_gas = params["ideal_gas"]
|
|
||||||
w0 = params["w0"]
|
|
||||||
interp_range = params["interp_range"]
|
|
||||||
|
|
||||||
A_eff = 1.5 * a ** 2
|
|
||||||
|
|
||||||
# defining function instead of storing every possilble value
|
# defining function instead of storing every possilble value
|
||||||
pressure = lambda r: mat.pressure_from_gradient(r, *pressure_values)
|
pressure = lambda r: mat.pressure_from_gradient(r, *pressure_values)
|
||||||
beta2 = lambda r: HCPCF_dispersion(
|
beta2 = lambda r: HCPCF_dispersion(
|
||||||
lambda_, a, material_dico, fiber_model, model_params, pressure(r), temp, ideal_gas
|
lambda_,
|
||||||
|
core_radius,
|
||||||
|
material_dico,
|
||||||
|
fiber_model,
|
||||||
|
model_params,
|
||||||
|
pressure(r),
|
||||||
|
temperature,
|
||||||
|
ideal_gas,
|
||||||
)
|
)
|
||||||
|
|
||||||
n2 = lambda r: mat.non_linear_refractive_index(material_dico, pressure(r), temp)
|
n2 = lambda r: mat.non_linear_refractive_index(material_dico, pressure(r), temperature)
|
||||||
ratio_range = np.linspace(0, 1, 256)
|
ratio_range = np.linspace(0, 1, 256)
|
||||||
|
|
||||||
gamma_grid = np.array([gamma_parameter(n2(r), w0, A_eff) for r in ratio_range])
|
gamma_grid = np.array([gamma_parameter(n2(r), w0, A_eff) for r in ratio_range])
|
||||||
@@ -640,7 +648,7 @@ def PCF_dispersion(lambda_, pitch, ratio_d, w0=None, n2=None, A_eff=None):
|
|||||||
return beta2, gamma
|
return beta2, gamma
|
||||||
|
|
||||||
|
|
||||||
def dispersion_central(fiber_model, params, deg=8):
|
def compute_dispersion(params: BareParams, deg=8):
|
||||||
"""dispatch function depending on what type of fiber is used
|
"""dispatch function depending on what type of fiber is used
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -660,8 +668,8 @@ def dispersion_central(fiber_model, params, deg=8):
|
|||||||
nonlinear parameter
|
nonlinear parameter
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "dispersion_file" in params:
|
if params.dispersion_file is not None:
|
||||||
disp_file = np.load(params["dispersion_file"])
|
disp_file = np.load(params.dispersion_file)
|
||||||
lambda_ = disp_file["wavelength"]
|
lambda_ = disp_file["wavelength"]
|
||||||
D = disp_file["dispersion"]
|
D = disp_file["dispersion"]
|
||||||
beta2 = D_to_beta2(D, lambda_)
|
beta2 = D_to_beta2(D, lambda_)
|
||||||
@@ -669,21 +677,20 @@ def dispersion_central(fiber_model, params, deg=8):
|
|||||||
else:
|
else:
|
||||||
lambda_ = lambda_for_dispersion()
|
lambda_ = lambda_for_dispersion()
|
||||||
beta2 = np.zeros_like(lambda_)
|
beta2 = np.zeros_like(lambda_)
|
||||||
fiber_model = fiber_model.lower()
|
|
||||||
|
|
||||||
if fiber_model == "pcf":
|
if params.model == "pcf":
|
||||||
beta2, gamma = PCF_dispersion(
|
beta2, gamma = PCF_dispersion(
|
||||||
lambda_,
|
lambda_,
|
||||||
params["pitch"],
|
params.pitch,
|
||||||
params["pitch_ratio"],
|
params.pitch_ratio,
|
||||||
w0=params["w0"],
|
w0=params.w0,
|
||||||
n2=params.get("n2"),
|
n2=params.n2,
|
||||||
A_eff=params.get("A_eff"),
|
A_eff=params.A_eff,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Load material info
|
# Load material info
|
||||||
gas_name = params["gas_name"]
|
gas_name = params.gas_name
|
||||||
|
|
||||||
if gas_name == "vacuum":
|
if gas_name == "vacuum":
|
||||||
material_dico = None
|
material_dico = None
|
||||||
@@ -691,8 +698,20 @@ def dispersion_central(fiber_model, params, deg=8):
|
|||||||
material_dico = toml.loads(io.Paths.gets("gas"))[gas_name]
|
material_dico = toml.loads(io.Paths.gets("gas"))[gas_name]
|
||||||
|
|
||||||
# compute dispersion
|
# compute dispersion
|
||||||
if params.get("dynamic_dispersion", False):
|
if params.dynamic_dispersion:
|
||||||
return dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg)
|
return dynamic_HCPCF_dispersion(
|
||||||
|
lambda_,
|
||||||
|
params.pressure,
|
||||||
|
params.core_radius,
|
||||||
|
params.model,
|
||||||
|
{k: getattr(params, k) for k in hc_model_specific_parameters[params.model]},
|
||||||
|
params.temperature,
|
||||||
|
params.ideal_gas,
|
||||||
|
params.w0,
|
||||||
|
params.interp_range,
|
||||||
|
material_dico,
|
||||||
|
deg,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# actually compute the dispersion
|
# actually compute the dispersion
|
||||||
@@ -700,31 +719,31 @@ def dispersion_central(fiber_model, params, deg=8):
|
|||||||
beta2 = HCPCF_dispersion(
|
beta2 = HCPCF_dispersion(
|
||||||
lambda_,
|
lambda_,
|
||||||
material_dico,
|
material_dico,
|
||||||
fiber_model,
|
params.model,
|
||||||
{k: params[k] for k in hc_model_specific_parameters[fiber_model]},
|
{k: getattr(params, k) for k in hc_model_specific_parameters[params.model]},
|
||||||
params["pressure"],
|
params.pressure,
|
||||||
params["temperature"],
|
params.temperature,
|
||||||
params["ideal_gas"],
|
params.ideal_gas,
|
||||||
)
|
)
|
||||||
|
|
||||||
if material_dico is not None:
|
if material_dico is not None:
|
||||||
A_eff = 1.5 * params["core_radius"] ** 2
|
A_eff = 1.5 * params.core_radius ** 2
|
||||||
n2 = mat.non_linear_refractive_index(
|
n2 = mat.non_linear_refractive_index(
|
||||||
material_dico, params["pressure"], params["temperature"]
|
material_dico, params.pressure, params.temperature
|
||||||
)
|
)
|
||||||
gamma = gamma_parameter(n2, params["w0"], A_eff)
|
gamma = gamma_parameter(n2, params.w0, A_eff)
|
||||||
else:
|
else:
|
||||||
gamma = None
|
gamma = None
|
||||||
|
|
||||||
# add plasma if wanted
|
# add plasma if wanted
|
||||||
if params["plasma_density"] > 0:
|
if params.plasma_density > 0:
|
||||||
beta2 += plasma_dispersion(lambda_, params["plasma_density"])
|
beta2 += plasma_dispersion(lambda_, params.plasma_density)
|
||||||
|
|
||||||
beta2_coef = dispersion_coefficients(lambda_, beta2, params["w0"], params["interp_range"], deg)
|
beta2_coef = dispersion_coefficients(lambda_, beta2, params.w0, params.interp_range, deg)
|
||||||
|
|
||||||
if gamma is None:
|
if gamma is None:
|
||||||
if "A_eff" in params:
|
if params.A_eff is not None:
|
||||||
gamma = gamma_parameter(params.get("n2", 2.6e-20), params["w0"], params["A_eff"])
|
gamma = gamma_parameter(params.n2, params.w0, params.A_eff)
|
||||||
else:
|
else:
|
||||||
gamma = 0
|
gamma = 0
|
||||||
|
|
||||||
@@ -855,15 +874,13 @@ def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=Non
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Compute raman response function if necessary
|
# Compute raman response function if necessary
|
||||||
|
f_r = 0.18
|
||||||
if "raman" in behaviors:
|
if "raman" in behaviors:
|
||||||
if "hr_w" == None:
|
if hr_w is None:
|
||||||
raise TypeError("freq-dependent Raman response must be give")
|
raise ValueError("freq-dependent Raman response must be give")
|
||||||
else:
|
if f_r is None:
|
||||||
if f_r is None:
|
if raman_type == "agrawal":
|
||||||
if raman_type in ["stolen", "measured"]:
|
f_r = 0.245
|
||||||
f_r = 0.18
|
|
||||||
elif raman_type == "agrawal":
|
|
||||||
f_r = 0.245
|
|
||||||
|
|
||||||
if "spm" in behaviors:
|
if "spm" in behaviors:
|
||||||
spm_part = lambda fi: (1 - f_r) * abs2(fi)
|
spm_part = lambda fi: (1 - f_r) * abs2(fi)
|
||||||
@@ -875,7 +892,6 @@ def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=Non
|
|||||||
else:
|
else:
|
||||||
raman_part = lambda fi: 0
|
raman_part = lambda fi: 0
|
||||||
|
|
||||||
spm_part = jit(spm_part, nopython=True)
|
|
||||||
ss_part = w_c / w0 if "ss" in behaviors else 0
|
ss_part = w_c / w0 if "ss" in behaviors else 0
|
||||||
|
|
||||||
if isinstance(gamma, (float, int)):
|
if isinstance(gamma, (float, int)):
|
||||||
@@ -924,7 +940,6 @@ def fast_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None)):
|
|||||||
return -1j * out
|
return -1j * out
|
||||||
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
|
||||||
def _fast_disp_loop(dispersion: np.ndarray, beta_arr, power_fact_arr):
|
def _fast_disp_loop(dispersion: np.ndarray, beta_arr, power_fact_arr):
|
||||||
for k in range(len(beta_arr) - 1, -1, -1):
|
for k in range(len(beta_arr) - 1, -1, -1):
|
||||||
dispersion = dispersion + beta_arr[k] * power_fact_arr[k]
|
dispersion = dispersion + beta_arr[k] * power_fact_arr[k]
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
|
|
||||||
from . import units
|
from . import units
|
||||||
from .units import NA, c, kB
|
from .units import NA, c, kB
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ n is the number of spectra at the same z position and nt is the size of the time
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Literal, Tuple
|
from typing import Literal, Tuple
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@@ -18,13 +19,13 @@ import numpy as np
|
|||||||
from numpy import pi
|
from numpy import pi
|
||||||
from numpy.fft import fft, fftshift, ifft
|
from numpy.fft import fft, fftshift, ifft
|
||||||
from scipy.interpolate import UnivariateSpline
|
from scipy.interpolate import UnivariateSpline
|
||||||
from numba import jit
|
|
||||||
|
|
||||||
|
from .. import io
|
||||||
from ..defaults import default_plotting
|
from ..defaults import default_plotting
|
||||||
|
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..plotting import plot_setup
|
|
||||||
from ..math import *
|
from ..math import *
|
||||||
|
from ..plotting import plot_setup
|
||||||
|
from ..utils.parameter import BareParams
|
||||||
|
|
||||||
c = 299792458.0
|
c = 299792458.0
|
||||||
hbar = 1.05457148e-34
|
hbar = 1.05457148e-34
|
||||||
@@ -205,6 +206,48 @@ def conform_pulse_params(
|
|||||||
return width, t0, peak_power, energy, soliton_num
|
return width, t0, peak_power, energy, soliton_num
|
||||||
|
|
||||||
|
|
||||||
|
def setup_custom_field(params: BareParams) -> bool:
|
||||||
|
"""sets up a custom field function if necessary and returns
|
||||||
|
True if it did so, False otherwise
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
params : Dict[str, Any]
|
||||||
|
params dictionary
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the field has been modified
|
||||||
|
"""
|
||||||
|
field_0 = width = peak_power = energy = None
|
||||||
|
|
||||||
|
did_set = True
|
||||||
|
|
||||||
|
if params.prev_data_dir is not None:
|
||||||
|
spec = io.load_last_spectrum(Path(params.prev_data_dir))[1]
|
||||||
|
field_0 = np.fft.ifft(spec) * np.sqrt(params.input_transmission)
|
||||||
|
elif params.field_file is not None:
|
||||||
|
field_data = np.load(params.field_file)
|
||||||
|
field_interp = interp1d(
|
||||||
|
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
|
||||||
|
)
|
||||||
|
field_0 = field_interp(params.t)
|
||||||
|
|
||||||
|
field_0 = field_0 * modify_field_ratio(
|
||||||
|
params.t,
|
||||||
|
field_0,
|
||||||
|
params.peak_power,
|
||||||
|
params.energy,
|
||||||
|
params.intensity_noise,
|
||||||
|
)
|
||||||
|
width, peak_power, energy = measure_field(params.t, field_0)
|
||||||
|
else:
|
||||||
|
did_set = False
|
||||||
|
|
||||||
|
return did_set, width, peak_power, energy, field_0
|
||||||
|
|
||||||
|
|
||||||
def E0_to_P0(E0, t0, shape="gaussian"):
|
def E0_to_P0(E0, t0, shape="gaussian"):
|
||||||
"""convert an initial total pulse energy to a pulse peak peak_power"""
|
"""convert an initial total pulse energy to a pulse peak peak_power"""
|
||||||
return E0 / (t0 * P0T0_to_E0_fac[shape])
|
return E0 / (t0 * P0T0_to_E0_fac[shape])
|
||||||
@@ -223,12 +266,10 @@ def gauss_pulse(t, t0, P0, offset=0):
|
|||||||
return np.sqrt(P0) * np.exp(-(((t - offset) / t0) ** 2))
|
return np.sqrt(P0) * np.exp(-(((t - offset) / t0) ** 2))
|
||||||
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
|
||||||
def photon_number(spectrum, w, dw, gamma):
|
def photon_number(spectrum, w, dw, gamma):
|
||||||
return np.sum(1 / gamma * abs2(spectrum) / w * dw)
|
return np.sum(1 / gamma * abs2(spectrum) / w * dw)
|
||||||
|
|
||||||
|
|
||||||
@jit(nopython=True)
|
|
||||||
def pulse_energy(spectrum, w, dw, _):
|
def pulse_energy(spectrum, w, dw, _):
|
||||||
return np.sum(abs2(spectrum) * dw)
|
return np.sum(abs2(spectrum) * dw)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Tuple, Type
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -18,7 +19,14 @@ except ModuleNotFoundError:
|
|||||||
|
|
||||||
|
|
||||||
class RK4IP:
|
class RK4IP:
|
||||||
def __init__(self, sim_params, save_data=False, job_identifier="", task_id=0, n_percent=10):
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: initialize.Params,
|
||||||
|
save_data=False,
|
||||||
|
job_identifier="",
|
||||||
|
task_id=0,
|
||||||
|
n_percent=10,
|
||||||
|
):
|
||||||
"""A 1D solver using 4th order Runge-Kutta in the interaction picture
|
"""A 1D solver using 4th order Runge-Kutta in the interaction picture
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -76,31 +84,29 @@ class RK4IP:
|
|||||||
self.logger = get_logger(self.job_identifier)
|
self.logger = get_logger(self.job_identifier)
|
||||||
self.resuming = False
|
self.resuming = False
|
||||||
self.save_data = save_data
|
self.save_data = save_data
|
||||||
self._extract_params(sim_params)
|
|
||||||
self._setup_functions()
|
|
||||||
self.starting_num = sim_params.get("recovery_last_stored", 0)
|
|
||||||
self._setup_sim_parameters()
|
|
||||||
|
|
||||||
def _extract_params(self, params):
|
self.w_c = params.w_c
|
||||||
self.w_c = params.pop("w_c")
|
self.w0 = params.w0
|
||||||
self.w0 = params.pop("w0")
|
self.w_power_fact = params.w_power_fact
|
||||||
self.w_power_fact = params.pop("w_power_fact")
|
self.spec_0 = params.spec_0
|
||||||
self.spec_0 = params.pop("spec_0")
|
self.z_targets = params.z_targets
|
||||||
self.z_targets = params.pop("z_targets")
|
self.z_final = params.length
|
||||||
self.z_final = params.pop("length")
|
self.beta = params.beta_func if params.beta_func is not None else params.beta
|
||||||
self.beta = params.pop("beta_func", params.pop("beta"))
|
self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma
|
||||||
self.gamma = params.pop("gamma_func", params.pop("gamma"))
|
self.behaviors = params.behaviors
|
||||||
self.behaviors = params.pop("behaviors")
|
self.raman_type = params.raman_type
|
||||||
self.raman_type = params.pop("raman_type", "stolen")
|
self.hr_w = params.hr_w
|
||||||
self.f_r = params.pop("f_r", 0)
|
self.adapt_step_size = params.adapt_step_size
|
||||||
self.hr_w = params.pop("hr_w", None)
|
self.error_ok = params.error_ok
|
||||||
self.adapt_step_size = params.pop("adapt_step_size", True)
|
self.dynamic_dispersion = params.dynamic_dispersion
|
||||||
self.error_ok = params.pop("error_ok")
|
self.starting_num = params.recovery_last_stored
|
||||||
self.dynamic_dispersion = params.pop("dynamic_dispersion", False)
|
|
||||||
|
self._setup_functions()
|
||||||
|
self._setup_sim_parameters()
|
||||||
|
|
||||||
def _setup_functions(self):
|
def _setup_functions(self):
|
||||||
self.N_func = create_non_linear_op(
|
self.N_func = create_non_linear_op(
|
||||||
self.behaviors, self.w_c, self.w0, self.gamma, self.raman_type, self.f_r, self.hr_w
|
self.behaviors, self.w_c, self.w0, self.gamma, self.raman_type, hr_w=self.hr_w
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.dynamic_dispersion:
|
if self.dynamic_dispersion:
|
||||||
@@ -303,7 +309,7 @@ class RK4IP:
|
|||||||
class SequentialRK4IP(RK4IP):
|
class SequentialRK4IP(RK4IP):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sim_params,
|
params: initialize.Params,
|
||||||
pbars: utils.PBars,
|
pbars: utils.PBars,
|
||||||
save_data=False,
|
save_data=False,
|
||||||
job_identifier="",
|
job_identifier="",
|
||||||
@@ -312,7 +318,7 @@ class SequentialRK4IP(RK4IP):
|
|||||||
):
|
):
|
||||||
self.pbars = pbars
|
self.pbars = pbars
|
||||||
super().__init__(
|
super().__init__(
|
||||||
sim_params,
|
params,
|
||||||
save_data=save_data,
|
save_data=save_data,
|
||||||
job_identifier=job_identifier,
|
job_identifier=job_identifier,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@@ -326,7 +332,7 @@ class SequentialRK4IP(RK4IP):
|
|||||||
class MutliProcRK4IP(RK4IP):
|
class MutliProcRK4IP(RK4IP):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sim_params,
|
params: initialize.Params,
|
||||||
p_queue: multiprocessing.Queue,
|
p_queue: multiprocessing.Queue,
|
||||||
worker_id: int,
|
worker_id: int,
|
||||||
save_data=False,
|
save_data=False,
|
||||||
@@ -337,7 +343,7 @@ class MutliProcRK4IP(RK4IP):
|
|||||||
self.worker_id = worker_id
|
self.worker_id = worker_id
|
||||||
self.p_queue = p_queue
|
self.p_queue = p_queue
|
||||||
super().__init__(
|
super().__init__(
|
||||||
sim_params,
|
params,
|
||||||
save_data=save_data,
|
save_data=save_data,
|
||||||
job_identifier=job_identifier,
|
job_identifier=job_identifier,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@@ -351,7 +357,7 @@ class MutliProcRK4IP(RK4IP):
|
|||||||
class RayRK4IP(RK4IP):
|
class RayRK4IP(RK4IP):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sim_params,
|
params: initialize.Params,
|
||||||
p_actor,
|
p_actor,
|
||||||
worker_id: int,
|
worker_id: int,
|
||||||
save_data=False,
|
save_data=False,
|
||||||
@@ -362,7 +368,7 @@ class RayRK4IP(RK4IP):
|
|||||||
self.worker_id = worker_id
|
self.worker_id = worker_id
|
||||||
self.p_actor = p_actor
|
self.p_actor = p_actor
|
||||||
super().__init__(
|
super().__init__(
|
||||||
sim_params,
|
params,
|
||||||
save_data=save_data,
|
save_data=save_data,
|
||||||
job_identifier=job_identifier,
|
job_identifier=job_identifier,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@@ -414,7 +420,7 @@ class Simulations:
|
|||||||
if isinstance(method, str):
|
if isinstance(method, str):
|
||||||
method = Simulations.simulation_methods_dict[method]
|
method = Simulations.simulation_methods_dict[method]
|
||||||
return method(param_seq, task_id)
|
return method(param_seq, task_id)
|
||||||
elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"]:
|
elif param_seq.num_sim > 1 and param_seq.config.parallel:
|
||||||
return Simulations.get_best_method()(param_seq, task_id)
|
return Simulations.get_best_method()(param_seq, task_id)
|
||||||
else:
|
else:
|
||||||
return SequencialSimulations(param_seq, task_id)
|
return SequencialSimulations(param_seq, task_id)
|
||||||
@@ -439,7 +445,7 @@ class Simulations:
|
|||||||
|
|
||||||
self.name = self.param_seq.name
|
self.name = self.param_seq.name
|
||||||
self.sim_dir = io.get_sim_dir(self.id, name_if_new=self.name)
|
self.sim_dir = io.get_sim_dir(self.id, name_if_new=self.name)
|
||||||
io.save_toml(os.path.join(self.sim_dir, "initial_config.toml"), self.param_seq.config)
|
io.save_parameters(self.param_seq.config, self.sim_dir, file_name="initial_config.toml")
|
||||||
|
|
||||||
self.sim_jobs_per_node = 1
|
self.sim_jobs_per_node = 1
|
||||||
self.max_concurrent_jobs = np.inf
|
self.max_concurrent_jobs = np.inf
|
||||||
@@ -447,9 +453,7 @@ class Simulations:
|
|||||||
@property
|
@property
|
||||||
def finished_and_complete(self):
|
def finished_and_complete(self):
|
||||||
try:
|
try:
|
||||||
io.check_data_integrity(
|
io.check_data_integrity(io.get_data_dirs(self.sim_dir), self.param_seq.config.z_num)
|
||||||
io.get_data_dirs(self.sim_dir), self.param_seq["simulation", "z_num"]
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
except IncompleteDataFolderError:
|
except IncompleteDataFolderError:
|
||||||
return False
|
return False
|
||||||
@@ -472,15 +476,15 @@ class Simulations:
|
|||||||
self.new_sim(v_list_str, params)
|
self.new_sim(v_list_str, params)
|
||||||
self.finish()
|
self.finish()
|
||||||
|
|
||||||
def new_sim(self, v_list_str: str, params: dict):
|
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||||
"""responsible to launch a new simulation
|
"""responsible to launch a new simulation
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
v_list_str : str
|
v_list_str : str
|
||||||
string that uniquely identifies the simulation as returned by utils.format_variable_list
|
string that uniquely identifies the simulation as returned by utils.format_variable_list
|
||||||
params : dict
|
params : initialize.Params
|
||||||
a flattened parameter dictionary, as returned by initialize.compute_init_parameters
|
computed parameters
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@@ -507,7 +511,7 @@ class SequencialSimulations(Simulations, priority=0):
|
|||||||
super().__init__(param_seq, task_id=task_id)
|
super().__init__(param_seq, task_id=task_id)
|
||||||
self.pbars = utils.PBars(self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1)
|
self.pbars = utils.PBars(self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1)
|
||||||
|
|
||||||
def new_sim(self, v_list_str: str, params: Dict[str, Any]):
|
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||||
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
|
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
|
||||||
SequentialRK4IP(
|
SequentialRK4IP(
|
||||||
params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id
|
params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id
|
||||||
@@ -517,7 +521,7 @@ class SequencialSimulations(Simulations, priority=0):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
pass
|
self.pbars.close()
|
||||||
|
|
||||||
|
|
||||||
class MultiProcSimulations(Simulations, priority=1):
|
class MultiProcSimulations(Simulations, priority=1):
|
||||||
@@ -553,7 +557,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
|||||||
worker.start()
|
worker.start()
|
||||||
super().run()
|
super().run()
|
||||||
|
|
||||||
def new_sim(self, v_list_str: str, params: dict):
|
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||||
self.queue.put((v_list_str, params), block=True, timeout=None)
|
self.queue.put((v_list_str, params), block=True, timeout=None)
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
@@ -576,7 +580,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
|||||||
p_queue: multiprocessing.Queue,
|
p_queue: multiprocessing.Queue,
|
||||||
):
|
):
|
||||||
while True:
|
while True:
|
||||||
raw_data: Tuple[List[tuple], Dict[str, Any]] = queue.get()
|
raw_data: Tuple[List[tuple], initialize.Params] = queue.get()
|
||||||
if raw_data == 0:
|
if raw_data == 0:
|
||||||
queue.task_done()
|
queue.task_done()
|
||||||
return
|
return
|
||||||
@@ -635,7 +639,7 @@ class RaySimulations(Simulations, priority=2):
|
|||||||
.remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
|
.remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
|
||||||
)
|
)
|
||||||
|
|
||||||
def new_sim(self, v_list_str: str, params: dict):
|
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||||
while len(self.jobs) >= self.sim_jobs_total:
|
while len(self.jobs) >= self.sim_jobs_total:
|
||||||
self._collect_1_job()
|
self._collect_1_job()
|
||||||
|
|
||||||
@@ -707,28 +711,27 @@ def new_simulation(
|
|||||||
method: Type[Simulations] = None,
|
method: Type[Simulations] = None,
|
||||||
) -> Simulations:
|
) -> Simulations:
|
||||||
|
|
||||||
config = io.load_toml(config_file)
|
config_dict = io.load_toml(config_file)
|
||||||
|
|
||||||
if prev_sim_dir is not None:
|
if prev_sim_dir is not None:
|
||||||
config.setdefault("simulation", {})
|
config_dict["prev_sim_dir"] = str(prev_sim_dir)
|
||||||
config["simulation"]["prev_sim_dir"] = str(prev_sim_dir)
|
|
||||||
|
|
||||||
task_id = np.random.randint(1e9, 1e12)
|
task_id = np.random.randint(1e9, 1e12)
|
||||||
|
|
||||||
if prev_sim_dir is None:
|
if prev_sim_dir is None:
|
||||||
param_seq = initialize.ParamSequence(config)
|
param_seq = initialize.ParamSequence(config_dict)
|
||||||
else:
|
else:
|
||||||
param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config)
|
param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config_dict)
|
||||||
|
|
||||||
print(f"{param_seq.name=}")
|
print(f"{param_seq.name=}")
|
||||||
|
|
||||||
return Simulations.new(param_seq, task_id, method)
|
return Simulations.new(param_seq, task_id, method)
|
||||||
|
|
||||||
|
|
||||||
def resume_simulations(sim_dir: str, method: Type[Simulations] = None) -> Simulations:
|
def resume_simulations(sim_dir: Path, method: Type[Simulations] = None) -> Simulations:
|
||||||
|
|
||||||
task_id = np.random.randint(1e9, 1e12)
|
task_id = np.random.randint(1e9, 1e12)
|
||||||
config = io.load_toml(os.path.join(sim_dir, "initial_config.toml"))
|
config = io.load_toml(sim_dir / "initial_config.toml")
|
||||||
io.set_data_folder(task_id, sim_dir)
|
io.set_data_folder(task_id, sim_dir)
|
||||||
param_seq = initialize.RecoveryParamSequence(config, task_id)
|
param_seq = initialize.RecoveryParamSequence(config, task_id)
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,10 @@
|
|||||||
# For example, nm(X) means "I give the number X in nm, figure out the ang. freq."
|
# For example, nm(X) means "I give the number X in nm, figure out the ang. freq."
|
||||||
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
|
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
|
||||||
|
|
||||||
from numba.core.types.misc import Phantom
|
from typing import Callable, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import isin, pi
|
from numpy import pi
|
||||||
|
|
||||||
c = 299792458.0
|
c = 299792458.0
|
||||||
hbar = 1.05457148e-34
|
hbar = 1.05457148e-34
|
||||||
@@ -217,7 +218,7 @@ units_map = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_unit(unit):
|
def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]:
|
||||||
if isinstance(unit, str):
|
if isinstance(unit, str):
|
||||||
return units_map[unit]
|
return units_map[unit]
|
||||||
return unit
|
return unit
|
||||||
|
|||||||
@@ -1,55 +1,47 @@
|
|||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import matplotlib.gridspec as gs
|
import matplotlib.gridspec as gs
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from matplotlib.colors import ListedColormap
|
from matplotlib.colors import ListedColormap
|
||||||
from scgenerator.utils import variable_iterator
|
|
||||||
from scipy.interpolate import UnivariateSpline
|
from scipy.interpolate import UnivariateSpline
|
||||||
|
|
||||||
from . import io, math
|
from . import io, math
|
||||||
from .math import abs2, length, make_uniform_1D, span
|
|
||||||
from .physics import pulse, units
|
|
||||||
from .defaults import default_plotting as defaults
|
from .defaults import default_plotting as defaults
|
||||||
|
from .math import abs2, make_uniform_1D, span
|
||||||
|
from .physics import pulse, units
|
||||||
|
from .utils.parameter import BareParams
|
||||||
|
|
||||||
|
RangeType = Tuple[float, float, Union[str, Callable]]
|
||||||
|
|
||||||
|
|
||||||
def plot_setup(
|
def plot_setup(
|
||||||
folder_name=None,
|
out_path: Path,
|
||||||
file_name=None,
|
file_type: str = "png",
|
||||||
file_type="png",
|
figsize: Tuple[float, float] = defaults["figsize"],
|
||||||
figsize=defaults["figsize"],
|
mode: Literal["default", "coherence", "coherence_T"] = "default",
|
||||||
params=None,
|
) -> Tuple[Path, plt.Figure, Union[plt.Axes, Tuple[plt.Axes]]]:
|
||||||
mode="default",
|
|
||||||
):
|
|
||||||
"""It should return :
|
"""It should return :
|
||||||
- a folder_name
|
- a folder_name
|
||||||
- a file name
|
- a file name
|
||||||
- a fig
|
- a fig
|
||||||
- an axis
|
- an axis
|
||||||
"""
|
"""
|
||||||
file_name = defaults["name"] if file_name is None else file_name
|
out_path = defaults["name"] if out_path is None else out_path
|
||||||
|
plot_name = out_path.stem
|
||||||
|
out_dir = out_path.resolve().parent
|
||||||
|
|
||||||
if params is not None:
|
file_name = plot_name + "." + file_type
|
||||||
folder_name = params.get("plot.folder_name", folder_name)
|
out_path = out_dir / file_name
|
||||||
file_name = params.get("plot.file_name", file_name)
|
|
||||||
file_type = params.get("plot.file_type", file_type)
|
|
||||||
figsize = params.get("plot.figsize", figsize)
|
|
||||||
|
|
||||||
# ensure output folder_name exists
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
folder_name, file_name = (
|
|
||||||
os.path.split(file_name)
|
|
||||||
if folder_name is None
|
|
||||||
else (folder_name, os.path.split(file_name)[1])
|
|
||||||
)
|
|
||||||
folder_name = os.path.join(io.Paths.get("plots"), folder_name)
|
|
||||||
if not os.path.exists(os.path.abspath(folder_name)):
|
|
||||||
os.makedirs(os.path.abspath(folder_name))
|
|
||||||
|
|
||||||
# ensure no overwrite
|
# ensure no overwrite
|
||||||
ind = 0
|
ind = 0
|
||||||
while os.path.exists(os.path.join(folder_name, file_name + "_" + str(ind) + "." + file_type)):
|
while (full_path := (out_dir / (plot_name + f"_{ind}." + file_type))).exists():
|
||||||
ind += 1
|
ind += 1
|
||||||
file_name = file_name + "_" + str(ind) + "." + file_type
|
|
||||||
|
|
||||||
if mode == "default":
|
if mode == "default":
|
||||||
fig, ax = plt.subplots(figsize=figsize)
|
fig, ax = plt.subplots(figsize=figsize)
|
||||||
@@ -78,7 +70,7 @@ def plot_setup(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"mode {mode} not understood")
|
raise ValueError(f"mode {mode} not understood")
|
||||||
|
|
||||||
return folder_name, file_name, fig, ax
|
return full_path, fig, ax
|
||||||
|
|
||||||
|
|
||||||
def draw_across(ax1, xy1, ax2, xy2, clip_on=False, **kwargs):
|
def draw_across(ax1, xy1, ax2, xy2, clip_on=False, **kwargs):
|
||||||
@@ -297,9 +289,7 @@ def _finish_plot_2D(
|
|||||||
|
|
||||||
folder_name = ""
|
folder_name = ""
|
||||||
if is_new_plot:
|
if is_new_plot:
|
||||||
folder_name, file_name, fig, ax = plot_setup(
|
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
|
||||||
file_name=file_name, file_type=file_type, params=params
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
fig = ax.get_figure()
|
fig = ax.get_figure()
|
||||||
|
|
||||||
@@ -345,8 +335,8 @@ def _finish_plot_2D(
|
|||||||
cbar.ax.set_ylabel(cbar_label)
|
cbar.ax.set_ylabel(cbar_label)
|
||||||
|
|
||||||
if is_new_plot:
|
if is_new_plot:
|
||||||
fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200)
|
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||||
print(f"plot saved in {os.path.join(folder_name, file_name)}")
|
print(f"plot saved in {out_path}")
|
||||||
if cbar_label is not None:
|
if cbar_label is not None:
|
||||||
return fig, ax, cbar.ax
|
return fig, ax, cbar.ax
|
||||||
else:
|
else:
|
||||||
@@ -354,20 +344,20 @@ def _finish_plot_2D(
|
|||||||
|
|
||||||
|
|
||||||
def plot_spectrogram(
|
def plot_spectrogram(
|
||||||
values,
|
values: np.ndarray,
|
||||||
x_range,
|
x_range: RangeType,
|
||||||
y_range,
|
y_range: RangeType,
|
||||||
params,
|
params: BareParams,
|
||||||
t_res=None,
|
t_res: int = None,
|
||||||
gate_width=None,
|
gate_width: float = None,
|
||||||
log=True,
|
log: bool = True,
|
||||||
vmin=None,
|
vmin: float = None,
|
||||||
vmax=None,
|
vmax: float = None,
|
||||||
cbar_label="normalized intensity (dB)",
|
cbar_label: str = "normalized intensity (dB)",
|
||||||
file_type="png",
|
file_type: str = "png",
|
||||||
file_name=None,
|
file_name: str = None,
|
||||||
cmap=None,
|
cmap: str = None,
|
||||||
ax=None,
|
ax: plt.Axes = None,
|
||||||
):
|
):
|
||||||
"""Plots a spectrogram given a complex field in the time domain
|
"""Plots a spectrogram given a complex field in the time domain
|
||||||
Parameters
|
Parameters
|
||||||
@@ -382,7 +372,7 @@ def plot_spectrogram(
|
|||||||
units : function to convert from the desired units to rad/s or to time.
|
units : function to convert from the desired units to rad/s or to time.
|
||||||
common functions are already defined in scgenerator.physics.units
|
common functions are already defined in scgenerator.physics.units
|
||||||
look there for more details
|
look there for more details
|
||||||
params : dict
|
params : BareParams
|
||||||
parameters of the simulations
|
parameters of the simulations
|
||||||
log : bool, optional
|
log : bool, optional
|
||||||
whether to compute the logarithm of the spectrogram
|
whether to compute the logarithm of the spectrogram
|
||||||
@@ -424,16 +414,16 @@ def plot_spectrogram(
|
|||||||
t_win = 2 * np.max(t_range[2](np.abs(t_range[:2])))
|
t_win = 2 * np.max(t_range[2](np.abs(t_range[:2])))
|
||||||
spec_kwargs = dict(t_res=t_res, t_win=t_win, gate_width=gate_width, shift=False)
|
spec_kwargs = dict(t_res=t_res, t_win=t_win, gate_width=gate_width, shift=False)
|
||||||
spec, new_t = pulse.spectrogram(
|
spec, new_t = pulse.spectrogram(
|
||||||
params["t"].copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None}
|
params.t.copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Crop and reoder axis
|
# Crop and reoder axis
|
||||||
new_t, ind_t, _ = units.sort_axis(new_t, t_range)
|
new_t, ind_t, _ = units.sort_axis(new_t, t_range)
|
||||||
new_f, ind_f, _ = units.sort_axis(params["w"], f_range)
|
new_f, ind_f, _ = units.sort_axis(params.w, f_range)
|
||||||
values = spec[ind_t][:, ind_f]
|
values = spec[ind_t][:, ind_f]
|
||||||
if f_range[2].type == "WL":
|
if f_range[2].type == "WL":
|
||||||
values = np.apply_along_axis(
|
values = np.apply_along_axis(
|
||||||
units.to_WL, 1, values, params["frep"], units.m(f_range[2].inv(new_f))
|
units.to_WL, 1, values, params.frep, units.m(f_range[2].inv(new_f))
|
||||||
)
|
)
|
||||||
values = np.apply_along_axis(make_uniform_1D, 1, values, new_f)
|
values = np.apply_along_axis(make_uniform_1D, 1, values, new_f)
|
||||||
|
|
||||||
@@ -463,19 +453,19 @@ def plot_spectrogram(
|
|||||||
|
|
||||||
|
|
||||||
def plot_results_2D(
|
def plot_results_2D(
|
||||||
values,
|
values: np.ndarray,
|
||||||
plt_range,
|
plt_range: RangeType,
|
||||||
params,
|
params: BareParams,
|
||||||
log="1D",
|
log: Union[int, float, bool, str] = "1D",
|
||||||
skip=16,
|
skip: int = 16,
|
||||||
vmin=None,
|
vmin: float = None,
|
||||||
vmax=None,
|
vmax: float = None,
|
||||||
transpose=False,
|
transpose: bool = False,
|
||||||
cbar_label="normalized intensity (dB)",
|
cbar_label: Optional[str] = "normalized intensity (dB)",
|
||||||
file_type="png",
|
file_type: str = "png",
|
||||||
file_name=None,
|
file_name: str = None,
|
||||||
cmap=None,
|
cmap: str = None,
|
||||||
ax=None,
|
ax: plt.Axes = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
plots 2D arrays and automatically saves the plots, as well as returns it
|
plots 2D arrays and automatically saves the plots, as well as returns it
|
||||||
@@ -540,27 +530,32 @@ def plot_results_2D(
|
|||||||
# make uniform if converting to wavelength
|
# make uniform if converting to wavelength
|
||||||
if plt_range[2].type == "WL":
|
if plt_range[2].type == "WL":
|
||||||
if is_spectrum:
|
if is_spectrum:
|
||||||
values = np.apply_along_axis(units.to_WL, 1, values, params.get("frep", 1), x_axis)
|
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
||||||
values = np.array(
|
values = np.array(
|
||||||
[make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values]
|
[make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values]
|
||||||
)
|
)
|
||||||
|
|
||||||
z = params["z_targets"]
|
lim_diff = 1e-5 * np.max(params.z_targets)
|
||||||
lim_diff = 1e-5 * np.max(z)
|
dz_s = np.diff(params.z_targets)
|
||||||
dz_s = np.diff(z)
|
|
||||||
if not np.all(np.diff(dz_s) < lim_diff):
|
if not np.all(np.diff(dz_s) < lim_diff):
|
||||||
new_z = np.linspace(
|
new_z = np.linspace(
|
||||||
*span(z), int(np.floor((np.max(z) - np.min(z)) / np.min(dz_s[dz_s > lim_diff])))
|
*span(params.z_targets),
|
||||||
|
int(
|
||||||
|
np.floor(
|
||||||
|
(np.max(params.z_targets) - np.min(params.z_targets))
|
||||||
|
/ np.min(dz_s[dz_s > lim_diff])
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
values = np.array(
|
values = np.array(
|
||||||
[make_uniform_1D(v, z, n=len(new_z), method="linear") for v in values.T]
|
[make_uniform_1D(v, params.z_targets, n=len(new_z), method="linear") for v in values.T]
|
||||||
).T
|
).T
|
||||||
z = new_z
|
params.z_targets = new_z
|
||||||
return _finish_plot_2D(
|
return _finish_plot_2D(
|
||||||
values,
|
values,
|
||||||
x_axis,
|
x_axis,
|
||||||
plt_range[2].label,
|
plt_range[2].label,
|
||||||
z,
|
params.z_targets,
|
||||||
"propagation distance (m)",
|
"propagation distance (m)",
|
||||||
log,
|
log,
|
||||||
vmin,
|
vmin,
|
||||||
@@ -576,20 +571,20 @@ def plot_results_2D(
|
|||||||
|
|
||||||
|
|
||||||
def plot_results_1D(
|
def plot_results_1D(
|
||||||
values,
|
values: np.ndarray,
|
||||||
plt_range,
|
plt_range: RangeType,
|
||||||
params,
|
params: BareParams,
|
||||||
log=False,
|
log: Union[str, int, float, bool] = False,
|
||||||
spacing=1,
|
spacing: Union[int, float] = 1,
|
||||||
vmin=None,
|
vmin: float = None,
|
||||||
vmax=None,
|
vmax: float = None,
|
||||||
ylabel=None,
|
ylabel: str = None,
|
||||||
yscaling=1,
|
yscaling: float = 1,
|
||||||
file_type="pdf",
|
file_type: str = "pdf",
|
||||||
file_name=None,
|
file_name: str = None,
|
||||||
ax=None,
|
ax: plt.Axes = None,
|
||||||
line_label=None,
|
line_label: str = None,
|
||||||
transpose=False,
|
transpose: bool = False,
|
||||||
**line_kwargs,
|
**line_kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -656,7 +651,7 @@ def plot_results_1D(
|
|||||||
# make uniform if converting to wavelength
|
# make uniform if converting to wavelength
|
||||||
if plt_range[2].type == "WL":
|
if plt_range[2].type == "WL":
|
||||||
if is_spectrum:
|
if is_spectrum:
|
||||||
values = units.to_WL(values, params["frep"], units.m.inv(params["w"][ind]))
|
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind]))
|
||||||
|
|
||||||
# change the resolution
|
# change the resolution
|
||||||
if isinstance(spacing, float):
|
if isinstance(spacing, float):
|
||||||
@@ -683,9 +678,7 @@ def plot_results_1D(
|
|||||||
|
|
||||||
folder_name = ""
|
folder_name = ""
|
||||||
if is_new_plot:
|
if is_new_plot:
|
||||||
folder_name, file_name, fig, ax = plot_setup(
|
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
|
||||||
file_name=file_name, file_type=file_type, params=params
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
fig = ax.get_figure()
|
fig = ax.get_figure()
|
||||||
if transpose:
|
if transpose:
|
||||||
@@ -702,40 +695,40 @@ def plot_results_1D(
|
|||||||
ax.set_xlabel(plt_range[2].label)
|
ax.set_xlabel(plt_range[2].label)
|
||||||
|
|
||||||
if is_new_plot:
|
if is_new_plot:
|
||||||
fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200)
|
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||||
print(f"plot saved in {os.path.join(folder_name, file_name)}")
|
print(f"plot saved in {out_path}")
|
||||||
return fig, ax, x_axis, values
|
return fig, ax, x_axis, values
|
||||||
|
|
||||||
|
|
||||||
def _prep_plot(values, plt_range, params):
|
def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams):
|
||||||
is_spectrum = values.dtype == "complex"
|
is_spectrum = values.dtype == "complex"
|
||||||
plt_range = (*plt_range[:2], units.get_unit(plt_range[2]))
|
plt_range = (*plt_range[:2], units.get_unit(plt_range[2]))
|
||||||
if plt_range[2].type in ["WL", "FREQ", "AFREQ"]:
|
if plt_range[2].type in ["WL", "FREQ", "AFREQ"]:
|
||||||
x_axis = params["w"].copy()
|
x_axis = params.w.copy()
|
||||||
else:
|
else:
|
||||||
x_axis = params["t"].copy()
|
x_axis = params.t.copy()
|
||||||
return is_spectrum, x_axis, plt_range
|
return is_spectrum, x_axis, plt_range
|
||||||
|
|
||||||
|
|
||||||
def plot_avg(
|
def plot_avg(
|
||||||
values,
|
values: np.ndarray,
|
||||||
plt_range,
|
plt_range: RangeType,
|
||||||
params,
|
params: BareParams,
|
||||||
log=False,
|
log: Union[float, int, str, bool] = False,
|
||||||
spacing=1,
|
spacing: Union[float, int] = 1,
|
||||||
vmin=None,
|
vmin: float = None,
|
||||||
vmax=None,
|
vmax: float = None,
|
||||||
ylabel=None,
|
ylabel: str = None,
|
||||||
yscaling=1,
|
yscaling: float = 1,
|
||||||
renormalize=True,
|
renormalize: bool = True,
|
||||||
add_coherence=False,
|
add_coherence: bool = False,
|
||||||
file_type="png",
|
file_type: str = "png",
|
||||||
file_name=None,
|
file_name: str = None,
|
||||||
ax=None,
|
ax: plt.Axes = None,
|
||||||
line_labels=None,
|
line_labels: Tuple[str, str] = None,
|
||||||
legend=True,
|
legend: bool = True,
|
||||||
legend_kwargs={},
|
legend_kwargs: Dict[str, Any] = {},
|
||||||
transpose=False,
|
transpose: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
plots 1D arrays and there mean and automatically saves the plots, as well as returns it
|
plots 1D arrays and there mean and automatically saves the plots, as well as returns it
|
||||||
@@ -817,8 +810,8 @@ def plot_avg(
|
|||||||
values *= yscaling
|
values *= yscaling
|
||||||
mean_values = np.mean(values, axis=0)
|
mean_values = np.mean(values, axis=0)
|
||||||
if plt_range[2].type == "WL" and renormalize:
|
if plt_range[2].type == "WL" and renormalize:
|
||||||
values = np.apply_along_axis(units.to_WL, 1, values, params["frep"], x_axis)
|
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
||||||
mean_values = units.to_WL(mean_values, params["frep"], x_axis)
|
mean_values = units.to_WL(mean_values, params.frep, x_axis)
|
||||||
|
|
||||||
# change the resolution
|
# change the resolution
|
||||||
if isinstance(spacing, float):
|
if isinstance(spacing, float):
|
||||||
@@ -852,12 +845,12 @@ def plot_avg(
|
|||||||
if is_new_plot:
|
if is_new_plot:
|
||||||
if add_coherence:
|
if add_coherence:
|
||||||
mode = "coherence_T" if transpose else "coherence"
|
mode = "coherence_T" if transpose else "coherence"
|
||||||
folder_name, file_name, fig, (top, bot) = plot_setup(
|
out_path, fig, (top, bot) = plot_setup(
|
||||||
file_name=file_name, file_type=file_type, params=params, mode=mode
|
out_path=Path(folder_name) / file_name, file_type=file_type, mode=mode
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
folder_name, file_name, fig, top = plot_setup(
|
out_path, fig, top = plot_setup(
|
||||||
file_name=file_name, file_type=file_type, params=params
|
out_path=Path(folder_name) / file_name, file_type=file_type
|
||||||
)
|
)
|
||||||
bot = top
|
bot = top
|
||||||
else:
|
else:
|
||||||
@@ -923,8 +916,8 @@ def plot_avg(
|
|||||||
top.legend(custom_lines, line_labels, **legend_kwargs)
|
top.legend(custom_lines, line_labels, **legend_kwargs)
|
||||||
|
|
||||||
if is_new_plot:
|
if is_new_plot:
|
||||||
fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200)
|
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||||
print(f"plot saved in {os.path.join(folder_name, file_name)}")
|
print(f"plot saved in {out_path}")
|
||||||
|
|
||||||
if top is bot:
|
if top is bot:
|
||||||
return fig, top
|
return fig, top
|
||||||
@@ -984,46 +977,6 @@ def prepare_plot_1D(values, plt_range, x_axis, yscaling=1, spacing=1, frep=80e6)
|
|||||||
return x_axis, np.squeeze(values)
|
return x_axis, np.squeeze(values)
|
||||||
|
|
||||||
|
|
||||||
def plot_dispersion_parameter(params, plt_range):
|
|
||||||
"""
|
|
||||||
Plots the dispersion parameter D as well as the beta2 parameter over the given range
|
|
||||||
"""
|
|
||||||
# TODO allow several curves, with legends, to be plotted
|
|
||||||
|
|
||||||
x_axis = np.linspace(*plt_range[:2], 1000)
|
|
||||||
w_axis = plt_range[2](x_axis)
|
|
||||||
|
|
||||||
if "disp_obj" in params:
|
|
||||||
D = params["disp_obj"].D_w(w_axis)
|
|
||||||
beta2 = params["disp_obj"].beta2_w(w_axis)
|
|
||||||
else:
|
|
||||||
print("no dispersion information given")
|
|
||||||
return
|
|
||||||
|
|
||||||
fig, (ax_D, ax_beta2) = plt.subplots(1, 2)
|
|
||||||
|
|
||||||
ax_D.plot(x_axis, 1e6 * D)
|
|
||||||
ax_D.plot(
|
|
||||||
x_axis,
|
|
||||||
0 * x_axis,
|
|
||||||
":",
|
|
||||||
c="k",
|
|
||||||
)
|
|
||||||
ax_D.set_xlabel(plt_range[2].label)
|
|
||||||
ax_D.set_ylabel(r"Dispersion parameter $D$ ($\frac{\mathrm{ps}}{\mathrm{nm\ km}}$)")
|
|
||||||
|
|
||||||
ax_beta2.plot(x_axis, 1e27 * beta2)
|
|
||||||
ax_beta2.plot(
|
|
||||||
x_axis,
|
|
||||||
0 * x_axis,
|
|
||||||
":",
|
|
||||||
c="k",
|
|
||||||
)
|
|
||||||
ax_beta2.set_xlabel(plt_range[2].label)
|
|
||||||
ax_beta2.set_ylabel(r"$\beta_2$ parameter ($\frac{\mathrm{ps}^2}{\mathrm{km}}$)")
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):
|
def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):
|
||||||
"""returns a new colormap based on "name" but that has a solid bacground (default=white)"""
|
"""returns a new colormap based on "name" but that has a solid bacground (default=white)"""
|
||||||
top = plt.get_cmap(name, 1024)
|
top = plt.get_cmap(name, 1024)
|
||||||
|
|||||||
@@ -128,11 +128,11 @@ def main():
|
|||||||
args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node)
|
args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node)
|
||||||
|
|
||||||
submit_path = Path(
|
submit_path = Path(
|
||||||
"submit " + final_config["name"] + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh"
|
"submit " + final_config.name + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh"
|
||||||
)
|
)
|
||||||
tmp_path = Path("submit tmp.sh")
|
tmp_path = Path("submit tmp.sh")
|
||||||
|
|
||||||
job_name = f"supercontinuum {final_config['name']}"
|
job_name = f"supercontinuum {final_config.name}"
|
||||||
submit_sh = template.format(
|
submit_sh = template.format(
|
||||||
job_name=job_name, configs_list=" ".join(f'"{c}"' for c in args.configs), **vars(args)
|
job_name=job_name, configs_list=" ".join(f'"{c}"' for c in args.configs), **vars(args)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,16 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Sequence
|
||||||
from glob import glob
|
|
||||||
from typing import Any, Dict, List, Tuple
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from scgenerator.const import SPECN_FN
|
from . import initialize, io, math
|
||||||
|
from .const import SPECN_FN
|
||||||
from . import io, initialize, math
|
|
||||||
from .plotting import units
|
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .plotting import units
|
||||||
|
|
||||||
|
|
||||||
class Spectrum(np.ndarray):
|
class Spectrum(np.ndarray):
|
||||||
@@ -43,7 +41,7 @@ class Pulse(Sequence):
|
|||||||
|
|
||||||
self.params = None
|
self.params = None
|
||||||
try:
|
try:
|
||||||
self.params = io.load_previous_parameters(self.path / "params.toml")
|
self.params = io.load_params(self.path / "params.toml")
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
self.logger.info(f"parameters corresponding to {self.path} not found")
|
self.logger.info(f"parameters corresponding to {self.path} not found")
|
||||||
|
|
||||||
|
|||||||
@@ -4,24 +4,23 @@ scgenerator module but some function may be used in any python program
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import collections
|
|
||||||
import itertools
|
import itertools
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from collections import abc
|
from collections import abc
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from dataclasses import asdict, replace
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Tuple, TypeVar, Union
|
from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import env
|
from .. import env
|
||||||
from .const import PARAM_SEPARATOR, valid_variable
|
from ..const import PARAM_SEPARATOR
|
||||||
from .math import *
|
from ..math import *
|
||||||
|
from .parameter import BareConfig, BareParams
|
||||||
|
|
||||||
T_ = TypeVar("T_")
|
T_ = TypeVar("T_")
|
||||||
|
|
||||||
@@ -177,18 +176,11 @@ def progress_worker(
|
|||||||
pbars[0].update()
|
pbars[0].update()
|
||||||
|
|
||||||
|
|
||||||
def count_variations(config: dict) -> Tuple[int, int]:
|
def count_variations(config: BareConfig) -> Tuple[int, int]:
|
||||||
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
|
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
|
||||||
variable_params_num is the number of distinct parameters that will vary."""
|
variable_params_num is the number of distinct parameters that will vary."""
|
||||||
sim_num = 1
|
variable_params_num = len(config.variable)
|
||||||
variable_params_num = 0
|
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
|
||||||
|
|
||||||
for section_name in valid_variable:
|
|
||||||
for array in config.get(section_name, {}).get("variable", {}).values():
|
|
||||||
sim_num *= len(array)
|
|
||||||
variable_params_num += 1
|
|
||||||
|
|
||||||
sim_num *= config["simulation"].get("repeat", 1)
|
|
||||||
return sim_num, variable_params_num
|
return sim_num, variable_params_num
|
||||||
|
|
||||||
|
|
||||||
@@ -217,49 +209,45 @@ def format_value(value):
|
|||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
def variable_iterator(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
|
def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]:
|
||||||
"""given a config with "variable" parameters, iterates through every possible combination,
|
"""given a config with "variable" parameters, iterates through every possible combination,
|
||||||
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
|
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
config : dict
|
config : BareConfig
|
||||||
initial config dictionary
|
initial config obj
|
||||||
|
|
||||||
Yields
|
Yields
|
||||||
-------
|
-------
|
||||||
Iterator[Tuple[List[Tuple[str, Any]], dict]]
|
Iterator[Tuple[List[Tuple[str, Any]], BareParams]]
|
||||||
variable_list : a list of (name, value) tuple of parameter name and value that are variable.
|
variable_list : a list of (name, value) tuple of parameter name and value that are variable.
|
||||||
|
|
||||||
dict : a config dictionary for one simulation
|
params : a BareParams obj for one simulation
|
||||||
"""
|
"""
|
||||||
indiv_config = deepcopy(config)
|
|
||||||
variable_dict = {
|
|
||||||
section_name: indiv_config.get(section_name, {}).pop("variable", {})
|
|
||||||
for section_name in valid_variable
|
|
||||||
}
|
|
||||||
|
|
||||||
possible_keys = []
|
possible_keys = []
|
||||||
possible_ranges = []
|
possible_ranges = []
|
||||||
|
|
||||||
for section_name, section in variable_dict.items():
|
for key, values in config.variable.items():
|
||||||
for key in section:
|
possible_keys.append(key)
|
||||||
arr = variable_dict[section_name][key]
|
possible_ranges.append(range(len(values)))
|
||||||
possible_keys.append((section_name, key))
|
|
||||||
possible_ranges.append(range(len(arr)))
|
|
||||||
|
|
||||||
combinations = itertools.product(*possible_ranges)
|
combinations = itertools.product(*possible_ranges)
|
||||||
|
|
||||||
for combination in combinations:
|
for combination in combinations:
|
||||||
|
indiv_config = {}
|
||||||
variable_list = []
|
variable_list = []
|
||||||
for i, key in enumerate(possible_keys):
|
for i, key in enumerate(possible_keys):
|
||||||
parameter_value = variable_dict[key[0]][key[1]][combination[i]]
|
parameter_value = config.variable[key][combination[i]]
|
||||||
indiv_config[key[0]][key[1]] = parameter_value
|
indiv_config[key] = parameter_value
|
||||||
variable_list.append((key[1], parameter_value))
|
variable_list.append((key, parameter_value))
|
||||||
yield variable_list, indiv_config
|
param_dict = asdict(config)
|
||||||
|
param_dict.pop("variable")
|
||||||
|
param_dict.update(indiv_config)
|
||||||
|
yield variable_list, BareParams(**param_dict)
|
||||||
|
|
||||||
|
|
||||||
def required_simulations(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
|
def required_simulations(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]:
|
||||||
"""takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different
|
"""takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different
|
||||||
parameter set and iterates through every single necessary simulation
|
parameter set and iterates through every single necessary simulation
|
||||||
|
|
||||||
@@ -273,48 +261,19 @@ def required_simulations(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]
|
|||||||
dict : a config dictionary for one simulation
|
dict : a config dictionary for one simulation
|
||||||
"""
|
"""
|
||||||
i = 0 # unique sim id
|
i = 0 # unique sim id
|
||||||
for variable_only, full_config in variable_iterator(config):
|
for variable_only, bare_params in variable_iterator(config):
|
||||||
for j in range(config["simulation"]["repeat"]):
|
for j in range(config.repeat):
|
||||||
variable_ind = [("id", i)] + variable_only + [("num", j)]
|
variable_ind = [("id", i)] + variable_only + [("num", j)]
|
||||||
i += 1
|
i += 1
|
||||||
yield variable_ind, full_config
|
yield variable_ind, bare_params
|
||||||
|
|
||||||
|
|
||||||
def deep_update(d: Mapping, u: Mapping) -> dict:
|
def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
|
||||||
for k, v in u.items():
|
|
||||||
if isinstance(v, collections.abc.Mapping):
|
|
||||||
d[k] = deep_update(d.get(k, {}), v)
|
|
||||||
else:
|
|
||||||
d[k] = v
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def override_config(new: Dict[str, Any], old: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
||||||
"""makes sure all the parameters set in new are there, leaves untouched parameters in old"""
|
"""makes sure all the parameters set in new are there, leaves untouched parameters in old"""
|
||||||
if old is None:
|
if old is None:
|
||||||
return new
|
return BareConfig(**new)
|
||||||
out = deepcopy(old)
|
variable = deepcopy(old.variable)
|
||||||
for section_name, section in new.items():
|
variable.update(new.pop("variable", {})) # add new variable
|
||||||
if isinstance(section, Mapping):
|
for k in new:
|
||||||
for param_name, value in section.items():
|
variable.pop(k) # remove old ones
|
||||||
if param_name == "variable" and isinstance(value, Mapping):
|
return replace(old, variable=variable, **{k: None for k in variable}, **new)
|
||||||
out[section_name].setdefault("variable", {})
|
|
||||||
for p, v in value.items():
|
|
||||||
# override previously unvariable param
|
|
||||||
if p in old[section_name]:
|
|
||||||
del out[section_name][p]
|
|
||||||
out[section_name]["variable"][p] = v
|
|
||||||
else:
|
|
||||||
# override previously variable param
|
|
||||||
if (
|
|
||||||
"variable" in old[section_name]
|
|
||||||
and isinstance(old[section_name]["variable"], Mapping)
|
|
||||||
and param_name in old[section_name]["variable"]
|
|
||||||
):
|
|
||||||
del out[section_name]["variable"][param_name]
|
|
||||||
if len(out[section_name]["variable"]) == 0:
|
|
||||||
del out[section_name["variable"]]
|
|
||||||
out[section_name][param_name] = value
|
|
||||||
else:
|
|
||||||
out[section_name] = section
|
|
||||||
return out
|
|
||||||
442
src/scgenerator/utils/parameter.py
Normal file
442
src/scgenerator/utils/parameter.py
Normal file
@@ -0,0 +1,442 @@
|
|||||||
|
import datetime
|
||||||
|
from copy import copy
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..const import __version__
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def type_checker(*types):
|
||||||
|
def _type_checker_wrapper(validator, n=None):
|
||||||
|
if isinstance(validator, str) and n is not None:
|
||||||
|
_name = validator
|
||||||
|
validator = lambda *args: None
|
||||||
|
|
||||||
|
def _type_checker_wrapped(name, n):
|
||||||
|
if not isinstance(n, types):
|
||||||
|
raise TypeError(
|
||||||
|
f"{name!r} value must be of type {' or '.join(format(t) for t in types)} "
|
||||||
|
f"instead of {type(n)}"
|
||||||
|
)
|
||||||
|
validator(name, n)
|
||||||
|
|
||||||
|
if n is None:
|
||||||
|
return _type_checker_wrapped
|
||||||
|
else:
|
||||||
|
_type_checker_wrapped(_name, n)
|
||||||
|
|
||||||
|
return _type_checker_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@type_checker(str)
|
||||||
|
def string(name, n):
|
||||||
|
if len(n) == 0:
|
||||||
|
raise ValueError(f"{name!r} must not be empty")
|
||||||
|
|
||||||
|
|
||||||
|
def in_range_excl(_min, _max):
|
||||||
|
@type_checker(float, int)
|
||||||
|
def _in_range(name, n):
|
||||||
|
if n <= _min or n >= _max:
|
||||||
|
raise ValueError(f"{name!r} must be between {_min} and {_max} (exclusive)")
|
||||||
|
|
||||||
|
return _in_range
|
||||||
|
|
||||||
|
|
||||||
|
def in_range_incl(_min, _max):
|
||||||
|
@type_checker(float, int)
|
||||||
|
def _in_range(name, n):
|
||||||
|
if n < _min or n > _max:
|
||||||
|
raise ValueError(f"{name!r} must be between {_min} and {_max} (inclusive)")
|
||||||
|
|
||||||
|
return _in_range
|
||||||
|
|
||||||
|
|
||||||
|
def boolean(name, n):
|
||||||
|
if not n is True and not n is False:
|
||||||
|
raise ValueError(f"{name!r} must be True or False")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def non_negative(*types):
|
||||||
|
@type_checker(*types)
|
||||||
|
def _non_negative(name, n):
|
||||||
|
if n < 0:
|
||||||
|
raise ValueError(f"{name!r} must be non negative")
|
||||||
|
|
||||||
|
return _non_negative
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def positive(*types):
|
||||||
|
@type_checker(*types)
|
||||||
|
def _positive(name, n):
|
||||||
|
if n <= 0:
|
||||||
|
raise ValueError(f"{name!r} must be positive")
|
||||||
|
|
||||||
|
return _positive
|
||||||
|
|
||||||
|
|
||||||
|
@type_checker(tuple, list)
|
||||||
|
def int_pair(name, t):
|
||||||
|
invalid = len(t) != 2
|
||||||
|
for m in t:
|
||||||
|
if invalid or not isinstance(m, int):
|
||||||
|
raise ValueError(f"{name!r} must be a list or a tuple of 2 int")
|
||||||
|
|
||||||
|
|
||||||
|
def literal(*l):
|
||||||
|
l = set(l)
|
||||||
|
|
||||||
|
@type_checker(str)
|
||||||
|
def _string(name, s):
|
||||||
|
if not s in l:
|
||||||
|
raise ValueError(f"{name!r} must be a str in {l}")
|
||||||
|
|
||||||
|
return _string
|
||||||
|
|
||||||
|
|
||||||
|
def validator_list(validator):
|
||||||
|
"""returns a new validator that applies validator to each el of an iterable"""
|
||||||
|
|
||||||
|
@type_checker(list, tuple)
|
||||||
|
def _list_validator(name, l):
|
||||||
|
for i, el in enumerate(l):
|
||||||
|
validator(name + f"[{i}]", el)
|
||||||
|
|
||||||
|
return _list_validator
|
||||||
|
|
||||||
|
|
||||||
|
def validator_or(*validators):
|
||||||
|
"""combines many validators and raises an exception only if all of them raise an exception"""
|
||||||
|
|
||||||
|
n = len(validators)
|
||||||
|
|
||||||
|
def _or_validator(name, value):
|
||||||
|
errors = []
|
||||||
|
for validator in validators:
|
||||||
|
try:
|
||||||
|
validator(name, value)
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
errors.append(e)
|
||||||
|
errors.sort(key=lambda el: isinstance(el, ValueError))
|
||||||
|
if len(errors) == n:
|
||||||
|
raise errors[-1]
|
||||||
|
|
||||||
|
return _or_validator
|
||||||
|
|
||||||
|
|
||||||
|
def validator_and(*validators):
|
||||||
|
def _and_validator(name, n):
|
||||||
|
for v in validators:
|
||||||
|
v(name, n)
|
||||||
|
|
||||||
|
return _and_validator
|
||||||
|
|
||||||
|
|
||||||
|
@type_checker(list, tuple, np.ndarray)
|
||||||
|
def num_list(name, l):
|
||||||
|
for i, el in enumerate(l):
|
||||||
|
type_checker(int, float)(name + f"[{i}]", el)
|
||||||
|
|
||||||
|
|
||||||
|
def func_validator(name, n):
|
||||||
|
if not callable(n):
|
||||||
|
raise TypeError(f"{name!r} must be callable")
|
||||||
|
|
||||||
|
|
||||||
|
class Parameter:
|
||||||
|
def __init__(self, validator, converter=None, default=None):
|
||||||
|
"""Single parameter
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tpe : type
|
||||||
|
type of the paramter
|
||||||
|
validators : Callable[[str, Any], None]
|
||||||
|
signature : validator(name, value)
|
||||||
|
must raise a ValueError when value doesn't fit the criteria checked by
|
||||||
|
validator. name is passed to validator to be included in the error message
|
||||||
|
converter : Callable, optional
|
||||||
|
converts a valid value (for example, str.lower), by default None
|
||||||
|
default : callable, optional
|
||||||
|
factory function for a default value (for example, list), by default None
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.validator = validator
|
||||||
|
self.converter = converter
|
||||||
|
self.default = default
|
||||||
|
|
||||||
|
def __set_name__(self, owner, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __get__(self, instance, owner):
|
||||||
|
if not instance:
|
||||||
|
return self
|
||||||
|
return instance.__dict__[self.name]
|
||||||
|
|
||||||
|
def __delete__(self, instance):
|
||||||
|
del instance.__dict__[self.name]
|
||||||
|
|
||||||
|
def __set__(self, instance, value):
|
||||||
|
if isinstance(value, Parameter):
|
||||||
|
defaut = None if self.default is None else copy(self.default)
|
||||||
|
instance.__dict__[self.name] = defaut
|
||||||
|
else:
|
||||||
|
if value is not None:
|
||||||
|
self.validator(self.name, value)
|
||||||
|
if self.converter is not None:
|
||||||
|
value = self.converter(value)
|
||||||
|
instance.__dict__[self.name] = value
|
||||||
|
|
||||||
|
|
||||||
|
class VariableParameter:
|
||||||
|
def __init__(self, parameterBase):
|
||||||
|
self.pbase = parameterBase
|
||||||
|
|
||||||
|
def __set_name__(self, owner, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __get__(self, instance, owner):
|
||||||
|
if not instance:
|
||||||
|
return self
|
||||||
|
return instance.__dict__[self.name]
|
||||||
|
|
||||||
|
def __delete__(self, instance):
|
||||||
|
del instance.__dict__[self.name]
|
||||||
|
|
||||||
|
def __set__(self, instance, value: dict):
|
||||||
|
if isinstance(value, VariableParameter):
|
||||||
|
value = {}
|
||||||
|
else:
|
||||||
|
for k, v in value.items():
|
||||||
|
if k not in valid_variable:
|
||||||
|
raise TypeError(f"{k!r} is not a valide variable parameter")
|
||||||
|
if len(v) == 0:
|
||||||
|
raise ValueError(f"variable parameter {k!r} must not be empty")
|
||||||
|
|
||||||
|
p = getattr(self.pbase, k)
|
||||||
|
|
||||||
|
for el in v:
|
||||||
|
p.validator(k, el)
|
||||||
|
instance.__dict__[self.name] = value
|
||||||
|
|
||||||
|
|
||||||
|
valid_variable = {
|
||||||
|
"beta",
|
||||||
|
"gamma",
|
||||||
|
"pitch",
|
||||||
|
"pitch_ratio",
|
||||||
|
"core_radius",
|
||||||
|
"capillary_num",
|
||||||
|
"capillary_outer_d",
|
||||||
|
"capillary_thickness",
|
||||||
|
"capillary_spacing",
|
||||||
|
"capillary_resonance_strengths",
|
||||||
|
"capillary_nested",
|
||||||
|
"he_mode",
|
||||||
|
"fit_parameters",
|
||||||
|
"input_transmission",
|
||||||
|
"n2",
|
||||||
|
"pressure",
|
||||||
|
"temperature",
|
||||||
|
"gas_name",
|
||||||
|
"plasma_density" "peak_power",
|
||||||
|
"mean_power",
|
||||||
|
"peak_power",
|
||||||
|
"energy",
|
||||||
|
"quantum_noise",
|
||||||
|
"shape",
|
||||||
|
"wavelength",
|
||||||
|
"intensity_noise",
|
||||||
|
"width",
|
||||||
|
"soliton_num",
|
||||||
|
"behaviors",
|
||||||
|
"raman_type",
|
||||||
|
"tolerated_error",
|
||||||
|
"step_size",
|
||||||
|
"ideal_gas",
|
||||||
|
"readjust_wavelength",
|
||||||
|
}
|
||||||
|
|
||||||
|
hc_model_specific_parameters = dict(
|
||||||
|
marcatili=["core_radius", "he_mode"],
|
||||||
|
marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"],
|
||||||
|
hasan=[
|
||||||
|
"core_radius",
|
||||||
|
"capillary_num",
|
||||||
|
"capillary_thickness",
|
||||||
|
"capillary_resonance_strengths",
|
||||||
|
"capillary_nested",
|
||||||
|
"capillary_spacing",
|
||||||
|
"capillary_outer_d",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BareParams:
|
||||||
|
"""
|
||||||
|
This class defines each valid parameter's name, type and valid value but doesn't provide
|
||||||
|
any method to act on those. For that, use initialize.Params
|
||||||
|
"""
|
||||||
|
|
||||||
|
# root
|
||||||
|
name: str = Parameter(string)
|
||||||
|
prev_data_dir: str = Parameter(string)
|
||||||
|
|
||||||
|
# # fiber
|
||||||
|
input_transmission: float = Parameter(in_range_incl(0, 1))
|
||||||
|
gamma: float = Parameter(non_negative(float, int))
|
||||||
|
n2: float = Parameter(non_negative(float, int))
|
||||||
|
effective_mode_diameter: float = Parameter(positive(float, int))
|
||||||
|
A_eff: float = Parameter(non_negative(float, int))
|
||||||
|
pitch: float = Parameter(in_range_excl(0, 1e-3))
|
||||||
|
pitch_ratio: float = Parameter(in_range_excl(0, 1))
|
||||||
|
core_radius: float = Parameter(in_range_excl(0, 1e-3))
|
||||||
|
he_mode: Tuple[int, int] = Parameter(int_pair)
|
||||||
|
fit_parameters: Tuple[int, int] = Parameter(int_pair)
|
||||||
|
beta: Iterable[float] = Parameter(num_list)
|
||||||
|
dispersion_file: str = Parameter(string)
|
||||||
|
model: str = Parameter(literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"))
|
||||||
|
length: float = Parameter(non_negative(float, int))
|
||||||
|
capillary_num: int = Parameter(positive(int))
|
||||||
|
capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3))
|
||||||
|
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3))
|
||||||
|
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3))
|
||||||
|
capillary_resonance_strengths: Iterable[float] = Parameter(num_list)
|
||||||
|
capillary_nested: int = Parameter(non_negative(int))
|
||||||
|
|
||||||
|
# gas
|
||||||
|
gas_name: str = Parameter(literal("vacuum", "helium", "air"), converter=str.lower)
|
||||||
|
pressure: Union[float, Iterable[float]] = Parameter(
|
||||||
|
validator_or(non_negative(float, int), num_list)
|
||||||
|
)
|
||||||
|
temperature: float = Parameter(positive(float, int))
|
||||||
|
plasma_density: float = Parameter(non_negative(float, int))
|
||||||
|
|
||||||
|
# pulse
|
||||||
|
field_file: str = Parameter(string)
|
||||||
|
repetition_rate: float = Parameter(non_negative(float, int))
|
||||||
|
peak_power: float = Parameter(positive(float, int))
|
||||||
|
mean_power: float = Parameter(positive(float, int))
|
||||||
|
energy: float = Parameter(positive(float, int))
|
||||||
|
soliton_num: float = Parameter(positive(float, int))
|
||||||
|
quantum_noise: bool = Parameter(boolean)
|
||||||
|
shape: str = Parameter(literal("gaussian", "sech"))
|
||||||
|
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9))
|
||||||
|
intensity_noise: float = Parameter(in_range_incl(0, 1))
|
||||||
|
width: float = Parameter(in_range_excl(0, 1e-9))
|
||||||
|
t0: float = Parameter(in_range_excl(0, 1e-9))
|
||||||
|
|
||||||
|
# simulation
|
||||||
|
behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")))
|
||||||
|
parallel: bool = Parameter(boolean)
|
||||||
|
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
|
||||||
|
ideal_gas: bool = Parameter(boolean)
|
||||||
|
repeat: int = Parameter(positive(int))
|
||||||
|
t_num: int = Parameter(positive(int))
|
||||||
|
z_num: int = Parameter(positive(int))
|
||||||
|
time_window: float = Parameter(positive(float, int))
|
||||||
|
dt: float = Parameter(in_range_excl(0, 5e-15))
|
||||||
|
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-5))
|
||||||
|
step_size: float = Parameter(positive(float, int))
|
||||||
|
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9))
|
||||||
|
upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9))
|
||||||
|
frep: float = Parameter(positive(float, int))
|
||||||
|
prev_sim_dir: str = Parameter(string)
|
||||||
|
readjust_wavelength: bool = Parameter(boolean)
|
||||||
|
recovery_last_stored: int = Parameter(non_negative(int))
|
||||||
|
|
||||||
|
# computed
|
||||||
|
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
w_c: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
t: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
L_D: float = Parameter(non_negative(float, int))
|
||||||
|
L_NL: float = Parameter(non_negative(float, int))
|
||||||
|
L_sol: float = Parameter(non_negative(float, int))
|
||||||
|
dynamic_dispersion: bool = Parameter(boolean)
|
||||||
|
adapt_step_size: bool = Parameter(boolean)
|
||||||
|
error_ok: float = Parameter(positive(float))
|
||||||
|
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
beta_func: Callable[[float], List[float]] = Parameter(func_validator)
|
||||||
|
gamma_func: Callable[[float], float] = Parameter(func_validator)
|
||||||
|
|
||||||
|
def prepare_for_dump(self) -> Dict[str, Any]:
|
||||||
|
param = asdict(self)
|
||||||
|
param = BareParams.strip_params_dict(param)
|
||||||
|
param["datetime"] = datetime.datetime.now()
|
||||||
|
param["version"] = __version__
|
||||||
|
return param
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def strip_params_dict(dico: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""prepares a dictionary for serialization. Some keys may not be preserved
|
||||||
|
(dropped because they take a lot of space and can be exactly reconstructed)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dico : dict
|
||||||
|
dictionary
|
||||||
|
"""
|
||||||
|
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"]
|
||||||
|
types = (np.ndarray, float, int, str, list, tuple, dict)
|
||||||
|
out = {}
|
||||||
|
for key, value in dico.items():
|
||||||
|
if key in forbiden_keys:
|
||||||
|
continue
|
||||||
|
if not isinstance(value, types):
|
||||||
|
continue
|
||||||
|
if isinstance(value, dict):
|
||||||
|
out[key] = BareParams.strip_params_dict(value)
|
||||||
|
elif isinstance(value, np.ndarray) and value.dtype == complex:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
out[key] = value
|
||||||
|
|
||||||
|
if "variable" in out and len(out["variable"]) == 0:
|
||||||
|
del out["variable"]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BareConfig(BareParams):
|
||||||
|
variable: dict = VariableParameter(BareParams)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
numero = type_checker(int)
|
||||||
|
|
||||||
|
@numero
|
||||||
|
def natural_number(name, n):
|
||||||
|
if n < 0:
|
||||||
|
raise ValueError(f"{name!r} must be positive")
|
||||||
|
|
||||||
|
try:
|
||||||
|
numero("a", np.arange(45))
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
try:
|
||||||
|
natural_number("b", -1)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
try:
|
||||||
|
natural_number("c", 1.0)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
try:
|
||||||
|
natural_number("d", 1)
|
||||||
|
print("success !")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
37
testing/test_new_params.py
Normal file
37
testing/test_new_params.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from numba.core import config
|
||||||
|
from scgenerator.initialize import Config, Params, BareParams
|
||||||
|
from scgenerator.utils import variable_iterator, override_config
|
||||||
|
from scgenerator.io import load_toml
|
||||||
|
from pprint import pprint
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
dico = load_toml("testing/configs/ensure_consistency/good2.toml")
|
||||||
|
out = dict(variable=dict())
|
||||||
|
for k, v in dico.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
for kk, vv in v.items():
|
||||||
|
if kk == "variable":
|
||||||
|
for kkk, vvv in vv.items():
|
||||||
|
out["variable"][kkk] = vvv
|
||||||
|
else:
|
||||||
|
out[kk] = vv
|
||||||
|
|
||||||
|
pprint(out)
|
||||||
|
p = Config(**out)
|
||||||
|
print(p)
|
||||||
|
|
||||||
|
for l, c in variable_iterator(p):
|
||||||
|
print(l, c.width, c.intensity_noise)
|
||||||
|
print()
|
||||||
|
|
||||||
|
config2 = override_config(dict(width=1.2e-13, variable=dict(peak_power=[1e5, 2e5])), p)
|
||||||
|
print(
|
||||||
|
f"{config2.variable=}",
|
||||||
|
f"{config2.intensity_noise=}",
|
||||||
|
f"{config2.width=}",
|
||||||
|
f"{config2.peak_power=}",
|
||||||
|
)
|
||||||
|
|
||||||
|
par = BareParams()
|
||||||
|
|
||||||
|
print(all(v is None for v in vars(par).values()))
|
||||||
Reference in New Issue
Block a user