better np cache, integr. plotting, nfft

This commit is contained in:
Benoît Sierro
2021-06-16 10:38:29 +02:00
parent 5de51dde37
commit 0108617b8e
7 changed files with 273 additions and 115 deletions

View File

@@ -3,6 +3,7 @@ from typing import Union
import numpy as np import numpy as np
from scipy.interpolate import griddata, interp1d from scipy.interpolate import griddata, interp1d
from scipy.special import jn_zeros from scipy.special import jn_zeros
from .utils.cache import np_cache
def span(*vec): def span(*vec):
@@ -89,6 +90,85 @@ def u_nm(n, m):
return jn_zeros(n - 1, m)[-1] return jn_zeros(n - 1, m)[-1]
@np_cache
def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
"""creates the nfft matrix
Parameters
----------
t : np.ndarray, shape = (n,)
time array
f : np.ndarray, shape = (m,)
frequency array
Returns
-------
np.ndarray, shape = (m, n)
multiply x(t) by this matrix to get ~X(f)
"""
P, F = np.meshgrid(t, f)
return np.exp(-2j * np.pi * P * F)
@np_cache
def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
"""creates the nfft matrix
Parameters
----------
t : np.ndarray, shape = (n,)
time array
f : np.ndarray, shape = (m,)
frequency array
Returns
-------
np.ndarray, shape = (m, n)
multiply ~X(f) by this matrix to get x(t)
"""
return np.linalg.pinv(nfft_matrix(t, f))
def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
"""computes the Fourier transform of an uneven signal
Parameters
----------
t : np.ndarray, shape = (n,)
time array
s : np.ndarray, shape = (n, )
amplitute at each point of t
f : np.ndarray, shape = (m, )
desired frequencies
Returns
-------
np.ndarray, shape = (m, )
amplitude at each frequency
"""
return nfft_matrix(t, f) @ s
def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
"""computes the inverse Fourier transform of an uneven spectrum
Parameters
----------
f : np.ndarray, shape = (n,)
frequency array
a : np.ndarray, shape = (n, )
amplitude at each point of f
t : np.ndarray, shape = (m, )
time array
Returns
-------
np.ndarray, shape = (m, )
amplitude at each point of t
"""
return infft_matrix(t, f) @ a
def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"): def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"):
"""Interpolates a 2D array with the help of griddata """Interpolates a 2D array with the help of griddata
Parameters Parameters

View File

@@ -9,7 +9,7 @@ from scipy.interpolate import interp1d
from .. import io from .. import io
from ..math import abs2, argclosest, power_fact, u_nm from ..math import abs2, argclosest, power_fact, u_nm
from ..utils.parameter import BareParams, hc_model_specific_parameters from ..utils.parameter import BareParams, hc_model_specific_parameters
from ..utils import np_cache from ..utils.cache import np_cache
from . import materials as mat from . import materials as mat
from . import units from . import units
from .units import c, pi from .units import c, pi

View File

