parameter computation working

This commit is contained in:
Benoît Sierro
2021-08-28 17:25:04 +02:00
parent e0951662a3
commit 39ae02ddb3
11 changed files with 72 additions and 231 deletions

12
play.py
View File

@@ -1,8 +1,11 @@
from dataclasses import fields
from scgenerator import Parameters from scgenerator import Parameters
from scgenerator.physics.simulate import RK4IP from scgenerator.physics.simulate import RK4IP
import os import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pprint import pprint
def main(): def main():
cwd = os.getcwd() cwd = os.getcwd()
@@ -10,12 +13,11 @@ def main():
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations") os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations")
pa = Parameters.load( pa = Parameters.load(
"/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM1550+PM2000D/PM1550_RIN.toml" "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM1550+PM2000D/PM2000D.toml"
) )
x = 1, 2
plt.plot(pa.t, pa.field_0.imag) print(pa.input_transmission)
plt.plot(pa.t, pa.field_0.real) print(x)
plt.show()
finally: finally:
os.chdir(cwd) os.chdir(cwd)

View File

@@ -12,4 +12,4 @@ from .physics.simulate import RK4IP, new_simulation, resume_simulations
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
from .spectra import Pulse, Spectrum from .spectra import Pulse, Spectrum
from .utils import Paths, load_toml from .utils import Paths, load_toml
from .utils.parameter import BareConfig, Parameters, PlotRange from .utils.parameter import Config, Parameters, PlotRange

View File

