diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 6e1b0c5..61c1150 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -3,6 +3,7 @@ from typing import Union import numpy as np from scipy.interpolate import griddata, interp1d from scipy.special import jn_zeros +from .utils.cache import np_cache def span(*vec): @@ -89,6 +90,85 @@ def u_nm(n, m): return jn_zeros(n - 1, m)[-1] +@np_cache +def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray: + """creates the nfft matrix + + Parameters + ---------- + t : np.ndarray, shape = (n,) + time array + f : np.ndarray, shape = (m,) + frequency array + + Returns + ------- + np.ndarray, shape = (m, n) + multiply x(t) by this matrix to get ~X(f) + """ + P, F = np.meshgrid(t, f) + return np.exp(-2j * np.pi * P * F) + + +@np_cache +def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray: + """creates the nfft matrix + + Parameters + ---------- + t : np.ndarray, shape = (n,) + time array + f : np.ndarray, shape = (m,) + frequency array + + Returns + ------- + np.ndarray, shape = (m, n) + multiply ~X(f) by this matrix to get x(t) + """ + return np.linalg.pinv(nfft_matrix(t, f)) + + +def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray: + """computes the Fourier transform of an uneven signal + + Parameters + ---------- + t : np.ndarray, shape = (n,) + time array + s : np.ndarray, shape = (n, ) + amplitute at each point of t + f : np.ndarray, shape = (m, ) + desired frequencies + + Returns + ------- + np.ndarray, shape = (m, ) + amplitude at each frequency + """ + return nfft_matrix(t, f) @ s + + +def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray: + """computes the inverse Fourier transform of an uneven spectrum + + Parameters + ---------- + f : np.ndarray, shape = (n,) + frequency array + a : np.ndarray, shape = (n, ) + amplitude at each point of f + t : np.ndarray, shape = (m, ) + time array + + Returns + ------- + np.ndarray, shape = (m, ) + amplitude at each point of t + """ + return infft_matrix(t, f) @ a + + def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"): """Interpolates a 2D array with the help of griddata Parameters diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index edb218f..d08e529 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -9,7 +9,7 @@ from scipy.interpolate import interp1d from .. import io from ..math import abs2, argclosest, power_fact, u_nm from ..utils.parameter import BareParams, hc_model_specific_parameters -from ..utils import np_cache +from ..utils.cache import np_cache from . import materials as mat from . import units from .units import c, pi diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index a9cb97b..f832c15 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -3,7 +3,8 @@ # to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ... from typing import Callable, Union - +from dataclasses import dataclass +from ..utils.parameter import Parameter, type_checker import numpy as np from numpy import pi @@ -224,6 +225,18 @@ def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]: return unit +def is_unit(name, value): + if not hasattr(get_unit(value), "inv"): + raise TypeError("invalid unit specified") + + +@dataclass +class PlotRange: + left: float = Parameter(type_checker(int, float)) + right: float = Parameter(type_checker(int, float)) + unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit) + + def beta2_coef(beta): fac = 1e27 out = np.zeros_like(beta) @@ -263,45 +276,49 @@ def standardize_dictionary(dico): return dico -def sort_axis(axis, plt_range): +def sort_axis(axis, plt_range: PlotRange): """ given an axis, returns this axis cropped according to the given range, converted and sorted + Parameters ---------- - axis : 1D array containing the original axis (usual the w or t array) - plt_range : tupple (min, max, conversion_function) used to crop the axis + axis : 1D array containing the original axis (usual the w or t array) + plt_range : tupple (min, max, conversion_function) used to crop the axis + Returns - ---------- - cropped : the axis cropped, converted and sorted - indices : indices to use to slice and sort other array in the same fashion - extent : tupple with min and max of cropped + ------- + cropped : the axis cropped, converted and sorted + indices : indices to use to slice and sort other array in the same fashion + extent : tupple with min and max of cropped + Example - ---------- - w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20)) - t = np.linspace(-10, 10, 400) - W, T = np.meshgrid(w, t) - y = np.exp(-W**2 - T**2) + ------- + w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20)) + t = np.linspace(-10, 10, 400) + W, T = np.meshgrid(w, t) + y = np.exp(-W**2 - T**2) - # Define ranges - rw = (-4, 4, s) - rt = (-2, 6, s) + # Define ranges + rw = (-4, 4, s) + rt = (-2, 6, s) - w, cw = sort_axis(w, rw) - t, ct = sort_axis(t, rt) + w, cw = sort_axis(w, rw) + t, ct = sort_axis(t, rt) - # slice y according to the given ranges - y = y[ct][:, cw] + # slice y according to the given ranges + y = y[ct][:, cw] """ - r = np.array(plt_range[:2], dtype="float") - func = get_unit(plt_range[2]) + r = np.array((plt_range.left, plt_range.right), dtype="float") - indices = np.arange(len(axis))[(axis <= np.max(func(r))) & (axis >= np.min(func(r)))] + indices = np.arange(len(axis))[ + (axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r))) + ] cropped = axis[indices] - order = np.argsort(func.inv(cropped)) + order = np.argsort(plt_range.unit.inv(cropped)) indices = indices[order] cropped = cropped[order] - out_ax = func.inv(cropped) + out_ax = plt_range.unit.inv(cropped) return out_ax, indices, (out_ax[0], out_ax[-1]) diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 864489c..3de9cba 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -248,14 +248,13 @@ def _finish_plot_2D( ax, file_name, file_type, - params, ): # apply log transform if required - if log != False: + if log is not False: vmax = defaults["vmax"] if vmax is None else vmax vmin = defaults["vmin"] if vmin is None else vmin - if isinstance(log, (float, int)) and log != True: + if isinstance(log, (float, int)) and log is not True: values = units.to_log(values, ref=log) elif log == "2D": @@ -338,7 +337,7 @@ def _finish_plot_2D( fig.savefig(out_path, bbox_inches="tight", dpi=200) print(f"plot saved in {out_path}") if cbar_label is not None: - return fig, ax, cbar.ax + return fig, (ax, cbar.ax) else: return fig, ax @@ -355,7 +354,7 @@ def plot_spectrogram( vmax: float = None, cbar_label: str = "normalized intensity (dB)", file_type: str = "png", - file_name: str = None, + file_name: str = "plot", cmap: str = None, ax: plt.Axes = None, ): @@ -448,13 +447,12 @@ def plot_spectrogram( ax, file_name, file_type, - params, ) def plot_results_2D( values: np.ndarray, - plt_range: RangeType, + plt_range: Union[units.PlotRange, tuple], params: BareParams, log: Union[int, float, bool, str] = "1D", skip: int = 16, @@ -463,7 +461,7 @@ def plot_results_2D( transpose: bool = False, cbar_label: Optional[str] = "normalized intensity (dB)", file_type: str = "png", - file_name: str = None, + file_name: str = "plot", cmap: str = None, ax: plt.Axes = None, ): @@ -528,7 +526,7 @@ def plot_results_2D( values = abs2(values) # make uniform if converting to wavelength - if plt_range[2].type == "WL": + if plt_range.unit.type == "WL": if is_spectrum: values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis) values = np.array( @@ -554,7 +552,7 @@ def plot_results_2D( return _finish_plot_2D( values, x_axis, - plt_range[2].label, + plt_range.unit.label, params.z_targets, "propagation distance (m)", log, @@ -566,13 +564,12 @@ def plot_results_2D( ax, file_name, file_type, - params, ) def plot_results_1D( values: np.ndarray, - plt_range: RangeType, + plt_range: Union[units.PlotRange, tuple], params: BareParams, log: Union[str, int, float, bool] = False, spacing: Union[int, float] = 1, @@ -586,7 +583,7 @@ def plot_results_1D( line_label: str = None, transpose: bool = False, **line_kwargs, -): +) -> tuple[plt.Figure, plt.Axes, np.ndarray, np.ndarray]: """ Parameters @@ -649,7 +646,7 @@ def plot_results_1D( values *= yscaling # make uniform if converting to wavelength - if plt_range[2].type == "WL": + if plt_range.unit.type == "WL": if is_spectrum: values = units.to_WL(values, params.frep, units.m.inv(params.w[ind])) @@ -687,12 +684,12 @@ def plot_results_1D( ax.yaxis.set_label_position("right") ax.set_xlim(vmax, vmin) ax.set_xlabel(ylabel) - ax.set_ylabel(plt_range[2].label) + ax.set_ylabel(plt_range.unit.label) else: ax.plot(x_axis, values, label=line_label, **line_kwargs) ax.set_ylim(vmin, vmax) ax.set_ylabel(ylabel) - ax.set_xlabel(plt_range[2].label) + ax.set_xlabel(plt_range.unit.label) if is_new_plot: fig.savefig(out_path, bbox_inches="tight", dpi=200) @@ -700,10 +697,13 @@ def plot_results_1D( return fig, ax, x_axis, values -def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams): +def _prep_plot( + values: np.ndarray, plt_range: Union[units.PlotRange, tuple], params: BareParams +) -> tuple[bool, np.ndarray, units.PlotRange]: is_spectrum = values.dtype == "complex" - plt_range = (*plt_range[:2], units.get_unit(plt_range[2])) - if plt_range[2].type in ["WL", "FREQ", "AFREQ"]: + if not isinstance(plt_range, units.PlotRange): + plt_range = units.PlotRange(*plt_range) + if plt_range.unit.type in ["WL", "FREQ", "AFREQ"]: x_axis = params.w.copy() else: x_axis = params.t.copy() @@ -712,7 +712,7 @@ def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams): def plot_avg( values: np.ndarray, - plt_range: RangeType, + plt_range: Union[units.PlotRange, tuple], params: BareParams, log: Union[float, int, str, bool] = False, spacing: Union[float, int] = 1, @@ -809,7 +809,7 @@ def plot_avg( values = abs2(values) values *= yscaling mean_values = np.mean(values, axis=0) - if plt_range[2].type == "WL" and renormalize: + if plt_range.unit.type == "WL" and renormalize: values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis) mean_values = units.to_WL(mean_values, params.frep, x_axis) @@ -886,7 +886,7 @@ def plot_avg( top.set_ylim(*ext) bot.yaxis.tick_right() bot.yaxis.set_label_position("right") - bot.set_ylabel(plt_range[2].label) + bot.set_ylabel(plt_range.unit.label) bot.set_ylim(*ext) else: for value in values: @@ -898,7 +898,7 @@ def plot_avg( top.set_ylim(bottom=vmin, top=vmax) top.set_ylabel(ylabel) top.set_xlim(*ext) - bot.set_xlabel(plt_range[2].label) + bot.set_xlabel(plt_range.unit.label) bot.set_xlim(*ext) custom_lines = [ @@ -961,7 +961,7 @@ def prepare_plot_1D(values, plt_range, x_axis, yscaling=1, spacing=1, frep=80e6) values = values[:, ind] - if plt_range[2].type == "WL": + if plt_range.unit.type == "WL": values = np.apply_along_axis(units.to_WL, -1, values, frep, x_axis) if isinstance(spacing, float): diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 159c590..6329256 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -1,14 +1,16 @@ import os from collections.abc import Sequence from pathlib import Path -from typing import Dict +from re import UNICODE +from typing import Callable, Dict, Iterable, Optional, Union import numpy as np from . import initialize, io, math +from .physics import units from .const import SPECN_FN from .logger import get_logger -from .plotting import units +from .plotting import plot_avg, plot_results_1D, plot_results_2D class Spectrum(np.ndarray): @@ -158,13 +160,12 @@ class Pulse(Sequence): Parameters ---------- ind : int or list of int - if only certain spectra are desired. - - If left to None, returns every spectrum - - If only 1 int, will cast the (1, n, nt) array into a (n, nt) array + if only certain spectra are desired Returns ---------- - spectra : array - squeezed array of complex spectra (n simulation on a nt size grid at each ind) + spectra : array of shape (nz, m, nt) + array of complex spectra (pulse at nz positions consisting + of nm simulation on a nt size grid) """ self.logger.debug(f"opening {self.path}") @@ -200,3 +201,46 @@ class Pulse(Sequence): spec = Spectrum(spec, self.wl, self.params.frep) self.cache[i] = spec return spec + + def plot_2D( + self, + left: float, + right: float, + unit: Union[Callable[[float], float], str], + z_ind: Union[int, Iterable[int]] = None, + sim_ind: int = 0, + **kwargs, + ): + plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind) + return plot_results_2D(vals, plt_range, self.params, **kwargs) + + def plot_1D( + self, + left: float, + right: float, + unit: Union[Callable[[float], float], str], + z_ind: int, + sim_ind: int = 0, + **kwargs, + ): + plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind) + return plot_results_1D(vals[0], plt_range, self.params, **kwargs) + + def plot_avg( + self, + left: float, + right: float, + unit: Union[Callable[[float], float], str], + z_ind: int, + **kwargs, + ): + plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, slice(None)) + return plot_avg(vals, plt_range, self.params, **kwargs) + + def retrieve_plot_values(self, left, right, unit, z_ind, sim_ind): + plt_range = units.PlotRange(left, right, unit) + if plt_range.unit.type == "TIME": + vals = self.all_fields(ind=z_ind)[:, sim_ind] + else: + vals = self.all_spectra(ind=z_ind)[:, sim_ind] + return plt_range, vals diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index 00750f4..7c81e99 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -283,61 +283,3 @@ def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig: for k in new: variable.pop(k, None) # remove old ones return replace(old, variable=variable, **{k: None for k in variable}, **new) - - -# def np_cache(function): -# """applies functools.cache to function that take numpy arrays as input""" - -# @cache -# def cached_wrapper(*hashable_args, **hashable_kwargs): -# args = tuple(np.array(arg) if isinstance(arg, tuple) else arg for arg in hashable_args) -# kwargs = { -# k: np.array(kwarg) if isinstance(kwarg, tuple) else kwarg -# for k, kwarg in hashable_kwargs.items() -# } -# return function(*args, **kwargs) - -# @wraps(function) -# def wrapper(*args, **kwargs): -# hashable_args = tuple(tuple(arg) if isinstance(arg, np.ndarray) else arg for arg in args) -# hashable_kwargs = { -# k: tuple(kwarg) if isinstance(kwarg, np.ndarray) else kwarg -# for k, kwarg in kwargs.items() -# } -# return cached_wrapper(*hashable_args, **hashable_kwargs) - -# # copy lru_cache attributes over too -# wrapper.cache_info = cached_wrapper.cache_info -# wrapper.cache_clear = cached_wrapper.cache_clear - -# return wrapper - - -class np_cache: - def __init__(self, function): - self.logger = get_logger(__name__) - self.func = function - self.cache = {} - self.hits = 0 - self.misses = 0 - update_wrapper(self, function) - - def __call__(self, *args, **kwargs): - hashable_args = tuple( - tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args - ) - hashable_kwargs = tuple( - { - k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg - for k, kwarg in kwargs.items() - }.items() - ) - key = hash((hashable_args, hashable_kwargs)) - if key not in self.cache: - self.logger.debug("cache miss") - self.misses += 1 - self.cache[key] = self.func(*args, **kwargs) - else: - self.hits += 1 - self.logger.debug("cache hit") - return copy(self.cache[key]) diff --git a/src/scgenerator/utils/cache.py b/src/scgenerator/utils/cache.py new file mode 100644 index 0000000..fbc4ee7 --- /dev/null +++ b/src/scgenerator/utils/cache.py @@ -0,0 +1,75 @@ +from collections import namedtuple +from copy import copy +from functools import wraps +import numpy as np + +CacheInfo = namedtuple("CacheInfo", "hits misses size") + + +def np_cache(func): + def new_cached_func(): + cache = {} + hits = misses = 0 + + @wraps(func) + def wrapped(*args, **kwargs): + nonlocal cache, hits, misses + hashable_args = tuple( + tuple(arg) if isinstance(arg, (np.ndarray, list)) else arg for arg in args + ) + hashable_kwargs = tuple( + { + k: tuple(kwarg) if isinstance(kwarg, (np.ndarray, list)) else kwarg + for k, kwarg in kwargs.items() + }.items() + ) + key = hash((hashable_args, hashable_kwargs)) + if key not in cache: + misses += 1 + cache[key] = func(*args, **kwargs) + else: + hits += 1 + return copy(cache[key]) + + def reset(): + nonlocal cache, hits, misses + cache = {} + hits = misses = 0 + + wrapped.cache_info = lambda: CacheInfo(hits, misses, len(cache)) + wrapped.reset = reset + + return wrapped + + return new_cached_func() + + +if __name__ == "__main__": + import random + import time + + @np_cache + def lol(a): + time.sleep(random.random() * 4) + return a / 2 + + @np_cache + def ggg(b): + time.sleep(random.random() * 4) + return b * 2 + + x = np.arange(6) + for i in range(5): + print(lol.cache_info()) + print(lol(x)) + + print(f"{ggg.cache_info()=}") + print(f"{lol.cache_info()=}") + lol.reset() + + print(ggg(np.arange(3))) + print(ggg(np.arange(8))) + print(ggg(np.arange(3))) + + print(f"{ggg.cache_info()=}") + print(f"{lol.cache_info()=}")