rudimentary cache decorator
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "scgenerator"
|
name = "scgenerator"
|
||||||
version = "0.4.6"
|
version = "0.4.7"
|
||||||
description = "Simulate nonlinear pulse propagation in optical fibers"
|
description = "Simulate nonlinear pulse propagation in optical fibers"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
||||||
@@ -17,7 +17,7 @@ requires-python = ">=3.11"
|
|||||||
keywords = ["nonlinear", "fiber optics", "simulation", "runge-kutta"]
|
keywords = ["nonlinear", "fiber optics", "simulation", "runge-kutta"]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy",
|
"numpy",
|
||||||
"scipy",
|
"scipy>=1.12.0",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"numba",
|
"numba",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
|||||||
@@ -4,13 +4,20 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import string
|
||||||
import tomllib
|
import tomllib
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Mapping, Self
|
from typing import Any, Callable, Mapping, ParamSpec, Self, TypeVar, TypeVarTuple
|
||||||
|
|
||||||
CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator"
|
CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator"
|
||||||
|
CACHE_DIR = Path(CACHE_DIR)
|
||||||
|
ACCEPTED = re.compile(string.ascii_letters + string.digits + r" \-_()\[\]\*~\.,=\+")
|
||||||
|
|
||||||
|
Ts = TypeVarTuple("Ts")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def sort_dict(value: Any) -> dict[str, Any]:
|
def sort_dict(value: Any) -> dict[str, Any]:
|
||||||
@@ -19,8 +26,13 @@ def sort_dict(value: Any) -> dict[str, Any]:
|
|||||||
return {k: sort_dict(value[k]) for k in sorted(value)}
|
return {k: sort_dict(value[k]) for k in sorted(value)}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_path(s: str) -> str:
|
||||||
|
return ACCEPTED.sub("_", s)
|
||||||
|
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
dir: Path
|
dir: Path
|
||||||
|
NO_DATA = object()
|
||||||
|
|
||||||
def check_exists(func):
|
def check_exists(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
@@ -47,20 +59,42 @@ class Cache:
|
|||||||
return cls(group)
|
return cls(group)
|
||||||
|
|
||||||
def __init__(self, group: str):
|
def __init__(self, group: str):
|
||||||
self.dir = CACHE_DIR / group
|
self.dir = CACHE_DIR / normalize_path(group)
|
||||||
|
|
||||||
def __contains__(self, key: str):
|
def __contains__(self, key: str):
|
||||||
|
key = normalize_path(key)
|
||||||
return (self.dir / key).exists()
|
return (self.dir / key).exists()
|
||||||
|
|
||||||
|
def __call__(self, key_func: Callable[[*Ts], str] = None):
|
||||||
|
if key_func is None:
|
||||||
|
|
||||||
|
def key_func(*args: *Ts) -> str:
|
||||||
|
return " ".join(str(el) for el in args)
|
||||||
|
|
||||||
|
def wrapper(func: Callable[[*Ts], T]) -> Callable[[*Ts], T]:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapped(*args: *Ts) -> T:
|
||||||
|
key = func.__qualname__ + " " + key_func(*args)
|
||||||
|
if (data := self.load(key)) is self.NO_DATA:
|
||||||
|
data = func(*args)
|
||||||
|
self.save(key, data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
@check_exists
|
@check_exists
|
||||||
def load(self, key: str) -> Any | None:
|
def load(self, key: str) -> Any | None:
|
||||||
|
key = normalize_path(key)
|
||||||
fn = self.dir / key
|
fn = self.dir / key
|
||||||
if not fn.exists():
|
if not fn.exists():
|
||||||
return None
|
return self.NO_DATA
|
||||||
return pickle.loads(fn.read_bytes())
|
return pickle.loads(fn.read_bytes())
|
||||||
|
|
||||||
@check_exists
|
@check_exists
|
||||||
def save(self, key: str, value: Any):
|
def save(self, key: str, value: Any):
|
||||||
|
key = normalize_path(key)
|
||||||
fn = self.dir / key
|
fn = self.dir / key
|
||||||
fn.write_bytes(pickle.dumps(value))
|
fn.write_bytes(pickle.dumps(value))
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import scgenerator.cache as scc
|
||||||
from scgenerator.cache import Cache
|
from scgenerator.cache import Cache
|
||||||
|
|
||||||
|
|
||||||
@@ -57,3 +58,31 @@ def test_toml():
|
|||||||
cache1.delete()
|
cache1.delete()
|
||||||
assert not cache1.dir.exists()
|
assert not cache1.dir.exists()
|
||||||
assert not cache3.dir.exists()
|
assert not cache3.dir.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_decorator():
|
||||||
|
cache = Cache("Test")
|
||||||
|
cache.delete()
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@cache()
|
||||||
|
def func(x: str) -> str:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return x + x
|
||||||
|
|
||||||
|
@cache(lambda li: f"{li[0]}-{li[-1]} {len(li)}")
|
||||||
|
def func2(some_list: list):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return sum(some_list)
|
||||||
|
|
||||||
|
assert func("hello") == "hellohello"
|
||||||
|
assert func2([0, 1, 2, 80]) == 83
|
||||||
|
assert func2([0, 1, 2, 80]) == 83
|
||||||
|
|
||||||
|
assert (scc.CACHE_DIR / "Test" / "test_decorator.<locals>.func hello").exists()
|
||||||
|
assert (scc.CACHE_DIR / "Test" / "test_decorator.<locals>.func2 0-80 4").exists()
|
||||||
|
|
||||||
|
assert call_count == 2
|
||||||
|
cache.delete()
|
||||||
|
|||||||
Reference in New Issue
Block a user