better parser and params translation

This commit is contained in:
Benoît Sierro
2021-10-20 16:50:07 +02:00
parent 717a7db84f
commit 7345eb8fea
10 changed files with 152 additions and 184 deletions

View File

@@ -260,14 +260,8 @@ time_window: float
total length of the temporal grid in s
### optional
behaviors: list of str {"spm", "raman", "ss"}
spm is self-phase modulation
raman is raman effect
ss is self-steepening
default : ["spm", "ss"]
raman_type: str {"measured", "stolen", "agrawal"}
type of Raman effect. Default is "agrawal".
raman_type: str {"measured", "stolen", "agrawal"}, optional
type of Raman effect. Specifying this parameter has the effect of turning on Raman effect
ideal_gas: bool
if True, use the ideal gas law. Otherwise, use van der Waals equation. default : False
@@ -285,7 +279,7 @@ step_size: float
if given, sets a constant step size rather than adapting it.
parallel: bool
whether to run simulations in parallel with the available ressources. default : false
whether to run simulations in parallel with the available resources. default : false
repeat: int
how many simulations to run per parameter set. default : 1

View File

@@ -1,3 +1,4 @@
# flake8: noqa
from . import math
from .legacy import convert_sim_folder
from .math import abs2, argclosest, span

View File

@@ -64,7 +64,6 @@ VALID_VARIABLE = {
"width",
"t0",
"soliton_num",
"behaviors",
"raman_type",
"tolerated_error",
"step_size",
@@ -85,15 +84,23 @@ MANDATORY_PARAMETERS = [
"input_transmission",
"z_targets",
"length",
"beta2_coefficients",
"gamma_arr",
"behaviors",
"adapt_step_size",
"tolerated_error",
"dynamic_dispersion",
"recovery_last_stored",
"output_path",
"repeat",
"linear_operator",
"nonlinear_operator",
]
ROOT_PARAMETERS = [
"repeat",
"num",
"dt",
"t_num",
"time_window",
"step_size",
"tolerated_error",
"width",
"shape",
]

View File

@@ -349,7 +349,7 @@ default_rules: list[Rule] = [
Rule("A_eff", fiber.A_eff_marcatili, priorities=-2),
Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]),
Rule("A_eff_arr", fiber.load_custom_A_eff),
Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
# Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
Rule(
"V_eff",
fiber.V_parameter_koshiba,
@@ -364,6 +364,7 @@ default_rules: list[Rule] = [
["l", "core_radius", "numerical_aperture", "interpolation_range"],
),
Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1),
Rule("gamma", fiber.gamma_parameter),
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]),
Rule("n2", materials.gas_n2),
Rule("n2", lambda: 2.2e-20, priorities=-1),

View File

