From f0502d93c5618c746da7800e70e65239e170631a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 13 Feb 2024 15:57:03 +0100 Subject: [PATCH] added caching utility --- src/scgenerator/__init__.py | 1 + src/scgenerator/cache.py | 74 +++++++++++++++++++++++++++++++++++++ tests/test_cache.py | 59 +++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+) create mode 100644 src/scgenerator/cache.py create mode 100644 tests/test_cache.py diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index abe2950..b9bd287 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -10,3 +10,4 @@ from scgenerator.spectra import Spectrum, propagation, propagation_series from scgenerator.physics import fiber, materials, plasma, pulse from scgenerator.physics.units import PlotRange from scgenerator.solver import SimulationResult, integrate, solve43 +from scgenerator.cache import Cache diff --git a/src/scgenerator/cache.py b/src/scgenerator/cache.py new file mode 100644 index 0000000..22d5319 --- /dev/null +++ b/src/scgenerator/cache.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import hashlib +import json +import os +import pickle +import shutil +import tomllib +from functools import wraps +from pathlib import Path +from typing import Any, Mapping, Self + +CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator" + + +def sort_dict(value: Any) -> dict[str, Any]: + if not isinstance(value, Mapping): + return value + return {k: sort_dict(value[k]) for k in sorted(value)} + + +class Cache: + dir: Path + + def check_exists(func): + @wraps(func) + def _wrapped(self: Self, *args, **kwargs): + if not self.dir.exists(): + os.makedirs(self.dir) + return func(self, *args, **kwargs) + + return _wrapped + + @classmethod + def from_json(cls, s: str, /) -> Self: + hashed = hashlib.md5(pickle.dumps(json.loads(s))).hexdigest() + group = f"JSON-{hashed}" + os.makedirs(CACHE_DIR / group, exist_ok=True) + return cls(group) + + @classmethod + def from_toml(cls, s: str, /, create: bool = True) -> Self: + hashed = hashlib.md5(pickle.dumps(sort_dict(tomllib.loads(s)))).hexdigest() + group = f"TOML-{hashed}" + if create: + os.makedirs(CACHE_DIR / group, exist_ok=True) + return cls(group) + + def __init__(self, group: str): + self.dir = CACHE_DIR / group + + def __contains__(self, key: str): + return (self.dir / key).exists() + + @check_exists + def load(self, key: str) -> Any | None: + fn = self.dir / key + if not fn.exists(): + return None + return pickle.loads(fn.read_bytes()) + + @check_exists + def save(self, key: str, value: Any): + fn = self.dir / key + fn.write_bytes(pickle.dumps(value)) + + @check_exists + def reset(self): + shutil.rmtree(self.dir) + os.makedirs(self.dir) + + def delete(self): + if self.dir.exists(): + shutil.rmtree(self.dir) diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..f6bf9ad --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,59 @@ +from scgenerator.cache import Cache + + +def test_io(): + cache = Cache("Hello") + assert not cache.dir.exists() + + cache.reset() + assert cache.dir.exists() + + cache.save("a", "bonjour") + assert cache.load("a") == "bonjour" + assert "a" in cache + + cache.delete() + assert not cache.dir.exists() + + +def test_toml(): + s1 = """ + [config] + plot_range = [750, 1350] + rin_measurement = "./DualComb1GHz_updated_corrected_extrapolated_1GHz_noqn.csv" + num_segments = 31 + num_frequencies = 513 + noise_seed_start = 3012 + anticorrelated_width = true + """ + s2 = """ + # some commment + + [config] + anticorrelated_width = true + plot_range = [750,1350] + num_segments=31 + num_frequencies = 513 + noise_seed_start=3012 + rin_measurement='./DualComb1GHz_updated_corrected_extrapolated_1GHz_noqn.csv' + """ + s3 = """ + # some commment + + [config] + anticorrelated_width = true + plot_range = [750,1351] + num_segments=31 + num_frequencies = 513 + noise_seed_start=3012 + rin_measurement='./DualComb1GHz_updated_corrected_extrapolated_1GHz_noqn.csv' + """ + + cache1 = Cache.from_toml(s1) + cache3 = Cache.from_toml(s3, create=False) + assert cache1.dir == Cache.from_toml(s2).dir + assert cache1.dir != cache3.dir + assert cache1.dir.exists() + cache1.delete() + assert not cache1.dir.exists() + assert not cache3.dir.exists()