Removed inheritence; cons_qty operators
This commit is contained in:
@@ -1,13 +1,15 @@
|
|||||||
from typing import Optional, Callable, Union, Any
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from .physics import fiber, pulse, materials, units
|
|
||||||
from .utils import _mock_function, get_arg_names, get_logger, func_rewrite
|
|
||||||
from .errors import *
|
|
||||||
from collections import defaultdict
|
|
||||||
from .const import MANDATORY_PARAMETERS
|
|
||||||
import numpy as np
|
|
||||||
import itertools
|
import itertools
|
||||||
from . import math, utils, operators
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from . import math, operators, utils
|
||||||
|
from .const import MANDATORY_PARAMETERS
|
||||||
|
from .errors import *
|
||||||
|
from .physics import fiber, materials, pulse, units
|
||||||
|
from .utils import _mock_function, func_rewrite, get_arg_names, get_logger
|
||||||
|
|
||||||
|
|
||||||
class Rule:
|
class Rule:
|
||||||
@@ -378,6 +380,7 @@ default_rules: list[Rule] = [
|
|||||||
Rule("loss_op", operators.NoLoss, priorities=-1),
|
Rule("loss_op", operators.NoLoss, priorities=-1),
|
||||||
Rule("disp_op", operators.ConstantPolyDispersion),
|
Rule("disp_op", operators.ConstantPolyDispersion),
|
||||||
Rule("linear_operator", operators.LinearOperator),
|
Rule("linear_operator", operators.LinearOperator),
|
||||||
|
Rule("conserved_quantity", operators.ConservedQuantity),
|
||||||
# gas
|
# gas
|
||||||
Rule("n_gas_2", materials.n_gas_2),
|
Rule("n_gas_2", materials.n_gas_2),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6,12 +6,16 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from functools import wraps
|
||||||
|
from os import stat
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
|
||||||
from .physics import fiber
|
|
||||||
from . import math
|
from . import math
|
||||||
|
from .logger import get_logger
|
||||||
|
from .physics import fiber, pulse
|
||||||
|
|
||||||
|
|
||||||
class SpectrumDescriptor:
|
class SpectrumDescriptor:
|
||||||
@@ -385,3 +389,78 @@ class CustomConstantLoss(ConstantLoss):
|
|||||||
wl = loss_data["wavelength"]
|
wl = loss_data["wavelength"]
|
||||||
loss = loss_data["loss"]
|
loss = loss_data["loss"]
|
||||||
self.alpha_arr = interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
self.alpha_arr = interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
############### CONSERVED QUANTITY ###############
|
||||||
|
##################################################
|
||||||
|
|
||||||
|
|
||||||
|
class ConservedQuantity(Operator):
|
||||||
|
def __new__(
|
||||||
|
raman_op: AbstractGamma, gamma_op: AbstractGamma, loss_op: AbstractLoss, w: np.ndarray
|
||||||
|
):
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
raman = not isinstance(raman_op, NoRaman)
|
||||||
|
loss = not isinstance(raman_op, NoLoss)
|
||||||
|
if raman and loss:
|
||||||
|
logger.debug("Conserved quantity : photon number with loss")
|
||||||
|
return PhotonNumberLoss(w, gamma_op, loss_op)
|
||||||
|
elif raman:
|
||||||
|
logger.debug("Conserved quantity : photon number without loss")
|
||||||
|
return PhotonNumberNoLoss(w, gamma_op)
|
||||||
|
elif loss:
|
||||||
|
logger.debug("Conserved quantity : energy with loss")
|
||||||
|
return EnergyLoss(w, loss_op)
|
||||||
|
else:
|
||||||
|
logger.debug("Conserved quantity : energy without loss")
|
||||||
|
return EnergyNoLoss(w)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, state: CurrentState) -> float:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoConservedQuantity(ConservedQuantity):
|
||||||
|
def __call__(self, state: CurrentState) -> float:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class PhotonNumberLoss(ConservedQuantity):
|
||||||
|
def __init__(self, w: np.ndarray, gamma_op: AbstractGamma, loss_op=AbstractLoss):
|
||||||
|
self.w = w
|
||||||
|
self.dw = w[1] - w[0]
|
||||||
|
self.gamma_op = gamma_op
|
||||||
|
self.loss_op = loss_op
|
||||||
|
|
||||||
|
def __call__(self, state: CurrentState) -> float:
|
||||||
|
return pulse.photon_number_with_loss(
|
||||||
|
state.spectrum, self.w, self.dw, self.gamma_op(state), self.loss_op(state), state.h
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PhotonNumberNoLoss(ConservedQuantity):
|
||||||
|
def __init__(self, w: np.ndarray, gamma_op: AbstractGamma):
|
||||||
|
self.w = w
|
||||||
|
self.dw = w[1] - w[0]
|
||||||
|
self.gamma_op = gamma_op
|
||||||
|
|
||||||
|
def __call__(self, state: CurrentState) -> float:
|
||||||
|
return pulse.photon_number(state.spectrum, self.w, self.dw, self.gamma_op(state))
|
||||||
|
|
||||||
|
|
||||||
|
class EnergyLoss(ConservedQuantity):
|
||||||
|
def __init__(self, w: np.ndarray, loss_op: AbstractLoss):
|
||||||
|
self.dw = w[1] - w[0]
|
||||||
|
self.loss_op = loss_op
|
||||||
|
|
||||||
|
def __call__(self, state: CurrentState) -> float:
|
||||||
|
return pulse.pulse_energy_with_loss(state.spectrum, self.dw, self.loss_op(state), state.h)
|
||||||
|
|
||||||
|
|
||||||
|
class EnergyNoLoss(ConservedQuantity):
|
||||||
|
def __init__(self, w: np.ndarray):
|
||||||
|
self.dw = w[1] - w[0]
|
||||||
|
|
||||||
|
def __call__(self, state: CurrentState) -> float:
|
||||||
|
return pulse.pulse_energy(state.spectrum, self.dw)
|
||||||
|
|||||||
@@ -2,20 +2,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import datetime as datetime_module
|
import datetime as datetime_module
|
||||||
import enum
|
import enum
|
||||||
import itertools
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import asdict, dataclass, fields
|
from dataclasses import asdict, dataclass, fields
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar, Union
|
from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.lib import isin
|
|
||||||
|
|
||||||
from . import env, math, utils
|
from . import env, utils
|
||||||
from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS
|
from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .utils import fiber_folder, update_path_name
|
from .utils import fiber_folder, update_path_name
|
||||||
@@ -210,6 +207,7 @@ class Parameter:
|
|||||||
self.name = name
|
self.name = name
|
||||||
if self.default is not None:
|
if self.default is not None:
|
||||||
Evaluator.register_default_param(self.name, self.default)
|
Evaluator.register_default_param(self.name, self.default)
|
||||||
|
VariationDescriptor.register_formatter(self.name, self.display)
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance, owner):
|
||||||
if instance is None:
|
if instance is None:
|
||||||
@@ -242,20 +240,7 @@ class Parameter:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _AbstractParameters:
|
class Parameters:
|
||||||
@classmethod
|
|
||||||
def __init_subclass__(cls):
|
|
||||||
cls.register_param_formatters()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_param_formatters(cls):
|
|
||||||
for k, v in cls.__dict__.items():
|
|
||||||
if isinstance(v, Parameter):
|
|
||||||
VariationDescriptor.register_formatter(k, v.display)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Parameters(_AbstractParameters):
|
|
||||||
"""
|
"""
|
||||||
This class defines each valid parameter's name, type and valid value.
|
This class defines each valid parameter's name, type and valid value.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import multiprocessing
|
|||||||
import multiprocessing.connection
|
import multiprocessing.connection
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Generator, Type, Union
|
from typing import Any, Generator, Type, Union
|
||||||
@@ -13,9 +12,7 @@ from .. import utils
|
|||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..parameter import Configuration, Parameters
|
from ..parameter import Configuration, Parameters
|
||||||
from ..pbar import PBars, ProgressBarActor, progress_worker
|
from ..pbar import PBars, ProgressBarActor, progress_worker
|
||||||
from ..operators import CurrentState
|
from ..operators import CurrentState, ConservedQuantity, NoConservedQuantity
|
||||||
from . import pulse
|
|
||||||
from .fiber import create_non_linear_op, fast_dispersion_op
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
@@ -70,10 +67,6 @@ class RK4IP:
|
|||||||
|
|
||||||
self.dw = self.params.w[1] - self.params.w[0]
|
self.dw = self.params.w[1] - self.params.w[0]
|
||||||
self.z_targets = self.params.z_targets
|
self.z_targets = self.params.z_targets
|
||||||
self.beta2_coefficients = (
|
|
||||||
params.beta_func if params.beta_func is not None else params.beta2_coefficients
|
|
||||||
)
|
|
||||||
self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma_arr
|
|
||||||
self.C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
|
self.C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
|
||||||
self.error_ok = (
|
self.error_ok = (
|
||||||
params.tolerated_error if self.params.adapt_step_size else self.params.step_size
|
params.tolerated_error if self.params.adapt_step_size else self.params.step_size
|
||||||
@@ -83,55 +76,18 @@ class RK4IP:
|
|||||||
self._setup_sim_parameters()
|
self._setup_sim_parameters()
|
||||||
|
|
||||||
def _setup_functions(self):
|
def _setup_functions(self):
|
||||||
self.N_func = create_non_linear_op(
|
|
||||||
self.params.behaviors,
|
|
||||||
self.params.w_c,
|
|
||||||
self.params.w0,
|
|
||||||
self.gamma,
|
|
||||||
self.params.raman_type,
|
|
||||||
hr_w=self.params.hr_w,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.params.dynamic_dispersion:
|
|
||||||
self.disp = lambda r: fast_dispersion_op(
|
|
||||||
self.params.w_c,
|
|
||||||
self.beta2_coefficients(r),
|
|
||||||
self.params.w_power_fact,
|
|
||||||
alpha=self.params.alpha_arr,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.disp = lambda r: fast_dispersion_op(
|
|
||||||
self.params.w_c,
|
|
||||||
self.beta2_coefficients,
|
|
||||||
self.params.w_power_fact,
|
|
||||||
alpha=self.params.alpha_arr,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up which quantity is conserved for adaptive step size
|
# Set up which quantity is conserved for adaptive step size
|
||||||
if self.params.adapt_step_size:
|
if self.params.adapt_step_size:
|
||||||
if "raman" in self.params.behaviors and self.params.alpha_arr is not None:
|
self.conserved_quantity_func = ConservedQuantity(
|
||||||
self.logger.debug("Conserved quantity : photon number with loss")
|
self.params.nonlinear_operator.raman_op,
|
||||||
self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number_with_loss(
|
self.params.nonlinear_operator.gamma_op,
|
||||||
spectrum, self.params.w, self.dw, self.gamma, self.params.alpha_arr, h
|
self.params.linear_operator.loss_op,
|
||||||
)
|
self.params.w,
|
||||||
elif "raman" in self.params.behaviors:
|
)
|
||||||
self.logger.debug("Conserved quantity : photon number without loss")
|
|
||||||
self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number(
|
|
||||||
spectrum, self.params.w, self.dw, self.gamma
|
|
||||||
)
|
|
||||||
elif self.params.alpha_arr is not None:
|
|
||||||
self.logger.debug("Conserved quantity : energy with loss")
|
|
||||||
self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy_with_loss(
|
|
||||||
self.C_to_A_factor * spectrum, self.dw, self.params.alpha_arr, h
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.debug("Conserved quantity : energy without loss")
|
|
||||||
self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy(
|
|
||||||
self.C_to_A_factor * spectrum, self.dw
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.logger.debug(f"Using constant step size of {1e6*self.error_ok:.3f}")
|
self.logger.debug(f"Using constant step size of {1e6*self.error_ok:.3f}")
|
||||||
self.conserved_quantity_func = lambda spectrum, h: 0.0
|
self.conserved_quantity_func = NoConservedQuantity()
|
||||||
|
|
||||||
def _setup_sim_parameters(self):
|
def _setup_sim_parameters(self):
|
||||||
# making sure to keep only the z that we want
|
# making sure to keep only the z that we want
|
||||||
@@ -140,27 +96,27 @@ class RK4IP:
|
|||||||
self.z_targets.sort()
|
self.z_targets.sort()
|
||||||
self.store_num = len(self.z_targets)
|
self.store_num = len(self.z_targets)
|
||||||
|
|
||||||
# Initial setup of simulation parameters
|
# Initial step size
|
||||||
self.z = self.z_targets.pop(0)
|
if self.params.adapt_step_size:
|
||||||
|
initial_h = (self.z_targets[0] - self.z) / 2
|
||||||
|
else:
|
||||||
|
initial_h = self.error_ok
|
||||||
# Setup initial values for every physical quantity that we want to track
|
# Setup initial values for every physical quantity that we want to track
|
||||||
self.state = CurrentState(
|
self.state = CurrentState(
|
||||||
length=self.params.length, spectrum=self.params.spec_0.copy() / self.C_to_A_factor
|
length=self.params.length,
|
||||||
|
z=self.z_targets.pop(0),
|
||||||
|
h=initial_h,
|
||||||
|
spectrum=self.params.spec_0.copy() / self.C_to_A_factor,
|
||||||
)
|
)
|
||||||
self.stored_spectra = self.params.recovery_last_stored * [None] + [
|
self.stored_spectra = self.params.recovery_last_stored * [None] + [
|
||||||
self.state.spectrum.copy()
|
self.state.spectrum.copy()
|
||||||
]
|
]
|
||||||
self.cons_qty = [
|
self.cons_qty = [
|
||||||
self.conserved_quantity_func(self.state.spectrum, 0),
|
self.conserved_quantity_func(self.state),
|
||||||
0,
|
0,
|
||||||
]
|
]
|
||||||
self.size_fac = 2 ** (1 / 5)
|
self.size_fac = 2 ** (1 / 5)
|
||||||
|
|
||||||
# Initial step size
|
|
||||||
if self.params.adapt_step_size:
|
|
||||||
self.initial_h = (self.z_targets[0] - self.z) / 2
|
|
||||||
else:
|
|
||||||
self.initial_h = self.error_ok
|
|
||||||
|
|
||||||
def _save_current_spectrum(self, num: int):
|
def _save_current_spectrum(self, num: int):
|
||||||
"""saves the spectrum and the corresponding cons_qty array
|
"""saves the spectrum and the corresponding cons_qty array
|
||||||
|
|||||||
107
tests.py
107
tests.py
@@ -1,84 +1,45 @@
|
|||||||
import numpy as np
|
from __future__ import annotations
|
||||||
import scgenerator as sc
|
from collections import defaultdict
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
def convert(l, beta2):
|
class Parameter:
|
||||||
return l[2:-2] * 1e9, sc.units.beta2_fs_cm.inv(beta2[2:-2])
|
registered_params = defaultdict(dict)
|
||||||
|
|
||||||
|
def __init__(self, default_value, display_suffix=""):
|
||||||
|
self.value = default_value
|
||||||
|
self.display_suffix = display_suffix
|
||||||
|
|
||||||
|
def __set_name__(self, owner, name):
|
||||||
|
self.name = name
|
||||||
|
self.registered_params[owner.__name__][name] = self
|
||||||
|
|
||||||
|
def __get__(self, instance, owner):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def __set__(self, instance, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def display(self):
|
||||||
|
return str(self.value) + " " + self.display_suffix
|
||||||
|
|
||||||
|
|
||||||
def test_empty_marcatili():
|
class A:
|
||||||
l = np.linspace(250, 1200, 500) * 1e-9
|
x = Parameter("lol")
|
||||||
beta2 = sc.fiber.HCPCF_dispersion(l, 15e-6)
|
y = Parameter(56.2)
|
||||||
plt.plot(*convert(l, beta2))
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_hasan_no_resonance():
|
class B:
|
||||||
l = np.linspace(250, 1200, 500) * 1e-9
|
x = Parameter(slice(None))
|
||||||
beta2 = sc.fiber.HCPCF_dispersion(
|
opt = None
|
||||||
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=6)
|
|
||||||
)
|
|
||||||
plt.plot(*convert(l, beta2))
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_hasan():
|
def main():
|
||||||
l = np.linspace(250, 1200, 500) * 1e-9
|
print(Parameter.registered_params["A"])
|
||||||
fig, (ax, ax2) = plt.subplots(2, 1, figsize=(6, 7), gridspec_kw=dict(height_ratios=[3, 1]))
|
print(Parameter.registered_params["B"])
|
||||||
ax.set_ylim(-40, 20)
|
a = A()
|
||||||
ax2.set_ylim(-100, 0)
|
a.x = 5
|
||||||
beta2 = sc.fiber.HCPCF_dispersion(
|
print(a.x)
|
||||||
l,
|
|
||||||
12e-6,
|
|
||||||
model="hasan",
|
|
||||||
model_params=dict(t=0.2e-6, g=1e-6, n=6, resonance_strength=(2e-6,)),
|
|
||||||
)
|
|
||||||
ax.plot(*convert(l, beta2))
|
|
||||||
beta2 = sc.fiber.HCPCF_dispersion(
|
|
||||||
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=6)
|
|
||||||
)
|
|
||||||
ax.plot(*convert(l, beta2))
|
|
||||||
|
|
||||||
l = np.linspace(500, 1500, 500) * 1e-9
|
|
||||||
beta2 = sc.fiber.HCPCF_dispersion(
|
|
||||||
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=10)
|
|
||||||
)
|
|
||||||
ax2.plot(*convert(l, beta2))
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_initial_field():
|
|
||||||
param = {
|
|
||||||
"name": "test",
|
|
||||||
"lambda0": [1030, "nm"],
|
|
||||||
"E0": [6, "uJ"],
|
|
||||||
"T0_FWHM": [27, "fs"],
|
|
||||||
"frep": 151e3,
|
|
||||||
"z_targets": [0, 0.07, 128],
|
|
||||||
"gas": "argon",
|
|
||||||
"pressure": 4e5,
|
|
||||||
"temperature": 293,
|
|
||||||
"pulse_shape": "sech",
|
|
||||||
"behaviors": [],
|
|
||||||
"fiber_model": "marcatili",
|
|
||||||
"model_params": {"core_radius": 18e-6},
|
|
||||||
"field_0": "exp(-(t/t0)**2)*P0 + P0/10 * cos(t/t0)*2*exp(-(0.05*t/t0)**2)",
|
|
||||||
"nt": 16384,
|
|
||||||
"T": 2e-12,
|
|
||||||
"adapt_step_size": True,
|
|
||||||
"error_ok": 1e-10,
|
|
||||||
"interp_range": [120, 2000],
|
|
||||||
"n_percent": 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
p = sc.compute_init_parameters(dictionary=param)
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
ax.plot(p["t"], abs(p["field_0"]))
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# test_empty_marcatili()
|
main()
|
||||||
# test_empty_hasan()
|
|
||||||
test_custom_initial_field()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user