better np cache, integr. plotting, nfft

This commit is contained in:
Benoît Sierro
2021-06-16 10:38:29 +02:00
parent 5de51dde37
commit 0108617b8e
7 changed files with 273 additions and 115 deletions

View File

@@ -3,6 +3,7 @@ from typing import Union
import numpy as np
from scipy.interpolate import griddata, interp1d
from scipy.special import jn_zeros
from .utils.cache import np_cache
def span(*vec):
@@ -89,6 +90,85 @@ def u_nm(n, m):
return jn_zeros(n - 1, m)[-1]
@np_cache
def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
"""creates the nfft matrix
Parameters
----------
t : np.ndarray, shape = (n,)
time array
f : np.ndarray, shape = (m,)
frequency array
Returns
-------
np.ndarray, shape = (m, n)
multiply x(t) by this matrix to get ~X(f)
"""
P, F = np.meshgrid(t, f)
return np.exp(-2j * np.pi * P * F)
@np_cache
def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
"""creates the nfft matrix
Parameters
----------
t : np.ndarray, shape = (n,)
time array
f : np.ndarray, shape = (m,)
frequency array
Returns
-------
np.ndarray, shape = (m, n)
multiply ~X(f) by this matrix to get x(t)
"""
return np.linalg.pinv(nfft_matrix(t, f))
def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
"""computes the Fourier transform of an uneven signal
Parameters
----------
t : np.ndarray, shape = (n,)
time array
s : np.ndarray, shape = (n, )
amplitute at each point of t
f : np.ndarray, shape = (m, )
desired frequencies
Returns
-------
np.ndarray, shape = (m, )
amplitude at each frequency
"""
return nfft_matrix(t, f) @ s
def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
"""computes the inverse Fourier transform of an uneven spectrum
Parameters
----------
f : np.ndarray, shape = (n,)
frequency array
a : np.ndarray, shape = (n, )
amplitude at each point of f
t : np.ndarray, shape = (m, )
time array
Returns
-------
np.ndarray, shape = (m, )
amplitude at each point of t
"""
return infft_matrix(t, f) @ a
def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"):
"""Interpolates a 2D array with the help of griddata
Parameters

View File

@@ -9,7 +9,7 @@ from scipy.interpolate import interp1d
from .. import io
from ..math import abs2, argclosest, power_fact, u_nm
from ..utils.parameter import BareParams, hc_model_specific_parameters
from ..utils import np_cache
from ..utils.cache import np_cache
from . import materials as mat
from . import units
from .units import c, pi

View File

@@ -3,7 +3,8 @@
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
from typing import Callable, Union
from dataclasses import dataclass
from ..utils.parameter import Parameter, type_checker
import numpy as np
from numpy import pi
@@ -224,6 +225,18 @@ def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]:
return unit
def is_unit(name, value):
if not hasattr(get_unit(value), "inv"):
raise TypeError("invalid unit specified")
@dataclass
class PlotRange:
left: float = Parameter(type_checker(int, float))
right: float = Parameter(type_checker(int, float))
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit)
def beta2_coef(beta):
fac = 1e27
out = np.zeros_like(beta)
@@ -263,45 +276,49 @@ def standardize_dictionary(dico):
return dico
def sort_axis(axis, plt_range):
def sort_axis(axis, plt_range: PlotRange):
"""
given an axis, returns this axis cropped according to the given range, converted and sorted
Parameters
----------
axis : 1D array containing the original axis (usual the w or t array)
plt_range : tupple (min, max, conversion_function) used to crop the axis
axis : 1D array containing the original axis (usual the w or t array)
plt_range : tupple (min, max, conversion_function) used to crop the axis
Returns
----------
cropped : the axis cropped, converted and sorted
indices : indices to use to slice and sort other array in the same fashion
extent : tupple with min and max of cropped
-------
cropped : the axis cropped, converted and sorted
indices : indices to use to slice and sort other array in the same fashion
extent : tupple with min and max of cropped
Example
----------
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
t = np.linspace(-10, 10, 400)
W, T = np.meshgrid(w, t)
y = np.exp(-W**2 - T**2)
-------
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
t = np.linspace(-10, 10, 400)
W, T = np.meshgrid(w, t)
y = np.exp(-W**2 - T**2)
# Define ranges
rw = (-4, 4, s)
rt = (-2, 6, s)
# Define ranges
rw = (-4, 4, s)
rt = (-2, 6, s)
w, cw = sort_axis(w, rw)
t, ct = sort_axis(t, rt)
w, cw = sort_axis(w, rw)
t, ct = sort_axis(t, rt)
# slice y according to the given ranges
y = y[ct][:, cw]
# slice y according to the given ranges
y = y[ct][:, cw]
"""
r = np.array(plt_range[:2], dtype="float")
func = get_unit(plt_range[2])
r = np.array((plt_range.left, plt_range.right), dtype="float")
indices = np.arange(len(axis))[(axis <= np.max(func(r))) & (axis >= np.min(func(r)))]
indices = np.arange(len(axis))[
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
]
cropped = axis[indices]
order = np.argsort(func.inv(cropped))
order = np.argsort(plt_range.unit.inv(cropped))
indices = indices[order]
cropped = cropped[order]
out_ax = func.inv(cropped)
out_ax = plt_range.unit.inv(cropped)
return out_ax, indices, (out_ax[0], out_ax[-1])

View File

@@ -248,14 +248,13 @@ def _finish_plot_2D(
ax,
file_name,
file_type,
params,
):
# apply log transform if required
if log != False:
if log is not False:
vmax = defaults["vmax"] if vmax is None else vmax
vmin = defaults["vmin"] if vmin is None else vmin
if isinstance(log, (float, int)) and log != True:
if isinstance(log, (float, int)) and log is not True:
values = units.to_log(values, ref=log)
elif log == "2D":
@@ -338,7 +337,7 @@ def _finish_plot_2D(
fig.savefig(out_path, bbox_inches="tight", dpi=200)
print(f"plot saved in {out_path}")
if cbar_label is not None:
return fig, ax, cbar.ax
return fig, (ax, cbar.ax)
else:
return fig, ax
@@ -355,7 +354,7 @@ def plot_spectrogram(
vmax: float = None,
cbar_label: str = "normalized intensity (dB)",
file_type: str = "png",
file_name: str = None,
file_name: str = "plot",
cmap: str = None,
ax: plt.Axes = None,
):
@@ -448,13 +447,12 @@ def plot_spectrogram(
ax,
file_name,
file_type,
params,
)
def plot_results_2D(
values: np.ndarray,
plt_range: RangeType,
plt_range: Union[units.PlotRange, tuple],
params: BareParams,
log: Union[int, float, bool, str] = "1D",
skip: int = 16,
@@ -463,7 +461,7 @@ def plot_results_2D(
transpose: bool = False,
cbar_label: Optional[str] = "normalized intensity (dB)",
file_type: str = "png",
file_name: str = None,
file_name: str = "plot",
cmap: str = None,
ax: plt.Axes = None,
):
@@ -528,7 +526,7 @@ def plot_results_2D(
values = abs2(values)
# make uniform if converting to wavelength
if plt_range[2].type == "WL":
if plt_range.unit.type == "WL":
if is_spectrum:
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
values = np.array(
@@ -554,7 +552,7 @@ def plot_results_2D(
return _finish_plot_2D(
values,
x_axis,
plt_range[2].label,
plt_range.unit.label,
params.z_targets,
"propagation distance (m)",
log,
@@ -566,13 +564,12 @@ def plot_results_2D(
ax,
file_name,
file_type,
params,
)
def plot_results_1D(
values: np.ndarray,
plt_range: RangeType,
plt_range: Union[units.PlotRange, tuple],
params: BareParams,
log: Union[str, int, float, bool] = False,
spacing: Union[int, float] = 1,
@@ -586,7 +583,7 @@ def plot_results_1D(
line_label: str = None,
transpose: bool = False,
**line_kwargs,
):
) -> tuple[plt.Figure, plt.Axes, np.ndarray, np.ndarray]:
"""
Parameters
@@ -649,7 +646,7 @@ def plot_results_1D(
values *= yscaling
# make uniform if converting to wavelength
if plt_range[2].type == "WL":
if plt_range.unit.type == "WL":
if is_spectrum:
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind]))
@@ -687,12 +684,12 @@ def plot_results_1D(
ax.yaxis.set_label_position("right")
ax.set_xlim(vmax, vmin)
ax.set_xlabel(ylabel)
ax.set_ylabel(plt_range[2].label)
ax.set_ylabel(plt_range.unit.label)
else:
ax.plot(x_axis, values, label=line_label, **line_kwargs)
ax.set_ylim(vmin, vmax)
ax.set_ylabel(ylabel)
ax.set_xlabel(plt_range[2].label)
ax.set_xlabel(plt_range.unit.label)
if is_new_plot:
fig.savefig(out_path, bbox_inches="tight", dpi=200)
@@ -700,10 +697,13 @@ def plot_results_1D(
return fig, ax, x_axis, values
def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams):
def _prep_plot(
values: np.ndarray, plt_range: Union[units.PlotRange, tuple], params: BareParams
) -> tuple[bool, np.ndarray, units.PlotRange]:
is_spectrum = values.dtype == "complex"
plt_range = (*plt_range[:2], units.get_unit(plt_range[2]))
if plt_range[2].type in ["WL", "FREQ", "AFREQ"]:
if not isinstance(plt_range, units.PlotRange):
plt_range = units.PlotRange(*plt_range)
if plt_range.unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = params.w.copy()
else:
x_axis = params.t.copy()
@@ -712,7 +712,7 @@ def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams):
def plot_avg(
values: np.ndarray,
plt_range: RangeType,
plt_range: Union[units.PlotRange, tuple],
params: BareParams,
log: Union[float, int, str, bool] = False,
spacing: Union[float, int] = 1,
@@ -809,7 +809,7 @@ def plot_avg(
values = abs2(values)
values *= yscaling
mean_values = np.mean(values, axis=0)
if plt_range[2].type == "WL" and renormalize:
if plt_range.unit.type == "WL" and renormalize:
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
mean_values = units.to_WL(mean_values, params.frep, x_axis)
@@ -886,7 +886,7 @@ def plot_avg(
top.set_ylim(*ext)
bot.yaxis.tick_right()
bot.yaxis.set_label_position("right")
bot.set_ylabel(plt_range[2].label)
bot.set_ylabel(plt_range.unit.label)
bot.set_ylim(*ext)
else:
for value in values:
@@ -898,7 +898,7 @@ def plot_avg(
top.set_ylim(bottom=vmin, top=vmax)
top.set_ylabel(ylabel)
top.set_xlim(*ext)
bot.set_xlabel(plt_range[2].label)
bot.set_xlabel(plt_range.unit.label)
bot.set_xlim(*ext)
custom_lines = [
@@ -961,7 +961,7 @@ def prepare_plot_1D(values, plt_range, x_axis, yscaling=1, spacing=1, frep=80e6)
values = values[:, ind]
if plt_range[2].type == "WL":
if plt_range.unit.type == "WL":
values = np.apply_along_axis(units.to_WL, -1, values, frep, x_axis)
if isinstance(spacing, float):

View File

@@ -1,14 +1,16 @@
import os
from collections.abc import Sequence
from pathlib import Path
from typing import Dict
from re import UNICODE
from typing import Callable, Dict, Iterable, Optional, Union
import numpy as np
from . import initialize, io, math
from .physics import units
from .const import SPECN_FN
from .logger import get_logger
from .plotting import units
from .plotting import plot_avg, plot_results_1D, plot_results_2D
class Spectrum(np.ndarray):
@@ -158,13 +160,12 @@ class Pulse(Sequence):
Parameters
----------
ind : int or list of int
if only certain spectra are desired.
- If left to None, returns every spectrum
- If only 1 int, will cast the (1, n, nt) array into a (n, nt) array
if only certain spectra are desired
Returns
----------
spectra : array
squeezed array of complex spectra (n simulation on a nt size grid at each ind)
spectra : array of shape (nz, m, nt)
array of complex spectra (pulse at nz positions consisting
of nm simulation on a nt size grid)
"""
self.logger.debug(f"opening {self.path}")
@@ -200,3 +201,46 @@ class Pulse(Sequence):
spec = Spectrum(spec, self.wl, self.params.frep)
self.cache[i] = spec
return spec
def plot_2D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_ind: Union[int, Iterable[int]] = None,
sim_ind: int = 0,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
return plot_results_2D(vals, plt_range, self.params, **kwargs)
def plot_1D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_ind: int,
sim_ind: int = 0,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
return plot_results_1D(vals[0], plt_range, self.params, **kwargs)
def plot_avg(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_ind: int,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, slice(None))
return plot_avg(vals, plt_range, self.params, **kwargs)
def retrieve_plot_values(self, left, right, unit, z_ind, sim_ind):
plt_range = units.PlotRange(left, right, unit)
if plt_range.unit.type == "TIME":
vals = self.all_fields(ind=z_ind)[:, sim_ind]
else:
vals = self.all_spectra(ind=z_ind)[:, sim_ind]
return plt_range, vals

View File

@@ -283,61 +283,3 @@ def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
for k in new:
variable.pop(k, None) # remove old ones
return replace(old, variable=variable, **{k: None for k in variable}, **new)
# def np_cache(function):
# """applies functools.cache to function that take numpy arrays as input"""
# @cache
# def cached_wrapper(*hashable_args, **hashable_kwargs):
# args = tuple(np.array(arg) if isinstance(arg, tuple) else arg for arg in hashable_args)
# kwargs = {
# k: np.array(kwarg) if isinstance(kwarg, tuple) else kwarg
# for k, kwarg in hashable_kwargs.items()
# }
# return function(*args, **kwargs)
# @wraps(function)
# def wrapper(*args, **kwargs):
# hashable_args = tuple(tuple(arg) if isinstance(arg, np.ndarray) else arg for arg in args)
# hashable_kwargs = {
# k: tuple(kwarg) if isinstance(kwarg, np.ndarray) else kwarg
# for k, kwarg in kwargs.items()
# }
# return cached_wrapper(*hashable_args, **hashable_kwargs)
# # copy lru_cache attributes over too
# wrapper.cache_info = cached_wrapper.cache_info
# wrapper.cache_clear = cached_wrapper.cache_clear
# return wrapper
class np_cache:
def __init__(self, function):
self.logger = get_logger(__name__)
self.func = function
self.cache = {}
self.hits = 0
self.misses = 0
update_wrapper(self, function)
def __call__(self, *args, **kwargs):
hashable_args = tuple(
tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args
)
hashable_kwargs = tuple(
{
k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg
for k, kwarg in kwargs.items()
}.items()
)
key = hash((hashable_args, hashable_kwargs))
if key not in self.cache:
self.logger.debug("cache miss")
self.misses += 1
self.cache[key] = self.func(*args, **kwargs)
else:
self.hits += 1
self.logger.debug("cache hit")
return copy(self.cache[key])

View File

@@ -0,0 +1,75 @@
from collections import namedtuple
from copy import copy
from functools import wraps
import numpy as np
CacheInfo = namedtuple("CacheInfo", "hits misses size")
def np_cache(func):
def new_cached_func():
cache = {}
hits = misses = 0
@wraps(func)
def wrapped(*args, **kwargs):
nonlocal cache, hits, misses
hashable_args = tuple(
tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args
)
hashable_kwargs = tuple(
{
k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg
for k, kwarg in kwargs.items()
}.items()
)
key = hash((hashable_args, hashable_kwargs))
if key not in cache:
misses += 1
cache[key] = func(*args, **kwargs)
else:
hits += 1
return copy(cache[key])
def reset():
nonlocal cache, hits, misses
cache = {}
hits = misses = 0
wrapped.cache_info = lambda: CacheInfo(hits, misses, len(cache))
wrapped.reset = reset
return wrapped
return new_cached_func()
if __name__ == "__main__":
import random
import time
@np_cache
def lol(a):
time.sleep(random.random() * 4)
return a / 2
@np_cache
def ggg(b):
time.sleep(random.random() * 4)
return b * 2
x = np.arange(6)
for i in range(5):
print(lol.cache_info())
print(lol(x))
print(f"{ggg.cache_info()=}")
print(f"{lol.cache_info()=}")
lol.reset()
print(ggg(np.arange(3)))
print(ggg(np.arange(8)))
print(ggg(np.arange(3)))
print(f"{ggg.cache_info()=}")
print(f"{lol.cache_info()=}")