From edadb4284b040dcc7de9b9da47634c772ed9bec6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Wed, 10 Apr 2024 13:50:57 +0200 Subject: [PATCH] improved cache --- src/scgenerator/cache.py | 41 +++++++++++++++++++++++++++++-------- src/scgenerator/plotting.py | 41 +++++++++++++++++++++++++++++++++---- src/scgenerator/spectra.py | 10 +++++++++ 3 files changed, 79 insertions(+), 13 deletions(-) diff --git a/src/scgenerator/cache.py b/src/scgenerator/cache.py index eaefe39..4f0f51a 100644 --- a/src/scgenerator/cache.py +++ b/src/scgenerator/cache.py @@ -7,10 +7,12 @@ import pickle import re import shutil import string +from time import perf_counter import tomllib +import warnings from functools import wraps from pathlib import Path -from typing import Any, Callable, Mapping, Self, TypeVar, TypeVarTuple +from typing import Any, Callable, Generic, Mapping, ParamSpec, Protocol, Self, TypeVar, TypeVarTuple CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator" CACHE_DIR = Path(CACHE_DIR) @@ -20,6 +22,7 @@ PRECACHED = {} PATH_LEN = 250 Ts = TypeVarTuple("Ts") +Ps = ParamSpec("Ps") T = TypeVar("T") @@ -36,6 +39,11 @@ def normalize_path(s: str) -> str: return path +class CachedFunction(Protocol[Ps, T]): + def __call__(self, *args: Ps.args, **kwargs: Ps.kwargs) -> T: ... + def cached_only(self, *args: Ps.args, **kwargs: Ps.kwargs) -> T | None: ... + + class Cache: dir: Path NO_DATA = object() @@ -71,21 +79,35 @@ class Cache: key = normalize_path(key) return (self.dir / key).exists() - def __call__(self, key_func: Callable[[*Ts], str] = None): + def __call__( + self, key_func: Callable[Ps, str] = None + ) -> Callable[[Callable[Ps, T]], CachedFunction[Ps, T]]: if key_func is None: - def key_func(*args: *Ts) -> str: - return " ".join(str(el) for el in args) + def key_func(*args: Ps.args, **kwargs: Ps.kwargs) -> str: + try: + return hashlib.md5(pickle.dumps(args + tuple(kwargs.items()))).hexdigest() + except TypeError: + warnings.warn(f"cache '{self.dir}' couldn't use pickle to calculate key") + return str(args) + str(kwargs) - def wrapper(func: Callable[[*Ts], T]) -> Callable[[*Ts], T]: + def wrapper(func: Callable[Ps, T]) -> CachedFunction[Ps, T]: @wraps(func) - def wrapped(*args: *Ts) -> T: - key = func.__qualname__ + " " + key_func(*args) + def wrapped(*args: Ps.args, **kwargs: Ps.kwargs) -> T: + key = func.__qualname__ + " " + key_func(*args, **kwargs) if (data := self.load(key)) is self.NO_DATA: - data = func(*args) + data = func(*args, **kwargs) self.save(key, data) return data + def cached_only(*args: Ps.args, **kwargs: Ps.kwargs) -> T | None: + key = func.__qualname__ + " " + key_func(*args, **kwargs) + if (data := self.load(key)) is self.NO_DATA: + return None + return data + + wrapped.cached_only = cached_only + return wrapped return wrapper @@ -96,7 +118,8 @@ class Cache: fn = self.dir / key if not fn.exists(): return self.NO_DATA - return pickle.loads(fn.read_bytes()) + stuff = pickle.loads(fn.read_bytes()) + return stuff @check_exists def save(self, key: str, value: Any): diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 6f13b7b..377a0c7 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -1,3 +1,4 @@ +import colorsys import os from dataclasses import dataclass from pathlib import Path @@ -8,7 +9,7 @@ import matplotlib.gridspec as gs import matplotlib.pyplot as plt import numpy as np from matplotlib.axes import Axes -from matplotlib.colors import ListedColormap +from matplotlib.colors import ColorConverter, ListedColormap from matplotlib.figure import Figure from matplotlib.transforms import offset_copy from scipy.interpolate import UnivariateSpline, interp1d @@ -37,6 +38,14 @@ class ImageData: return get_extent(self.x, self.y) +def alt_color(c, fac: float): + if isinstance(c, (list, tuple)) or (isinstance(c, np.ndarray) and c.ndim > 1): + return np.array([alt_color(el, fac) for el in c]) + *color, alpha = ColorConverter.to_rgba(c) + h, s, v = colorsys.rgb_to_hsv(*color) + return (*colorsys.hsv_to_rgb(h, s, min(1, v * fac)), alpha) + + def get_extent(x, y, facx=1, facy=1): """ returns the extent 4-tuple needed for imshow, aligning each pixel @@ -230,14 +239,31 @@ def create_zoom_axis( return inset -def corner_annotation(text, ax, position="tl", pts_x=4, pts_y=4, **text_kwargs): +def corner_annotation( + text, + ax: Axes | None = None, + position="tl", + pts_x=4, + pts_y=4, + opaque: bool = False, + **text_kwargs, +): """puts an annotatin in a corner of an ax Parameters ---------- text : str text to put in the corner - ax : matplotlib axis object - position : str {"tl", "tr", "bl", "br"} + ax : matplotlib.axes.Axes object, optional + by default current Axes + position : str {"tl", "tr", "bl", "br"}, optional + short for top/bottom left/right, by default top left + pts_x/y : float + offset in points from absolute corner, by default 4 + opaque : bool, optional + draw a white rectangle behind the text to make it stand out if the plot is busy + by default False + + Returns ---------- @@ -264,6 +290,13 @@ def corner_annotation(text, ax, position="tl", pts_x=4, pts_y=4, **text_kwargs): pts_x = -pts_x ha = "right" + if ax is None: + ax = plt.gca() + + if opaque: + bbox = dict(facecolor="white", edgecolor=(1, 1, 1, 0), boxstyle="round,pad=0", alpha=1) + text_kwargs["bbox"] = text_kwargs.get("bbox", {}) | bbox + ax.annotate( text, (x, y), diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 54906c4..072d953 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -65,6 +65,10 @@ class Spectrum(np.ndarray): data = np.frombuffer(buf[i:], np.complex128).reshape(shape) return cls(data, w, t, ifft) + @classmethod + def open(cls, path: os.PathLike) -> Spectrum: + return cls.from_bytes(Path(path).read_bytes()) + def __new__( cls, input_array, @@ -147,6 +151,12 @@ class Spectrum(np.ndarray): + self.astype(np.complex128, subok=False).tobytes() ) + def tobytes(self, *args, **kwargs) -> bytes: + warnings.warn( + "Calling `tobytes` (numpy function) on Spectrum object. Did you mean `bytes(obj)`?" + ) + return super().tobytes(*args, **kwargs) + @property def wl_disp(self): return self.l[self.l_order]