@@ -3,7 +3,8 @@
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ... # to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
from typing import Callable, Union from typing import Callable, Union
from dataclasses import dataclass
from ..utils.parameter import Parameter, type_checker
import numpy as np import numpy as np
from numpy import pi from numpy import pi
@@ -224,6 +225,18 @@ def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]:
return unit return unit
def is_unit(name, value):
if not hasattr(get_unit(value), "inv"):
raise TypeError("invalid unit specified")
@dataclass
class PlotRange:
left: float = Parameter(type_checker(int, float))
right: float = Parameter(type_checker(int, float))
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit)
def beta2_coef(beta): def beta2_coef(beta):
fac = 1e27 fac = 1e27
out = np.zeros_like(beta) out = np.zeros_like(beta)
@@ -263,45 +276,49 @@ def standardize_dictionary(dico):
return dico return dico
def sort_axis(axis, plt_range): def sort_axis(axis, plt_range: PlotRange):
""" """
given an axis, returns this axis cropped according to the given range, converted and sorted given an axis, returns this axis cropped according to the given range, converted and sorted
Parameters Parameters
---------- ----------
axis : 1D array containing the original axis (usual the w or t array) axis : 1D array containing the original axis (usual the w or t array)
plt_range : tupple (min, max, conversion_function) used to crop the axis plt_range : tupple (min, max, conversion_function) used to crop the axis
Returns Returns
---------- -------
cropped : the axis cropped, converted and sorted cropped : the axis cropped, converted and sorted
indices : indices to use to slice and sort other array in the same fashion indices : indices to use to slice and sort other array in the same fashion
extent : tupple with min and max of cropped extent : tupple with min and max of cropped
Example Example
---------- -------
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20)) w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
t = np.linspace(-10, 10, 400) t = np.linspace(-10, 10, 400)
W, T = np.meshgrid(w, t) W, T = np.meshgrid(w, t)
y = np.exp(-W**2 - T**2) y = np.exp(-W**2 - T**2)
# Define ranges # Define ranges
rw = (-4, 4, s) rw = (-4, 4, s)
rt = (-2, 6, s) rt = (-2, 6, s)
w, cw = sort_axis(w, rw) w, cw = sort_axis(w, rw)
t, ct = sort_axis(t, rt) t, ct = sort_axis(t, rt)
# slice y according to the given ranges # slice y according to the given ranges
y = y[ct][:, cw] y = y[ct][:, cw]
""" """
r = np.array(plt_range[:2], dtype="float") r = np.array((plt_range.left, plt_range.right), dtype="float")
func = get_unit(plt_range[2])
indices = np.arange(len(axis))[(axis <= np.max(func(r))) & (axis >= np.min(func(r)))] indices = np.arange(len(axis))[
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
]
cropped = axis[indices] cropped = axis[indices]
order = np.argsort(func.inv(cropped)) order = np.argsort(plt_range.unit.inv(cropped))
indices = indices[order] indices = indices[order]
cropped = cropped[order] cropped = cropped[order]
out_ax = func.inv(cropped) out_ax = plt_range.unit.inv(cropped)
return out_ax, indices, (out_ax[0], out_ax[-1]) return out_ax, indices, (out_ax[0], out_ax[-1])

View File

