new ConfigFile class

This commit is contained in:
Benoît Sierro
2021-10-18 11:29:17 +02:00
parent 1132b19012
commit 2a3d222d85
19 changed files with 330 additions and 351 deletions

View File

@@ -1,25 +1,20 @@
from . import math
from .legacy import convert_sim_folder
from .math import abs2, argclosest, span
from .parameter import Configuration, Parameters
from .physics import fiber, materials, pulse, simulate, units
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from .physics.units import PlotRange
from .plotting import (
get_extent,
mean_values_plot,
plot_spectrogram,
propagation_plot,
single_position_plot,
transform_2D_propagation,
transform_1D_values,
transform_2D_propagation,
transform_mean_values,
get_extent,
)
from .spectra import Spectrum, SimulationSeries
from ._utils import Paths, _open_config, parameter, open_single_config
from ._utils.parameter import Configuration, Parameters
from ._utils.utils import PlotRange
from ._utils.legacy import convert_sim_folder
from ._utils.variationer import (
Variationer,
VariationDescriptor,
VariationSpecsError,
DescriptorDict,
)
from .spectra import SimulationSeries, Spectrum
from .utils import Paths, _open_config, open_single_config
from .variationer import DescriptorDict, VariationDescriptor, Variationer, VariationSpecsError

View File

@@ -1,260 +0,0 @@
import inspect
import os
import re
from collections import defaultdict
from functools import cache
from pathlib import Path
from string import printable as str_printable
from typing import Any, Callable, Iterator, Set
import numpy as np
import toml
from pydantic import BaseModel
from .._utils import load_toml, save_toml
from ..const import PARAM_FN, PARAM_SEPARATOR, Z_FN
from ..physics.units import get_unit
class HashableBaseModel(BaseModel):
"""Pydantic BaseModel that's immutable and can be hashed"""
def __hash__(self) -> int:
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
class Config:
allow_mutation = False
def to_62(i: int) -> str:
arr = []
if i == 0:
return "0"
i = abs(i)
while i:
i, value = divmod(i, 62)
arr.append(str_printable[value])
return "".join(reversed(arr))
class PlotRange(HashableBaseModel):
left: float
right: float
unit: Callable[[float], float]
conserved_quantity: bool = True
def __init__(self, left, right, unit, **kwargs):
super().__init__(left=left, right=right, unit=get_unit(unit), **kwargs)
def __str__(self):
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
return sort_axis(axis, self)
def __iter__(self):
yield self.left
yield self.right
yield self.unit.__name__
def sort_axis(
axis: np.ndarray, plt_range: PlotRange
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
"""
given an axis, returns this axis cropped according to the given range, converted and sorted
Parameters
----------
axis : 1D array containing the original axis (usual the w or t array)
plt_range : tupple (min, max, conversion_function) used to crop the axis
Returns
-------
cropped : the axis cropped, converted and sorted
indices : indices to use to slice and sort other array in the same fashion
extent : tupple with min and max of cropped
Example
-------
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
t = np.linspace(-10, 10, 400)
W, T = np.meshgrid(w, t)
y = np.exp(-W**2 - T**2)
# Define ranges
rw = (-4, 4, s)
rt = (-2, 6, s)
w, cw = sort_axis(w, rw)
t, ct = sort_axis(t, rt)
# slice y according to the given ranges
y = y[ct][:, cw]
"""
if isinstance(plt_range, tuple):
plt_range = PlotRange(*plt_range)
r = np.array((plt_range.left, plt_range.right), dtype="float")
indices = np.arange(len(axis))[
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
]
cropped = axis[indices]
order = np.argsort(plt_range.unit.inv(cropped))
indices = indices[order]
cropped = cropped[order]
out_ax = plt_range.unit.inv(cropped)
return out_ax, indices, (out_ax[0], out_ax[-1])
def get_arg_names(func: Callable) -> list[str]:
# spec = inspect.getfullargspec(func)
# args = spec.args
# if spec.defaults is not None and len(spec.defaults) > 0:
# args = args[: -len(spec.defaults)]
# return args
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
out_func = scope[tmp_name]
out_func.__module__ = "evaluator"
return out_func
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func
def combine_simulations(path: Path, dest: Path = None):
"""combines raw simulations into one folder per branch
Parameters
----------
path : Path
source of the simulations (must contain u_xx directories)
dest : Path, optional
if given, moves the simulations to dest, by default None
"""
paths: dict[str, list[Path]] = defaultdict(list)
if dest is None:
dest = path
for p in path.glob("u_*b_*"):
if p.is_dir():
paths[p.name.split()[1]].append(p)
for l in paths.values():
l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0])
for pulses in paths.values():
new_path = dest / update_path(pulses[0].name)
os.makedirs(new_path, exist_ok=True)
for num, pulse in enumerate(pulses):
params_ok = False
for file in pulse.glob("*"):
if file.name == PARAM_FN:
if not params_ok:
update_params(new_path, file)
params_ok = True
else:
file.unlink()
elif file.name == Z_FN:
file.rename(new_path / file.name)
elif file.name.startswith("spectr") and num == 0:
file.rename(new_path / file.name)
else:
file.rename(new_path / (file.stem + f"_{num}" + file.suffix))
pulse.rmdir()
def update_params(new_path: Path, file: Path):
params = load_toml(file)
if (p := params.get("prev_data_dir")) is not None:
p = Path(p)
params["prev_data_dir"] = str(p.parent / update_path(p.name))
params["output_path"] = str(new_path)
save_toml(new_path / PARAM_FN, params)
file.unlink()
def save_parameters(
params: dict[str, Any], destination_dir: Path, file_name: str = PARAM_FN
) -> Path:
"""saves a parameter dictionary. Note that is does remove some entries, particularly
those that take a lot of space ("t", "w", ...)
Parameters
----------
params : dict[str, Any]
dictionary to save
destination_dir : Path
destination directory
Returns
-------
Path
path to newly created the paramter file
"""
file_path = destination_dir / file_name
os.makedirs(file_path.parent, exist_ok=True)
# save toml of the simulation
with open(file_path, "w") as file:
toml.dump(params, file, encoder=toml.TomlNumpyEncoder())
return file_path
def update_path(p: str) -> str:
return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p)
def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str:
return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name])
def simulations_list(path: os.PathLike) -> list[Path]:
"""finds simulations folders contained in a parent directory
Parameters
----------
path : os.PathLike
parent path
Returns
-------
list[Path]
Absolute Path to the simulation folder
"""
paths: list[Path] = []
for pwd, _, files in os.walk(path):
if PARAM_FN in files:
paths.append(Path(pwd))
paths.sort(key=lambda el: el.parent.name)
return [p for p in paths if p.parent.name == paths[-1].parent.name]

