rudimentary cache decorator

This commit is contained in:
2024-02-27 16:45:16 +01:00
parent 090100290a
commit 4ef82397bc
3 changed files with 68 additions and 5 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "scgenerator" name = "scgenerator"
version = "0.4.6" version = "0.4.7"
description = "Simulate nonlinear pulse propagation in optical fibers" description = "Simulate nonlinear pulse propagation in optical fibers"
readme = "README.md" readme = "README.md"
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
@@ -17,7 +17,7 @@ requires-python = ">=3.11"
keywords = ["nonlinear", "fiber optics", "simulation", "runge-kutta"] keywords = ["nonlinear", "fiber optics", "simulation", "runge-kutta"]
dependencies = [ dependencies = [
"numpy", "numpy",
"scipy", "scipy>=1.12.0",
"matplotlib", "matplotlib",
"numba", "numba",
"tqdm", "tqdm",

View File

@@ -4,13 +4,20 @@ import hashlib
import json import json
import os import os
import pickle import pickle
import re
import shutil import shutil
import string
import tomllib import tomllib
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import Any, Mapping, Self from typing import Any, Callable, Mapping, ParamSpec, Self, TypeVar, TypeVarTuple
CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator" CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator"
CACHE_DIR = Path(CACHE_DIR)
ACCEPTED = re.compile(string.ascii_letters + string.digits + r" \-_()\[\]\*~\.,=\+")
Ts = TypeVarTuple("Ts")
T = TypeVar("T")
def sort_dict(value: Any) -> dict[str, Any]: def sort_dict(value: Any) -> dict[str, Any]:
@@ -19,8 +26,13 @@ def sort_dict(value: Any) -> dict[str, Any]:
return {k: sort_dict(value[k]) for k in sorted(value)} return {k: sort_dict(value[k]) for k in sorted(value)}
def normalize_path(s: str) -> str:
return ACCEPTED.sub("_", s)
class Cache: class Cache:
dir: Path dir: Path
NO_DATA = object()
def check_exists(func): def check_exists(func):
@wraps(func) @wraps(func)
@@ -47,20 +59,42 @@ class Cache:
return cls(group) return cls(group)
def __init__(self, group: str): def __init__(self, group: str):
self.dir = CACHE_DIR / group self.dir = CACHE_DIR / normalize_path(group)
def __contains__(self, key: str): def __contains__(self, key: str):
key = normalize_path(key)
return (self.dir / key).exists() return (self.dir / key).exists()
def __call__(self, key_func: Callable[[*Ts], str] = None):
if key_func is None:
def key_func(*args: *Ts) -> str:
return " ".join(str(el) for el in args)
def wrapper(func: Callable[[*Ts], T]) -> Callable[[*Ts], T]:
@wraps(func)
def wrapped(*args: *Ts) -> T:
key = func.__qualname__ + " " + key_func(*args)
if (data := self.load(key)) is self.NO_DATA:
data = func(*args)
self.save(key, data)
return data
return wrapped
return wrapper
@check_exists @check_exists
def load(self, key: str) -> Any | None: def load(self, key: str) -> Any | None:
key = normalize_path(key)
fn = self.dir / key fn = self.dir / key
if not fn.exists(): if not fn.exists():
return None return self.NO_DATA
return pickle.loads(fn.read_bytes()) return pickle.loads(fn.read_bytes())
@check_exists @check_exists
def save(self, key: str, value: Any): def save(self, key: str, value: Any):
key = normalize_path(key)
fn = self.dir / key fn = self.dir / key
fn.write_bytes(pickle.dumps(value)) fn.write_bytes(pickle.dumps(value))

View File

@@ -1,3 +1,4 @@
import scgenerator.cache as scc
from scgenerator.cache import Cache from scgenerator.cache import Cache
@@ -57,3 +58,31 @@ def test_toml():
cache1.delete() cache1.delete()
assert not cache1.dir.exists() assert not cache1.dir.exists()
assert not cache3.dir.exists() assert not cache3.dir.exists()
def test_decorator():
cache = Cache("Test")
cache.delete()
call_count = 0
@cache()
def func(x: str) -> str:
nonlocal call_count
call_count += 1
return x + x
@cache(lambda li: f"{li[0]}-{li[-1]} {len(li)}")
def func2(some_list: list):
nonlocal call_count
call_count += 1
return sum(some_list)
assert func("hello") == "hellohello"
assert func2([0, 1, 2, 80]) == 83
assert func2([0, 1, 2, 80]) == 83
assert (scc.CACHE_DIR / "Test" / "test_decorator.<locals>.func hello").exists()
assert (scc.CACHE_DIR / "Test" / "test_decorator.<locals>.func2 0-80 4").exists()
assert call_count == 2
cache.delete()