added caching utility
This commit is contained in:
@@ -10,3 +10,4 @@ from scgenerator.spectra import Spectrum, propagation, propagation_series
|
||||
from scgenerator.physics import fiber, materials, plasma, pulse
|
||||
from scgenerator.physics.units import PlotRange
|
||||
from scgenerator.solver import SimulationResult, integrate, solve43
|
||||
from scgenerator.cache import Cache
|
||||
|
||||
74
src/scgenerator/cache.py
Normal file
74
src/scgenerator/cache.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tomllib
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping, Self
|
||||
|
||||
CACHE_DIR = os.getenv("SCGENERATOR_CACHE_DIR") or Path.home() / ".cache" / "scgenerator"
|
||||
|
||||
|
||||
def sort_dict(value: Any) -> dict[str, Any]:
|
||||
if not isinstance(value, Mapping):
|
||||
return value
|
||||
return {k: sort_dict(value[k]) for k in sorted(value)}
|
||||
|
||||
|
||||
class Cache:
|
||||
dir: Path
|
||||
|
||||
def check_exists(func):
|
||||
@wraps(func)
|
||||
def _wrapped(self: Self, *args, **kwargs):
|
||||
if not self.dir.exists():
|
||||
os.makedirs(self.dir)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return _wrapped
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s: str, /) -> Self:
|
||||
hashed = hashlib.md5(pickle.dumps(json.loads(s))).hexdigest()
|
||||
group = f"JSON-{hashed}"
|
||||
os.makedirs(CACHE_DIR / group, exist_ok=True)
|
||||
return cls(group)
|
||||
|
||||
@classmethod
|
||||
def from_toml(cls, s: str, /, create: bool = True) -> Self:
|
||||
hashed = hashlib.md5(pickle.dumps(sort_dict(tomllib.loads(s)))).hexdigest()
|
||||
group = f"TOML-{hashed}"
|
||||
if create:
|
||||
os.makedirs(CACHE_DIR / group, exist_ok=True)
|
||||
return cls(group)
|
||||
|
||||
def __init__(self, group: str):
|
||||
self.dir = CACHE_DIR / group
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return (self.dir / key).exists()
|
||||
|
||||
@check_exists
|
||||
def load(self, key: str) -> Any | None:
|
||||
fn = self.dir / key
|
||||
if not fn.exists():
|
||||
return None
|
||||
return pickle.loads(fn.read_bytes())
|
||||
|
||||
@check_exists
|
||||
def save(self, key: str, value: Any):
|
||||
fn = self.dir / key
|
||||
fn.write_bytes(pickle.dumps(value))
|
||||
|
||||
@check_exists
|
||||
def reset(self):
|
||||
shutil.rmtree(self.dir)
|
||||
os.makedirs(self.dir)
|
||||
|
||||
def delete(self):
|
||||
if self.dir.exists():
|
||||
shutil.rmtree(self.dir)
|
||||
59
tests/test_cache.py
Normal file
59
tests/test_cache.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from scgenerator.cache import Cache
|
||||
|
||||
|
||||
def test_io():
|
||||
cache = Cache("Hello")
|
||||
assert not cache.dir.exists()
|
||||
|
||||
cache.reset()
|
||||
assert cache.dir.exists()
|
||||
|
||||
cache.save("a", "bonjour")
|
||||
assert cache.load("a") == "bonjour"
|
||||
assert "a" in cache
|
||||
|
||||
cache.delete()
|
||||
assert not cache.dir.exists()
|
||||
|
||||
|
||||
def test_toml():
|
||||
s1 = """
|
||||
[config]
|
||||
plot_range = [750, 1350]
|
||||
rin_measurement = "./DualComb1GHz_updated_corrected_extrapolated_1GHz_noqn.csv"
|
||||
num_segments = 31
|
||||
num_frequencies = 513
|
||||
noise_seed_start = 3012
|
||||
anticorrelated_width = true
|
||||
"""
|
||||
s2 = """
|
||||
# some commment
|
||||
|
||||
[config]
|
||||
anticorrelated_width = true
|
||||
plot_range = [750,1350]
|
||||
num_segments=31
|
||||
num_frequencies = 513
|
||||
noise_seed_start=3012
|
||||
rin_measurement='./DualComb1GHz_updated_corrected_extrapolated_1GHz_noqn.csv'
|
||||
"""
|
||||
s3 = """
|
||||
# some commment
|
||||
|
||||
[config]
|
||||
anticorrelated_width = true
|
||||
plot_range = [750,1351]
|
||||
num_segments=31
|
||||
num_frequencies = 513
|
||||
noise_seed_start=3012
|
||||
rin_measurement='./DualComb1GHz_updated_corrected_extrapolated_1GHz_noqn.csv'
|
||||
"""
|
||||
|
||||
cache1 = Cache.from_toml(s1)
|
||||
cache3 = Cache.from_toml(s3, create=False)
|
||||
assert cache1.dir == Cache.from_toml(s2).dir
|
||||
assert cache1.dir != cache3.dir
|
||||
assert cache1.dir.exists()
|
||||
cache1.delete()
|
||||
assert not cache1.dir.exists()
|
||||
assert not cache3.dir.exists()
|
||||
Reference in New Issue
Block a user