parameter computation working
This commit is contained in:
12
play.py
12
play.py
@@ -1,8 +1,11 @@
|
|||||||
|
from dataclasses import fields
|
||||||
from scgenerator import Parameters
|
from scgenerator import Parameters
|
||||||
from scgenerator.physics.simulate import RK4IP
|
from scgenerator.physics.simulate import RK4IP
|
||||||
import os
|
import os
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
@@ -10,12 +13,11 @@ def main():
|
|||||||
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations")
|
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations")
|
||||||
|
|
||||||
pa = Parameters.load(
|
pa = Parameters.load(
|
||||||
"/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM1550+PM2000D/PM1550_RIN.toml"
|
"/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM1550+PM2000D/PM2000D.toml"
|
||||||
)
|
)
|
||||||
|
x = 1, 2
|
||||||
plt.plot(pa.t, pa.field_0.imag)
|
print(pa.input_transmission)
|
||||||
plt.plot(pa.t, pa.field_0.real)
|
print(x)
|
||||||
plt.show()
|
|
||||||
finally:
|
finally:
|
||||||
os.chdir(cwd)
|
os.chdir(cwd)
|
||||||
|
|
||||||
|
|||||||
@@ -12,4 +12,4 @@ from .physics.simulate import RK4IP, new_simulation, resume_simulations
|
|||||||
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
||||||
from .spectra import Pulse, Spectrum
|
from .spectra import Pulse, Spectrum
|
||||||
from .utils import Paths, load_toml
|
from .utils import Paths, load_toml
|
||||||
from .utils.parameter import BareConfig, Parameters, PlotRange
|
from .utils.parameter import Config, Parameters, PlotRange
|
||||||
|
|||||||
@@ -11,183 +11,16 @@ from . import utils
|
|||||||
from .errors import *
|
from .errors import *
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .utils.parameter import (
|
from .utils.parameter import (
|
||||||
BareConfig,
|
Config,
|
||||||
Parameters,
|
Parameters,
|
||||||
hc_model_specific_parameters,
|
|
||||||
override_config,
|
override_config,
|
||||||
required_simulations,
|
required_simulations,
|
||||||
)
|
)
|
||||||
from scgenerator.utils import parameter
|
from scgenerator.utils import parameter
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Config(BareConfig):
|
|
||||||
@classmethod
|
|
||||||
def from_bare(cls, bare: BareConfig):
|
|
||||||
return cls(**asdict(bare))
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
for p_name, value in self.__dict__.items():
|
|
||||||
if value is not None and p_name in self.variable:
|
|
||||||
raise DuplicateParameterError(f"got multiple values for parameter {p_name!r}")
|
|
||||||
self.setdefault("name", "no name")
|
|
||||||
self.fiber_consistency()
|
|
||||||
if self.model in hc_model_specific_parameters:
|
|
||||||
self.gas_consistency()
|
|
||||||
self.pulse_consistency()
|
|
||||||
self.simulation_consistency()
|
|
||||||
|
|
||||||
def fiber_consistency(self):
|
|
||||||
|
|
||||||
if self.contains("dispersion_file") or self.contains("beta2_coefficients"):
|
|
||||||
if not (
|
|
||||||
self.contains("A_eff")
|
|
||||||
or self.contains("A_eff_file")
|
|
||||||
or self.contains("effective_mode_diameter")
|
|
||||||
):
|
|
||||||
self.get("gamma", specified_parameters=["custom fiber model"])
|
|
||||||
self.get("n2", specified_parameters=["custom fiber model"])
|
|
||||||
self.setdefault("model", "custom")
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.get("model")
|
|
||||||
|
|
||||||
if self.model == "pcf":
|
|
||||||
self.get_fiber("pitch")
|
|
||||||
self.get_fiber("pitch_ratio")
|
|
||||||
|
|
||||||
elif self.model == "hasan":
|
|
||||||
self.get_multiple(
|
|
||||||
["capillary_spacing", "capillary_outer_d"], 1, fiber_model="hasan"
|
|
||||||
)
|
|
||||||
for param in [
|
|
||||||
"core_radius",
|
|
||||||
"capillary_num",
|
|
||||||
"capillary_thickness",
|
|
||||||
"capillary_resonance_strengths",
|
|
||||||
"capillary_nested",
|
|
||||||
]:
|
|
||||||
self.get_fiber(param)
|
|
||||||
else:
|
|
||||||
for param in hc_model_specific_parameters[self.model]:
|
|
||||||
self.get_fiber(param)
|
|
||||||
if self.contains("loss"):
|
|
||||||
if self.loss == "capillary":
|
|
||||||
for param in ["core_radius", "he_mode"]:
|
|
||||||
self.get_fiber(param)
|
|
||||||
for param in ["length", "input_transmission"]:
|
|
||||||
self.get(param)
|
|
||||||
|
|
||||||
def gas_consistency(self):
|
|
||||||
for param in ["gas_name", "temperature", "pressure", "plasma_density"]:
|
|
||||||
self.get(param, specified_params=["gas"])
|
|
||||||
|
|
||||||
def pulse_consistency(self):
|
|
||||||
for param in ["wavelength", "quantum_noise", "intensity_noise"]:
|
|
||||||
self.get(param)
|
|
||||||
|
|
||||||
if not self.contains("field_file"):
|
|
||||||
self.get("shape")
|
|
||||||
|
|
||||||
if self.contains("soliton_num"):
|
|
||||||
self.get_multiple(
|
|
||||||
["peak_power", "mean_power", "energy", "width", "t0"],
|
|
||||||
1,
|
|
||||||
specified_parameters=["soliton_num"],
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.get_multiple(["t0", "width"], 1)
|
|
||||||
self.get_multiple(["peak_power", "energy", "mean_power"], 1)
|
|
||||||
if self.contains("mean_power"):
|
|
||||||
self.get("repetition_rate", specified_parameters=["mean_power"])
|
|
||||||
|
|
||||||
def simulation_consistency(self):
|
|
||||||
self.get_multiple(["dt", "t_num", "time_window"], 2)
|
|
||||||
|
|
||||||
for param in [
|
|
||||||
"behaviors",
|
|
||||||
"z_num",
|
|
||||||
"tolerated_error",
|
|
||||||
"parallel",
|
|
||||||
"repeat",
|
|
||||||
"interpolation_range",
|
|
||||||
"interpolation_degree",
|
|
||||||
"ideal_gas",
|
|
||||||
"recovery_last_stored",
|
|
||||||
]:
|
|
||||||
self.get(param)
|
|
||||||
|
|
||||||
if (
|
|
||||||
any(["raman" in l for l in self.variable.get("behaviors", [])])
|
|
||||||
or "raman" in self.behaviors
|
|
||||||
):
|
|
||||||
self.get("raman_type", specified_parameters=["raman"])
|
|
||||||
|
|
||||||
def contains(self, key):
|
|
||||||
return self.variable.get(key) is not None or getattr(self, key) is not None
|
|
||||||
|
|
||||||
def get(self, param, **kwargs) -> Any:
|
|
||||||
"""checks if param is in the parameter section dict and attempts to fill in a default value
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
param : str
|
|
||||||
the name of the parameter (dict key)
|
|
||||||
kwargs : any
|
|
||||||
key word arguments passed to the MissingParameterError constructor
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
MissingFiberParameterError
|
|
||||||
raised when a parameter is missing and no default exists
|
|
||||||
"""
|
|
||||||
|
|
||||||
# whether the parameter is in the right place and valid is checked elsewhere,
|
|
||||||
# here, we just make sure it is present.
|
|
||||||
if not self.contains(param):
|
|
||||||
try:
|
|
||||||
setattr(self, param, default_parameters[param])
|
|
||||||
except KeyError:
|
|
||||||
raise MissingParameterError(param, **kwargs)
|
|
||||||
|
|
||||||
def get_fiber(self, param, **kwargs):
|
|
||||||
"""wrapper for fiber parameters that depend on fiber model"""
|
|
||||||
self.get(param, fiber_model=self.model, **kwargs)
|
|
||||||
|
|
||||||
def get_multiple(self, params, num, **kwargs):
|
|
||||||
"""similar to the get method but works with several parameters
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
params : list of str
|
|
||||||
names of the required parameters
|
|
||||||
num : int
|
|
||||||
how many of the parameters in params are required
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
MissingParameterError
|
|
||||||
raised when not enough parameters are provided and no defaults exist
|
|
||||||
"""
|
|
||||||
gotten = 0
|
|
||||||
for param in params:
|
|
||||||
try:
|
|
||||||
self.get(param, **kwargs)
|
|
||||||
gotten += 1
|
|
||||||
except MissingParameterError:
|
|
||||||
pass
|
|
||||||
if gotten >= num:
|
|
||||||
return
|
|
||||||
raise MissingParameterError(params, num_required=num, **kwargs)
|
|
||||||
|
|
||||||
def setdefault(self, param, value):
|
|
||||||
if getattr(self, param) is None:
|
|
||||||
setattr(self, param, value)
|
|
||||||
|
|
||||||
|
|
||||||
class ParamSequence:
|
class ParamSequence:
|
||||||
def __init__(self, config_dict: Union[dict[str, Any], os.PathLike, BareConfig]):
|
def __init__(self, config_dict: Union[dict[str, Any], os.PathLike, Config]):
|
||||||
"""creates a param sequence from a base config
|
"""creates a param sequence from a base config
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -197,7 +30,7 @@ class ParamSequence:
|
|||||||
"""
|
"""
|
||||||
if isinstance(config_dict, Config):
|
if isinstance(config_dict, Config):
|
||||||
self.config = config_dict
|
self.config = config_dict
|
||||||
elif isinstance(config_dict, BareConfig):
|
elif isinstance(config_dict, Config):
|
||||||
self.config = Config.from_bare(config_dict)
|
self.config = Config.from_bare(config_dict)
|
||||||
else:
|
else:
|
||||||
if not isinstance(config_dict, Mapping):
|
if not isinstance(config_dict, Mapping):
|
||||||
@@ -231,7 +64,7 @@ class ParamSequence:
|
|||||||
|
|
||||||
|
|
||||||
class ContinuationParamSequence(ParamSequence):
|
class ContinuationParamSequence(ParamSequence):
|
||||||
def __init__(self, prev_sim_dir: os.PathLike, new_config: BareConfig):
|
def __init__(self, prev_sim_dir: os.PathLike, new_config: Config):
|
||||||
"""Parameter sequence that builds on a previous simulation but with a new configuration
|
"""Parameter sequence that builds on a previous simulation but with a new configuration
|
||||||
It is recommended that only the fiber and the number of points stored may be changed and
|
It is recommended that only the fiber and the number of points stored may be changed and
|
||||||
changing other parameters could results in unexpected behaviors. The new config doesn't have to
|
changing other parameters could results in unexpected behaviors. The new config doesn't have to
|
||||||
@@ -245,9 +78,9 @@ class ContinuationParamSequence(ParamSequence):
|
|||||||
new config
|
new config
|
||||||
"""
|
"""
|
||||||
self.prev_sim_dir = Path(prev_sim_dir)
|
self.prev_sim_dir = Path(prev_sim_dir)
|
||||||
self.bare_configs = BareConfig.load_sequence(new_config.previous_config_file)
|
self.bare_configs = Config.load_sequence(new_config.previous_config_file)
|
||||||
self.bare_configs.append(new_config)
|
self.bare_configs.append(new_config)
|
||||||
self.bare_configs[0] = Config.from_bare(self.bare_configs[0])
|
self.bare_configs[0].check_validity()
|
||||||
final_config = parameter.final_config_from_sequence(*self.bare_configs)
|
final_config = parameter.final_config_from_sequence(*self.bare_configs)
|
||||||
super().__init__(final_config)
|
super().__init__(final_config)
|
||||||
|
|
||||||
@@ -293,7 +126,7 @@ class ContinuationParamSequence(ParamSequence):
|
|||||||
return count_variations(*self.bare_configs)
|
return count_variations(*self.bare_configs)
|
||||||
|
|
||||||
|
|
||||||
def count_variations(*bare_configs: BareConfig) -> int:
|
def count_variations(*bare_configs: Config) -> int:
|
||||||
sim_num = 1
|
sim_num = 1
|
||||||
for conf in bare_configs:
|
for conf in bare_configs:
|
||||||
for l in conf.variable.values():
|
for l in conf.variable.values():
|
||||||
@@ -310,7 +143,7 @@ class RecoveryParamSequence(ParamSequence):
|
|||||||
self.prev_sim_dir = None
|
self.prev_sim_dir = None
|
||||||
if self.config.prev_sim_dir is not None:
|
if self.config.prev_sim_dir is not None:
|
||||||
self.prev_sim_dir = Path(self.config.prev_sim_dir)
|
self.prev_sim_dir = Path(self.config.prev_sim_dir)
|
||||||
init_config = BareConfig.load(self.prev_sim_dir / "initial_config.toml")
|
init_config = Config.load(self.prev_sim_dir / "initial_config.toml")
|
||||||
self.prev_variable_lists = [
|
self.prev_variable_lists = [
|
||||||
(
|
(
|
||||||
set(variable_list[1:]),
|
set(variable_list[1:]),
|
||||||
@@ -403,7 +236,7 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
previous = None
|
previous = None
|
||||||
configs = BareConfig.load_sequence(*configs)
|
configs = Config.load_sequence(*configs)
|
||||||
for config in configs:
|
for config in configs:
|
||||||
# if (p := Path(config)).is_dir():
|
# if (p := Path(config)).is_dir():
|
||||||
# config = p / "initial_config.toml"
|
# config = p / "initial_config.toml"
|
||||||
|
|||||||
@@ -326,7 +326,12 @@ def load_and_adjust_field_file(
|
|||||||
field_0 = load_field_file(field_file, t)
|
field_0 = load_field_file(field_file, t)
|
||||||
if energy is not None:
|
if energy is not None:
|
||||||
curr_energy = np.trapz(abs2(field_0), t)
|
curr_energy = np.trapz(abs2(field_0), t)
|
||||||
field_0 *=
|
field_0 = field_0 * np.sqrt(energy / curr_energy)
|
||||||
|
elif peak_power is not None:
|
||||||
|
ratio = np.sqrt(peak_power / abs2(field_0).max())
|
||||||
|
field_0 = field_0 * ratio
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Not enough parameters specified to load {field_file} correctly")
|
||||||
|
|
||||||
field_0 = field_0 * modify_field_ratio(
|
field_0 = field_0 * modify_field_ratio(
|
||||||
t, field_0, peak_power, energy, intensity_noise, noise_correlation
|
t, field_0, peak_power, energy, intensity_noise, noise_correlation
|
||||||
@@ -1056,10 +1061,6 @@ def measure_field(t: np.ndarray, field: np.ndarray) -> Tuple[float, float, float
|
|||||||
return fwhm, peak_power, energy
|
return fwhm, peak_power, energy
|
||||||
|
|
||||||
|
|
||||||
def measure_custom_field(field_file: str, t: np.ndarray) -> float:
|
|
||||||
return measure_field(t, load_field_file(field_file, t))[0]
|
|
||||||
|
|
||||||
|
|
||||||
def remove_2nd_order_dispersion(
|
def remove_2nd_order_dispersion(
|
||||||
spectrum: T, w_c: np.ndarray, beta2: float, max_z: float = -100.0
|
spectrum: T, w_c: np.ndarray, beta2: float, max_z: float = -100.0
|
||||||
) -> tuple[T, OptimizeResult]:
|
) -> tuple[T, OptimizeResult]:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Type
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .. import env, initialize, utils
|
from .. import env, initialize, utils
|
||||||
from ..utils.parameter import Parameters, BareConfig
|
from ..utils.parameter import Parameters, Config, format_variable_list
|
||||||
from ..const import PARAM_SEPARATOR
|
from ..const import PARAM_SEPARATOR
|
||||||
from ..errors import IncompleteDataFolderError
|
from ..errors import IncompleteDataFolderError
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
@@ -471,7 +471,7 @@ class Simulations:
|
|||||||
|
|
||||||
def _run_available(self):
|
def _run_available(self):
|
||||||
for variable, params in self.param_seq:
|
for variable, params in self.param_seq:
|
||||||
v_list_str = utils.format_variable_list(variable)
|
v_list_str = format_variable_list(variable)
|
||||||
utils.save_parameters(params.prepare_for_dump(), self.sim_dir / v_list_str)
|
utils.save_parameters(params.prepare_for_dump(), self.sim_dir / v_list_str)
|
||||||
|
|
||||||
self.new_sim(v_list_str, params)
|
self.new_sim(v_list_str, params)
|
||||||
@@ -690,7 +690,7 @@ def run_simulation_sequence(
|
|||||||
method=None,
|
method=None,
|
||||||
prev_sim_dir: os.PathLike = None,
|
prev_sim_dir: os.PathLike = None,
|
||||||
):
|
):
|
||||||
configs = BareConfig.load_sequence(*config_files)
|
configs = Config.load_sequence(*config_files)
|
||||||
|
|
||||||
prev = prev_sim_dir
|
prev = prev_sim_dir
|
||||||
for config in configs:
|
for config in configs:
|
||||||
@@ -707,7 +707,7 @@ def run_simulation_sequence(
|
|||||||
|
|
||||||
|
|
||||||
def new_simulation(
|
def new_simulation(
|
||||||
config: BareConfig,
|
config: Config,
|
||||||
prev_sim_dir=None,
|
prev_sim_dir=None,
|
||||||
method: Type[Simulations] = None,
|
method: Type[Simulations] = None,
|
||||||
) -> Simulations:
|
) -> Simulations:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import numpy as np
|
|||||||
|
|
||||||
from ..initialize import validate_config_sequence
|
from ..initialize import validate_config_sequence
|
||||||
from ..utils import Paths
|
from ..utils import Paths
|
||||||
from ..utils.parameter import BareConfig
|
from ..utils.parameter import Config
|
||||||
|
|
||||||
|
|
||||||
def primes(n):
|
def primes(n):
|
||||||
@@ -128,7 +128,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.command == "merge":
|
if args.command == "merge":
|
||||||
final_name = BareConfig.load(Path(args.configs[0]) / "initial_config.toml").name
|
final_name = Config.load(Path(args.configs[0]) / "initial_config.toml").name
|
||||||
sim_num = "many"
|
sim_num = "many"
|
||||||
args.nodes = 1
|
args.nodes = 1
|
||||||
args.cpus_per_node = 1
|
args.cpus_per_node = 1
|
||||||
|
|||||||
@@ -146,8 +146,6 @@ class Pulse(Sequence):
|
|||||||
|
|
||||||
self.params = Parameters.load(self.path / "params.toml")
|
self.params = Parameters.load(self.path / "params.toml")
|
||||||
|
|
||||||
initialize.build_sim_grid_in_place(self.params)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.z = np.load(os.path.join(path, "z.npy"))
|
self.z = np.load(os.path.join(path, "z.npy"))
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -161,7 +159,7 @@ class Pulse(Sequence):
|
|||||||
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
||||||
|
|
||||||
self.t = self.params.t
|
self.t = self.params.t
|
||||||
w = initialize.wspace(self.t) + units.m(self.params.wavelength)
|
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
||||||
self.w_order = np.argsort(w)
|
self.w_order = np.argsort(w)
|
||||||
self.w = w
|
self.w = w
|
||||||
self.wl = units.m.inv(self.w)
|
self.wl = units.m.inv(self.w)
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ class Paths:
|
|||||||
|
|
||||||
|
|
||||||
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
|
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
|
||||||
|
prev_data_dir = Path(prev_data_dir)
|
||||||
num = find_last_spectrum_num(prev_data_dir)
|
num = find_last_spectrum_num(prev_data_dir)
|
||||||
return np.load(prev_data_dir / SPEC1_FN.format(num))
|
return np.load(prev_data_dir / SPEC1_FN.format(num))
|
||||||
|
|
||||||
|
|||||||
@@ -336,21 +336,6 @@ valid_variable = {
|
|||||||
"ideal_gas",
|
"ideal_gas",
|
||||||
}
|
}
|
||||||
|
|
||||||
hc_model_specific_parameters = dict(
|
|
||||||
marcatili=["core_radius", "he_mode"],
|
|
||||||
marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"],
|
|
||||||
hasan=[
|
|
||||||
"core_radius",
|
|
||||||
"capillary_num",
|
|
||||||
"capillary_thickness",
|
|
||||||
"capillary_resonance_strengths",
|
|
||||||
"capillary_nested",
|
|
||||||
"capillary_spacing",
|
|
||||||
"capillary_outer_d",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
|
|
||||||
|
|
||||||
mandatory_parameters = [
|
mandatory_parameters = [
|
||||||
"name",
|
"name",
|
||||||
"w_c",
|
"w_c",
|
||||||
@@ -360,6 +345,7 @@ mandatory_parameters = [
|
|||||||
"alpha",
|
"alpha",
|
||||||
"spec_0",
|
"spec_0",
|
||||||
"field_0",
|
"field_0",
|
||||||
|
"input_transmission",
|
||||||
"z_targets",
|
"z_targets",
|
||||||
"length",
|
"length",
|
||||||
"beta2_coefficients",
|
"beta2_coefficients",
|
||||||
@@ -451,7 +437,7 @@ class Parameters:
|
|||||||
time_window: float = Parameter(positive(float, int))
|
time_window: float = Parameter(positive(float, int))
|
||||||
dt: float = Parameter(in_range_excl(0, 5e-15))
|
dt: float = Parameter(in_range_excl(0, 5e-15))
|
||||||
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
|
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
|
||||||
step_size: float = Parameter(positive(float, int), default=0)
|
step_size: float = Parameter(non_negative(float, int), default=0)
|
||||||
interpolation_range: tuple[float, float] = Parameter(float_pair)
|
interpolation_range: tuple[float, float] = Parameter(float_pair)
|
||||||
interpolation_degree: int = Parameter(positive(int), default=8)
|
interpolation_degree: int = Parameter(positive(int), default=8)
|
||||||
prev_sim_dir: str = Parameter(string)
|
prev_sim_dir: str = Parameter(string)
|
||||||
@@ -463,7 +449,7 @@ class Parameters:
|
|||||||
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
beta2: float = Parameter(type_checker(int, float))
|
beta2: float = Parameter(type_checker(int, float))
|
||||||
alpha_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
alpha_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
alpha: float = Parameter(positive(float, int), default=0)
|
alpha: float = Parameter(non_negative(float, int), default=0)
|
||||||
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
gamma_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
A_eff_arr: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
w: np.ndarray = Parameter(type_checker(np.ndarray))
|
w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
@@ -588,6 +574,7 @@ class Rule:
|
|||||||
kwarg_names: list[str],
|
kwarg_names: list[str],
|
||||||
n_var: int,
|
n_var: int,
|
||||||
args_const: list[str] = None,
|
args_const: list[str] = None,
|
||||||
|
priorities: Union[int, list[int]] = None,
|
||||||
) -> list["Rule"]:
|
) -> list["Rule"]:
|
||||||
"""given a function that doesn't need all its keyword arguemtn specified, will
|
"""given a function that doesn't need all its keyword arguemtn specified, will
|
||||||
return a list of Rule obj, one for each combination of n_var specified kwargs
|
return a list of Rule obj, one for each combination of n_var specified kwargs
|
||||||
@@ -625,7 +612,7 @@ class Rule:
|
|||||||
|
|
||||||
new_func = func_rewrite(func, list(var_possibility), args_const)
|
new_func = func_rewrite(func, list(var_possibility), args_const)
|
||||||
|
|
||||||
rules.append(cls(target, new_func))
|
rules.append(cls(target, new_func, priorities=priorities))
|
||||||
return rules
|
return rules
|
||||||
|
|
||||||
|
|
||||||
@@ -732,8 +719,7 @@ class Evaluator:
|
|||||||
prefix
|
prefix
|
||||||
+ f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}"
|
+ f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}"
|
||||||
)
|
)
|
||||||
self.params[param_name] = returned_value
|
self.set_value(param_name, returned_value, param_priority)
|
||||||
self.eval_stats[param_name].priority = param_priority
|
|
||||||
if param_name == target:
|
if param_name == target:
|
||||||
value = returned_value
|
value = returned_value
|
||||||
break
|
break
|
||||||
@@ -749,6 +735,7 @@ class Evaluator:
|
|||||||
error = NoDefaultError(prefix + f"No default provided for {target}")
|
error = NoDefaultError(prefix + f"No default provided for {target}")
|
||||||
else:
|
else:
|
||||||
value = default
|
value = default
|
||||||
|
self.set_value(target, value, 0)
|
||||||
|
|
||||||
if value is None and error is not None:
|
if value is None and error is not None:
|
||||||
raise error
|
raise error
|
||||||
@@ -756,6 +743,13 @@ class Evaluator:
|
|||||||
self.__curent_lookup.remove(target)
|
self.__curent_lookup.remove(target)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
return self.params[key]
|
||||||
|
|
||||||
|
def set_value(self, key: str, value: Any, priority: int):
|
||||||
|
self.params[key] = value
|
||||||
|
self.eval_stats[key].priority = priority
|
||||||
|
|
||||||
def validate_condition(self, rule: Rule) -> bool:
|
def validate_condition(self, rule: Rule) -> bool:
|
||||||
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
||||||
|
|
||||||
@@ -779,18 +773,25 @@ class Evaluator:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BareConfig(Parameters):
|
class Config(Parameters):
|
||||||
variable: dict = VariableParameter(Parameters)
|
variable: dict = VariableParameter(Parameters)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def check_validity(self):
|
||||||
|
conf_dict = asdict(self)
|
||||||
|
variable = conf_dict.pop("variable", {})
|
||||||
|
for k, v in variable.items():
|
||||||
|
conf_dict[k] = v[0]
|
||||||
|
Parameters(**conf_dict)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: os.PathLike) -> "BareConfig":
|
def load(cls, path: os.PathLike) -> "Config":
|
||||||
return cls(**utils.load_toml(path))
|
return cls(**utils.load_toml(path))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_sequence(cls, *config_paths: os.PathLike) -> list["BareConfig"]:
|
def load_sequence(cls, *config_paths: os.PathLike) -> list["Config"]:
|
||||||
"""Loads a sequence of
|
"""Loads a sequence of
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -830,8 +831,13 @@ class PlotRange:
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||||
|
|
||||||
|
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||||
|
return sort_axis(axis, self)
|
||||||
|
|
||||||
def sort_axis(axis, plt_range: PlotRange) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
|
||||||
|
def sort_axis(
|
||||||
|
axis: np.ndarray, plt_range: PlotRange
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||||
"""
|
"""
|
||||||
given an axis, returns this axis cropped according to the given range, converted and sorted
|
given an axis, returns this axis cropped according to the given range, converted and sorted
|
||||||
|
|
||||||
@@ -893,7 +899,7 @@ def validate_arg_names(names: list[str]):
|
|||||||
raise ValueError(f"{n} is an invalid parameter name")
|
raise ValueError(f"{n} is an invalid parameter name")
|
||||||
|
|
||||||
|
|
||||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
|
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
||||||
if arg_names is None:
|
if arg_names is None:
|
||||||
arg_names = get_arg_names(func)
|
arg_names = get_arg_names(func)
|
||||||
else:
|
else:
|
||||||
@@ -972,7 +978,7 @@ def pretty_format_from_sim_name(name: str) -> str:
|
|||||||
return PARAM_SEPARATOR.join(out)
|
return PARAM_SEPARATOR.join(out)
|
||||||
|
|
||||||
|
|
||||||
def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]:
|
def variable_iterator(config: Config) -> Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]:
|
||||||
"""given a config with "variable" parameters, iterates through every possible combination,
|
"""given a config with "variable" parameters, iterates through every possible combination,
|
||||||
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
|
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
|
||||||
|
|
||||||
@@ -1011,7 +1017,7 @@ def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]
|
|||||||
|
|
||||||
|
|
||||||
def required_simulations(
|
def required_simulations(
|
||||||
*configs: BareConfig,
|
*configs: Config,
|
||||||
) -> Iterator[tuple[list[tuple[str, Any]], Parameters]]:
|
) -> Iterator[tuple[list[tuple[str, Any]], Parameters]]:
|
||||||
"""takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different
|
"""takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different
|
||||||
parameter set and iterates through every single necessary simulation
|
parameter set and iterates through every single necessary simulation
|
||||||
@@ -1045,11 +1051,11 @@ def reduce_all_variable(all_variable: list[list[tuple[str, Any]]]) -> list[tuple
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig:
|
def override_config(new: Config, old: Config = None) -> Config:
|
||||||
"""makes sure all the parameters set in new are there, leaves untouched parameters in old"""
|
"""makes sure all the parameters set in new are there, leaves untouched parameters in old"""
|
||||||
new_dict = asdict(new)
|
new_dict = asdict(new)
|
||||||
if old is None:
|
if old is None:
|
||||||
return BareConfig(**new_dict)
|
return Config(**new_dict)
|
||||||
variable = deepcopy(old.variable)
|
variable = deepcopy(old.variable)
|
||||||
new_dict = {k: v for k, v in new_dict.items() if v is not None}
|
new_dict = {k: v for k, v in new_dict.items() if v is not None}
|
||||||
|
|
||||||
@@ -1060,7 +1066,7 @@ def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig:
|
|||||||
return replace(old, variable=variable, **new_dict)
|
return replace(old, variable=variable, **new_dict)
|
||||||
|
|
||||||
|
|
||||||
def final_config_from_sequence(*configs: BareConfig) -> BareConfig:
|
def final_config_from_sequence(*configs: Config) -> Config:
|
||||||
if len(configs) == 0:
|
if len(configs) == 0:
|
||||||
raise ValueError("Must provide at least one config")
|
raise ValueError("Must provide at least one config")
|
||||||
if len(configs) == 1:
|
if len(configs) == 1:
|
||||||
@@ -1085,10 +1091,11 @@ default_rules: list[Rule] = [
|
|||||||
Rule("spec_0", np.fft.fft, ["field_0"]),
|
Rule("spec_0", np.fft.fft, ["field_0"]),
|
||||||
Rule("field_0", np.fft.ifft, ["spec_0"]),
|
Rule("field_0", np.fft.ifft, ["spec_0"]),
|
||||||
Rule("spec_0", utils.load_previous_spectrum, priorities=3),
|
Rule("spec_0", utils.load_previous_spectrum, priorities=3),
|
||||||
Rule(
|
*Rule.deduce(
|
||||||
["pre_field_0", "peak_power", "energy", "width"],
|
["pre_field_0", "peak_power", "energy", "width"],
|
||||||
pulse.load_and_adjust_field_file,
|
pulse.load_and_adjust_field_file,
|
||||||
["field_file", "t", "peak_power", "energy", "intensity_noise", "noise_correlation"],
|
["energy", "peak_power"],
|
||||||
|
1,
|
||||||
priorities=[2, 1, 1, 1],
|
priorities=[2, 1, 1, 1],
|
||||||
),
|
),
|
||||||
Rule("pre_field_0", pulse.initial_field, priorities=1),
|
Rule("pre_field_0", pulse.initial_field, priorities=1),
|
||||||
@@ -1099,7 +1106,6 @@ default_rules: list[Rule] = [
|
|||||||
),
|
),
|
||||||
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
||||||
Rule("peak_power", pulse.soliton_num_to_peak_power),
|
Rule("peak_power", pulse.soliton_num_to_peak_power),
|
||||||
Rule(["width", "peak_power", "energy"], pulse.measure_custom_field),
|
|
||||||
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
|
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
|
||||||
Rule("energy", pulse.mean_power_to_energy),
|
Rule("energy", pulse.mean_power_to_energy),
|
||||||
Rule("t0", pulse.width_to_t0),
|
Rule("t0", pulse.width_to_t0),
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import toml
|
|||||||
from scgenerator import defaults, utils, math
|
from scgenerator import defaults, utils, math
|
||||||
from scgenerator.errors import *
|
from scgenerator.errors import *
|
||||||
from scgenerator.physics import pulse, units
|
from scgenerator.physics import pulse, units
|
||||||
from scgenerator.utils.parameter import BareConfig, Parameters
|
from scgenerator.utils.parameter import Config, Parameters
|
||||||
|
|
||||||
|
|
||||||
def load_conf(name):
|
def load_conf(name):
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ def conf_maker(folder, val=True):
|
|||||||
if val:
|
if val:
|
||||||
return initialize.Config(**load_conf(folder + "/" + name))
|
return initialize.Config(**load_conf(folder + "/" + name))
|
||||||
else:
|
else:
|
||||||
return initialize.BareConfig(**load_conf(folder + "/" + name))
|
return initialize.Config(**load_conf(folder + "/" + name))
|
||||||
|
|
||||||
return conf
|
return conf
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user