better np cache, integr. plotting, nfft
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
from scipy.interpolate import griddata, interp1d
|
||||
from scipy.special import jn_zeros
|
||||
from .utils.cache import np_cache
|
||||
|
||||
|
||||
def span(*vec):
|
||||
@@ -89,6 +90,85 @@ def u_nm(n, m):
|
||||
return jn_zeros(n - 1, m)[-1]
|
||||
|
||||
|
||||
@np_cache
|
||||
def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
"""creates the nfft matrix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : np.ndarray, shape = (n,)
|
||||
time array
|
||||
f : np.ndarray, shape = (m,)
|
||||
frequency array
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray, shape = (m, n)
|
||||
multiply x(t) by this matrix to get ~X(f)
|
||||
"""
|
||||
P, F = np.meshgrid(t, f)
|
||||
return np.exp(-2j * np.pi * P * F)
|
||||
|
||||
|
||||
@np_cache
|
||||
def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
"""creates the nfft matrix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : np.ndarray, shape = (n,)
|
||||
time array
|
||||
f : np.ndarray, shape = (m,)
|
||||
frequency array
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray, shape = (m, n)
|
||||
multiply ~X(f) by this matrix to get x(t)
|
||||
"""
|
||||
return np.linalg.pinv(nfft_matrix(t, f))
|
||||
|
||||
|
||||
def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
"""computes the Fourier transform of an uneven signal
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : np.ndarray, shape = (n,)
|
||||
time array
|
||||
s : np.ndarray, shape = (n, )
|
||||
amplitute at each point of t
|
||||
f : np.ndarray, shape = (m, )
|
||||
desired frequencies
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray, shape = (m, )
|
||||
amplitude at each frequency
|
||||
"""
|
||||
return nfft_matrix(t, f) @ s
|
||||
|
||||
|
||||
def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
|
||||
"""computes the inverse Fourier transform of an uneven spectrum
|
||||
|
||||
Parameters
|
||||
----------
|
||||
f : np.ndarray, shape = (n,)
|
||||
frequency array
|
||||
a : np.ndarray, shape = (n, )
|
||||
amplitude at each point of f
|
||||
t : np.ndarray, shape = (m, )
|
||||
time array
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray, shape = (m, )
|
||||
amplitude at each point of t
|
||||
"""
|
||||
return infft_matrix(t, f) @ a
|
||||
|
||||
|
||||
def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"):
|
||||
"""Interpolates a 2D array with the help of griddata
|
||||
Parameters
|
||||
|
||||
@@ -9,7 +9,7 @@ from scipy.interpolate import interp1d
|
||||
from .. import io
|
||||
from ..math import abs2, argclosest, power_fact, u_nm
|
||||
from ..utils.parameter import BareParams, hc_model_specific_parameters
|
||||
from ..utils import np_cache
|
||||
from ..utils.cache import np_cache
|
||||
from . import materials as mat
|
||||
from . import units
|
||||
from .units import c, pi
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
|
||||
|
||||
from typing import Callable, Union
|
||||
|
||||
from dataclasses import dataclass
|
||||
from ..utils.parameter import Parameter, type_checker
|
||||
import numpy as np
|
||||
from numpy import pi
|
||||
|
||||
@@ -224,6 +225,18 @@ def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]:
|
||||
return unit
|
||||
|
||||
|
||||
def is_unit(name, value):
|
||||
if not hasattr(get_unit(value), "inv"):
|
||||
raise TypeError("invalid unit specified")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotRange:
|
||||
left: float = Parameter(type_checker(int, float))
|
||||
right: float = Parameter(type_checker(int, float))
|
||||
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit)
|
||||
|
||||
|
||||
def beta2_coef(beta):
|
||||
fac = 1e27
|
||||
out = np.zeros_like(beta)
|
||||
@@ -263,45 +276,49 @@ def standardize_dictionary(dico):
|
||||
return dico
|
||||
|
||||
|
||||
def sort_axis(axis, plt_range):
|
||||
def sort_axis(axis, plt_range: PlotRange):
|
||||
"""
|
||||
given an axis, returns this axis cropped according to the given range, converted and sorted
|
||||
|
||||
Parameters
|
||||
----------
|
||||
axis : 1D array containing the original axis (usual the w or t array)
|
||||
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
||||
axis : 1D array containing the original axis (usual the w or t array)
|
||||
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
||||
|
||||
Returns
|
||||
----------
|
||||
cropped : the axis cropped, converted and sorted
|
||||
indices : indices to use to slice and sort other array in the same fashion
|
||||
extent : tupple with min and max of cropped
|
||||
-------
|
||||
cropped : the axis cropped, converted and sorted
|
||||
indices : indices to use to slice and sort other array in the same fashion
|
||||
extent : tupple with min and max of cropped
|
||||
|
||||
Example
|
||||
----------
|
||||
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
||||
t = np.linspace(-10, 10, 400)
|
||||
W, T = np.meshgrid(w, t)
|
||||
y = np.exp(-W**2 - T**2)
|
||||
-------
|
||||
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
||||
t = np.linspace(-10, 10, 400)
|
||||
W, T = np.meshgrid(w, t)
|
||||
y = np.exp(-W**2 - T**2)
|
||||
|
||||
# Define ranges
|
||||
rw = (-4, 4, s)
|
||||
rt = (-2, 6, s)
|
||||
# Define ranges
|
||||
rw = (-4, 4, s)
|
||||
rt = (-2, 6, s)
|
||||
|
||||
w, cw = sort_axis(w, rw)
|
||||
t, ct = sort_axis(t, rt)
|
||||
w, cw = sort_axis(w, rw)
|
||||
t, ct = sort_axis(t, rt)
|
||||
|
||||
# slice y according to the given ranges
|
||||
y = y[ct][:, cw]
|
||||
# slice y according to the given ranges
|
||||
y = y[ct][:, cw]
|
||||
"""
|
||||
|
||||
r = np.array(plt_range[:2], dtype="float")
|
||||
func = get_unit(plt_range[2])
|
||||
r = np.array((plt_range.left, plt_range.right), dtype="float")
|
||||
|
||||
indices = np.arange(len(axis))[(axis <= np.max(func(r))) & (axis >= np.min(func(r)))]
|
||||
indices = np.arange(len(axis))[
|
||||
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
|
||||
]
|
||||
cropped = axis[indices]
|
||||
order = np.argsort(func.inv(cropped))
|
||||
order = np.argsort(plt_range.unit.inv(cropped))
|
||||
indices = indices[order]
|
||||
cropped = cropped[order]
|
||||
out_ax = func.inv(cropped)
|
||||
out_ax = plt_range.unit.inv(cropped)
|
||||
|
||||
return out_ax, indices, (out_ax[0], out_ax[-1])
|
||||
|
||||
|
||||
@@ -248,14 +248,13 @@ def _finish_plot_2D(
|
||||
ax,
|
||||
file_name,
|
||||
file_type,
|
||||
params,
|
||||
):
|
||||
|
||||
# apply log transform if required
|
||||
if log != False:
|
||||
if log is not False:
|
||||
vmax = defaults["vmax"] if vmax is None else vmax
|
||||
vmin = defaults["vmin"] if vmin is None else vmin
|
||||
if isinstance(log, (float, int)) and log != True:
|
||||
if isinstance(log, (float, int)) and log is not True:
|
||||
values = units.to_log(values, ref=log)
|
||||
|
||||
elif log == "2D":
|
||||
@@ -338,7 +337,7 @@ def _finish_plot_2D(
|
||||
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||
print(f"plot saved in {out_path}")
|
||||
if cbar_label is not None:
|
||||
return fig, ax, cbar.ax
|
||||
return fig, (ax, cbar.ax)
|
||||
else:
|
||||
return fig, ax
|
||||
|
||||
@@ -355,7 +354,7 @@ def plot_spectrogram(
|
||||
vmax: float = None,
|
||||
cbar_label: str = "normalized intensity (dB)",
|
||||
file_type: str = "png",
|
||||
file_name: str = None,
|
||||
file_name: str = "plot",
|
||||
cmap: str = None,
|
||||
ax: plt.Axes = None,
|
||||
):
|
||||
@@ -448,13 +447,12 @@ def plot_spectrogram(
|
||||
ax,
|
||||
file_name,
|
||||
file_type,
|
||||
params,
|
||||
)
|
||||
|
||||
|
||||
def plot_results_2D(
|
||||
values: np.ndarray,
|
||||
plt_range: RangeType,
|
||||
plt_range: Union[units.PlotRange, tuple],
|
||||
params: BareParams,
|
||||
log: Union[int, float, bool, str] = "1D",
|
||||
skip: int = 16,
|
||||
@@ -463,7 +461,7 @@ def plot_results_2D(
|
||||
transpose: bool = False,
|
||||
cbar_label: Optional[str] = "normalized intensity (dB)",
|
||||
file_type: str = "png",
|
||||
file_name: str = None,
|
||||
file_name: str = "plot",
|
||||
cmap: str = None,
|
||||
ax: plt.Axes = None,
|
||||
):
|
||||
@@ -528,7 +526,7 @@ def plot_results_2D(
|
||||
values = abs2(values)
|
||||
|
||||
# make uniform if converting to wavelength
|
||||
if plt_range[2].type == "WL":
|
||||
if plt_range.unit.type == "WL":
|
||||
if is_spectrum:
|
||||
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
||||
values = np.array(
|
||||
@@ -554,7 +552,7 @@ def plot_results_2D(
|
||||
return _finish_plot_2D(
|
||||
values,
|
||||
x_axis,
|
||||
plt_range[2].label,
|
||||
plt_range.unit.label,
|
||||
params.z_targets,
|
||||
"propagation distance (m)",
|
||||
log,
|
||||
@@ -566,13 +564,12 @@ def plot_results_2D(
|
||||
ax,
|
||||
file_name,
|
||||
file_type,
|
||||
params,
|
||||
)
|
||||
|
||||
|
||||
def plot_results_1D(
|
||||
values: np.ndarray,
|
||||
plt_range: RangeType,
|
||||
plt_range: Union[units.PlotRange, tuple],
|
||||
params: BareParams,
|
||||
log: Union[str, int, float, bool] = False,
|
||||
spacing: Union[int, float] = 1,
|
||||
@@ -586,7 +583,7 @@ def plot_results_1D(
|
||||
line_label: str = None,
|
||||
transpose: bool = False,
|
||||
**line_kwargs,
|
||||
):
|
||||
) -> tuple[plt.Figure, plt.Axes, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -649,7 +646,7 @@ def plot_results_1D(
|
||||
values *= yscaling
|
||||
|
||||
# make uniform if converting to wavelength
|
||||
if plt_range[2].type == "WL":
|
||||
if plt_range.unit.type == "WL":
|
||||
if is_spectrum:
|
||||
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind]))
|
||||
|
||||
@@ -687,12 +684,12 @@ def plot_results_1D(
|
||||
ax.yaxis.set_label_position("right")
|
||||
ax.set_xlim(vmax, vmin)
|
||||
ax.set_xlabel(ylabel)
|
||||
ax.set_ylabel(plt_range[2].label)
|
||||
ax.set_ylabel(plt_range.unit.label)
|
||||
else:
|
||||
ax.plot(x_axis, values, label=line_label, **line_kwargs)
|
||||
ax.set_ylim(vmin, vmax)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_xlabel(plt_range[2].label)
|
||||
ax.set_xlabel(plt_range.unit.label)
|
||||
|
||||
if is_new_plot:
|
||||
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||
@@ -700,10 +697,13 @@ def plot_results_1D(
|
||||
return fig, ax, x_axis, values
|
||||
|
||||
|
||||
def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams):
|
||||
def _prep_plot(
|
||||
values: np.ndarray, plt_range: Union[units.PlotRange, tuple], params: BareParams
|
||||
) -> tuple[bool, np.ndarray, units.PlotRange]:
|
||||
is_spectrum = values.dtype == "complex"
|
||||
plt_range = (*plt_range[:2], units.get_unit(plt_range[2]))
|
||||
if plt_range[2].type in ["WL", "FREQ", "AFREQ"]:
|
||||
if not isinstance(plt_range, units.PlotRange):
|
||||
plt_range = units.PlotRange(*plt_range)
|
||||
if plt_range.unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||
x_axis = params.w.copy()
|
||||
else:
|
||||
x_axis = params.t.copy()
|
||||
@@ -712,7 +712,7 @@ def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams):
|
||||
|
||||
def plot_avg(
|
||||
values: np.ndarray,
|
||||
plt_range: RangeType,
|
||||
plt_range: Union[units.PlotRange, tuple],
|
||||
params: BareParams,
|
||||
log: Union[float, int, str, bool] = False,
|
||||
spacing: Union[float, int] = 1,
|
||||
@@ -809,7 +809,7 @@ def plot_avg(
|
||||
values = abs2(values)
|
||||
values *= yscaling
|
||||
mean_values = np.mean(values, axis=0)
|
||||
if plt_range[2].type == "WL" and renormalize:
|
||||
if plt_range.unit.type == "WL" and renormalize:
|
||||
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
||||
mean_values = units.to_WL(mean_values, params.frep, x_axis)
|
||||
|
||||
@@ -886,7 +886,7 @@ def plot_avg(
|
||||
top.set_ylim(*ext)
|
||||
bot.yaxis.tick_right()
|
||||
bot.yaxis.set_label_position("right")
|
||||
bot.set_ylabel(plt_range[2].label)
|
||||
bot.set_ylabel(plt_range.unit.label)
|
||||
bot.set_ylim(*ext)
|
||||
else:
|
||||
for value in values:
|
||||
@@ -898,7 +898,7 @@ def plot_avg(
|
||||
top.set_ylim(bottom=vmin, top=vmax)
|
||||
top.set_ylabel(ylabel)
|
||||
top.set_xlim(*ext)
|
||||
bot.set_xlabel(plt_range[2].label)
|
||||
bot.set_xlabel(plt_range.unit.label)
|
||||
bot.set_xlim(*ext)
|
||||
|
||||
custom_lines = [
|
||||
@@ -961,7 +961,7 @@ def prepare_plot_1D(values, plt_range, x_axis, yscaling=1, spacing=1, frep=80e6)
|
||||
|
||||
values = values[:, ind]
|
||||
|
||||
if plt_range[2].type == "WL":
|
||||
if plt_range.unit.type == "WL":
|
||||
values = np.apply_along_axis(units.to_WL, -1, values, frep, x_axis)
|
||||
|
||||
if isinstance(spacing, float):
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from re import UNICODE
|
||||
from typing import Callable, Dict, Iterable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import initialize, io, math
|
||||
from .physics import units
|
||||
from .const import SPECN_FN
|
||||
from .logger import get_logger
|
||||
from .plotting import units
|
||||
from .plotting import plot_avg, plot_results_1D, plot_results_2D
|
||||
|
||||
|
||||
class Spectrum(np.ndarray):
|
||||
@@ -158,13 +160,12 @@ class Pulse(Sequence):
|
||||
Parameters
|
||||
----------
|
||||
ind : int or list of int
|
||||
if only certain spectra are desired.
|
||||
- If left to None, returns every spectrum
|
||||
- If only 1 int, will cast the (1, n, nt) array into a (n, nt) array
|
||||
if only certain spectra are desired
|
||||
Returns
|
||||
----------
|
||||
spectra : array
|
||||
squeezed array of complex spectra (n simulation on a nt size grid at each ind)
|
||||
spectra : array of shape (nz, m, nt)
|
||||
array of complex spectra (pulse at nz positions consisting
|
||||
of nm simulation on a nt size grid)
|
||||
"""
|
||||
|
||||
self.logger.debug(f"opening {self.path}")
|
||||
@@ -200,3 +201,46 @@ class Pulse(Sequence):
|
||||
spec = Spectrum(spec, self.wl, self.params.frep)
|
||||
self.cache[i] = spec
|
||||
return spec
|
||||
|
||||
def plot_2D(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
z_ind: Union[int, Iterable[int]] = None,
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
|
||||
return plot_results_2D(vals, plt_range, self.params, **kwargs)
|
||||
|
||||
def plot_1D(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
z_ind: int,
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
|
||||
return plot_results_1D(vals[0], plt_range, self.params, **kwargs)
|
||||
|
||||
def plot_avg(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
z_ind: int,
|
||||
**kwargs,
|
||||
):
|
||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, slice(None))
|
||||
return plot_avg(vals, plt_range, self.params, **kwargs)
|
||||
|
||||
def retrieve_plot_values(self, left, right, unit, z_ind, sim_ind):
|
||||
plt_range = units.PlotRange(left, right, unit)
|
||||
if plt_range.unit.type == "TIME":
|
||||
vals = self.all_fields(ind=z_ind)[:, sim_ind]
|
||||
else:
|
||||
vals = self.all_spectra(ind=z_ind)[:, sim_ind]
|
||||
return plt_range, vals
|
||||
|
||||
@@ -283,61 +283,3 @@ def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
|
||||
for k in new:
|
||||
variable.pop(k, None) # remove old ones
|
||||
return replace(old, variable=variable, **{k: None for k in variable}, **new)
|
||||
|
||||
|
||||
# def np_cache(function):
|
||||
# """applies functools.cache to function that take numpy arrays as input"""
|
||||
|
||||
# @cache
|
||||
# def cached_wrapper(*hashable_args, **hashable_kwargs):
|
||||
# args = tuple(np.array(arg) if isinstance(arg, tuple) else arg for arg in hashable_args)
|
||||
# kwargs = {
|
||||
# k: np.array(kwarg) if isinstance(kwarg, tuple) else kwarg
|
||||
# for k, kwarg in hashable_kwargs.items()
|
||||
# }
|
||||
# return function(*args, **kwargs)
|
||||
|
||||
# @wraps(function)
|
||||
# def wrapper(*args, **kwargs):
|
||||
# hashable_args = tuple(tuple(arg) if isinstance(arg, np.ndarray) else arg for arg in args)
|
||||
# hashable_kwargs = {
|
||||
# k: tuple(kwarg) if isinstance(kwarg, np.ndarray) else kwarg
|
||||
# for k, kwarg in kwargs.items()
|
||||
# }
|
||||
# return cached_wrapper(*hashable_args, **hashable_kwargs)
|
||||
|
||||
# # copy lru_cache attributes over too
|
||||
# wrapper.cache_info = cached_wrapper.cache_info
|
||||
# wrapper.cache_clear = cached_wrapper.cache_clear
|
||||
|
||||
# return wrapper
|
||||
|
||||
|
||||
class np_cache:
|
||||
def __init__(self, function):
|
||||
self.logger = get_logger(__name__)
|
||||
self.func = function
|
||||
self.cache = {}
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
update_wrapper(self, function)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
hashable_args = tuple(
|
||||
tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args
|
||||
)
|
||||
hashable_kwargs = tuple(
|
||||
{
|
||||
k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg
|
||||
for k, kwarg in kwargs.items()
|
||||
}.items()
|
||||
)
|
||||
key = hash((hashable_args, hashable_kwargs))
|
||||
if key not in self.cache:
|
||||
self.logger.debug("cache miss")
|
||||
self.misses += 1
|
||||
self.cache[key] = self.func(*args, **kwargs)
|
||||
else:
|
||||
self.hits += 1
|
||||
self.logger.debug("cache hit")
|
||||
return copy(self.cache[key])
|
||||
|
||||
75
src/scgenerator/utils/cache.py
Normal file
75
src/scgenerator/utils/cache.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from collections import namedtuple
|
||||
from copy import copy
|
||||
from functools import wraps
|
||||
import numpy as np
|
||||
|
||||
CacheInfo = namedtuple("CacheInfo", "hits misses size")
|
||||
|
||||
|
||||
def np_cache(func):
|
||||
def new_cached_func():
|
||||
cache = {}
|
||||
hits = misses = 0
|
||||
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
nonlocal cache, hits, misses
|
||||
hashable_args = tuple(
|
||||
tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args
|
||||
)
|
||||
hashable_kwargs = tuple(
|
||||
{
|
||||
k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg
|
||||
for k, kwarg in kwargs.items()
|
||||
}.items()
|
||||
)
|
||||
key = hash((hashable_args, hashable_kwargs))
|
||||
if key not in cache:
|
||||
misses += 1
|
||||
cache[key] = func(*args, **kwargs)
|
||||
else:
|
||||
hits += 1
|
||||
return copy(cache[key])
|
||||
|
||||
def reset():
|
||||
nonlocal cache, hits, misses
|
||||
cache = {}
|
||||
hits = misses = 0
|
||||
|
||||
wrapped.cache_info = lambda: CacheInfo(hits, misses, len(cache))
|
||||
wrapped.reset = reset
|
||||
|
||||
return wrapped
|
||||
|
||||
return new_cached_func()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
import time
|
||||
|
||||
@np_cache
|
||||
def lol(a):
|
||||
time.sleep(random.random() * 4)
|
||||
return a / 2
|
||||
|
||||
@np_cache
|
||||
def ggg(b):
|
||||
time.sleep(random.random() * 4)
|
||||
return b * 2
|
||||
|
||||
x = np.arange(6)
|
||||
for i in range(5):
|
||||
print(lol.cache_info())
|
||||
print(lol(x))
|
||||
|
||||
print(f"{ggg.cache_info()=}")
|
||||
print(f"{lol.cache_info()=}")
|
||||
lol.reset()
|
||||
|
||||
print(ggg(np.arange(3)))
|
||||
print(ggg(np.arange(8)))
|
||||
print(ggg(np.arange(3)))
|
||||
|
||||
print(f"{ggg.cache_info()=}")
|
||||
print(f"{lol.cache_info()=}")
|
||||
Reference in New Issue
Block a user