Files
scgenerator/src/scgenerator/evaluator.py
Benoît Sierro 4b5563bf54 units refactor
2023-09-25 16:14:38 +02:00

614 lines
22 KiB
Python

from __future__ import annotations
import inspect
import traceback
from collections import ChainMap, defaultdict
from functools import cache
from inspect import Parameter as PARAM
from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Union
import numpy as np
from scgenerator import io, math, operators, utils
from scgenerator.const import INF, MANDATORY_PARAMETERS
from scgenerator.physics import fiber, materials, plasma, pulse, units
from scgenerator.utils import get_logger
class ErrorRecord(NamedTuple):
error: Exception
lookup_stack: tuple[str]
rules_stack: tuple[Rule]
traceback: str
def pretty_format(self) -> str:
if self.rules_stack:
return "\n".join(
[
*(rule.func_name for rule in self.rules_stack[:-1]),
self.traceback,
self.rules_stack[-1].pretty_format(),
str(self.error),
]
)
else:
return "\n".join(
[
self.traceback,
str(self.error),
]
)
class EvaluatorError(Exception):
target: str | None = None
class NoValidRuleError(EvaluatorError):
def __init__(self, target: str):
self.target = target
super().__init__(f"no valid rule to compute {target!r}")
class CyclicDependencyError(EvaluatorError):
def __init__(self, target: str, lookup_stack: list[str]):
self.target = target
cycle = "".join(lookup_stack)
super().__init__(f"cycle detected while computing {target!r}:\n{cycle}")
class AllRulesExhaustedError(EvaluatorError):
def __init__(self, target: str):
self.target = target
super().__init__(f"every rule for {target!r} failed and no default value is set")
class EvaluatorErrorTree:
target: str
all_errors: list[ErrorRecord]
def __init__(self, target: str):
self.target = target
self.all_errors = []
def __repr__(self) -> str:
return f"{self.__class__.__name__}(target={self.target!r})"
def __len__(self) -> int:
return len(self.all_errors)
def append(self, error: Exception, lookup_stack: list[str], rules_stack: list[Rule]):
tr = traceback.format_exc().splitlines()
tr = "\n".join(tr[:1] + tr[3:]) # get rid of the `EvaluatorErrorTree.append` frame
self.all_errors.append(ErrorRecord(error, tuple(lookup_stack), tuple(rules_stack), tr))
def compile_error(self) -> EvaluatorError:
line = "\n" + "-" * 80 + "\n"
failed_rules = line.join(rec.pretty_format() for rec in self.all_errors)
raise EvaluatorError(
f"Couldn't compute {self.target}. {len(self)} rules failed.{line}{failed_rules}"
)
class Rule:
targets: dict[str, int]
func: Callable
args: tuple[str]
conditions: dict[str, Any]
arg_defaults: dict[str, Any]
mock_func: Callable
def __init__(
self,
target: Union[str, list[Optional[str]]],
func: Callable,
args: tuple[str] | None = None,
priorities: Union[int, list[int]] | None = None,
conditions: dict[str, str] | None = None,
defaults: dict[str, Any] | None = None,
):
targets = list(target) if isinstance(target, (list, tuple)) else [target]
self.func = func
if priorities is None:
priorities = [0] * len(targets)
elif isinstance(priorities, (int, float, np.integer, np.floating)):
priorities = [priorities] * len(targets)
self.targets = dict(zip(targets, priorities))
func_args, func_defaults = get_arg_names(func)
if args is not None:
try:
for func_arg, user_arg in zip(func_args, args, strict=True):
if func_arg in func_defaults:
func_defaults[user_arg] = func_defaults.pop(func_arg)
except ValueError as e:
raise ValueError(
f"length mismatch between arguments of {func.__name__}: "
f"{func_args!r} and provided ones: {args!r}"
) from e
func_args = args
func_defaults |= defaults or {}
self.args = tuple(func_args)
self.arg_defaults = func_defaults
self.mock_func = _mock_function(len(self.args), len(self.targets))
self.conditions = conditions or {}
def __repr__(self) -> str:
return f"Rule(targets={self.targets!r}, func={self.func_name}, args={self.args!r})"
def __str__(self) -> str:
return (
f"[{', '.join(self.args)}] -- {self.func.__module__}."
f"{self.func.__name__} --> [{', '.join(self.targets)}]"
)
def __eq__(self, other: Rule) -> bool:
return (
self.args == other.args
and tuple(self.targets) == tuple(other.targets)
and self.func == other.func
)
def __hash__(self) -> int:
return hash((self.args, self.func, tuple(self.targets)))
def pretty_format(self) -> str:
return io.format_graph(self.args, self.func_name, self.targets)
@property
def func_name(self) -> str:
return f"{self.func.__module__}.{self.func.__name__}"
class EvaluatedValue(NamedTuple):
value: Any
priority: float = INF
rule: Rule = None
class Evaluator:
defaults: dict[str, Any] = {}
rules: dict[str, list[Rule]]
main_map: dict[str, EvaluatedValue]
lookup_stack: list[str]
@classmethod
def default(cls, full_field: bool = False) -> "Evaluator":
evaluator = cls()
logger = get_logger(__name__)
if full_field:
logger.debug("Full field simulation")
evaluator.append(*full_field_rules)
else:
logger.debug("Envelope simulation")
evaluator.append(*envelope_rules)
return evaluator
@classmethod
def evaluate_default(cls, params: dict[str, Any], check_only=False) -> dict[str, Any]:
evaluator = cls.default(params.get("full_field", False))
evaluator.set(**params)
for target in MANDATORY_PARAMETERS:
evaluator.compute(target, check_only=check_only)
return evaluator.main_map
@classmethod
def register_default_param(cls, key, value):
cls.defaults[key] = value
def __init__(self, *rules: Rule):
self.rules = defaultdict(list)
self.main_map = {}
self.logger = get_logger(__name__)
self.append(*rules)
def __getitem__(self, key: str) -> Any:
return self.main_map[key].value
def append(self, *rule: Rule):
for r in rule:
for t in r.targets:
if t is None:
continue
self.rules[t].append(r)
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
def set(self, **params: Any):
"""
sets the internal set of parameters
Parameters
----------
dico : dict, optional
if given, replace current dict of parameters with this one
(not a copy of it), by default None
params : Any
if dico is None, update the internal dict of parameters with params
"""
for k, v in params.items():
self.main_map[k] = EvaluatedValue(v, INF)
def reset(self):
self.main_map = {}
def compute(self, target: str, check_only=False) -> Any:
"""
computes a target
Parameters
----------
target : str
name of the target
Returns
-------
Any
return type of the target function
Raises
------
EvaluatorError
a cyclic dependence exists
KeyError
there is no saved rule for the target
"""
errors = EvaluatorErrorTree(target)
param_chain_map = ChainMap(self.main_map)
lookup_stack = []
rules_stack = []
try:
value = self._compute(
target, check_only, param_chain_map, lookup_stack, rules_stack, errors
)
except EvaluatorError as e:
errors.append(e, lookup_stack, rules_stack)
raise errors.compile_error() from None
self.merge_chain_map(param_chain_map)
return value
def _compute(
self,
target: str,
check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str],
rules_stack: list[Rule],
errors: EvaluatorErrorTree,
) -> Any:
if target in param_chain_map and param_chain_map[target].value is not None:
return param_chain_map[target].value
if target not in self.rules or len(self.rules[target]) == 0:
raise NoValidRuleError(target)
if target in lookup_stack:
raise CyclicDependencyError(target, lookup_stack)
lookup_stack.append(target)
base_cm_length = len(param_chain_map.maps)
for rule in self.rules[target]:
rules_stack.append(rule)
values = self.apply_rule(
rule, check_only, param_chain_map, lookup_stack, rules_stack, errors
)
if self.valid_values(
values, rule, check_only, param_chain_map, lookup_stack, rules_stack, errors
):
break
rules_stack.pop()
lookup_stack.pop()
if values is None:
if target in self.defaults:
values = {target: self.defaults[target]}
else:
self.clear_chain_map(param_chain_map, base_cm_length)
raise AllRulesExhaustedError(target)
param_chain_map.maps.insert(0, values)
return values[target].value
def valid_values(
self,
values: dict[str, EvaluatedValue] | None,
rule: Rule,
check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str],
rules_stack: list[Rule],
errors: EvaluatorErrorTree,
) -> bool:
if values is None:
return False
if check_only:
return True
for arg, target_value in rule.conditions.items():
value = self._compute(arg, False, param_chain_map, lookup_stack, rules_stack, errors)
if value != target_value:
return False
return True
def apply_rule(
self,
rule: Rule,
check_only: bool,
param_chain_map: MutableMapping[str, EvaluatedValue],
lookup_stack: list[str],
rules_stack: list[Rule],
errors: EvaluatorErrorTree,
) -> dict[str, EvaluatedValue] | None:
arg_values = []
for arg in rule.args:
try:
val = self._compute(
arg, check_only, param_chain_map, lookup_stack, rules_stack, errors
)
except Exception as e:
if arg in rule.arg_defaults:
val = rule.arg_defaults[arg]
else:
errors.append(e, lookup_stack, rules_stack)
return None
arg_values.append(val)
func = rule.mock_func if check_only else rule.func
try:
values = func(*arg_values)
except Exception as e:
errors.append(e, lookup_stack, rules_stack)
return None
if not isinstance(values, tuple):
values = (values,)
return {k: EvaluatedValue(v, rule.targets[k], rule) for k, v in zip(rule.targets, values)}
def merge_chain_map(self, param_chain_map: dict[str, EvaluatedValue]):
while len(param_chain_map.maps) > 1:
params = param_chain_map.maps.pop(0)
for k, v in params.items():
target_priority = self.main_map[k].priority if k in self.main_map else -INF
if v.priority > target_priority:
self.main_map[k] = v
def clear_chain_map(
self, param_chain_map: MutableMapping[str, EvaluatedValue], target_size: int
):
while len(param_chain_map.maps) > target_size:
param_chain_map.maps.pop(0)
def clear_computed(self):
"""deletes all computed values to leave only hard-set ones"""
self.main_map = {k: v for k, v in self.main_map.items() if v.rule is None}
def validate_condition(self, rule: Rule) -> bool:
try:
return all(self.compute(k) == v for k, v in rule.conditions.items())
except EvaluatorError:
return False
def get_arg_names(func: Callable) -> tuple[list[str], dict[str, Any]]:
"""
returns the positional argument names of func.
Parameters
----------
func : Callable
if a function, returns the names of the positional arguments
Returns
-------
list[str]
list of argument names
dict[str, Any]
default values of arguments, if provided
"""
defaults = {}
args = []
for p in inspect.signature(func).parameters.values():
if p.kind in {PARAM.VAR_KEYWORD, PARAM.VAR_POSITIONAL}:
raise TypeError(
f"function {func.__name__} has variadic argument {p.name!r}, which is not allowed"
)
if p.kind is PARAM.KEYWORD_ONLY:
if p.default is PARAM.empty:
raise TypeError(
f"function {func.__name__} has keyword-only argument {p.name!r} "
"with no default. This is not supported at the moment."
)
continue
args.append(p.name)
if p.default is not PARAM.empty:
defaults[p.name] = p.default
return args, defaults
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func
default_rules: list[Rule] = [
# Grid
Rule(["t", "time_window", "dt", "t_num"], math.build_t_grid),
Rule("z_targets", math.build_z_grid),
Rule("adapt_step_size", lambda step_size: step_size == 0),
Rule(
"dynamic_dispersion",
lambda pressure: isinstance(pressure, (list, tuple, np.ndarray)),
),
Rule("w0", units.m_rads, ["wavelength"]),
Rule("l", units.m_rads, ["w"]),
Rule("w0_ind", math.argclosest, ["w_for_disp", "w0"]),
Rule("w_num", len, ["w"]),
Rule("dw", lambda w: w[1] - w[0]),
Rule(["fft", "ifft"], utils.fft_functions, priorities=1),
Rule("wavelength_window", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)),
# Pulse
Rule("field_0", pulse.finalize_pulse),
Rule(["input_time", "input_field"], pulse.load_custom_field),
Rule("spec_0", lambda fft, field_0: fft(field_0)),
Rule("field_0", lambda ifft, spec_0: ifft(spec_0)),
Rule(
["pre_field_0", "peak_power", "energy", "width"],
pulse.adjust_custom_field,
priorities=[2, 1, 1, 1],
),
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
Rule("peak_power", pulse.soliton_num_to_peak_power),
Rule("mean_power", pulse.energy_to_mean_power),
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
Rule("energy", pulse.mean_power_to_energy, priorities=2),
Rule("t0", pulse.width_to_t0),
Rule("t0", pulse.soliton_num_to_t0),
Rule("width", pulse.t0_to_width),
Rule("soliton_num", pulse.soliton_num),
Rule("dispersion_length", pulse.dispersion_length),
Rule("nonlinear_length", pulse.nonlinear_length),
Rule("soliton_length", pulse.soliton_length),
Rule("c_to_a_factor", lambda: 1, priorities=-1),
# Fiber Dispersion
Rule("w_for_disp", units.m_rads, ["wl_for_disp"]),
Rule("gas_info", materials.Gas),
Rule("chi_gas", lambda gas_info, wl_for_disp: gas_info.sellmeier.chi(wl_for_disp)),
Rule("n_gas_2", materials.n_gas_2),
Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")),
Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")),
Rule(
"n_eff",
fiber.n_eff_marcatili_adjusted,
conditions=dict(model="marcatili_adjusted"),
),
Rule("n_eff", fiber.n_eff_pcf, conditions=dict(model="pcf")),
Rule("n0", lambda w0_ind, n_eff: n_eff[w0_ind]),
Rule("capillary_spacing", fiber.capillary_spacing_hasan),
Rule("capillary_resonance_strengths", fiber.capillary_resonance_strengths),
Rule("capillary_resonance_strengths", lambda: [], priorities=-1),
Rule("beta_arr", fiber.beta),
Rule("beta1_arr", fiber.beta1),
Rule("beta2_arr", fiber.beta2),
Rule(
"zero_dispersion_wavelength",
lambda beta2_arr, wl_for_disp: wl_for_disp[math.argclosest(beta2_arr, 0)],
),
# Fiber nonlinearity
Rule("effective_area", fiber.effective_area_pcf),
Rule("effective_area", fiber.effective_area_from_V, priorities=-1),
Rule("effective_area", fiber.effective_area_from_diam),
Rule("effective_area", fiber.effective_area_hasan, conditions=dict(model="hasan")),
Rule("effective_area", fiber.effective_area_from_gamma, priorities=-1),
Rule("effective_area", fiber.effective_area_marcatili, priorities=-2),
Rule("effective_area_arr", fiber.effective_area_from_V, ["core_radius", "V_eff_arr"]),
Rule("effective_area_arr", fiber.load_custom_effective_area),
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
Rule("V_eff_arr", fiber.V_eff_step_index),
Rule("V_eff", lambda V_eff_arr: V_eff_arr[0]),
Rule("n2", materials.gas_n2),
Rule("n2", lambda: 2.2e-20, priorities=-1),
Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1),
Rule("gamma", fiber.gamma_parameter),
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effective_area_arr"]),
# Raman
Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w),
Rule("raman_fraction", fiber.raman_fraction),
Rule("raman_fraction", lambda: 0, priorities=-1),
# loss
Rule("alpha_arr", fiber.scalar_loss),
Rule("alpha_arr", fiber.safe_capillary_loss, conditions=dict(loss="capillary")),
# operators
Rule("n_eff_op", operators.marcatili_refractive_index),
Rule("n_eff_op", operators.marcatili_adjusted_refractive_index),
Rule("n_eff_op", operators.vincetti_refractive_index),
Rule("gas_op", operators.ConstantGas),
Rule("gas_op", operators.PressureGradientGas),
Rule("square_index", lambda gas_op: gas_op.square_index),
Rule("number_density", lambda gas_op: gas_op.number_density),
Rule("n2_op", lambda gas_op: gas_op.n2),
Rule("chi3_op", lambda gas_op: gas_op.chi3),
Rule("loss_op", operators.constant_quantity, ["alpha_arr"]),
Rule("loss_op", lambda: operators.constant_quantity(0), priorities=-1),
]
envelope_rules = default_rules + [
# Grid
Rule(["w_c", "w", "w_order"], math.build_envelope_w_grid),
Rule("dt", math._dt_from_wl_window),
# Pulse
Rule("spectrum_factor", pulse.spectrum_factor_envelope, priorities=-1),
Rule("pre_field_0", pulse.initial_field_envelope, priorities=1),
Rule("c_to_a_factor", pulse.c_to_a_factor),
# Dispersion
Rule(["wl_for_disp", "dispersion_ind"], fiber.lambda_for_envelope_dispersion),
Rule("beta2_coefficients", fiber.dispersion_coefficients),
Rule("beta2_arr", fiber.dispersion_from_coefficients),
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
Rule(
["wl_for_disp", "beta2_arr", "wavelength_window"],
fiber.load_custom_dispersion,
priorities=[2, 2, 2],
),
# Operators
Rule("gamma_op", operators.variable_gamma, priorities=2),
Rule("gamma_op", operators.constant_quantity, ["gamma_arr"], priorities=1),
Rule("gamma_op", lambda w_num, gamma: operators.constant_quantity(np.ones(w_num) * gamma)),
Rule("gamma_op", lambda: operators.constant_quantity(0.0), priorities=-1),
Rule("ss_op", lambda w_c, w0: operators.constant_quantity(w_c / w0)),
Rule("ss_op", lambda: operators.constant_quantity(0), priorities=-1),
Rule("spm_op", operators.envelope_spm, conditions=dict(spm=True)),
Rule("spm_op", operators.no_op_freq, priorities=-1),
Rule("raman_op", operators.envelope_raman),
Rule("raman_op", operators.no_op_freq, priorities=-1),
Rule("nonlinear_operator", operators.envelope_nonlinear_operator),
Rule("dispersion_op", operators.constant_polynomial_dispersion),
Rule("dispersion_op", operators.constant_direct_dispersion),
Rule("dispersion_op", operators.direct_dispersion),
Rule("linear_operator", operators.envelope_linear_operator),
Rule("conserved_quantity", operators.conserved_quantity),
]
full_field_rules = default_rules + [
# Grid
Rule(["w", "w_order", "l"], math.build_full_field_w_grid, priorities=1),
# Pulse
Rule("pre_field_0", pulse.initial_full_field),
Rule("spectrum_factor", pulse.spectrum_factor_fullfield, priorities=-1),
# Dispersion
Rule(["wl_for_disp", "dispersion_ind"], fiber.lambda_for_full_field_dispersion),
Rule("frame_velocity", fiber.frame_velocity),
Rule("beta2", lambda beta2_arr, w0_ind: beta2_arr[w0_ind]),
# Nonlinearity
Rule("chi3", materials.gas_chi3),
Rule("plasma_obj", lambda dt, gas_info: plasma.Plasma(dt, gas_info.ionization_energy)),
# Operators
Rule("spm_op", operators.full_field_spm),
Rule("spm_op", operators.no_op_freq, priorities=-1),
Rule("beta_op", operators.constant_wave_vector),
Rule("linear_operator", operators.full_field_linear_operator),
Rule("plasma_op", operators.ionization, conditions=dict(photoionization=True)),
Rule("plasma_op", operators.no_op_freq, priorities=-1),
Rule("raman_op", operators.no_op_freq, priorities=-1),
Rule("nonlinear_operator", operators.full_field_nonlinear_operator),
]