@@ -11,183 +11,16 @@ from . import utils
from .errors import * from .errors import *
from .logger import get_logger from .logger import get_logger
from .utils.parameter import ( from .utils.parameter import (
BareConfig, Config,
Parameters, Parameters,
hc_model_specific_parameters,
override_config, override_config,
required_simulations, required_simulations,
) )
from scgenerator.utils import parameter from scgenerator.utils import parameter
@dataclass
class Config(BareConfig):
@classmethod
def from_bare(cls, bare: BareConfig):
return cls(**asdict(bare))
def __post_init__(self):
for p_name, value in self.__dict__.items():
if value is not None and p_name in self.variable:
raise DuplicateParameterError(f"got multiple values for parameter {p_name!r}")
self.setdefault("name", "no name")
self.fiber_consistency()
if self.model in hc_model_specific_parameters:
self.gas_consistency()
self.pulse_consistency()
self.simulation_consistency()
def fiber_consistency(self):
if self.contains("dispersion_file") or self.contains("beta2_coefficients"):
if not (
self.contains("A_eff")
or self.contains("A_eff_file")
or self.contains("effective_mode_diameter")
):
self.get("gamma", specified_parameters=["custom fiber model"])
self.get("n2", specified_parameters=["custom fiber model"])
self.setdefault("model", "custom")
else:
self.get("model")
if self.model == "pcf":
self.get_fiber("pitch")
self.get_fiber("pitch_ratio")
elif self.model == "hasan":
self.get_multiple(
["capillary_spacing", "capillary_outer_d"], 1, fiber_model="hasan"
)
for param in [
"core_radius",
"capillary_num",
"capillary_thickness",
"capillary_resonance_strengths",
"capillary_nested",
]:
self.get_fiber(param)
else:
for param in hc_model_specific_parameters[self.model]:
self.get_fiber(param)
if self.contains("loss"):
if self.loss == "capillary":
for param in ["core_radius", "he_mode"]:
self.get_fiber(param)
for param in ["length", "input_transmission"]:
self.get(param)
def gas_consistency(self):
for param in ["gas_name", "temperature", "pressure", "plasma_density"]:
self.get(param, specified_params=["gas"])
def pulse_consistency(self):
for param in ["wavelength", "quantum_noise", "intensity_noise"]:
self.get(param)
if not self.contains("field_file"):
self.get("shape")
if self.contains("soliton_num"):
self.get_multiple(
["peak_power", "mean_power", "energy", "width", "t0"],
1,
specified_parameters=["soliton_num"],
)
else:
self.get_multiple(["t0", "width"], 1)
self.get_multiple(["peak_power", "energy", "mean_power"], 1)
if self.contains("mean_power"):
self.get("repetition_rate", specified_parameters=["mean_power"])
def simulation_consistency(self):
self.get_multiple(["dt", "t_num", "time_window"], 2)
for param in [
"behaviors",
"z_num",
"tolerated_error",
"parallel",
"repeat",
"interpolation_range",
"interpolation_degree",
"ideal_gas",
"recovery_last_stored",
]:
self.get(param)
if (
any(["raman" in l for l in self.variable.get("behaviors", [])])
or "raman" in self.behaviors
):
self.get("raman_type", specified_parameters=["raman"])
def contains(self, key):
return self.variable.get(key) is not None or getattr(self, key) is not None
def get(self, param, **kwargs) -> Any:
"""checks if param is in the parameter section dict and attempts to fill in a default value
Parameters
----------
param : str
the name of the parameter (dict key)
kwargs : any
key word arguments passed to the MissingParameterError constructor
Raises
------
MissingFiberParameterError
raised when a parameter is missing and no default exists
"""
# whether the parameter is in the right place and valid is checked elsewhere,
# here, we just make sure it is present.
if not self.contains(param):
try:
setattr(self, param, default_parameters[param])
except KeyError:
raise MissingParameterError(param, **kwargs)
def get_fiber(self, param, **kwargs):
"""wrapper for fiber parameters that depend on fiber model"""
self.get(param, fiber_model=self.model, **kwargs)
def get_multiple(self, params, num, **kwargs):
"""similar to the get method but works with several parameters
Parameters
----------
params : list of str
names of the required parameters
num : int
how many of the parameters in params are required
Raises
------
MissingParameterError
raised when not enough parameters are provided and no defaults exist
"""
gotten = 0
for param in params:
try:
self.get(param, **kwargs)
gotten += 1
except MissingParameterError:
pass
if gotten >= num:
return
raise MissingParameterError(params, num_required=num, **kwargs)
def setdefault(self, param, value):
if getattr(self, param) is None:
setattr(self, param, value)
class ParamSequence: class ParamSequence:
def __init__(self, config_dict: Union[dict[str, Any], os.PathLike, BareConfig]): def __init__(self, config_dict: Union[dict[str, Any], os.PathLike, Config]):
"""creates a param sequence from a base config """creates a param sequence from a base config
Parameters Parameters
@@ -197,7 +30,7 @@ class ParamSequence:
""" """
if isinstance(config_dict, Config): if isinstance(config_dict, Config):
self.config = config_dict self.config = config_dict
elif isinstance(config_dict, BareConfig): elif isinstance(config_dict, Config):
self.config = Config.from_bare(config_dict) self.config = Config.from_bare(config_dict)
else: else:
if not isinstance(config_dict, Mapping): if not isinstance(config_dict, Mapping):
@@ -231,7 +64,7 @@ class ParamSequence:
class ContinuationParamSequence(ParamSequence): class ContinuationParamSequence(ParamSequence):
def __init__(self, prev_sim_dir: os.PathLike, new_config: BareConfig): def __init__(self, prev_sim_dir: os.PathLike, new_config: Config):
"""Parameter sequence that builds on a previous simulation but with a new configuration """Parameter sequence that builds on a previous simulation but with a new configuration
It is recommended that only the fiber and the number of points stored may be changed and It is recommended that only the fiber and the number of points stored may be changed and
changing other parameters could results in unexpected behaviors. The new config doesn't have to changing other parameters could results in unexpected behaviors. The new config doesn't have to
@@ -245,9 +78,9 @@ class ContinuationParamSequence(ParamSequence):
new config new config
""" """
self.prev_sim_dir = Path(prev_sim_dir) self.prev_sim_dir = Path(prev_sim_dir)
self.bare_configs = BareConfig.load_sequence(new_config.previous_config_file) self.bare_configs = Config.load_sequence(new_config.previous_config_file)
self.bare_configs.append(new_config) self.bare_configs.append(new_config)
self.bare_configs[0] = Config.from_bare(self.bare_configs[0]) self.bare_configs[0].check_validity()
final_config = parameter.final_config_from_sequence(*self.bare_configs) final_config = parameter.final_config_from_sequence(*self.bare_configs)
super().__init__(final_config) super().__init__(final_config)
@@ -293,7 +126,7 @@ class ContinuationParamSequence(ParamSequence):
return count_variations(*self.bare_configs) return count_variations(*self.bare_configs)
def count_variations(*bare_configs: BareConfig) -> int: def count_variations(*bare_configs: Config) -> int:
sim_num = 1 sim_num = 1
for conf in bare_configs: for conf in bare_configs:
for l in conf.variable.values(): for l in conf.variable.values():
@@ -310,7 +143,7 @@ class RecoveryParamSequence(ParamSequence):
self.prev_sim_dir = None self.prev_sim_dir = None
if self.config.prev_sim_dir is not None: if self.config.prev_sim_dir is not None:
self.prev_sim_dir = Path(self.config.prev_sim_dir) self.prev_sim_dir = Path(self.config.prev_sim_dir)
init_config = BareConfig.load(self.prev_sim_dir / "initial_config.toml") init_config = Config.load(self.prev_sim_dir / "initial_config.toml")
self.prev_variable_lists = [ self.prev_variable_lists = [
( (
set(variable_list[1:]), set(variable_list[1:]),
@@ -403,7 +236,7 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]:
""" """
previous = None previous = None
configs = BareConfig.load_sequence(*configs) configs = Config.load_sequence(*configs)
for config in configs: for config in configs:
# if (p := Path(config)).is_dir(): # if (p := Path(config)).is_dir():
# config = p / "initial_config.toml" # config = p / "initial_config.toml"

View File

@@ -326,7 +326,12 @@ def load_and_adjust_field_file(
field_0 = load_field_file(field_file, t) field_0 = load_field_file(field_file, t)
if energy is not None: if energy is not None:
curr_energy = np.trapz(abs2(field_0), t) curr_energy = np.trapz(abs2(field_0), t)
field_0 *= field_0 = field_0 * np.sqrt(energy / curr_energy)
elif peak_power is not None:
ratio = np.sqrt(peak_power / abs2(field_0).max())
field_0 = field_0 * ratio
else:
raise ValueError(f"Not enough parameters specified to load {field_file} correctly")
field_0 = field_0 * modify_field_ratio( field_0 = field_0 * modify_field_ratio(
t, field_0, peak_power, energy, intensity_noise, noise_correlation t, field_0, peak_power, energy, intensity_noise, noise_correlation
@@ -1056,10 +1061,6 @@ def measure_field(t: np.ndarray, field: np.ndarray) -> Tuple[float, float, float
return fwhm, peak_power, energy return fwhm, peak_power, energy
def measure_custom_field(field_file: str, t: np.ndarray) -> float:
return measure_field(t, load_field_file(field_file, t))[0]
def remove_2nd_order_dispersion( def remove_2nd_order_dispersion(
spectrum: T, w_c: np.ndarray, beta2: float, max_z: float = -100.0 spectrum: T, w_c: np.ndarray, beta2: float, max_z: float = -100.0
) -> tuple[T, OptimizeResult]: ) -> tuple[T, OptimizeResult]:

View File

@@ -8,7 +8,7 @@ from typing import Type
import numpy as np import numpy as np
from .. import env, initialize, utils from .. import env, initialize, utils
from ..utils.parameter import Parameters, BareConfig from ..utils.parameter import Parameters, Config, format_variable_list
from ..const import PARAM_SEPARATOR from ..const import PARAM_SEPARATOR
from ..errors import IncompleteDataFolderError from ..errors import IncompleteDataFolderError
from ..logger import get_logger from ..logger import get_logger
@@ -471,7 +471,7 @@ class Simulations:
def _run_available(self): def _run_available(self):
for variable, params in self.param_seq: for variable, params in self.param_seq:
v_list_str = utils.format_variable_list(variable) v_list_str = format_variable_list(variable)
utils.save_parameters(params.prepare_for_dump(), self.sim_dir / v_list_str) utils.save_parameters(params.prepare_for_dump(), self.sim_dir / v_list_str)
self.new_sim(v_list_str, params) self.new_sim(v_list_str, params)
@@ -690,7 +690,7 @@ def run_simulation_sequence(
method=None, method=None,
prev_sim_dir: os.PathLike = None, prev_sim_dir: os.PathLike = None,
): ):
configs = BareConfig.load_sequence(*config_files) configs = Config.load_sequence(*config_files)
prev = prev_sim_dir prev = prev_sim_dir
for config in configs: for config in configs:
@@ -707,7 +707,7 @@ def run_simulation_sequence(
def new_simulation( def new_simulation(
config: BareConfig, config: Config,
prev_sim_dir=None, prev_sim_dir=None,
method: Type[Simulations] = None, method: Type[Simulations] = None,
) -> Simulations: ) -> Simulations:

View File

@@ -11,7 +11,7 @@ import numpy as np
from ..initialize import validate_config_sequence from ..initialize import validate_config_sequence
from ..utils import Paths from ..utils import Paths
from ..utils.parameter import BareConfig from ..utils.parameter import Config
def primes(n): def primes(n):
@@ -128,7 +128,7 @@ def main():
) )
if args.command == "merge": if args.command == "merge":
final_name = BareConfig.load(Path(args.configs[0]) / "initial_config.toml").name final_name = Config.load(Path(args.configs[0]) / "initial_config.toml").name
sim_num = "many" sim_num = "many"
args.nodes = 1 args.nodes = 1
args.cpus_per_node = 1 args.cpus_per_node = 1

View File

@@ -146,8 +146,6 @@ class Pulse(Sequence):
self.params = Parameters.load(self.path / "params.toml") self.params = Parameters.load(self.path / "params.toml")
initialize.build_sim_grid_in_place(self.params)
try: try:
self.z = np.load(os.path.join(path, "z.npy")) self.z = np.load(os.path.join(path, "z.npy"))
except FileNotFoundError: except FileNotFoundError:
@@ -161,7 +159,7 @@ class Pulse(Sequence):
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}") raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
self.t = self.params.t self.t = self.params.t
w = initialize.wspace(self.t) + units.m(self.params.wavelength) w = math.wspace(self.t) + units.m(self.params.wavelength)
self.w_order = np.argsort(w) self.w_order = np.argsort(w)
self.w = w self.w = w
self.wl = units.m.inv(self.w) self.wl = units.m.inv(self.w)

View File

@@ -84,6 +84,7 @@ class Paths:
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray: def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
prev_data_dir = Path(prev_data_dir)
num = find_last_spectrum_num(prev_data_dir) num = find_last_spectrum_num(prev_data_dir)
return np.load(prev_data_dir / SPEC1_FN.format(num)) return np.load(prev_data_dir / SPEC1_FN.format(num))

View File

@@ -336,21 +336,6 @@ valid_variable = {
"ideal_gas", "ideal_gas",
} }
hc_model_specific_parameters = dict(
marcatili=["core_radius", "he_mode"],
marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"],
hasan=[
"core_radius",
"capillary_num",
"capillary_thickness",
"capillary_resonance_strengths",
"capillary_nested",
"capillary_spacing",
"capillary_outer_d",
],
)
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
mandatory_parameters = [ mandatory_parameters = [
"name", "name",
"w_c", "w_c",
@@ -360,6 +345,7 @@ mandatory_parameters = [
"alpha", "alpha",
"spec_0", "spec_0",
"field_0", "field_0",
"input_transmission",
"z_targets", "z_targets",
"length", "length",
"beta2_coefficients", "beta2_coefficients",
@@ -451,7 +437,7 @@ class Parameters:
time_window: float = Parameter(positive(float, int)) time_window: float = Parameter(positive(float, int))
dt: float = Parameter(in_range_excl(0, 5e-15)) dt: float = Parameter(in_range_excl(0, 5e-15))
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11) tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
step_size: float = Parameter(positive(float, int), default=0) step_size: float = Parameter(non_negative(float, int), default=0)
interpolation_range: tuple[float, float] = Parameter(float_pair) interpolation_range: tuple[float, float] = Parameter(float_pair)
interpolation_degree: int = Parameter(positive(int), default=8) interpolation_degree: int = Parameter(positive(int), default=8)
prev_sim_dir: str = Parameter(string) prev_sim_dir: str = Parameter(string)
@@ -463,7 +449,7 @@ class Parameters:
spec_0: np.ndarray = Parameter(type_checker(np.ndarray)) spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
beta2: float = Parameter(type_checker(int, float)) beta2: float = Parameter(type_checker(int, float))
alpha_arr: np.ndarray = Parameter(type_checker(np.ndarray)) alpha_arr: np.ndarray = Parameter(type_checker(np.ndarray))
alpha: float = Parameter(positive(float, int), default=0) alpha: float = Parameter(non_negative(float, int), default=0)
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray)) gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray)) A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray))
w: np.ndarray = Parameter(type_checker(np.ndarray)) w: np.ndarray = Parameter(type_checker(np.ndarray))
@@ -588,6 +574,7 @@ class Rule:
kwarg_names: list[str], kwarg_names: list[str],
n_var: int, n_var: int,
args_const: list[str] = None, args_const: list[str] = None,
priorities: Union[int, list[int]] = None,
) -> list["Rule"]: ) -> list["Rule"]:
"""given a function that doesn't need all its keyword arguemtn specified, will """given a function that doesn't need all its keyword arguemtn specified, will
return a list of Rule obj, one for each combination of n_var specified kwargs return a list of Rule obj, one for each combination of n_var specified kwargs
@@ -625,7 +612,7 @@ class Rule:
new_func = func_rewrite(func, list(var_possibility), args_const) new_func = func_rewrite(func, list(var_possibility), args_const)
rules.append(cls(target, new_func)) rules.append(cls(target, new_func, priorities=priorities))
return rules return rules
@@ -732,8 +719,7 @@ class Evaluator:
prefix prefix
+ f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}" + f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}"
) )
self.params[param_name] = returned_value self.set_value(param_name, returned_value, param_priority)
self.eval_stats[param_name].priority = param_priority
if param_name == target: if param_name == target:
value = returned_value value = returned_value
break break
@@ -749,6 +735,7 @@ class Evaluator:
error = NoDefaultError(prefix + f"No default provided for {target}") error = NoDefaultError(prefix + f"No default provided for {target}")
else: else:
value = default value = default
self.set_value(target, value, 0)
if value is None and error is not None: if value is None and error is not None:
raise error raise error
@@ -756,6 +743,13 @@ class Evaluator:
self.__curent_lookup.remove(target) self.__curent_lookup.remove(target)
return value return value
def __getitem__(self, key: str) -> Any:
return self.params[key]
def set_value(self, key: str, value: Any, priority: int):
self.params[key] = value
self.eval_stats[key].priority = priority
def validate_condition(self, rule: Rule) -> bool: def validate_condition(self, rule: Rule) -> bool:
return all(self.compute(k) == v for k, v in rule.conditions.items()) return all(self.compute(k) == v for k, v in rule.conditions.items())
@@ -779,18 +773,25 @@ class Evaluator:
@dataclass @dataclass
class BareConfig(Parameters): class Config(Parameters):
variable: dict = VariableParameter(Parameters) variable: dict = VariableParameter(Parameters)
def __post_init__(self): def __post_init__(self):
pass pass
def check_validity(self):
conf_dict = asdict(self)
variable = conf_dict.pop("variable", {})
for k, v in variable.items():
conf_dict[k] = v[0]
Parameters(**conf_dict)
@classmethod @classmethod
def load(cls, path: os.PathLike) -> "BareConfig": def load(cls, path: os.PathLike) -> "Config":
return cls(**utils.load_toml(path)) return cls(**utils.load_toml(path))
@classmethod @classmethod
def load_sequence(cls, *config_paths: os.PathLike) -> list["BareConfig"]: def load_sequence(cls, *config_paths: os.PathLike) -> list["Config"]:
"""Loads a sequence of """Loads a sequence of
Parameters Parameters
@@ -830,8 +831,13 @@ class PlotRange:
def __str__(self): def __str__(self):
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}" return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
return sort_axis(axis, self)
def sort_axis(axis, plt_range: PlotRange) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
def sort_axis(
axis: np.ndarray, plt_range: PlotRange
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
""" """
given an axis, returns this axis cropped according to the given range, converted and sorted given an axis, returns this axis cropped according to the given range, converted and sorted
@@ -893,7 +899,7 @@ def validate_arg_names(names: list[str]):
raise ValueError(f"{n} is an invalid parameter name") raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None): def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
if arg_names is None: if arg_names is None:
arg_names = get_arg_names(func) arg_names = get_arg_names(func)
else: else:
@@ -972,7 +978,7 @@ def pretty_format_from_sim_name(name: str) -> str:
return PARAM_SEPARATOR.join(out) return PARAM_SEPARATOR.join(out)
def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]: def variable_iterator(config: Config) -> Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]:
"""given a config with "variable" parameters, iterates through every possible combination, """given a config with "variable" parameters, iterates through every possible combination,
yielding a a list of (parameter_name, value) tuples and a full config dictionary. yielding a a list of (parameter_name, value) tuples and a full config dictionary.
@@ -1011,7 +1017,7 @@ def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]
def required_simulations( def required_simulations(
*configs: BareConfig, *configs: Config,
) -> Iterator[tuple[list[tuple[str, Any]], Parameters]]: ) -> Iterator[tuple[list[tuple[str, Any]], Parameters]]:
"""takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different """takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different
parameter set and iterates through every single necessary simulation parameter set and iterates through every single necessary simulation
@@ -1045,11 +1051,11 @@ def reduce_all_variable(all_variable: list[list[tuple[str, Any]]]) -> list[tuple
return out return out
def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig: def override_config(new: Config, old: Config = None) -> Config:
"""makes sure all the parameters set in new are there, leaves untouched parameters in old""" """makes sure all the parameters set in new are there, leaves untouched parameters in old"""
new_dict = asdict(new) new_dict = asdict(new)
if old is None: if old is None:
return BareConfig(**new_dict) return Config(**new_dict)
variable = deepcopy(old.variable) variable = deepcopy(old.variable)
new_dict = {k: v for k, v in new_dict.items() if v is not None} new_dict = {k: v for k, v in new_dict.items() if v is not None}
@@ -1060,7 +1066,7 @@ def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig:
return replace(old, variable=variable, **new_dict) return replace(old, variable=variable, **new_dict)
def final_config_from_sequence(*configs: BareConfig) -> BareConfig: def final_config_from_sequence(*configs: Config) -> Config:
if len(configs) == 0: if len(configs) == 0:
raise ValueError("Must provide at least one config") raise ValueError("Must provide at least one config")
if len(configs) == 1: if len(configs) == 1:
@@ -1085,10 +1091,11 @@ default_rules: list[Rule] = [
Rule("spec_0", np.fft.fft, ["field_0"]), Rule("spec_0", np.fft.fft, ["field_0"]),
Rule("field_0", np.fft.ifft, ["spec_0"]), Rule("field_0", np.fft.ifft, ["spec_0"]),
Rule("spec_0", utils.load_previous_spectrum, priorities=3), Rule("spec_0", utils.load_previous_spectrum, priorities=3),
Rule( *Rule.deduce(
["pre_field_0", "peak_power", "energy", "width"], ["pre_field_0", "peak_power", "energy", "width"],
pulse.load_and_adjust_field_file, pulse.load_and_adjust_field_file,
["field_file", "t", "peak_power", "energy", "intensity_noise", "noise_correlation"], ["energy", "peak_power"],
1,
priorities=[2, 1, 1, 1], priorities=[2, 1, 1, 1],
), ),
Rule("pre_field_0", pulse.initial_field, priorities=1), Rule("pre_field_0", pulse.initial_field, priorities=1),
@@ -1099,7 +1106,6 @@ default_rules: list[Rule] = [
), ),
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
Rule("peak_power", pulse.soliton_num_to_peak_power), Rule("peak_power", pulse.soliton_num_to_peak_power),
Rule(["width", "peak_power", "energy"], pulse.measure_custom_field),
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]), Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
Rule("energy", pulse.mean_power_to_energy), Rule("energy", pulse.mean_power_to_energy),
Rule("t0", pulse.width_to_t0), Rule("t0", pulse.width_to_t0),

View File

@@ -7,7 +7,7 @@ import toml
from scgenerator import defaults, utils, math from scgenerator import defaults, utils, math
from scgenerator.errors import * from scgenerator.errors import *
from scgenerator.physics import pulse, units from scgenerator.physics import pulse, units
from scgenerator.utils.parameter import BareConfig, Parameters from scgenerator.utils.parameter import Config, Parameters
def load_conf(name): def load_conf(name):

View File

@@ -16,7 +16,7 @@ def conf_maker(folder, val=True):
if val: if val:
return initialize.Config(**load_conf(folder + "/" + name)) return initialize.Config(**load_conf(folder + "/" + name))
else: else:
return initialize.BareConfig(**load_conf(folder + "/" + name)) return initialize.Config(**load_conf(folder + "/" + name))
return conf return conf