Still debugging new system

This commit is contained in:
Benoît Sierro
2021-11-16 08:32:33 +01:00
parent 38b3d063e5
commit 46102a03d9
5 changed files with 103 additions and 49 deletions

View File

@@ -5,7 +5,6 @@ from typing import Any, Callable, Optional, Union
import numpy as np
from scgenerator import solver
from . import math, operators, utils
from .const import MANDATORY_PARAMETERS
@@ -378,17 +377,7 @@ default_rules: list[Rule] = [
Rule("n_op", operators.HasanRefractiveIndex),
Rule("raman_op", operators.NoRaman, priorities=-1),
Rule("loss_op", operators.NoLoss, priorities=-1),
# solvers
Rule("integrator", solver.ConstantStepIntegrator, conditions=dict(adapt_step_size=False)),
Rule(
"integrator",
solver.ConservedQuantityIntegrator,
conditions=dict(adapt_step_size=True),
priorities=1,
),
Rule("integrator", solver.RK4IPSD, conditions=dict(adapt_step_size=True)),
Rule("integrator", solver.ERK43, conditions=dict(adapt_step_size=True)),
Rule("integrator", solver.ERK54, conditions=dict(adapt_step_size=True), priorities=1),
Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1),
]
envelope_rules = default_rules + [

View File

@@ -20,7 +20,7 @@ from .evaluator import Evaluator
from .logger import get_logger
from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
from .solver import Integrator
from .utils import fiber_folder, update_path_name
from .utils import DebugDict, fiber_folder, update_path_name
from .variationer import VariationDescriptor, Variationer
T = TypeVar("T")
@@ -287,7 +287,7 @@ class Parameters:
"""
# internal machinery
_param_dico: dict[str, Any] = field(init=False, default_factory=dict, repr=False)
_param_dico: dict[str, Any] = field(init=False, default_factory=DebugDict, repr=False)
_evaluator: Evaluator = field(init=False, repr=False)
_p_names: ClassVar[Set[str]] = set()
@@ -346,7 +346,6 @@ class Parameters:
mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW"))
energy: float = Parameter(positive(float, int), display_info=(1e6, "μJ"))
soliton_num: float = Parameter(non_negative(float, int))
quantum_noise: bool = Parameter(boolean, default=False)
additional_noise_factor: float = Parameter(positive(float, int), default=1)
shape: str = Parameter(literal("gaussian", "sech"), default="gaussian")
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
@@ -355,12 +354,19 @@ class Parameters:
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
# Behaviors to include
quantum_noise: bool = Parameter(boolean, default=False)
self_steepening: bool = Parameter(boolean, default=True)
ideal_gas: bool = Parameter(boolean, default=False)
photoionization: bool = Parameter(boolean, default=False)
# simulation
full_field: bool = Parameter(boolean, default=False)
integration_scheme: str = Parameter(
literal("erk43", "erk54", "cqe", "sd", "constant"), converter=str.lower, default="erk43"
)
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
self_steepening: bool = Parameter(boolean, default=True)
spm: bool = Parameter(boolean, default=True)
ideal_gas: bool = Parameter(boolean, default=False)
repeat: int = Parameter(positive(int), default=1)
t_num: int = Parameter(positive(int))
z_num: int = Parameter(positive(int))

View File

@@ -163,17 +163,14 @@ class RK4IP:
store = False
state = self.init_state.copy()
yield len(self.stored_spectra) - 1, state
if self.params.adapt_step_size:
integrator = solver.RK4IPSD(
state,
self.params.linear_operator,
self.params.nonlinear_operator,
self.params.tolerated_error,
)
else:
integrator = solver.ConstantStepIntegrator(
state, self.params.linear_operator, self.params.nonlinear_operator
)
integrator_args = [
self.params.compute(a) for a in solver.Integrator.factory_args() if a != "init_state"
]
integrator = solver.Integrator.create(
self.params.integration_scheme, state, *integrator_args
)
for state in integrator:
new_tracked_values = integrator.all_values()

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from collections import defaultdict
import logging
from abc import abstractmethod
@@ -15,6 +16,39 @@ from .operators import (
NonLinearOperator,
ValueTracker,
)
from .utils import get_arg_names
import warnings
warnings.filterwarnings("error")
class IntegratorFactory:
arg_registry: dict[str, dict[str, int]]
cls_registry: dict[str, Type[Integrator]]
all_arg_names: list[str]
def __init__(self):
self.arg_registry = defaultdict(dict)
self.cls_registry = {}
def register(self, name: str, cls: Type[Integrator]):
self.cls_registry[name] = cls
arg_names = [a for a in get_arg_names(cls.__init__)[1:] if a not in {"init_state", "self"}]
for i, a_name in enumerate(arg_names):
self.arg_registry[a_name][name] = i
self.all_arg_names = list(self.arg_registry.keys())
def create(self, scheme: str, state: CurrentState, *args) -> Integrator:
cls = self.cls_registry[scheme]
kwargs = dict(
init_state=state,
**{
a: args[self.arg_registry[a][scheme]]
for a in self.all_arg_names
if scheme in self.arg_registry[a]
},
)
return cls(**kwargs)
class Integrator(ValueTracker):
@@ -24,7 +58,8 @@ class Integrator(ValueTracker):
tolerated_error: float
_tracked_values: dict[str, float]
logger: logging.Logger
__registry: dict[str, Type[Integrator]] = {}
__factory: IntegratorFactory = IntegratorFactory()
steps_rejected = 0
def __init__(
self,
@@ -40,12 +75,16 @@ class Integrator(ValueTracker):
self._tracked_values = {}
self.logger = get_logger(self.__class__.__name__)
def __init_subclass__(cls):
cls.__registry[cls.__name__] = cls
def __init_subclass__(cls, scheme=""):
cls.__factory.register(scheme, cls)
@classmethod
def get(cls, integr: str) -> Type[Integrator]:
return cls.__registry[integr]
def create(cls, name: str, state: CurrentState, *args) -> Integrator:
return cls.__factory.create(name, state, *args)
@classmethod
def factory_args(cls) -> list[str]:
return cls.__factory.all_arg_names
@abstractmethod
def __iter__(self) -> Iterator[CurrentState]:
@@ -64,10 +103,15 @@ class Integrator(ValueTracker):
dict[str, float]
tracked values
"""
return self.values() | self._tracked_values | dict(z=self.state.z, step=self.state.step)
return (
self.values()
| self._tracked_values
| dict(z=self.state.z, step=self.state.step, steps_rejected=self.steps_rejected)
)
def record_tracked_values(self):
self._tracked_values = super().all_values()
self.steps_rejected = 0
def nl(self, spectrum: np.ndarray) -> np.ndarray:
return self.nonlinear_operator(self.state.replace(spectrum))
@@ -83,7 +127,7 @@ class Integrator(ValueTracker):
return self.state
class ConstantStepIntegrator(Integrator):
class ConstantStepIntegrator(Integrator, scheme="constant"):
def __init__(
self,
init_state: CurrentState,
@@ -112,7 +156,7 @@ class ConstantStepIntegrator(Integrator):
)
class ConservedQuantityIntegrator(Integrator):
class ConservedQuantityIntegrator(Integrator, scheme="cqe"):
last_qty: float
conserved_quantity: AbstractConservedQuantity
current_error: float = 0.0
@@ -172,7 +216,7 @@ class ConservedQuantityIntegrator(Integrator):
return dict(cons_qty=self.last_qty, relative_error=self.current_error)
class RK4IPSD(Integrator):
class RK4IPSD(Integrator, scheme="sd"):
"""Runge-Kutta 4 in Interaction Picture with step doubling"""
next_h_factor: float = 1.0
@@ -197,18 +241,26 @@ class RK4IPSD(Integrator):
)
new_coarse = self.take_step(h, self.state.spectrum, lin, nonlin)
self.current_error = compute_diff(new_coarse, new_fine)
if self.current_error > 2 * self.tolerated_error:
h_next_step = h * 0.5
elif self.tolerated_error <= self.current_error <= 2 * self.tolerated_error:
h_next_step = h / size_fac
break
elif 0.5 * self.tolerated_error <= self.current_error < self.tolerated_error:
h_next_step = h
break
if self.current_error > 0.0:
next_h_factor = max(
0.5, min(2.0, (self.tolerated_error / self.current_error) ** 0.25)
)
else:
h_next_step = h * size_fac
next_h_factor = 2.0
h_next_step = next_h_factor * h
if self.current_error <= 2 * self.tolerated_error:
break
# if self.current_error > 2 * self.tolerated_error:
# h_next_step = h * 0.5
# elif self.tolerated_error <= self.current_error <= 2 * self.tolerated_error:
# h_next_step = h / size_fac
# break
# elif 0.5 * self.tolerated_error <= self.current_error < self.tolerated_error:
# h_next_step = h
# break
# else:
# h_next_step = h * size_fac
# break
self.state.spectrum = new_fine
yield self.accept_step(self.state, h, h_next_step)
@@ -225,7 +277,7 @@ class RK4IPSD(Integrator):
)
class ERK43(RK4IPSD):
class ERK43(RK4IPSD, scheme="erk43"):
def __iter__(self) -> Iterator[CurrentState]:
h_next_step = self.state.current_step_size
k5 = self.nonlinear_operator(self.state)
@@ -259,6 +311,7 @@ class ERK43(RK4IPSD):
if self.current_error <= 2 * self.tolerated_error:
break
h_next_step = min(0.9, next_h_factor) * h
self.steps_rejected += 1
self.logger.info(
f"step {self.state.step} rejected : {h=}, {self.current_error=}, {h_next_step=}"
)
@@ -268,7 +321,7 @@ class ERK43(RK4IPSD):
yield self.accept_step(self.state, h, h_next_step)
class ERK54(RK4IPSD):
class ERK54(RK4IPSD, scheme="erk54"):
def __iter__(self) -> Iterator[CurrentState]:
self.logger.info("using ERK54")
h_next_step = self.state.current_step_size
@@ -362,3 +415,7 @@ def compute_diff(coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float:
diff = coarse_spec - fine_spec
diff2 = diff.imag ** 2 + diff.real ** 2
return np.sqrt(diff2.sum() / (fine_spec.real ** 2 + fine_spec.imag ** 2).sum())
def get_integrator(integration_scheme: str):
return Integrator.get(integration_scheme)()

View File

@@ -28,6 +28,11 @@ from .errors import DuplicateParameterError
T_ = TypeVar("T_")
class DebugDict(dict):
def __setitem__(self, k, v) -> None:
return super().__setitem__(k, v)
class Paths:
_data_files = [
"materials.toml",