From 4ef82397bc2ce792d17a9dd82e10eddd33ed73f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 27 Feb 2024 16:45:16 +0100 Subject: [PATCH] rudimentary cache decorator --- pyproject.toml | 4 ++-- src/scgenerator/cache.py | 40 +++++++++++++++++++++++++++++++++++++--- tests/test_cache.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 69d1153..8fbfa8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.4.6" +version = "0.4.7" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] @@ -17,7 +17,7 @@ requires-python = ">=3.11" keywords = ["nonlinear", "fiber optics", "simulation", "runge-kutta"] dependencies = [ "numpy", - "scipy", + "scipy>=1.12.0", "matplotlib", "numba", "tqdm", diff --git a/src/scgenerator/cache.py b/src/scgenerator/cache.py index 22d5319..13b2b87 100644 --- a/src/scgenerator/cache.py +++ b/src/scgenerator/cache.py @@ -4,13 +4,20 @@ import hashlib import json import os import pickle +import re import shutil +import string import tomllib from functools import wraps from pathlib import Path -from typing import Any, Mapping, Self +from typing import Any, Callable, Mapping, ParamSpec, Self, TypeVar, TypeVarTuple CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator" +CACHE_DIR = Path(CACHE_DIR) +ACCEPTED = re.compile(string.ascii_letters + string.digits + r" \-_()\[\]\*~\.,=\+") + +Ts = TypeVarTuple("Ts") +T = TypeVar("T") def sort_dict(value: Any) -> dict[str, Any]: @@ -19,8 +26,13 @@ def sort_dict(value: Any) -> dict[str, Any]: return {k: sort_dict(value[k]) for k in sorted(value)} +def normalize_path(s: str) -> str: + return ACCEPTED.sub("_", s) + + class Cache: dir: Path + NO_DATA = object() def check_exists(func): @wraps(func) @@ -47,20 +59,42 @@ class Cache: return cls(group) def __init__(self, group: str): - self.dir = CACHE_DIR / group + self.dir = CACHE_DIR / normalize_path(group) def __contains__(self, key: str): + key = normalize_path(key) return (self.dir / key).exists() + def __call__(self, key_func: Callable[[*Ts], str] = None): + if key_func is None: + + def key_func(*args: *Ts) -> str: + return " ".join(str(el) for el in args) + + def wrapper(func: Callable[[*Ts], T]) -> Callable[[*Ts], T]: + @wraps(func) + def wrapped(*args: *Ts) -> T: + key = func.__qualname__ + " " + key_func(*args) + if (data := self.load(key)) is self.NO_DATA: + data = func(*args) + self.save(key, data) + return data + + return wrapped + + return wrapper + @check_exists def load(self, key: str) -> Any | None: + key = normalize_path(key) fn = self.dir / key if not fn.exists(): - return None + return self.NO_DATA return pickle.loads(fn.read_bytes()) @check_exists def save(self, key: str, value: Any): + key = normalize_path(key) fn = self.dir / key fn.write_bytes(pickle.dumps(value)) diff --git a/tests/test_cache.py b/tests/test_cache.py index f6bf9ad..754d287 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,3 +1,4 @@ +import scgenerator.cache as scc from scgenerator.cache import Cache @@ -57,3 +58,31 @@ def test_toml(): cache1.delete() assert not cache1.dir.exists() assert not cache3.dir.exists() + + +def test_decorator(): + cache = Cache("Test") + cache.delete() + call_count = 0 + + @cache() + def func(x: str) -> str: + nonlocal call_count + call_count += 1 + return x + x + + @cache(lambda li: f"{li[0]}-{li[-1]} {len(li)}") + def func2(some_list: list): + nonlocal call_count + call_count += 1 + return sum(some_list) + + assert func("hello") == "hellohello" + assert func2([0, 1, 2, 80]) == 83 + assert func2([0, 1, 2, 80]) == 83 + + assert (scc.CACHE_DIR / "Test" / "test_decorator..func hello").exists() + assert (scc.CACHE_DIR / "Test" / "test_decorator..func2 0-80 4").exists() + + assert call_count == 2 + cache.delete()