Removed inheritence; cons_qty operators

This commit is contained in:
Benoît Sierro
2021-10-19 17:15:02 +02:00
parent ecb5ee681a
commit de12b0d5c1
5 changed files with 148 additions and 164 deletions

View File

@@ -1,13 +1,15 @@
from typing import Optional, Callable, Union, Any
from dataclasses import dataclass
from .physics import fiber, pulse, materials, units
from .utils import _mock_function, get_arg_names, get_logger, func_rewrite
from .errors import *
from collections import defaultdict
from .const import MANDATORY_PARAMETERS
import numpy as np
import itertools
from . import math, utils, operators
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import numpy as np
from . import math, operators, utils
from .const import MANDATORY_PARAMETERS
from .errors import *
from .physics import fiber, materials, pulse, units
from .utils import _mock_function, func_rewrite, get_arg_names, get_logger
class Rule:
@@ -378,6 +380,7 @@ default_rules: list[Rule] = [
Rule("loss_op", operators.NoLoss, priorities=-1),
Rule("disp_op", operators.ConstantPolyDispersion),
Rule("linear_operator", operators.LinearOperator),
Rule("conserved_quantity", operators.ConservedQuantity),
# gas
Rule("n_gas_2", materials.n_gas_2),
]

View File

@@ -6,12 +6,16 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import wraps
from os import stat
from typing import Callable
import numpy as np
from scipy.interpolate import interp1d
from .physics import fiber
from . import math
from .logger import get_logger
from .physics import fiber, pulse
class SpectrumDescriptor:
@@ -385,3 +389,78 @@ class CustomConstantLoss(ConstantLoss):
wl = loss_data["wavelength"]
loss = loss_data["loss"]
self.alpha_arr = interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
##################################################
############### CONSERVED QUANTITY ###############
##################################################
class ConservedQuantity(Operator):
def __new__(
raman_op: AbstractGamma, gamma_op: AbstractGamma, loss_op: AbstractLoss, w: np.ndarray
):
logger = get_logger(__name__)
raman = not isinstance(raman_op, NoRaman)
loss = not isinstance(raman_op, NoLoss)
if raman and loss:
logger.debug("Conserved quantity : photon number with loss")
return PhotonNumberLoss(w, gamma_op, loss_op)
elif raman:
logger.debug("Conserved quantity : photon number without loss")
return PhotonNumberNoLoss(w, gamma_op)
elif loss:
logger.debug("Conserved quantity : energy with loss")
return EnergyLoss(w, loss_op)
else:
logger.debug("Conserved quantity : energy without loss")
return EnergyNoLoss(w)
@abstractmethod
def __call__(self, state: CurrentState) -> float:
pass
class NoConservedQuantity(ConservedQuantity):
def __call__(self, state: CurrentState) -> float:
return 0.0
class PhotonNumberLoss(ConservedQuantity):
def __init__(self, w: np.ndarray, gamma_op: AbstractGamma, loss_op=AbstractLoss):
self.w = w
self.dw = w[1] - w[0]
self.gamma_op = gamma_op
self.loss_op = loss_op
def __call__(self, state: CurrentState) -> float:
return pulse.photon_number_with_loss(
state.spectrum, self.w, self.dw, self.gamma_op(state), self.loss_op(state), state.h
)
class PhotonNumberNoLoss(ConservedQuantity):
def __init__(self, w: np.ndarray, gamma_op: AbstractGamma):
self.w = w
self.dw = w[1] - w[0]
self.gamma_op = gamma_op
def __call__(self, state: CurrentState) -> float:
return pulse.photon_number(state.spectrum, self.w, self.dw, self.gamma_op(state))
class EnergyLoss(ConservedQuantity):
def __init__(self, w: np.ndarray, loss_op: AbstractLoss):
self.dw = w[1] - w[0]
self.loss_op = loss_op
def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy_with_loss(state.spectrum, self.dw, self.loss_op(state), state.h)
class EnergyNoLoss(ConservedQuantity):
def __init__(self, w: np.ndarray):
self.dw = w[1] - w[0]
def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy(state.spectrum, self.dw)

View File

