Removed inheritence; cons_qty operators

This commit is contained in:
Benoît Sierro
2021-10-19 17:15:02 +02:00
parent ecb5ee681a
commit de12b0d5c1
5 changed files with 148 additions and 164 deletions

View File

@@ -1,13 +1,15 @@
from typing import Optional, Callable, Union, Any
from dataclasses import dataclass
from .physics import fiber, pulse, materials, units
from .utils import _mock_function, get_arg_names, get_logger, func_rewrite
from .errors import *
from collections import defaultdict
from .const import MANDATORY_PARAMETERS
import numpy as np
import itertools import itertools
from . import math, utils, operators from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import numpy as np
from . import math, operators, utils
from .const import MANDATORY_PARAMETERS
from .errors import *
from .physics import fiber, materials, pulse, units
from .utils import _mock_function, func_rewrite, get_arg_names, get_logger
class Rule: class Rule:
@@ -378,6 +380,7 @@ default_rules: list[Rule] = [
Rule("loss_op", operators.NoLoss, priorities=-1), Rule("loss_op", operators.NoLoss, priorities=-1),
Rule("disp_op", operators.ConstantPolyDispersion), Rule("disp_op", operators.ConstantPolyDispersion),
Rule("linear_operator", operators.LinearOperator), Rule("linear_operator", operators.LinearOperator),
Rule("conserved_quantity", operators.ConservedQuantity),
# gas # gas
Rule("n_gas_2", materials.n_gas_2), Rule("n_gas_2", materials.n_gas_2),
] ]

View File

@@ -6,12 +6,16 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps
from os import stat
from typing import Callable
import numpy as np import numpy as np
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
from .physics import fiber
from . import math from . import math
from .logger import get_logger
from .physics import fiber, pulse
class SpectrumDescriptor: class SpectrumDescriptor:
@@ -385,3 +389,78 @@ class CustomConstantLoss(ConstantLoss):
wl = loss_data["wavelength"] wl = loss_data["wavelength"]
loss = loss_data["loss"] loss = loss_data["loss"]
self.alpha_arr = interp1d(wl, loss, fill_value=0, bounds_error=False)(l) self.alpha_arr = interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
##################################################
############### CONSERVED QUANTITY ###############
##################################################
class ConservedQuantity(Operator):
def __new__(
raman_op: AbstractGamma, gamma_op: AbstractGamma, loss_op: AbstractLoss, w: np.ndarray
):
logger = get_logger(__name__)
raman = not isinstance(raman_op, NoRaman)
loss = not isinstance(raman_op, NoLoss)
if raman and loss:
logger.debug("Conserved quantity : photon number with loss")
return PhotonNumberLoss(w, gamma_op, loss_op)
elif raman:
logger.debug("Conserved quantity : photon number without loss")
return PhotonNumberNoLoss(w, gamma_op)
elif loss:
logger.debug("Conserved quantity : energy with loss")
return EnergyLoss(w, loss_op)
else:
logger.debug("Conserved quantity : energy without loss")
return EnergyNoLoss(w)
@abstractmethod
def __call__(self, state: CurrentState) -> float:
pass
class NoConservedQuantity(ConservedQuantity):
def __call__(self, state: CurrentState) -> float:
return 0.0
class PhotonNumberLoss(ConservedQuantity):
def __init__(self, w: np.ndarray, gamma_op: AbstractGamma, loss_op=AbstractLoss):
self.w = w
self.dw = w[1] - w[0]
self.gamma_op = gamma_op
self.loss_op = loss_op
def __call__(self, state: CurrentState) -> float:
return pulse.photon_number_with_loss(
state.spectrum, self.w, self.dw, self.gamma_op(state), self.loss_op(state), state.h
)
class PhotonNumberNoLoss(ConservedQuantity):
def __init__(self, w: np.ndarray, gamma_op: AbstractGamma):
self.w = w
self.dw = w[1] - w[0]
self.gamma_op = gamma_op
def __call__(self, state: CurrentState) -> float:
return pulse.photon_number(state.spectrum, self.w, self.dw, self.gamma_op(state))
class EnergyLoss(ConservedQuantity):
def __init__(self, w: np.ndarray, loss_op: AbstractLoss):
self.dw = w[1] - w[0]
self.loss_op = loss_op
def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy_with_loss(state.spectrum, self.dw, self.loss_op(state), state.h)
class EnergyNoLoss(ConservedQuantity):
def __init__(self, w: np.ndarray):
self.dw = w[1] - w[0]
def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy(state.spectrum, self.dw)

View File

