Still debugging new system
This commit is contained in:
@@ -5,7 +5,6 @@ from typing import Any, Callable, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from scgenerator import solver
|
|
||||||
|
|
||||||
from . import math, operators, utils
|
from . import math, operators, utils
|
||||||
from .const import MANDATORY_PARAMETERS
|
from .const import MANDATORY_PARAMETERS
|
||||||
@@ -378,17 +377,7 @@ default_rules: list[Rule] = [
|
|||||||
Rule("n_op", operators.HasanRefractiveIndex),
|
Rule("n_op", operators.HasanRefractiveIndex),
|
||||||
Rule("raman_op", operators.NoRaman, priorities=-1),
|
Rule("raman_op", operators.NoRaman, priorities=-1),
|
||||||
Rule("loss_op", operators.NoLoss, priorities=-1),
|
Rule("loss_op", operators.NoLoss, priorities=-1),
|
||||||
# solvers
|
Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1),
|
||||||
Rule("integrator", solver.ConstantStepIntegrator, conditions=dict(adapt_step_size=False)),
|
|
||||||
Rule(
|
|
||||||
"integrator",
|
|
||||||
solver.ConservedQuantityIntegrator,
|
|
||||||
conditions=dict(adapt_step_size=True),
|
|
||||||
priorities=1,
|
|
||||||
),
|
|
||||||
Rule("integrator", solver.RK4IPSD, conditions=dict(adapt_step_size=True)),
|
|
||||||
Rule("integrator", solver.ERK43, conditions=dict(adapt_step_size=True)),
|
|
||||||
Rule("integrator", solver.ERK54, conditions=dict(adapt_step_size=True), priorities=1),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
envelope_rules = default_rules + [
|
envelope_rules = default_rules + [
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from .evaluator import Evaluator
|
|||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
|
from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
|
||||||
from .solver import Integrator
|
from .solver import Integrator
|
||||||
from .utils import fiber_folder, update_path_name
|
from .utils import DebugDict, fiber_folder, update_path_name
|
||||||
from .variationer import VariationDescriptor, Variationer
|
from .variationer import VariationDescriptor, Variationer
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -287,7 +287,7 @@ class Parameters:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# internal machinery
|
# internal machinery
|
||||||
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
|
_param_dico: dict[str, Any] = field(init=False, default_factory=DebugDict, repr=False)
|
||||||
_evaluator: Evaluator = field(init=False, repr=False)
|
_evaluator: Evaluator = field(init=False, repr=False)
|
||||||
_p_names: ClassVar[Set[str]] = set()
|
_p_names: ClassVar[Set[str]] = set()
|
||||||
|
|
||||||
@@ -346,7 +346,6 @@ class Parameters:
|
|||||||
mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW"))
|
mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW"))
|
||||||
energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ"))
|
energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ"))
|
||||||
soliton_num: float = Parameter(non_negative(float, int))
|
soliton_num: float = Parameter(non_negative(float, int))
|
||||||
quantum_noise: bool = Parameter(boolean, default=False)
|
|
||||||
additional_noise_factor: float = Parameter(positive(float, int), default=1)
|
additional_noise_factor: float = Parameter(positive(float, int), default=1)
|
||||||
shape: str = Parameter(literal("gaussian", "sech"), default="gaussian")
|
shape: str = Parameter(literal("gaussian", "sech"), default="gaussian")
|
||||||
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
|
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
|
||||||
@@ -355,12 +354,19 @@ class Parameters:
|
|||||||
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
||||||
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
||||||
|
|
||||||
|
# Behaviors to include
|
||||||
|
quantum_noise: bool = Parameter(boolean, default=False)
|
||||||
|
self_steepening: bool = Parameter(boolean, default=True)
|
||||||
|
ideal_gas: bool = Parameter(boolean, default=False)
|
||||||
|
photoionization: bool = Parameter(boolean, default=False)
|
||||||
|
|
||||||
# simulation
|
# simulation
|
||||||
full_field: bool = Parameter(boolean, default=False)
|
full_field: bool = Parameter(boolean, default=False)
|
||||||
|
integration_scheme: str = Parameter(
|
||||||
|
literal("erk43", "erk54", "cqe", "sd", "constant"), converter=str.lower, default="erk43"
|
||||||
|
)
|
||||||
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
|
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
|
||||||
self_steepening: bool = Parameter(boolean, default=True)
|
|
||||||
spm: bool = Parameter(boolean, default=True)
|
spm: bool = Parameter(boolean, default=True)
|
||||||
ideal_gas: bool = Parameter(boolean, default=False)
|
|
||||||
repeat: int = Parameter(positive(int), default=1)
|
repeat: int = Parameter(positive(int), default=1)
|
||||||
t_num: int = Parameter(positive(int))
|
t_num: int = Parameter(positive(int))
|
||||||
z_num: int = Parameter(positive(int))
|
z_num: int = Parameter(positive(int))
|
||||||
|
|||||||
@@ -163,17 +163,14 @@ class RK4IP:
|
|||||||
store = False
|
store = False
|
||||||
state = self.init_state.copy()
|
state = self.init_state.copy()
|
||||||
yield len(self.stored_spectra) - 1, state
|
yield len(self.stored_spectra) - 1, state
|
||||||
if self.params.adapt_step_size:
|
|
||||||
integrator = solver.RK4IPSD(
|
integrator_args = [
|
||||||
state,
|
self.params.compute(a) for a in solver.Integrator.factory_args() if a != "init_state"
|
||||||
self.params.linear_operator,
|
]
|
||||||
self.params.nonlinear_operator,
|
integrator = solver.Integrator.create(
|
||||||
self.params.tolerated_error,
|
self.params.integration_scheme, state, *integrator_args
|
||||||
)
|
|
||||||
else:
|
|
||||||
integrator = solver.ConstantStepIntegrator(
|
|
||||||
state, self.params.linear_operator, self.params.nonlinear_operator
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for state in integrator:
|
for state in integrator:
|
||||||
|
|
||||||
new_tracked_values = integrator.all_values()
|
new_tracked_values = integrator.all_values()
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
@@ -15,6 +16,39 @@ from .operators import (
|
|||||||
NonLinearOperator,
|
NonLinearOperator,
|
||||||
ValueTracker,
|
ValueTracker,
|
||||||
)
|
)
|
||||||
|
from .utils import get_arg_names
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.filterwarnings("error")
|
||||||
|
|
||||||
|
|
||||||
|
class IntegratorFactory:
|
||||||
|
arg_registry: dict[str, dict[str, int]]
|
||||||
|
cls_registry: dict[str, Type[Integrator]]
|
||||||
|
all_arg_names: list[str]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.arg_registry = defaultdict(dict)
|
||||||
|
self.cls_registry = {}
|
||||||
|
|
||||||
|
def register(self, name: str, cls: Type[Integrator]):
|
||||||
|
self.cls_registry[name] = cls
|
||||||
|
arg_names = [a for a in get_arg_names(cls.__init__)[1:] if a not in {"init_state", "self"}]
|
||||||
|
for i, a_name in enumerate(arg_names):
|
||||||
|
self.arg_registry[a_name][name] = i
|
||||||
|
self.all_arg_names = list(self.arg_registry.keys())
|
||||||
|
|
||||||
|
def create(self, scheme: str, state: CurrentState, *args) -> Integrator:
|
||||||
|
cls = self.cls_registry[scheme]
|
||||||
|
kwargs = dict(
|
||||||
|
init_state=state,
|
||||||
|
**{
|
||||||
|
a: args[self.arg_registry[a][scheme]]
|
||||||
|
for a in self.all_arg_names
|
||||||
|
if scheme in self.arg_registry[a]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Integrator(ValueTracker):
|
class Integrator(ValueTracker):
|
||||||
@@ -24,7 +58,8 @@ class Integrator(ValueTracker):
|
|||||||
tolerated_error: float
|
tolerated_error: float
|
||||||
_tracked_values: dict[str, float]
|
_tracked_values: dict[str, float]
|
||||||
logger: logging.Logger
|
logger: logging.Logger
|
||||||
__registry: dict[str, Type[Integrator]] = {}
|
__factory: IntegratorFactory = IntegratorFactory()
|
||||||
|
steps_rejected = 0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -40,12 +75,16 @@ class Integrator(ValueTracker):
|
|||||||
self._tracked_values = {}
|
self._tracked_values = {}
|
||||||
self.logger = get_logger(self.__class__.__name__)
|
self.logger = get_logger(self.__class__.__name__)
|
||||||
|
|
||||||
def __init_subclass__(cls):
|
def __init_subclass__(cls, scheme=""):
|
||||||
cls.__registry[cls.__name__] = cls
|
cls.__factory.register(scheme, cls)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, integr: str) -> Type[Integrator]:
|
def create(cls, name: str, state: CurrentState, *args) -> Integrator:
|
||||||
return cls.__registry[integr]
|
return cls.__factory.create(name, state, *args)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def factory_args(cls) -> list[str]:
|
||||||
|
return cls.__factory.all_arg_names
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __iter__(self) -> Iterator[CurrentState]:
|
def __iter__(self) -> Iterator[CurrentState]:
|
||||||
@@ -64,10 +103,15 @@ class Integrator(ValueTracker):
|
|||||||
dict[str, float]
|
dict[str, float]
|
||||||
tracked values
|
tracked values
|
||||||
"""
|
"""
|
||||||
return self.values() | self._tracked_values | dict(z=self.state.z, step=self.state.step)
|
return (
|
||||||
|
self.values()
|
||||||
|
| self._tracked_values
|
||||||
|
| dict(z=self.state.z, step=self.state.step, steps_rejected=self.steps_rejected)
|
||||||
|
)
|
||||||
|
|
||||||
def record_tracked_values(self):
|
def record_tracked_values(self):
|
||||||
self._tracked_values = super().all_values()
|
self._tracked_values = super().all_values()
|
||||||
|
self.steps_rejected = 0
|
||||||
|
|
||||||
def nl(self, spectrum: np.ndarray) -> np.ndarray:
|
def nl(self, spectrum: np.ndarray) -> np.ndarray:
|
||||||
return self.nonlinear_operator(self.state.replace(spectrum))
|
return self.nonlinear_operator(self.state.replace(spectrum))
|
||||||
@@ -83,7 +127,7 @@ class Integrator(ValueTracker):
|
|||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
|
|
||||||
class ConstantStepIntegrator(Integrator):
|
class ConstantStepIntegrator(Integrator, scheme="constant"):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
init_state: CurrentState,
|
init_state: CurrentState,
|
||||||
@@ -112,7 +156,7 @@ class ConstantStepIntegrator(Integrator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConservedQuantityIntegrator(Integrator):
|
class ConservedQuantityIntegrator(Integrator, scheme="cqe"):
|
||||||
last_qty: float
|
last_qty: float
|
||||||
conserved_quantity: AbstractConservedQuantity
|
conserved_quantity: AbstractConservedQuantity
|
||||||
current_error: float = 0.0
|
current_error: float = 0.0
|
||||||
@@ -172,7 +216,7 @@ class ConservedQuantityIntegrator(Integrator):
|
|||||||
return dict(cons_qty=self.last_qty, relative_error=self.current_error)
|
return dict(cons_qty=self.last_qty, relative_error=self.current_error)
|
||||||
|
|
||||||
|
|
||||||
class RK4IPSD(Integrator):
|
class RK4IPSD(Integrator, scheme="sd"):
|
||||||
"""Runge-Kutta 4 in Interaction Picture with step doubling"""
|
"""Runge-Kutta 4 in Interaction Picture with step doubling"""
|
||||||
|
|
||||||
next_h_factor: float = 1.0
|
next_h_factor: float = 1.0
|
||||||
@@ -197,18 +241,26 @@ class RK4IPSD(Integrator):
|
|||||||
)
|
)
|
||||||
new_coarse = self.take_step(h, self.state.spectrum, lin, nonlin)
|
new_coarse = self.take_step(h, self.state.spectrum, lin, nonlin)
|
||||||
self.current_error = compute_diff(new_coarse, new_fine)
|
self.current_error = compute_diff(new_coarse, new_fine)
|
||||||
|
if self.current_error > 0.0:
|
||||||
if self.current_error > 2 * self.tolerated_error:
|
next_h_factor = max(
|
||||||
h_next_step = h * 0.5
|
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
|
||||||
elif self.tolerated_error <= self.current_error <= 2 * self.tolerated_error:
|
)
|
||||||
h_next_step = h / size_fac
|
|
||||||
break
|
|
||||||
elif 0.5 * self.tolerated_error <= self.current_error < self.tolerated_error:
|
|
||||||
h_next_step = h
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
h_next_step = h * size_fac
|
next_h_factor = 2.0
|
||||||
|
h_next_step = next_h_factor * h
|
||||||
|
if self.current_error <= 2 * self.tolerated_error:
|
||||||
break
|
break
|
||||||
|
# if self.current_error > 2 * self.tolerated_error:
|
||||||
|
# h_next_step = h * 0.5
|
||||||
|
# elif self.tolerated_error <= self.current_error <= 2 * self.tolerated_error:
|
||||||
|
# h_next_step = h / size_fac
|
||||||
|
# break
|
||||||
|
# elif 0.5 * self.tolerated_error <= self.current_error < self.tolerated_error:
|
||||||
|
# h_next_step = h
|
||||||
|
# break
|
||||||
|
# else:
|
||||||
|
# h_next_step = h * size_fac
|
||||||
|
# break
|
||||||
|
|
||||||
self.state.spectrum = new_fine
|
self.state.spectrum = new_fine
|
||||||
yield self.accept_step(self.state, h, h_next_step)
|
yield self.accept_step(self.state, h, h_next_step)
|
||||||
@@ -225,7 +277,7 @@ class RK4IPSD(Integrator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ERK43(RK4IPSD):
|
class ERK43(RK4IPSD, scheme="erk43"):
|
||||||
def __iter__(self) -> Iterator[CurrentState]:
|
def __iter__(self) -> Iterator[CurrentState]:
|
||||||
h_next_step = self.state.current_step_size
|
h_next_step = self.state.current_step_size
|
||||||
k5 = self.nonlinear_operator(self.state)
|
k5 = self.nonlinear_operator(self.state)
|
||||||
@@ -259,6 +311,7 @@ class ERK43(RK4IPSD):
|
|||||||
if self.current_error <= 2 * self.tolerated_error:
|
if self.current_error <= 2 * self.tolerated_error:
|
||||||
break
|
break
|
||||||
h_next_step = min(0.9, next_h_factor) * h
|
h_next_step = min(0.9, next_h_factor) * h
|
||||||
|
self.steps_rejected += 1
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"step {self.state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}"
|
f"step {self.state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}"
|
||||||
)
|
)
|
||||||
@@ -268,7 +321,7 @@ class ERK43(RK4IPSD):
|
|||||||
yield self.accept_step(self.state, h, h_next_step)
|
yield self.accept_step(self.state, h, h_next_step)
|
||||||
|
|
||||||
|
|
||||||
class ERK54(RK4IPSD):
|
class ERK54(RK4IPSD, scheme="erk54"):
|
||||||
def __iter__(self) -> Iterator[CurrentState]:
|
def __iter__(self) -> Iterator[CurrentState]:
|
||||||
self.logger.info("using ERK54")
|
self.logger.info("using ERK54")
|
||||||
h_next_step = self.state.current_step_size
|
h_next_step = self.state.current_step_size
|
||||||
@@ -362,3 +415,7 @@ def compute_diff(coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float:
|
|||||||
diff = coarse_spec - fine_spec
|
diff = coarse_spec - fine_spec
|
||||||
diff2 = diff.imag ** 2 + diff.real ** 2
|
diff2 = diff.imag ** 2 + diff.real ** 2
|
||||||
return np.sqrt(diff2.sum() / (fine_spec.real ** 2 + fine_spec.imag ** 2).sum())
|
return np.sqrt(diff2.sum() / (fine_spec.real ** 2 + fine_spec.imag ** 2).sum())
|
||||||
|
|
||||||
|
|
||||||
|
def get_integrator(integration_scheme: str):
|
||||||
|
return Integrator.get(integration_scheme)()
|
||||||
|
|||||||
@@ -28,6 +28,11 @@ from .errors import DuplicateParameterError
|
|||||||
T_ = TypeVar("T_")
|
T_ = TypeVar("T_")
|
||||||
|
|
||||||
|
|
||||||
|
class DebugDict(dict):
|
||||||
|
def __setitem__(self, k, v) -> None:
|
||||||
|
return super().__setitem__(k, v)
|
||||||
|
|
||||||
|
|
||||||
class Paths:
|
class Paths:
|
||||||
_data_files = [
|
_data_files = [
|
||||||
"materials.toml",
|
"materials.toml",
|
||||||
|
|||||||
Reference in New Issue
Block a user