removed Paths, introduced simple data_file function

This commit is contained in:
Benoît Sierro
2023-07-27 11:27:03 +02:00
parent 3d9b4f57f3
commit b95446df0d
4 changed files with 14 additions and 57 deletions

View File

@@ -1,9 +1,8 @@
# isort: skip_file
# ruff: noqa
from scgenerator import math, operators, plotting
from scgenerator import io, math, operators, plotting
from scgenerator.helpers import *
from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace
from scgenerator.parameter import Parameters
from scgenerator.physics import fiber, materials, plasma, pulse, units
from scgenerator.physics.units import PlotRange
from scgenerator.solver import integrate, solve43, SimulationResult
from scgenerator.solver import SimulationResult, integrate, solve43

View File

@@ -1,5 +1,12 @@
import datetime
import json
from pathlib import Path
import pkg_resources
def data_file(path: str) -> Path:
return Path(pkg_resources.resource_filename("scgenerator", path))
class DatetimeEncoder(json.JSONEncoder):

View File

@@ -6,7 +6,7 @@ from numpy.fft import fft
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
from scipy.interpolate import interp1d
from scgenerator import utils
from scgenerator import io
from scgenerator.math import argclosest, u_nm
from scgenerator.physics import materials as mat
from scgenerator.physics import units
@@ -784,13 +784,12 @@ def delayed_raman_t(t: np.ndarray, raman_type: str) -> np.ndarray:
elif raman_type == "measured":
try:
path = utils.Paths.get("hr_t")
loaded = np.load(path)
file = io.data_file("raman_response.npy")
t_stored, hr_arr_stored = np.load(file)
except FileNotFoundError:
print("Not able to find the measured Raman response function. Going with agrawal model")
return delayed_raman_t(t, raman_type="agrawal")
t_stored, hr_arr_stored = loaded["t"], loaded["hr_arr"]
hr_arr = interp1d(t_stored, hr_arr_stored, bounds_error=False, fill_value=0)(t)
else:
print("invalid raman response function, aborting")

View File

@@ -17,10 +17,10 @@ from string import printable as str_printable
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Union
import numpy as np
import pkg_resources as pkg
import tomli
import tomli_w
from scgenerator import io
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN
from scgenerator.logger import get_logger
@@ -40,54 +40,6 @@ class TimedMessage:
return False
class Paths:
_data_files = [
"materials.toml",
"hr_t.npz",
"submit_job_template.txt",
"start_worker.sh",
"start_head.sh",
]
paths = {
f.split(".")[0]: os.path.abspath(
pkg.resource_filename("scgenerator", os.path.join("data", f))
)
for f in _data_files
}
@classmethod
def get(cls, key):
if key not in cls.paths:
if os.path.exists("paths.toml"):
with open("paths.toml", "rb") as file:
paths_dico = tomli.load(file)
for k, v in paths_dico.items():
cls.paths[k] = v
if key not in cls.paths:
get_logger(__name__).info(
f"{key} was not found in path index, returning current working directory."
)
cls.paths[key] = os.getcwd()
return cls.paths[key]
@classmethod
def gets(cls, key):
"""returned the specified file as a string"""
with open(cls.get(key)) as file:
return file.read()
@classmethod
def plot(cls, name):
"""returns the paths to the specified plot. Used to save new plot
example
---------
fig.savefig(Paths.plot("figure5.pdf"))
"""
return os.path.join(cls.get("plots"), name)
def conform_variable_entry(d) -> list[dict[str, list]]:
if isinstance(d, MutableMapping):
d = [{k: v} for k, v in d.items()]
@@ -207,7 +159,7 @@ def load_material_dico(name: str) -> dict[str, Any]:
----------
material_dico : dict
"""
return tomli.loads(Paths.gets("materials"))[name]
return json.loads(io.data_file("materials.json").read_text())[name]
def save_data(data: Union[np.ndarray, MutableMapping], data_dir: Path, file_name: str):