@@ -2,20 +2,17 @@ from __future__ import annotations
import datetime as datetime_module
import enum
import itertools
import os
import time
from collections import defaultdict
from copy import copy
from dataclasses import asdict, dataclass, fields
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar, Union
from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
import numpy as np
from numpy.lib import isin
from . import env, math, utils
from . import env, utils
from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS
from .logger import get_logger
from .utils import fiber_folder, update_path_name
@@ -210,6 +207,7 @@ class Parameter:
self.name = name
if self.default is not None:
Evaluator.register_default_param(self.name, self.default)
VariationDescriptor.register_formatter(self.name, self.display)
def __get__(self, instance, owner):
if instance is None:
@@ -242,20 +240,7 @@ class Parameter:
@dataclass
class _AbstractParameters:
@classmethod
def __init_subclass__(cls):
cls.register_param_formatters()
@classmethod
def register_param_formatters(cls):
for k, v in cls.__dict__.items():
if isinstance(v, Parameter):
VariationDescriptor.register_formatter(k, v.display)
@dataclass
class Parameters(_AbstractParameters):
class Parameters:
"""
This class defines each valid parameter's name, type and valid value.
"""

View File

@@ -2,7 +2,6 @@ import multiprocessing
import multiprocessing.connection
import os
import random
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Generator, Type, Union
@@ -13,9 +12,7 @@ from .. import utils
from ..logger import get_logger
from ..parameter import Configuration, Parameters
from ..pbar import PBars, ProgressBarActor, progress_worker
from ..operators import CurrentState
from . import pulse
from .fiber import create_non_linear_op, fast_dispersion_op
from ..operators import CurrentState, ConservedQuantity, NoConservedQuantity
try:
import ray
@@ -70,10 +67,6 @@ class RK4IP:
self.dw = self.params.w[1] - self.params.w[0]
self.z_targets = self.params.z_targets
self.beta2_coefficients = (
params.beta_func if params.beta_func is not None else params.beta2_coefficients
)
self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma_arr
self.C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
self.error_ok = (
params.tolerated_error if self.params.adapt_step_size else self.params.step_size
@@ -83,55 +76,18 @@ class RK4IP:
self._setup_sim_parameters()
def _setup_functions(self):
self.N_func = create_non_linear_op(
self.params.behaviors,
self.params.w_c,
self.params.w0,
self.gamma,
self.params.raman_type,
hr_w=self.params.hr_w,
)
if self.params.dynamic_dispersion:
self.disp = lambda r: fast_dispersion_op(
self.params.w_c,
self.beta2_coefficients(r),
self.params.w_power_fact,
alpha=self.params.alpha_arr,
)
else:
self.disp = lambda r: fast_dispersion_op(
self.params.w_c,
self.beta2_coefficients,
self.params.w_power_fact,
alpha=self.params.alpha_arr,
)
# Set up which quantity is conserved for adaptive step size
if self.params.adapt_step_size:
if "raman" in self.params.behaviors and self.params.alpha_arr is not None:
self.logger.debug("Conserved quantity : photon number with loss")
self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number_with_loss(
spectrum, self.params.w, self.dw, self.gamma, self.params.alpha_arr, h
)
elif "raman" in self.params.behaviors:
self.logger.debug("Conserved quantity : photon number without loss")
self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number(
spectrum, self.params.w, self.dw, self.gamma
)
elif self.params.alpha_arr is not None:
self.logger.debug("Conserved quantity : energy with loss")
self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy_with_loss(
self.C_to_A_factor * spectrum, self.dw, self.params.alpha_arr, h
)
else:
self.logger.debug("Conserved quantity : energy without loss")
self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy(
self.C_to_A_factor * spectrum, self.dw
self.conserved_quantity_func = ConservedQuantity(
self.params.nonlinear_operator.raman_op,
self.params.nonlinear_operator.gamma_op,
self.params.linear_operator.loss_op,
self.params.w,
)
else:
self.logger.debug(f"Using constant step size of {1e6*self.error_ok:.3f}")
self.conserved_quantity_func = lambda spectrum, h: 0.0
self.conserved_quantity_func = NoConservedQuantity()
def _setup_sim_parameters(self):
# making sure to keep only the z that we want
@@ -140,27 +96,27 @@ class RK4IP:
self.z_targets.sort()
self.store_num = len(self.z_targets)
# Initial setup of simulation parameters
self.z = self.z_targets.pop(0)
# Initial step size
if self.params.adapt_step_size:
initial_h = (self.z_targets[0] - self.z) / 2
else:
initial_h = self.error_ok
# Setup initial values for every physical quantity that we want to track
self.state = CurrentState(
length=self.params.length, spectrum=self.params.spec_0.copy() / self.C_to_A_factor
length=self.params.length,
z=self.z_targets.pop(0),
h=initial_h,
spectrum=self.params.spec_0.copy() / self.C_to_A_factor,
)
self.stored_spectra = self.params.recovery_last_stored * [None] + [
self.state.spectrum.copy()
]
self.cons_qty = [
self.conserved_quantity_func(self.state.spectrum, 0),
self.conserved_quantity_func(self.state),
0,
]
self.size_fac = 2 ** (1 / 5)
# Initial step size
if self.params.adapt_step_size:
self.initial_h = (self.z_targets[0] - self.z) / 2
else:
self.initial_h = self.error_ok
def _save_current_spectrum(self, num: int):
"""saves the spectrum and the corresponding cons_qty array

107
tests.py
View File

@@ -1,84 +1,45 @@
import numpy as np
import scgenerator as sc
import matplotlib.pyplot as plt
from __future__ import annotations
from collections import defaultdict
def convert(l, beta2):
return l[2:-2] * 1e9, sc.units.beta2_fs_cm.inv(beta2[2:-2])
class Parameter:
registered_params = defaultdict(dict)
def __init__(self, default_value, display_suffix=""):
self.value = default_value
self.display_suffix = display_suffix
def __set_name__(self, owner, name):
self.name = name
self.registered_params[owner.__name__][name] = self
def __get__(self, instance, owner):
return self.value
def __set__(self, instance, value):
self.value = value
def display(self):
return str(self.value) + " " + self.display_suffix
def test_empty_marcatili():
l = np.linspace(250, 1200, 500) * 1e-9
beta2 = sc.fiber.HCPCF_dispersion(l, 15e-6)
plt.plot(*convert(l, beta2))
plt.show()
class A:
x = Parameter("lol")
y = Parameter(56.2)
def test_empty_hasan_no_resonance():
l = np.linspace(250, 1200, 500) * 1e-9
beta2 = sc.fiber.HCPCF_dispersion(
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=6)
)
plt.plot(*convert(l, beta2))
plt.show()
class B:
x = Parameter(slice(None))
opt = None
def test_empty_hasan():
l = np.linspace(250, 1200, 500) * 1e-9
fig, (ax, ax2) = plt.subplots(2, 1, figsize=(6, 7), gridspec_kw=dict(height_ratios=[3, 1]))
ax.set_ylim(-40, 20)
ax2.set_ylim(-100, 0)
beta2 = sc.fiber.HCPCF_dispersion(
l,
12e-6,
model="hasan",
model_params=dict(t=0.2e-6, g=1e-6, n=6, resonance_strength=(2e-6,)),
)
ax.plot(*convert(l, beta2))
beta2 = sc.fiber.HCPCF_dispersion(
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=6)
)
ax.plot(*convert(l, beta2))
l = np.linspace(500, 1500, 500) * 1e-9
beta2 = sc.fiber.HCPCF_dispersion(
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=10)
)
ax2.plot(*convert(l, beta2))
plt.show()
def test_custom_initial_field():
param = {
"name": "test",
"lambda0": [1030, "nm"],
"E0": [6, "uJ"],
"T0_FWHM": [27, "fs"],
"frep": 151e3,
"z_targets": [0, 0.07, 128],
"gas": "argon",
"pressure": 4e5,
"temperature": 293,
"pulse_shape": "sech",
"behaviors": [],
"fiber_model": "marcatili",
"model_params": {"core_radius": 18e-6},
"field_0": "exp(-(t/t0)**2)*P0 + P0/10 * cos(t/t0)*2*exp(-(0.05*t/t0)**2)",
"nt": 16384,
"T": 2e-12,
"adapt_step_size": True,
"error_ok": 1e-10,
"interp_range": [120, 2000],
"n_percent": 2,
}
p = sc.compute_init_parameters(dictionary=param)
fig, ax = plt.subplots()
ax.plot(p["t"], abs(p["field_0"]))
plt.show()
def main():
print(Parameter.registered_params["A"])
print(Parameter.registered_params["B"])
a = A()
a.x = 5
print(a.x)
if __name__ == "__main__":
# test_empty_marcatili()
# test_empty_hasan()
test_custom_initial_field()
main()