@@ -248,14 +248,13 @@ def _finish_plot_2D(
ax, ax,
file_name, file_name,
file_type, file_type,
params,
): ):
# apply log transform if required # apply log transform if required
if log != False: if log is not False:
vmax = defaults["vmax"] if vmax is None else vmax vmax = defaults["vmax"] if vmax is None else vmax
vmin = defaults["vmin"] if vmin is None else vmin vmin = defaults["vmin"] if vmin is None else vmin
if isinstance(log, (float, int)) and log != True: if isinstance(log, (float, int)) and log is not True:
values = units.to_log(values, ref=log) values = units.to_log(values, ref=log)
elif log == "2D": elif log == "2D":
@@ -338,7 +337,7 @@ def _finish_plot_2D(
fig.savefig(out_path, bbox_inches="tight", dpi=200) fig.savefig(out_path, bbox_inches="tight", dpi=200)
print(f"plot saved in {out_path}") print(f"plot saved in {out_path}")
if cbar_label is not None: if cbar_label is not None:
return fig, ax, cbar.ax return fig, (ax, cbar.ax)
else: else:
return fig, ax return fig, ax
@@ -355,7 +354,7 @@ def plot_spectrogram(
vmax: float = None, vmax: float = None,
cbar_label: str = "normalized intensity (dB)", cbar_label: str = "normalized intensity (dB)",
file_type: str = "png", file_type: str = "png",
file_name: str = None, file_name: str = "plot",
cmap: str = None, cmap: str = None,
ax: plt.Axes = None, ax: plt.Axes = None,
): ):
@@ -448,13 +447,12 @@ def plot_spectrogram(
ax, ax,
file_name, file_name,
file_type, file_type,
params,
) )
def plot_results_2D( def plot_results_2D(
values: np.ndarray, values: np.ndarray,
plt_range: RangeType, plt_range: Union[units.PlotRange, tuple],
params: BareParams, params: BareParams,
log: Union[int, float, bool, str] = "1D", log: Union[int, float, bool, str] = "1D",
skip: int = 16, skip: int = 16,
@@ -463,7 +461,7 @@ def plot_results_2D(
transpose: bool = False, transpose: bool = False,
cbar_label: Optional[str] = "normalized intensity (dB)", cbar_label: Optional[str] = "normalized intensity (dB)",
file_type: str = "png", file_type: str = "png",
file_name: str = None, file_name: str = "plot",
cmap: str = None, cmap: str = None,
ax: plt.Axes = None, ax: plt.Axes = None,
): ):
@@ -528,7 +526,7 @@ def plot_results_2D(
values = abs2(values) values = abs2(values)
# make uniform if converting to wavelength # make uniform if converting to wavelength
if plt_range[2].type == "WL": if plt_range.unit.type == "WL":
if is_spectrum: if is_spectrum:
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis) values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
values = np.array( values = np.array(
@@ -554,7 +552,7 @@ def plot_results_2D(
return _finish_plot_2D( return _finish_plot_2D(
values, values,
x_axis, x_axis,
plt_range[2].label, plt_range.unit.label,
params.z_targets, params.z_targets,
"propagation distance (m)", "propagation distance (m)",
log, log,
@@ -566,13 +564,12 @@ def plot_results_2D(
ax, ax,
file_name, file_name,
file_type, file_type,
params,
) )
def plot_results_1D( def plot_results_1D(
values: np.ndarray, values: np.ndarray,
plt_range: RangeType, plt_range: Union[units.PlotRange, tuple],
params: BareParams, params: BareParams,
log: Union[str, int, float, bool] = False, log: Union[str, int, float, bool] = False,
spacing: Union[int, float] = 1, spacing: Union[int, float] = 1,
@@ -586,7 +583,7 @@ def plot_results_1D(
line_label: str = None, line_label: str = None,
transpose: bool = False, transpose: bool = False,
**line_kwargs, **line_kwargs,
): ) -> tuple[plt.Figure, plt.Axes, np.ndarray, np.ndarray]:
""" """
Parameters Parameters
@@ -649,7 +646,7 @@ def plot_results_1D(
values *= yscaling values *= yscaling
# make uniform if converting to wavelength # make uniform if converting to wavelength
if plt_range[2].type == "WL": if plt_range.unit.type == "WL":
if is_spectrum: if is_spectrum:
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind])) values = units.to_WL(values, params.frep, units.m.inv(params.w[ind]))
@@ -687,12 +684,12 @@ def plot_results_1D(
ax.yaxis.set_label_position("right") ax.yaxis.set_label_position("right")
ax.set_xlim(vmax, vmin) ax.set_xlim(vmax, vmin)
ax.set_xlabel(ylabel) ax.set_xlabel(ylabel)
ax.set_ylabel(plt_range[2].label) ax.set_ylabel(plt_range.unit.label)
else: else:
ax.plot(x_axis, values, label=line_label, **line_kwargs) ax.plot(x_axis, values, label=line_label, **line_kwargs)
ax.set_ylim(vmin, vmax) ax.set_ylim(vmin, vmax)
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.set_xlabel(plt_range[2].label) ax.set_xlabel(plt_range.unit.label)
if is_new_plot: if is_new_plot:
fig.savefig(out_path, bbox_inches="tight", dpi=200) fig.savefig(out_path, bbox_inches="tight", dpi=200)
@@ -700,10 +697,13 @@ def plot_results_1D(
return fig, ax, x_axis, values return fig, ax, x_axis, values
def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams): def _prep_plot(
values: np.ndarray, plt_range: Union[units.PlotRange, tuple], params: BareParams
) -> tuple[bool, np.ndarray, units.PlotRange]:
is_spectrum = values.dtype == "complex" is_spectrum = values.dtype == "complex"
plt_range = (*plt_range[:2], units.get_unit(plt_range[2])) if not isinstance(plt_range, units.PlotRange):
if plt_range[2].type in ["WL", "FREQ", "AFREQ"]: plt_range = units.PlotRange(*plt_range)
if plt_range.unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = params.w.copy() x_axis = params.w.copy()
else: else:
x_axis = params.t.copy() x_axis = params.t.copy()
@@ -712,7 +712,7 @@ def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams):
def plot_avg( def plot_avg(
values: np.ndarray, values: np.ndarray,
plt_range: RangeType, plt_range: Union[units.PlotRange, tuple],
params: BareParams, params: BareParams,
log: Union[float, int, str, bool] = False, log: Union[float, int, str, bool] = False,
spacing: Union[float, int] = 1, spacing: Union[float, int] = 1,
@@ -809,7 +809,7 @@ def plot_avg(
values = abs2(values) values = abs2(values)
values *= yscaling values *= yscaling
mean_values = np.mean(values, axis=0) mean_values = np.mean(values, axis=0)
if plt_range[2].type == "WL" and renormalize: if plt_range.unit.type == "WL" and renormalize:
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis) values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
mean_values = units.to_WL(mean_values, params.frep, x_axis) mean_values = units.to_WL(mean_values, params.frep, x_axis)
@@ -886,7 +886,7 @@ def plot_avg(
top.set_ylim(*ext) top.set_ylim(*ext)
bot.yaxis.tick_right() bot.yaxis.tick_right()
bot.yaxis.set_label_position("right") bot.yaxis.set_label_position("right")
bot.set_ylabel(plt_range[2].label) bot.set_ylabel(plt_range.unit.label)
bot.set_ylim(*ext) bot.set_ylim(*ext)
else: else:
for value in values: for value in values:
@@ -898,7 +898,7 @@ def plot_avg(
top.set_ylim(bottom=vmin, top=vmax) top.set_ylim(bottom=vmin, top=vmax)
top.set_ylabel(ylabel) top.set_ylabel(ylabel)
top.set_xlim(*ext) top.set_xlim(*ext)
bot.set_xlabel(plt_range[2].label) bot.set_xlabel(plt_range.unit.label)
bot.set_xlim(*ext) bot.set_xlim(*ext)
custom_lines = [ custom_lines = [
@@ -961,7 +961,7 @@ def prepare_plot_1D(values, plt_range, x_axis, yscaling=1, spacing=1, frep=80e6)
values = values[:, ind] values = values[:, ind]
if plt_range[2].type == "WL": if plt_range.unit.type == "WL":
values = np.apply_along_axis(units.to_WL, -1, values, frep, x_axis) values = np.apply_along_axis(units.to_WL, -1, values, frep, x_axis)
if isinstance(spacing, float): if isinstance(spacing, float):

View File

@@ -1,14 +1,16 @@
import os import os
from collections.abc import Sequence from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from typing import Dict from re import UNICODE
from typing import Callable, Dict, Iterable, Optional, Union
import numpy as np import numpy as np
from . import initialize, io, math from . import initialize, io, math
from .physics import units
from .const import SPECN_FN from .const import SPECN_FN
from .logger import get_logger from .logger import get_logger
from .plotting import units from .plotting import plot_avg, plot_results_1D, plot_results_2D
class Spectrum(np.ndarray): class Spectrum(np.ndarray):
@@ -158,13 +160,12 @@ class Pulse(Sequence):
Parameters Parameters
---------- ----------
ind : int or list of int ind : int or list of int
if only certain spectra are desired. if only certain spectra are desired
- If left to None, returns every spectrum
- If only 1 int, will cast the (1, n, nt) array into a (n, nt) array
Returns Returns
---------- ----------
spectra : array spectra : array of shape (nz, m, nt)
squeezed array of complex spectra (n simulation on a nt size grid at each ind) array of complex spectra (pulse at nz positions consisting
of nm simulation on a nt size grid)
""" """
self.logger.debug(f"opening {self.path}") self.logger.debug(f"opening {self.path}")
@@ -200,3 +201,46 @@ class Pulse(Sequence):
spec = Spectrum(spec, self.wl, self.params.frep) spec = Spectrum(spec, self.wl, self.params.frep)
self.cache[i] = spec self.cache[i] = spec
return spec return spec
def plot_2D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_ind: Union[int, Iterable[int]] = None,
sim_ind: int = 0,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
return plot_results_2D(vals, plt_range, self.params, **kwargs)
def plot_1D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_ind: int,
sim_ind: int = 0,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
return plot_results_1D(vals[0], plt_range, self.params, **kwargs)
def plot_avg(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_ind: int,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, slice(None))
return plot_avg(vals, plt_range, self.params, **kwargs)
def retrieve_plot_values(self, left, right, unit, z_ind, sim_ind):
plt_range = units.PlotRange(left, right, unit)
if plt_range.unit.type == "TIME":
vals = self.all_fields(ind=z_ind)[:, sim_ind]
else:
vals = self.all_spectra(ind=z_ind)[:, sim_ind]
return plt_range, vals

View File

@@ -283,61 +283,3 @@ def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
for k in new: for k in new:
variable.pop(k, None) # remove old ones variable.pop(k, None) # remove old ones
return replace(old, variable=variable, **{k: None for k in variable}, **new) return replace(old, variable=variable, **{k: None for k in variable}, **new)
# def np_cache(function):
# """applies functools.cache to function that take numpy arrays as input"""
# @cache
# def cached_wrapper(*hashable_args, **hashable_kwargs):
# args = tuple(np.array(arg) if isinstance(arg, tuple) else arg for arg in hashable_args)
# kwargs = {
# k: np.array(kwarg) if isinstance(kwarg, tuple) else kwarg
# for k, kwarg in hashable_kwargs.items()
# }
# return function(*args, **kwargs)
# @wraps(function)
# def wrapper(*args, **kwargs):
# hashable_args = tuple(tuple(arg) if isinstance(arg, np.ndarray) else arg for arg in args)
# hashable_kwargs = {
# k: tuple(kwarg) if isinstance(kwarg, np.ndarray) else kwarg
# for k, kwarg in kwargs.items()
# }
# return cached_wrapper(*hashable_args, **hashable_kwargs)
# # copy lru_cache attributes over too
# wrapper.cache_info = cached_wrapper.cache_info
# wrapper.cache_clear = cached_wrapper.cache_clear
# return wrapper
class np_cache:
def __init__(self, function):
self.logger = get_logger(__name__)
self.func = function
self.cache = {}
self.hits = 0
self.misses = 0
update_wrapper(self, function)
def __call__(self, *args, **kwargs):
hashable_args = tuple(
tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args
)
hashable_kwargs = tuple(
{
k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg
for k, kwarg in kwargs.items()
}.items()
)
key = hash((hashable_args, hashable_kwargs))
if key not in self.cache:
self.logger.debug("cache miss")
self.misses += 1
self.cache[key] = self.func(*args, **kwargs)
else:
self.hits += 1
self.logger.debug("cache hit")
return copy(self.cache[key])

View File

@@ -0,0 +1,75 @@
from collections import namedtuple
from copy import copy
from functools import wraps
import numpy as np
CacheInfo = namedtuple("CacheInfo", "hits misses size")
def np_cache(func):
def new_cached_func():
cache = {}
hits = misses = 0
@wraps(func)
def wrapped(*args, **kwargs):
nonlocal cache, hits, misses
hashable_args = tuple(
tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args
)
hashable_kwargs = tuple(
{
k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg
for k, kwarg in kwargs.items()
}.items()
)
key = hash((hashable_args, hashable_kwargs))
if key not in cache:
misses += 1
cache[key] = func(*args, **kwargs)
else:
hits += 1
return copy(cache[key])
def reset():
nonlocal cache, hits, misses
cache = {}
hits = misses = 0
wrapped.cache_info = lambda: CacheInfo(hits, misses, len(cache))
wrapped.reset = reset
return wrapped
return new_cached_func()
if __name__ == "__main__":
import random
import time
@np_cache
def lol(a):
time.sleep(random.random() * 4)
return a / 2
@np_cache
def ggg(b):
time.sleep(random.random() * 4)
return b * 2
x = np.arange(6)
for i in range(5):
print(lol.cache_info())
print(lol(x))
print(f"{ggg.cache_info()=}")
print(f"{lol.cache_info()=}")
lol.reset()
print(ggg(np.arange(3)))
print(ggg(np.arange(8)))
print(ggg(np.arange(3)))
print(f"{ggg.cache_info()=}")
print(f"{lol.cache_info()=}")