Still debugging new system
This commit is contained in:
@@ -5,7 +5,6 @@ from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scgenerator import solver
|
||||
|
||||
from . import math, operators, utils
|
||||
from .const import MANDATORY_PARAMETERS
|
||||
@@ -378,17 +377,7 @@ default_rules: list[Rule] = [
|
||||
Rule("n_op", operators.HasanRefractiveIndex),
|
||||
Rule("raman_op", operators.NoRaman, priorities=-1),
|
||||
Rule("loss_op", operators.NoLoss, priorities=-1),
|
||||
# solvers
|
||||
Rule("integrator", solver.ConstantStepIntegrator, conditions=dict(adapt_step_size=False)),
|
||||
Rule(
|
||||
"integrator",
|
||||
solver.ConservedQuantityIntegrator,
|
||||
conditions=dict(adapt_step_size=True),
|
||||
priorities=1,
|
||||
),
|
||||
Rule("integrator", solver.RK4IPSD, conditions=dict(adapt_step_size=True)),
|
||||
Rule("integrator", solver.ERK43, conditions=dict(adapt_step_size=True)),
|
||||
Rule("integrator", solver.ERK54, conditions=dict(adapt_step_size=True), priorities=1),
|
||||
Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1),
|
||||
]
|
||||
|
||||
envelope_rules = default_rules + [
|
||||
|
||||
@@ -20,7 +20,7 @@ from .evaluator import Evaluator
|
||||
from .logger import get_logger
|
||||
from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
|
||||
from .solver import Integrator
|
||||
from .utils import fiber_folder, update_path_name
|
||||
from .utils import DebugDict, fiber_folder, update_path_name
|
||||
from .variationer import VariationDescriptor, Variationer
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -287,7 +287,7 @@ class Parameters:
|
||||
"""
|
||||
|
||||
# internal machinery
|
||||
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
|
||||
_param_dico: dict[str, Any] = field(init=False, default_factory=DebugDict, repr=False)
|
||||
_evaluator: Evaluator = field(init=False, repr=False)
|
||||
_p_names: ClassVar[Set[str]] = set()
|
||||
|
||||
@@ -346,7 +346,6 @@ class Parameters:
|
||||
mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW"))
|
||||
energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ"))
|
||||
soliton_num: float = Parameter(non_negative(float, int))
|
||||
quantum_noise: bool = Parameter(boolean, default=False)
|
||||
additional_noise_factor: float = Parameter(positive(float, int), default=1)
|
||||
shape: str = Parameter(literal("gaussian", "sech"), default="gaussian")
|
||||
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
|
||||
@@ -355,12 +354,19 @@ class Parameters:
|
||||
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
||||
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
||||
|
||||
# Behaviors to include
|
||||
quantum_noise: bool = Parameter(boolean, default=False)
|
||||
self_steepening: bool = Parameter(boolean, default=True)
|
||||
ideal_gas: bool = Parameter(boolean, default=False)
|
||||
photoionization: bool = Parameter(boolean, default=False)
|
||||
|
||||
# simulation
|
||||
full_field: bool = Parameter(boolean, default=False)
|
||||
integration_scheme: str = Parameter(
|
||||
literal("erk43", "erk54", "cqe", "sd", "constant"), converter=str.lower, default="erk43"
|
||||
)
|
||||
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
|
||||
self_steepening: bool = Parameter(boolean, default=True)
|
||||
spm: bool = Parameter(boolean, default=True)
|
||||
ideal_gas: bool = Parameter(boolean, default=False)
|
||||
repeat: int = Parameter(positive(int), default=1)
|
||||
t_num: int = Parameter(positive(int))
|
||||
z_num: int = Parameter(positive(int))
|
||||
|
||||
@@ -163,17 +163,14 @@ class RK4IP:
|
||||
store = False
|
||||
state = self.init_state.copy()
|
||||
yield len(self.stored_spectra) - 1, state
|
||||
if self.params.adapt_step_size:
|
||||
integrator = solver.RK4IPSD(
|
||||
state,
|
||||
self.params.linear_operator,
|
||||
self.params.nonlinear_operator,
|
||||
self.params.tolerated_error,
|
||||
)
|
||||
else:
|
||||
integrator = solver.ConstantStepIntegrator(
|
||||
state, self.params.linear_operator, self.params.nonlinear_operator
|
||||
|
||||
integrator_args = [
|
||||
self.params.compute(a) for a in solver.Integrator.factory_args() if a != "init_state"
|
||||
]
|
||||
integrator = solver.Integrator.create(
|
||||
self.params.integration_scheme, state, *integrator_args
|
||||
)
|
||||
|
||||
for state in integrator:
|
||||
|
||||
new_tracked_values = integrator.all_values()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
@@ -15,6 +16,39 @@ from .operators import (
|
||||
NonLinearOperator,
|
||||
ValueTracker,
|
||||
)
|
||||
from .utils import get_arg_names
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
|
||||
|
||||
class IntegratorFactory:
|
||||
arg_registry: dict[str, dict[str, int]]
|
||||
cls_registry: dict[str, Type[Integrator]]
|
||||
all_arg_names: list[str]
|
||||
|
||||
def __init__(self):
|
||||
self.arg_registry = defaultdict(dict)
|
||||
self.cls_registry = {}
|
||||
|
||||
def register(self, name: str, cls: Type[Integrator]):
|
||||
self.cls_registry[name] = cls
|
||||
arg_names = [a for a in get_arg_names(cls.__init__)[1:] if a not in {"init_state", "self"}]
|
||||
for i, a_name in enumerate(arg_names):
|
||||
self.arg_registry[a_name][name] = i
|
||||
self.all_arg_names = list(self.arg_registry.keys())
|
||||
|
||||
def create(self, scheme: str, state: CurrentState, *args) -> Integrator:
|
||||
cls = self.cls_registry[scheme]
|
||||
kwargs = dict(
|
||||
init_state=state,
|
||||
**{
|
||||
a: args[self.arg_registry[a][scheme]]
|
||||
for a in self.all_arg_names
|
||||
if scheme in self.arg_registry[a]
|
||||
},
|
||||
)
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
class Integrator(ValueTracker):
|
||||
@@ -24,7 +58,8 @@ class Integrator(ValueTracker):
|
||||
tolerated_error: float
|
||||
_tracked_values: dict[str, float]
|
||||
logger: logging.Logger
|
||||
__registry: dict[str, Type[Integrator]] = {}
|
||||
__factory: IntegratorFactory = IntegratorFactory()
|
||||
steps_rejected = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -40,12 +75,16 @@ class Integrator(ValueTracker):
|
||||
self._tracked_values = {}
|
||||
self.logger = get_logger(self.__class__.__name__)
|
||||
|
||||
def __init_subclass__(cls):
|
||||
cls.__registry[cls.__name__] = cls
|
||||
def __init_subclass__(cls, scheme=""):
|
||||
cls.__factory.register(scheme, cls)
|
||||
|
||||
@classmethod
|
||||
def get(cls, integr: str) -> Type[Integrator]:
|
||||
return cls.__registry[integr]
|
||||
def create(cls, name: str, state: CurrentState, *args) -> Integrator:
|
||||
return cls.__factory.create(name, state, *args)
|
||||
|
||||
@classmethod
|
||||
def factory_args(cls) -> list[str]:
|
||||
return cls.__factory.all_arg_names
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Iterator[CurrentState]:
|
||||
@@ -64,10 +103,15 @@ class Integrator(ValueTracker):
|
||||
dict[str, float]
|
||||
tracked values
|
||||
"""
|
||||
return self.values() | self._tracked_values | dict(z=self.state.z, step=self.state.step)
|
||||
return (
|
||||
self.values()
|
||||
| self._tracked_values
|
||||
| dict(z=self.state.z, step=self.state.step, steps_rejected=self.steps_rejected)
|
||||
)
|
||||
|
||||
def record_tracked_values(self):
|
||||
self._tracked_values = super().all_values()
|
||||
self.steps_rejected = 0
|
||||
|
||||
def nl(self, spectrum: np.ndarray) -> np.ndarray:
|
||||
return self.nonlinear_operator(self.state.replace(spectrum))
|
||||
@@ -83,7 +127,7 @@ class Integrator(ValueTracker):
|
||||
return self.state
|
||||
|
||||
|
||||
class ConstantStepIntegrator(Integrator):
|
||||
class ConstantStepIntegrator(Integrator, scheme="constant"):
|
||||
def __init__(
|
||||
self,
|
||||
init_state: CurrentState,
|
||||
@@ -112,7 +156,7 @@ class ConstantStepIntegrator(Integrator):
|
||||
)
|
||||
|
||||
|
||||
class ConservedQuantityIntegrator(Integrator):
|
||||
class ConservedQuantityIntegrator(Integrator, scheme="cqe"):
|
||||
last_qty: float
|
||||
conserved_quantity: AbstractConservedQuantity
|
||||
current_error: float = 0.0
|
||||
@@ -172,7 +216,7 @@ class ConservedQuantityIntegrator(Integrator):
|
||||
return dict(cons_qty=self.last_qty, relative_error=self.current_error)
|
||||
|
||||
|
||||
class RK4IPSD(Integrator):
|
||||
class RK4IPSD(Integrator, scheme="sd"):
|
||||
"""Runge-Kutta 4 in Interaction Picture with step doubling"""
|
||||
|
||||
next_h_factor: float = 1.0
|
||||
@@ -197,18 +241,26 @@ class RK4IPSD(Integrator):
|
||||
)
|
||||
new_coarse = self.take_step(h, self.state.spectrum, lin, nonlin)
|
||||
self.current_error = compute_diff(new_coarse, new_fine)
|
||||
|
||||
if self.current_error > 2 * self.tolerated_error:
|
||||
h_next_step = h * 0.5
|
||||
elif self.tolerated_error <= self.current_error <= 2 * self.tolerated_error:
|
||||
h_next_step = h / size_fac
|
||||
break
|
||||
elif 0.5 * self.tolerated_error <= self.current_error < self.tolerated_error:
|
||||
h_next_step = h
|
||||
break
|
||||
if self.current_error > 0.0:
|
||||
next_h_factor = max(
|
||||
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
|
||||
)
|
||||
else:
|
||||
h_next_step = h * size_fac
|
||||
next_h_factor = 2.0
|
||||
h_next_step = next_h_factor * h
|
||||
if self.current_error <= 2 * self.tolerated_error:
|
||||
break
|
||||
# if self.current_error > 2 * self.tolerated_error:
|
||||
# h_next_step = h * 0.5
|
||||
# elif self.tolerated_error <= self.current_error <= 2 * self.tolerated_error:
|
||||
# h_next_step = h / size_fac
|
||||
# break
|
||||
# elif 0.5 * self.tolerated_error <= self.current_error < self.tolerated_error:
|
||||
# h_next_step = h
|
||||
# break
|
||||
# else:
|
||||
# h_next_step = h * size_fac
|
||||
# break
|
||||
|
||||
self.state.spectrum = new_fine
|
||||
yield self.accept_step(self.state, h, h_next_step)
|
||||
@@ -225,7 +277,7 @@ class RK4IPSD(Integrator):
|
||||
)
|
||||
|
||||
|
||||
class ERK43(RK4IPSD):
|
||||
class ERK43(RK4IPSD, scheme="erk43"):
|
||||
def __iter__(self) -> Iterator[CurrentState]:
|
||||
h_next_step = self.state.current_step_size
|
||||
k5 = self.nonlinear_operator(self.state)
|
||||
@@ -259,6 +311,7 @@ class ERK43(RK4IPSD):
|
||||
if self.current_error <= 2 * self.tolerated_error:
|
||||
break
|
||||
h_next_step = min(0.9, next_h_factor) * h
|
||||
self.steps_rejected += 1
|
||||
self.logger.info(
|
||||
f"step {self.state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}"
|
||||
)
|
||||
@@ -268,7 +321,7 @@ class ERK43(RK4IPSD):
|
||||
yield self.accept_step(self.state, h, h_next_step)
|
||||
|
||||
|
||||
class ERK54(RK4IPSD):
|
||||
class ERK54(RK4IPSD, scheme="erk54"):
|
||||
def __iter__(self) -> Iterator[CurrentState]:
|
||||
self.logger.info("using ERK54")
|
||||
h_next_step = self.state.current_step_size
|
||||
@@ -362,3 +415,7 @@ def compute_diff(coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float:
|
||||
diff = coarse_spec - fine_spec
|
||||
diff2 = diff.imag ** 2 + diff.real ** 2
|
||||
return np.sqrt(diff2.sum() / (fine_spec.real ** 2 + fine_spec.imag ** 2).sum())
|
||||
|
||||
|
||||
def get_integrator(integration_scheme: str):
|
||||
return Integrator.get(integration_scheme)()
|
||||
|
||||
@@ -28,6 +28,11 @@ from .errors import DuplicateParameterError
|
||||
T_ = TypeVar("T_")
|
||||
|
||||
|
||||
class DebugDict(dict):
|
||||
def __setitem__(self, k, v) -> None:
|
||||
return super().__setitem__(k, v)
|
||||
|
||||
|
||||
class Paths:
|
||||
_data_files = [
|
||||
"materials.toml",
|
||||
|
||||
Reference in New Issue
Block a user