@@ -1,6 +1,6 @@
from genericpath import exists
import os
import sys
from collections import MutableMapping
from pathlib import Path
from typing import Any, Set
@@ -9,8 +9,8 @@ import toml
from .const import SPEC1_FN, SPEC1_FN_N, SPECN_FN1
from .parameter import Configuration, Parameters
from .utils import save_parameters
from .pbar import PBars
from .utils import save_parameters
from .variationer import VariationDescriptor
@@ -87,6 +87,45 @@ def _mv_specs(pbar: PBars, new_params: Parameters, start_z: int, spec_num: int,
pbar.update()
def translate_parameters(d: dict[str, Any]) -> dict[str, Any]:
"""translate parameters name and value from older versions of the program
Parameters
----------
d : dict[str, Any]
[description]
Returns
-------
dict[str, Any]
[description]
"""
old_names = dict(
interp_degree="interpolation_degree",
beta="beta2_coefficients",
interp_range="interpolation_range",
)
wl_limits_old = ["lower_wavelength_interp_limit", "upper_wavelength_interp_limit"]
defaults_to_add = dict(repeat=1)
new = {}
if len(set(wl_limits_old) & d.keys()) == 2:
new["interpolation_range"] = (d[wl_limits_old[0]], d[wl_limits_old[1]])
for k, v in d.items():
if k == "error_ok":
new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v
elif k == "behaviors":
beh = d["behaviors"]
if "raman" in beh:
new["raman_type"] = d["raman_type"]
new["spm"] = "spm" in beh
new["self_steepening"] = "ss" in beh
elif isinstance(v, MutableMapping):
new[k] = translate_parameters(v)
else:
new[old_names.get(k, k)] = v
return defaults_to_add | new
def main():
convert_sim_folder(sys.argv[1])

View File

@@ -1,7 +1,6 @@
from typing import Union
import numpy as np
from scipy.interpolate import griddata, interp1d
from scipy.special import jn_zeros
from .cache import np_cache
@@ -172,50 +171,6 @@ def indft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
return indft_matrix(t, f) @ a
def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"):
"""Interpolates a 2D array with the help of griddata
Parameters
----------
values : 2D array of real values
x_axis : x-coordinates of values
y_axis : y-coordinates of values
method : method of interpolation to be passed to griddata
Returns
----------
array of shape n
"""
xx, yy = np.meshgrid(x_axis, y_axis)
xx = xx.flatten()
yy = yy.flatten()
if not isinstance(n, tuple):
n = (n, n)
# old_points = np.array([gridx.ravel(), gridy.ravel()])
newx, newy = np.meshgrid(np.linspace(*span(x_axis), n[0]), np.linspace(*span(y_axis), n[1]))
print("interpolating")
out = griddata((xx, yy), values.flatten(), (newx, newy), method=method, fill_value=0)
print("interpolating done!")
return out.reshape(n[1], n[0])
def make_uniform_1D(values, x_axis, n=1024, method="linear"):
"""Interpolates a 2D array with the help of interp1d
Parameters
----------
values : 1D array of real values
x_axis : x-coordinates of values
method : method of interpolation to be passed to interp1d
Returns
----------
array of length n
"""
xx = np.linspace(*span(x_axis), len(x_axis))
return interp1d(x_axis, values, kind=method)(xx)
def all_zeros(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""find all the x values such that y(x)=0 with linear interpolation"""
pos = np.argwhere(y[1:] * y[:-1] < 0)[:, 0]

View File

@@ -17,7 +17,7 @@ from .physics import fiber, pulse
class SpectrumDescriptor:
name: str
value: np.ndarray
value: np.ndarray = None
def __set__(self, instance, value):
instance.spec2 = math.abs2(value)

View File

@@ -12,13 +12,13 @@ from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
import numpy as np
from . import env, utils
from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS
from . import env, legacy, utils
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
from .evaluator import Evaluator
from .logger import get_logger
from .operators import LinearOperator, NonLinearOperator
from .utils import fiber_folder, update_path_name
from .variationer import VariationDescriptor, Variationer
from .evaluator import Evaluator
from .operators import NonLinearOperator, LinearOperator
T = TypeVar("T")
@@ -312,11 +312,9 @@ class Parameters:
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
# simulation
behaviors: tuple[str] = Parameter(
validator_list(literal("spm", "raman", "ss")), converter=tuple, default=("spm", "ss")
)
parallel: bool = Parameter(boolean, default=True)
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
self_steepening: bool = Parameter(boolean, default=True)
spm: bool = Parameter(boolean, default=True)
ideal_gas: bool = Parameter(boolean, default=False)
repeat: int = Parameter(positive(int), default=1)
t_num: int = Parameter(positive(int))
@@ -329,6 +327,7 @@ class Parameters:
interpolation_degree: int = Parameter(positive(int), default=8)
prev_sim_dir: str = Parameter(string)
recovery_last_stored: int = Parameter(non_negative(int), default=0)
parallel: bool = Parameter(boolean, default=True)
worker_num: int = Parameter(positive(int))
# computed
@@ -459,9 +458,9 @@ class Configuration:
obj with the output path of the simulation saved in its output_path attribute.
"""
fiber_configs: list[dict[str, Any]]
fiber_configs: list[utils.SubConfig]
vary_dicts: list[dict[str, list]]
master_config: dict[str, Any]
master_config_dict: dict[str, Any]
fiber_paths: list[Path]
num_sim: int
num_fibers: int
@@ -515,51 +514,47 @@ class Configuration:
mkdir=False,
prevent_overwrite=not self.overwrite,
)
self.master_config = self.fiber_configs[0].copy()
self.master_config_dict = self.fiber_configs[0].fixed | {
k: v[0] for vary_dict in self.fiber_configs[0].variable for k, v in vary_dict.items()
}
self.name = self.final_path.name
self.z_num = 0
self.total_num_steps = 0
self.fiber_paths = []
self.all_configs = {}
self.skip_callback = skip_callback
self.worker_num = self.master_config.get("worker_num", max(1, os.cpu_count() // 2))
self.repeat = self.master_config.get("repeat", 1)
self.worker_num = self.master_config_dict.get("worker_num", max(1, os.cpu_count() // 2))
self.repeat = self.master_config_dict.get("repeat", 1)
self.variationer = Variationer()
fiber_names = set()
self.num_fibers = 0
for i, config in enumerate(self.fiber_configs):
config.setdefault("name", Parameters.name.default)
self.z_num += config["z_num"]
fiber_names.add(config["name"])
vary_dict_list: list[dict[str, list]] = config.pop("variable")
self.variationer.append(vary_dict_list)
config.fixed.setdefault("name", Parameters.name.default)
self.z_num += config.fixed["z_num"]
fiber_names.add(config.fixed["name"])
self.variationer.append(config.variable)
self.fiber_paths.append(
utils.ensure_folder(
self.final_path / fiber_folder(i, self.name, config["name"]),
self.final_path / fiber_folder(i, self.name, config.fixed["name"]),
mkdir=False,
prevent_overwrite=not self.overwrite,
)
)
self.__validate_variable(vary_dict_list)
self.__validate_variable(config.variable)
self.num_fibers += 1
Evaluator.evaluate_default(
self.__build_base_config()
| config
| {k: v[0] for vary_dict in vary_dict_list for k, v in vary_dict.items()},
self.master_config_dict
| config.fixed
| {k: v[0] for vary_dict in config.variable for k, v in vary_dict.items()},
True,
)
self.num_sim = self.variationer.var_num()
self.total_num_steps = sum(
config["z_num"] * self.variationer.var_num(i)
config.fixed["z_num"] * self.variationer.var_num(i)
for i, config in enumerate(self.fiber_configs)
)
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
def __build_base_config(self):
cfg = self.master_config.copy()
vary: list[dict[str, list]] = cfg.pop("variable")
return cfg | {k: v[0] for vary_dict in vary for k, v in vary_dict.items()}
self.parallel = self.master_config_dict.get("parallel", Parameters.parallel.default)
def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
for vary_dict in vary_dict_list:
@@ -593,7 +588,7 @@ class Configuration:
index = self.num_fibers + index
sim_dict: dict[Path, Configuration.__SimConfig] = {}
for descriptor in self.variationer.iterate(index):
cfg = descriptor.update_config(self.fiber_configs[index])
cfg = descriptor.update_config(self.fiber_configs[index].fixed)
if index > 0:
cfg["prev_data_dir"] = str(
self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True)
@@ -611,7 +606,8 @@ class Configuration:
task, config_dict = self.__decide(sim_config)
if task == self.Action.RUN:
sim_dict.pop(data_dir)
yield sim_config.descriptor, Parameters(**sim_config.config)
param_dict = legacy.translate_parameters(sim_config.config)
yield sim_config.descriptor, Parameters(**param_dict)
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
self.skip_callback(config_dict["recovery_last_stored"])
break
@@ -695,10 +691,7 @@ class Configuration:
def save_parameters(self):
os.makedirs(self.final_path, exist_ok=True)
cfgs = [
cfg | dict(variable=self.variationer.all_dicts[i])
for i, cfg in enumerate(self.fiber_configs)
]
cfgs = [cfg.fixed | dict(variable=cfg.variable) for cfg in self.fiber_configs]
utils.save_toml(self.final_path / "initial_config.toml", dict(name=self.name, Fiber=cfgs))
@property

View File

@@ -14,14 +14,15 @@ from dataclasses import dataclass
from functools import cache
from pathlib import Path
from string import printable as str_printable
from typing import Any, Callable, MutableMapping, Sequence, TypeVar
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Set
import numpy as np
import pkg_resources as pkg
import toml
from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN
from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN, ROOT_PARAMETERS
from .logger import get_logger
from .errors import DuplicateParameterError
T_ = TypeVar("T_")
@@ -74,39 +75,51 @@ class Paths:
return os.path.join(cls.get("plots"), name)
class ConfigFileParser:
path: Path
repeat: int
master: ConfigFileParser.SubConfig
configs: list[ConfigFileParser.SubConfig]
@dataclass
@dataclass(init=False)
class SubConfig:
fixed: dict[str, Any]
variable: dict[str, list]
variable: list[dict[str, list]]
fixed_keys: Set[str]
variable_keys: Set[str]
def __init__(self, path: os.PathLike):
self.path = Path(path)
fiber_list: list[dict[str, Any]]
if self.path.name.lower().endswith(".toml"):
loaded_config = _open_config(self.path)
fiber_list = loaded_config.pop("Fiber")
else:
loaded_config = dict(name=self.path.name)
fiber_list = [_open_config(p) for p in sorted(self.path.glob("initial_config*.toml"))]
def __init__(self, dico: dict[str, Any]):
dico = dico.copy()
self.variable = conform_variable_entry(dico.pop("variable", []))
self.fixed = dico
self.__update
if len(fiber_list) == 0:
raise ValueError(f"No fiber in config {self.path}")
configs = []
for i, params in enumerate(fiber_list):
configs.append(loaded_config | params)
for root_vary, first_vary in itertools.product(
loaded_config["variable"], configs[0]["variable"]
):
if len(common := root_vary.keys() & first_vary.keys()) != 0:
raise ValueError(f"These variable keys are specified twice : {common!r}")
configs[0] |= {k: v for k, v in loaded_config.items() if k != "variable"}
configs[0]["variable"].append(dict(num=list(range(configs[0].get("repeat", 1)))))
def __update(self):
self.variable_keys = set()
self.fixed_keys = set()
for dico in self.variable:
for key in dico:
if key in self.variable_keys:
raise DuplicateParameterError(f"{key} is specified twice")
self.variable_keys.add(key)
for key in self.fixed:
if key in self.variable_keys:
raise DuplicateParameterError(f"{key} is specified twice")
self.fixed_keys.add(key)
def weak_update(self, other: SubConfig = None, **kwargs):
"""similar to a dict update method put prioritizes existing values
Parameters
----------
other : SubConfig
other obj
"""
if other is None:
other = SubConfig(kwargs)
self.fixed = other.fixed | self.fixed
self.variable = other.variable + self.variable
self.__update()
def conform_variable_entry(d) -> list[dict[str, list]]:
if isinstance(d, MutableMapping):
d = [{k: v} for k, v in d.items()]
return d
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
@@ -141,23 +154,11 @@ def _open_config(path: os.PathLike):
path = conform_toml_path(path)
dico = resolve_loadfile_arg(load_toml(path))
dico = standardize_variable_dicts(dico)
if "Fiber" not in dico:
dico = dict(name=path.name, Fiber=[dico])
return dico
def standardize_variable_dicts(dico: dict[str, Any]):
if "Fiber" in dico:
dico["Fiber"] = [standardize_variable_dicts(fiber) for fiber in dico["Fiber"]]
if (var := dico.get("variable")) is not None:
if isinstance(var, MutableMapping):
dico["variable"] = [var]
else:
dico["variable"] = [{}]
return dico
def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]:
if (f_list := dico.pop("INCLUDE", None)) is not None:
if isinstance(f_list, str):
@@ -196,7 +197,7 @@ def save_toml(path: os.PathLike, dico):
return dico
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]:
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[SubConfig]]:
"""loads a configuration file
Parameters
@@ -213,28 +214,26 @@ def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]
"""
path = Path(path)
fiber_list: list[dict[str, Any]]
if path.name.lower().endswith(".toml"):
loaded_config = _open_config(path)
fiber_list = loaded_config.pop("Fiber")
master_config_dict = _open_config(path)
fiber_list = [SubConfig(d) for d in master_config_dict.pop("Fiber")]
master_config = SubConfig(master_config_dict)
else:
loaded_config = dict(name=path.name)
fiber_list = [_open_config(p) for p in sorted(path.glob("initial_config*.toml"))]
master_config = SubConfig(dict(name=path.name))
fiber_list = [SubConfig(_open_config(p)) for p in sorted(path.glob("initial_config*.toml"))]
if len(fiber_list) == 0:
raise ValueError(f"No fiber in config {path}")
final_path = loaded_config.get("name")
configs = []
for i, params in enumerate(fiber_list):
configs.append(loaded_config | params)
for root_vary, first_vary in itertools.product(
loaded_config["variable"], configs[0]["variable"]
):
if len(common := root_vary.keys() & first_vary.keys()) != 0:
raise ValueError(f"These variable keys are specified twice : {common!r}")
configs[0] |= {k: v for k, v in loaded_config.items() if k != "variable"}
configs[0]["variable"].append(dict(num=list(range(configs[0].get("repeat", 1)))))
return Path(final_path), configs
for fiber in fiber_list:
fiber.weak_update(master_config)
if "num" not in fiber_list[0].variable_keys:
repeat_arg = list(range(fiber_list[0].fixed.get("repeat", 1)))
fiber_list[0].weak_update(variable=dict(num=repeat_arg))
for p_name in ROOT_PARAMETERS:
if any(p_name in conf.variable_keys for conf in fiber_list[1:]):
raise ValueError(f"{p_name} should only be specified in the root or first fiber")
configs = fiber_list
return Path(master_config.fixed["name"]), configs
@cache
@@ -340,27 +339,6 @@ def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray
)
def translate_parameters(d: dict[str, Any]) -> dict[str, Any]:
old_names = dict(
interp_degree="interpolation_degree",
beta="beta2_coefficients",
interp_range="interpolation_range",
)
deleted_names = {"lower_wavelength_interp_limit", "upper_wavelength_interp_limit"}
defaults_to_add = dict(repeat=1)
new = {}
for k, v in d.items():
if k == "error_ok":
new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v
elif k in deleted_names:
continue
elif isinstance(v, MutableMapping):
new[k] = translate_parameters(v)
else:
new[old_names.get(k, k)] = v
return defaults_to_add | new
def to_62(i: int) -> str:
arr = []
if i == 0:
@@ -445,7 +423,7 @@ def combine_simulations(path: Path, dest: Path = None):
for l in paths.values():
try:
l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0])
except ValueError:
except TypeError:
pass
for pulses in paths.values():
new_path = dest / update_path_name(pulses[0].name)

View File

@@ -79,7 +79,7 @@ class Variationer:
len_to_test = len(values[0])
if not all(len(v) == len_to_test for v in values[1:]):
raise VariationSpecsError(
f"variable items should all have the same number of parameters"
"variable items should all have the same number of parameters"
)
num_vars.append(len_to_test)
if len(num_vars) == 0: