added caching utility
This commit is contained in:
@@ -10,3 +10,4 @@ from scgenerator.spectra import Spectrum, propagation, propagation_series
|
|||||||
from scgenerator.physics import fiber, materials, plasma, pulse
|
from scgenerator.physics import fiber, materials, plasma, pulse
|
||||||
from scgenerator.physics.units import PlotRange
|
from scgenerator.physics.units import PlotRange
|
||||||
from scgenerator.solver import SimulationResult, integrate, solve43
|
from scgenerator.solver import SimulationResult, integrate, solve43
|
||||||
|
from scgenerator.cache import Cache
|
||||||
|
|||||||
74
src/scgenerator/cache.py
Normal file
74
src/scgenerator/cache.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
import tomllib
|
||||||
|
from functools import wraps
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Mapping, Self
|
||||||
|
|
||||||
|
CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator"
|
||||||
|
|
||||||
|
|
||||||
|
def sort_dict(value: Any) -> dict[str, Any]:
|
||||||
|
if not isinstance(value, Mapping):
|
||||||
|
return value
|
||||||
|
return {k: sort_dict(value[k]) for k in sorted(value)}
|
||||||
|
|
||||||
|
|
||||||
|
class Cache:
|
||||||
|
dir: Path
|
||||||
|
|
||||||
|
def check_exists(func):
|
||||||
|
@wraps(func)
|
||||||
|
def _wrapped(self: Self, *args, **kwargs):
|
||||||
|
if not self.dir.exists():
|
||||||
|
os.makedirs(self.dir)
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return _wrapped
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, s: str, /) -> Self:
|
||||||
|
hashed = hashlib.md5(pickle.dumps(json.loads(s))).hexdigest()
|
||||||
|
group = f"JSON-{hashed}"
|
||||||
|
os.makedirs(CACHE_DIR / group, exist_ok=True)
|
||||||
|
return cls(group)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, s: str, /, create: bool = True) -> Self:
|
||||||
|
hashed = hashlib.md5(pickle.dumps(sort_dict(tomllib.loads(s)))).hexdigest()
|
||||||
|
group = f"TOML-{hashed}"
|
||||||
|
if create:
|
||||||
|
os.makedirs(CACHE_DIR / group, exist_ok=True)
|
||||||
|
return cls(group)
|
||||||
|
|
||||||
|
def __init__(self, group: str):
|
||||||
|
self.dir = CACHE_DIR / group
|
||||||
|
|
||||||
|
def __contains__(self, key: str):
|
||||||
|
return (self.dir / key).exists()
|
||||||
|
|
||||||
|
@check_exists
|
||||||
|
def load(self, key: str) -> Any | None:
|
||||||
|
fn = self.dir / key
|
||||||
|
if not fn.exists():
|
||||||
|
return None
|
||||||
|
return pickle.loads(fn.read_bytes())
|
||||||
|
|
||||||
|
@check_exists
|
||||||
|
def save(self, key: str, value: Any):
|
||||||
|
fn = self.dir / key
|
||||||
|
fn.write_bytes(pickle.dumps(value))
|
||||||
|
|
||||||
|
@check_exists
|
||||||
|
def reset(self):
|
||||||
|
shutil.rmtree(self.dir)
|
||||||
|
os.makedirs(self.dir)
|
||||||
|
|
||||||
|
def delete(self):
|
||||||
|
if self.dir.exists():
|
||||||
|
shutil.rmtree(self.dir)
|
||||||
59
tests/test_cache.py
Normal file
59
tests/test_cache.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from scgenerator.cache import Cache
|
||||||
|
|
||||||
|
|
||||||
|
def test_io():
|
||||||
|
cache = Cache("Hello")
|
||||||
|
assert not cache.dir.exists()
|
||||||
|
|
||||||
|
cache.reset()
|
||||||
|
assert cache.dir.exists()
|
||||||
|
|
||||||
|
cache.save("a", "bonjour")
|
||||||
|
assert cache.load("a") == "bonjour"
|
||||||
|
assert "a" in cache
|
||||||
|
|
||||||
|
cache.delete()
|
||||||
|
assert not cache.dir.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_toml():
|
||||||
|
s1 = """
|
||||||
|
[config]
|
||||||
|
plot_range = [750, 1350]
|
||||||
|
rin_measurement = "./DualComb1GHz_updated_corrected_extrapolated_1GHz_noqn.csv"
|
||||||
|
num_segments = 31
|
||||||
|
num_frequencies = 513
|
||||||
|
noise_seed_start = 3012
|
||||||
|
anticorrelated_width = true
|
||||||
|
"""
|
||||||
|
s2 = """
|
||||||
|
# some commment
|
||||||
|
|
||||||
|
[config]
|
||||||
|
anticorrelated_width = true
|
||||||
|
plot_range = [750,1350]
|
||||||
|
num_segments=31
|
||||||
|
num_frequencies = 513
|
||||||
|
noise_seed_start=3012
|
||||||
|
rin_measurement='./DualComb1GHz_updated_corrected_extrapolated_1GHz_noqn.csv'
|
||||||
|
"""
|
||||||
|
s3 = """
|
||||||
|
# some commment
|
||||||
|
|
||||||
|
[config]
|
||||||
|
anticorrelated_width = true
|
||||||
|
plot_range = [750,1351]
|
||||||
|
num_segments=31
|
||||||
|
num_frequencies = 513
|
||||||
|
noise_seed_start=3012
|
||||||
|
rin_measurement='./DualComb1GHz_updated_corrected_extrapolated_1GHz_noqn.csv'
|
||||||
|
"""
|
||||||
|
|
||||||
|
cache1 = Cache.from_toml(s1)
|
||||||
|
cache3 = Cache.from_toml(s3, create=False)
|
||||||
|
assert cache1.dir == Cache.from_toml(s2).dir
|
||||||
|
assert cache1.dir != cache3.dir
|
||||||
|
assert cache1.dir.exists()
|
||||||
|
cache1.delete()
|
||||||
|
assert not cache1.dir.exists()
|
||||||
|
assert not cache3.dir.exists()
|
||||||
Reference in New Issue
Block a user