Removed inheritence; cons_qty operators
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
from typing import Optional, Callable, Union, Any
|
||||
from dataclasses import dataclass
|
||||
from .physics import fiber, pulse, materials, units
|
||||
from .utils import _mock_function, get_arg_names, get_logger, func_rewrite
|
||||
from .errors import *
|
||||
from collections import defaultdict
|
||||
from .const import MANDATORY_PARAMETERS
|
||||
import numpy as np
|
||||
import itertools
|
||||
from . import math, utils, operators
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import math, operators, utils
|
||||
from .const import MANDATORY_PARAMETERS
|
||||
from .errors import *
|
||||
from .physics import fiber, materials, pulse, units
|
||||
from .utils import _mock_function, func_rewrite, get_arg_names, get_logger
|
||||
|
||||
|
||||
class Rule:
|
||||
@@ -378,6 +380,7 @@ default_rules: list[Rule] = [
|
||||
Rule("loss_op", operators.NoLoss, priorities=-1),
|
||||
Rule("disp_op", operators.ConstantPolyDispersion),
|
||||
Rule("linear_operator", operators.LinearOperator),
|
||||
Rule("conserved_quantity", operators.ConservedQuantity),
|
||||
# gas
|
||||
Rule("n_gas_2", materials.n_gas_2),
|
||||
]
|
||||
|
||||
@@ -6,12 +6,16 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from os import stat
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from .physics import fiber
|
||||
from . import math
|
||||
from .logger import get_logger
|
||||
from .physics import fiber, pulse
|
||||
|
||||
|
||||
class SpectrumDescriptor:
|
||||
@@ -385,3 +389,78 @@ class CustomConstantLoss(ConstantLoss):
|
||||
wl = loss_data["wavelength"]
|
||||
loss = loss_data["loss"]
|
||||
self.alpha_arr = interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
||||
|
||||
|
||||
##################################################
|
||||
############### CONSERVED QUANTITY ###############
|
||||
##################################################
|
||||
|
||||
|
||||
class ConservedQuantity(Operator):
|
||||
def __new__(
|
||||
raman_op: AbstractGamma, gamma_op: AbstractGamma, loss_op: AbstractLoss, w: np.ndarray
|
||||
):
|
||||
logger = get_logger(__name__)
|
||||
raman = not isinstance(raman_op, NoRaman)
|
||||
loss = not isinstance(raman_op, NoLoss)
|
||||
if raman and loss:
|
||||
logger.debug("Conserved quantity : photon number with loss")
|
||||
return PhotonNumberLoss(w, gamma_op, loss_op)
|
||||
elif raman:
|
||||
logger.debug("Conserved quantity : photon number without loss")
|
||||
return PhotonNumberNoLoss(w, gamma_op)
|
||||
elif loss:
|
||||
logger.debug("Conserved quantity : energy with loss")
|
||||
return EnergyLoss(w, loss_op)
|
||||
else:
|
||||
logger.debug("Conserved quantity : energy without loss")
|
||||
return EnergyNoLoss(w)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
pass
|
||||
|
||||
|
||||
class NoConservedQuantity(ConservedQuantity):
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return 0.0
|
||||
|
||||
|
||||
class PhotonNumberLoss(ConservedQuantity):
|
||||
def __init__(self, w: np.ndarray, gamma_op: AbstractGamma, loss_op=AbstractLoss):
|
||||
self.w = w
|
||||
self.dw = w[1] - w[0]
|
||||
self.gamma_op = gamma_op
|
||||
self.loss_op = loss_op
|
||||
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return pulse.photon_number_with_loss(
|
||||
state.spectrum, self.w, self.dw, self.gamma_op(state), self.loss_op(state), state.h
|
||||
)
|
||||
|
||||
|
||||
class PhotonNumberNoLoss(ConservedQuantity):
|
||||
def __init__(self, w: np.ndarray, gamma_op: AbstractGamma):
|
||||
self.w = w
|
||||
self.dw = w[1] - w[0]
|
||||
self.gamma_op = gamma_op
|
||||
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return pulse.photon_number(state.spectrum, self.w, self.dw, self.gamma_op(state))
|
||||
|
||||
|
||||
class EnergyLoss(ConservedQuantity):
|
||||
def __init__(self, w: np.ndarray, loss_op: AbstractLoss):
|
||||
self.dw = w[1] - w[0]
|
||||
self.loss_op = loss_op
|
||||
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return pulse.pulse_energy_with_loss(state.spectrum, self.dw, self.loss_op(state), state.h)
|
||||
|
||||
|
||||
class EnergyNoLoss(ConservedQuantity):
|
||||
def __init__(self, w: np.ndarray):
|
||||
self.dw = w[1] - w[0]
|
||||
|
||||
def __call__(self, state: CurrentState) -> float:
|
||||
return pulse.pulse_energy(state.spectrum, self.dw)
|
||||
|
||||
@@ -2,20 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import datetime as datetime_module
|
||||
import enum
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, Iterable, Iterator, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib import isin
|
||||
|
||||
from . import env, math, utils
|
||||
from . import env, utils
|
||||
from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS
|
||||
from .logger import get_logger
|
||||
from .utils import fiber_folder, update_path_name
|
||||
@@ -210,6 +207,7 @@ class Parameter:
|
||||
self.name = name
|
||||
if self.default is not None:
|
||||
Evaluator.register_default_param(self.name, self.default)
|
||||
VariationDescriptor.register_formatter(self.name, self.display)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
@@ -242,20 +240,7 @@ class Parameter:
|
||||
|
||||
|
||||
@dataclass
|
||||
class _AbstractParameters:
|
||||
@classmethod
|
||||
def __init_subclass__(cls):
|
||||
cls.register_param_formatters()
|
||||
|
||||
@classmethod
|
||||
def register_param_formatters(cls):
|
||||
for k, v in cls.__dict__.items():
|
||||
if isinstance(v, Parameter):
|
||||
VariationDescriptor.register_formatter(k, v.display)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Parameters(_AbstractParameters):
|
||||
class Parameters:
|
||||
"""
|
||||
This class defines each valid parameter's name, type and valid value.
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,6 @@ import multiprocessing
|
||||
import multiprocessing.connection
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Type, Union
|
||||
@@ -13,9 +12,7 @@ from .. import utils
|
||||
from ..logger import get_logger
|
||||
from ..parameter import Configuration, Parameters
|
||||
from ..pbar import PBars, ProgressBarActor, progress_worker
|
||||
from ..operators import CurrentState
|
||||
from . import pulse
|
||||
from .fiber import create_non_linear_op, fast_dispersion_op
|
||||
from ..operators import CurrentState, ConservedQuantity, NoConservedQuantity
|
||||
|
||||
try:
|
||||
import ray
|
||||
@@ -70,10 +67,6 @@ class RK4IP:
|
||||
|
||||
self.dw = self.params.w[1] - self.params.w[0]
|
||||
self.z_targets = self.params.z_targets
|
||||
self.beta2_coefficients = (
|
||||
params.beta_func if params.beta_func is not None else params.beta2_coefficients
|
||||
)
|
||||
self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma_arr
|
||||
self.C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
|
||||
self.error_ok = (
|
||||
params.tolerated_error if self.params.adapt_step_size else self.params.step_size
|
||||
@@ -83,55 +76,18 @@ class RK4IP:
|
||||
self._setup_sim_parameters()
|
||||
|
||||
def _setup_functions(self):
|
||||
self.N_func = create_non_linear_op(
|
||||
self.params.behaviors,
|
||||
self.params.w_c,
|
||||
self.params.w0,
|
||||
self.gamma,
|
||||
self.params.raman_type,
|
||||
hr_w=self.params.hr_w,
|
||||
)
|
||||
|
||||
if self.params.dynamic_dispersion:
|
||||
self.disp = lambda r: fast_dispersion_op(
|
||||
self.params.w_c,
|
||||
self.beta2_coefficients(r),
|
||||
self.params.w_power_fact,
|
||||
alpha=self.params.alpha_arr,
|
||||
)
|
||||
else:
|
||||
self.disp = lambda r: fast_dispersion_op(
|
||||
self.params.w_c,
|
||||
self.beta2_coefficients,
|
||||
self.params.w_power_fact,
|
||||
alpha=self.params.alpha_arr,
|
||||
)
|
||||
|
||||
# Set up which quantity is conserved for adaptive step size
|
||||
if self.params.adapt_step_size:
|
||||
if "raman" in self.params.behaviors and self.params.alpha_arr is not None:
|
||||
self.logger.debug("Conserved quantity : photon number with loss")
|
||||
self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number_with_loss(
|
||||
spectrum, self.params.w, self.dw, self.gamma, self.params.alpha_arr, h
|
||||
)
|
||||
elif "raman" in self.params.behaviors:
|
||||
self.logger.debug("Conserved quantity : photon number without loss")
|
||||
self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number(
|
||||
spectrum, self.params.w, self.dw, self.gamma
|
||||
)
|
||||
elif self.params.alpha_arr is not None:
|
||||
self.logger.debug("Conserved quantity : energy with loss")
|
||||
self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy_with_loss(
|
||||
self.C_to_A_factor * spectrum, self.dw, self.params.alpha_arr, h
|
||||
)
|
||||
else:
|
||||
self.logger.debug("Conserved quantity : energy without loss")
|
||||
self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy(
|
||||
self.C_to_A_factor * spectrum, self.dw
|
||||
self.conserved_quantity_func = ConservedQuantity(
|
||||
self.params.nonlinear_operator.raman_op,
|
||||
self.params.nonlinear_operator.gamma_op,
|
||||
self.params.linear_operator.loss_op,
|
||||
self.params.w,
|
||||
)
|
||||
else:
|
||||
self.logger.debug(f"Using constant step size of {1e6*self.error_ok:.3f}")
|
||||
self.conserved_quantity_func = lambda spectrum, h: 0.0
|
||||
self.conserved_quantity_func = NoConservedQuantity()
|
||||
|
||||
def _setup_sim_parameters(self):
|
||||
# making sure to keep only the z that we want
|
||||
@@ -140,27 +96,27 @@ class RK4IP:
|
||||
self.z_targets.sort()
|
||||
self.store_num = len(self.z_targets)
|
||||
|
||||
# Initial setup of simulation parameters
|
||||
self.z = self.z_targets.pop(0)
|
||||
|
||||
# Initial step size
|
||||
if self.params.adapt_step_size:
|
||||
initial_h = (self.z_targets[0] - self.z) / 2
|
||||
else:
|
||||
initial_h = self.error_ok
|
||||
# Setup initial values for every physical quantity that we want to track
|
||||
self.state = CurrentState(
|
||||
length=self.params.length, spectrum=self.params.spec_0.copy() / self.C_to_A_factor
|
||||
length=self.params.length,
|
||||
z=self.z_targets.pop(0),
|
||||
h=initial_h,
|
||||
spectrum=self.params.spec_0.copy() / self.C_to_A_factor,
|
||||
)
|
||||
self.stored_spectra = self.params.recovery_last_stored * [None] + [
|
||||
self.state.spectrum.copy()
|
||||
]
|
||||
self.cons_qty = [
|
||||
self.conserved_quantity_func(self.state.spectrum, 0),
|
||||
self.conserved_quantity_func(self.state),
|
||||
0,
|
||||
]
|
||||
self.size_fac = 2 ** (1 / 5)
|
||||
|
||||
# Initial step size
|
||||
if self.params.adapt_step_size:
|
||||
self.initial_h = (self.z_targets[0] - self.z) / 2
|
||||
else:
|
||||
self.initial_h = self.error_ok
|
||||
|
||||
def _save_current_spectrum(self, num: int):
|
||||
"""saves the spectrum and the corresponding cons_qty array
|
||||
|
||||
107
tests.py
107
tests.py
@@ -1,84 +1,45 @@
|
||||
import numpy as np
|
||||
import scgenerator as sc
|
||||
import matplotlib.pyplot as plt
|
||||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def convert(l, beta2):
|
||||
return l[2:-2] * 1e9, sc.units.beta2_fs_cm.inv(beta2[2:-2])
|
||||
class Parameter:
|
||||
registered_params = defaultdict(dict)
|
||||
|
||||
def __init__(self, default_value, display_suffix=""):
|
||||
self.value = default_value
|
||||
self.display_suffix = display_suffix
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
self.name = name
|
||||
self.registered_params[owner.__name__][name] = self
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.value
|
||||
|
||||
def __set__(self, instance, value):
|
||||
self.value = value
|
||||
|
||||
def display(self):
|
||||
return str(self.value) + " " + self.display_suffix
|
||||
|
||||
|
||||
def test_empty_marcatili():
|
||||
l = np.linspace(250, 1200, 500) * 1e-9
|
||||
beta2 = sc.fiber.HCPCF_dispersion(l, 15e-6)
|
||||
plt.plot(*convert(l, beta2))
|
||||
plt.show()
|
||||
class A:
|
||||
x = Parameter("lol")
|
||||
y = Parameter(56.2)
|
||||
|
||||
|
||||
def test_empty_hasan_no_resonance():
|
||||
l = np.linspace(250, 1200, 500) * 1e-9
|
||||
beta2 = sc.fiber.HCPCF_dispersion(
|
||||
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=6)
|
||||
)
|
||||
plt.plot(*convert(l, beta2))
|
||||
plt.show()
|
||||
class B:
|
||||
x = Parameter(slice(None))
|
||||
opt = None
|
||||
|
||||
|
||||
def test_empty_hasan():
|
||||
l = np.linspace(250, 1200, 500) * 1e-9
|
||||
fig, (ax, ax2) = plt.subplots(2, 1, figsize=(6, 7), gridspec_kw=dict(height_ratios=[3, 1]))
|
||||
ax.set_ylim(-40, 20)
|
||||
ax2.set_ylim(-100, 0)
|
||||
beta2 = sc.fiber.HCPCF_dispersion(
|
||||
l,
|
||||
12e-6,
|
||||
model="hasan",
|
||||
model_params=dict(t=0.2e-6, g=1e-6, n=6, resonance_strength=(2e-6,)),
|
||||
)
|
||||
ax.plot(*convert(l, beta2))
|
||||
beta2 = sc.fiber.HCPCF_dispersion(
|
||||
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=6)
|
||||
)
|
||||
ax.plot(*convert(l, beta2))
|
||||
|
||||
l = np.linspace(500, 1500, 500) * 1e-9
|
||||
beta2 = sc.fiber.HCPCF_dispersion(
|
||||
l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=10)
|
||||
)
|
||||
ax2.plot(*convert(l, beta2))
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_custom_initial_field():
|
||||
param = {
|
||||
"name": "test",
|
||||
"lambda0": [1030, "nm"],
|
||||
"E0": [6, "uJ"],
|
||||
"T0_FWHM": [27, "fs"],
|
||||
"frep": 151e3,
|
||||
"z_targets": [0, 0.07, 128],
|
||||
"gas": "argon",
|
||||
"pressure": 4e5,
|
||||
"temperature": 293,
|
||||
"pulse_shape": "sech",
|
||||
"behaviors": [],
|
||||
"fiber_model": "marcatili",
|
||||
"model_params": {"core_radius": 18e-6},
|
||||
"field_0": "exp(-(t/t0)**2)*P0 + P0/10 * cos(t/t0)*2*exp(-(0.05*t/t0)**2)",
|
||||
"nt": 16384,
|
||||
"T": 2e-12,
|
||||
"adapt_step_size": True,
|
||||
"error_ok": 1e-10,
|
||||
"interp_range": [120, 2000],
|
||||
"n_percent": 2,
|
||||
}
|
||||
|
||||
p = sc.compute_init_parameters(dictionary=param)
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(p["t"], abs(p["field_0"]))
|
||||
plt.show()
|
||||
def main():
|
||||
print(Parameter.registered_params["A"])
|
||||
print(Parameter.registered_params["B"])
|
||||
a = A()
|
||||
a.x = 5
|
||||
print(a.x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_empty_marcatili()
|
||||
# test_empty_hasan()
|
||||
test_custom_initial_field()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user