diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index 9a20003..dbd9f2d 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -9,6 +9,8 @@ from scgenerator.physics.simulate import ( resume_simulations, run_simulation_sequence, ) +from scgenerator.physics.fiber import dispersion_coefficients + try: import ray @@ -71,6 +73,8 @@ def main(): args = parser.parse_args() args.func(args) + print(f"coef hits : {dispersion_coefficients.hits}, misses : {dispersion_coefficients.misses}") + def run_sim(args): diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index f231de3..edb218f 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -9,6 +9,7 @@ from scipy.interpolate import interp1d from .. import io from ..math import abs2, argclosest, power_fact, u_nm from ..utils.parameter import BareParams, hc_model_specific_parameters +from ..utils import np_cache from . import materials as mat from . import units from .units import c, pi @@ -43,7 +44,7 @@ def is_dynamic_dispersion(pressure=None): return out -def HCARF_gap(core_radius, capillary_num, capillary_outer_d): +def HCARF_gap(core_radius: float, capillary_num: int, capillary_outer_d: float): """computes the gap length between capillaries of a hollow core anti-resonance fiber Parameters @@ -64,7 +65,7 @@ def HCARF_gap(core_radius, capillary_num, capillary_outer_d): ) - capillary_outer_d -def dispersion_parameter(n_eff, lambda_): +def dispersion_parameter(n_eff: np.ndarray, lambda_: np.ndarray): """computes the dispersion parameter D from an effective index of refraction n_eff Since computing gradients/derivatives of discrete arrays is not well defined on the boundary, it is advised to chop off the two values on either end of the returned array @@ -179,17 +180,18 @@ def n_eff_marcatili_adjusted(lambda_, n_gas_2, core_radius, he_mode=(1, 1), fit_ return np.sqrt(n_gas_2 - (lambda_ * u / (2 * pi * corrected_radius)) ** 2) +@np_cache def n_eff_hasan( - lambda_, - n_gas_2, - core_radius, - capillary_num, - capillary_thickness, - capillary_outer_d=None, - capillary_spacing=None, - capillary_resonance_strengths=[], - capillary_nested=0, -): + lambda_: np.ndarray, + n_gas_2: np.ndarray, + core_radius: float, + capillary_num: int, + capillary_thickness: float, + capillary_outer_d: float = None, + capillary_spacing: float = None, + capillary_resonance_strengths: list[float] = [], + capillary_nested: int = 0, +) -> np.ndarray: """computes the effective refractive index of the fundamental mode according to the Hasan model for a anti-resonance fiber Parameters @@ -410,7 +412,7 @@ def HCPF_ZDW( return l[zdw_ind] -def beta2(w, n_eff): +def beta2(w: np.ndarray, n_eff: np.ndarray) -> np.ndarray: """computes the dispersion parameter beta2 according to the effective refractive index of the fiber and the frequency range Parameters @@ -556,6 +558,7 @@ def gamma_parameter(n2, w0, A_eff): return n2 * w0 / (A_eff * c) +@np_cache def PCF_dispersion(lambda_, pitch, ratio_d, w0=None, n2=None, A_eff=None): """ semi-analytical computation of the dispersion profile of a triangular Index-guiding PCF @@ -752,7 +755,10 @@ def compute_dispersion(params: BareParams): return beta2_coef, gamma -def dispersion_coefficients(lambda_, beta2, w0, interp_range=None, deg=8): +@np_cache +def dispersion_coefficients( + lambda_: np.ndarray, beta2: np.ndarray, w0: float, interp_range=None, deg=8 +): """Computes the taylor expansion of beta2 to be used in dispersion_op Parameters diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index e71bb14..00750f4 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -4,23 +4,30 @@ scgenerator module but some function may be used in any python program """ +import functools import itertools import multiprocessing import threading +from functools import update_wrapper from collections import abc from copy import deepcopy from dataclasses import asdict, replace from io import StringIO from pathlib import Path from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union +from copy import copy import numpy as np +from numpy.lib.index_tricks import nd_grid from tqdm import tqdm +from ..logger import get_logger + from .. import env from ..const import PARAM_SEPARATOR from ..math import * from .parameter import BareConfig, BareParams +from scgenerator import logger T_ = TypeVar("T_") @@ -276,3 +283,61 @@ def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig: for k in new: variable.pop(k, None) # remove old ones return replace(old, variable=variable, **{k: None for k in variable}, **new) + + +# def np_cache(function): +# """applies functools.cache to function that take numpy arrays as input""" + +# @cache +# def cached_wrapper(*hashable_args, **hashable_kwargs): +# args = tuple(np.array(arg) if isinstance(arg, tuple) else arg for arg in hashable_args) +# kwargs = { +# k: np.array(kwarg) if isinstance(kwarg, tuple) else kwarg +# for k, kwarg in hashable_kwargs.items() +# } +# return function(*args, **kwargs) + +# @wraps(function) +# def wrapper(*args, **kwargs): +# hashable_args = tuple(tuple(arg) if isinstance(arg, np.ndarray) else arg for arg in args) +# hashable_kwargs = { +# k: tuple(kwarg) if isinstance(kwarg, np.ndarray) else kwarg +# for k, kwarg in kwargs.items() +# } +# return cached_wrapper(*hashable_args, **hashable_kwargs) + +# # copy lru_cache attributes over too +# wrapper.cache_info = cached_wrapper.cache_info +# wrapper.cache_clear = cached_wrapper.cache_clear + +# return wrapper + + +class np_cache: + def __init__(self, function): + self.logger = get_logger(__name__) + self.func = function + self.cache = {} + self.hits = 0 + self.misses = 0 + update_wrapper(self, function) + + def __call__(self, *args, **kwargs): + hashable_args = tuple( + tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args + ) + hashable_kwargs = tuple( + { + k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg + for k, kwarg in kwargs.items() + }.items() + ) + key = hash((hashable_args, hashable_kwargs)) + if key not in self.cache: + self.logger.debug("cache miss") + self.misses += 1 + self.cache[key] = self.func(*args, **kwargs) + else: + self.hits += 1 + self.logger.debug("cache hit") + return copy(self.cache[key])