better parser and params translation
This commit is contained in:
12
README.md
12
README.md
@@ -260,14 +260,8 @@ time_window: float
|
||||
total length of the temporal grid in s
|
||||
|
||||
### optional
|
||||
behaviors: list of str {"spm", "raman", "ss"}
|
||||
spm is self-phase modulation
|
||||
raman is raman effect
|
||||
ss is self-steepening
|
||||
default : ["spm", "ss"]
|
||||
|
||||
raman_type: str {"measured", "stolen", "agrawal"}
|
||||
type of Raman effect. Default is "agrawal".
|
||||
raman_type: str {"measured", "stolen", "agrawal"}, optional
|
||||
type of Raman effect. Specifying this parameter has the effect of turning on Raman effect
|
||||
|
||||
ideal_gas: bool
|
||||
if True, use the ideal gas law. Otherwise, use van der Waals equation. default : False
|
||||
@@ -285,7 +279,7 @@ step_size: float
|
||||
if given, sets a constant step size rather than adapting it.
|
||||
|
||||
parallel: bool
|
||||
whether to run simulations in parallel with the available ressources. default : false
|
||||
whether to run simulations in parallel with the available resources. default : false
|
||||
|
||||
repeat: int
|
||||
how many simulations to run per parameter set. default : 1
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# flake8: noqa
|
||||
from . import math
|
||||
from .legacy import convert_sim_folder
|
||||
from .math import abs2, argclosest, span
|
||||
|
||||
@@ -64,7 +64,6 @@ VALID_VARIABLE = {
|
||||
"width",
|
||||
"t0",
|
||||
"soliton_num",
|
||||
"behaviors",
|
||||
"raman_type",
|
||||
"tolerated_error",
|
||||
"step_size",
|
||||
@@ -85,15 +84,23 @@ MANDATORY_PARAMETERS = [
|
||||
"input_transmission",
|
||||
"z_targets",
|
||||
"length",
|
||||
"beta2_coefficients",
|
||||
"gamma_arr",
|
||||
"behaviors",
|
||||
"adapt_step_size",
|
||||
"tolerated_error",
|
||||
"dynamic_dispersion",
|
||||
"recovery_last_stored",
|
||||
"output_path",
|
||||
"repeat",
|
||||
"linear_operator",
|
||||
"nonlinear_operator",
|
||||
]
|
||||
|
||||
ROOT_PARAMETERS = [
|
||||
"repeat",
|
||||
"num",
|
||||
"dt",
|
||||
"t_num",
|
||||
"time_window",
|
||||
"step_size",
|
||||
"tolerated_error",
|
||||
"width",
|
||||
"shape",
|
||||
]
|
||||
|
||||
@@ -349,7 +349,7 @@ default_rules: list[Rule] = [
|
||||
Rule("A_eff", fiber.A_eff_marcatili, priorities=-2),
|
||||
Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]),
|
||||
Rule("A_eff_arr", fiber.load_custom_A_eff),
|
||||
Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
|
||||
# Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
|
||||
Rule(
|
||||
"V_eff",
|
||||
fiber.V_parameter_koshiba,
|
||||
@@ -364,6 +364,7 @@ default_rules: list[Rule] = [
|
||||
["l", "core_radius", "numerical_aperture", "interpolation_range"],
|
||||
),
|
||||
Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1),
|
||||
Rule("gamma", fiber.gamma_parameter),
|
||||
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]),
|
||||
Rule("n2", materials.gas_n2),
|
||||
Rule("n2", lambda: 2.2e-20, priorities=-1),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from genericpath import exists
|
||||
import os
|
||||
import sys
|
||||
from collections import MutableMapping
|
||||
from pathlib import Path
|
||||
from typing import Any, Set
|
||||
|
||||
@@ -9,8 +9,8 @@ import toml
|
||||
|
||||
from .const import SPEC1_FN, SPEC1_FN_N, SPECN_FN1
|
||||
from .parameter import Configuration, Parameters
|
||||
from .utils import save_parameters
|
||||
from .pbar import PBars
|
||||
from .utils import save_parameters
|
||||
from .variationer import VariationDescriptor
|
||||
|
||||
|
||||
@@ -87,6 +87,45 @@ def _mv_specs(pbar: PBars, new_params: Parameters, start_z: int, spec_num: int,
|
||||
pbar.update()
|
||||
|
||||
|
||||
def translate_parameters(d: dict[str, Any]) -> dict[str, Any]:
|
||||
"""translate parameters name and value from older versions of the program
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d : dict[str, Any]
|
||||
[description]
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, Any]
|
||||
[description]
|
||||
"""
|
||||
old_names = dict(
|
||||
interp_degree="interpolation_degree",
|
||||
beta="beta2_coefficients",
|
||||
interp_range="interpolation_range",
|
||||
)
|
||||
wl_limits_old = ["lower_wavelength_interp_limit", "upper_wavelength_interp_limit"]
|
||||
defaults_to_add = dict(repeat=1)
|
||||
new = {}
|
||||
if len(set(wl_limits_old) & d.keys()) == 2:
|
||||
new["interpolation_range"] = (d[wl_limits_old[0]], d[wl_limits_old[1]])
|
||||
for k, v in d.items():
|
||||
if k == "error_ok":
|
||||
new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v
|
||||
elif k == "behaviors":
|
||||
beh = d["behaviors"]
|
||||
if "raman" in beh:
|
||||
new["raman_type"] = d["raman_type"]
|
||||
new["spm"] = "spm" in beh
|
||||
new["self_steepening"] = "ss" in beh
|
||||
elif isinstance(v, MutableMapping):
|
||||
new[k] = translate_parameters(v)
|
||||
else:
|
||||
new[old_names.get(k, k)] = v
|
||||
return defaults_to_add | new
|
||||
|
||||
|
||||
def main():
|
||||
convert_sim_folder(sys.argv[1])
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from scipy.interpolate import griddata, interp1d
|
||||
from scipy.special import jn_zeros
|
||||
from .cache import np_cache
|
||||
|
||||
@@ -172,50 +171,6 @@ def indft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
|
||||
return indft_matrix(t, f) @ a
|
||||
|
||||
|
||||
def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"):
|
||||
"""Interpolates a 2D array with the help of griddata
|
||||
Parameters
|
||||
----------
|
||||
values : 2D array of real values
|
||||
x_axis : x-coordinates of values
|
||||
y_axis : y-coordinates of values
|
||||
method : method of interpolation to be passed to griddata
|
||||
Returns
|
||||
----------
|
||||
array of shape n
|
||||
"""
|
||||
xx, yy = np.meshgrid(x_axis, y_axis)
|
||||
xx = xx.flatten()
|
||||
yy = yy.flatten()
|
||||
|
||||
if not isinstance(n, tuple):
|
||||
n = (n, n)
|
||||
|
||||
# old_points = np.array([gridx.ravel(), gridy.ravel()])
|
||||
|
||||
newx, newy = np.meshgrid(np.linspace(*span(x_axis), n[0]), np.linspace(*span(y_axis), n[1]))
|
||||
|
||||
print("interpolating")
|
||||
out = griddata((xx, yy), values.flatten(), (newx, newy), method=method, fill_value=0)
|
||||
print("interpolating done!")
|
||||
return out.reshape(n[1], n[0])
|
||||
|
||||
|
||||
def make_uniform_1D(values, x_axis, n=1024, method="linear"):
|
||||
"""Interpolates a 2D array with the help of interp1d
|
||||
Parameters
|
||||
----------
|
||||
values : 1D array of real values
|
||||
x_axis : x-coordinates of values
|
||||
method : method of interpolation to be passed to interp1d
|
||||
Returns
|
||||
----------
|
||||
array of length n
|
||||
"""
|
||||
xx = np.linspace(*span(x_axis), len(x_axis))
|
||||
return interp1d(x_axis, values, kind=method)(xx)
|
||||
|
||||
|
||||
def all_zeros(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
"""find all the x values such that y(x)=0 with linear interpolation"""
|
||||
pos = np.argwhere(y[1:] * y[:-1] < 0)[:, 0]
|
||||
|
||||
@@ -17,7 +17,7 @@ from .physics import fiber, pulse
|
||||
|
||||
class SpectrumDescriptor:
|
||||
name: str
|
||||
value: np.ndarray
|
||||
value: np.ndarray = None
|
||||
|
||||
def __set__(self, instance, value):
|
||||
instance.spec2 = math.abs2(value)
|
||||
|
||||
@@ -12,13 +12,13 @@ from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import env, utils
|
||||
from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS
|
||||
from . import env, legacy, utils
|
||||
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
|
||||
from .evaluator import Evaluator
|
||||
from .logger import get_logger
|
||||
from .operators import LinearOperator, NonLinearOperator
|
||||
from .utils import fiber_folder, update_path_name
|
||||
from .variationer import VariationDescriptor, Variationer
|
||||
from .evaluator import Evaluator
|
||||
from .operators import NonLinearOperator, LinearOperator
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -312,11 +312,9 @@ class Parameters:
|
||||
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
||||
|
||||
# simulation
|
||||
behaviors: tuple[str] = Parameter(
|
||||
validator_list(literal("spm", "raman", "ss")), converter=tuple, default=("spm", "ss")
|
||||
)
|
||||
parallel: bool = Parameter(boolean, default=True)
|
||||
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
|
||||
self_steepening: bool = Parameter(boolean, default=True)
|
||||
spm: bool = Parameter(boolean, default=True)
|
||||
ideal_gas: bool = Parameter(boolean, default=False)
|
||||
repeat: int = Parameter(positive(int), default=1)
|
||||
t_num: int = Parameter(positive(int))
|
||||
@@ -329,6 +327,7 @@ class Parameters:
|
||||
interpolation_degree: int = Parameter(positive(int), default=8)
|
||||
prev_sim_dir: str = Parameter(string)
|
||||
recovery_last_stored: int = Parameter(non_negative(int), default=0)
|
||||
parallel: bool = Parameter(boolean, default=True)
|
||||
worker_num: int = Parameter(positive(int))
|
||||
|
||||
# computed
|
||||
@@ -459,9 +458,9 @@ class Configuration:
|
||||
obj with the output path of the simulation saved in its output_path attribute.
|
||||
"""
|
||||
|
||||
fiber_configs: list[dict[str, Any]]
|
||||
fiber_configs: list[utils.SubConfig]
|
||||
vary_dicts: list[dict[str, list]]
|
||||
master_config: dict[str, Any]
|
||||
master_config_dict: dict[str, Any]
|
||||
fiber_paths: list[Path]
|
||||
num_sim: int
|
||||
num_fibers: int
|
||||
@@ -515,51 +514,47 @@ class Configuration:
|
||||
mkdir=False,
|
||||
prevent_overwrite=not self.overwrite,
|
||||
)
|
||||
self.master_config = self.fiber_configs[0].copy()
|
||||
self.master_config_dict = self.fiber_configs[0].fixed | {
|
||||
k: v[0] for vary_dict in self.fiber_configs[0].variable for k, v in vary_dict.items()
|
||||
}
|
||||
self.name = self.final_path.name
|
||||
self.z_num = 0
|
||||
self.total_num_steps = 0
|
||||
self.fiber_paths = []
|
||||
self.all_configs = {}
|
||||
self.skip_callback = skip_callback
|
||||
self.worker_num = self.master_config.get("worker_num", max(1, os.cpu_count() // 2))
|
||||
self.repeat = self.master_config.get("repeat", 1)
|
||||
self.worker_num = self.master_config_dict.get("worker_num", max(1, os.cpu_count() // 2))
|
||||
self.repeat = self.master_config_dict.get("repeat", 1)
|
||||
self.variationer = Variationer()
|
||||
|
||||
fiber_names = set()
|
||||
self.num_fibers = 0
|
||||
for i, config in enumerate(self.fiber_configs):
|
||||
config.setdefault("name", Parameters.name.default)
|
||||
self.z_num += config["z_num"]
|
||||
fiber_names.add(config["name"])
|
||||
vary_dict_list: list[dict[str, list]] = config.pop("variable")
|
||||
self.variationer.append(vary_dict_list)
|
||||
config.fixed.setdefault("name", Parameters.name.default)
|
||||
self.z_num += config.fixed["z_num"]
|
||||
fiber_names.add(config.fixed["name"])
|
||||
self.variationer.append(config.variable)
|
||||
self.fiber_paths.append(
|
||||
utils.ensure_folder(
|
||||
self.final_path / fiber_folder(i, self.name, config["name"]),
|
||||
self.final_path / fiber_folder(i, self.name, config.fixed["name"]),
|
||||
mkdir=False,
|
||||
prevent_overwrite=not self.overwrite,
|
||||
)
|
||||
)
|
||||
self.__validate_variable(vary_dict_list)
|
||||
self.__validate_variable(config.variable)
|
||||
self.num_fibers += 1
|
||||
Evaluator.evaluate_default(
|
||||
self.__build_base_config()
|
||||
| config
|
||||
| {k: v[0] for vary_dict in vary_dict_list for k, v in vary_dict.items()},
|
||||
self.master_config_dict
|
||||
| config.fixed
|
||||
| {k: v[0] for vary_dict in config.variable for k, v in vary_dict.items()},
|
||||
True,
|
||||
)
|
||||
self.num_sim = self.variationer.var_num()
|
||||
self.total_num_steps = sum(
|
||||
config["z_num"] * self.variationer.var_num(i)
|
||||
config.fixed["z_num"] * self.variationer.var_num(i)
|
||||
for i, config in enumerate(self.fiber_configs)
|
||||
)
|
||||
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
|
||||
|
||||
def __build_base_config(self):
|
||||
cfg = self.master_config.copy()
|
||||
vary: list[dict[str, list]] = cfg.pop("variable")
|
||||
return cfg | {k: v[0] for vary_dict in vary for k, v in vary_dict.items()}
|
||||
self.parallel = self.master_config_dict.get("parallel", Parameters.parallel.default)
|
||||
|
||||
def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
|
||||
for vary_dict in vary_dict_list:
|
||||
@@ -593,7 +588,7 @@ class Configuration:
|
||||
index = self.num_fibers + index
|
||||
sim_dict: dict[Path, Configuration.__SimConfig] = {}
|
||||
for descriptor in self.variationer.iterate(index):
|
||||
cfg = descriptor.update_config(self.fiber_configs[index])
|
||||
cfg = descriptor.update_config(self.fiber_configs[index].fixed)
|
||||
if index > 0:
|
||||
cfg["prev_data_dir"] = str(
|
||||
self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True)
|
||||
@@ -611,7 +606,8 @@ class Configuration:
|
||||
task, config_dict = self.__decide(sim_config)
|
||||
if task == self.Action.RUN:
|
||||
sim_dict.pop(data_dir)
|
||||
yield sim_config.descriptor, Parameters(**sim_config.config)
|
||||
param_dict = legacy.translate_parameters(sim_config.config)
|
||||
yield sim_config.descriptor, Parameters(**param_dict)
|
||||
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
|
||||
self.skip_callback(config_dict["recovery_last_stored"])
|
||||
break
|
||||
@@ -695,10 +691,7 @@ class Configuration:
|
||||
|
||||
def save_parameters(self):
|
||||
os.makedirs(self.final_path, exist_ok=True)
|
||||
cfgs = [
|
||||
cfg | dict(variable=self.variationer.all_dicts[i])
|
||||
for i, cfg in enumerate(self.fiber_configs)
|
||||
]
|
||||
cfgs = [cfg.fixed | dict(variable=cfg.variable) for cfg in self.fiber_configs]
|
||||
utils.save_toml(self.final_path / "initial_config.toml", dict(name=self.name, Fiber=cfgs))
|
||||
|
||||
@property
|
||||
|
||||
@@ -14,14 +14,15 @@ from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from string import printable as str_printable
|
||||
from typing import Any, Callable, MutableMapping, Sequence, TypeVar
|
||||
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Set
|
||||
|
||||
import numpy as np
|
||||
import pkg_resources as pkg
|
||||
import toml
|
||||
|
||||
from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN
|
||||
from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN, ROOT_PARAMETERS
|
||||
from .logger import get_logger
|
||||
from .errors import DuplicateParameterError
|
||||
|
||||
T_ = TypeVar("T_")
|
||||
|
||||
@@ -74,39 +75,51 @@ class Paths:
|
||||
return os.path.join(cls.get("plots"), name)
|
||||
|
||||
|
||||
class ConfigFileParser:
|
||||
path: Path
|
||||
repeat: int
|
||||
master: ConfigFileParser.SubConfig
|
||||
configs: list[ConfigFileParser.SubConfig]
|
||||
|
||||
@dataclass
|
||||
@dataclass(init=False)
|
||||
class SubConfig:
|
||||
fixed: dict[str, Any]
|
||||
variable: dict[str, list]
|
||||
variable: list[dict[str, list]]
|
||||
fixed_keys: Set[str]
|
||||
variable_keys: Set[str]
|
||||
|
||||
def __init__(self, path: os.PathLike):
|
||||
self.path = Path(path)
|
||||
fiber_list: list[dict[str, Any]]
|
||||
if self.path.name.lower().endswith(".toml"):
|
||||
loaded_config = _open_config(self.path)
|
||||
fiber_list = loaded_config.pop("Fiber")
|
||||
else:
|
||||
loaded_config = dict(name=self.path.name)
|
||||
fiber_list = [_open_config(p) for p in sorted(self.path.glob("initial_config*.toml"))]
|
||||
def __init__(self, dico: dict[str, Any]):
|
||||
dico = dico.copy()
|
||||
self.variable = conform_variable_entry(dico.pop("variable", []))
|
||||
self.fixed = dico
|
||||
self.__update
|
||||
|
||||
if len(fiber_list) == 0:
|
||||
raise ValueError(f"No fiber in config {self.path}")
|
||||
configs = []
|
||||
for i, params in enumerate(fiber_list):
|
||||
configs.append(loaded_config | params)
|
||||
for root_vary, first_vary in itertools.product(
|
||||
loaded_config["variable"], configs[0]["variable"]
|
||||
):
|
||||
if len(common := root_vary.keys() & first_vary.keys()) != 0:
|
||||
raise ValueError(f"These variable keys are specified twice : {common!r}")
|
||||
configs[0] |= {k: v for k, v in loaded_config.items() if k != "variable"}
|
||||
configs[0]["variable"].append(dict(num=list(range(configs[0].get("repeat", 1)))))
|
||||
def __update(self):
|
||||
self.variable_keys = set()
|
||||
self.fixed_keys = set()
|
||||
for dico in self.variable:
|
||||
for key in dico:
|
||||
if key in self.variable_keys:
|
||||
raise DuplicateParameterError(f"{key} is specified twice")
|
||||
self.variable_keys.add(key)
|
||||
for key in self.fixed:
|
||||
if key in self.variable_keys:
|
||||
raise DuplicateParameterError(f"{key} is specified twice")
|
||||
self.fixed_keys.add(key)
|
||||
|
||||
def weak_update(self, other: SubConfig = None, **kwargs):
|
||||
"""similar to a dict update method put prioritizes existing values
|
||||
|
||||
Parameters
|
||||
----------
|
||||
other : SubConfig
|
||||
other obj
|
||||
"""
|
||||
if other is None:
|
||||
other = SubConfig(kwargs)
|
||||
self.fixed = other.fixed | self.fixed
|
||||
self.variable = other.variable + self.variable
|
||||
self.__update()
|
||||
|
||||
|
||||
def conform_variable_entry(d) -> list[dict[str, list]]:
|
||||
if isinstance(d, MutableMapping):
|
||||
d = [{k: v} for k, v in d.items()]
|
||||
return d
|
||||
|
||||
|
||||
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
|
||||
@@ -141,23 +154,11 @@ def _open_config(path: os.PathLike):
|
||||
path = conform_toml_path(path)
|
||||
dico = resolve_loadfile_arg(load_toml(path))
|
||||
|
||||
dico = standardize_variable_dicts(dico)
|
||||
if "Fiber" not in dico:
|
||||
dico = dict(name=path.name, Fiber=[dico])
|
||||
return dico
|
||||
|
||||
|
||||
def standardize_variable_dicts(dico: dict[str, Any]):
|
||||
if "Fiber" in dico:
|
||||
dico["Fiber"] = [standardize_variable_dicts(fiber) for fiber in dico["Fiber"]]
|
||||
if (var := dico.get("variable")) is not None:
|
||||
if isinstance(var, MutableMapping):
|
||||
dico["variable"] = [var]
|
||||
else:
|
||||
dico["variable"] = [{}]
|
||||
return dico
|
||||
|
||||
|
||||
def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]:
|
||||
if (f_list := dico.pop("INCLUDE", None)) is not None:
|
||||
if isinstance(f_list, str):
|
||||
@@ -196,7 +197,7 @@ def save_toml(path: os.PathLike, dico):
|
||||
return dico
|
||||
|
||||
|
||||
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]:
|
||||
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[SubConfig]]:
|
||||
"""loads a configuration file
|
||||
|
||||
Parameters
|
||||
@@ -213,28 +214,26 @@ def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]
|
||||
|
||||
"""
|
||||
path = Path(path)
|
||||
fiber_list: list[dict[str, Any]]
|
||||
if path.name.lower().endswith(".toml"):
|
||||
loaded_config = _open_config(path)
|
||||
fiber_list = loaded_config.pop("Fiber")
|
||||
master_config_dict = _open_config(path)
|
||||
fiber_list = [SubConfig(d) for d in master_config_dict.pop("Fiber")]
|
||||
master_config = SubConfig(master_config_dict)
|
||||
else:
|
||||
loaded_config = dict(name=path.name)
|
||||
fiber_list = [_open_config(p) for p in sorted(path.glob("initial_config*.toml"))]
|
||||
master_config = SubConfig(dict(name=path.name))
|
||||
fiber_list = [SubConfig(_open_config(p)) for p in sorted(path.glob("initial_config*.toml"))]
|
||||
|
||||
if len(fiber_list) == 0:
|
||||
raise ValueError(f"No fiber in config {path}")
|
||||
final_path = loaded_config.get("name")
|
||||
configs = []
|
||||
for i, params in enumerate(fiber_list):
|
||||
configs.append(loaded_config | params)
|
||||
for root_vary, first_vary in itertools.product(
|
||||
loaded_config["variable"], configs[0]["variable"]
|
||||
):
|
||||
if len(common := root_vary.keys() & first_vary.keys()) != 0:
|
||||
raise ValueError(f"These variable keys are specified twice : {common!r}")
|
||||
configs[0] |= {k: v for k, v in loaded_config.items() if k != "variable"}
|
||||
configs[0]["variable"].append(dict(num=list(range(configs[0].get("repeat", 1)))))
|
||||
return Path(final_path), configs
|
||||
for fiber in fiber_list:
|
||||
fiber.weak_update(master_config)
|
||||
if "num" not in fiber_list[0].variable_keys:
|
||||
repeat_arg = list(range(fiber_list[0].fixed.get("repeat", 1)))
|
||||
fiber_list[0].weak_update(variable=dict(num=repeat_arg))
|
||||
for p_name in ROOT_PARAMETERS:
|
||||
if any(p_name in conf.variable_keys for conf in fiber_list[1:]):
|
||||
raise ValueError(f"{p_name} should only be specified in the root or first fiber")
|
||||
configs = fiber_list
|
||||
return Path(master_config.fixed["name"]), configs
|
||||
|
||||
|
||||
@cache
|
||||
@@ -340,27 +339,6 @@ def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray
|
||||
)
|
||||
|
||||
|
||||
def translate_parameters(d: dict[str, Any]) -> dict[str, Any]:
|
||||
old_names = dict(
|
||||
interp_degree="interpolation_degree",
|
||||
beta="beta2_coefficients",
|
||||
interp_range="interpolation_range",
|
||||
)
|
||||
deleted_names = {"lower_wavelength_interp_limit", "upper_wavelength_interp_limit"}
|
||||
defaults_to_add = dict(repeat=1)
|
||||
new = {}
|
||||
for k, v in d.items():
|
||||
if k == "error_ok":
|
||||
new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v
|
||||
elif k in deleted_names:
|
||||
continue
|
||||
elif isinstance(v, MutableMapping):
|
||||
new[k] = translate_parameters(v)
|
||||
else:
|
||||
new[old_names.get(k, k)] = v
|
||||
return defaults_to_add | new
|
||||
|
||||
|
||||
def to_62(i: int) -> str:
|
||||
arr = []
|
||||
if i == 0:
|
||||
@@ -445,7 +423,7 @@ def combine_simulations(path: Path, dest: Path = None):
|
||||
for l in paths.values():
|
||||
try:
|
||||
l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0])
|
||||
except ValueError:
|
||||
except TypeError:
|
||||
pass
|
||||
for pulses in paths.values():
|
||||
new_path = dest / update_path_name(pulses[0].name)
|
||||
|
||||
@@ -79,7 +79,7 @@ class Variationer:
|
||||
len_to_test = len(values[0])
|
||||
if not all(len(v) == len_to_test for v in values[1:]):
|
||||
raise VariationSpecsError(
|
||||
f"variable items should all have the same number of parameters"
|
||||
"variable items should all have the same number of parameters"
|
||||
)
|
||||
num_vars.append(len_to_test)
|
||||
if len(num_vars) == 0:
|
||||
|
||||
Reference in New Issue
Block a user