diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 32aa692..043be3d 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -1,9 +1,8 @@ -# isort: skip_file # ruff: noqa -from scgenerator import math, operators, plotting +from scgenerator import io, math, operators, plotting from scgenerator.helpers import * from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace from scgenerator.parameter import Parameters from scgenerator.physics import fiber, materials, plasma, pulse, units from scgenerator.physics.units import PlotRange -from scgenerator.solver import integrate, solve43, SimulationResult +from scgenerator.solver import SimulationResult, integrate, solve43 diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index f682929..152f84a 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -1,5 +1,12 @@ import datetime import json +from pathlib import Path + +import pkg_resources + + +def data_file(path: str) -> Path: + return Path(pkg_resources.resource_filename("scgenerator", path)) class DatetimeEncoder(json.JSONEncoder): diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 274e85d..374023b 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -6,7 +6,7 @@ from numpy.fft import fft from numpy.polynomial.chebyshev import Chebyshev, cheb2poly from scipy.interpolate import interp1d -from scgenerator import utils +from scgenerator import io from scgenerator.math import argclosest, u_nm from scgenerator.physics import materials as mat from scgenerator.physics import units @@ -784,13 +784,12 @@ def delayed_raman_t(t: np.ndarray, raman_type: str) -> np.ndarray: elif raman_type == "measured": try: - path = utils.Paths.get("hr_t") - loaded = np.load(path) + file = io.data_file("raman_response.npy") + t_stored, hr_arr_stored = np.load(file) except FileNotFoundError: print("Not able to find the measured Raman response function. Going with agrawal model") return delayed_raman_t(t, raman_type="agrawal") - t_stored, hr_arr_stored = loaded["t"], loaded["hr_arr"] hr_arr = interp1d(t_stored, hr_arr_stored, bounds_error=False, fill_value=0)(t) else: print("invalid raman response function, aborting") diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index 4c7b25e..860684f 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -17,10 +17,10 @@ from string import printable as str_printable from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Union import numpy as np -import pkg_resources as pkg import tomli import tomli_w +from scgenerator import io from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN from scgenerator.logger import get_logger @@ -40,54 +40,6 @@ class TimedMessage: return False -class Paths: - _data_files = [ - "materials.toml", - "hr_t.npz", - "submit_job_template.txt", - "start_worker.sh", - "start_head.sh", - ] - - paths = { - f.split(".")[0]: os.path.abspath( - pkg.resource_filename("scgenerator", os.path.join("data", f)) - ) - for f in _data_files - } - - @classmethod - def get(cls, key): - if key not in cls.paths: - if os.path.exists("paths.toml"): - with open("paths.toml", "rb") as file: - paths_dico = tomli.load(file) - for k, v in paths_dico.items(): - cls.paths[k] = v - if key not in cls.paths: - get_logger(__name__).info( - f"{key} was not found in path index, returning current working directory." - ) - cls.paths[key] = os.getcwd() - - return cls.paths[key] - - @classmethod - def gets(cls, key): - """returned the specified file as a string""" - with open(cls.get(key)) as file: - return file.read() - - @classmethod - def plot(cls, name): - """returns the paths to the specified plot. Used to save new plot - example - --------- - fig.savefig(Paths.plot("figure5.pdf")) - """ - return os.path.join(cls.get("plots"), name) - - def conform_variable_entry(d) -> list[dict[str, list]]: if isinstance(d, MutableMapping): d = [{k: v} for k, v in d.items()] @@ -207,7 +159,7 @@ def load_material_dico(name: str) -> dict[str, Any]: ---------- material_dico : dict """ - return tomli.loads(Paths.gets("materials"))[name] + return json.loads(io.data_file("materials.json").read_text())[name] def save_data(data: Union[np.ndarray, MutableMapping], data_dir: Path, file_name: str):