cache for somo init computations
This commit is contained in:
@@ -9,6 +9,8 @@ from scgenerator.physics.simulate import (
|
|||||||
resume_simulations,
|
resume_simulations,
|
||||||
run_simulation_sequence,
|
run_simulation_sequence,
|
||||||
)
|
)
|
||||||
|
from scgenerator.physics.fiber import dispersion_coefficients
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
@@ -71,6 +73,8 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.func(args)
|
args.func(args)
|
||||||
|
|
||||||
|
print(f"coef hits : {dispersion_coefficients.hits}, misses : {dispersion_coefficients.misses}")
|
||||||
|
|
||||||
|
|
||||||
def run_sim(args):
|
def run_sim(args):
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from scipy.interpolate import interp1d
|
|||||||
from .. import io
|
from .. import io
|
||||||
from ..math import abs2, argclosest, power_fact, u_nm
|
from ..math import abs2, argclosest, power_fact, u_nm
|
||||||
from ..utils.parameter import BareParams, hc_model_specific_parameters
|
from ..utils.parameter import BareParams, hc_model_specific_parameters
|
||||||
|
from ..utils import np_cache
|
||||||
from . import materials as mat
|
from . import materials as mat
|
||||||
from . import units
|
from . import units
|
||||||
from .units import c, pi
|
from .units import c, pi
|
||||||
@@ -43,7 +44,7 @@ def is_dynamic_dispersion(pressure=None):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def HCARF_gap(core_radius, capillary_num, capillary_outer_d):
|
def HCARF_gap(core_radius: float, capillary_num: int, capillary_outer_d: float):
|
||||||
"""computes the gap length between capillaries of a hollow core anti-resonance fiber
|
"""computes the gap length between capillaries of a hollow core anti-resonance fiber
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -64,7 +65,7 @@ def HCARF_gap(core_radius, capillary_num, capillary_outer_d):
|
|||||||
) - capillary_outer_d
|
) - capillary_outer_d
|
||||||
|
|
||||||
|
|
||||||
def dispersion_parameter(n_eff, lambda_):
|
def dispersion_parameter(n_eff: np.ndarray, lambda_: np.ndarray):
|
||||||
"""computes the dispersion parameter D from an effective index of refraction n_eff
|
"""computes the dispersion parameter D from an effective index of refraction n_eff
|
||||||
Since computing gradients/derivatives of discrete arrays is not well defined on the boundary, it is
|
Since computing gradients/derivatives of discrete arrays is not well defined on the boundary, it is
|
||||||
advised to chop off the two values on either end of the returned array
|
advised to chop off the two values on either end of the returned array
|
||||||
@@ -179,17 +180,18 @@ def n_eff_marcatili_adjusted(lambda_, n_gas_2, core_radius, he_mode=(1, 1), fit_
|
|||||||
return np.sqrt(n_gas_2 - (lambda_ * u / (2 * pi * corrected_radius)) ** 2)
|
return np.sqrt(n_gas_2 - (lambda_ * u / (2 * pi * corrected_radius)) ** 2)
|
||||||
|
|
||||||
|
|
||||||
|
@np_cache
|
||||||
def n_eff_hasan(
|
def n_eff_hasan(
|
||||||
lambda_,
|
lambda_: np.ndarray,
|
||||||
n_gas_2,
|
n_gas_2: np.ndarray,
|
||||||
core_radius,
|
core_radius: float,
|
||||||
capillary_num,
|
capillary_num: int,
|
||||||
capillary_thickness,
|
capillary_thickness: float,
|
||||||
capillary_outer_d=None,
|
capillary_outer_d: float = None,
|
||||||
capillary_spacing=None,
|
capillary_spacing: float = None,
|
||||||
capillary_resonance_strengths=[],
|
capillary_resonance_strengths: list[float] = [],
|
||||||
capillary_nested=0,
|
capillary_nested: int = 0,
|
||||||
):
|
) -> np.ndarray:
|
||||||
"""computes the effective refractive index of the fundamental mode according to the Hasan model for a anti-resonance fiber
|
"""computes the effective refractive index of the fundamental mode according to the Hasan model for a anti-resonance fiber
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -410,7 +412,7 @@ def HCPF_ZDW(
|
|||||||
return l[zdw_ind]
|
return l[zdw_ind]
|
||||||
|
|
||||||
|
|
||||||
def beta2(w, n_eff):
|
def beta2(w: np.ndarray, n_eff: np.ndarray) -> np.ndarray:
|
||||||
"""computes the dispersion parameter beta2 according to the effective refractive index of the fiber and the frequency range
|
"""computes the dispersion parameter beta2 according to the effective refractive index of the fiber and the frequency range
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -556,6 +558,7 @@ def gamma_parameter(n2, w0, A_eff):
|
|||||||
return n2 * w0 / (A_eff * c)
|
return n2 * w0 / (A_eff * c)
|
||||||
|
|
||||||
|
|
||||||
|
@np_cache
|
||||||
def PCF_dispersion(lambda_, pitch, ratio_d, w0=None, n2=None, A_eff=None):
|
def PCF_dispersion(lambda_, pitch, ratio_d, w0=None, n2=None, A_eff=None):
|
||||||
"""
|
"""
|
||||||
semi-analytical computation of the dispersion profile of a triangular Index-guiding PCF
|
semi-analytical computation of the dispersion profile of a triangular Index-guiding PCF
|
||||||
@@ -752,7 +755,10 @@ def compute_dispersion(params: BareParams):
|
|||||||
return beta2_coef, gamma
|
return beta2_coef, gamma
|
||||||
|
|
||||||
|
|
||||||
def dispersion_coefficients(lambda_, beta2, w0, interp_range=None, deg=8):
|
@np_cache
|
||||||
|
def dispersion_coefficients(
|
||||||
|
lambda_: np.ndarray, beta2: np.ndarray, w0: float, interp_range=None, deg=8
|
||||||
|
):
|
||||||
"""Computes the taylor expansion of beta2 to be used in dispersion_op
|
"""Computes the taylor expansion of beta2 to be used in dispersion_op
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|||||||
@@ -4,23 +4,30 @@ scgenerator module but some function may be used in any python program
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import threading
|
import threading
|
||||||
|
from functools import update_wrapper
|
||||||
from collections import abc
|
from collections import abc
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict, replace
|
from dataclasses import asdict, replace
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union
|
from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.lib.index_tricks import nd_grid
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..logger import get_logger
|
||||||
|
|
||||||
from .. import env
|
from .. import env
|
||||||
from ..const import PARAM_SEPARATOR
|
from ..const import PARAM_SEPARATOR
|
||||||
from ..math import *
|
from ..math import *
|
||||||
from .parameter import BareConfig, BareParams
|
from .parameter import BareConfig, BareParams
|
||||||
|
from scgenerator import logger
|
||||||
|
|
||||||
T_ = TypeVar("T_")
|
T_ = TypeVar("T_")
|
||||||
|
|
||||||
@@ -276,3 +283,61 @@ def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
|
|||||||
for k in new:
|
for k in new:
|
||||||
variable.pop(k, None) # remove old ones
|
variable.pop(k, None) # remove old ones
|
||||||
return replace(old, variable=variable, **{k: None for k in variable}, **new)
|
return replace(old, variable=variable, **{k: None for k in variable}, **new)
|
||||||
|
|
||||||
|
|
||||||
|
# def np_cache(function):
|
||||||
|
# """applies functools.cache to function that take numpy arrays as input"""
|
||||||
|
|
||||||
|
# @cache
|
||||||
|
# def cached_wrapper(*hashable_args, **hashable_kwargs):
|
||||||
|
# args = tuple(np.array(arg) if isinstance(arg, tuple) else arg for arg in hashable_args)
|
||||||
|
# kwargs = {
|
||||||
|
# k: np.array(kwarg) if isinstance(kwarg, tuple) else kwarg
|
||||||
|
# for k, kwarg in hashable_kwargs.items()
|
||||||
|
# }
|
||||||
|
# return function(*args, **kwargs)
|
||||||
|
|
||||||
|
# @wraps(function)
|
||||||
|
# def wrapper(*args, **kwargs):
|
||||||
|
# hashable_args = tuple(tuple(arg) if isinstance(arg, np.ndarray) else arg for arg in args)
|
||||||
|
# hashable_kwargs = {
|
||||||
|
# k: tuple(kwarg) if isinstance(kwarg, np.ndarray) else kwarg
|
||||||
|
# for k, kwarg in kwargs.items()
|
||||||
|
# }
|
||||||
|
# return cached_wrapper(*hashable_args, **hashable_kwargs)
|
||||||
|
|
||||||
|
# # copy lru_cache attributes over too
|
||||||
|
# wrapper.cache_info = cached_wrapper.cache_info
|
||||||
|
# wrapper.cache_clear = cached_wrapper.cache_clear
|
||||||
|
|
||||||
|
# return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class np_cache:
|
||||||
|
def __init__(self, function):
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
self.func = function
|
||||||
|
self.cache = {}
|
||||||
|
self.hits = 0
|
||||||
|
self.misses = 0
|
||||||
|
update_wrapper(self, function)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
hashable_args = tuple(
|
||||||
|
tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args
|
||||||
|
)
|
||||||
|
hashable_kwargs = tuple(
|
||||||
|
{
|
||||||
|
k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg
|
||||||
|
for k, kwarg in kwargs.items()
|
||||||
|
}.items()
|
||||||
|
)
|
||||||
|
key = hash((hashable_args, hashable_kwargs))
|
||||||
|
if key not in self.cache:
|
||||||
|
self.logger.debug("cache miss")
|
||||||
|
self.misses += 1
|
||||||
|
self.cache[key] = self.func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
self.hits += 1
|
||||||
|
self.logger.debug("cache hit")
|
||||||
|
return copy(self.cache[key])
|
||||||
|
|||||||
Reference in New Issue
Block a user