cache for somo init computations

This commit is contained in:
Benoît Sierro
2021-06-15 10:12:32 +02:00
parent 76607a76f1
commit 3ff6b1c3c9
3 changed files with 89 additions and 14 deletions

View File

@@ -9,6 +9,8 @@ from scgenerator.physics.simulate import (
resume_simulations,
run_simulation_sequence,
)
from scgenerator.physics.fiber import dispersion_coefficients
try:
import ray
@@ -71,6 +73,8 @@ def main():
args = parser.parse_args()
args.func(args)
print(f"coef hits : {dispersion_coefficients.hits}, misses : {dispersion_coefficients.misses}")
def run_sim(args):

View File

@@ -9,6 +9,7 @@ from scipy.interpolate import interp1d
from .. import io
from ..math import abs2, argclosest, power_fact, u_nm
from ..utils.parameter import BareParams, hc_model_specific_parameters
from ..utils import np_cache
from . import materials as mat
from . import units
from .units import c, pi
@@ -43,7 +44,7 @@ def is_dynamic_dispersion(pressure=None):
return out
def HCARF_gap(core_radius, capillary_num, capillary_outer_d):
def HCARF_gap(core_radius: float, capillary_num: int, capillary_outer_d: float):
"""computes the gap length between capillaries of a hollow core anti-resonance fiber
Parameters
@@ -64,7 +65,7 @@ def HCARF_gap(core_radius, capillary_num, capillary_outer_d):
) - capillary_outer_d
def dispersion_parameter(n_eff, lambda_):
def dispersion_parameter(n_eff: np.ndarray, lambda_: np.ndarray):
"""computes the dispersion parameter D from an effective index of refraction n_eff
Since computing gradients/derivatives of discrete arrays is not well defined on the boundary, it is
advised to chop off the two values on either end of the returned array
@@ -179,17 +180,18 @@ def n_eff_marcatili_adjusted(lambda_, n_gas_2, core_radius, he_mode=(1, 1), fit_
return np.sqrt(n_gas_2 - (lambda_ * u / (2 * pi * corrected_radius)) ** 2)
@np_cache
def n_eff_hasan(
lambda_,
n_gas_2,
core_radius,
capillary_num,
capillary_thickness,
capillary_outer_d=None,
capillary_spacing=None,
capillary_resonance_strengths=[],
capillary_nested=0,
):
lambda_: np.ndarray,
n_gas_2: np.ndarray,
core_radius: float,
capillary_num: int,
capillary_thickness: float,
capillary_outer_d: float = None,
capillary_spacing: float = None,
capillary_resonance_strengths: list[float] = [],
capillary_nested: int = 0,
) -> np.ndarray:
"""computes the effective refractive index of the fundamental mode according to the Hasan model for a anti-resonance fiber
Parameters
@@ -410,7 +412,7 @@ def HCPF_ZDW(
return l[zdw_ind]
def beta2(w, n_eff):
def beta2(w: np.ndarray, n_eff: np.ndarray) -> np.ndarray:
"""computes the dispersion parameter beta2 according to the effective refractive index of the fiber and the frequency range
Parameters
@@ -556,6 +558,7 @@ def gamma_parameter(n2, w0, A_eff):
return n2 * w0 / (A_eff * c)
@np_cache
def PCF_dispersion(lambda_, pitch, ratio_d, w0=None, n2=None, A_eff=None):
"""
semi-analytical computation of the dispersion profile of a triangular Index-guiding PCF
@@ -752,7 +755,10 @@ def compute_dispersion(params: BareParams):
return beta2_coef, gamma
def dispersion_coefficients(lambda_, beta2, w0, interp_range=None, deg=8):
@np_cache
def dispersion_coefficients(
lambda_: np.ndarray, beta2: np.ndarray, w0: float, interp_range=None, deg=8
):
"""Computes the taylor expansion of beta2 to be used in dispersion_op
Parameters

View File

@@ -4,23 +4,30 @@ scgenerator module but some function may be used in any python program
"""
import functools
import itertools
import multiprocessing
import threading
from functools import update_wrapper
from collections import abc
from copy import deepcopy
from dataclasses import asdict, replace
from io import StringIO
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union
from copy import copy
import numpy as np
from numpy.lib.index_tricks import nd_grid
from tqdm import tqdm
from ..logger import get_logger
from .. import env
from ..const import PARAM_SEPARATOR
from ..math import *
from .parameter import BareConfig, BareParams
from scgenerator import logger
T_ = TypeVar("T_")
@@ -276,3 +283,61 @@ def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
for k in new:
variable.pop(k, None) # remove old ones
return replace(old, variable=variable, **{k: None for k in variable}, **new)
# def np_cache(function):
# """applies functools.cache to function that take numpy arrays as input"""
# @cache
# def cached_wrapper(*hashable_args, **hashable_kwargs):
# args = tuple(np.array(arg) if isinstance(arg, tuple) else arg for arg in hashable_args)
# kwargs = {
# k: np.array(kwarg) if isinstance(kwarg, tuple) else kwarg
# for k, kwarg in hashable_kwargs.items()
# }
# return function(*args, **kwargs)
# @wraps(function)
# def wrapper(*args, **kwargs):
# hashable_args = tuple(tuple(arg) if isinstance(arg, np.ndarray) else arg for arg in args)
# hashable_kwargs = {
# k: tuple(kwarg) if isinstance(kwarg, np.ndarray) else kwarg
# for k, kwarg in kwargs.items()
# }
# return cached_wrapper(*hashable_args, **hashable_kwargs)
# # copy lru_cache attributes over too
# wrapper.cache_info = cached_wrapper.cache_info
# wrapper.cache_clear = cached_wrapper.cache_clear
# return wrapper
class np_cache:
def __init__(self, function):
self.logger = get_logger(__name__)
self.func = function
self.cache = {}
self.hits = 0
self.misses = 0
update_wrapper(self, function)
def __call__(self, *args, **kwargs):
hashable_args = tuple(
tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args
)
hashable_kwargs = tuple(
{
k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg
for k, kwarg in kwargs.items()
}.items()
)
key = hash((hashable_args, hashable_kwargs))
if key not in self.cache:
self.logger.debug("cache miss")
self.misses += 1
self.cache[key] = self.func(*args, **kwargs)
else:
self.hits += 1
self.logger.debug("cache hit")
return copy(self.cache[key])