View File

@@ -7,7 +7,7 @@ from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from .. import _utils as utils
from .. import utils
from .. import const, env, scripts
from ..logger import get_logger
from ..physics.fiber import dispersion_coefficients

View File

@@ -7,11 +7,11 @@ from typing import Any, Set
import numpy as np
import toml
from ..const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1, Z_FN
from .const import SPEC1_FN, SPEC1_FN_N, SPECN_FN1
from .parameter import Configuration, Parameters
from .utils import save_parameters
from .pbar import PBars
from .variationer import VariationDescriptor, Variationer
from .variationer import VariationDescriptor
def load_config(path: os.PathLike) -> dict[str, Any]:

View File

@@ -3,7 +3,7 @@ from typing import Union
import numpy as np
from scipy.interpolate import griddata, interp1d
from scipy.special import jn_zeros
from ._utils.cache import np_cache
from .cache import np_cache
pi = np.pi
c = 299792458.0

View File

@@ -8,34 +8,22 @@ import os
import re
import time
from collections import defaultdict
from copy import copy, deepcopy
from copy import copy
from dataclasses import asdict, dataclass, fields
from functools import cache, lru_cache
from functools import lru_cache
from pathlib import Path
from typing import (
Any,
Callable,
Generator,
Iterable,
Iterator,
Literal,
Optional,
Sequence,
TypeVar,
Union,
)
from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar, Union
import numpy as np
from numpy.lib import isin
from .. import _utils as utils
from .. import env, math
from .._utils.variationer import VariationDescriptor, Variationer
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
from ..errors import EvaluatorError, NoDefaultError
from ..logger import get_logger
from ..physics import fiber, materials, pulse, units
from . import env, math, utils
from .const import PARAM_FN, __version__
from .errors import EvaluatorError, NoDefaultError
from .logger import get_logger
from .physics import fiber, materials, pulse, units
from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path
from .variationer import VariationDescriptor, Variationer
T = TypeVar("T")

View File

@@ -10,7 +10,7 @@ from typing import Iterable, Union
from tqdm import tqdm
from ..env import pbar_policy
from .env import pbar_policy
T_ = typing.TypeVar("T_")

View File

