better np cache, integr. plotting, nfft
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.interpolate import griddata, interp1d
|
from scipy.interpolate import griddata, interp1d
|
||||||
from scipy.special import jn_zeros
|
from scipy.special import jn_zeros
|
||||||
|
from .utils.cache import np_cache
|
||||||
|
|
||||||
|
|
||||||
def span(*vec):
|
def span(*vec):
|
||||||
@@ -89,6 +90,85 @@ def u_nm(n, m):
|
|||||||
return jn_zeros(n - 1, m)[-1]
|
return jn_zeros(n - 1, m)[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@np_cache
|
||||||
|
def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||||
|
"""creates the nfft matrix
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
t : np.ndarray, shape = (n,)
|
||||||
|
time array
|
||||||
|
f : np.ndarray, shape = (m,)
|
||||||
|
frequency array
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray, shape = (m, n)
|
||||||
|
multiply x(t) by this matrix to get ~X(f)
|
||||||
|
"""
|
||||||
|
P, F = np.meshgrid(t, f)
|
||||||
|
return np.exp(-2j * np.pi * P * F)
|
||||||
|
|
||||||
|
|
||||||
|
@np_cache
|
||||||
|
def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||||
|
"""creates the nfft matrix
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
t : np.ndarray, shape = (n,)
|
||||||
|
time array
|
||||||
|
f : np.ndarray, shape = (m,)
|
||||||
|
frequency array
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray, shape = (m, n)
|
||||||
|
multiply ~X(f) by this matrix to get x(t)
|
||||||
|
"""
|
||||||
|
return np.linalg.pinv(nfft_matrix(t, f))
|
||||||
|
|
||||||
|
|
||||||
|
def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||||
|
"""computes the Fourier transform of an uneven signal
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
t : np.ndarray, shape = (n,)
|
||||||
|
time array
|
||||||
|
s : np.ndarray, shape = (n, )
|
||||||
|
amplitute at each point of t
|
||||||
|
f : np.ndarray, shape = (m, )
|
||||||
|
desired frequencies
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray, shape = (m, )
|
||||||
|
amplitude at each frequency
|
||||||
|
"""
|
||||||
|
return nfft_matrix(t, f) @ s
|
||||||
|
|
||||||
|
|
||||||
|
def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
|
||||||
|
"""computes the inverse Fourier transform of an uneven spectrum
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
f : np.ndarray, shape = (n,)
|
||||||
|
frequency array
|
||||||
|
a : np.ndarray, shape = (n, )
|
||||||
|
amplitude at each point of f
|
||||||
|
t : np.ndarray, shape = (m, )
|
||||||
|
time array
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray, shape = (m, )
|
||||||
|
amplitude at each point of t
|
||||||
|
"""
|
||||||
|
return infft_matrix(t, f) @ a
|
||||||
|
|
||||||
|
|
||||||
def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"):
|
def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"):
|
||||||
"""Interpolates a 2D array with the help of griddata
|
"""Interpolates a 2D array with the help of griddata
|
||||||
Parameters
|
Parameters
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from scipy.interpolate import interp1d
|
|||||||
from .. import io
|
from .. import io
|
||||||
from ..math import abs2, argclosest, power_fact, u_nm
|
from ..math import abs2, argclosest, power_fact, u_nm
|
||||||
from ..utils.parameter import BareParams, hc_model_specific_parameters
|
from ..utils.parameter import BareParams, hc_model_specific_parameters
|
||||||
from ..utils import np_cache
|
from ..utils.cache import np_cache
|
||||||
from . import materials as mat
|
from . import materials as mat
|
||||||
from . import units
|
from . import units
|
||||||
from .units import c, pi
|
from .units import c, pi
|
||||||
|
|||||||
@@ -3,7 +3,8 @@
|
|||||||
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
|
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
|
||||||
|
|
||||||
from typing import Callable, Union
|
from typing import Callable, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from ..utils.parameter import Parameter, type_checker
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import pi
|
from numpy import pi
|
||||||
|
|
||||||
@@ -224,6 +225,18 @@ def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]:
|
|||||||
return unit
|
return unit
|
||||||
|
|
||||||
|
|
||||||
|
def is_unit(name, value):
|
||||||
|
if not hasattr(get_unit(value), "inv"):
|
||||||
|
raise TypeError("invalid unit specified")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlotRange:
|
||||||
|
left: float = Parameter(type_checker(int, float))
|
||||||
|
right: float = Parameter(type_checker(int, float))
|
||||||
|
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit)
|
||||||
|
|
||||||
|
|
||||||
def beta2_coef(beta):
|
def beta2_coef(beta):
|
||||||
fac = 1e27
|
fac = 1e27
|
||||||
out = np.zeros_like(beta)
|
out = np.zeros_like(beta)
|
||||||
@@ -263,45 +276,49 @@ def standardize_dictionary(dico):
|
|||||||
return dico
|
return dico
|
||||||
|
|
||||||
|
|
||||||
def sort_axis(axis, plt_range):
|
def sort_axis(axis, plt_range: PlotRange):
|
||||||
"""
|
"""
|
||||||
given an axis, returns this axis cropped according to the given range, converted and sorted
|
given an axis, returns this axis cropped according to the given range, converted and sorted
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
axis : 1D array containing the original axis (usual the w or t array)
|
axis : 1D array containing the original axis (usual the w or t array)
|
||||||
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
-------
|
||||||
cropped : the axis cropped, converted and sorted
|
cropped : the axis cropped, converted and sorted
|
||||||
indices : indices to use to slice and sort other array in the same fashion
|
indices : indices to use to slice and sort other array in the same fashion
|
||||||
extent : tupple with min and max of cropped
|
extent : tupple with min and max of cropped
|
||||||
|
|
||||||
Example
|
Example
|
||||||
----------
|
-------
|
||||||
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
||||||
t = np.linspace(-10, 10, 400)
|
t = np.linspace(-10, 10, 400)
|
||||||
W, T = np.meshgrid(w, t)
|
W, T = np.meshgrid(w, t)
|
||||||
y = np.exp(-W**2 - T**2)
|
y = np.exp(-W**2 - T**2)
|
||||||
|
|
||||||
# Define ranges
|
# Define ranges
|
||||||
rw = (-4, 4, s)
|
rw = (-4, 4, s)
|
||||||
rt = (-2, 6, s)
|
rt = (-2, 6, s)
|
||||||
|
|
||||||
w, cw = sort_axis(w, rw)
|
w, cw = sort_axis(w, rw)
|
||||||
t, ct = sort_axis(t, rt)
|
t, ct = sort_axis(t, rt)
|
||||||
|
|
||||||
# slice y according to the given ranges
|
# slice y according to the given ranges
|
||||||
y = y[ct][:, cw]
|
y = y[ct][:, cw]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
r = np.array(plt_range[:2], dtype="float")
|
r = np.array((plt_range.left, plt_range.right), dtype="float")
|
||||||
func = get_unit(plt_range[2])
|
|
||||||
|
|
||||||
indices = np.arange(len(axis))[(axis <= np.max(func(r))) & (axis >= np.min(func(r)))]
|
indices = np.arange(len(axis))[
|
||||||
|
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
|
||||||
|
]
|
||||||
cropped = axis[indices]
|
cropped = axis[indices]
|
||||||
order = np.argsort(func.inv(cropped))
|
order = np.argsort(plt_range.unit.inv(cropped))
|
||||||
indices = indices[order]
|
indices = indices[order]
|
||||||
cropped = cropped[order]
|
cropped = cropped[order]
|
||||||
out_ax = func.inv(cropped)
|
out_ax = plt_range.unit.inv(cropped)
|
||||||
|
|
||||||
return out_ax, indices, (out_ax[0], out_ax[-1])
|
return out_ax, indices, (out_ax[0], out_ax[-1])
|
||||||
|
|
||||||
|
|||||||
@@ -248,14 +248,13 @@ def _finish_plot_2D(
|
|||||||
ax,
|
ax,
|
||||||
file_name,
|
file_name,
|
||||||
file_type,
|
file_type,
|
||||||
params,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
# apply log transform if required
|
# apply log transform if required
|
||||||
if log != False:
|
if log is not False:
|
||||||
vmax = defaults["vmax"] if vmax is None else vmax
|
vmax = defaults["vmax"] if vmax is None else vmax
|
||||||
vmin = defaults["vmin"] if vmin is None else vmin
|
vmin = defaults["vmin"] if vmin is None else vmin
|
||||||
if isinstance(log, (float, int)) and log != True:
|
if isinstance(log, (float, int)) and log is not True:
|
||||||
values = units.to_log(values, ref=log)
|
values = units.to_log(values, ref=log)
|
||||||
|
|
||||||
elif log == "2D":
|
elif log == "2D":
|
||||||
@@ -338,7 +337,7 @@ def _finish_plot_2D(
|
|||||||
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||||
print(f"plot saved in {out_path}")
|
print(f"plot saved in {out_path}")
|
||||||
if cbar_label is not None:
|
if cbar_label is not None:
|
||||||
return fig, ax, cbar.ax
|
return fig, (ax, cbar.ax)
|
||||||
else:
|
else:
|
||||||
return fig, ax
|
return fig, ax
|
||||||
|
|
||||||
@@ -355,7 +354,7 @@ def plot_spectrogram(
|
|||||||
vmax: float = None,
|
vmax: float = None,
|
||||||
cbar_label: str = "normalized intensity (dB)",
|
cbar_label: str = "normalized intensity (dB)",
|
||||||
file_type: str = "png",
|
file_type: str = "png",
|
||||||
file_name: str = None,
|
file_name: str = "plot",
|
||||||
cmap: str = None,
|
cmap: str = None,
|
||||||
ax: plt.Axes = None,
|
ax: plt.Axes = None,
|
||||||
):
|
):
|
||||||
@@ -448,13 +447,12 @@ def plot_spectrogram(
|
|||||||
ax,
|
ax,
|
||||||
file_name,
|
file_name,
|
||||||
file_type,
|
file_type,
|
||||||
params,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def plot_results_2D(
|
def plot_results_2D(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
plt_range: RangeType,
|
plt_range: Union[units.PlotRange, tuple],
|
||||||
params: BareParams,
|
params: BareParams,
|
||||||
log: Union[int, float, bool, str] = "1D",
|
log: Union[int, float, bool, str] = "1D",
|
||||||
skip: int = 16,
|
skip: int = 16,
|
||||||
@@ -463,7 +461,7 @@ def plot_results_2D(
|
|||||||
transpose: bool = False,
|
transpose: bool = False,
|
||||||
cbar_label: Optional[str] = "normalized intensity (dB)",
|
cbar_label: Optional[str] = "normalized intensity (dB)",
|
||||||
file_type: str = "png",
|
file_type: str = "png",
|
||||||
file_name: str = None,
|
file_name: str = "plot",
|
||||||
cmap: str = None,
|
cmap: str = None,
|
||||||
ax: plt.Axes = None,
|
ax: plt.Axes = None,
|
||||||
):
|
):
|
||||||
@@ -528,7 +526,7 @@ def plot_results_2D(
|
|||||||
values = abs2(values)
|
values = abs2(values)
|
||||||
|
|
||||||
# make uniform if converting to wavelength
|
# make uniform if converting to wavelength
|
||||||
if plt_range[2].type == "WL":
|
if plt_range.unit.type == "WL":
|
||||||
if is_spectrum:
|
if is_spectrum:
|
||||||
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
||||||
values = np.array(
|
values = np.array(
|
||||||
@@ -554,7 +552,7 @@ def plot_results_2D(
|
|||||||
return _finish_plot_2D(
|
return _finish_plot_2D(
|
||||||
values,
|
values,
|
||||||
x_axis,
|
x_axis,
|
||||||
plt_range[2].label,
|
plt_range.unit.label,
|
||||||
params.z_targets,
|
params.z_targets,
|
||||||
"propagation distance (m)",
|
"propagation distance (m)",
|
||||||
log,
|
log,
|
||||||
@@ -566,13 +564,12 @@ def plot_results_2D(
|
|||||||
ax,
|
ax,
|
||||||
file_name,
|
file_name,
|
||||||
file_type,
|
file_type,
|
||||||
params,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def plot_results_1D(
|
def plot_results_1D(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
plt_range: RangeType,
|
plt_range: Union[units.PlotRange, tuple],
|
||||||
params: BareParams,
|
params: BareParams,
|
||||||
log: Union[str, int, float, bool] = False,
|
log: Union[str, int, float, bool] = False,
|
||||||
spacing: Union[int, float] = 1,
|
spacing: Union[int, float] = 1,
|
||||||
@@ -586,7 +583,7 @@ def plot_results_1D(
|
|||||||
line_label: str = None,
|
line_label: str = None,
|
||||||
transpose: bool = False,
|
transpose: bool = False,
|
||||||
**line_kwargs,
|
**line_kwargs,
|
||||||
):
|
) -> tuple[plt.Figure, plt.Axes, np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -649,7 +646,7 @@ def plot_results_1D(
|
|||||||
values *= yscaling
|
values *= yscaling
|
||||||
|
|
||||||
# make uniform if converting to wavelength
|
# make uniform if converting to wavelength
|
||||||
if plt_range[2].type == "WL":
|
if plt_range.unit.type == "WL":
|
||||||
if is_spectrum:
|
if is_spectrum:
|
||||||
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind]))
|
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind]))
|
||||||
|
|
||||||
@@ -687,12 +684,12 @@ def plot_results_1D(
|
|||||||
ax.yaxis.set_label_position("right")
|
ax.yaxis.set_label_position("right")
|
||||||
ax.set_xlim(vmax, vmin)
|
ax.set_xlim(vmax, vmin)
|
||||||
ax.set_xlabel(ylabel)
|
ax.set_xlabel(ylabel)
|
||||||
ax.set_ylabel(plt_range[2].label)
|
ax.set_ylabel(plt_range.unit.label)
|
||||||
else:
|
else:
|
||||||
ax.plot(x_axis, values, label=line_label, **line_kwargs)
|
ax.plot(x_axis, values, label=line_label, **line_kwargs)
|
||||||
ax.set_ylim(vmin, vmax)
|
ax.set_ylim(vmin, vmax)
|
||||||
ax.set_ylabel(ylabel)
|
ax.set_ylabel(ylabel)
|
||||||
ax.set_xlabel(plt_range[2].label)
|
ax.set_xlabel(plt_range.unit.label)
|
||||||
|
|
||||||
if is_new_plot:
|
if is_new_plot:
|
||||||
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||||
@@ -700,10 +697,13 @@ def plot_results_1D(
|
|||||||
return fig, ax, x_axis, values
|
return fig, ax, x_axis, values
|
||||||
|
|
||||||
|
|
||||||
def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams):
|
def _prep_plot(
|
||||||
|
values: np.ndarray, plt_range: Union[units.PlotRange, tuple], params: BareParams
|
||||||
|
) -> tuple[bool, np.ndarray, units.PlotRange]:
|
||||||
is_spectrum = values.dtype == "complex"
|
is_spectrum = values.dtype == "complex"
|
||||||
plt_range = (*plt_range[:2], units.get_unit(plt_range[2]))
|
if not isinstance(plt_range, units.PlotRange):
|
||||||
if plt_range[2].type in ["WL", "FREQ", "AFREQ"]:
|
plt_range = units.PlotRange(*plt_range)
|
||||||
|
if plt_range.unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||||
x_axis = params.w.copy()
|
x_axis = params.w.copy()
|
||||||
else:
|
else:
|
||||||
x_axis = params.t.copy()
|
x_axis = params.t.copy()
|
||||||
@@ -712,7 +712,7 @@ def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams):
|
|||||||
|
|
||||||
def plot_avg(
|
def plot_avg(
|
||||||
values: np.ndarray,
|
values: np.ndarray,
|
||||||
plt_range: RangeType,
|
plt_range: Union[units.PlotRange, tuple],
|
||||||
params: BareParams,
|
params: BareParams,
|
||||||
log: Union[float, int, str, bool] = False,
|
log: Union[float, int, str, bool] = False,
|
||||||
spacing: Union[float, int] = 1,
|
spacing: Union[float, int] = 1,
|
||||||
@@ -809,7 +809,7 @@ def plot_avg(
|
|||||||
values = abs2(values)
|
values = abs2(values)
|
||||||
values *= yscaling
|
values *= yscaling
|
||||||
mean_values = np.mean(values, axis=0)
|
mean_values = np.mean(values, axis=0)
|
||||||
if plt_range[2].type == "WL" and renormalize:
|
if plt_range.unit.type == "WL" and renormalize:
|
||||||
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
||||||
mean_values = units.to_WL(mean_values, params.frep, x_axis)
|
mean_values = units.to_WL(mean_values, params.frep, x_axis)
|
||||||
|
|
||||||
@@ -886,7 +886,7 @@ def plot_avg(
|
|||||||
top.set_ylim(*ext)
|
top.set_ylim(*ext)
|
||||||
bot.yaxis.tick_right()
|
bot.yaxis.tick_right()
|
||||||
bot.yaxis.set_label_position("right")
|
bot.yaxis.set_label_position("right")
|
||||||
bot.set_ylabel(plt_range[2].label)
|
bot.set_ylabel(plt_range.unit.label)
|
||||||
bot.set_ylim(*ext)
|
bot.set_ylim(*ext)
|
||||||
else:
|
else:
|
||||||
for value in values:
|
for value in values:
|
||||||
@@ -898,7 +898,7 @@ def plot_avg(
|
|||||||
top.set_ylim(bottom=vmin, top=vmax)
|
top.set_ylim(bottom=vmin, top=vmax)
|
||||||
top.set_ylabel(ylabel)
|
top.set_ylabel(ylabel)
|
||||||
top.set_xlim(*ext)
|
top.set_xlim(*ext)
|
||||||
bot.set_xlabel(plt_range[2].label)
|
bot.set_xlabel(plt_range.unit.label)
|
||||||
bot.set_xlim(*ext)
|
bot.set_xlim(*ext)
|
||||||
|
|
||||||
custom_lines = [
|
custom_lines = [
|
||||||
@@ -961,7 +961,7 @@ def prepare_plot_1D(values, plt_range, x_axis, yscaling=1, spacing=1, frep=80e6)
|
|||||||
|
|
||||||
values = values[:, ind]
|
values = values[:, ind]
|
||||||
|
|
||||||
if plt_range[2].type == "WL":
|
if plt_range.unit.type == "WL":
|
||||||
values = np.apply_along_axis(units.to_WL, -1, values, frep, x_axis)
|
values = np.apply_along_axis(units.to_WL, -1, values, frep, x_axis)
|
||||||
|
|
||||||
if isinstance(spacing, float):
|
if isinstance(spacing, float):
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from re import UNICODE
|
||||||
|
from typing import Callable, Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from . import initialize, io, math
|
from . import initialize, io, math
|
||||||
|
from .physics import units
|
||||||
from .const import SPECN_FN
|
from .const import SPECN_FN
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .plotting import units
|
from .plotting import plot_avg, plot_results_1D, plot_results_2D
|
||||||
|
|
||||||
|
|
||||||
class Spectrum(np.ndarray):
|
class Spectrum(np.ndarray):
|
||||||
@@ -158,13 +160,12 @@ class Pulse(Sequence):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
ind : int or list of int
|
ind : int or list of int
|
||||||
if only certain spectra are desired.
|
if only certain spectra are desired
|
||||||
- If left to None, returns every spectrum
|
|
||||||
- If only 1 int, will cast the (1, n, nt) array into a (n, nt) array
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
spectra : array
|
spectra : array of shape (nz, m, nt)
|
||||||
squeezed array of complex spectra (n simulation on a nt size grid at each ind)
|
array of complex spectra (pulse at nz positions consisting
|
||||||
|
of nm simulation on a nt size grid)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.logger.debug(f"opening {self.path}")
|
self.logger.debug(f"opening {self.path}")
|
||||||
@@ -200,3 +201,46 @@ class Pulse(Sequence):
|
|||||||
spec = Spectrum(spec, self.wl, self.params.frep)
|
spec = Spectrum(spec, self.wl, self.params.frep)
|
||||||
self.cache[i] = spec
|
self.cache[i] = spec
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
def plot_2D(
|
||||||
|
self,
|
||||||
|
left: float,
|
||||||
|
right: float,
|
||||||
|
unit: Union[Callable[[float], float], str],
|
||||||
|
z_ind: Union[int, Iterable[int]] = None,
|
||||||
|
sim_ind: int = 0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
|
||||||
|
return plot_results_2D(vals, plt_range, self.params, **kwargs)
|
||||||
|
|
||||||
|
def plot_1D(
|
||||||
|
self,
|
||||||
|
left: float,
|
||||||
|
right: float,
|
||||||
|
unit: Union[Callable[[float], float], str],
|
||||||
|
z_ind: int,
|
||||||
|
sim_ind: int = 0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
|
||||||
|
return plot_results_1D(vals[0], plt_range, self.params, **kwargs)
|
||||||
|
|
||||||
|
def plot_avg(
|
||||||
|
self,
|
||||||
|
left: float,
|
||||||
|
right: float,
|
||||||
|
unit: Union[Callable[[float], float], str],
|
||||||
|
z_ind: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, slice(None))
|
||||||
|
return plot_avg(vals, plt_range, self.params, **kwargs)
|
||||||
|
|
||||||
|
def retrieve_plot_values(self, left, right, unit, z_ind, sim_ind):
|
||||||
|
plt_range = units.PlotRange(left, right, unit)
|
||||||
|
if plt_range.unit.type == "TIME":
|
||||||
|
vals = self.all_fields(ind=z_ind)[:, sim_ind]
|
||||||
|
else:
|
||||||
|
vals = self.all_spectra(ind=z_ind)[:, sim_ind]
|
||||||
|
return plt_range, vals
|
||||||
|
|||||||
@@ -283,61 +283,3 @@ def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
|
|||||||
for k in new:
|
for k in new:
|
||||||
variable.pop(k, None) # remove old ones
|
variable.pop(k, None) # remove old ones
|
||||||
return replace(old, variable=variable, **{k: None for k in variable}, **new)
|
return replace(old, variable=variable, **{k: None for k in variable}, **new)
|
||||||
|
|
||||||
|
|
||||||
# def np_cache(function):
|
|
||||||
# """applies functools.cache to function that take numpy arrays as input"""
|
|
||||||
|
|
||||||
# @cache
|
|
||||||
# def cached_wrapper(*hashable_args, **hashable_kwargs):
|
|
||||||
# args = tuple(np.array(arg) if isinstance(arg, tuple) else arg for arg in hashable_args)
|
|
||||||
# kwargs = {
|
|
||||||
# k: np.array(kwarg) if isinstance(kwarg, tuple) else kwarg
|
|
||||||
# for k, kwarg in hashable_kwargs.items()
|
|
||||||
# }
|
|
||||||
# return function(*args, **kwargs)
|
|
||||||
|
|
||||||
# @wraps(function)
|
|
||||||
# def wrapper(*args, **kwargs):
|
|
||||||
# hashable_args = tuple(tuple(arg) if isinstance(arg, np.ndarray) else arg for arg in args)
|
|
||||||
# hashable_kwargs = {
|
|
||||||
# k: tuple(kwarg) if isinstance(kwarg, np.ndarray) else kwarg
|
|
||||||
# for k, kwarg in kwargs.items()
|
|
||||||
# }
|
|
||||||
# return cached_wrapper(*hashable_args, **hashable_kwargs)
|
|
||||||
|
|
||||||
# # copy lru_cache attributes over too
|
|
||||||
# wrapper.cache_info = cached_wrapper.cache_info
|
|
||||||
# wrapper.cache_clear = cached_wrapper.cache_clear
|
|
||||||
|
|
||||||
# return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class np_cache:
|
|
||||||
def __init__(self, function):
|
|
||||||
self.logger = get_logger(__name__)
|
|
||||||
self.func = function
|
|
||||||
self.cache = {}
|
|
||||||
self.hits = 0
|
|
||||||
self.misses = 0
|
|
||||||
update_wrapper(self, function)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
hashable_args = tuple(
|
|
||||||
tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args
|
|
||||||
)
|
|
||||||
hashable_kwargs = tuple(
|
|
||||||
{
|
|
||||||
k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg
|
|
||||||
for k, kwarg in kwargs.items()
|
|
||||||
}.items()
|
|
||||||
)
|
|
||||||
key = hash((hashable_args, hashable_kwargs))
|
|
||||||
if key not in self.cache:
|
|
||||||
self.logger.debug("cache miss")
|
|
||||||
self.misses += 1
|
|
||||||
self.cache[key] = self.func(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
self.hits += 1
|
|
||||||
self.logger.debug("cache hit")
|
|
||||||
return copy(self.cache[key])
|
|
||||||
|
|||||||
75
src/scgenerator/utils/cache.py
Normal file
75
src/scgenerator/utils/cache.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from collections import namedtuple
|
||||||
|
from copy import copy
|
||||||
|
from functools import wraps
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
CacheInfo = namedtuple("CacheInfo", "hits misses size")
|
||||||
|
|
||||||
|
|
||||||
|
def np_cache(func):
|
||||||
|
def new_cached_func():
|
||||||
|
cache = {}
|
||||||
|
hits = misses = 0
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
nonlocal cache, hits, misses
|
||||||
|
hashable_args = tuple(
|
||||||
|
tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args
|
||||||
|
)
|
||||||
|
hashable_kwargs = tuple(
|
||||||
|
{
|
||||||
|
k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg
|
||||||
|
for k, kwarg in kwargs.items()
|
||||||
|
}.items()
|
||||||
|
)
|
||||||
|
key = hash((hashable_args, hashable_kwargs))
|
||||||
|
if key not in cache:
|
||||||
|
misses += 1
|
||||||
|
cache[key] = func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
hits += 1
|
||||||
|
return copy(cache[key])
|
||||||
|
|
||||||
|
def reset():
|
||||||
|
nonlocal cache, hits, misses
|
||||||
|
cache = {}
|
||||||
|
hits = misses = 0
|
||||||
|
|
||||||
|
wrapped.cache_info = lambda: CacheInfo(hits, misses, len(cache))
|
||||||
|
wrapped.reset = reset
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
return new_cached_func()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
@np_cache
|
||||||
|
def lol(a):
|
||||||
|
time.sleep(random.random() * 4)
|
||||||
|
return a / 2
|
||||||
|
|
||||||
|
@np_cache
|
||||||
|
def ggg(b):
|
||||||
|
time.sleep(random.random() * 4)
|
||||||
|
return b * 2
|
||||||
|
|
||||||
|
x = np.arange(6)
|
||||||
|
for i in range(5):
|
||||||
|
print(lol.cache_info())
|
||||||
|
print(lol(x))
|
||||||
|
|
||||||
|
print(f"{ggg.cache_info()=}")
|
||||||
|
print(f"{lol.cache_info()=}")
|
||||||
|
lol.reset()
|
||||||
|
|
||||||
|
print(ggg(np.arange(3)))
|
||||||
|
print(ggg(np.arange(8)))
|
||||||
|
print(ggg(np.arange(3)))
|
||||||
|
|
||||||
|
print(f"{ggg.cache_info()=}")
|
||||||
|
print(f"{lol.cache_info()=}")
|
||||||
Reference in New Issue
Block a user