diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index ccccd6e..ec312c2 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -5,7 +5,6 @@ from typing import Any, Callable, Optional, Union import numpy as np -from scgenerator import solver from . import math, operators, utils from .const import MANDATORY_PARAMETERS @@ -378,17 +377,7 @@ default_rules: list[Rule] = [ Rule("n_op", operators.HasanRefractiveIndex), Rule("raman_op", operators.NoRaman, priorities=-1), Rule("loss_op", operators.NoLoss, priorities=-1), - # solvers - Rule("integrator", solver.ConstantStepIntegrator, conditions=dict(adapt_step_size=False)), - Rule( - "integrator", - solver.ConservedQuantityIntegrator, - conditions=dict(adapt_step_size=True), - priorities=1, - ), - Rule("integrator", solver.RK4IPSD, conditions=dict(adapt_step_size=True)), - Rule("integrator", solver.ERK43, conditions=dict(adapt_step_size=True)), - Rule("integrator", solver.ERK54, conditions=dict(adapt_step_size=True), priorities=1), + Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1), ] envelope_rules = default_rules + [ diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index c1c8815..7328ef7 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -20,7 +20,7 @@ from .evaluator import Evaluator from .logger import get_logger from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator from .solver import Integrator -from .utils import fiber_folder, update_path_name +from .utils import DebugDict, fiber_folder, update_path_name from .variationer import VariationDescriptor, Variationer T = TypeVar("T") @@ -287,7 +287,7 @@ class Parameters: """ # internal machinery - _param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False) + _param_dico: dict[str, Any] = field(init=False, default_factory=DebugDict, repr=False) _evaluator: Evaluator = field(init=False, repr=False) _p_names: ClassVar[Set[str]] = set() @@ -346,7 +346,6 @@ class Parameters: mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW")) energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ")) soliton_num: float = Parameter(non_negative(float, int)) - quantum_noise: bool = Parameter(boolean, default=False) additional_noise_factor: float = Parameter(positive(float, int), default=1) shape: str = Parameter(literal("gaussian", "sech"), default="gaussian") wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm")) @@ -355,12 +354,19 @@ class Parameters: width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) + # Behaviors to include + quantum_noise: bool = Parameter(boolean, default=False) + self_steepening: bool = Parameter(boolean, default=True) + ideal_gas: bool = Parameter(boolean, default=False) + photoionization: bool = Parameter(boolean, default=False) + # simulation full_field: bool = Parameter(boolean, default=False) + integration_scheme: str = Parameter( + literal("erk43", "erk54", "cqe", "sd", "constant"), converter=str.lower, default="erk43" + ) raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower) - self_steepening: bool = Parameter(boolean, default=True) spm: bool = Parameter(boolean, default=True) - ideal_gas: bool = Parameter(boolean, default=False) repeat: int = Parameter(positive(int), default=1) t_num: int = Parameter(positive(int)) z_num: int = Parameter(positive(int)) diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 8f99cdf..52d2f37 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -163,17 +163,14 @@ class RK4IP: store = False state = self.init_state.copy() yield len(self.stored_spectra) - 1, state - if self.params.adapt_step_size: - integrator = solver.RK4IPSD( - state, - self.params.linear_operator, - self.params.nonlinear_operator, - self.params.tolerated_error, - ) - else: - integrator = solver.ConstantStepIntegrator( - state, self.params.linear_operator, self.params.nonlinear_operator - ) + + integrator_args = [ + self.params.compute(a) for a in solver.Integrator.factory_args() if a != "init_state" + ] + integrator = solver.Integrator.create( + self.params.integration_scheme, state, *integrator_args + ) + for state in integrator: new_tracked_values = integrator.all_values() diff --git a/src/scgenerator/solver.py b/src/scgenerator/solver.py index beba2dc..cea643e 100644 --- a/src/scgenerator/solver.py +++ b/src/scgenerator/solver.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections import defaultdict import logging from abc import abstractmethod @@ -15,6 +16,39 @@ from .operators import ( NonLinearOperator, ValueTracker, ) +from .utils import get_arg_names +import warnings + +warnings.filterwarnings("error") + + +class IntegratorFactory: + arg_registry: dict[str, dict[str, int]] + cls_registry: dict[str, Type[Integrator]] + all_arg_names: list[str] + + def __init__(self): + self.arg_registry = defaultdict(dict) + self.cls_registry = {} + + def register(self, name: str, cls: Type[Integrator]): + self.cls_registry[name] = cls + arg_names = [a for a in get_arg_names(cls.__init__)[1:] if a not in {"init_state", "self"}] + for i, a_name in enumerate(arg_names): + self.arg_registry[a_name][name] = i + self.all_arg_names = list(self.arg_registry.keys()) + + def create(self, scheme: str, state: CurrentState, *args) -> Integrator: + cls = self.cls_registry[scheme] + kwargs = dict( + init_state=state, + **{ + a: args[self.arg_registry[a][scheme]] + for a in self.all_arg_names + if scheme in self.arg_registry[a] + }, + ) + return cls(**kwargs) class Integrator(ValueTracker): @@ -24,7 +58,8 @@ class Integrator(ValueTracker): tolerated_error: float _tracked_values: dict[str, float] logger: logging.Logger - __registry: dict[str, Type[Integrator]] = {} + __factory: IntegratorFactory = IntegratorFactory() + steps_rejected = 0 def __init__( self, @@ -40,12 +75,16 @@ class Integrator(ValueTracker): self._tracked_values = {} self.logger = get_logger(self.__class__.__name__) - def __init_subclass__(cls): - cls.__registry[cls.__name__] = cls + def __init_subclass__(cls, scheme=""): + cls.__factory.register(scheme, cls) @classmethod - def get(cls, integr: str) -> Type[Integrator]: - return cls.__registry[integr] + def create(cls, name: str, state: CurrentState, *args) -> Integrator: + return cls.__factory.create(name, state, *args) + + @classmethod + def factory_args(cls) -> list[str]: + return cls.__factory.all_arg_names @abstractmethod def __iter__(self) -> Iterator[CurrentState]: @@ -64,10 +103,15 @@ class Integrator(ValueTracker): dict[str, float] tracked values """ - return self.values() | self._tracked_values | dict(z=self.state.z, step=self.state.step) + return ( + self.values() + | self._tracked_values + | dict(z=self.state.z, step=self.state.step, steps_rejected=self.steps_rejected) + ) def record_tracked_values(self): self._tracked_values = super().all_values() + self.steps_rejected = 0 def nl(self, spectrum: np.ndarray) -> np.ndarray: return self.nonlinear_operator(self.state.replace(spectrum)) @@ -83,7 +127,7 @@ class Integrator(ValueTracker): return self.state -class ConstantStepIntegrator(Integrator): +class ConstantStepIntegrator(Integrator, scheme="constant"): def __init__( self, init_state: CurrentState, @@ -112,7 +156,7 @@ class ConstantStepIntegrator(Integrator): ) -class ConservedQuantityIntegrator(Integrator): +class ConservedQuantityIntegrator(Integrator, scheme="cqe"): last_qty: float conserved_quantity: AbstractConservedQuantity current_error: float = 0.0 @@ -172,7 +216,7 @@ class ConservedQuantityIntegrator(Integrator): return dict(cons_qty=self.last_qty, relative_error=self.current_error) -class RK4IPSD(Integrator): +class RK4IPSD(Integrator, scheme="sd"): """Runge-Kutta 4 in Interaction Picture with step doubling""" next_h_factor: float = 1.0 @@ -197,18 +241,26 @@ class RK4IPSD(Integrator): ) new_coarse = self.take_step(h, self.state.spectrum, lin, nonlin) self.current_error = compute_diff(new_coarse, new_fine) - - if self.current_error > 2 * self.tolerated_error: - h_next_step = h * 0.5 - elif self.tolerated_error <= self.current_error <= 2 * self.tolerated_error: - h_next_step = h / size_fac - break - elif 0.5 * self.tolerated_error <= self.current_error < self.tolerated_error: - h_next_step = h - break + if self.current_error > 0.0: + next_h_factor = max( + 0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25) + ) else: - h_next_step = h * size_fac + next_h_factor = 2.0 + h_next_step = next_h_factor * h + if self.current_error <= 2 * self.tolerated_error: break + # if self.current_error > 2 * self.tolerated_error: + # h_next_step = h * 0.5 + # elif self.tolerated_error <= self.current_error <= 2 * self.tolerated_error: + # h_next_step = h / size_fac + # break + # elif 0.5 * self.tolerated_error <= self.current_error < self.tolerated_error: + # h_next_step = h + # break + # else: + # h_next_step = h * size_fac + # break self.state.spectrum = new_fine yield self.accept_step(self.state, h, h_next_step) @@ -225,7 +277,7 @@ class RK4IPSD(Integrator): ) -class ERK43(RK4IPSD): +class ERK43(RK4IPSD, scheme="erk43"): def __iter__(self) -> Iterator[CurrentState]: h_next_step = self.state.current_step_size k5 = self.nonlinear_operator(self.state) @@ -259,6 +311,7 @@ class ERK43(RK4IPSD): if self.current_error <= 2 * self.tolerated_error: break h_next_step = min(0.9, next_h_factor) * h + self.steps_rejected += 1 self.logger.info( f"step {self.state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}" ) @@ -268,7 +321,7 @@ class ERK43(RK4IPSD): yield self.accept_step(self.state, h, h_next_step) -class ERK54(RK4IPSD): +class ERK54(RK4IPSD, scheme="erk54"): def __iter__(self) -> Iterator[CurrentState]: self.logger.info("using ERK54") h_next_step = self.state.current_step_size @@ -362,3 +415,7 @@ def compute_diff(coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float: diff = coarse_spec - fine_spec diff2 = diff.imag ** 2 + diff.real ** 2 return np.sqrt(diff2.sum() / (fine_spec.real ** 2 + fine_spec.imag ** 2).sum()) + + +def get_integrator(integration_scheme: str): + return Integrator.get(integration_scheme)() diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index 1c7d38e..221edbd 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -28,6 +28,11 @@ from .errors import DuplicateParameterError T_ = TypeVar("T_") +class DebugDict(dict): + def __setitem__(self, k, v) -> None: + return super().__setitem__(k, v) + + class Paths: _data_files = [ "materials.toml",