@@ -10,7 +10,8 @@ from scipy.optimize import minimize_scalar
from .. import math
from . import fiber, materials, units, pulse
from .._utils import cache, load_material_dico
from ..cache import np_cache
from ..utils import load_material_dico
T = TypeVar("T")
@@ -21,7 +22,7 @@ def group_delay_to_gdd(wavelength: np.ndarray, group_delay: np.ndarray) -> np.nd
return gdd
@cache.np_cache
@np_cache
def material_dispersion(
wavelengths: np.ndarray,
material: str,

View File

@@ -1,21 +1,19 @@
from typing import Any, Iterable, Literal, TypeVar
import numpy as np
from numpy.fft import fft, ifft
from numpy import e
from numpy.fft import fft, ifft
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
from scipy.interpolate import interp1d
from .. import utils
from ..cache import np_cache
from ..logger import get_logger
from .. import _utils as utils
from ..math import abs2, argclosest, power_fact, u_nm
from .._utils.cache import np_cache
from . import materials as mat
from . import units
from .units import c, pi
pipi = 2 * pi
T = TypeVar("T")

View File

@@ -1,13 +1,14 @@
from typing import Any, Callable
import numpy as np
import scipy.special
from scipy.integrate import cumulative_trapezoid
from .. import utils
from ..cache import np_cache
from ..logger import get_logger
from . import units
from .. import _utils
from .units import NA, c, kB, me, e, hbar
from .._utils.cache import np_cache
from .units import NA, c, e, hbar, kB, me
@np_cache
@@ -15,7 +16,7 @@ def n_gas_2(
wl_for_disp: np.ndarray, gas_name: str, pressure: float, temperature: float, ideal_gas: bool
):
"""Returns the sqare of the index of refraction of the specified gas"""
material_dico = _utils.load_material_dico(gas_name)
material_dico = utils.load_material_dico(gas_name)
if ideal_gas:
n_gas_2 = sellmeier(wl_for_disp, material_dico, pressure, temperature) + 1
@@ -218,7 +219,7 @@ def gas_n2(gas_name: str, pressure: float, temperature: float) -> float:
float
n2 in m2/W
"""
return non_linear_refractive_index(_utils.load_material_dico(gas_name), pressure, temperature)
return non_linear_refractive_index(utils.load_material_dico(gas_name), pressure, temperature)
def adiabadicity(w: np.ndarray, I: float, field: np.ndarray) -> np.ndarray:

View File

@@ -7,17 +7,13 @@ from pathlib import Path
from typing import Any, Generator, Type, Union
import numpy as np
from send2trash import send2trash
from .. import env
from .. import _utils as utils
from .._utils.utils import combine_simulations, save_parameters
from .. import utils
from ..logger import get_logger
from .._utils.parameter import Configuration, Parameters
from .._utils.pbar import PBars, ProgressBarActor, progress_worker
from ..parameter import Configuration, Parameters
from ..pbar import PBars, ProgressBarActor, progress_worker
from . import pulse
from .fiber import create_non_linear_op, fast_dispersion_op
from scgenerator._utils import pbar
try:
import ray
@@ -505,7 +501,7 @@ class Simulations:
for variable, params in self.configuration:
params.compute()
v_list_str = variable.formatted_descriptor(True)
save_parameters(params.prepare_for_dump(), Path(params.output_path))
utils.save_parameters(params.prepare_for_dump(), Path(params.output_path))
self.new_sim(v_list_str, params)
self.finish()
@@ -737,7 +733,7 @@ def run_simulation(
sim.run()
for path in config.fiber_paths:
combine_simulations(path)
utils.combine_simulations(path)
def new_simulation(

View File

@@ -3,7 +3,7 @@
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
from typing import Callable, TypeVar, Union
from operator import itemgetter
import numpy as np
from numpy import pi
@@ -274,3 +274,75 @@ def to_log2D(arr, ref=None):
m = arr / ref
m = 10 * np.log10(m, out=np.zeros_like(m) - 100, where=m > 0)
return m
class PlotRange(tuple):
left: float = property(itemgetter(0))
right: float = property(itemgetter(1))
unit: Callable[[float], float] = property(itemgetter(2))
conserved_quantity: bool = property(itemgetter(3))
__slots__ = []
def __new__(cls, left, right, unit, conserved_quantity=True):
return tuple.__new__(cls, (left, right, get_unit(unit), conserved_quantity))
def __iter__(self):
yield self.left
yield self.right
yield self.unit.__name__
def __repr__(self):
return "PlotRange" + super().__repr__()
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
return sort_axis(axis, self)
def sort_axis(
axis: np.ndarray, plt_range: PlotRange
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
"""
given an axis, returns this axis cropped according to the given range, converted and sorted
Parameters
----------
axis : 1D array containing the original axis (usual the w or t array)
plt_range : tupple (min, max, conversion_function) used to crop the axis
Returns
-------
cropped : the axis cropped, converted and sorted
indices : indices to use to slice and sort other array in the same fashion
extent : tupple with min and max of cropped
Example
-------
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
t = np.linspace(-10, 10, 400)
W, T = np.meshgrid(w, t)
y = np.exp(-W**2 - T**2)
# Define ranges
rw = (-4, 4, s)
rt = (-2, 6, s)
w, cw = sort_axis(w, rw)
t, ct = sort_axis(t, rt)
# slice y according to the given ranges
y = y[ct][:, cw]
"""
if not isinstance(plt_range, PlotRange):
plt_range = PlotRange(*plt_range)
r = np.array((plt_range.left, plt_range.right), dtype="float")
indices = np.arange(len(axis))[
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
]
cropped = axis[indices]
order = np.argsort(plt_range.unit.inv(cropped))
indices = indices[order]
cropped = cropped[order]
out_ax = plt_range.unit.inv(cropped)
return out_ax, indices, (out_ax[0], out_ax[-1])

View File

@@ -11,13 +11,13 @@ from scipy.interpolate import UnivariateSpline
from scipy.interpolate.interpolate import interp1d
from . import math
from ._utils import load_spectrum
from ._utils.parameter import Parameters
from ._utils.utils import PlotRange, sort_axis
from .const import PARAM_SEPARATOR, SPEC1_FN
from .defaults import default_plotting as defaults
from .math import abs2, span
from .parameter import Parameters
from .physics import pulse, units
from .physics.units import PlotRange, sort_axis
from .utils import load_spectrum
RangeType = tuple[float, float, Union[str, Callable]]
NO_LIM = object()

View File

@@ -9,15 +9,11 @@ from tqdm import tqdm
from .. import env, math
from ..const import PARAM_FN, PARAM_SEPARATOR
from ..parameter import Configuration, Parameters
from ..physics import fiber, units
from ..plotting import plot_setup
from ..spectra import SimulationSeries
from .._utils import auto_crop, _open_config, save_toml, translate_parameters
from .._utils.parameter import (
Configuration,
Parameters,
)
from .._utils.utils import simulations_list
from ..utils import _open_config, auto_crop, save_toml, simulations_list, translate_parameters
def fingerprint(params: Parameters):

View File

@@ -9,8 +9,8 @@ from typing import Tuple
import numpy as np
from .._utils import Paths
from .._utils.parameter import Configuration
from ..utils import Paths
from ..parameter import Configuration
def primes(n):

View File

@@ -8,18 +8,18 @@ import matplotlib.pyplot as plt
import numpy as np
from . import math
from ._utils import load_spectrum
from ._utils.parameter import Parameters
from ._utils.utils import PlotRange, simulations_list
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N
from .logger import get_logger
from .parameter import Parameters
from .physics import pulse, units
from .physics.units import PlotRange
from .plotting import (
mean_values_plot,
propagation_plot,
single_position_plot,
transform_2D_propagation,
)
from .utils import load_spectrum, simulations_list
class Spectrum(np.ndarray):

View File

@@ -3,33 +3,27 @@ This files includes utility functions designed more or less to be used specifica
scgenerator module but some function may be used in any python program
"""
from __future__ import annotations
from dataclasses import dataclass
import inspect
import itertools
import os
from collections import abc
import re
from collections import defaultdict
from functools import cache
from pathlib import Path
from string import printable as str_printable
from functools import cache
from typing import Any, MutableMapping, Sequence, TypeVar
from typing import Any, Callable, MutableMapping, Sequence, TypeVar
import numpy as np
from numpy.lib.arraysetops import isin
import pkg_resources as pkg
import toml
from tqdm import tqdm
import itertools
from ..const import SPEC1_FN, __version__
from ..logger import get_logger
from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN
from .logger import get_logger
T_ = TypeVar("T_")
PathTree = list[tuple[Path, ...]]
class Paths:
_data_files = [
@@ -79,6 +73,42 @@ class Paths:
return os.path.join(cls.get("plots"), name)
class ConfigFileParser:
path: Path
repeat: int
master: ConfigFileParser.SubConfig
configs: list[ConfigFileParser.SubConfig]
@dataclass
class SubConfig:
fixed: dict[str, Any]
variable: dict[str, list]
def __init__(self, path: os.PathLike):
self.path = Path(path)
fiber_list: list[dict[str, Any]]
if self.path.name.lower().endswith(".toml"):
loaded_config = _open_config(self.path)
fiber_list = loaded_config.pop("Fiber")
else:
loaded_config = dict(name=self.path.name)
fiber_list = [_open_config(p) for p in sorted(self.path.glob("initial_config*.toml"))]
if len(fiber_list) == 0:
raise ValueError(f"No fiber in config {self.path}")
final_path = loaded_config.get("name")
configs = []
for i, params in enumerate(fiber_list):
configs.append(loaded_config | params)
for root_vary, first_vary in itertools.product(
loaded_config["variable"], configs[0]["variable"]
):
if len(common := root_vary.keys() & first_vary.keys()) != 0:
raise ValueError(f"These variable keys are specified twice : {common!r}")
configs[0] |= {k: v for k, v in loaded_config.items() if k != "variable"}
configs[0]["variable"].append(dict(num=list(range(configs[0].get("repeat", 1)))))
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
prev_data_dir = Path(prev_data_dir)
num = find_last_spectrum_num(prev_data_dir)
@@ -329,3 +359,166 @@ def translate_parameters(d: dict[str, Any]) -> dict[str, Any]:
else:
new[old_names.get(k, k)] = v
return defaults_to_add | new
def to_62(i: int) -> str:
arr = []
if i == 0:
return "0"
i = abs(i)
while i:
i, value = divmod(i, 62)
arr.append(str_printable[value])
return "".join(reversed(arr))
def get_arg_names(func: Callable) -> list[str]:
# spec = inspect.getfullargspec(func)
# args = spec.args
# if spec.defaults is not None and len(spec.defaults) > 0:
# args = args[: -len(spec.defaults)]
# return args
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
out_func = scope[tmp_name]
out_func.__module__ = "evaluator"
return out_func
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func
def combine_simulations(path: Path, dest: Path = None):
"""combines raw simulations into one folder per branch
Parameters
----------
path : Path
source of the simulations (must contain u_xx directories)
dest : Path, optional
if given, moves the simulations to dest, by default None
"""
paths: dict[str, list[Path]] = defaultdict(list)
if dest is None:
dest = path
for p in path.glob("u_*b_*"):
if p.is_dir():
paths[p.name.split()[1]].append(p)
for l in paths.values():
l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0])
for pulses in paths.values():
new_path = dest / update_path(pulses[0].name)
os.makedirs(new_path, exist_ok=True)
for num, pulse in enumerate(pulses):
params_ok = False
for file in pulse.glob("*"):
if file.name == PARAM_FN:
if not params_ok:
update_params(new_path, file)
params_ok = True
else:
file.unlink()
elif file.name == Z_FN:
file.rename(new_path / file.name)
elif file.name.startswith("spectr") and num == 0:
file.rename(new_path / file.name)
else:
file.rename(new_path / (file.stem + f"_{num}" + file.suffix))
pulse.rmdir()
def update_params(new_path: Path, file: Path):
params = load_toml(file)
if (p := params.get("prev_data_dir")) is not None:
p = Path(p)
params["prev_data_dir"] = str(p.parent / update_path(p.name))
params["output_path"] = str(new_path)
save_toml(new_path / PARAM_FN, params)
file.unlink()
def save_parameters(
params: dict[str, Any], destination_dir: Path, file_name: str = PARAM_FN
) -> Path:
"""saves a parameter dictionary. Note that is does remove some entries, particularly
those that take a lot of space ("t", "w", ...)
Parameters
----------
params : dict[str, Any]
dictionary to save
destination_dir : Path
destination directory
Returns
-------
Path
path to newly created the paramter file
"""
file_path = destination_dir / file_name
os.makedirs(file_path.parent, exist_ok=True)
# save toml of the simulation
with open(file_path, "w") as file:
toml.dump(params, file, encoder=toml.TomlNumpyEncoder())
return file_path
def update_path(p: str) -> str:
return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p)
def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str:
return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name])
def simulations_list(path: os.PathLike) -> list[Path]:
"""finds simulations folders contained in a parent directory
Parameters
----------
path : os.PathLike
parent path
Returns
-------
list[Path]
Absolute Path to the simulation folder
"""
paths: list[Path] = []
for pwd, _, files in os.walk(path):
if PARAM_FN in files:
paths.append(Path(pwd))
paths.sort(key=lambda el: el.parent.name)
return [p for p in paths if p.parent.name == paths[-1].parent.name]

View File

@@ -8,8 +8,7 @@ import numpy as np
from pydantic import validator
from pydantic.main import BaseModel
from ..const import PARAM_SEPARATOR
from . import utils
from .const import PARAM_SEPARATOR
T = TypeVar("T")