From de12b0d5c1694799af77d22605f4679bc7cbf979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 19 Oct 2021 17:15:02 +0200 Subject: [PATCH] Removed inheritence; cons_qty operators --- src/scgenerator/evaluator.py | 21 +++--- src/scgenerator/operators.py | 81 ++++++++++++++++++++- src/scgenerator/parameter.py | 23 ++---- src/scgenerator/physics/simulate.py | 80 +++++---------------- tests.py | 107 +++++++++------------------- 5 files changed, 148 insertions(+), 164 deletions(-) diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index a374929..977619b 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -1,13 +1,15 @@ -from typing import Optional, Callable, Union, Any -from dataclasses import dataclass -from .physics import fiber, pulse, materials, units -from .utils import _mock_function, get_arg_names, get_logger, func_rewrite -from .errors import * -from collections import defaultdict -from .const import MANDATORY_PARAMETERS -import numpy as np import itertools -from . import math, utils, operators +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import numpy as np + +from . import math, operators, utils +from .const import MANDATORY_PARAMETERS +from .errors import * +from .physics import fiber, materials, pulse, units +from .utils import _mock_function, func_rewrite, get_arg_names, get_logger class Rule: @@ -378,6 +380,7 @@ default_rules: list[Rule] = [ Rule("loss_op", operators.NoLoss, priorities=-1), Rule("disp_op", operators.ConstantPolyDispersion), Rule("linear_operator", operators.LinearOperator), + Rule("conserved_quantity", operators.ConservedQuantity), # gas Rule("n_gas_2", materials.n_gas_2), ] diff --git a/src/scgenerator/operators.py b/src/scgenerator/operators.py index b7396c8..e1b0854 100644 --- a/src/scgenerator/operators.py +++ b/src/scgenerator/operators.py @@ -6,12 +6,16 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass, field +from functools import wraps +from os import stat +from typing import Callable import numpy as np from scipy.interpolate import interp1d -from .physics import fiber from . import math +from .logger import get_logger +from .physics import fiber, pulse class SpectrumDescriptor: @@ -385,3 +389,78 @@ class CustomConstantLoss(ConstantLoss): wl = loss_data["wavelength"] loss = loss_data["loss"] self.alpha_arr = interp1d(wl, loss, fill_value=0, bounds_error=False)(l) + + +################################################## +############### CONSERVED QUANTITY ############### +################################################## + + +class ConservedQuantity(Operator): + def __new__( + raman_op: AbstractGamma, gamma_op: AbstractGamma, loss_op: AbstractLoss, w: np.ndarray + ): + logger = get_logger(__name__) + raman = not isinstance(raman_op, NoRaman) + loss = not isinstance(raman_op, NoLoss) + if raman and loss: + logger.debug("Conserved quantity : photon number with loss") + return PhotonNumberLoss(w, gamma_op, loss_op) + elif raman: + logger.debug("Conserved quantity : photon number without loss") + return PhotonNumberNoLoss(w, gamma_op) + elif loss: + logger.debug("Conserved quantity : energy with loss") + return EnergyLoss(w, loss_op) + else: + logger.debug("Conserved quantity : energy without loss") + return EnergyNoLoss(w) + + @abstractmethod + def __call__(self, state: CurrentState) -> float: + pass + + +class NoConservedQuantity(ConservedQuantity): + def __call__(self, state: CurrentState) -> float: + return 0.0 + + +class PhotonNumberLoss(ConservedQuantity): + def __init__(self, w: np.ndarray, gamma_op: AbstractGamma, loss_op=AbstractLoss): + self.w = w + self.dw = w[1] - w[0] + self.gamma_op = gamma_op + self.loss_op = loss_op + + def __call__(self, state: CurrentState) -> float: + return pulse.photon_number_with_loss( + state.spectrum, self.w, self.dw, self.gamma_op(state), self.loss_op(state), state.h + ) + + +class PhotonNumberNoLoss(ConservedQuantity): + def __init__(self, w: np.ndarray, gamma_op: AbstractGamma): + self.w = w + self.dw = w[1] - w[0] + self.gamma_op = gamma_op + + def __call__(self, state: CurrentState) -> float: + return pulse.photon_number(state.spectrum, self.w, self.dw, self.gamma_op(state)) + + +class EnergyLoss(ConservedQuantity): + def __init__(self, w: np.ndarray, loss_op: AbstractLoss): + self.dw = w[1] - w[0] + self.loss_op = loss_op + + def __call__(self, state: CurrentState) -> float: + return pulse.pulse_energy_with_loss(state.spectrum, self.dw, self.loss_op(state), state.h) + + +class EnergyNoLoss(ConservedQuantity): + def __init__(self, w: np.ndarray): + self.dw = w[1] - w[0] + + def __call__(self, state: CurrentState) -> float: + return pulse.pulse_energy(state.spectrum, self.dw) diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 804d366..29782b3 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -2,20 +2,17 @@ from __future__ import annotations import datetime as datetime_module import enum -import itertools import os import time -from collections import defaultdict from copy import copy from dataclasses import asdict, dataclass, fields from functools import lru_cache from pathlib import Path -from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar, Union +from typing import Any, Callable, Iterable, Iterator, TypeVar, Union import numpy as np -from numpy.lib import isin -from . import env, math, utils +from . import env, utils from .const import PARAM_FN, __version__, VALID_VARIABLE, MANDATORY_PARAMETERS from .logger import get_logger from .utils import fiber_folder, update_path_name @@ -210,6 +207,7 @@ class Parameter: self.name = name if self.default is not None: Evaluator.register_default_param(self.name, self.default) + VariationDescriptor.register_formatter(self.name, self.display) def __get__(self, instance, owner): if instance is None: @@ -242,20 +240,7 @@ class Parameter: @dataclass -class _AbstractParameters: - @classmethod - def __init_subclass__(cls): - cls.register_param_formatters() - - @classmethod - def register_param_formatters(cls): - for k, v in cls.__dict__.items(): - if isinstance(v, Parameter): - VariationDescriptor.register_formatter(k, v.display) - - -@dataclass -class Parameters(_AbstractParameters): +class Parameters: """ This class defines each valid parameter's name, type and valid value. """ diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index dfe7189..a66d740 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -2,7 +2,6 @@ import multiprocessing import multiprocessing.connection import os import random -from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Any, Generator, Type, Union @@ -13,9 +12,7 @@ from .. import utils from ..logger import get_logger from ..parameter import Configuration, Parameters from ..pbar import PBars, ProgressBarActor, progress_worker -from ..operators import CurrentState -from . import pulse -from .fiber import create_non_linear_op, fast_dispersion_op +from ..operators import CurrentState, ConservedQuantity, NoConservedQuantity try: import ray @@ -70,10 +67,6 @@ class RK4IP: self.dw = self.params.w[1] - self.params.w[0] self.z_targets = self.params.z_targets - self.beta2_coefficients = ( - params.beta_func if params.beta_func is not None else params.beta2_coefficients - ) - self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma_arr self.C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4) self.error_ok = ( params.tolerated_error if self.params.adapt_step_size else self.params.step_size @@ -83,55 +76,18 @@ class RK4IP: self._setup_sim_parameters() def _setup_functions(self): - self.N_func = create_non_linear_op( - self.params.behaviors, - self.params.w_c, - self.params.w0, - self.gamma, - self.params.raman_type, - hr_w=self.params.hr_w, - ) - - if self.params.dynamic_dispersion: - self.disp = lambda r: fast_dispersion_op( - self.params.w_c, - self.beta2_coefficients(r), - self.params.w_power_fact, - alpha=self.params.alpha_arr, - ) - else: - self.disp = lambda r: fast_dispersion_op( - self.params.w_c, - self.beta2_coefficients, - self.params.w_power_fact, - alpha=self.params.alpha_arr, - ) # Set up which quantity is conserved for adaptive step size if self.params.adapt_step_size: - if "raman" in self.params.behaviors and self.params.alpha_arr is not None: - self.logger.debug("Conserved quantity : photon number with loss") - self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number_with_loss( - spectrum, self.params.w, self.dw, self.gamma, self.params.alpha_arr, h - ) - elif "raman" in self.params.behaviors: - self.logger.debug("Conserved quantity : photon number without loss") - self.conserved_quantity_func = lambda spectrum, h: pulse.photon_number( - spectrum, self.params.w, self.dw, self.gamma - ) - elif self.params.alpha_arr is not None: - self.logger.debug("Conserved quantity : energy with loss") - self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy_with_loss( - self.C_to_A_factor * spectrum, self.dw, self.params.alpha_arr, h - ) - else: - self.logger.debug("Conserved quantity : energy without loss") - self.conserved_quantity_func = lambda spectrum, h: pulse.pulse_energy( - self.C_to_A_factor * spectrum, self.dw - ) + self.conserved_quantity_func = ConservedQuantity( + self.params.nonlinear_operator.raman_op, + self.params.nonlinear_operator.gamma_op, + self.params.linear_operator.loss_op, + self.params.w, + ) else: self.logger.debug(f"Using constant step size of {1e6*self.error_ok:.3f}") - self.conserved_quantity_func = lambda spectrum, h: 0.0 + self.conserved_quantity_func = NoConservedQuantity() def _setup_sim_parameters(self): # making sure to keep only the z that we want @@ -140,27 +96,27 @@ class RK4IP: self.z_targets.sort() self.store_num = len(self.z_targets) - # Initial setup of simulation parameters - self.z = self.z_targets.pop(0) - + # Initial step size + if self.params.adapt_step_size: + initial_h = (self.z_targets[0] - self.z) / 2 + else: + initial_h = self.error_ok # Setup initial values for every physical quantity that we want to track self.state = CurrentState( - length=self.params.length, spectrum=self.params.spec_0.copy() / self.C_to_A_factor + length=self.params.length, + z=self.z_targets.pop(0), + h=initial_h, + spectrum=self.params.spec_0.copy() / self.C_to_A_factor, ) self.stored_spectra = self.params.recovery_last_stored * [None] + [ self.state.spectrum.copy() ] self.cons_qty = [ - self.conserved_quantity_func(self.state.spectrum, 0), + self.conserved_quantity_func(self.state), 0, ] self.size_fac = 2 ** (1 / 5) - # Initial step size - if self.params.adapt_step_size: - self.initial_h = (self.z_targets[0] - self.z) / 2 - else: - self.initial_h = self.error_ok def _save_current_spectrum(self, num: int): """saves the spectrum and the corresponding cons_qty array diff --git a/tests.py b/tests.py index b75edb9..8a0e1c9 100644 --- a/tests.py +++ b/tests.py @@ -1,84 +1,45 @@ -import numpy as np -import scgenerator as sc -import matplotlib.pyplot as plt +from __future__ import annotations +from collections import defaultdict -def convert(l, beta2): - return l[2:-2] * 1e9, sc.units.beta2_fs_cm.inv(beta2[2:-2]) +class Parameter: + registered_params = defaultdict(dict) + + def __init__(self, default_value, display_suffix=""): + self.value = default_value + self.display_suffix = display_suffix + + def __set_name__(self, owner, name): + self.name = name + self.registered_params[owner.__name__][name] = self + + def __get__(self, instance, owner): + return self.value + + def __set__(self, instance, value): + self.value = value + + def display(self): + return str(self.value) + " " + self.display_suffix -def test_empty_marcatili(): - l = np.linspace(250, 1200, 500) * 1e-9 - beta2 = sc.fiber.HCPCF_dispersion(l, 15e-6) - plt.plot(*convert(l, beta2)) - plt.show() +class A: + x = Parameter("lol") + y = Parameter(56.2) -def test_empty_hasan_no_resonance(): - l = np.linspace(250, 1200, 500) * 1e-9 - beta2 = sc.fiber.HCPCF_dispersion( - l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=6) - ) - plt.plot(*convert(l, beta2)) - plt.show() +class B: + x = Parameter(slice(None)) + opt = None -def test_empty_hasan(): - l = np.linspace(250, 1200, 500) * 1e-9 - fig, (ax, ax2) = plt.subplots(2, 1, figsize=(6, 7), gridspec_kw=dict(height_ratios=[3, 1])) - ax.set_ylim(-40, 20) - ax2.set_ylim(-100, 0) - beta2 = sc.fiber.HCPCF_dispersion( - l, - 12e-6, - model="hasan", - model_params=dict(t=0.2e-6, g=1e-6, n=6, resonance_strength=(2e-6,)), - ) - ax.plot(*convert(l, beta2)) - beta2 = sc.fiber.HCPCF_dispersion( - l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=6) - ) - ax.plot(*convert(l, beta2)) - - l = np.linspace(500, 1500, 500) * 1e-9 - beta2 = sc.fiber.HCPCF_dispersion( - l, 12e-6, model="hasan", model_params=dict(t=0.2e-6, g=1e-6, n=10) - ) - ax2.plot(*convert(l, beta2)) - plt.show() - - -def test_custom_initial_field(): - param = { - "name": "test", - "lambda0": [1030, "nm"], - "E0": [6, "uJ"], - "T0_FWHM": [27, "fs"], - "frep": 151e3, - "z_targets": [0, 0.07, 128], - "gas": "argon", - "pressure": 4e5, - "temperature": 293, - "pulse_shape": "sech", - "behaviors": [], - "fiber_model": "marcatili", - "model_params": {"core_radius": 18e-6}, - "field_0": "exp(-(t/t0)**2)*P0 + P0/10 * cos(t/t0)*2*exp(-(0.05*t/t0)**2)", - "nt": 16384, - "T": 2e-12, - "adapt_step_size": True, - "error_ok": 1e-10, - "interp_range": [120, 2000], - "n_percent": 2, - } - - p = sc.compute_init_parameters(dictionary=param) - fig, ax = plt.subplots() - ax.plot(p["t"], abs(p["field_0"])) - plt.show() +def main(): + print(Parameter.registered_params["A"]) + print(Parameter.registered_params["B"]) + a = A() + a.x = 5 + print(a.x) if __name__ == "__main__": - # test_empty_marcatili() - # test_empty_hasan() - test_custom_initial_field() \ No newline at end of file + main()