new ConfigFile class
This commit is contained in:
@@ -1,25 +1,20 @@
|
||||
from . import math
|
||||
from .legacy import convert_sim_folder
|
||||
from .math import abs2, argclosest, span
|
||||
from .parameter import Configuration, Parameters
|
||||
from .physics import fiber, materials, pulse, simulate, units
|
||||
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
|
||||
from .physics.units import PlotRange
|
||||
from .plotting import (
|
||||
get_extent,
|
||||
mean_values_plot,
|
||||
plot_spectrogram,
|
||||
propagation_plot,
|
||||
single_position_plot,
|
||||
transform_2D_propagation,
|
||||
transform_1D_values,
|
||||
transform_2D_propagation,
|
||||
transform_mean_values,
|
||||
get_extent,
|
||||
)
|
||||
from .spectra import Spectrum, SimulationSeries
|
||||
from ._utils import Paths, _open_config, parameter, open_single_config
|
||||
from ._utils.parameter import Configuration, Parameters
|
||||
from ._utils.utils import PlotRange
|
||||
from ._utils.legacy import convert_sim_folder
|
||||
from ._utils.variationer import (
|
||||
Variationer,
|
||||
VariationDescriptor,
|
||||
VariationSpecsError,
|
||||
DescriptorDict,
|
||||
)
|
||||
from .spectra import SimulationSeries, Spectrum
|
||||
from .utils import Paths, _open_config, open_single_config
|
||||
from .variationer import DescriptorDict, VariationDescriptor, Variationer, VariationSpecsError
|
||||
|
||||
@@ -1,260 +0,0 @@
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from string import printable as str_printable
|
||||
from typing import Any, Callable, Iterator, Set
|
||||
|
||||
import numpy as np
|
||||
import toml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._utils import load_toml, save_toml
|
||||
from ..const import PARAM_FN, PARAM_SEPARATOR, Z_FN
|
||||
from ..physics.units import get_unit
|
||||
|
||||
|
||||
class HashableBaseModel(BaseModel):
|
||||
"""Pydantic BaseModel that's immutable and can be hashed"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
|
||||
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
|
||||
|
||||
def to_62(i: int) -> str:
|
||||
arr = []
|
||||
if i == 0:
|
||||
return "0"
|
||||
i = abs(i)
|
||||
while i:
|
||||
i, value = divmod(i, 62)
|
||||
arr.append(str_printable[value])
|
||||
return "".join(reversed(arr))
|
||||
|
||||
|
||||
class PlotRange(HashableBaseModel):
|
||||
left: float
|
||||
right: float
|
||||
unit: Callable[[float], float]
|
||||
conserved_quantity: bool = True
|
||||
|
||||
def __init__(self, left, right, unit, **kwargs):
|
||||
super().__init__(left=left, right=right, unit=get_unit(unit), **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||
|
||||
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
return sort_axis(axis, self)
|
||||
|
||||
def __iter__(self):
|
||||
yield self.left
|
||||
yield self.right
|
||||
yield self.unit.__name__
|
||||
|
||||
|
||||
def sort_axis(
|
||||
axis: np.ndarray, plt_range: PlotRange
|
||||
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
"""
|
||||
given an axis, returns this axis cropped according to the given range, converted and sorted
|
||||
|
||||
Parameters
|
||||
----------
|
||||
axis : 1D array containing the original axis (usual the w or t array)
|
||||
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
||||
|
||||
Returns
|
||||
-------
|
||||
cropped : the axis cropped, converted and sorted
|
||||
indices : indices to use to slice and sort other array in the same fashion
|
||||
extent : tupple with min and max of cropped
|
||||
|
||||
Example
|
||||
-------
|
||||
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
||||
t = np.linspace(-10, 10, 400)
|
||||
W, T = np.meshgrid(w, t)
|
||||
y = np.exp(-W**2 - T**2)
|
||||
|
||||
# Define ranges
|
||||
rw = (-4, 4, s)
|
||||
rt = (-2, 6, s)
|
||||
|
||||
w, cw = sort_axis(w, rw)
|
||||
t, ct = sort_axis(t, rt)
|
||||
|
||||
# slice y according to the given ranges
|
||||
y = y[ct][:, cw]
|
||||
"""
|
||||
if isinstance(plt_range, tuple):
|
||||
plt_range = PlotRange(*plt_range)
|
||||
r = np.array((plt_range.left, plt_range.right), dtype="float")
|
||||
|
||||
indices = np.arange(len(axis))[
|
||||
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
|
||||
]
|
||||
cropped = axis[indices]
|
||||
order = np.argsort(plt_range.unit.inv(cropped))
|
||||
indices = indices[order]
|
||||
cropped = cropped[order]
|
||||
out_ax = plt_range.unit.inv(cropped)
|
||||
|
||||
return out_ax, indices, (out_ax[0], out_ax[-1])
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
# spec = inspect.getfullargspec(func)
|
||||
# args = spec.args
|
||||
# if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
# args = args[: -len(spec.defaults)]
|
||||
# return args
|
||||
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
out_func = scope[tmp_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
@cache
|
||||
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
||||
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
||||
return_str = ", ".join("True" for _ in range(num_returns))
|
||||
func_name = f"__mock_{num_args}_{num_returns}"
|
||||
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
||||
scope = {}
|
||||
exec(func_str, scope)
|
||||
out_func = scope[func_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
def combine_simulations(path: Path, dest: Path = None):
|
||||
"""combines raw simulations into one folder per branch
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Path
|
||||
source of the simulations (must contain u_xx directories)
|
||||
dest : Path, optional
|
||||
if given, moves the simulations to dest, by default None
|
||||
"""
|
||||
paths: dict[str, list[Path]] = defaultdict(list)
|
||||
if dest is None:
|
||||
dest = path
|
||||
|
||||
for p in path.glob("u_*b_*"):
|
||||
if p.is_dir():
|
||||
paths[p.name.split()[1]].append(p)
|
||||
for l in paths.values():
|
||||
l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0])
|
||||
for pulses in paths.values():
|
||||
new_path = dest / update_path(pulses[0].name)
|
||||
os.makedirs(new_path, exist_ok=True)
|
||||
for num, pulse in enumerate(pulses):
|
||||
params_ok = False
|
||||
for file in pulse.glob("*"):
|
||||
if file.name == PARAM_FN:
|
||||
if not params_ok:
|
||||
update_params(new_path, file)
|
||||
params_ok = True
|
||||
else:
|
||||
file.unlink()
|
||||
elif file.name == Z_FN:
|
||||
file.rename(new_path / file.name)
|
||||
elif file.name.startswith("spectr") and num == 0:
|
||||
file.rename(new_path / file.name)
|
||||
else:
|
||||
file.rename(new_path / (file.stem + f"_{num}" + file.suffix))
|
||||
pulse.rmdir()
|
||||
|
||||
|
||||
def update_params(new_path: Path, file: Path):
|
||||
params = load_toml(file)
|
||||
if (p := params.get("prev_data_dir")) is not None:
|
||||
p = Path(p)
|
||||
params["prev_data_dir"] = str(p.parent / update_path(p.name))
|
||||
params["output_path"] = str(new_path)
|
||||
save_toml(new_path / PARAM_FN, params)
|
||||
file.unlink()
|
||||
|
||||
|
||||
def save_parameters(
|
||||
params: dict[str, Any], destination_dir: Path, file_name: str = PARAM_FN
|
||||
) -> Path:
|
||||
"""saves a parameter dictionary. Note that is does remove some entries, particularly
|
||||
those that take a lot of space ("t", "w", ...)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict[str, Any]
|
||||
dictionary to save
|
||||
destination_dir : Path
|
||||
destination directory
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
path to newly created the paramter file
|
||||
"""
|
||||
file_path = destination_dir / file_name
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
|
||||
# save toml of the simulation
|
||||
with open(file_path, "w") as file:
|
||||
toml.dump(params, file, encoder=toml.TomlNumpyEncoder())
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def update_path(p: str) -> str:
|
||||
return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p)
|
||||
|
||||
|
||||
def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str:
|
||||
return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name])
|
||||
|
||||
|
||||
def simulations_list(path: os.PathLike) -> list[Path]:
|
||||
"""finds simulations folders contained in a parent directory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
parent path
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Path]
|
||||
Absolute Path to the simulation folder
|
||||
"""
|
||||
paths: list[Path] = []
|
||||
for pwd, _, files in os.walk(path):
|
||||
if PARAM_FN in files:
|
||||
paths.append(Path(pwd))
|
||||
paths.sort(key=lambda el: el.parent.name)
|
||||
return [p for p in paths if p.parent.name == paths[-1].parent.name]
|
||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from .. import _utils as utils
|
||||
from .. import utils
|
||||
from .. import const, env, scripts
|
||||
from ..logger import get_logger
|
||||
from ..physics.fiber import dispersion_coefficients
|
||||
|
||||
@@ -7,11 +7,11 @@ from typing import Any, Set
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
from ..const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1, Z_FN
|
||||
from .const import SPEC1_FN, SPEC1_FN_N, SPECN_FN1
|
||||
from .parameter import Configuration, Parameters
|
||||
from .utils import save_parameters
|
||||
from .pbar import PBars
|
||||
from .variationer import VariationDescriptor, Variationer
|
||||
from .variationer import VariationDescriptor
|
||||
|
||||
|
||||
def load_config(path: os.PathLike) -> dict[str, Any]:
|
||||
@@ -3,7 +3,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
from scipy.interpolate import griddata, interp1d
|
||||
from scipy.special import jn_zeros
|
||||
from ._utils.cache import np_cache
|
||||
from .cache import np_cache
|
||||
|
||||
pi = np.pi
|
||||
c = 299792458.0
|
||||
|
||||
@@ -8,34 +8,22 @@ import os
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from copy import copy, deepcopy
|
||||
from copy import copy
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from functools import cache, lru_cache
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib import isin
|
||||
|
||||
from .. import _utils as utils
|
||||
from .. import env, math
|
||||
from .._utils.variationer import VariationDescriptor, Variationer
|
||||
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
|
||||
from ..errors import EvaluatorError, NoDefaultError
|
||||
from ..logger import get_logger
|
||||
from ..physics import fiber, materials, pulse, units
|
||||
from . import env, math, utils
|
||||
from .const import PARAM_FN, __version__
|
||||
from .errors import EvaluatorError, NoDefaultError
|
||||
from .logger import get_logger
|
||||
from .physics import fiber, materials, pulse, units
|
||||
from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path
|
||||
from .variationer import VariationDescriptor, Variationer
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Iterable, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..env import pbar_policy
|
||||
from .env import pbar_policy
|
||||
|
||||
T_ = typing.TypeVar("T_")
|
||||
|
||||
@@ -10,7 +10,8 @@ from scipy.optimize import minimize_scalar
|
||||
|
||||
from .. import math
|
||||
from . import fiber, materials, units, pulse
|
||||
from .._utils import cache, load_material_dico
|
||||
from ..cache import np_cache
|
||||
from ..utils import load_material_dico
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -21,7 +22,7 @@ def group_delay_to_gdd(wavelength: np.ndarray, group_delay: np.ndarray) -> np.nd
|
||||
return gdd
|
||||
|
||||
|
||||
@cache.np_cache
|
||||
@np_cache
|
||||
def material_dispersion(
|
||||
wavelengths: np.ndarray,
|
||||
material: str,
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
from typing import Any, Iterable, Literal, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy.fft import fft, ifft
|
||||
from numpy import e
|
||||
from numpy.fft import fft, ifft
|
||||
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from .. import utils
|
||||
from ..cache import np_cache
|
||||
from ..logger import get_logger
|
||||
|
||||
from .. import _utils as utils
|
||||
from ..math import abs2, argclosest, power_fact, u_nm
|
||||
from .._utils.cache import np_cache
|
||||
from . import materials as mat
|
||||
from . import units
|
||||
from .units import c, pi
|
||||
|
||||
|
||||
pipi = 2 * pi
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import scipy.special
|
||||
from scipy.integrate import cumulative_trapezoid
|
||||
|
||||
from .. import utils
|
||||
from ..cache import np_cache
|
||||
from ..logger import get_logger
|
||||
from . import units
|
||||
from .. import _utils
|
||||
from .units import NA, c, kB, me, e, hbar
|
||||
from .._utils.cache import np_cache
|
||||
from .units import NA, c, e, hbar, kB, me
|
||||
|
||||
|
||||
@np_cache
|
||||
@@ -15,7 +16,7 @@ def n_gas_2(
|
||||
wl_for_disp: np.ndarray, gas_name: str, pressure: float, temperature: float, ideal_gas: bool
|
||||
):
|
||||
"""Returns the sqare of the index of refraction of the specified gas"""
|
||||
material_dico = _utils.load_material_dico(gas_name)
|
||||
material_dico = utils.load_material_dico(gas_name)
|
||||
|
||||
if ideal_gas:
|
||||
n_gas_2 = sellmeier(wl_for_disp, material_dico, pressure, temperature) + 1
|
||||
@@ -218,7 +219,7 @@ def gas_n2(gas_name: str, pressure: float, temperature: float) -> float:
|
||||
float
|
||||
n2 in m2/W
|
||||
"""
|
||||
return non_linear_refractive_index(_utils.load_material_dico(gas_name), pressure, temperature)
|
||||
return non_linear_refractive_index(utils.load_material_dico(gas_name), pressure, temperature)
|
||||
|
||||
|
||||
def adiabadicity(w: np.ndarray, I: float, field: np.ndarray) -> np.ndarray:
|
||||
|
||||
@@ -7,17 +7,13 @@ from pathlib import Path
|
||||
from typing import Any, Generator, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from send2trash import send2trash
|
||||
|
||||
from .. import env
|
||||
from .. import _utils as utils
|
||||
from .._utils.utils import combine_simulations, save_parameters
|
||||
from .. import utils
|
||||
from ..logger import get_logger
|
||||
from .._utils.parameter import Configuration, Parameters
|
||||
from .._utils.pbar import PBars, ProgressBarActor, progress_worker
|
||||
from ..parameter import Configuration, Parameters
|
||||
from ..pbar import PBars, ProgressBarActor, progress_worker
|
||||
from . import pulse
|
||||
from .fiber import create_non_linear_op, fast_dispersion_op
|
||||
from scgenerator._utils import pbar
|
||||
|
||||
try:
|
||||
import ray
|
||||
@@ -505,7 +501,7 @@ class Simulations:
|
||||
for variable, params in self.configuration:
|
||||
params.compute()
|
||||
v_list_str = variable.formatted_descriptor(True)
|
||||
save_parameters(params.prepare_for_dump(), Path(params.output_path))
|
||||
utils.save_parameters(params.prepare_for_dump(), Path(params.output_path))
|
||||
|
||||
self.new_sim(v_list_str, params)
|
||||
self.finish()
|
||||
@@ -737,7 +733,7 @@ def run_simulation(
|
||||
sim.run()
|
||||
|
||||
for path in config.fiber_paths:
|
||||
combine_simulations(path)
|
||||
utils.combine_simulations(path)
|
||||
|
||||
|
||||
def new_simulation(
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
|
||||
|
||||
from typing import Callable, TypeVar, Union
|
||||
|
||||
from operator import itemgetter
|
||||
import numpy as np
|
||||
from numpy import pi
|
||||
|
||||
@@ -274,3 +274,75 @@ def to_log2D(arr, ref=None):
|
||||
m = arr / ref
|
||||
m = 10 * np.log10(m, out=np.zeros_like(m) - 100, where=m > 0)
|
||||
return m
|
||||
|
||||
|
||||
class PlotRange(tuple):
|
||||
left: float = property(itemgetter(0))
|
||||
right: float = property(itemgetter(1))
|
||||
unit: Callable[[float], float] = property(itemgetter(2))
|
||||
conserved_quantity: bool = property(itemgetter(3))
|
||||
__slots__ = []
|
||||
|
||||
def __new__(cls, left, right, unit, conserved_quantity=True):
|
||||
return tuple.__new__(cls, (left, right, get_unit(unit), conserved_quantity))
|
||||
|
||||
def __iter__(self):
|
||||
yield self.left
|
||||
yield self.right
|
||||
yield self.unit.__name__
|
||||
|
||||
def __repr__(self):
|
||||
return "PlotRange" + super().__repr__()
|
||||
|
||||
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
return sort_axis(axis, self)
|
||||
|
||||
|
||||
def sort_axis(
|
||||
axis: np.ndarray, plt_range: PlotRange
|
||||
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
"""
|
||||
given an axis, returns this axis cropped according to the given range, converted and sorted
|
||||
|
||||
Parameters
|
||||
----------
|
||||
axis : 1D array containing the original axis (usual the w or t array)
|
||||
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
||||
|
||||
Returns
|
||||
-------
|
||||
cropped : the axis cropped, converted and sorted
|
||||
indices : indices to use to slice and sort other array in the same fashion
|
||||
extent : tupple with min and max of cropped
|
||||
|
||||
Example
|
||||
-------
|
||||
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
||||
t = np.linspace(-10, 10, 400)
|
||||
W, T = np.meshgrid(w, t)
|
||||
y = np.exp(-W**2 - T**2)
|
||||
|
||||
# Define ranges
|
||||
rw = (-4, 4, s)
|
||||
rt = (-2, 6, s)
|
||||
|
||||
w, cw = sort_axis(w, rw)
|
||||
t, ct = sort_axis(t, rt)
|
||||
|
||||
# slice y according to the given ranges
|
||||
y = y[ct][:, cw]
|
||||
"""
|
||||
if not isinstance(plt_range, PlotRange):
|
||||
plt_range = PlotRange(*plt_range)
|
||||
r = np.array((plt_range.left, plt_range.right), dtype="float")
|
||||
|
||||
indices = np.arange(len(axis))[
|
||||
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
|
||||
]
|
||||
cropped = axis[indices]
|
||||
order = np.argsort(plt_range.unit.inv(cropped))
|
||||
indices = indices[order]
|
||||
cropped = cropped[order]
|
||||
out_ax = plt_range.unit.inv(cropped)
|
||||
|
||||
return out_ax, indices, (out_ax[0], out_ax[-1])
|
||||
|
||||
@@ -11,13 +11,13 @@ from scipy.interpolate import UnivariateSpline
|
||||
from scipy.interpolate.interpolate import interp1d
|
||||
|
||||
from . import math
|
||||
from ._utils import load_spectrum
|
||||
from ._utils.parameter import Parameters
|
||||
from ._utils.utils import PlotRange, sort_axis
|
||||
from .const import PARAM_SEPARATOR, SPEC1_FN
|
||||
from .defaults import default_plotting as defaults
|
||||
from .math import abs2, span
|
||||
from .parameter import Parameters
|
||||
from .physics import pulse, units
|
||||
from .physics.units import PlotRange, sort_axis
|
||||
from .utils import load_spectrum
|
||||
|
||||
RangeType = tuple[float, float, Union[str, Callable]]
|
||||
NO_LIM = object()
|
||||
|
||||
@@ -9,15 +9,11 @@ from tqdm import tqdm
|
||||
|
||||
from .. import env, math
|
||||
from ..const import PARAM_FN, PARAM_SEPARATOR
|
||||
from ..parameter import Configuration, Parameters
|
||||
from ..physics import fiber, units
|
||||
from ..plotting import plot_setup
|
||||
from ..spectra import SimulationSeries
|
||||
from .._utils import auto_crop, _open_config, save_toml, translate_parameters
|
||||
from .._utils.parameter import (
|
||||
Configuration,
|
||||
Parameters,
|
||||
)
|
||||
from .._utils.utils import simulations_list
|
||||
from ..utils import _open_config, auto_crop, save_toml, simulations_list, translate_parameters
|
||||
|
||||
|
||||
def fingerprint(params: Parameters):
|
||||
|
||||
@@ -9,8 +9,8 @@ from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .._utils import Paths
|
||||
from .._utils.parameter import Configuration
|
||||
from ..utils import Paths
|
||||
from ..parameter import Configuration
|
||||
|
||||
|
||||
def primes(n):
|
||||
|
||||
@@ -8,18 +8,18 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from . import math
|
||||
from ._utils import load_spectrum
|
||||
from ._utils.parameter import Parameters
|
||||
from ._utils.utils import PlotRange, simulations_list
|
||||
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N
|
||||
from .logger import get_logger
|
||||
from .parameter import Parameters
|
||||
from .physics import pulse, units
|
||||
from .physics.units import PlotRange
|
||||
from .plotting import (
|
||||
mean_values_plot,
|
||||
propagation_plot,
|
||||
single_position_plot,
|
||||
transform_2D_propagation,
|
||||
)
|
||||
from .utils import load_spectrum, simulations_list
|
||||
|
||||
|
||||
class Spectrum(np.ndarray):
|
||||
|
||||
@@ -3,33 +3,27 @@ This files includes utility functions designed more or less to be used specifica
|
||||
scgenerator module but some function may be used in any python program
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
from collections import abc
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from string import printable as str_printable
|
||||
from functools import cache
|
||||
from typing import Any, MutableMapping, Sequence, TypeVar
|
||||
|
||||
from typing import Any, Callable, MutableMapping, Sequence, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib.arraysetops import isin
|
||||
import pkg_resources as pkg
|
||||
import toml
|
||||
from tqdm import tqdm
|
||||
import itertools
|
||||
|
||||
|
||||
from ..const import SPEC1_FN, __version__
|
||||
from ..logger import get_logger
|
||||
from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, Z_FN
|
||||
from .logger import get_logger
|
||||
|
||||
T_ = TypeVar("T_")
|
||||
|
||||
PathTree = list[tuple[Path, ...]]
|
||||
|
||||
|
||||
class Paths:
|
||||
_data_files = [
|
||||
@@ -79,6 +73,42 @@ class Paths:
|
||||
return os.path.join(cls.get("plots"), name)
|
||||
|
||||
|
||||
class ConfigFileParser:
|
||||
path: Path
|
||||
repeat: int
|
||||
master: ConfigFileParser.SubConfig
|
||||
configs: list[ConfigFileParser.SubConfig]
|
||||
|
||||
@dataclass
|
||||
class SubConfig:
|
||||
fixed: dict[str, Any]
|
||||
variable: dict[str, list]
|
||||
|
||||
def __init__(self, path: os.PathLike):
|
||||
self.path = Path(path)
|
||||
fiber_list: list[dict[str, Any]]
|
||||
if self.path.name.lower().endswith(".toml"):
|
||||
loaded_config = _open_config(self.path)
|
||||
fiber_list = loaded_config.pop("Fiber")
|
||||
else:
|
||||
loaded_config = dict(name=self.path.name)
|
||||
fiber_list = [_open_config(p) for p in sorted(self.path.glob("initial_config*.toml"))]
|
||||
|
||||
if len(fiber_list) == 0:
|
||||
raise ValueError(f"No fiber in config {self.path}")
|
||||
final_path = loaded_config.get("name")
|
||||
configs = []
|
||||
for i, params in enumerate(fiber_list):
|
||||
configs.append(loaded_config | params)
|
||||
for root_vary, first_vary in itertools.product(
|
||||
loaded_config["variable"], configs[0]["variable"]
|
||||
):
|
||||
if len(common := root_vary.keys() & first_vary.keys()) != 0:
|
||||
raise ValueError(f"These variable keys are specified twice : {common!r}")
|
||||
configs[0] |= {k: v for k, v in loaded_config.items() if k != "variable"}
|
||||
configs[0]["variable"].append(dict(num=list(range(configs[0].get("repeat", 1)))))
|
||||
|
||||
|
||||
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
|
||||
prev_data_dir = Path(prev_data_dir)
|
||||
num = find_last_spectrum_num(prev_data_dir)
|
||||
@@ -329,3 +359,166 @@ def translate_parameters(d: dict[str, Any]) -> dict[str, Any]:
|
||||
else:
|
||||
new[old_names.get(k, k)] = v
|
||||
return defaults_to_add | new
|
||||
|
||||
|
||||
def to_62(i: int) -> str:
|
||||
arr = []
|
||||
if i == 0:
|
||||
return "0"
|
||||
i = abs(i)
|
||||
while i:
|
||||
i, value = divmod(i, 62)
|
||||
arr.append(str_printable[value])
|
||||
return "".join(reversed(arr))
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
# spec = inspect.getfullargspec(func)
|
||||
# args = spec.args
|
||||
# if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
# args = args[: -len(spec.defaults)]
|
||||
# return args
|
||||
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
out_func = scope[tmp_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
@cache
|
||||
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
||||
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
||||
return_str = ", ".join("True" for _ in range(num_returns))
|
||||
func_name = f"__mock_{num_args}_{num_returns}"
|
||||
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
||||
scope = {}
|
||||
exec(func_str, scope)
|
||||
out_func = scope[func_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
def combine_simulations(path: Path, dest: Path = None):
|
||||
"""combines raw simulations into one folder per branch
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Path
|
||||
source of the simulations (must contain u_xx directories)
|
||||
dest : Path, optional
|
||||
if given, moves the simulations to dest, by default None
|
||||
"""
|
||||
paths: dict[str, list[Path]] = defaultdict(list)
|
||||
if dest is None:
|
||||
dest = path
|
||||
|
||||
for p in path.glob("u_*b_*"):
|
||||
if p.is_dir():
|
||||
paths[p.name.split()[1]].append(p)
|
||||
for l in paths.values():
|
||||
l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0])
|
||||
for pulses in paths.values():
|
||||
new_path = dest / update_path(pulses[0].name)
|
||||
os.makedirs(new_path, exist_ok=True)
|
||||
for num, pulse in enumerate(pulses):
|
||||
params_ok = False
|
||||
for file in pulse.glob("*"):
|
||||
if file.name == PARAM_FN:
|
||||
if not params_ok:
|
||||
update_params(new_path, file)
|
||||
params_ok = True
|
||||
else:
|
||||
file.unlink()
|
||||
elif file.name == Z_FN:
|
||||
file.rename(new_path / file.name)
|
||||
elif file.name.startswith("spectr") and num == 0:
|
||||
file.rename(new_path / file.name)
|
||||
else:
|
||||
file.rename(new_path / (file.stem + f"_{num}" + file.suffix))
|
||||
pulse.rmdir()
|
||||
|
||||
|
||||
def update_params(new_path: Path, file: Path):
|
||||
params = load_toml(file)
|
||||
if (p := params.get("prev_data_dir")) is not None:
|
||||
p = Path(p)
|
||||
params["prev_data_dir"] = str(p.parent / update_path(p.name))
|
||||
params["output_path"] = str(new_path)
|
||||
save_toml(new_path / PARAM_FN, params)
|
||||
file.unlink()
|
||||
|
||||
|
||||
def save_parameters(
|
||||
params: dict[str, Any], destination_dir: Path, file_name: str = PARAM_FN
|
||||
) -> Path:
|
||||
"""saves a parameter dictionary. Note that is does remove some entries, particularly
|
||||
those that take a lot of space ("t", "w", ...)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict[str, Any]
|
||||
dictionary to save
|
||||
destination_dir : Path
|
||||
destination directory
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
path to newly created the paramter file
|
||||
"""
|
||||
file_path = destination_dir / file_name
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
|
||||
# save toml of the simulation
|
||||
with open(file_path, "w") as file:
|
||||
toml.dump(params, file, encoder=toml.TomlNumpyEncoder())
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def update_path(p: str) -> str:
|
||||
return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p)
|
||||
|
||||
|
||||
def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str:
|
||||
return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name])
|
||||
|
||||
|
||||
def simulations_list(path: os.PathLike) -> list[Path]:
|
||||
"""finds simulations folders contained in a parent directory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
parent path
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Path]
|
||||
Absolute Path to the simulation folder
|
||||
"""
|
||||
paths: list[Path] = []
|
||||
for pwd, _, files in os.walk(path):
|
||||
if PARAM_FN in files:
|
||||
paths.append(Path(pwd))
|
||||
paths.sort(key=lambda el: el.parent.name)
|
||||
return [p for p in paths if p.parent.name == paths[-1].parent.name]
|
||||
@@ -8,8 +8,7 @@ import numpy as np
|
||||
from pydantic import validator
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from ..const import PARAM_SEPARATOR
|
||||
from . import utils
|
||||
from .const import PARAM_SEPARATOR
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
Reference in New Issue
Block a user