parameter computation working
This commit is contained in:
12
play.py
12
play.py
@@ -1,8 +1,11 @@
|
||||
from dataclasses import fields
|
||||
from scgenerator import Parameters
|
||||
from scgenerator.physics.simulate import RK4IP
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def main():
|
||||
cwd = os.getcwd()
|
||||
@@ -10,12 +13,11 @@ def main():
|
||||
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations")
|
||||
|
||||
pa = Parameters.load(
|
||||
"/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM1550+PM2000D/PM1550_RIN.toml"
|
||||
"/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM1550+PM2000D/PM2000D.toml"
|
||||
)
|
||||
|
||||
plt.plot(pa.t, pa.field_0.imag)
|
||||
plt.plot(pa.t, pa.field_0.real)
|
||||
plt.show()
|
||||
x = 1, 2
|
||||
print(pa.input_transmission)
|
||||
print(x)
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
@@ -12,4 +12,4 @@ from .physics.simulate import RK4IP, new_simulation, resume_simulations
|
||||
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
||||
from .spectra import Pulse, Spectrum
|
||||
from .utils import Paths, load_toml
|
||||
from .utils.parameter import BareConfig, Parameters, PlotRange
|
||||
from .utils.parameter import Config, Parameters, PlotRange
|
||||
|
||||
@@ -11,183 +11,16 @@ from . import utils
|
||||
from .errors import *
|
||||
from .logger import get_logger
|
||||
from .utils.parameter import (
|
||||
BareConfig,
|
||||
Config,
|
||||
Parameters,
|
||||
hc_model_specific_parameters,
|
||||
override_config,
|
||||
required_simulations,
|
||||
)
|
||||
from scgenerator.utils import parameter
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(BareConfig):
|
||||
@classmethod
|
||||
def from_bare(cls, bare: BareConfig):
|
||||
return cls(**asdict(bare))
|
||||
|
||||
def __post_init__(self):
|
||||
for p_name, value in self.__dict__.items():
|
||||
if value is not None and p_name in self.variable:
|
||||
raise DuplicateParameterError(f"got multiple values for parameter {p_name!r}")
|
||||
self.setdefault("name", "no name")
|
||||
self.fiber_consistency()
|
||||
if self.model in hc_model_specific_parameters:
|
||||
self.gas_consistency()
|
||||
self.pulse_consistency()
|
||||
self.simulation_consistency()
|
||||
|
||||
def fiber_consistency(self):
|
||||
|
||||
if self.contains("dispersion_file") or self.contains("beta2_coefficients"):
|
||||
if not (
|
||||
self.contains("A_eff")
|
||||
or self.contains("A_eff_file")
|
||||
or self.contains("effective_mode_diameter")
|
||||
):
|
||||
self.get("gamma", specified_parameters=["custom fiber model"])
|
||||
self.get("n2", specified_parameters=["custom fiber model"])
|
||||
self.setdefault("model", "custom")
|
||||
|
||||
else:
|
||||
self.get("model")
|
||||
|
||||
if self.model == "pcf":
|
||||
self.get_fiber("pitch")
|
||||
self.get_fiber("pitch_ratio")
|
||||
|
||||
elif self.model == "hasan":
|
||||
self.get_multiple(
|
||||
["capillary_spacing", "capillary_outer_d"], 1, fiber_model="hasan"
|
||||
)
|
||||
for param in [
|
||||
"core_radius",
|
||||
"capillary_num",
|
||||
"capillary_thickness",
|
||||
"capillary_resonance_strengths",
|
||||
"capillary_nested",
|
||||
]:
|
||||
self.get_fiber(param)
|
||||
else:
|
||||
for param in hc_model_specific_parameters[self.model]:
|
||||
self.get_fiber(param)
|
||||
if self.contains("loss"):
|
||||
if self.loss == "capillary":
|
||||
for param in ["core_radius", "he_mode"]:
|
||||
self.get_fiber(param)
|
||||
for param in ["length", "input_transmission"]:
|
||||
self.get(param)
|
||||
|
||||
def gas_consistency(self):
|
||||
for param in ["gas_name", "temperature", "pressure", "plasma_density"]:
|
||||
self.get(param, specified_params=["gas"])
|
||||
|
||||
def pulse_consistency(self):
|
||||
for param in ["wavelength", "quantum_noise", "intensity_noise"]:
|
||||
self.get(param)
|
||||
|
||||
if not self.contains("field_file"):
|
||||
self.get("shape")
|
||||
|
||||
if self.contains("soliton_num"):
|
||||
self.get_multiple(
|
||||
["peak_power", "mean_power", "energy", "width", "t0"],
|
||||
1,
|
||||
specified_parameters=["soliton_num"],
|
||||
)
|
||||
|
||||
else:
|
||||
self.get_multiple(["t0", "width"], 1)
|
||||
self.get_multiple(["peak_power", "energy", "mean_power"], 1)
|
||||
if self.contains("mean_power"):
|
||||
self.get("repetition_rate", specified_parameters=["mean_power"])
|
||||
|
||||
def simulation_consistency(self):
|
||||
self.get_multiple(["dt", "t_num", "time_window"], 2)
|
||||
|
||||
for param in [
|
||||
"behaviors",
|
||||
"z_num",
|
||||
"tolerated_error",
|
||||
"parallel",
|
||||
"repeat",
|
||||
"interpolation_range",
|
||||
"interpolation_degree",
|
||||
"ideal_gas",
|
||||
"recovery_last_stored",
|
||||
]:
|
||||
self.get(param)
|
||||
|
||||
if (
|
||||
any(["raman" in l for l in self.variable.get("behaviors", [])])
|
||||
or "raman" in self.behaviors
|
||||
):
|
||||
self.get("raman_type", specified_parameters=["raman"])
|
||||
|
||||
def contains(self, key):
|
||||
return self.variable.get(key) is not None or getattr(self, key) is not None
|
||||
|
||||
def get(self, param, **kwargs) -> Any:
|
||||
"""checks if param is in the parameter section dict and attempts to fill in a default value
|
||||
|
||||
Parameters
|
||||
----------
|
||||
param : str
|
||||
the name of the parameter (dict key)
|
||||
kwargs : any
|
||||
key word arguments passed to the MissingParameterError constructor
|
||||
|
||||
Raises
|
||||
------
|
||||
MissingFiberParameterError
|
||||
raised when a parameter is missing and no default exists
|
||||
"""
|
||||
|
||||
# whether the parameter is in the right place and valid is checked elsewhere,
|
||||
# here, we just make sure it is present.
|
||||
if not self.contains(param):
|
||||
try:
|
||||
setattr(self, param, default_parameters[param])
|
||||
except KeyError:
|
||||
raise MissingParameterError(param, **kwargs)
|
||||
|
||||
def get_fiber(self, param, **kwargs):
|
||||
"""wrapper for fiber parameters that depend on fiber model"""
|
||||
self.get(param, fiber_model=self.model, **kwargs)
|
||||
|
||||
def get_multiple(self, params, num, **kwargs):
|
||||
"""similar to the get method but works with several parameters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : list of str
|
||||
names of the required parameters
|
||||
num : int
|
||||
how many of the parameters in params are required
|
||||
|
||||
Raises
|
||||
------
|
||||
MissingParameterError
|
||||
raised when not enough parameters are provided and no defaults exist
|
||||
"""
|
||||
gotten = 0
|
||||
for param in params:
|
||||
try:
|
||||
self.get(param, **kwargs)
|
||||
gotten += 1
|
||||
except MissingParameterError:
|
||||
pass
|
||||
if gotten >= num:
|
||||
return
|
||||
raise MissingParameterError(params, num_required=num, **kwargs)
|
||||
|
||||
def setdefault(self, param, value):
|
||||
if getattr(self, param) is None:
|
||||
setattr(self, param, value)
|
||||
|
||||
|
||||
class ParamSequence:
|
||||
def __init__(self, config_dict: Union[dict[str, Any], os.PathLike, BareConfig]):
|
||||
def __init__(self, config_dict: Union[dict[str, Any], os.PathLike, Config]):
|
||||
"""creates a param sequence from a base config
|
||||
|
||||
Parameters
|
||||
@@ -197,7 +30,7 @@ class ParamSequence:
|
||||
"""
|
||||
if isinstance(config_dict, Config):
|
||||
self.config = config_dict
|
||||
elif isinstance(config_dict, BareConfig):
|
||||
elif isinstance(config_dict, Config):
|
||||
self.config = Config.from_bare(config_dict)
|
||||
else:
|
||||
if not isinstance(config_dict, Mapping):
|
||||
@@ -231,7 +64,7 @@ class ParamSequence:
|
||||
|
||||
|
||||
class ContinuationParamSequence(ParamSequence):
|
||||
def __init__(self, prev_sim_dir: os.PathLike, new_config: BareConfig):
|
||||
def __init__(self, prev_sim_dir: os.PathLike, new_config: Config):
|
||||
"""Parameter sequence that builds on a previous simulation but with a new configuration
|
||||
It is recommended that only the fiber and the number of points stored may be changed and
|
||||
changing other parameters could results in unexpected behaviors. The new config doesn't have to
|
||||
@@ -245,9 +78,9 @@ class ContinuationParamSequence(ParamSequence):
|
||||
new config
|
||||
"""
|
||||
self.prev_sim_dir = Path(prev_sim_dir)
|
||||
self.bare_configs = BareConfig.load_sequence(new_config.previous_config_file)
|
||||
self.bare_configs = Config.load_sequence(new_config.previous_config_file)
|
||||
self.bare_configs.append(new_config)
|
||||
self.bare_configs[0] = Config.from_bare(self.bare_configs[0])
|
||||
self.bare_configs[0].check_validity()
|
||||
final_config = parameter.final_config_from_sequence(*self.bare_configs)
|
||||
super().__init__(final_config)
|
||||
|
||||
@@ -293,7 +126,7 @@ class ContinuationParamSequence(ParamSequence):
|
||||
return count_variations(*self.bare_configs)
|
||||
|
||||
|
||||
def count_variations(*bare_configs: BareConfig) -> int:
|
||||
def count_variations(*bare_configs: Config) -> int:
|
||||
sim_num = 1
|
||||
for conf in bare_configs:
|
||||
for l in conf.variable.values():
|
||||
@@ -310,7 +143,7 @@ class RecoveryParamSequence(ParamSequence):
|
||||
self.prev_sim_dir = None
|
||||
if self.config.prev_sim_dir is not None:
|
||||
self.prev_sim_dir = Path(self.config.prev_sim_dir)
|
||||
init_config = BareConfig.load(self.prev_sim_dir / "initial_config.toml")
|
||||
init_config = Config.load(self.prev_sim_dir / "initial_config.toml")
|
||||
self.prev_variable_lists = [
|
||||
(
|
||||
set(variable_list[1:]),
|
||||
@@ -403,7 +236,7 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]:
|
||||
"""
|
||||
|
||||
previous = None
|
||||
configs = BareConfig.load_sequence(*configs)
|
||||
configs = Config.load_sequence(*configs)
|
||||
for config in configs:
|
||||
# if (p := Path(config)).is_dir():
|
||||
# config = p / "initial_config.toml"
|
||||
|
||||
@@ -326,7 +326,12 @@ def load_and_adjust_field_file(
|
||||
field_0 = load_field_file(field_file, t)
|
||||
if energy is not None:
|
||||
curr_energy = np.trapz(abs2(field_0), t)
|
||||
field_0 *=
|
||||
field_0 = field_0 * np.sqrt(energy / curr_energy)
|
||||
elif peak_power is not None:
|
||||
ratio = np.sqrt(peak_power / abs2(field_0).max())
|
||||
field_0 = field_0 * ratio
|
||||
else:
|
||||
raise ValueError(f"Not enough parameters specified to load {field_file} correctly")
|
||||
|
||||
field_0 = field_0 * modify_field_ratio(
|
||||
t, field_0, peak_power, energy, intensity_noise, noise_correlation
|
||||
@@ -1056,10 +1061,6 @@ def measure_field(t: np.ndarray, field: np.ndarray) -> Tuple[float, float, float
|
||||
return fwhm, peak_power, energy
|
||||
|
||||
|
||||
def measure_custom_field(field_file: str, t: np.ndarray) -> float:
|
||||
return measure_field(t, load_field_file(field_file, t))[0]
|
||||
|
||||
|
||||
def remove_2nd_order_dispersion(
|
||||
spectrum: T, w_c: np.ndarray, beta2: float, max_z: float = -100.0
|
||||
) -> tuple[T, OptimizeResult]:
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Type
|
||||
import numpy as np
|
||||
|
||||
from .. import env, initialize, utils
|
||||
from ..utils.parameter import Parameters, BareConfig
|
||||
from ..utils.parameter import Parameters, Config, format_variable_list
|
||||
from ..const import PARAM_SEPARATOR
|
||||
from ..errors import IncompleteDataFolderError
|
||||
from ..logger import get_logger
|
||||
@@ -471,7 +471,7 @@ class Simulations:
|
||||
|
||||
def _run_available(self):
|
||||
for variable, params in self.param_seq:
|
||||
v_list_str = utils.format_variable_list(variable)
|
||||
v_list_str = format_variable_list(variable)
|
||||
utils.save_parameters(params.prepare_for_dump(), self.sim_dir / v_list_str)
|
||||
|
||||
self.new_sim(v_list_str, params)
|
||||
@@ -690,7 +690,7 @@ def run_simulation_sequence(
|
||||
method=None,
|
||||
prev_sim_dir: os.PathLike = None,
|
||||
):
|
||||
configs = BareConfig.load_sequence(*config_files)
|
||||
configs = Config.load_sequence(*config_files)
|
||||
|
||||
prev = prev_sim_dir
|
||||
for config in configs:
|
||||
@@ -707,7 +707,7 @@ def run_simulation_sequence(
|
||||
|
||||
|
||||
def new_simulation(
|
||||
config: BareConfig,
|
||||
config: Config,
|
||||
prev_sim_dir=None,
|
||||
method: Type[Simulations] = None,
|
||||
) -> Simulations:
|
||||
|
||||
@@ -11,7 +11,7 @@ import numpy as np
|
||||
|
||||
from ..initialize import validate_config_sequence
|
||||
from ..utils import Paths
|
||||
from ..utils.parameter import BareConfig
|
||||
from ..utils.parameter import Config
|
||||
|
||||
|
||||
def primes(n):
|
||||
@@ -128,7 +128,7 @@ def main():
|
||||
)
|
||||
|
||||
if args.command == "merge":
|
||||
final_name = BareConfig.load(Path(args.configs[0]) / "initial_config.toml").name
|
||||
final_name = Config.load(Path(args.configs[0]) / "initial_config.toml").name
|
||||
sim_num = "many"
|
||||
args.nodes = 1
|
||||
args.cpus_per_node = 1
|
||||
|
||||
@@ -146,8 +146,6 @@ class Pulse(Sequence):
|
||||
|
||||
self.params = Parameters.load(self.path / "params.toml")
|
||||
|
||||
initialize.build_sim_grid_in_place(self.params)
|
||||
|
||||
try:
|
||||
self.z = np.load(os.path.join(path, "z.npy"))
|
||||
except FileNotFoundError:
|
||||
@@ -161,7 +159,7 @@ class Pulse(Sequence):
|
||||
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
||||
|
||||
self.t = self.params.t
|
||||
w = initialize.wspace(self.t) + units.m(self.params.wavelength)
|
||||
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
||||
self.w_order = np.argsort(w)
|
||||
self.w = w
|
||||
self.wl = units.m.inv(self.w)
|
||||
|
||||
@@ -84,6 +84,7 @@ class Paths:
|
||||
|
||||
|
||||
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
|
||||
prev_data_dir = Path(prev_data_dir)
|
||||
num = find_last_spectrum_num(prev_data_dir)
|
||||
return np.load(prev_data_dir / SPEC1_FN.format(num))
|
||||
|
||||
|
||||
@@ -336,21 +336,6 @@ valid_variable = {
|
||||
"ideal_gas",
|
||||
}
|
||||
|
||||
hc_model_specific_parameters = dict(
|
||||
marcatili=["core_radius", "he_mode"],
|
||||
marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"],
|
||||
hasan=[
|
||||
"core_radius",
|
||||
"capillary_num",
|
||||
"capillary_thickness",
|
||||
"capillary_resonance_strengths",
|
||||
"capillary_nested",
|
||||
"capillary_spacing",
|
||||
"capillary_outer_d",
|
||||
],
|
||||
)
|
||||
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
|
||||
|
||||
mandatory_parameters = [
|
||||
"name",
|
||||
"w_c",
|
||||
@@ -360,6 +345,7 @@ mandatory_parameters = [
|
||||
"alpha",
|
||||
"spec_0",
|
||||
"field_0",
|
||||
"input_transmission",
|
||||
"z_targets",
|
||||
"length",
|
||||
"beta2_coefficients",
|
||||
@@ -451,7 +437,7 @@ class Parameters:
|
||||
time_window: float = Parameter(positive(float, int))
|
||||
dt: float = Parameter(in_range_excl(0, 5e-15))
|
||||
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
|
||||
step_size: float = Parameter(positive(float, int), default=0)
|
||||
step_size: float = Parameter(non_negative(float, int), default=0)
|
||||
interpolation_range: tuple[float, float] = Parameter(float_pair)
|
||||
interpolation_degree: int = Parameter(positive(int), default=8)
|
||||
prev_sim_dir: str = Parameter(string)
|
||||
@@ -463,7 +449,7 @@ class Parameters:
|
||||
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
beta2: float = Parameter(type_checker(int, float))
|
||||
alpha_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
alpha: float = Parameter(positive(float, int), default=0)
|
||||
alpha: float = Parameter(non_negative(float, int), default=0)
|
||||
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
@@ -588,6 +574,7 @@ class Rule:
|
||||
kwarg_names: list[str],
|
||||
n_var: int,
|
||||
args_const: list[str] = None,
|
||||
priorities: Union[int, list[int]] = None,
|
||||
) -> list["Rule"]:
|
||||
"""given a function that doesn't need all its keyword arguemtn specified, will
|
||||
return a list of Rule obj, one for each combination of n_var specified kwargs
|
||||
@@ -625,7 +612,7 @@ class Rule:
|
||||
|
||||
new_func = func_rewrite(func, list(var_possibility), args_const)
|
||||
|
||||
rules.append(cls(target, new_func))
|
||||
rules.append(cls(target, new_func, priorities=priorities))
|
||||
return rules
|
||||
|
||||
|
||||
@@ -732,8 +719,7 @@ class Evaluator:
|
||||
prefix
|
||||
+ f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}"
|
||||
)
|
||||
self.params[param_name] = returned_value
|
||||
self.eval_stats[param_name].priority = param_priority
|
||||
self.set_value(param_name, returned_value, param_priority)
|
||||
if param_name == target:
|
||||
value = returned_value
|
||||
break
|
||||
@@ -749,6 +735,7 @@ class Evaluator:
|
||||
error = NoDefaultError(prefix + f"No default provided for {target}")
|
||||
else:
|
||||
value = default
|
||||
self.set_value(target, value, 0)
|
||||
|
||||
if value is None and error is not None:
|
||||
raise error
|
||||
@@ -756,6 +743,13 @@ class Evaluator:
|
||||
self.__curent_lookup.remove(target)
|
||||
return value
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.params[key]
|
||||
|
||||
def set_value(self, key: str, value: Any, priority: int):
|
||||
self.params[key] = value
|
||||
self.eval_stats[key].priority = priority
|
||||
|
||||
def validate_condition(self, rule: Rule) -> bool:
|
||||
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
||||
|
||||
@@ -779,18 +773,25 @@ class Evaluator:
|
||||
|
||||
|
||||
@dataclass
|
||||
class BareConfig(Parameters):
|
||||
class Config(Parameters):
|
||||
variable: dict = VariableParameter(Parameters)
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
def check_validity(self):
|
||||
conf_dict = asdict(self)
|
||||
variable = conf_dict.pop("variable", {})
|
||||
for k, v in variable.items():
|
||||
conf_dict[k] = v[0]
|
||||
Parameters(**conf_dict)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: os.PathLike) -> "BareConfig":
|
||||
def load(cls, path: os.PathLike) -> "Config":
|
||||
return cls(**utils.load_toml(path))
|
||||
|
||||
@classmethod
|
||||
def load_sequence(cls, *config_paths: os.PathLike) -> list["BareConfig"]:
|
||||
def load_sequence(cls, *config_paths: os.PathLike) -> list["Config"]:
|
||||
"""Loads a sequence of
|
||||
|
||||
Parameters
|
||||
@@ -830,8 +831,13 @@ class PlotRange:
|
||||
def __str__(self):
|
||||
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||
|
||||
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
return sort_axis(axis, self)
|
||||
|
||||
def sort_axis(axis, plt_range: PlotRange) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
|
||||
def sort_axis(
|
||||
axis: np.ndarray, plt_range: PlotRange
|
||||
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
"""
|
||||
given an axis, returns this axis cropped according to the given range, converted and sorted
|
||||
|
||||
@@ -893,7 +899,7 @@ def validate_arg_names(names: list[str]):
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
@@ -972,7 +978,7 @@ def pretty_format_from_sim_name(name: str) -> str:
|
||||
return PARAM_SEPARATOR.join(out)
|
||||
|
||||
|
||||
def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]:
|
||||
def variable_iterator(config: Config) -> Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]:
|
||||
"""given a config with "variable" parameters, iterates through every possible combination,
|
||||
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
|
||||
|
||||
@@ -1011,7 +1017,7 @@ def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]
|
||||
|
||||
|
||||
def required_simulations(
|
||||
*configs: BareConfig,
|
||||
*configs: Config,
|
||||
) -> Iterator[tuple[list[tuple[str, Any]], Parameters]]:
|
||||
"""takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different
|
||||
parameter set and iterates through every single necessary simulation
|
||||
@@ -1045,11 +1051,11 @@ def reduce_all_variable(all_variable: list[list[tuple[str, Any]]]) -> list[tuple
|
||||
return out
|
||||
|
||||
|
||||
def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig:
|
||||
def override_config(new: Config, old: Config = None) -> Config:
|
||||
"""makes sure all the parameters set in new are there, leaves untouched parameters in old"""
|
||||
new_dict = asdict(new)
|
||||
if old is None:
|
||||
return BareConfig(**new_dict)
|
||||
return Config(**new_dict)
|
||||
variable = deepcopy(old.variable)
|
||||
new_dict = {k: v for k, v in new_dict.items() if v is not None}
|
||||
|
||||
@@ -1060,7 +1066,7 @@ def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig:
|
||||
return replace(old, variable=variable, **new_dict)
|
||||
|
||||
|
||||
def final_config_from_sequence(*configs: BareConfig) -> BareConfig:
|
||||
def final_config_from_sequence(*configs: Config) -> Config:
|
||||
if len(configs) == 0:
|
||||
raise ValueError("Must provide at least one config")
|
||||
if len(configs) == 1:
|
||||
@@ -1085,10 +1091,11 @@ default_rules: list[Rule] = [
|
||||
Rule("spec_0", np.fft.fft, ["field_0"]),
|
||||
Rule("field_0", np.fft.ifft, ["spec_0"]),
|
||||
Rule("spec_0", utils.load_previous_spectrum, priorities=3),
|
||||
Rule(
|
||||
*Rule.deduce(
|
||||
["pre_field_0", "peak_power", "energy", "width"],
|
||||
pulse.load_and_adjust_field_file,
|
||||
["field_file", "t", "peak_power", "energy", "intensity_noise", "noise_correlation"],
|
||||
["energy", "peak_power"],
|
||||
1,
|
||||
priorities=[2, 1, 1, 1],
|
||||
),
|
||||
Rule("pre_field_0", pulse.initial_field, priorities=1),
|
||||
@@ -1099,7 +1106,6 @@ default_rules: list[Rule] = [
|
||||
),
|
||||
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
||||
Rule("peak_power", pulse.soliton_num_to_peak_power),
|
||||
Rule(["width", "peak_power", "energy"], pulse.measure_custom_field),
|
||||
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
|
||||
Rule("energy", pulse.mean_power_to_energy),
|
||||
Rule("t0", pulse.width_to_t0),
|
||||
|
||||
@@ -7,7 +7,7 @@ import toml
|
||||
from scgenerator import defaults, utils, math
|
||||
from scgenerator.errors import *
|
||||
from scgenerator.physics import pulse, units
|
||||
from scgenerator.utils.parameter import BareConfig, Parameters
|
||||
from scgenerator.utils.parameter import Config, Parameters
|
||||
|
||||
|
||||
def load_conf(name):
|
||||
|
||||
@@ -16,7 +16,7 @@ def conf_maker(folder, val=True):
|
||||
if val:
|
||||
return initialize.Config(**load_conf(folder + "/" + name))
|
||||
else:
|
||||
return initialize.BareConfig(**load_conf(folder + "/" + name))
|
||||
return initialize.Config(**load_conf(folder + "/" + name))
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
Reference in New Issue
Block a user