improved cache

This commit is contained in:
2024-04-10 13:50:57 +02:00
parent 22dcb9d15c
commit edadb4284b
3 changed files with 79 additions and 13 deletions

View File

@@ -7,10 +7,12 @@ import pickle
import re import re
import shutil import shutil
import string import string
from time import perf_counter
import tomllib import tomllib
import warnings
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Mapping, Self, TypeVar, TypeVarTuple from typing import Any, Callable, Generic, Mapping, ParamSpec, Protocol, Self, TypeVar, TypeVarTuple
CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator" CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator"
CACHE_DIR = Path(CACHE_DIR) CACHE_DIR = Path(CACHE_DIR)
@@ -20,6 +22,7 @@ PRECACHED = {}
PATH_LEN = 250 PATH_LEN = 250
Ts = TypeVarTuple("Ts") Ts = TypeVarTuple("Ts")
Ps = ParamSpec("Ps")
T = TypeVar("T") T = TypeVar("T")
@@ -36,6 +39,11 @@ def normalize_path(s: str) -> str:
return path return path
class CachedFunction(Protocol[Ps, T]):
def __call__(self, *args: Ps.args, **kwargs: Ps.kwargs) -> T: ...
def cached_only(self, *args: Ps.args, **kwargs: Ps.kwargs) -> T | None: ...
class Cache: class Cache:
dir: Path dir: Path
NO_DATA = object() NO_DATA = object()
@@ -71,21 +79,35 @@ class Cache:
key = normalize_path(key) key = normalize_path(key)
return (self.dir / key).exists() return (self.dir / key).exists()
def __call__(self, key_func: Callable[[*Ts], str] = None): def __call__(
self, key_func: Callable[Ps, str] = None
) -> Callable[[Callable[Ps, T]], CachedFunction[Ps, T]]:
if key_func is None: if key_func is None:
def key_func(*args: *Ts) -> str: def key_func(*args: Ps.args, **kwargs: Ps.kwargs) -> str:
return " ".join(str(el) for el in args) try:
return hashlib.md5(pickle.dumps(args + tuple(kwargs.items()))).hexdigest()
except TypeError:
warnings.warn(f"cache '{self.dir}' couldn't use pickle to calculate key")
return str(args) + str(kwargs)
def wrapper(func: Callable[[*Ts], T]) -> Callable[[*Ts], T]: def wrapper(func: Callable[Ps, T]) -> CachedFunction[Ps, T]:
@wraps(func) @wraps(func)
def wrapped(*args: *Ts) -> T: def wrapped(*args: Ps.args, **kwargs: Ps.kwargs) -> T:
key = func.__qualname__ + " " + key_func(*args) key = func.__qualname__ + " " + key_func(*args, **kwargs)
if (data := self.load(key)) is self.NO_DATA: if (data := self.load(key)) is self.NO_DATA:
data = func(*args) data = func(*args, **kwargs)
self.save(key, data) self.save(key, data)
return data return data
def cached_only(*args: Ps.args, **kwargs: Ps.kwargs) -> T | None:
key = func.__qualname__ + " " + key_func(*args, **kwargs)
if (data := self.load(key)) is self.NO_DATA:
return None
return data
wrapped.cached_only = cached_only
return wrapped return wrapped
return wrapper return wrapper
@@ -96,7 +118,8 @@ class Cache:
fn = self.dir / key fn = self.dir / key
if not fn.exists(): if not fn.exists():
return self.NO_DATA return self.NO_DATA
return pickle.loads(fn.read_bytes()) stuff = pickle.loads(fn.read_bytes())
return stuff
@check_exists @check_exists
def save(self, key: str, value: Any): def save(self, key: str, value: Any):

View File

@@ -1,3 +1,4 @@
import colorsys
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@@ -8,7 +9,7 @@ import matplotlib.gridspec as gs
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from matplotlib.axes import Axes from matplotlib.axes import Axes
from matplotlib.colors import ListedColormap from matplotlib.colors import ColorConverter, ListedColormap
from matplotlib.figure import Figure from matplotlib.figure import Figure
from matplotlib.transforms import offset_copy from matplotlib.transforms import offset_copy
from scipy.interpolate import UnivariateSpline, interp1d from scipy.interpolate import UnivariateSpline, interp1d
@@ -37,6 +38,14 @@ class ImageData:
return get_extent(self.x, self.y) return get_extent(self.x, self.y)
def alt_color(c, fac: float):
if isinstance(c, (list, tuple)) or (isinstance(c, np.ndarray) and c.ndim > 1):
return np.array([alt_color(el, fac) for el in c])
*color, alpha = ColorConverter.to_rgba(c)
h, s, v = colorsys.rgb_to_hsv(*color)
return (*colorsys.hsv_to_rgb(h, s, min(1, v * fac)), alpha)
def get_extent(x, y, facx=1, facy=1): def get_extent(x, y, facx=1, facy=1):
""" """
returns the extent 4-tuple needed for imshow, aligning each pixel returns the extent 4-tuple needed for imshow, aligning each pixel
@@ -230,14 +239,31 @@ def create_zoom_axis(
return inset return inset
def corner_annotation(text, ax, position="tl", pts_x=4, pts_y=4, **text_kwargs): def corner_annotation(
text,
ax: Axes | None = None,
position="tl",
pts_x=4,
pts_y=4,
opaque: bool = False,
**text_kwargs,
):
"""puts an annotatin in a corner of an ax """puts an annotatin in a corner of an ax
Parameters Parameters
---------- ----------
text : str text : str
text to put in the corner text to put in the corner
ax : matplotlib axis object ax : matplotlib.axes.Axes object, optional
position : str {"tl", "tr", "bl", "br"} by default current Axes
position : str {"tl", "tr", "bl", "br"}, optional
short for top/bottom left/right, by default top left
pts_x/y : float
offset in points from absolute corner, by default 4
opaque : bool, optional
draw a white rectangle behind the text to make it stand out if the plot is busy
by default False
Returns Returns
---------- ----------
@@ -264,6 +290,13 @@ def corner_annotation(text, ax, position="tl", pts_x=4, pts_y=4, **text_kwargs):
pts_x = -pts_x pts_x = -pts_x
ha = "right" ha = "right"
if ax is None:
ax = plt.gca()
if opaque:
bbox = dict(facecolor="white", edgecolor=(1, 1, 1, 0), boxstyle="round,pad=0", alpha=1)
text_kwargs["bbox"] = text_kwargs.get("bbox", {}) | bbox
ax.annotate( ax.annotate(
text, text,
(x, y), (x, y),

View File

@@ -65,6 +65,10 @@ class Spectrum(np.ndarray):
data = np.frombuffer(buf[i:], np.complex128).reshape(shape) data = np.frombuffer(buf[i:], np.complex128).reshape(shape)
return cls(data, w, t, ifft) return cls(data, w, t, ifft)
@classmethod
def open(cls, path: os.PathLike) -> Spectrum:
return cls.from_bytes(Path(path).read_bytes())
def __new__( def __new__(
cls, cls,
input_array, input_array,
@@ -147,6 +151,12 @@ class Spectrum(np.ndarray):
+ self.astype(np.complex128, subok=False).tobytes() + self.astype(np.complex128, subok=False).tobytes()
) )
def tobytes(self, *args, **kwargs) -> bytes:
warnings.warn(
"Calling `tobytes` (numpy function) on Spectrum object. Did you mean `bytes(obj)`?"
)
return super().tobytes(*args, **kwargs)
@property @property
def wl_disp(self): def wl_disp(self):
return self.l[self.l_order] return self.l[self.l_order]