@@ -2,20 +2,17 @@ from __future__ import annotations
import datetime as datetime_module import datetime as datetime_module
import enum import enum
import itertools
import os import os
import time import time
from collections import defaultdict
from copy import copy from copy import copy
from dataclasses import asdict, dataclass, fields from dataclasses import asdict, dataclass, fields
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar, Union from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
import numpy as np import numpy as np
from numpy.lib import isin
from . import env, math, utils from . import env, utils
from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS
from .logger import get_logger from .logger import get_logger
from .utils import fiber_folder, update_path_name from .utils import fiber_folder, update_path_name
@@ -210,6 +207,7 @@ class Parameter:
self.name = name self.name = name
if self.default is not None: if self.default is not None:
Evaluator.register_default_param(self.name, self.default) Evaluator.register_default_param(self.name, self.default)
VariationDescriptor.register_formatter(self.name, self.display)
def __get__(self, instance, owner): def __get__(self, instance, owner):
if instance is None: if instance is None:
@@ -242,20 +240,7 @@ class Parameter:
@dataclass @dataclass
class _AbstractParameters: class Parameters:
@classmethod
def __init_subclass__(cls):
cls.register_param_formatters()
@classmethod
def register_param_formatters(cls):
for k, v in cls.__dict__.items():
if isinstance(v, Parameter):
VariationDescriptor.register_formatter(k, v.display)
@dataclass
class Parameters(_AbstractParameters):
""" """
This class defines each valid parameter's name, type and valid value. This class defines each valid parameter's name, type and valid value.
""" """

View File

@@ -2,7 +2,6 @@ import multiprocessing
import multiprocessing.connection import multiprocessing.connection
import os import os
import random import random
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Generator, Type, Union from typing import Any, Generator, Type, Union
@@ -13,9 +12,7 @@ from .. import utils
from ..logger import get_logger from ..logger import get_logger
from ..parameter import Configuration, Parameters from ..parameter import Configuration, Parameters
from ..pbar import PBars, ProgressBarActor, progress_worker from ..pbar import PBars, ProgressBarActor, progress_worker
from ..operators import CurrentState from ..operators import CurrentState, ConservedQuantity, NoConservedQuantity
from . import pulse
from .fiber import create_non_linear_op, fast_dispersion_op
try: try:
import ray import ray
@@ -70,10 +67,6 @@ class RK4IP:
self.dw = self.params.w[1] - self.params.w[0] self.dw = self.params.w[1] - self.params.w[0]
self.z_targets = self.params.z_targets self.z_targets = self.params.z_targets
self.beta2_coefficients = (
params.beta_func if params.beta_func is not None else params.beta2_coefficients
)
self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma_arr
self.C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4) self.C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
self.error_ok = ( self.error_ok = (
params.tolerated_error if self.params.adapt_step_size else self.params.step_size params.tolerated_error if self.params.adapt_step_size else self.params.step_size
@@ -83,55 +76,18 @@ class RK4IP:
self._setup_sim_parameters() self._setup_sim_parameters()
def _setup_functions(self): def _setup_functions(self):
self.N_func = create_non_linear_op(
self.params.behaviors,
self.params.w_c,
self.params.w0,
self.gamma,
self.params.raman_type,
hr_w=self.params.hr_w,
)
if self.params.dynamic_dispersion:
self.disp = lambda r: fast_dispersion_op(
self.params.w_c,
self.beta2_coefficients(r),
self.params.w_power_fact,
alpha=self.params.alpha_arr,
)
else:
self.disp = lambda r: fast_dispersion_op(
self.params.w_c,
self.beta2_coefficients,
self.params.w_power_fact,
alpha=self.params.alpha_arr,
)
# Set up which quantity is conserved for adaptive step size # Set up which quantity is conserved for adaptive step size
if self.params.adapt_step_size: if self.params.adapt_step_size:
if "raman" in self.params.behaviors and self.params.alpha_arr is not None: self.conserved_quantity_func = ConservedQuantity(
self.logger.debug("Conserved quantity : photon number with loss") self.params.nonlinear_operator.raman_op,
self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number_with_loss( self.params.nonlinear_operator.gamma_op,
spectrum, self.params.w, self.dw, self.gamma, self.params.alpha_arr, h self.params.linear_operator.loss_op,
) self.params.w,
elif "raman" in self.params.behaviors: )
self.logger.debug("Conserved quantity : photon number without loss")
self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number(
spectrum, self.params.w, self.dw, self.gamma
)
elif self.params.alpha_arr is not None:
self.logger.debug("Conserved quantity : energy with loss")
self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy_with_loss(
self.C_to_A_factor * spectrum, self.dw, self.params.alpha_arr, h
)
else:
self.logger.debug("Conserved quantity : energy without loss")
self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy(
self.C_to_A_factor * spectrum, self.dw
)
else: else:
self.logger.debug(f"Using constant step size of {1e6*self.error_ok:.3f}") self.logger.debug(f"Using constant step size of {1e6*self.error_ok:.3f}")
self.conserved_quantity_func = lambda spectrum, h: 0.0 self.conserved_quantity_func = NoConservedQuantity()
def _setup_sim_parameters(self): def _setup_sim_parameters(self):
# making sure to keep only the z that we want # making sure to keep only the z that we want
@@ -140,27 +96,27 @@ class RK4IP:
self.z_targets.sort() self.z_targets.sort()
self.store_num = len(self.z_targets) self.store_num = len(self.z_targets)
# Initial setup of simulation parameters # Initial step size
self.z = self.z_targets.pop(0) if self.params.adapt_step_size:
initial_h = (self.z_targets[0] - self.z) / 2
else:
initial_h = self.error_ok
# Setup initial values for every physical quantity that we want to track # Setup initial values for every physical quantity that we want to track
self.state = CurrentState( self.state = CurrentState(
length=self.params.length, spectrum=self.params.spec_0.copy() / self.C_to_A_factor length=self.params.length,
z=self.z_targets.pop(0),
h=initial_h,
spectrum=self.params.spec_0.copy() / self.C_to_A_factor,
) )
self.stored_spectra = self.params.recovery_last_stored * [None] + [ self.stored_spectra = self.params.recovery_last_stored * [None] + [
self.state.spectrum.copy() self.state.spectrum.copy()
] ]
self.cons_qty = [ self.cons_qty = [
self.conserved_quantity_func(self.state.spectrum, 0), self.conserved_quantity_func(self.state),
0, 0,
] ]
self.size_fac = 2 ** (1 / 5) self.size_fac = 2 ** (1 / 5)
# Initial step size
if self.params.adapt_step_size:
self.initial_h = (self.z_targets[0] - self.z) / 2
else:
self.initial_h = self.error_ok
def _save_current_spectrum(self, num: int): def _save_current_spectrum(self, num: int):
"""saves the spectrum and the corresponding cons_qty array """saves the spectrum and the corresponding cons_qty array

107
tests.py
View File

@@ -1,84 +1,45 @@
import numpy as np from __future__ import annotations
import scgenerator as sc from collections import defaultdict
import matplotlib.pyplot as plt
def convert(l, beta2): class Parameter:
return l[2:-2] * 1e9, sc.units.beta2_fs_cm.inv(beta2[2:-2]) registered_params = defaultdict(dict)
def __init__(self, default_value, display_suffix=""):
self.value = default_value
self.display_suffix = display_suffix
def __set_name__(self, owner, name):
self.name = name
self.registered_params[owner.__name__][name] = self
def __get__(self, instance, owner):
return self.value
def __set__(self, instance, value):
self.value = value
def display(self):
return str(self.value) + " " + self.display_suffix
def test_empty_marcatili(): class A:
l = np.linspace(250, 1200, 500) * 1e-9 x = Parameter("lol")
beta2 = sc.fiber.HCPCF_dispersion(l, 15e-6) y = Parameter(56.2)
plt.plot(*convert(l, beta2))
plt.show()
def test_empty_hasan_no_resonance(): class B:
l = np.linspace(250, 1200, 500) * 1e-9 x = Parameter(slice(None))
beta2 = sc.fiber.HCPCF_dispersion( opt = None
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=6)
)
plt.plot(*convert(l, beta2))
plt.show()
def test_empty_hasan(): def main():
l = np.linspace(250, 1200, 500) * 1e-9 print(Parameter.registered_params["A"])
fig, (ax, ax2) = plt.subplots(2, 1, figsize=(6, 7), gridspec_kw=dict(height_ratios=[3, 1])) print(Parameter.registered_params["B"])
ax.set_ylim(-40, 20) a = A()
ax2.set_ylim(-100, 0) a.x = 5
beta2 = sc.fiber.HCPCF_dispersion( print(a.x)
l,
12e-6,
model="hasan",
model_params=dict(t=0.2e-6, g=1e-6, n=6, resonance_strength=(2e-6,)),
)
ax.plot(*convert(l, beta2))
beta2 = sc.fiber.HCPCF_dispersion(
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=6)
)
ax.plot(*convert(l, beta2))
l = np.linspace(500, 1500, 500) * 1e-9
beta2 = sc.fiber.HCPCF_dispersion(
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=10)
)
ax2.plot(*convert(l, beta2))
plt.show()
def test_custom_initial_field():
param = {
"name": "test",
"lambda0": [1030, "nm"],
"E0": [6, "uJ"],
"T0_FWHM": [27, "fs"],
"frep": 151e3,
"z_targets": [0, 0.07, 128],
"gas": "argon",
"pressure": 4e5,
"temperature": 293,
"pulse_shape": "sech",
"behaviors": [],
"fiber_model": "marcatili",
"model_params": {"core_radius": 18e-6},
"field_0": "exp(-(t/t0)**2)*P0 + P0/10 * cos(t/t0)*2*exp(-(0.05*t/t0)**2)",
"nt": 16384,
"T": 2e-12,
"adapt_step_size": True,
"error_ok": 1e-10,
"interp_range": [120, 2000],
"n_percent": 2,
}
p = sc.compute_init_parameters(dictionary=param)
fig, ax = plt.subplots()
ax.plot(p["t"], abs(p["field_0"]))
plt.show()
if __name__ == "__main__": if __name__ == "__main__":
# test_empty_marcatili() main()
# test_empty_hasan()
test_custom_initial_field()