Merge branch 'new_sim_descriptor'
This commit is contained in:
@@ -141,8 +141,8 @@ hasan :
|
||||
capillary_num : int
|
||||
number of capillaries
|
||||
|
||||
capillary_outer_d : float, optional if g is specified
|
||||
outer diameter of the capillaries
|
||||
capillary_radius : float, optional if g is specified
|
||||
outer radius of the capillaries
|
||||
|
||||
capillary_thickness : float
|
||||
thickness of the capillary walls
|
||||
|
||||
19
play.py
19
play.py
@@ -4,22 +4,3 @@ import scgenerator as sc
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def _main():
|
||||
print(os.getcwd())
|
||||
for v_list, params in sc.Configuration("PM1550+PM2000D+PM1550/Pos30000.toml"):
|
||||
print(params.fiber_map)
|
||||
|
||||
|
||||
def main():
|
||||
drr = os.getcwd()
|
||||
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations")
|
||||
try:
|
||||
_main()
|
||||
finally:
|
||||
os.chdir(drr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,8 +1,25 @@
|
||||
from . import math, utils
|
||||
from . import math
|
||||
from .math import abs2, argclosest, span
|
||||
from .physics import fiber, materials, pulse, simulate, units
|
||||
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
|
||||
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
||||
from .spectra import Pulse, Spectrum
|
||||
from .utils import Paths, open_config, parameter
|
||||
from .utils.parameter import Configuration, Parameters, PlotRange
|
||||
from .plotting import (
|
||||
mean_values_plot,
|
||||
plot_spectrogram,
|
||||
propagation_plot,
|
||||
single_position_plot,
|
||||
transform_2D_propagation,
|
||||
transform_1D_values,
|
||||
transform_mean_values,
|
||||
get_extent,
|
||||
)
|
||||
from .spectra import Spectrum, SimulationSeries
|
||||
from ._utils import Paths, _open_config, parameter, open_single_config
|
||||
from ._utils.parameter import Configuration, Parameters
|
||||
from ._utils.utils import PlotRange
|
||||
from ._utils.legacy import convert_sim_folder
|
||||
from ._utils.variationer import (
|
||||
Variationer,
|
||||
VariationDescriptor,
|
||||
VariationSpecsError,
|
||||
DescriptorDict,
|
||||
)
|
||||
|
||||
325
src/scgenerator/_utils/__init__.py
Normal file
325
src/scgenerator/_utils/__init__.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
This files includes utility functions designed more or less to be used specifically with the
|
||||
scgenerator module but some function may be used in any python program
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
from collections import abc
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from string import printable as str_printable
|
||||
from functools import cache
|
||||
from typing import Any, Callable, Generator, Iterable, MutableMapping, Sequence, TypeVar, Union
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pkg_resources as pkg
|
||||
import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
from .pbar import PBars
|
||||
from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN1, Z_FN, __version__
|
||||
from ..env import pbar_policy
|
||||
from ..logger import get_logger
|
||||
|
||||
T_ = TypeVar("T_")
|
||||
|
||||
PathTree = list[tuple[Path, ...]]
|
||||
|
||||
|
||||
class Paths:
|
||||
_data_files = [
|
||||
"materials.toml",
|
||||
"hr_t.npz",
|
||||
"submit_job_template.txt",
|
||||
"start_worker.sh",
|
||||
"start_head.sh",
|
||||
]
|
||||
|
||||
paths = {
|
||||
f.split(".")[0]: os.path.abspath(
|
||||
pkg.resource_filename("scgenerator", os.path.join("data", f))
|
||||
)
|
||||
for f in _data_files
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get(cls, key):
|
||||
if key not in cls.paths:
|
||||
if os.path.exists("paths.toml"):
|
||||
with open("paths.toml") as file:
|
||||
paths_dico = toml.load(file)
|
||||
for k, v in paths_dico.items():
|
||||
cls.paths[k] = v
|
||||
if key not in cls.paths:
|
||||
get_logger(__name__).info(
|
||||
f"{key} was not found in path index, returning current working directory."
|
||||
)
|
||||
cls.paths[key] = os.getcwd()
|
||||
|
||||
return cls.paths[key]
|
||||
|
||||
@classmethod
|
||||
def gets(cls, key):
|
||||
"""returned the specified file as a string"""
|
||||
with open(cls.get(key)) as file:
|
||||
return file.read()
|
||||
|
||||
@classmethod
|
||||
def plot(cls, name):
|
||||
"""returns the paths to the specified plot. Used to save new plot
|
||||
example
|
||||
---------
|
||||
fig.savefig(Paths.plot("figure5.pdf"))
|
||||
"""
|
||||
return os.path.join(cls.get("plots"), name)
|
||||
|
||||
|
||||
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
|
||||
prev_data_dir = Path(prev_data_dir)
|
||||
num = find_last_spectrum_num(prev_data_dir)
|
||||
return load_spectrum(prev_data_dir / SPEC1_FN.format(num))
|
||||
|
||||
|
||||
@cache
|
||||
def load_spectrum(folder: os.PathLike) -> np.ndarray:
|
||||
return np.load(folder)
|
||||
|
||||
|
||||
def conform_toml_path(path: os.PathLike) -> str:
|
||||
path: str = str(path)
|
||||
if not path.lower().endswith(".toml"):
|
||||
path = path + ".toml"
|
||||
return path
|
||||
|
||||
|
||||
def open_single_config(path: os.PathLike) -> dict[str, Any]:
|
||||
d = _open_config(path)
|
||||
f = d.pop("Fiber")[0]
|
||||
return d | f
|
||||
|
||||
|
||||
def _open_config(path: os.PathLike):
|
||||
"""returns a dictionary parsed from the specified toml file
|
||||
This also handle having a 'INCLUDE' argument that will fill
|
||||
otherwise unspecified keys with what's in the INCLUDE file(s)"""
|
||||
|
||||
path = conform_toml_path(path)
|
||||
dico = resolve_loadfile_arg(load_toml(path))
|
||||
|
||||
dico.setdefault("variable", {})
|
||||
for key in {"simulation", "fiber", "gas", "pulse"} & dico.keys():
|
||||
section = dico.pop(key)
|
||||
dico["variable"].update(section.pop("variable", {}))
|
||||
dico.update(section)
|
||||
if len(dico["variable"]) == 0:
|
||||
dico.pop("variable")
|
||||
return dico
|
||||
|
||||
|
||||
def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]:
|
||||
if (f_list := dico.pop("INCLUDE", None)) is not None:
|
||||
if isinstance(f_list, str):
|
||||
f_list = [f_list]
|
||||
for to_load in f_list:
|
||||
loaded = load_toml(to_load)
|
||||
for k, v in loaded.items():
|
||||
if k not in dico and k not in dico.get("variable", {}):
|
||||
dico[k] = v
|
||||
for k, v in dico.items():
|
||||
if isinstance(v, MutableMapping):
|
||||
dico[k] = resolve_loadfile_arg(v)
|
||||
elif isinstance(v, Sequence):
|
||||
for i, vv in enumerate(v):
|
||||
if isinstance(vv, MutableMapping):
|
||||
dico[k][i] = resolve_loadfile_arg(vv)
|
||||
return dico
|
||||
|
||||
|
||||
def load_toml(descr: os.PathLike) -> dict[str, Any]:
|
||||
descr = str(descr)
|
||||
if ":" in descr:
|
||||
path, entry = descr.split(":", 1)
|
||||
with open(path) as file:
|
||||
return toml.load(file)[entry]
|
||||
else:
|
||||
with open(descr) as file:
|
||||
return toml.load(file)
|
||||
|
||||
|
||||
def save_toml(path: os.PathLike, dico):
|
||||
"""saves a dictionary into a toml file"""
|
||||
path = conform_toml_path(path)
|
||||
with open(path, mode="w") as file:
|
||||
toml.dump(dico, file)
|
||||
return dico
|
||||
|
||||
|
||||
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]:
|
||||
"""loads a configuration file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
path to the config toml file or a directory containing config files
|
||||
|
||||
Returns
|
||||
-------
|
||||
final_path : Path
|
||||
output name of the simulation
|
||||
list[dict[str, Any]]
|
||||
one config per fiber
|
||||
|
||||
"""
|
||||
path = Path(path)
|
||||
fiber_list: list[dict[str, Any]]
|
||||
if path.name.lower().endswith(".toml"):
|
||||
loaded_config = _open_config(path)
|
||||
fiber_list = loaded_config.pop("Fiber")
|
||||
else:
|
||||
loaded_config = dict(name=path.name)
|
||||
fiber_list = [_open_config(p) for p in sorted(path.glob("initial_config*.toml"))]
|
||||
|
||||
if len(fiber_list) == 0:
|
||||
raise ValueError(f"No fiber in config {path}")
|
||||
final_path = loaded_config.get("name")
|
||||
configs = []
|
||||
for i, params in enumerate(fiber_list):
|
||||
params.setdefault("variable", {})
|
||||
configs.append(loaded_config | params)
|
||||
configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"]
|
||||
configs[0]["variable"]["num"] = list(range(configs[0].get("repeat", 1)))
|
||||
|
||||
return Path(final_path), configs
|
||||
|
||||
|
||||
def load_material_dico(name: str) -> dict[str, Any]:
|
||||
"""loads a material dictionary
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
name of the material
|
||||
Returns
|
||||
----------
|
||||
material_dico : dict
|
||||
"""
|
||||
return toml.loads(Paths.gets("materials"))[name]
|
||||
|
||||
|
||||
def save_data(data: np.ndarray, data_dir: Path, file_name: str):
|
||||
"""saves numpy array to disk
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray
|
||||
data to save
|
||||
file_name : str
|
||||
file name
|
||||
task_id : int
|
||||
id that uniquely identifies the process
|
||||
identifier : str, optional
|
||||
identifier in the main data folder of the task, by default ""
|
||||
"""
|
||||
path = data_dir / file_name
|
||||
np.save(path, data)
|
||||
get_logger(__name__).debug(f"saved data in {path}")
|
||||
return
|
||||
|
||||
|
||||
def ensure_folder(path: Path, prevent_overwrite: bool = True, mkdir=True) -> Path:
|
||||
"""ensure a folder exists and doesn't overwrite anything if required
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Path
|
||||
desired path
|
||||
prevent_overwrite : bool, optional
|
||||
whether to create a new directory when one already exists, by default True
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
final path
|
||||
"""
|
||||
|
||||
path = path.resolve()
|
||||
|
||||
# is path root ?
|
||||
if len(path.parts) < 2:
|
||||
return path
|
||||
|
||||
# is a part of path an existing *file* ?
|
||||
parts = path.parts
|
||||
path = Path(path.root)
|
||||
for part in parts:
|
||||
if path.is_file():
|
||||
path = ensure_folder(path, mkdir=mkdir, prevent_overwrite=False)
|
||||
path /= part
|
||||
|
||||
folder_name = path.name
|
||||
|
||||
for i in itertools.count():
|
||||
if not path.is_file() and (not prevent_overwrite or not path.is_dir()):
|
||||
if mkdir:
|
||||
path.mkdir(exist_ok=True)
|
||||
return path
|
||||
path = path.parent / (folder_name + f"_{i}")
|
||||
|
||||
|
||||
def branch_id(branch: tuple[Path, ...]) -> str:
|
||||
return branch[-1].name.split()[1]
|
||||
|
||||
|
||||
def find_last_spectrum_num(data_dir: Path):
|
||||
for num in itertools.count(1):
|
||||
p_to_test = data_dir / SPEC1_FN.format(num)
|
||||
if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0:
|
||||
return num - 1
|
||||
|
||||
|
||||
def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray:
|
||||
threshold = y.min() + rel_thr * (y.max() - y.min())
|
||||
above_threshold = y > threshold
|
||||
ind = np.argsort(x)
|
||||
valid_ind = [
|
||||
np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k
|
||||
]
|
||||
ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0]
|
||||
width = len(ind_above)
|
||||
return np.concatenate(
|
||||
(
|
||||
np.arange(max(ind_above[0] - width, 0), ind_above[0]),
|
||||
ind_above,
|
||||
np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def translate_parameters(d: dict[str, Any]) -> dict[str, Any]:
|
||||
old_names = dict(
|
||||
interp_degree="interpolation_degree",
|
||||
beta="beta2_coefficients",
|
||||
interp_range="interpolation_range",
|
||||
)
|
||||
deleted_names = {"lower_wavelength_interp_limit", "upper_wavelength_interp_limit"}
|
||||
defaults_to_add = dict(repeat=1)
|
||||
new = {}
|
||||
for k, v in d.items():
|
||||
if k == "error_ok":
|
||||
new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v
|
||||
elif k in deleted_names:
|
||||
continue
|
||||
elif isinstance(v, MutableMapping):
|
||||
new[k] = translate_parameters(v)
|
||||
else:
|
||||
new[old_names.get(k, k)] = v
|
||||
return defaults_to_add | new
|
||||
99
src/scgenerator/_utils/legacy.py
Normal file
99
src/scgenerator/_utils/legacy.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from genericpath import exists
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from typing import Any, Set
|
||||
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
from ..const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1, Z_FN
|
||||
from .parameter import Configuration, Parameters
|
||||
from .utils import fiber_folder, save_parameters
|
||||
from .pbar import PBars
|
||||
from .variationer import VariationDescriptor, Variationer
|
||||
|
||||
|
||||
def load_config(path: os.PathLike) -> dict[str, Any]:
|
||||
with open(path) as file:
|
||||
d = toml.load(file)
|
||||
d.setdefault("variable", {})
|
||||
return d
|
||||
|
||||
|
||||
def load_config_sequence(path: os.PathLike) -> tuple[list[Path], list[dict[str, Any]]]:
|
||||
paths = sorted(list(Path(path).glob("initial_config*.toml")))
|
||||
return paths, [load_config(cfg) for cfg in paths]
|
||||
|
||||
|
||||
def convert_sim_folder(path: os.PathLike):
|
||||
path = Path(path).resolve()
|
||||
new_root = path.parent / "sc_legagy_converter" / path.name
|
||||
os.makedirs(new_root, exist_ok=True)
|
||||
config_paths, configs = load_config_sequence(path)
|
||||
master_config = dict(name=path.name, Fiber=configs)
|
||||
with open(new_root / "initial_config.toml", "w") as f:
|
||||
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
|
||||
configuration = Configuration(path, final_output_path=new_root)
|
||||
pbar = PBars(configuration.total_num_steps, "Converting")
|
||||
|
||||
new_paths: dict[VariationDescriptor, Parameters] = dict(configuration)
|
||||
old_paths: Set[Path] = set()
|
||||
old2new: list[tuple[Path, VariationDescriptor, Parameters, tuple[int, int]]] = []
|
||||
for descriptor, params in configuration.iterate_single_fiber(-1):
|
||||
old_path = path / descriptor.branch.formatted_descriptor()
|
||||
if not Path(old_path).is_dir():
|
||||
raise FileNotFoundError(f"missing {old_path} from {path}. Aborting.")
|
||||
old_paths.add(old_path)
|
||||
for d in descriptor.iter_parents():
|
||||
z_num_start = sum(c["z_num"] for c in configs[: d.num_fibers - 1])
|
||||
z_limits = (z_num_start, z_num_start + params.z_num)
|
||||
old2new.append((old_path, d, new_paths[d], z_limits))
|
||||
|
||||
processed_paths: Set[Path] = set()
|
||||
processed_specs: Set[VariationDescriptor] = set()
|
||||
|
||||
for old_path, descr, new_params, (start_z, end_z) in old2new:
|
||||
move_specs = descr not in processed_specs
|
||||
processed_specs.add(descr)
|
||||
if (parent := descr.parent) is not None:
|
||||
new_params.prev_data_dir = str(new_paths[parent].final_path)
|
||||
save_parameters(new_params.prepare_for_dump(), new_params.final_path)
|
||||
for spec_num in range(start_z, end_z):
|
||||
old_spec = old_path / SPECN_FN1.format(spec_num)
|
||||
if move_specs:
|
||||
_mv_specs(pbar, new_params, start_z, spec_num, old_spec)
|
||||
old_spec.unlink()
|
||||
if old_path not in processed_paths:
|
||||
(old_path / PARAM_FN).unlink()
|
||||
(old_path / Z_FN).unlink()
|
||||
processed_paths.add(old_path)
|
||||
|
||||
for old_path in processed_paths:
|
||||
old_path.rmdir()
|
||||
|
||||
for cp in config_paths:
|
||||
cp.unlink()
|
||||
|
||||
|
||||
def _mv_specs(pbar: PBars, new_params: Parameters, start_z: int, spec_num: int, old_spec: Path):
|
||||
os.makedirs(new_params.final_path, exist_ok=True)
|
||||
spec_data = np.load(old_spec)
|
||||
for j, spec1 in enumerate(spec_data):
|
||||
if j == 0:
|
||||
np.save(new_params.final_path / SPEC1_FN.format(spec_num - start_z), spec1)
|
||||
else:
|
||||
np.save(
|
||||
new_params.final_path / SPEC1_FN_N.format(spec_num - start_z, j),
|
||||
spec1,
|
||||
)
|
||||
pbar.update()
|
||||
|
||||
|
||||
def main():
|
||||
convert_sim_folder(sys.argv[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as datetime_module
|
||||
import enum
|
||||
import inspect
|
||||
@@ -10,15 +12,30 @@ from copy import copy, deepcopy
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from functools import cache, lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib import isin
|
||||
|
||||
from .. import math, utils
|
||||
from .. import _utils as utils
|
||||
from .. import env, math
|
||||
from .._utils.variationer import VariationDescriptor, Variationer
|
||||
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
|
||||
from ..errors import EvaluatorError, NoDefaultError
|
||||
from ..logger import get_logger
|
||||
from ..physics import fiber, materials, pulse, units
|
||||
from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -38,7 +55,7 @@ VALID_VARIABLE = {
|
||||
"effective_mode_diameter",
|
||||
"core_radius",
|
||||
"capillary_num",
|
||||
"capillary_outer_d",
|
||||
"capillary_radius",
|
||||
"capillary_thickness",
|
||||
"capillary_spacing",
|
||||
"capillary_resonance_strengths",
|
||||
@@ -69,6 +86,7 @@ VALID_VARIABLE = {
|
||||
"interpolation_degree",
|
||||
"ideal_gas",
|
||||
"length",
|
||||
"num",
|
||||
}
|
||||
|
||||
MANDATORY_PARAMETERS = [
|
||||
@@ -256,7 +274,7 @@ class Parameter:
|
||||
----------
|
||||
tpe : type
|
||||
type of the paramter
|
||||
validators : Callable[[str, Any], None]
|
||||
validator : Callable[[str, Any], None]
|
||||
signature : validator(name, value)
|
||||
must raise a ValueError when value doesn't fit the criteria checked by
|
||||
validator. name is passed to validator to be included in the error message
|
||||
@@ -290,7 +308,6 @@ class Parameter:
|
||||
if isinstance(value, Parameter):
|
||||
defaut = None if self.default is None else copy(self.default)
|
||||
instance.__dict__[self.name] = defaut
|
||||
# instance.__dict__[self.name] = None
|
||||
else:
|
||||
if value is not None:
|
||||
if self.converter is not None:
|
||||
@@ -298,7 +315,7 @@ class Parameter:
|
||||
self.validator(self.name, value)
|
||||
instance.__dict__[self.name] = value
|
||||
|
||||
def display(self, num: float):
|
||||
def display(self, num: float) -> str:
|
||||
if self.display_info is None:
|
||||
return str(num)
|
||||
else:
|
||||
@@ -309,18 +326,23 @@ class Parameter:
|
||||
return f"{num_str} {unit}"
|
||||
|
||||
|
||||
def fiber_map_converter(d: dict[str, str]) -> list[tuple[float, str]]:
|
||||
if isinstance(d, dict):
|
||||
return [(float(k), v) for k, v in d.items()]
|
||||
else:
|
||||
return [(float(k), v) for k, v in d]
|
||||
@dataclass
|
||||
class _AbstractParameters:
|
||||
@classmethod
|
||||
def __init_subclass__(cls):
|
||||
cls.register_param_formatters()
|
||||
|
||||
@classmethod
|
||||
def register_param_formatters(cls):
|
||||
for k, v in cls.__dict__.items():
|
||||
if isinstance(v, Parameter):
|
||||
VariationDescriptor.register_formatter(k, v.display)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Parameters:
|
||||
class Parameters(_AbstractParameters):
|
||||
"""
|
||||
This class defines each valid parameter's name, type and valid value. Initializing
|
||||
such an obj will automatically compute all possible parameters
|
||||
This class defines each valid parameter's name, type and valid value.
|
||||
"""
|
||||
|
||||
# root
|
||||
@@ -352,7 +374,7 @@ class Parameters:
|
||||
)
|
||||
length: float = Parameter(non_negative(float, int))
|
||||
capillary_num: int = Parameter(positive(int))
|
||||
capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3))
|
||||
capillary_radius: float = Parameter(in_range_excl(0, 1e-3))
|
||||
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3))
|
||||
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3))
|
||||
capillary_resonance_strengths: Iterable[float] = Parameter(num_list, default=[])
|
||||
@@ -430,15 +452,13 @@ class Parameters:
|
||||
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
beta_func: Callable[[float], list[float]] = Parameter(func_validator)
|
||||
gamma_func: Callable[[float], float] = Parameter(func_validator)
|
||||
fiber_map: list[tuple[float, str]] = Parameter(
|
||||
validator_list(type_checker(tuple)), converter=fiber_map_converter
|
||||
)
|
||||
|
||||
num: int = Parameter(non_negative(int))
|
||||
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
||||
version: str = Parameter(string)
|
||||
|
||||
def prepare_for_dump(self) -> dict[str, Any]:
|
||||
param = asdict(self)
|
||||
param["fiber_map"] = [(str(z), n) for z, n in param.get("fiber_map", [])]
|
||||
param = Parameters.strip_params_dict(param)
|
||||
param["datetime"] = datetime_module.datetime.now()
|
||||
param["version"] = __version__
|
||||
@@ -461,7 +481,7 @@ class Parameters:
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: os.PathLike) -> "Parameters":
|
||||
return cls(**utils.open_config(path))
|
||||
return cls(**utils._open_config(path))
|
||||
|
||||
@classmethod
|
||||
def load_and_compute(cls, path: os.PathLike) -> "Parameters":
|
||||
@@ -512,6 +532,12 @@ class Parameters:
|
||||
|
||||
return out
|
||||
|
||||
@property
|
||||
def final_path(self) -> Path:
|
||||
if self.output_path is not None:
|
||||
return Path(update_path(self.output_path))
|
||||
return None
|
||||
|
||||
|
||||
class Rule:
|
||||
def __init__(
|
||||
@@ -769,9 +795,12 @@ class Configuration:
|
||||
obj with the output path of the simulation saved in its output_path attribute.
|
||||
"""
|
||||
|
||||
master_configs: list[dict[str, Any]]
|
||||
sim_dirs: list[Path]
|
||||
fiber_configs: list[dict[str, Any]]
|
||||
vary_dicts: list[dict[str, list]]
|
||||
master_config: dict[str, Any]
|
||||
fiber_paths: list[Path]
|
||||
num_sim: int
|
||||
num_fibers: int
|
||||
repeat: int
|
||||
z_num: int
|
||||
total_num_steps: int
|
||||
@@ -779,19 +808,17 @@ class Configuration:
|
||||
parallel: bool
|
||||
overwrite: bool
|
||||
final_path: str
|
||||
all_configs_dict: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
|
||||
all_configs_list: list[list["Configuration.__SimConfig"]]
|
||||
all_configs: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class __SimConfig:
|
||||
vary_list: list[tuple[str, Any]]
|
||||
descriptor: VariationDescriptor
|
||||
config: dict[str, Any]
|
||||
output_path: Path
|
||||
index: tuple[tuple[int, ...], ...]
|
||||
|
||||
@property
|
||||
def sim_num(self) -> int:
|
||||
return len(self.index)
|
||||
return len(self.descriptor.index)
|
||||
|
||||
class State(enum.Enum):
|
||||
COMPLETE = enum.auto()
|
||||
@@ -805,57 +832,70 @@ class Configuration:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
final_config_path: os.PathLike,
|
||||
config_path: os.PathLike,
|
||||
overwrite: bool = True,
|
||||
wait: bool = False,
|
||||
skip_callback: Callable[[int], None] = None,
|
||||
final_output_path: os.PathLike = None,
|
||||
):
|
||||
self.logger = get_logger(__name__)
|
||||
self.wait = wait
|
||||
|
||||
self.master_configs, self.final_path = utils.load_config_sequence(final_config_path)
|
||||
if self.final_path is None:
|
||||
self.final_path = Parameters.name.default
|
||||
self.name = Path(self.final_path).name
|
||||
self.overwrite = overwrite
|
||||
self.final_path, self.fiber_configs = utils.load_config_sequence(config_path)
|
||||
self.final_path = env.get(env.OUTPUT_PATH, self.final_path)
|
||||
if final_output_path is not None:
|
||||
self.final_path = final_output_path
|
||||
self.final_path = utils.ensure_folder(
|
||||
Path(self.final_path),
|
||||
mkdir=False,
|
||||
prevent_overwrite=not self.overwrite,
|
||||
)
|
||||
self.master_config = self.fiber_configs[0].copy()
|
||||
self.name = self.final_path.name
|
||||
self.z_num = 0
|
||||
self.total_num_steps = 0
|
||||
self.sim_dirs = []
|
||||
self.overwrite = overwrite
|
||||
self.fiber_paths = []
|
||||
self.all_configs = {}
|
||||
self.skip_callback = skip_callback
|
||||
self.worker_num = self.master_configs[0].get("worker_num", max(1, os.cpu_count() // 2))
|
||||
self.repeat = self.master_configs[0].get("repeat", 1)
|
||||
self.worker_num = self.master_config.get("worker_num", max(1, os.cpu_count() // 2))
|
||||
self.repeat = self.master_config.get("repeat", 1)
|
||||
self.variationer = Variationer()
|
||||
|
||||
names = set()
|
||||
for i, config in enumerate(self.master_configs):
|
||||
fiber_names = set()
|
||||
self.num_fibers = 0
|
||||
for i, config in enumerate(self.fiber_configs):
|
||||
config.setdefault("name", Parameters.name.default)
|
||||
self.z_num += config["z_num"]
|
||||
config.setdefault("name", f"{Parameters.name.default} {i}")
|
||||
given_name = config["name"]
|
||||
fn_i = 0
|
||||
while config["name"] in names:
|
||||
config["name"] = given_name + f"_{fn_i}"
|
||||
fn_i += 1
|
||||
names.add(config["name"])
|
||||
|
||||
self.sim_dirs.append(
|
||||
fiber_names.add(config["name"])
|
||||
vary_dict = config.pop("variable")
|
||||
self.variationer.append(vary_dict)
|
||||
self.fiber_paths.append(
|
||||
utils.ensure_folder(
|
||||
Path("_".join(["_", self.name, Path(config["name"]).name, "_"])),
|
||||
self.final_path / fiber_folder(i, self.name, config["name"]),
|
||||
mkdir=False,
|
||||
prevent_overwrite=not self.overwrite,
|
||||
)
|
||||
)
|
||||
self.__validate_variable(config)
|
||||
self.__compute_sim_dirs()
|
||||
[Evaluator.evaluate_default(c[0].config, True) for c in self.all_configs_list]
|
||||
self.num_sim = len(self.all_configs_list[-1])
|
||||
self.__validate_variable(vary_dict)
|
||||
self.num_fibers += 1
|
||||
Evaluator.evaluate_default(
|
||||
self.__build_base_config() | config | {k: v[0] for k, v in vary_dict.items()}, True
|
||||
)
|
||||
self.num_sim = self.variationer.var_num()
|
||||
self.total_num_steps = sum(
|
||||
config["z_num"] * len(self.all_configs_list[i])
|
||||
for i, config in enumerate(self.master_configs)
|
||||
config["z_num"] * self.variationer.var_num(i)
|
||||
for i, config in enumerate(self.fiber_configs)
|
||||
)
|
||||
self.final_sim_dir = utils.ensure_folder(
|
||||
Path(self.master_configs[-1]["name"]), mkdir=False, prevent_overwrite=not self.overwrite
|
||||
)
|
||||
self.parallel = self.master_configs[0].get("parallel", Parameters.parallel.default)
|
||||
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
|
||||
|
||||
def __validate_variable(self, config: dict[str, Any]):
|
||||
for k, v in config.get("variable", {}).items():
|
||||
def __build_base_config(self):
|
||||
cfg = self.master_config.copy()
|
||||
vary = cfg.pop("variable", {})
|
||||
return cfg | {k: v[0] for k, v in vary.items()}
|
||||
|
||||
def __validate_variable(self, vary_dict: dict[str, list]):
|
||||
for k, v in vary_dict.items():
|
||||
p = getattr(Parameters, k)
|
||||
validator_list(p.validator)("variable " + k, v)
|
||||
if k not in VALID_VARIABLE:
|
||||
@@ -863,76 +903,47 @@ class Configuration:
|
||||
if len(v) == 0:
|
||||
raise ValueError(f"variable parameter {k!r} must not be empty")
|
||||
|
||||
def __compute_sim_dirs(self):
|
||||
self.all_configs_dict = {}
|
||||
self.all_configs_list = []
|
||||
self.master_configs[0]["variable"]["num"] = list(
|
||||
range(self.master_configs[0].get("repeat", 1))
|
||||
)
|
||||
dp = DataPather([c["variable"] for c in self.master_configs])
|
||||
for i, conf in enumerate(self.master_configs):
|
||||
self.all_configs_list.append([])
|
||||
for sim_index, prev_path, this_path, this_vary in dp.all_vary_list(i):
|
||||
this_conf = conf.copy()
|
||||
if i > 0:
|
||||
prev_path = utils.ensure_folder(
|
||||
self.sim_dirs[i - 1] / prev_path, not self.overwrite, False
|
||||
)
|
||||
this_conf["prev_data_dir"] = str(prev_path)
|
||||
def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]:
|
||||
for i in range(self.num_fibers):
|
||||
yield from self.iterate_single_fiber(i)
|
||||
|
||||
this_path = utils.ensure_folder(
|
||||
self.sim_dirs[i] / this_path, not self.overwrite, False
|
||||
)
|
||||
this_conf.pop("variable")
|
||||
conf_to_use = {k: v for k, v in this_vary if k != "num"} | this_conf
|
||||
self.all_configs_dict[sim_index] = self.__SimConfig(
|
||||
this_vary, conf_to_use, this_path, sim_index
|
||||
)
|
||||
self.all_configs_list[i].append(self.all_configs_dict[sim_index])
|
||||
|
||||
def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]:
|
||||
for i, sim_config_list in enumerate(self.all_configs_list):
|
||||
for sim_config, params in self.__iter_1_sim(sim_config_list):
|
||||
fiber_map = []
|
||||
for j in range(i + 1):
|
||||
this_conf = self.all_configs_dict[sim_config.index[: j + 1]].config
|
||||
if j > 0:
|
||||
prev_conf = self.all_configs_dict[sim_config.index[:j]].config
|
||||
length = prev_conf["length"] + fiber_map[j - 1][0]
|
||||
else:
|
||||
length = 0.0
|
||||
fiber_map.append((length, this_conf["name"]))
|
||||
params.output_path = str(sim_config.output_path)
|
||||
params.fiber_map = fiber_map
|
||||
yield sim_config.vary_list, params
|
||||
|
||||
def __iter_1_sim(
|
||||
self, configs: list["Configuration.__SimConfig"]
|
||||
) -> Generator[tuple["Configuration.__SimConfig", Parameters], None, None]:
|
||||
def iterate_single_fiber(self, index: int) -> Iterator[tuple[VariationDescriptor, Parameters]]:
|
||||
"""iterates through the parameters of only one fiber. It takes care of recovering partially
|
||||
completed simulations, skipping complete ones and waiting for the previous fiber to finish
|
||||
|
||||
Parameters
|
||||
----------
|
||||
configs : list[__SimConfig]
|
||||
list of configuration obj
|
||||
index : int
|
||||
which fiber to iterate over
|
||||
|
||||
Yields
|
||||
-------
|
||||
__SimConfig
|
||||
configuration obj
|
||||
Parameters
|
||||
computed Parameters obj
|
||||
"""
|
||||
sim_dict: dict[Path, Configuration.__SimConfig] = {s.output_path: s for s in configs}
|
||||
if index < 0:
|
||||
index = self.num_fibers + index
|
||||
sim_dict: dict[Path, Configuration.__SimConfig] = {}
|
||||
for descriptor in self.variationer.iterate(index):
|
||||
cfg = descriptor.update_config(self.fiber_configs[index])
|
||||
if index > 0:
|
||||
cfg["prev_data_dir"] = str(
|
||||
self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True)
|
||||
)
|
||||
p = utils.ensure_folder(
|
||||
self.fiber_paths[index] / descriptor.formatted_descriptor(True),
|
||||
not self.overwrite,
|
||||
False,
|
||||
)
|
||||
cfg["output_path"] = str(p)
|
||||
sim_config = self.__SimConfig(descriptor, cfg, p)
|
||||
sim_dict[p] = self.all_configs[sim_config.descriptor.index] = sim_config
|
||||
while len(sim_dict) > 0:
|
||||
for data_dir, sim_config in sim_dict.items():
|
||||
task, config_dict = self.__decide(sim_config)
|
||||
if task == self.Action.RUN:
|
||||
sim_dict.pop(data_dir)
|
||||
p = Parameters(**config_dict)
|
||||
p.compute()
|
||||
yield sim_config, p
|
||||
yield sim_config.descriptor, Parameters(**sim_config.config)
|
||||
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
|
||||
self.skip_callback(config_dict["recovery_last_stored"])
|
||||
break
|
||||
@@ -957,12 +968,14 @@ class Configuration:
|
||||
|
||||
Returns
|
||||
-------
|
||||
str : {'run', 'wait', 'skip'}
|
||||
str : Configuration.Action
|
||||
what to do
|
||||
config_dict : dict[str, Any]
|
||||
config dictionary. The only key possibly modified is 'prev_data_dir', which
|
||||
gets set if the simulation is partially completed
|
||||
"""
|
||||
if not self.wait:
|
||||
return self.Action.RUN, sim_config.config
|
||||
out_status, num = self.sim_status(sim_config.output_path, sim_config.config)
|
||||
if out_status == self.State.COMPLETE:
|
||||
return self.Action.SKIP, sim_config.config
|
||||
@@ -999,7 +1012,7 @@ class Configuration:
|
||||
num = utils.find_last_spectrum_num(data_dir)
|
||||
if config_dict is None:
|
||||
try:
|
||||
config_dict = utils.open_config(data_dir / PARAM_FN)
|
||||
config_dict = utils._open_config(data_dir / PARAM_FN)
|
||||
except FileNotFoundError:
|
||||
self.logger.warning(f"did not find {PARAM_FN!r} in {data_dir}")
|
||||
return self.State.ABSENT, 0
|
||||
@@ -1013,9 +1026,12 @@ class Configuration:
|
||||
raise ValueError(f"Too many spectra in {data_dir}")
|
||||
|
||||
def save_parameters(self):
|
||||
for config, sim_dir in zip(self.master_configs, self.sim_dirs):
|
||||
os.makedirs(sim_dir, exist_ok=True)
|
||||
utils.save_toml(sim_dir / f"initial_config.toml", config)
|
||||
os.makedirs(self.final_path, exist_ok=True)
|
||||
cfgs = [
|
||||
cfg | dict(variable=self.variationer.all_dicts[i])
|
||||
for i, cfg in enumerate(self.fiber_configs)
|
||||
]
|
||||
utils.save_toml(self.final_path / f"initial_config.toml", dict(name=self.name, Fiber=cfgs))
|
||||
|
||||
@property
|
||||
def first(self) -> Parameters:
|
||||
@@ -1023,336 +1039,6 @@ class Configuration:
|
||||
return param
|
||||
|
||||
|
||||
class DataPather:
|
||||
def __init__(self, dl: list[dict[str, Any]]):
|
||||
self.dict_list = dl
|
||||
|
||||
def vary_list_iterator(
|
||||
self, index: int
|
||||
) -> Generator[tuple[tuple[tuple[int, ...]], list[list[tuple[str, Any]]]], None, None]:
|
||||
"""iterates through every possible combination of a list of dict of lists
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : int
|
||||
up to where in the stored dict_list to go
|
||||
|
||||
Yields
|
||||
-------
|
||||
list[list[tuple[str, Any]]]
|
||||
list of list of (key, value) pairs
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
self.dict_list = [{a:[56, 57], b:["?", "!"]}, {c:[0, -1]}] ->
|
||||
[
|
||||
[[(a, 56), (b, "?")], [(c, 0)]],
|
||||
[[(a, 56), (b, "?")], [(c, 1)]],
|
||||
[[(a, 56), (b, "!")], [(c, 0)]],
|
||||
[[(a, 56), (b, "!")], [(c, 1)]],
|
||||
[[(a, 57), (b, "?")], [(c, 0)]],
|
||||
[[(a, 57), (b, "?")], [(c, 1)]],
|
||||
[[(a, 57), (b, "!")], [(c, 0)]],
|
||||
[[(a, 57), (b, "!")], [(c, 1)]],
|
||||
]
|
||||
"""
|
||||
if index < 0:
|
||||
index = len(self.dict_list) - index
|
||||
d_tem_list = [el for d in self.dict_list[: index + 1] for el in d.items()]
|
||||
dict_pos = np.cumsum([0] + [len(d) for d in self.dict_list[: index + 1]])
|
||||
ranges = [range(len(l)) for _, l in d_tem_list]
|
||||
|
||||
for r in itertools.product(*ranges):
|
||||
flat = [(d_tem_list[i][0], d_tem_list[i][1][j]) for i, j in enumerate(r)]
|
||||
pos = tuple(r)
|
||||
out = [flat[left:right] for left, right in zip(dict_pos[:-1], dict_pos[1:])]
|
||||
pos = tuple(pos[left:right] for left, right in zip(dict_pos[:-1], dict_pos[1:]))
|
||||
yield pos, out
|
||||
|
||||
def all_vary_list(self, index):
|
||||
for sim_index, l in self.vary_list_iterator(index):
|
||||
unique_vary: list[tuple[str, Any]] = []
|
||||
for ll in l[: index + 1]:
|
||||
for pname, pval in ll:
|
||||
for i, (pn, _) in enumerate(unique_vary):
|
||||
if pn == pname:
|
||||
del unique_vary[i]
|
||||
break
|
||||
unique_vary.append((pname, pval))
|
||||
yield sim_index, format_variable_list(
|
||||
reduce_all_variable(l[:index]), add_iden=True
|
||||
), format_variable_list(reduce_all_variable(l), add_iden=True), unique_vary
|
||||
|
||||
def __repr__(self):
|
||||
return f"DataPather([{', '.join(repr(d) for d in self.dict_list)}])"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlotRange:
|
||||
left: float = Parameter(type_checker(int, float))
|
||||
right: float = Parameter(type_checker(int, float))
|
||||
unit: Callable[[float], float] = Parameter(units.is_unit, converter=units.get_unit)
|
||||
conserved_quantity: bool = Parameter(boolean, default=True)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.left >= self.right:
|
||||
raise ValueError(
|
||||
f"left value {self.left!r} must be strictly smaller than right value {self.right!r}"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||
|
||||
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
return sort_axis(axis, self)
|
||||
|
||||
def __iter__(self):
|
||||
yield self.left
|
||||
yield self.right
|
||||
yield self.unit.__name__
|
||||
|
||||
|
||||
def sort_axis(
|
||||
axis: np.ndarray, plt_range: PlotRange
|
||||
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
"""convert an array according to the given range
|
||||
|
||||
Parameters
|
||||
----------
|
||||
axis : np.ndarray, shape (n,)
|
||||
array
|
||||
plt_range : PlotRange
|
||||
range to crop in
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
new array converted to the desired unit and cropped in the given range
|
||||
np.ndarray
|
||||
indices of the concerved values
|
||||
tuple[float, float]
|
||||
actual minimum and maximum of the new axis
|
||||
|
||||
Example
|
||||
-------
|
||||
>> sort_axis([18.0, 19.0, 20.0, 13.0, 15.2], PlotRange(1400, 1900, "cm"))
|
||||
([1520.0, 1800.0, 1900.0], [4, 0, 1], (1520.0, 1900.0))
|
||||
"""
|
||||
if isinstance(plt_range, tuple):
|
||||
plt_range = PlotRange(*plt_range)
|
||||
|
||||
masked = np.ma.array(axis, mask=~np.isfinite(axis))
|
||||
converted = plt_range.unit.inv(masked)
|
||||
converted[(converted < plt_range.left) | (converted > plt_range.right)] = np.ma.masked
|
||||
indices = np.arange(len(axis))[~converted.mask]
|
||||
cropped = converted.compressed()
|
||||
order = cropped.argsort()
|
||||
|
||||
return cropped[order], indices[order], (cropped.min(), cropped.max())
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
spec = inspect.getfullargspec(func)
|
||||
args = spec.args
|
||||
if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
args = args[: -len(spec.defaults)]
|
||||
return args
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
out_func = scope[tmp_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
@cache
|
||||
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
||||
if not isinstance(num_args, int) and isinstance(num_returns, int):
|
||||
raise TypeError(f"num_args and num_returns must be int")
|
||||
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
||||
return_str = ", ".join("True" for _ in range(num_returns))
|
||||
func_name = f"__mock_{num_args}_{num_returns}"
|
||||
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
||||
scope = {}
|
||||
exec(func_str, scope)
|
||||
out_func = scope[func_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
def format_variable_list(l: list[tuple[str, Any]], add_iden=False) -> str:
|
||||
"""formats a variable list into a str such that each simulation has a unique
|
||||
directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations)
|
||||
branch identifier are added at the beginning.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
l : list[tuple[str, Any]]
|
||||
list of variable parameters
|
||||
add_iden : bool
|
||||
add unique simulation and parameter-set identifiers
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
directory name
|
||||
"""
|
||||
str_list = []
|
||||
for p_name, p_value in l:
|
||||
ps = p_name.replace("/", "").replace(PARAM_SEPARATOR, "")
|
||||
vs = format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
|
||||
str_list.append(ps + PARAM_SEPARATOR + vs)
|
||||
tmp_name = PARAM_SEPARATOR.join(str_list)
|
||||
if not add_iden:
|
||||
return tmp_name
|
||||
unique_id = unique_identifier(l)
|
||||
branch_id = branch_identifier(l)
|
||||
return unique_id + PARAM_SEPARATOR + branch_id + PARAM_SEPARATOR + tmp_name
|
||||
|
||||
|
||||
def branch_identifier(l):
|
||||
branch_id = "b_" + utils.to_62(hash(str([el for el in l if el[0] != "num"])))
|
||||
return branch_id
|
||||
|
||||
|
||||
def unique_identifier(l):
|
||||
unique_id = "u_" + utils.to_62(hash(str(l)))
|
||||
return unique_id
|
||||
|
||||
|
||||
def format_value(name: str, value) -> str:
|
||||
if value is True or value is False:
|
||||
return str(value)
|
||||
elif isinstance(value, (float, int)):
|
||||
try:
|
||||
return getattr(Parameters, name).display(value)
|
||||
except AttributeError:
|
||||
return format(value, ".9g")
|
||||
elif isinstance(value, (list, tuple, np.ndarray)):
|
||||
return "-".join([str(v) for v in value])
|
||||
elif isinstance(value, str):
|
||||
p = Path(value)
|
||||
if p.exists():
|
||||
return p.stem
|
||||
return str(value)
|
||||
|
||||
|
||||
def pretty_format_value(name: str, value) -> str:
|
||||
try:
|
||||
return getattr(Parameters, name).display(value)
|
||||
except AttributeError:
|
||||
return name + PARAM_SEPARATOR + str(value)
|
||||
|
||||
|
||||
def pretty_format_from_sim_name(name: str) -> str:
|
||||
"""formats a pretty version of a simulation directory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
name of the simulation (directory name)
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
prettier name
|
||||
"""
|
||||
s = name.split(PARAM_SEPARATOR)
|
||||
out = []
|
||||
for key, value in zip(s[::2], s[1::2]):
|
||||
try:
|
||||
out += [key.replace("_", " "), getattr(Parameters, key).display(float(value))]
|
||||
except (AttributeError, ValueError):
|
||||
out.append(key + PARAM_SEPARATOR + value)
|
||||
return PARAM_SEPARATOR.join(out)
|
||||
|
||||
|
||||
def variable_iterator(
|
||||
config: dict[str, Any], first: bool
|
||||
) -> Generator[tuple[list[tuple[str, Any]], dict[str, Any]], None, None]:
|
||||
"""given a config with "variable" parameters, iterates through every possible combination,
|
||||
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : BareConfig
|
||||
initial config obj
|
||||
first : int
|
||||
whether it is the first fiber or not (only the first fiber get a sim number)
|
||||
|
||||
Yields
|
||||
-------
|
||||
Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]
|
||||
variable_list : a list of (name, value) tuple of parameter name and value that are variable.
|
||||
|
||||
params : a dict[str, Any] to be fed to Parameters
|
||||
"""
|
||||
possible_keys = []
|
||||
possible_ranges = []
|
||||
|
||||
for key, values in config.get("variable", {}).items():
|
||||
possible_keys.append(key)
|
||||
possible_ranges.append(range(len(values)))
|
||||
|
||||
combinations = itertools.product(*possible_ranges)
|
||||
|
||||
master_index = 0
|
||||
repeat = config.get("repeat", 1) if first else 1
|
||||
for combination in combinations:
|
||||
indiv_config = {}
|
||||
variable_list = []
|
||||
for i, key in enumerate(possible_keys):
|
||||
parameter_value = config["variable"][key][combination[i]]
|
||||
indiv_config[key] = parameter_value
|
||||
variable_list.append((key, parameter_value))
|
||||
param_dict = deepcopy(config)
|
||||
param_dict.pop("variable")
|
||||
param_dict.update(indiv_config)
|
||||
for repeat_index in range(repeat):
|
||||
# variable_ind = [("id", master_index)] + variable_list
|
||||
variable_ind = variable_list
|
||||
if first:
|
||||
variable_ind += [("num", repeat_index)]
|
||||
yield variable_ind, param_dict
|
||||
master_index += 1
|
||||
|
||||
|
||||
def reduce_all_variable(all_variable: list[list[tuple[str, Any]]]) -> list[tuple[str, Any]]:
|
||||
out = []
|
||||
for n, variable_list in enumerate(all_variable):
|
||||
out += [("fiber", "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[n % 26] * (n // 26 + 1)), *variable_list]
|
||||
return out
|
||||
|
||||
|
||||
def strip_vary_list(all_variable: T) -> T:
|
||||
if len(all_variable) == 0:
|
||||
return all_variable
|
||||
elif isinstance(all_variable[0], Sequence) and (
|
||||
len(all_variable[0]) == 0 or not isinstance(all_variable[0][0], str)
|
||||
):
|
||||
return [strip_vary_list(el) for el in all_variable]
|
||||
else:
|
||||
return [el for el in all_variable if el[0] != "num"]
|
||||
|
||||
|
||||
default_rules: list[Rule] = [
|
||||
# Grid
|
||||
*Rule.deduce(
|
||||
@@ -1417,6 +1103,7 @@ default_rules: list[Rule] = [
|
||||
priorities=[2, 2, 2],
|
||||
),
|
||||
Rule("hr_w", fiber.delayed_raman_w),
|
||||
Rule("n_gas_2", materials.n_gas_2),
|
||||
Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")),
|
||||
Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")),
|
||||
Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")),
|
||||
@@ -1426,12 +1113,13 @@ default_rules: list[Rule] = [
|
||||
["wl_for_disp", "pitch", "pitch_ratio"],
|
||||
conditions=dict(model="pcf"),
|
||||
),
|
||||
Rule("capillary_spacing", fiber.HCARF_gap),
|
||||
Rule("capillary_spacing", fiber.capillary_spacing_hasan),
|
||||
# Fiber nonlinearity
|
||||
Rule("A_eff", fiber.A_eff_from_V),
|
||||
Rule("A_eff", fiber.A_eff_from_diam),
|
||||
Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")),
|
||||
Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1),
|
||||
Rule("A_eff", fiber.A_eff_marcatili, priorities=-2),
|
||||
Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]),
|
||||
Rule("A_eff_arr", fiber.load_custom_A_eff),
|
||||
Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
|
||||
189
src/scgenerator/_utils/pbar.py
Normal file
189
src/scgenerator/_utils/pbar.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import typing
|
||||
from collections import abc
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..env import pbar_policy
|
||||
|
||||
T_ = typing.TypeVar("T_")
|
||||
|
||||
|
||||
class PBars:
|
||||
def __init__(
|
||||
self,
|
||||
task: Union[int, Iterable[T_]],
|
||||
desc: str,
|
||||
num_sub_bars: int = 0,
|
||||
head_kwargs=None,
|
||||
worker_kwargs=None,
|
||||
) -> "PBars":
|
||||
"""creates a PBars obj
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : int | Iterable
|
||||
if int : total length of the main task
|
||||
if Iterable : behaves like tqdm
|
||||
desc : str
|
||||
description of the main task
|
||||
num_sub_bars : int
|
||||
number of sub-tasks
|
||||
|
||||
"""
|
||||
self.id = random.randint(100000, 999999)
|
||||
try:
|
||||
self.width = os.get_terminal_size().columns
|
||||
except OSError:
|
||||
self.width = 80
|
||||
if isinstance(task, abc.Iterable):
|
||||
self.iterator: Iterable[T_] = iter(task)
|
||||
self.num_tot: int = len(task)
|
||||
else:
|
||||
self.num_tot: int = task
|
||||
self.iterator = None
|
||||
|
||||
self.policy = pbar_policy()
|
||||
if head_kwargs is None:
|
||||
head_kwargs = dict()
|
||||
if worker_kwargs is None:
|
||||
worker_kwargs = dict(
|
||||
total=1,
|
||||
desc="Worker {worker_id}",
|
||||
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
|
||||
)
|
||||
if "print" not in pbar_policy():
|
||||
head_kwargs["file"] = worker_kwargs["file"] = StringIO()
|
||||
self.width = 80
|
||||
head_kwargs["desc"] = desc
|
||||
self.pbars = [tqdm(total=self.num_tot, ncols=self.width, ascii=False, **head_kwargs)]
|
||||
for i in range(1, num_sub_bars + 1):
|
||||
kwargs = {k: v for k, v in worker_kwargs.items()}
|
||||
if "desc" in kwargs:
|
||||
kwargs["desc"] = kwargs["desc"].format(worker_id=i)
|
||||
self.append(tqdm(position=i, ncols=self.width, ascii=False, **kwargs))
|
||||
self.print_path = Path(
|
||||
f"progress {self.pbars[0].desc.replace('/', '')} {self.id}"
|
||||
).resolve()
|
||||
self.close_ev = threading.Event()
|
||||
if "file" in self.policy:
|
||||
self.thread = threading.Thread(target=self.print_worker, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def print(self):
|
||||
if "file" not in self.policy:
|
||||
return
|
||||
s = []
|
||||
for pbar in self.pbars:
|
||||
s.append(str(pbar))
|
||||
self.print_path.write_text("\n".join(s))
|
||||
|
||||
def print_worker(self):
|
||||
while True:
|
||||
if self.close_ev.wait(2.0):
|
||||
return
|
||||
self.print()
|
||||
|
||||
def __iter__(self):
|
||||
with self as pb:
|
||||
for thing in self.iterator:
|
||||
yield thing
|
||||
pb.update()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.pbars[key]
|
||||
|
||||
def update(self, i=None, value=1):
|
||||
if i is None:
|
||||
for pbar in self.pbars[1:]:
|
||||
pbar.update(value)
|
||||
elif i > 0:
|
||||
self.pbars[i].update(value)
|
||||
self.pbars[0].update()
|
||||
|
||||
def append(self, pbar: tqdm):
|
||||
self.pbars.append(pbar)
|
||||
|
||||
def reset(self, i):
|
||||
self.pbars[i].update(-self.pbars[i].n)
|
||||
self.print()
|
||||
|
||||
def close(self):
|
||||
self.print()
|
||||
self.close_ev.set()
|
||||
if "file" in self.policy:
|
||||
self.thread.join()
|
||||
for pbar in self.pbars:
|
||||
pbar.close()
|
||||
|
||||
|
||||
class ProgressBarActor:
|
||||
def __init__(self, name: str, num_workers: int, num_steps: int) -> None:
|
||||
self.counters = [0 for _ in range(num_workers + 1)]
|
||||
self.p_bars = PBars(
|
||||
num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")
|
||||
)
|
||||
|
||||
def update(self, worker_id: int, rel_pos: float = None) -> None:
|
||||
"""update a counter
|
||||
|
||||
Parameters
|
||||
----------
|
||||
worker_id : int
|
||||
id of the worker. 0 is the overall progress
|
||||
rel_pos : float, optional
|
||||
if None, increase the counter by one, if set, will set
|
||||
the counter to the specified value (instead of incrementing it), by default None
|
||||
"""
|
||||
if rel_pos is None:
|
||||
self.counters[worker_id] += 1
|
||||
else:
|
||||
self.counters[worker_id] = rel_pos
|
||||
|
||||
def update_pbars(self):
|
||||
for counter, pbar in zip(self.counters, self.p_bars.pbars):
|
||||
pbar.update(counter - pbar.n)
|
||||
|
||||
def close(self):
|
||||
self.p_bars.close()
|
||||
|
||||
|
||||
def progress_worker(
|
||||
name: str, num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue
|
||||
):
|
||||
"""keeps track of progress on a separate thread
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_steps : int
|
||||
total number of steps, used for the main progress bar (position 0)
|
||||
progress_queue : multiprocessing.Queue
|
||||
values are either
|
||||
Literal[0] : stop the worker and close the progress bars
|
||||
tuple[int, float] : worker id and relative progress between 0 and 1
|
||||
"""
|
||||
with PBars(
|
||||
num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")
|
||||
) as pbars:
|
||||
while True:
|
||||
raw = progress_queue.get()
|
||||
if raw == 0:
|
||||
return
|
||||
i, rel_pos = raw
|
||||
if i > 0:
|
||||
pbars[i].update(rel_pos - pbars[i].n)
|
||||
pbars[0].update()
|
||||
elif i == 0:
|
||||
pbars[0].update(rel_pos)
|
||||
260
src/scgenerator/_utils/utils.py
Normal file
260
src/scgenerator/_utils/utils.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from string import printable as str_printable
|
||||
from typing import Any, Callable, Iterator, Set
|
||||
|
||||
import numpy as np
|
||||
import toml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._utils import load_toml, save_toml
|
||||
from ..const import PARAM_FN, PARAM_SEPARATOR, Z_FN
|
||||
from ..physics.units import get_unit
|
||||
|
||||
|
||||
class HashableBaseModel(BaseModel):
|
||||
"""Pydantic BaseModel that's immutable and can be hashed"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
|
||||
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
|
||||
|
||||
def to_62(i: int) -> str:
|
||||
arr = []
|
||||
if i == 0:
|
||||
return "0"
|
||||
i = abs(i)
|
||||
while i:
|
||||
i, value = divmod(i, 62)
|
||||
arr.append(str_printable[value])
|
||||
return "".join(reversed(arr))
|
||||
|
||||
|
||||
class PlotRange(HashableBaseModel):
|
||||
left: float
|
||||
right: float
|
||||
unit: Callable[[float], float]
|
||||
conserved_quantity: bool = True
|
||||
|
||||
def __init__(self, left, right, unit, **kwargs):
|
||||
super().__init__(left=left, right=right, unit=get_unit(unit), **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||
|
||||
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
return sort_axis(axis, self)
|
||||
|
||||
def __iter__(self):
|
||||
yield self.left
|
||||
yield self.right
|
||||
yield self.unit.__name__
|
||||
|
||||
|
||||
def sort_axis(
|
||||
axis: np.ndarray, plt_range: PlotRange
|
||||
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
"""
|
||||
given an axis, returns this axis cropped according to the given range, converted and sorted
|
||||
|
||||
Parameters
|
||||
----------
|
||||
axis : 1D array containing the original axis (usual the w or t array)
|
||||
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
||||
|
||||
Returns
|
||||
-------
|
||||
cropped : the axis cropped, converted and sorted
|
||||
indices : indices to use to slice and sort other array in the same fashion
|
||||
extent : tupple with min and max of cropped
|
||||
|
||||
Example
|
||||
-------
|
||||
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
||||
t = np.linspace(-10, 10, 400)
|
||||
W, T = np.meshgrid(w, t)
|
||||
y = np.exp(-W**2 - T**2)
|
||||
|
||||
# Define ranges
|
||||
rw = (-4, 4, s)
|
||||
rt = (-2, 6, s)
|
||||
|
||||
w, cw = sort_axis(w, rw)
|
||||
t, ct = sort_axis(t, rt)
|
||||
|
||||
# slice y according to the given ranges
|
||||
y = y[ct][:, cw]
|
||||
"""
|
||||
if isinstance(plt_range, tuple):
|
||||
plt_range = PlotRange(*plt_range)
|
||||
r = np.array((plt_range.left, plt_range.right), dtype="float")
|
||||
|
||||
indices = np.arange(len(axis))[
|
||||
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
|
||||
]
|
||||
cropped = axis[indices]
|
||||
order = np.argsort(plt_range.unit.inv(cropped))
|
||||
indices = indices[order]
|
||||
cropped = cropped[order]
|
||||
out_ax = plt_range.unit.inv(cropped)
|
||||
|
||||
return out_ax, indices, (out_ax[0], out_ax[-1])
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
# spec = inspect.getfullargspec(func)
|
||||
# args = spec.args
|
||||
# if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
# args = args[: -len(spec.defaults)]
|
||||
# return args
|
||||
return [k for k, v in inspect.signature(func).parameters.items() if v.default is inspect._empty]
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
out_func = scope[tmp_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
@cache
|
||||
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
||||
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
||||
return_str = ", ".join("True" for _ in range(num_returns))
|
||||
func_name = f"__mock_{num_args}_{num_returns}"
|
||||
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
||||
scope = {}
|
||||
exec(func_str, scope)
|
||||
out_func = scope[func_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
def combine_simulations(path: Path, dest: Path = None):
|
||||
"""combines raw simulations into one folder per branch
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Path
|
||||
source of the simulations (must contain u_xx directories)
|
||||
dest : Path, optional
|
||||
if given, moves the simulations to dest, by default None
|
||||
"""
|
||||
paths: dict[str, list[Path]] = defaultdict(list)
|
||||
if dest is None:
|
||||
dest = path
|
||||
|
||||
for p in path.glob("u_*b_*"):
|
||||
if p.is_dir():
|
||||
paths[p.name.split()[1]].append(p)
|
||||
for l in paths.values():
|
||||
l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0])
|
||||
for pulses in paths.values():
|
||||
new_path = dest / update_path(pulses[0].name)
|
||||
os.makedirs(new_path, exist_ok=True)
|
||||
for num, pulse in enumerate(pulses):
|
||||
params_ok = False
|
||||
for file in pulse.glob("*"):
|
||||
if file.name == PARAM_FN:
|
||||
if not params_ok:
|
||||
update_params(new_path, file)
|
||||
params_ok = True
|
||||
else:
|
||||
file.unlink()
|
||||
elif file.name == Z_FN:
|
||||
file.rename(new_path / file.name)
|
||||
elif file.name.startswith("spectr") and num == 0:
|
||||
file.rename(new_path / file.name)
|
||||
else:
|
||||
file.rename(new_path / (file.stem + f"_{num}" + file.suffix))
|
||||
pulse.rmdir()
|
||||
|
||||
|
||||
def update_params(new_path: Path, file: Path):
|
||||
params = load_toml(file)
|
||||
if (p := params.get("prev_data_dir")) is not None:
|
||||
p = Path(p)
|
||||
params["prev_data_dir"] = str(p.parent / update_path(p.name))
|
||||
params["output_path"] = str(new_path)
|
||||
save_toml(new_path / PARAM_FN, params)
|
||||
file.unlink()
|
||||
|
||||
|
||||
def save_parameters(
|
||||
params: dict[str, Any], destination_dir: Path, file_name: str = PARAM_FN
|
||||
) -> Path:
|
||||
"""saves a parameter dictionary. Note that is does remove some entries, particularly
|
||||
those that take a lot of space ("t", "w", ...)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict[str, Any]
|
||||
dictionary to save
|
||||
destination_dir : Path
|
||||
destination directory
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
path to newly created the paramter file
|
||||
"""
|
||||
file_path = destination_dir / file_name
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
|
||||
# save toml of the simulation
|
||||
with open(file_path, "w") as file:
|
||||
toml.dump(params, file, encoder=toml.TomlNumpyEncoder())
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def update_path(p: str) -> str:
|
||||
return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p)
|
||||
|
||||
|
||||
def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str:
|
||||
return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name])
|
||||
|
||||
|
||||
def simulations_list(path: os.PathLike) -> list[Path]:
|
||||
"""finds simulations folders contained in a parent directory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
parent path
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Path]
|
||||
Absolute Path to the simulation folder
|
||||
"""
|
||||
paths: list[Path] = []
|
||||
for pwd, _, files in os.walk(path):
|
||||
if PARAM_FN in files:
|
||||
paths.append(Path(pwd))
|
||||
paths.sort(key=lambda el: el.parent.name)
|
||||
return [p for p in paths if p.parent.name == paths[-1].parent.name]
|
||||
321
src/scgenerator/_utils/variationer.py
Normal file
321
src/scgenerator/_utils/variationer.py
Normal file
@@ -0,0 +1,321 @@
|
||||
from math import prod
|
||||
import itertools
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generator, Generic, Iterable, Iterator, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import validator
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from ..const import PARAM_SEPARATOR
|
||||
from . import utils
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class VariationSpecsError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class Variationer:
|
||||
"""
|
||||
manages possible combinations of values given dicts of lists
|
||||
|
||||
Example
|
||||
-------
|
||||
`>> var = Variationer([dict(a=[1, 2]), [dict(b=["000", "111"], c=["a", "-1"])]])
|
||||
list(v.raw_descr for v in var.iterate())
|
||||
|
||||
[
|
||||
((("a", 1),), (("b", "000"), ("c", "a"))),
|
||||
((("a", 1),), (("b", "111"), ("c", "-1"))),
|
||||
((("a", 2),), (("b", "000"), ("c", "a"))),
|
||||
((("a", 2),), (("b", "111"), ("c", "-1"))),
|
||||
]`
|
||||
|
||||
"""
|
||||
|
||||
all_indices: list[list[int]]
|
||||
all_dicts: list[list[dict[str, list]]]
|
||||
|
||||
def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]] = None):
|
||||
self.all_indices = []
|
||||
self.all_dicts = []
|
||||
if variables is not None:
|
||||
for i, el in enumerate(variables):
|
||||
self.append(el)
|
||||
|
||||
def append(self, var_list: Union[list[MutableMapping], MutableMapping]):
|
||||
"""append a list of variable parameter sets
|
||||
each call to append creates a new group of parameters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
var_list : Union[list[MutableMapping], MutableMapping]
|
||||
each dict in the list is treated as an independent parameter
|
||||
this means that if for one dict, len > 1, the lists of possible values
|
||||
must be the same length
|
||||
|
||||
Example
|
||||
-------
|
||||
`>> append([dict(wavelength=[800e-9, 900e-9], power=[1e3, 2e3]), dict(length=[3e-2, 3.5e-2, 4e-2])])`
|
||||
|
||||
means that for every parameter variations, wavelength=800e-9 will always occur when power=1e3 and
|
||||
vice versa, while length is free to vary independently
|
||||
|
||||
Raises
|
||||
------
|
||||
VariationSpecsError
|
||||
raised when possible values lists in a same dict are not the same length
|
||||
"""
|
||||
if not isinstance(var_list, Sequence):
|
||||
var_list = [{k: v} for k, v in var_list.items()]
|
||||
else:
|
||||
var_list = list(var_list)
|
||||
num_vars = []
|
||||
for d in var_list:
|
||||
values = list(d.values())
|
||||
len_to_test = len(values[0])
|
||||
if not all(len(v) == len_to_test for v in values[1:]):
|
||||
raise VariationSpecsError(
|
||||
f"variable items should all have the same number of parameters"
|
||||
)
|
||||
num_vars.append(len_to_test)
|
||||
if len(num_vars) == 0:
|
||||
num_vars = [1]
|
||||
self.all_indices.append(num_vars)
|
||||
self.all_dicts.append(var_list)
|
||||
|
||||
def iterate(self, index: int = -1) -> Generator["VariationDescriptor", None, None]:
|
||||
index = self.__index(index)
|
||||
flattened_indices = sum(self.all_indices[: index + 1], [])
|
||||
index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[: index + 1]])
|
||||
ranges = [range(i) for i in flattened_indices]
|
||||
for r in itertools.product(*ranges):
|
||||
out: list[list[tuple[str, Any]]] = []
|
||||
indicies: list[list[int]] = []
|
||||
for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])):
|
||||
out.append([])
|
||||
indicies.append([])
|
||||
for value_index, var_d in zip(r[start:end], self.all_dicts[i]):
|
||||
for k, v in var_d.items():
|
||||
out[-1].append((k, v[value_index]))
|
||||
indicies[-1].append(value_index)
|
||||
yield VariationDescriptor(raw_descr=out, index=indicies)
|
||||
|
||||
def __index(self, index: int) -> int:
|
||||
if index < 0:
|
||||
index = len(self.all_indices) + index
|
||||
return index
|
||||
|
||||
def var_num(self, index: int = -1) -> int:
|
||||
index = self.__index(index)
|
||||
return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
|
||||
|
||||
|
||||
class VariationDescriptor(BaseModel):
|
||||
raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
|
||||
index: tuple[tuple[int, ...], ...]
|
||||
separator: str = "fiber"
|
||||
_format_registry: dict[str, Callable[..., str]] = {}
|
||||
__ids: dict[int, int] = {}
|
||||
|
||||
@classmethod
|
||||
def register_formatter(cls, p_name: str, func: Callable[..., str]):
|
||||
"""register a function that formats a particular parameter
|
||||
|
||||
Parameters
|
||||
----------
|
||||
p_name : str
|
||||
name of the parameter
|
||||
func : Callable[..., str]
|
||||
function that takes as single argument the value of the parameter and returns a string
|
||||
"""
|
||||
cls._format_registry[p_name] = func
|
||||
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
|
||||
def formatted_descriptor(self, add_identifier=False) -> str:
|
||||
"""formats a variable list into a str such that each simulation has a unique
|
||||
directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations)
|
||||
branch identifier can added at the beginning.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
add_identifier : bool
|
||||
add unique simulation and parameter-set identifiers
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
simulation descriptor
|
||||
"""
|
||||
str_list = []
|
||||
|
||||
for p_name, p_value in self.flat:
|
||||
ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "")
|
||||
vs = self.format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
|
||||
str_list.append(ps + PARAM_SEPARATOR + vs)
|
||||
tmp_name = PARAM_SEPARATOR.join(str_list)
|
||||
if not add_identifier:
|
||||
return tmp_name
|
||||
return (
|
||||
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
|
||||
)
|
||||
|
||||
def format_value(self, name: str, value) -> str:
|
||||
if value is True or value is False:
|
||||
return str(value)
|
||||
elif isinstance(value, (float, int)):
|
||||
try:
|
||||
return self._format_registry[name](value)
|
||||
except KeyError:
|
||||
return format(value, ".9g")
|
||||
elif isinstance(value, (list, tuple, np.ndarray)):
|
||||
return "-".join([str(v) for v in value])
|
||||
elif isinstance(value, str):
|
||||
p = Path(value)
|
||||
if p.exists():
|
||||
return p.stem
|
||||
return str(value)
|
||||
|
||||
def __getitem__(self, key) -> "VariationDescriptor":
|
||||
return VariationDescriptor(
|
||||
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.formatted_descriptor(add_identifier=False)
|
||||
|
||||
def __lt__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr < other.raw_descr
|
||||
|
||||
def __le__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr <= other.raw_descr
|
||||
|
||||
def __gt__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr > other.raw_descr
|
||||
|
||||
def __ge__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr >= other.raw_descr
|
||||
|
||||
def __eq__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr == other.raw_descr
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.raw_descr)
|
||||
|
||||
def __contains__(self, other: "VariationDescriptor") -> bool:
|
||||
return all(el in self.raw_descr for el in other.raw_descr)
|
||||
|
||||
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
|
||||
"""updates a dictionary with the value of the descriptor
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cfg : dict[str, Any]
|
||||
dict to be updated
|
||||
index : int, optional
|
||||
index of the fiber from which to apply the parameters, by default -1
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, Any]
|
||||
same as cfg but with key from the descriptor added/updated.
|
||||
"""
|
||||
out_cfg = cfg.copy()
|
||||
out_cfg.pop("variable", None)
|
||||
return out_cfg | {k: v for k, v in self.raw_descr[index]}
|
||||
|
||||
def iter_parents(self) -> Iterator["VariationDescriptor"]:
|
||||
if (p := self.parent) is not None:
|
||||
yield from p.iter_parents()
|
||||
yield self
|
||||
|
||||
@property
|
||||
def flat(self) -> list[tuple[str, Any]]:
|
||||
out = []
|
||||
for n, variable_list in enumerate(self.raw_descr):
|
||||
out += [
|
||||
(self.separator, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[n % 26] * (n // 26 + 1)),
|
||||
*variable_list,
|
||||
]
|
||||
return out
|
||||
|
||||
@property
|
||||
def branch(self) -> "BranchDescriptor":
|
||||
descr: list[list[tuple[str, Any]]] = []
|
||||
ind: list[list[int]] = []
|
||||
for i, l in enumerate(self.raw_descr):
|
||||
descr.append([])
|
||||
ind.append([])
|
||||
for j, (k, v) in enumerate(l):
|
||||
if k != "num":
|
||||
descr[-1].append((k, v))
|
||||
ind[-1].append(self.index[i][j])
|
||||
return BranchDescriptor(raw_descr=descr, index=ind, separator=self.separator)
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
unique_id = hash(str(self.flat))
|
||||
self.__ids.setdefault(unique_id, len(self.__ids))
|
||||
return "u_" + str(self.__ids[unique_id])
|
||||
|
||||
@property
|
||||
def parent(self) -> Optional["VariationDescriptor"]:
|
||||
if len(self.raw_descr) < 2:
|
||||
return None
|
||||
return VariationDescriptor(
|
||||
raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator
|
||||
)
|
||||
|
||||
@property
|
||||
def num_fibers(self) -> int:
|
||||
return len(self.raw_descr)
|
||||
|
||||
|
||||
class BranchDescriptor(VariationDescriptor):
|
||||
__ids: dict[int, int] = {}
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
branch_id = hash(str(self.flat))
|
||||
self.__ids.setdefault(branch_id, len(self.__ids))
|
||||
return "b_" + str(self.__ids[branch_id])
|
||||
|
||||
@validator("raw_descr")
|
||||
def validate_raw_descr(cls, v):
|
||||
return tuple(tuple(el for el in variable if el[0] != "num") for variable in v)
|
||||
|
||||
|
||||
class DescriptorDict(Generic[T]):
|
||||
def __init__(self, dico: dict[VariationDescriptor, T] = None):
|
||||
self.dico: dict[tuple[tuple[tuple[str, Any], ...], ...], tuple[VariationDescriptor, T]] = {}
|
||||
if dico is not None:
|
||||
for k, v in dico.items():
|
||||
self[k] = v
|
||||
|
||||
def __setitem__(self, key: VariationDescriptor, value: T):
|
||||
if not isinstance(key, VariationDescriptor):
|
||||
raise TypeError("key must be a VariationDescriptor instance")
|
||||
self.dico[key.raw_descr] = (key, value)
|
||||
|
||||
def __getitem__(
|
||||
self, key: Union[VariationDescriptor, tuple[tuple[tuple[str, Any], ...], ...]]
|
||||
) -> T:
|
||||
if isinstance(key, VariationDescriptor):
|
||||
return self.dico[key.raw_descr][1]
|
||||
else:
|
||||
return self.dico[key][1]
|
||||
|
||||
def items(self) -> Iterator[tuple[VariationDescriptor, T]]:
|
||||
for k, v in self.dico.items():
|
||||
yield k, v[1]
|
||||
|
||||
def keys(self) -> list[VariationDescriptor]:
|
||||
return [v[0] for v in self.dico.values()]
|
||||
|
||||
def values(self) -> list[T]:
|
||||
return [v[1] for v in self.dico.values()]
|
||||
@@ -1,14 +1,13 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import ChainMap
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import const, env, scripts, utils
|
||||
from .. import const, env, scripts
|
||||
from .. import _utils as utils
|
||||
from ..logger import get_logger
|
||||
from ..physics.fiber import dispersion_coefficients
|
||||
from ..physics.simulate import SequencialSimulations, run_simulation
|
||||
|
||||
@@ -20,7 +20,8 @@ def pbar_format(worker_id: int):
|
||||
|
||||
|
||||
SPEC1_FN = "spectrum_{}.npy"
|
||||
SPECN_FN = "spectra_{}.npy"
|
||||
SPECN_FN1 = "spectra_{}.npy"
|
||||
SPEC1_FN_N = "spectrum_{}_{}.npy"
|
||||
Z_FN = "z.npy"
|
||||
PARAM_FN = "params.toml"
|
||||
PARAM_SEPARATOR = " "
|
||||
|
||||
@@ -48,7 +48,7 @@ def data_folder(task_id: int) -> Optional[str]:
|
||||
return tmp
|
||||
|
||||
|
||||
def get(key: str) -> Any:
|
||||
def get(key: str, default=None) -> Any:
|
||||
str_value = os.environ.get(key)
|
||||
if isinstance(str_value, str):
|
||||
try:
|
||||
@@ -58,7 +58,7 @@ def get(key: str) -> Any:
|
||||
return t(str_value)
|
||||
except (ValueError, KeyError):
|
||||
pass
|
||||
return None
|
||||
return default
|
||||
|
||||
|
||||
def all_environ() -> Dict[str, str]:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
from scipy.interpolate import griddata, interp1d
|
||||
from scipy.special import jn_zeros
|
||||
from .utils.cache import np_cache
|
||||
from ._utils.cache import np_cache
|
||||
|
||||
pi = np.pi
|
||||
c = 299792458.0
|
||||
|
||||
@@ -10,8 +10,7 @@ from scipy.optimize import minimize_scalar
|
||||
|
||||
from .. import math
|
||||
from . import fiber, materials, units, pulse
|
||||
from .. import utils
|
||||
from ..utils import cache
|
||||
from .._utils import cache, load_material_dico
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -62,7 +61,7 @@ def material_dispersion(
|
||||
)
|
||||
return disp
|
||||
else:
|
||||
material_dico = utils.load_material_dico(material)
|
||||
material_dico = load_material_dico(material)
|
||||
if ideal:
|
||||
n_gas_2 = materials.sellmeier(wavelengths, material_dico, pressure, temperature) + 1
|
||||
else:
|
||||
|
||||
@@ -8,9 +8,9 @@ from scipy.interpolate import interp1d
|
||||
|
||||
from ..logger import get_logger
|
||||
|
||||
from .. import utils
|
||||
from .. import _utils as utils
|
||||
from ..math import abs2, argclosest, power_fact, u_nm
|
||||
from ..utils.cache import np_cache
|
||||
from .._utils.cache import np_cache
|
||||
from . import materials as mat
|
||||
from . import units
|
||||
from .units import c, pi
|
||||
@@ -49,27 +49,6 @@ def is_dynamic_dispersion(pressure=None):
|
||||
return out
|
||||
|
||||
|
||||
def HCARF_gap(core_radius: float, capillary_num: int, capillary_outer_d: float):
|
||||
"""computes the gap length between capillaries of a hollow core anti-resonance fiber
|
||||
|
||||
Parameters
|
||||
----------
|
||||
core_radius : float
|
||||
radius of the core (m) (from cented to edge of a capillary)
|
||||
capillary_num : int
|
||||
number of capillaries
|
||||
capillary_outer_d : float
|
||||
diameter of the capillaries including the wall thickness(m). The core together with the microstructure has a diameter of 2R + 2d
|
||||
|
||||
Returns
|
||||
-------
|
||||
gap : float
|
||||
"""
|
||||
return (core_radius + capillary_outer_d / 2) * 2 * np.sin(
|
||||
pi / capillary_num
|
||||
) - capillary_outer_d
|
||||
|
||||
|
||||
def gvd_from_n_eff(n_eff: np.ndarray, wl_for_disp: np.ndarray):
|
||||
"""computes the dispersion parameter D from an effective index of refraction n_eff
|
||||
Since computing gradients/derivatives of discrete arrays is not well defined on the boundary, it is
|
||||
@@ -193,6 +172,30 @@ def n_eff_marcatili_adjusted(wl_for_disp, n_gas_2, core_radius, he_mode=(1, 1),
|
||||
return np.sqrt(n_gas_2 - (wl_for_disp * u / (pipi * corrected_radius)) ** 2)
|
||||
|
||||
|
||||
def A_eff_marcatili(core_radius: float) -> float:
|
||||
"""Effective mode-field area for fundamental mode hollow capillaries
|
||||
|
||||
Parameters
|
||||
----------
|
||||
core_radius : float
|
||||
radius of the core
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
effective mode field area
|
||||
"""
|
||||
return 1.5 * core_radius ** 2
|
||||
|
||||
|
||||
def capillary_spacing_hasan(
|
||||
capillary_num: int, capillary_radius: float, core_radius: float
|
||||
) -> float:
|
||||
return (
|
||||
2 * (capillary_radius + core_radius) * np.sin(np.pi / capillary_num) - 2 * capillary_radius
|
||||
)
|
||||
|
||||
|
||||
@np_cache
|
||||
def n_eff_hasan(
|
||||
wl_for_disp: np.ndarray,
|
||||
|
||||
@@ -5,14 +5,14 @@ from scipy.integrate import cumulative_trapezoid
|
||||
|
||||
from ..logger import get_logger
|
||||
from . import units
|
||||
from .. import utils
|
||||
from .. import _utils
|
||||
from .units import NA, c, kB, me, e, hbar
|
||||
|
||||
|
||||
def n_gas_2(
|
||||
wl_for_disp: np.ndarray, gas: str, pressure: float, temperature: float, ideal_gas: bool
|
||||
wl_for_disp: np.ndarray, gas_name: str, pressure: float, temperature: float, ideal_gas: bool
|
||||
):
|
||||
material_dico = utils.load_material_dico(gas)
|
||||
material_dico = _utils.load_material_dico(gas_name)
|
||||
|
||||
if ideal_gas:
|
||||
n_gas_2 = sellmeier(wl_for_disp, material_dico, pressure, temperature) + 1
|
||||
|
||||
@@ -23,8 +23,6 @@ from scipy.interpolate import UnivariateSpline
|
||||
from scipy.optimize import minimize_scalar
|
||||
from scipy.optimize.optimize import OptimizeResult
|
||||
|
||||
from scgenerator import utils
|
||||
|
||||
from ..defaults import default_plotting
|
||||
from ..logger import get_logger
|
||||
from ..math import *
|
||||
@@ -820,7 +818,8 @@ def find_lobe_limits(x_axis, values, debug="", already_sorted=True):
|
||||
)
|
||||
ax.legend()
|
||||
fig.savefig(out_path, bbox_inches="tight")
|
||||
plt.close()
|
||||
if fig is not None:
|
||||
plt.close(fig)
|
||||
|
||||
else:
|
||||
good_roots, left_lim, right_lim = _select_roots(d_spline, d_roots, dd_roots, fwhm_pos)
|
||||
|
||||
@@ -9,11 +9,15 @@ from typing import Any, Generator, Type, Union
|
||||
import numpy as np
|
||||
from send2trash import send2trash
|
||||
|
||||
from .. import env, utils
|
||||
from .. import env
|
||||
from .. import _utils as utils
|
||||
from .._utils.utils import combine_simulations, save_parameters
|
||||
from ..logger import get_logger
|
||||
from ..utils.parameter import Configuration, Parameters, format_variable_list
|
||||
from .._utils.parameter import Configuration, Parameters
|
||||
from .._utils.pbar import PBars, ProgressBarActor, progress_worker
|
||||
from . import pulse
|
||||
from .fiber import create_non_linear_op, fast_dispersion_op
|
||||
from scgenerator._utils import pbar
|
||||
|
||||
try:
|
||||
import ray
|
||||
@@ -215,6 +219,17 @@ class RK4IP:
|
||||
return self.stored_spectra
|
||||
|
||||
def irun(self) -> Generator[tuple[int, int, np.ndarray], None, None]:
|
||||
"""run the simulation as a generator obj
|
||||
|
||||
Yields
|
||||
-------
|
||||
int
|
||||
current simulation step
|
||||
int
|
||||
current number of spectra returned
|
||||
np.ndarray
|
||||
spectrum
|
||||
"""
|
||||
|
||||
# Print introduction
|
||||
self.logger.debug(
|
||||
@@ -332,7 +347,7 @@ class SequentialRK4IP(RK4IP):
|
||||
def __init__(
|
||||
self,
|
||||
params: Parameters,
|
||||
pbars: utils.PBars,
|
||||
pbars: PBars,
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
task_id=0,
|
||||
@@ -466,14 +481,14 @@ class Simulations:
|
||||
|
||||
self.configuration = configuration
|
||||
|
||||
self.name = self.configuration.final_path
|
||||
self.sim_dir = self.configuration.final_sim_dir
|
||||
self.name = self.configuration.name
|
||||
self.sim_dir = self.configuration.final_path
|
||||
self.configuration.save_parameters()
|
||||
|
||||
self.sim_jobs_per_node = 1
|
||||
|
||||
def finished_and_complete(self):
|
||||
for sim in self.configuration.all_configs_dict.values():
|
||||
for sim in self.configuration.all_configs.values():
|
||||
if (
|
||||
self.configuration.sim_status(sim.output_path)[0]
|
||||
!= self.configuration.State.COMPLETE
|
||||
@@ -487,8 +502,9 @@ class Simulations:
|
||||
|
||||
def _run_available(self):
|
||||
for variable, params in self.configuration:
|
||||
v_list_str = format_variable_list(variable, add_iden=True)
|
||||
utils.save_parameters(params.prepare_for_dump(), Path(params.output_path))
|
||||
params.compute()
|
||||
v_list_str = variable.formatted_descriptor(True)
|
||||
save_parameters(params.prepare_for_dump(), Path(params.output_path))
|
||||
|
||||
self.new_sim(v_list_str, params)
|
||||
self.finish()
|
||||
@@ -525,8 +541,10 @@ class SequencialSimulations(Simulations, priority=0):
|
||||
|
||||
def __init__(self, configuration: Configuration, task_id):
|
||||
super().__init__(configuration, task_id=task_id)
|
||||
self.pbars = utils.PBars(
|
||||
self.configuration.total_num_steps, "Simulating " + self.configuration.final_path, 1
|
||||
self.pbars = PBars(
|
||||
self.configuration.total_num_steps,
|
||||
"Simulating " + self.configuration.final_path.name,
|
||||
1,
|
||||
)
|
||||
self.configuration.skip_callback = lambda num: self.pbars.update(0, num)
|
||||
|
||||
@@ -567,7 +585,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
||||
for i in range(self.sim_jobs_per_node)
|
||||
]
|
||||
self.p_worker = multiprocessing.Process(
|
||||
target=utils.progress_worker,
|
||||
target=progress_worker,
|
||||
args=(
|
||||
Path(self.configuration.final_path).name,
|
||||
self.sim_jobs_per_node,
|
||||
@@ -656,7 +674,7 @@ class RaySimulations(Simulations, priority=2):
|
||||
self.pool = ray.util.ActorPool(self.propagator.remote() for _ in range(self.sim_jobs_total))
|
||||
self.num_submitted = 0
|
||||
self.rolling_id = 0
|
||||
self.p_actor = ray.remote(utils.ProgressBarActor).remote(
|
||||
self.p_actor = ray.remote(ProgressBarActor).remote(
|
||||
self.configuration.final_path, self.sim_jobs_total, self.configuration.total_num_steps
|
||||
)
|
||||
self.configuration.skip_callback = lambda num: ray.get(self.p_actor.update.remote(0, num))
|
||||
@@ -712,21 +730,13 @@ def run_simulation(
|
||||
config_file: os.PathLike,
|
||||
method: Union[str, Type[Simulations]] = None,
|
||||
):
|
||||
config = Configuration(config_file)
|
||||
config = Configuration(config_file, wait=True)
|
||||
|
||||
sim = new_simulation(config, method)
|
||||
sim.run()
|
||||
path_trees = utils.build_path_trees(config.sim_dirs[-1])
|
||||
|
||||
final_name = env.get(env.OUTPUT_PATH)
|
||||
if final_name is None:
|
||||
final_name = config.final_path
|
||||
|
||||
utils.merge(final_name, path_trees)
|
||||
try:
|
||||
send2trash(config.sim_dirs)
|
||||
except (PermissionError, OSError):
|
||||
get_logger(__name__).error("Could not send temporary directories to trash")
|
||||
for path in config.fiber_paths:
|
||||
combine_simulations(path)
|
||||
|
||||
|
||||
def new_simulation(
|
||||
@@ -762,6 +772,8 @@ def parallel_RK4IP(
|
||||
]:
|
||||
logger = get_logger(__name__)
|
||||
params = list(Configuration(config))
|
||||
for _, param in params:
|
||||
param.compute()
|
||||
n = len(params)
|
||||
z_num = params[0][1].z_num
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# For example, nm(X) means "I give the number X in nm, figure out the ang. freq."
|
||||
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -14,7 +14,8 @@ from .const import PARAM_SEPARATOR
|
||||
from .defaults import default_plotting as defaults
|
||||
from .math import abs2, span
|
||||
from .physics import pulse, units
|
||||
from .utils.parameter import Parameters, PlotRange, sort_axis
|
||||
from ._utils.parameter import Parameters
|
||||
from ._utils.utils import PlotRange, sort_axis
|
||||
|
||||
RangeType = tuple[float, float, Union[str, Callable]]
|
||||
NO_LIM = object()
|
||||
|
||||
@@ -11,14 +11,13 @@ from .. import env, math
|
||||
from ..const import PARAM_FN, PARAM_SEPARATOR
|
||||
from ..physics import fiber, units
|
||||
from ..plotting import plot_setup
|
||||
from ..spectra import Pulse
|
||||
from ..utils import auto_crop, open_config, save_toml, translate_parameters
|
||||
from ..utils.parameter import (
|
||||
from ..spectra import SimulationSeries
|
||||
from .._utils import auto_crop, _open_config, save_toml, translate_parameters
|
||||
from .._utils.parameter import (
|
||||
Configuration,
|
||||
Parameters,
|
||||
pretty_format_from_sim_name,
|
||||
pretty_format_value,
|
||||
)
|
||||
from .._utils.utils import simulations_list
|
||||
|
||||
|
||||
def fingerprint(params: Parameters):
|
||||
@@ -33,7 +32,7 @@ def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
|
||||
opts[k] = int(v)
|
||||
if k in {"log", "renormalize"}:
|
||||
opts[k] = True if v == "True" else False
|
||||
dir_list = list(p for p in sim_dir.glob("*") if p.is_dir())
|
||||
dir_list = simulations_list(sim_dir)
|
||||
if len(dir_list) == 0:
|
||||
dir_list = [sim_dir]
|
||||
limits = [
|
||||
@@ -41,12 +40,12 @@ def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
|
||||
]
|
||||
with tqdm(total=len(dir_list) * len(limits)) as bar:
|
||||
for p in dir_list:
|
||||
pulse = Pulse(p)
|
||||
pulse = SimulationSeries(p)
|
||||
for left, right, unit in limits:
|
||||
path, fig, ax = plot_setup(
|
||||
pulse.path.parent
|
||||
/ (
|
||||
pretty_format_from_sim_name(pulse.path.name)
|
||||
pulse.path.name
|
||||
+ PARAM_SEPARATOR
|
||||
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
|
||||
)
|
||||
@@ -259,7 +258,7 @@ def finish_plot(fig, legend_axes, all_labels, params):
|
||||
|
||||
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]:
|
||||
cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"])
|
||||
pseq = Configuration(open_config(config_path))
|
||||
pseq = Configuration(_open_config(config_path))
|
||||
for style, (variables, params) in zip(cc, pseq):
|
||||
lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]]
|
||||
yield style, lbl, params
|
||||
@@ -268,7 +267,7 @@ def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters
|
||||
def convert_params(params_file: os.PathLike):
|
||||
p = Path(params_file)
|
||||
if p.name == PARAM_FN:
|
||||
d = open_config(params_file)
|
||||
d = _open_config(params_file)
|
||||
d = translate_parameters(d)
|
||||
save_toml(params_file, d)
|
||||
print(f"converted {p}")
|
||||
|
||||
@@ -9,8 +9,8 @@ from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import Paths
|
||||
from ..utils.parameter import Configuration
|
||||
from .._utils import Paths
|
||||
from .._utils.parameter import Configuration
|
||||
|
||||
|
||||
def primes(n):
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, Optional, Union
|
||||
from typing import Any, Callable, Iterator, Optional, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from . import math
|
||||
from .const import SPECN_FN
|
||||
from ._utils import load_spectrum
|
||||
from ._utils.parameter import Parameters
|
||||
from ._utils.utils import PlotRange, simulations_list
|
||||
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N
|
||||
from .logger import get_logger
|
||||
from .physics import pulse, units
|
||||
from .plotting import (
|
||||
@@ -16,8 +20,6 @@ from .plotting import (
|
||||
single_position_plot,
|
||||
transform_2D_propagation,
|
||||
)
|
||||
from .utils.parameter import Parameters, PlotRange
|
||||
from .utils import load_spectrum
|
||||
|
||||
|
||||
class Spectrum(np.ndarray):
|
||||
@@ -42,18 +44,6 @@ class Spectrum(np.ndarray):
|
||||
def __getitem__(self, key) -> "Spectrum":
|
||||
return super().__getitem__(key)
|
||||
|
||||
def energy(self) -> Union[np.ndarray, float]:
|
||||
if self.ndim == 1:
|
||||
m = np.argwhere(self.params.l > 0)[:, 0]
|
||||
m = np.array(sorted(m, key=lambda el: self.params.l[el]))
|
||||
return np.trapz(self.wl_int[m], self.params.l[m])
|
||||
else:
|
||||
return np.array([s.energy() for s in self])
|
||||
|
||||
def crop_wl(self, left: float, right: float) -> np.ndarray:
|
||||
cond = (self.params.l >= left) & (self.params.l <= right)
|
||||
return cond
|
||||
|
||||
@property
|
||||
def wl_int(self):
|
||||
return units.to_WL(math.abs2(self), self.params.l)
|
||||
@@ -118,7 +108,7 @@ class Spectrum(np.ndarray):
|
||||
return self.params.l[np.argmax(self.wl_int, axis=-1)]
|
||||
return np.array([s.wl_max for s in self])
|
||||
|
||||
def mask_wl(self, pos: float, width: float) -> "Spectrum":
|
||||
def mask_wl(self, pos: float, width: float) -> Spectrum:
|
||||
return self * np.exp(
|
||||
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
|
||||
)
|
||||
@@ -127,189 +117,105 @@ class Spectrum(np.ndarray):
|
||||
return pulse.measure_field(self.params.t, self.time_amp)
|
||||
|
||||
|
||||
class Pulse(Sequence):
|
||||
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
|
||||
"""load a data folder as a pulse
|
||||
class SimulationSeries:
|
||||
path: Path
|
||||
params: Parameters
|
||||
total_length: float
|
||||
total_num_steps: int
|
||||
previous: SimulationSeries = None
|
||||
fiber_lengths: list[tuple[str, float]]
|
||||
fiber_positions: list[tuple[str, float]]
|
||||
z_inds: np.ndarray
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
path to the data (folder containing .npy files)
|
||||
default_ind : int | Iterable[int], optional
|
||||
default indices to be loaded, by default None
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
path does not contain proper data
|
||||
"""
|
||||
self.logger = get_logger(__name__)
|
||||
self.path = Path(path)
|
||||
self.default_ind = default_ind
|
||||
|
||||
if not self.path.is_dir():
|
||||
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
||||
|
||||
self.params = Parameters.load(self.path / "params.toml")
|
||||
def __init__(self, path: os.PathLike):
|
||||
self.logger = get_logger()
|
||||
for self.path in simulations_list(path):
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(f"No simulation in {path}")
|
||||
self.params = Parameters.load(self.path / PARAM_FN)
|
||||
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
||||
if self.params.fiber_map is None:
|
||||
self.params.fiber_map = [(0.0, self.params.name)]
|
||||
|
||||
try:
|
||||
self.z = np.load(os.path.join(path, "z.npy"))
|
||||
except FileNotFoundError:
|
||||
if self.params is not None:
|
||||
self.z = self.params.z_targets
|
||||
else:
|
||||
raise
|
||||
self.nmax = len(list(self.path.glob("spectra_*.npy")))
|
||||
if self.nmax <= 0:
|
||||
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
||||
|
||||
self.t = self.params.t
|
||||
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
||||
self.w_order = np.argsort(w)
|
||||
self.w = w
|
||||
self.wl = units.m.inv(self.w)
|
||||
self.params.w = self.w
|
||||
self.params.z_targets = self.z
|
||||
self.w = self.params.w
|
||||
if self.params.prev_data_dir is not None:
|
||||
self.previous = SimulationSeries(self.params.prev_data_dir)
|
||||
self.total_length = self.accumulate_params("length")
|
||||
self.total_num_steps = self.accumulate_params("z_num")
|
||||
self.z_inds = np.arange(len(self.params.z_targets))
|
||||
self.z = self.params.z_targets
|
||||
if self.previous is not None:
|
||||
self.z += self.previous.params.z_targets[-1]
|
||||
self.params.z_targets = np.concatenate((self.previous.z, self.params.z_targets))
|
||||
self.z_inds += self.previous.z_inds[-1] + 1
|
||||
self.fiber_lengths = self.all_params("length")
|
||||
self.fiber_positions = [
|
||||
(this[0], following[1])
|
||||
for this, following in zip(self.fiber_lengths, [(None, 0.0)] + self.fiber_lengths)
|
||||
]
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
similar to all_spectra but works as an iterator
|
||||
"""
|
||||
|
||||
self.logger.debug(f"iterating through {self.path}")
|
||||
for i in range(self.nmax):
|
||||
yield self._load1(i)
|
||||
|
||||
def __len__(self):
|
||||
return self.nmax
|
||||
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
return self.all_spectra(key)
|
||||
|
||||
def intensity(self, unit):
|
||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||
x_axis = unit.inv(self.w)
|
||||
else:
|
||||
x_axis = unit.inv(self.t)
|
||||
|
||||
order = np.argsort(x_axis)
|
||||
func = dict(
|
||||
WL=self._to_wl_int,
|
||||
FREQ=self._to_freq_int,
|
||||
AFREQ=self._to_afreq_int,
|
||||
TIME=self._to_time_int,
|
||||
)[unit.type]
|
||||
|
||||
for spec in self:
|
||||
yield x_axis[order], func(spec)[:, order]
|
||||
|
||||
def _to_wl_int(self, spectrum):
|
||||
return units.to_WL(math.abs2(spectrum), spectrum.wl)
|
||||
|
||||
def _to_freq_int(self, spectrum):
|
||||
return math.abs2(spectrum)
|
||||
|
||||
def _to_afreq_int(self, spectrum):
|
||||
return math.abs2(spectrum)
|
||||
|
||||
def _to_time_int(self, spectrum):
|
||||
return math.abs2(np.fft.ifft(spectrum))
|
||||
|
||||
def amplitude(self, unit):
|
||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||
x_axis = unit.inv(self.w)
|
||||
else:
|
||||
x_axis = unit.inv(self.t)
|
||||
|
||||
order = np.argsort(x_axis)
|
||||
func = dict(
|
||||
WL=self._to_wl_amp,
|
||||
FREQ=self._to_freq_amp,
|
||||
AFREQ=self._to_afreq_amp,
|
||||
TIME=self._to_time_amp,
|
||||
)[unit.type]
|
||||
|
||||
for spec in self:
|
||||
yield x_axis[order], func(spec)[:, order]
|
||||
|
||||
def _to_wl_amp(self, spectrum):
|
||||
return (
|
||||
np.sqrt(
|
||||
units.to_WL(
|
||||
math.abs2(spectrum),
|
||||
spectrum.wl,
|
||||
)
|
||||
)
|
||||
* spectrum
|
||||
/ np.abs(spectrum)
|
||||
)
|
||||
|
||||
def _to_freq_amp(self, spectrum):
|
||||
return spectrum
|
||||
|
||||
def _to_afreq_amp(self, spectrum):
|
||||
return spectrum
|
||||
|
||||
def _to_time_amp(self, spectrum):
|
||||
return np.fft.ifft(spectrum)
|
||||
|
||||
def all_spectra(self, ind=None) -> Spectrum:
|
||||
"""
|
||||
loads the data already simulated.
|
||||
defauft shape is (z_targets, n, nt)
|
||||
def all_params(self, key: str) -> list[tuple[str, Any]]:
|
||||
"""returns the value of a parameter for each fiber
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ind : int or list of int
|
||||
if only certain spectra are desired
|
||||
key : str
|
||||
name of the parameter
|
||||
|
||||
Returns
|
||||
----------
|
||||
spectra : array of shape (nz, m, nt)
|
||||
array of complex spectra (pulse at nz positions consisting
|
||||
of nm simulation on a nt size grid)
|
||||
-------
|
||||
list[tuple[str, Any]]
|
||||
list of (fiber_name, param_value) tuples
|
||||
"""
|
||||
return list(reversed(self._all_params(key, [])))
|
||||
|
||||
self.logger.debug(f"opening {self.path}")
|
||||
def accumulate_params(self, key: str) -> Any:
|
||||
"""returns the sum of all the values a parameter takes. Useful to
|
||||
get the total length of the fiber, the total number of steps, etc.
|
||||
|
||||
# Check if file exists and assert how many z positions there are
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
name of the parameter
|
||||
|
||||
if ind is None:
|
||||
if self.default_ind is None:
|
||||
ind = range(self.nmax)
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
final sum
|
||||
"""
|
||||
return sum(el[1] for el in self.all_params(key))
|
||||
|
||||
def spectra(
|
||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||
) -> Spectrum:
|
||||
if z_descr is None:
|
||||
out = [self.spectra(i, sim_ind) for i in range(self.total_num_steps)]
|
||||
else:
|
||||
ind = self.default_ind
|
||||
if isinstance(ind, (int, np.integer)):
|
||||
ind = [ind]
|
||||
elif isinstance(ind, (float, np.floating)):
|
||||
ind = [self.z_ind(ind)]
|
||||
elif isinstance(ind[0], (float, np.floating)):
|
||||
ind = [self.z_ind(ii) for ii in ind]
|
||||
|
||||
# Load the spectra
|
||||
spectra = []
|
||||
for i in ind:
|
||||
spectra.append(self._load1(i))
|
||||
spectra = Spectrum(spectra, self.params)
|
||||
|
||||
self.logger.debug(f"all spectra from {self.path} successfully loaded")
|
||||
if len(ind) == 1:
|
||||
return spectra[0]
|
||||
if isinstance(z_descr, (float, np.floating)):
|
||||
if self.z[0] <= z_descr <= self.z[-1]:
|
||||
z_ind = self.z_inds[np.argmin(np.abs(self.z - z_descr))]
|
||||
elif 0 <= z_descr < self.z[0]:
|
||||
return self.previous.spectra(z_descr, sim_ind)
|
||||
else:
|
||||
return spectra
|
||||
raise ValueError(
|
||||
f"cannot match z={z_descr} with max length of {self.total_length}"
|
||||
)
|
||||
else:
|
||||
z_ind = z_descr
|
||||
|
||||
def all_fields(self, ind=None):
|
||||
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
|
||||
if z_ind < self.z_inds[0]:
|
||||
return self.previous.spectra(z_ind, sim_ind)
|
||||
if sim_ind is None:
|
||||
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
|
||||
else:
|
||||
out = self._load_1(z_ind)
|
||||
return Spectrum(out, self.params)
|
||||
|
||||
def _load1(self, i: int):
|
||||
if i < 0:
|
||||
i = self.nmax + i
|
||||
spec = load_spectrum(self.path / SPECN_FN.format(i))
|
||||
spec = np.atleast_2d(spec)
|
||||
spec = Spectrum(spec, self.params)
|
||||
return spec
|
||||
def fields(
|
||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||
) -> Spectrum:
|
||||
return np.fft.ifft(self.spectra(z_descr, sim_ind))
|
||||
|
||||
# Plotting
|
||||
|
||||
def plot_2D(
|
||||
self,
|
||||
@@ -317,12 +223,11 @@ class Pulse(Sequence):
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
ax: plt.Axes,
|
||||
z_pos: Union[int, Iterable[int]] = None,
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
plot_range = PlotRange(left, right, unit)
|
||||
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
|
||||
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
|
||||
return propagation_plot(vals, plot_range, self.params, ax, **kwargs)
|
||||
|
||||
def plot_1D(
|
||||
@@ -349,7 +254,7 @@ class Pulse(Sequence):
|
||||
**kwargs,
|
||||
):
|
||||
plot_range = PlotRange(left, right, unit)
|
||||
vals = self.retrieve_plot_values(plot_range, z_pos, slice(None))
|
||||
vals = self.retrieve_plot_values(plot_range, z_pos, None)
|
||||
return mean_values_plot(vals, plot_range, self.params, ax, **kwargs)
|
||||
|
||||
def retrieve_plot_values(
|
||||
@@ -357,16 +262,9 @@ class Pulse(Sequence):
|
||||
):
|
||||
|
||||
if plot_range.unit.type == "TIME":
|
||||
vals = self.all_fields(ind=z_pos)
|
||||
return self.fields(z_pos, sim_ind)
|
||||
else:
|
||||
vals = self.all_spectra(ind=z_pos)
|
||||
|
||||
if sim_ind is None:
|
||||
return vals
|
||||
elif z_pos is None:
|
||||
return vals[:, sim_ind]
|
||||
else:
|
||||
return vals[sim_ind]
|
||||
return self.spectra(z_pos, sim_ind)
|
||||
|
||||
def rin_propagation(
|
||||
self, left: float, right: float, unit: str
|
||||
@@ -392,22 +290,63 @@ class Pulse(Sequence):
|
||||
RIN
|
||||
"""
|
||||
spectra = []
|
||||
for spec in np.moveaxis(self.all_spectra(), 1, 0):
|
||||
for spec in np.moveaxis(self.spectra(None, None), 1, 0):
|
||||
x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False)
|
||||
spectra.append(tmp)
|
||||
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
|
||||
|
||||
def z_ind(self, z: float) -> int:
|
||||
"""return the closest z index to the given target
|
||||
# Private
|
||||
|
||||
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
|
||||
"""loads a spectrum file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : float
|
||||
target
|
||||
z_ind : int
|
||||
z_index relative to the entire simulation
|
||||
sim_ind : int, optional
|
||||
simulation index, used when repeated simulations with same parameters are ran, by default 0
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
index
|
||||
np.ndarray
|
||||
loaded spectrum file
|
||||
"""
|
||||
return math.argclosest(self.z, z)
|
||||
if sim_ind > 0:
|
||||
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind - self.z_inds[0], sim_ind))
|
||||
else:
|
||||
return load_spectrum(self.path / SPEC1_FN.format(z_ind - self.z_inds[0]))
|
||||
|
||||
def _all_params(self, key: str, l: list) -> list:
|
||||
l.append((self.params.name, getattr(self.params, key)))
|
||||
if self.previous is not None:
|
||||
return self.previous._all_params(key, l)
|
||||
return l
|
||||
|
||||
# Magic methods
|
||||
|
||||
def __iter__(self) -> Iterator[Spectrum]:
|
||||
for i in range(self.total_num_steps):
|
||||
yield self.spectra(i, None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})"
|
||||
|
||||
def __eq__(self, other: SimulationSeries) -> bool:
|
||||
return (
|
||||
self.path == other.path
|
||||
and self.params == other.params
|
||||
and self.previous == other.previous
|
||||
)
|
||||
|
||||
def __contains__(self, other: SimulationSeries) -> bool:
|
||||
if other is self or other == self:
|
||||
return True
|
||||
if self.previous is not None:
|
||||
return other in self.previous
|
||||
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
if isinstance(key, tuple):
|
||||
return self.spectra(*key)
|
||||
else:
|
||||
return self.spectra(key, None)
|
||||
|
||||
@@ -1,677 +0,0 @@
|
||||
"""
|
||||
This files includes utility functions designed more or less to be used specifically with the
|
||||
scgenerator module but some function may be used in any python program
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import threading
|
||||
from collections import abc
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from string import printable as str_printable
|
||||
from functools import cache
|
||||
from typing import Any, Callable, Generator, Iterable, MutableMapping, Sequence, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import pkg_resources as pkg
|
||||
import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__
|
||||
from ..env import pbar_policy
|
||||
from ..logger import get_logger
|
||||
|
||||
T_ = TypeVar("T_")
|
||||
|
||||
PathTree = list[tuple[Path, ...]]
|
||||
|
||||
|
||||
class Paths:
|
||||
_data_files = [
|
||||
"materials.toml",
|
||||
"hr_t.npz",
|
||||
"submit_job_template.txt",
|
||||
"start_worker.sh",
|
||||
"start_head.sh",
|
||||
]
|
||||
|
||||
paths = {
|
||||
f.split(".")[0]: os.path.abspath(
|
||||
pkg.resource_filename("scgenerator", os.path.join("data", f))
|
||||
)
|
||||
for f in _data_files
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get(cls, key):
|
||||
if key not in cls.paths:
|
||||
if os.path.exists("paths.toml"):
|
||||
with open("paths.toml") as file:
|
||||
paths_dico = toml.load(file)
|
||||
for k, v in paths_dico.items():
|
||||
cls.paths[k] = v
|
||||
if key not in cls.paths:
|
||||
get_logger(__name__).info(
|
||||
f"{key} was not found in path index, returning current working directory."
|
||||
)
|
||||
cls.paths[key] = os.getcwd()
|
||||
|
||||
return cls.paths[key]
|
||||
|
||||
@classmethod
|
||||
def gets(cls, key):
|
||||
"""returned the specified file as a string"""
|
||||
with open(cls.get(key)) as file:
|
||||
return file.read()
|
||||
|
||||
@classmethod
|
||||
def plot(cls, name):
|
||||
"""returns the paths to the specified plot. Used to save new plot
|
||||
example
|
||||
---------
|
||||
fig.savefig(Paths.plot("figure5.pdf"))
|
||||
"""
|
||||
return os.path.join(cls.get("plots"), name)
|
||||
|
||||
|
||||
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
|
||||
prev_data_dir = Path(prev_data_dir)
|
||||
num = find_last_spectrum_num(prev_data_dir)
|
||||
return load_spectrum(prev_data_dir / SPEC1_FN.format(num))
|
||||
|
||||
|
||||
@cache
|
||||
def load_spectrum(folder: os.PathLike) -> np.ndarray:
|
||||
return np.load(folder)
|
||||
|
||||
|
||||
def conform_toml_path(path: os.PathLike) -> str:
|
||||
path: str = str(path)
|
||||
if not path.lower().endswith(".toml"):
|
||||
path = path + ".toml"
|
||||
return path
|
||||
|
||||
|
||||
def open_single_config(path: os.PathLike) -> dict[str, Any]:
|
||||
d = open_config(path)
|
||||
f = d.pop("Fiber")[0]
|
||||
return d | f
|
||||
|
||||
|
||||
def open_config(path: os.PathLike):
|
||||
"""returns a dictionary parsed from the specified toml file
|
||||
This also handle having a 'INCLUDE' argument that will fill
|
||||
otherwise unspecified keys with what's in the INCLUDE file(s)"""
|
||||
|
||||
path = conform_toml_path(path)
|
||||
dico = resolve_loadfile_arg(load_toml(path))
|
||||
|
||||
dico.setdefault("variable", {})
|
||||
for key in {"simulation", "fiber", "gas", "pulse"} & dico.keys():
|
||||
section = dico.pop(key)
|
||||
dico["variable"].update(section.pop("variable", {}))
|
||||
dico.update(section)
|
||||
if len(dico["variable"]) == 0:
|
||||
dico.pop("variable")
|
||||
return dico
|
||||
|
||||
|
||||
def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]:
|
||||
if (f_list := dico.pop("INCLUDE", None)) is not None:
|
||||
if isinstance(f_list, str):
|
||||
f_list = [f_list]
|
||||
for to_load in f_list:
|
||||
loaded = load_toml(to_load)
|
||||
for k, v in loaded.items():
|
||||
if k not in dico and k not in dico.get("variable", {}):
|
||||
dico[k] = v
|
||||
for k, v in dico.items():
|
||||
if isinstance(v, MutableMapping):
|
||||
dico[k] = resolve_loadfile_arg(v)
|
||||
elif isinstance(v, Sequence):
|
||||
for i, vv in enumerate(v):
|
||||
if isinstance(vv, MutableMapping):
|
||||
dico[k][i] = resolve_loadfile_arg(vv)
|
||||
return dico
|
||||
|
||||
|
||||
def load_toml(descr: str) -> dict[str, Any]:
|
||||
if ":" in descr:
|
||||
path, entry = descr.split(":", 1)
|
||||
with open(path) as file:
|
||||
return toml.load(file)[entry]
|
||||
else:
|
||||
with open(descr) as file:
|
||||
return toml.load(file)
|
||||
|
||||
|
||||
def save_toml(path: os.PathLike, dico):
|
||||
"""saves a dictionary into a toml file"""
|
||||
path = conform_toml_path(path)
|
||||
with open(path, mode="w") as file:
|
||||
toml.dump(dico, file)
|
||||
return dico
|
||||
|
||||
|
||||
def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str, Any]], str]:
|
||||
loaded_config = open_config(final_config_path)
|
||||
final_name = loaded_config.get("name")
|
||||
fiber_list = loaded_config.pop("Fiber")
|
||||
configs = []
|
||||
if fiber_list is not None:
|
||||
master_variable = loaded_config.get("variable", {})
|
||||
for i, params in enumerate(fiber_list):
|
||||
params.setdefault("variable", master_variable if i == 0 else {})
|
||||
if i == 0:
|
||||
params["variable"] |= master_variable
|
||||
configs.append(loaded_config | params)
|
||||
else:
|
||||
configs.append(loaded_config)
|
||||
while "previous_config_file" in configs[0]:
|
||||
configs.insert(0, open_config(configs[0]["previous_config_file"]))
|
||||
configs[0].setdefault("variable", {})
|
||||
for pre, nex in zip(configs[:-1], configs[1:]):
|
||||
variable = nex.pop("variable", {})
|
||||
nex.update({k: v for k, v in pre.items() if k not in nex})
|
||||
nex["variable"] = variable
|
||||
|
||||
return configs, final_name
|
||||
|
||||
|
||||
def save_parameters(
|
||||
params: dict[str, Any], destination_dir: Path, file_name: str = PARAM_FN
|
||||
) -> Path:
|
||||
"""saves a parameter dictionary. Note that is does remove some entries, particularly
|
||||
those that take a lot of space ("t", "w", ...)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict[str, Any]
|
||||
dictionary to save
|
||||
destination_dir : Path
|
||||
destination directory
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
path to newly created the paramter file
|
||||
"""
|
||||
file_path = destination_dir / file_name
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
|
||||
# save toml of the simulation
|
||||
with open(file_path, "w") as file:
|
||||
toml.dump(params, file, encoder=toml.TomlNumpyEncoder())
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def load_material_dico(name: str) -> dict[str, Any]:
|
||||
"""loads a material dictionary
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
name of the material
|
||||
Returns
|
||||
----------
|
||||
material_dico : dict
|
||||
"""
|
||||
return toml.loads(Paths.gets("materials"))[name]
|
||||
|
||||
|
||||
def update_appended_params(source: Path, destination: Path, z: Sequence):
|
||||
z_num = len(z)
|
||||
params = open_config(source)
|
||||
params["z_num"] = z_num
|
||||
params["length"] = float(z[-1] - z[0])
|
||||
for p_name in ["recovery_data_dir", "prev_data_dir", "output_path"]:
|
||||
if p_name in params:
|
||||
del params[p_name]
|
||||
save_toml(destination, params)
|
||||
|
||||
|
||||
def to_62(i: int) -> str:
|
||||
arr = []
|
||||
if i == 0:
|
||||
return "0"
|
||||
i = abs(i)
|
||||
while i:
|
||||
i, value = divmod(i, 62)
|
||||
arr.append(str_printable[value])
|
||||
return "".join(reversed(arr))
|
||||
|
||||
|
||||
def build_path_trees(sim_dir: Path) -> list[PathTree]:
|
||||
sim_dir = sim_dir.resolve()
|
||||
path_branches: list[tuple[Path, ...]] = []
|
||||
to_check = list(sim_dir.glob("*fiber*num*"))
|
||||
with PBars(len(to_check), desc="Building path trees") as pbar:
|
||||
for branch in map(build_path_branch, to_check):
|
||||
if branch is not None:
|
||||
path_branches.append(branch)
|
||||
pbar.update()
|
||||
path_trees = group_path_branches(path_branches)
|
||||
return path_trees
|
||||
|
||||
|
||||
def build_path_branch(data_dir: Path) -> tuple[Path, ...]:
|
||||
if not data_dir.is_dir():
|
||||
return None
|
||||
path_branch = [data_dir]
|
||||
while (
|
||||
prev_sim_path := open_config(path_branch[-1] / PARAM_FN).get("prev_data_dir")
|
||||
) is not None:
|
||||
p = Path(prev_sim_path).resolve()
|
||||
if not p.exists():
|
||||
p = Path(*p.parts[-2:]).resolve()
|
||||
path_branch.append(p)
|
||||
return tuple(reversed(path_branch))
|
||||
|
||||
|
||||
def group_path_branches(path_branches: list[tuple[Path, ...]]) -> list[PathTree]:
|
||||
"""groups path lists
|
||||
|
||||
[
|
||||
("a/id 0 wavelength 100 num 0"," b/id 0 wavelength 100 num 0"),
|
||||
("a/id 2 wavelength 100 num 1"," b/id 2 wavelength 100 num 1"),
|
||||
("a/id 1 wavelength 200 num 0"," b/id 1 wavelength 200 num 0"),
|
||||
("a/id 3 wavelength 200 num 1"," b/id 3 wavelength 200 num 1")
|
||||
]
|
||||
->
|
||||
[
|
||||
(
|
||||
("a/id 0 wavelength 100 num 0", "a/id 2 wavelength 100 num 1"),
|
||||
("b/id 0 wavelength 100 num 0", "b/id 2 wavelength 100 num 1"),
|
||||
)
|
||||
(
|
||||
("a/id 1 wavelength 200 num 0", "a/id 3 wavelength 200 num 1"),
|
||||
("b/id 1 wavelength 200 num 0", "b/id 3 wavelength 200 num 1"),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path_branches : list[tuple[Path, ...]]
|
||||
each element of the list is a path to a folder containing data of one simulation
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PathTree]
|
||||
list of PathTrees to be used in merge
|
||||
"""
|
||||
sort_key = lambda el: el[0]
|
||||
|
||||
size = len(path_branches[0])
|
||||
out_trees_map: dict[str, dict[int, dict[int, Path]]] = {}
|
||||
for branch in path_branches:
|
||||
b_id = branch_id(branch)
|
||||
out_trees_map.setdefault(b_id, {i: {} for i in range(size)})
|
||||
for sim_part, data_dir in enumerate(branch):
|
||||
num = re.search(r"(?<=num )[0-9]+", data_dir.name)[0]
|
||||
out_trees_map[b_id][sim_part][int(num)] = data_dir
|
||||
|
||||
return [
|
||||
tuple(
|
||||
tuple(w for _, w in sorted(v.items(), key=sort_key))
|
||||
for __, v in sorted(d.items(), key=sort_key)
|
||||
)
|
||||
for d in out_trees_map.values()
|
||||
]
|
||||
|
||||
|
||||
def merge_path_tree(
|
||||
path_tree: PathTree, destination: Path, z_callback: Callable[[int], None] = None
|
||||
):
|
||||
"""given a path tree, copies the file into the right location
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path_tree : PathTree
|
||||
elements of the list returned by group_path_branches
|
||||
destination : Path
|
||||
dir where to save the data
|
||||
"""
|
||||
z_arr: list[float] = []
|
||||
|
||||
destination.mkdir(exist_ok=True)
|
||||
|
||||
for i, (z, merged_spectra) in enumerate(merge_spectra(path_tree)):
|
||||
z_arr.append(z)
|
||||
spec_out_name = SPECN_FN.format(i)
|
||||
np.save(destination / spec_out_name, merged_spectra)
|
||||
if z_callback is not None:
|
||||
z_callback(i)
|
||||
d = np.diff(z_arr)
|
||||
d[d < 0] = 0
|
||||
z_arr = np.concatenate(([z_arr[0]], np.cumsum(d)))
|
||||
np.save(destination / Z_FN, z_arr)
|
||||
update_appended_params(path_tree[-1][0] / PARAM_FN, destination / PARAM_FN, z_arr)
|
||||
|
||||
|
||||
def merge_spectra(
|
||||
path_tree: PathTree,
|
||||
) -> Generator[tuple[float, np.ndarray], None, None]:
|
||||
for same_sim_paths in path_tree:
|
||||
z_arr = np.load(same_sim_paths[0] / Z_FN)
|
||||
for i, z in enumerate(z_arr):
|
||||
spectra: list[np.ndarray] = []
|
||||
for data_dir in same_sim_paths:
|
||||
spec = np.load(data_dir / SPEC1_FN.format(i))
|
||||
spectra.append(spec)
|
||||
yield z, np.atleast_2d(spectra)
|
||||
|
||||
|
||||
def merge(destination: os.PathLike, path_trees: list[PathTree] = None):
|
||||
|
||||
destination = ensure_folder(Path(destination))
|
||||
|
||||
z_num = 0
|
||||
prev_z_num = 0
|
||||
|
||||
for i, sim_dir in enumerate(sim_dirs(path_trees)):
|
||||
conf = sim_dir / "initial_config.toml"
|
||||
shutil.copy(
|
||||
conf,
|
||||
destination / f"initial_config_{i}.toml",
|
||||
)
|
||||
prev_z_num = open_config(conf).get("z_num", prev_z_num)
|
||||
z_num += prev_z_num
|
||||
|
||||
pbars = PBars(
|
||||
len(path_trees) * z_num, "Merging", 1, worker_kwargs=dict(total=z_num, desc="current pos")
|
||||
)
|
||||
for path_tree in path_trees:
|
||||
pbars.reset(1)
|
||||
iden_items = path_tree[-1][0].name.split()[2:]
|
||||
for i, p_name in list(enumerate(iden_items))[-2::-2]:
|
||||
if p_name == "num":
|
||||
del iden_items[i + 1]
|
||||
del iden_items[i]
|
||||
iden = PARAM_SEPARATOR.join(iden_items)
|
||||
merge_path_tree(path_tree, destination / iden, z_callback=lambda i: pbars.update(1))
|
||||
|
||||
|
||||
def sim_dirs(path_trees: list[PathTree]) -> Generator[Path, None, None]:
|
||||
for p in path_trees[0]:
|
||||
yield p[0].parent
|
||||
|
||||
|
||||
def save_data(data: np.ndarray, data_dir: Path, file_name: str):
|
||||
"""saves numpy array to disk
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : np.ndarray
|
||||
data to save
|
||||
file_name : str
|
||||
file name
|
||||
task_id : int
|
||||
id that uniquely identifies the process
|
||||
identifier : str, optional
|
||||
identifier in the main data folder of the task, by default ""
|
||||
"""
|
||||
path = data_dir / file_name
|
||||
np.save(path, data)
|
||||
get_logger(__name__).debug(f"saved data in {path}")
|
||||
return
|
||||
|
||||
|
||||
def ensure_folder(path: Path, prevent_overwrite: bool = True, mkdir=True) -> Path:
|
||||
"""ensure a folder exists and doesn't overwrite anything if required
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : Path
|
||||
desired path
|
||||
prevent_overwrite : bool, optional
|
||||
whether to create a new directory when one already exists, by default True
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
final path
|
||||
"""
|
||||
|
||||
path = path.resolve()
|
||||
|
||||
# is path root ?
|
||||
if len(path.parts) < 2:
|
||||
return path
|
||||
|
||||
# is a part of path an existing *file* ?
|
||||
parts = path.parts
|
||||
path = Path(path.root)
|
||||
for part in parts:
|
||||
if path.is_file():
|
||||
path = ensure_folder(path, mkdir=mkdir, prevent_overwrite=False)
|
||||
path /= part
|
||||
|
||||
folder_name = path.name
|
||||
|
||||
for i in itertools.count():
|
||||
if not path.is_file() and (not prevent_overwrite or not path.is_dir()):
|
||||
if mkdir:
|
||||
path.mkdir(exist_ok=True)
|
||||
return path
|
||||
path = path.parent / (folder_name + f"_{i}")
|
||||
|
||||
|
||||
class PBars:
|
||||
def __init__(
|
||||
self,
|
||||
task: Union[int, Iterable[T_]],
|
||||
desc: str,
|
||||
num_sub_bars: int = 0,
|
||||
head_kwargs=None,
|
||||
worker_kwargs=None,
|
||||
) -> "PBars":
|
||||
|
||||
self.id = random.randint(100000, 999999)
|
||||
try:
|
||||
self.width = os.get_terminal_size().columns
|
||||
except OSError:
|
||||
self.width = 80
|
||||
if isinstance(task, abc.Iterable):
|
||||
self.iterator: Iterable[T_] = iter(task)
|
||||
self.num_tot: int = len(task)
|
||||
else:
|
||||
self.num_tot: int = task
|
||||
self.iterator = None
|
||||
|
||||
self.policy = pbar_policy()
|
||||
if head_kwargs is None:
|
||||
head_kwargs = dict()
|
||||
if worker_kwargs is None:
|
||||
worker_kwargs = dict(
|
||||
total=1,
|
||||
desc="Worker {worker_id}",
|
||||
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
|
||||
)
|
||||
if "print" not in pbar_policy():
|
||||
head_kwargs["file"] = worker_kwargs["file"] = StringIO()
|
||||
self.width = 80
|
||||
head_kwargs["desc"] = desc
|
||||
self.pbars = [tqdm(total=self.num_tot, ncols=self.width, ascii=False, **head_kwargs)]
|
||||
for i in range(1, num_sub_bars + 1):
|
||||
kwargs = {k: v for k, v in worker_kwargs.items()}
|
||||
if "desc" in kwargs:
|
||||
kwargs["desc"] = kwargs["desc"].format(worker_id=i)
|
||||
self.append(tqdm(position=i, ncols=self.width, ascii=False, **kwargs))
|
||||
self.print_path = Path(
|
||||
f"progress {self.pbars[0].desc.replace('/', '')} {self.id}"
|
||||
).resolve()
|
||||
self.close_ev = threading.Event()
|
||||
if "file" in self.policy:
|
||||
self.thread = threading.Thread(target=self.print_worker, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def print(self):
|
||||
if "file" not in self.policy:
|
||||
return
|
||||
s = []
|
||||
for pbar in self.pbars:
|
||||
s.append(str(pbar))
|
||||
self.print_path.write_text("\n".join(s))
|
||||
|
||||
def print_worker(self):
|
||||
while True:
|
||||
if self.close_ev.wait(2.0):
|
||||
return
|
||||
self.print()
|
||||
|
||||
def __iter__(self):
|
||||
with self as pb:
|
||||
for thing in self.iterator:
|
||||
yield thing
|
||||
pb.update()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.pbars[key]
|
||||
|
||||
def update(self, i=None, value=1):
|
||||
if i is None:
|
||||
for pbar in self.pbars[1:]:
|
||||
pbar.update(value)
|
||||
elif i > 0:
|
||||
self.pbars[i].update(value)
|
||||
self.pbars[0].update()
|
||||
|
||||
def append(self, pbar: tqdm):
|
||||
self.pbars.append(pbar)
|
||||
|
||||
def reset(self, i):
|
||||
self.pbars[i].update(-self.pbars[i].n)
|
||||
self.print()
|
||||
|
||||
def close(self):
|
||||
self.print()
|
||||
self.close_ev.set()
|
||||
if "file" in self.policy:
|
||||
self.thread.join()
|
||||
for pbar in self.pbars:
|
||||
pbar.close()
|
||||
|
||||
|
||||
class ProgressBarActor:
|
||||
def __init__(self, name: str, num_workers: int, num_steps: int) -> None:
|
||||
self.counters = [0 for _ in range(num_workers + 1)]
|
||||
self.p_bars = PBars(
|
||||
num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")
|
||||
)
|
||||
|
||||
def update(self, worker_id: int, rel_pos: float = None) -> None:
|
||||
"""update a counter
|
||||
|
||||
Parameters
|
||||
----------
|
||||
worker_id : int
|
||||
id of the worker. 0 is the overall progress
|
||||
rel_pos : float, optional
|
||||
if None, increase the counter by one, if set, will set
|
||||
the counter to the specified value (instead of incrementing it), by default None
|
||||
"""
|
||||
if rel_pos is None:
|
||||
self.counters[worker_id] += 1
|
||||
else:
|
||||
self.counters[worker_id] = rel_pos
|
||||
|
||||
def update_pbars(self):
|
||||
for counter, pbar in zip(self.counters, self.p_bars.pbars):
|
||||
pbar.update(counter - pbar.n)
|
||||
|
||||
def close(self):
|
||||
self.p_bars.close()
|
||||
|
||||
|
||||
def progress_worker(
|
||||
name: str, num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue
|
||||
):
|
||||
"""keeps track of progress on a separate thread
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_steps : int
|
||||
total number of steps, used for the main progress bar (position 0)
|
||||
progress_queue : multiprocessing.Queue
|
||||
values are either
|
||||
Literal[0] : stop the worker and close the progress bars
|
||||
tuple[int, float] : worker id and relative progress between 0 and 1
|
||||
"""
|
||||
with PBars(
|
||||
num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")
|
||||
) as pbars:
|
||||
while True:
|
||||
raw = progress_queue.get()
|
||||
if raw == 0:
|
||||
return
|
||||
i, rel_pos = raw
|
||||
if i > 0:
|
||||
pbars[i].update(rel_pos - pbars[i].n)
|
||||
pbars[0].update()
|
||||
elif i == 0:
|
||||
pbars[0].update(rel_pos)
|
||||
|
||||
|
||||
def branch_id(branch: tuple[Path, ...]) -> str:
|
||||
return branch[-1].name.split()[1]
|
||||
|
||||
|
||||
def find_last_spectrum_num(data_dir: Path):
|
||||
for num in itertools.count(1):
|
||||
p_to_test = data_dir / SPEC1_FN.format(num)
|
||||
if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0:
|
||||
return num - 1
|
||||
|
||||
|
||||
def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray:
|
||||
threshold = y.min() + rel_thr * (y.max() - y.min())
|
||||
above_threshold = y > threshold
|
||||
ind = np.argsort(x)
|
||||
valid_ind = [
|
||||
np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k
|
||||
]
|
||||
ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0]
|
||||
width = len(ind_above)
|
||||
return np.concatenate(
|
||||
(
|
||||
np.arange(max(ind_above[0] - width, 0), ind_above[0]),
|
||||
ind_above,
|
||||
np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def translate_parameters(d: dict[str, Any]) -> dict[str, Any]:
|
||||
old_names = dict(
|
||||
interp_degree="interpolation_degree",
|
||||
beta="beta2_coefficients",
|
||||
interp_range="interpolation_range",
|
||||
)
|
||||
deleted_names = {"lower_wavelength_interp_limit", "upper_wavelength_interp_limit"}
|
||||
defaults_to_add = dict(repeat=1)
|
||||
new = {}
|
||||
for k, v in d.items():
|
||||
if k == "error_ok":
|
||||
new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v
|
||||
elif k in deleted_names:
|
||||
continue
|
||||
elif isinstance(v, MutableMapping):
|
||||
new[k] = translate_parameters(v)
|
||||
else:
|
||||
new[old_names.get(k, k)] = v
|
||||
return defaults_to_add | new
|
||||
@@ -1,35 +0,0 @@
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import toml
|
||||
from scgenerator import logger
|
||||
from send2trash import send2trash
|
||||
|
||||
TMP = "testing/.tmp"
|
||||
|
||||
|
||||
class TestRecoveryParamSequence(unittest.TestCase):
|
||||
def setUp(self):
|
||||
shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP)
|
||||
self.conf = toml.load(TMP + "/initial_config.toml")
|
||||
io.set_data_folder(55, TMP)
|
||||
|
||||
def test_remaining_simulations_count(self):
|
||||
param_seq = initialize.RecoveryParamSequence(self.conf, 55)
|
||||
self.assertEqual(5, len(param_seq))
|
||||
|
||||
def test_only_one_to_complete(self):
|
||||
param_seq = initialize.RecoveryParamSequence(self.conf, 55)
|
||||
i = 0
|
||||
for expected, (vary_list, params) in zip([True, False, False, False, False], param_seq):
|
||||
i += 1
|
||||
self.assertEqual(expected, "recovery_last_stored" in params)
|
||||
|
||||
self.assertEqual(5, i)
|
||||
|
||||
def tearDown(self):
|
||||
send2trash(TMP)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,216 +0,0 @@
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import toml
|
||||
from scgenerator import defaults, utils, math
|
||||
from scgenerator.errors import *
|
||||
from scgenerator.physics import pulse, units
|
||||
from scgenerator.utils.parameter import Config, Parameters
|
||||
|
||||
|
||||
def load_conf(name):
|
||||
with open("testing/configs/" + name + ".toml") as file:
|
||||
conf = toml.load(file)
|
||||
return conf
|
||||
|
||||
|
||||
def conf_maker(folder):
|
||||
def conf(name):
|
||||
return load_conf(folder + "/" + name)
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
class TestParamSequence(unittest.TestCase):
|
||||
def iterconf(self, files):
|
||||
conf = conf_maker("param_sequence")
|
||||
for path in files:
|
||||
yield init.ParamSequence(conf(path))
|
||||
|
||||
def test_no_repeat_in_sub_folder_names(self):
|
||||
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
|
||||
l = []
|
||||
s = []
|
||||
for vary_list, _ in utils.required_simulations(param_seq.config):
|
||||
self.assertNotIn(vary_list, l)
|
||||
self.assertNotIn(utils.format_variable_list(vary_list), s)
|
||||
l.append(vary_list)
|
||||
s.append(utils.format_variable_list(vary_list))
|
||||
|
||||
def test_no_variations_yields_only_num_and_id(self):
|
||||
for param_seq in self.iterconf(["no_variations"]):
|
||||
for vary_list, _ in utils.required_simulations(param_seq.config):
|
||||
self.assertEqual(vary_list[1][0], "num")
|
||||
self.assertEqual(vary_list[0][0], "id")
|
||||
self.assertEqual(2, len(vary_list))
|
||||
|
||||
|
||||
class TestInitializeMethods(unittest.TestCase):
|
||||
def test_validate_types(self):
|
||||
conf = lambda s: load_conf("validate_types/" + s)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, r"'behaviors\[3\]' must be a str in"):
|
||||
init.Config(**conf("bad2"))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "value must be of type <class 'float'>"):
|
||||
init.Config(**conf("bad3"))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"):
|
||||
init.Config(**conf("bad4"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "'variable intensity_noise' value must be of type <class 'list'>"
|
||||
):
|
||||
init.Config(**conf("bad5"))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "'repeat' must be positive"):
|
||||
init.Config(**conf("bad6"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "variable parameter 'intensity_noise' must not be empty"
|
||||
):
|
||||
init.Config(**conf("bad7"))
|
||||
|
||||
self.assertIsNone(init.Config(**conf("good")).hr_w)
|
||||
|
||||
def test_ensure_consistency(self):
|
||||
conf = lambda s: load_conf("ensure_consistency/" + s)
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"1 of '\['t0', 'width'\]' is required and no defaults have been set",
|
||||
):
|
||||
init.Config(**conf("bad1"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set",
|
||||
):
|
||||
init.Config(**conf("bad2"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"2 of '\['dt', 't_num', 'time_window'\]' are required and no defaults have been set",
|
||||
):
|
||||
init.Config(**conf("bad3"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
DuplicateParameterError,
|
||||
r"got multiple values for parameter 'width'",
|
||||
):
|
||||
init.Config(**conf("bad4"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"'capillary_thickness' is a required parameter for fiber model 'hasan' and no defaults have been set",
|
||||
):
|
||||
init.Config(**conf("bad5"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"1 of '\['capillary_spacing', 'capillary_outer_d'\]' is required for fiber model 'hasan' and no defaults have been set",
|
||||
):
|
||||
init.Config(**conf("bad6"))
|
||||
|
||||
self.assertLessEqual(
|
||||
{"model": "pcf"}.items(), init.Config(**conf("good1")).__dict__.items()
|
||||
)
|
||||
|
||||
self.assertIsNone(init.Config(**conf("good4")).gamma)
|
||||
|
||||
self.assertLessEqual(
|
||||
{"raman_type": "agrawal"}.items(),
|
||||
init.Config(**conf("good2")).__dict__.items(),
|
||||
)
|
||||
|
||||
self.assertLessEqual(
|
||||
{"name": "no name"}.items(), init.Config(**conf("good3")).__dict__.items()
|
||||
)
|
||||
|
||||
self.assertLessEqual(
|
||||
{"capillary_nested": 0, "capillary_resonance_strengths": []}.items(),
|
||||
init.Config(**conf("good4")).__dict__.items(),
|
||||
)
|
||||
|
||||
self.assertLessEqual(
|
||||
dict(he_mode=(1, 1)).items(),
|
||||
init.Config(**conf("good5")).__dict__.items(),
|
||||
)
|
||||
|
||||
self.assertLessEqual(
|
||||
dict(temperature=300, pressure=1e5, gas_name="vacuum", plasma_density=0).items(),
|
||||
init.Config(**conf("good5")).__dict__.items(),
|
||||
)
|
||||
|
||||
def setup_conf_custom_field(self, path) -> Parameters:
|
||||
|
||||
conf = load_conf(path)
|
||||
conf = Parameters(**conf)
|
||||
init.build_sim_grid_in_place(conf)
|
||||
return conf
|
||||
|
||||
def test_setup_custom_field(self):
|
||||
d = np.load("testing/configs/custom_field/init_field.npz")
|
||||
t = d["time"]
|
||||
field = d["field"]
|
||||
conf = self.setup_conf_custom_field("custom_field/no_change")
|
||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
||||
conf
|
||||
)
|
||||
self.assertAlmostEqual(conf.field_0.real.max(), field.real.max(), 4)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/peak_power")
|
||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
||||
conf
|
||||
)
|
||||
conf.wavelength = pulse.correct_wavelength(conf.wavelength, conf.w_c, conf.field_0)
|
||||
self.assertAlmostEqual(math.abs2(conf.field_0).max(), 20000, 4)
|
||||
self.assertTrue(result)
|
||||
self.assertNotAlmostEqual(conf.wavelength, 1593e-9)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/mean_power")
|
||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
||||
conf
|
||||
)
|
||||
self.assertAlmostEqual(np.trapz(math.abs2(conf.field_0), conf.t), 0.22 / 40e6, 4)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/recover1")
|
||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
||||
conf
|
||||
)
|
||||
self.assertAlmostEqual(math.abs2(conf.field_0 - field).sum(), 0)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/recover2")
|
||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
||||
conf
|
||||
)
|
||||
self.assertAlmostEqual((math.abs2(conf.field_0) / 0.9 - math.abs2(field)).sum(), 0)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
|
||||
result = Parameters(**conf)
|
||||
self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
|
||||
conf.wavelength = 1593e-9
|
||||
result = Parameters(**conf)
|
||||
|
||||
conf = load_conf("custom_field/wavelength_shift2")
|
||||
conf = init.Config(**conf)
|
||||
for target, (variable, config) in zip(
|
||||
[1050e-9, 1321e-9, 1593e-9], init.ParamSequence(conf)
|
||||
):
|
||||
init.build_sim_grid_in_place(conf)
|
||||
self.assertAlmostEqual(
|
||||
units.m.inv(config.w)[np.argmax(math.abs2(config.spec_0))], target
|
||||
)
|
||||
print(config.wavelength, target)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
conf = conf_maker("validate_types")
|
||||
|
||||
unittest.main()
|
||||
@@ -1,41 +0,0 @@
|
||||
import unittest
|
||||
from scgenerator.physics.pulse import conform_pulse_params
|
||||
|
||||
|
||||
class TestPulseMethods(unittest.TestCase):
|
||||
def test_conform_pulse_params(self):
|
||||
self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, energy=6))
|
||||
self.assertNotIn(None, conform_pulse_params("gaussian", width=5, energy=6))
|
||||
self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, peak_power=6))
|
||||
self.assertNotIn(None, conform_pulse_params("gaussian", width=5, peak_power=6))
|
||||
|
||||
self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, energy=6)))
|
||||
self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, energy=6)))
|
||||
self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, peak_power=6)))
|
||||
self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, peak_power=6)))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "when soliton number is desired, both gamma and beta2 must be specified"
|
||||
):
|
||||
conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "when soliton number is desired, both gamma and beta2 must be specified"
|
||||
):
|
||||
conform_pulse_params("gaussian", t0=5, energy=6, beta2=0.01)
|
||||
|
||||
self.assertEqual(
|
||||
5, len(conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01, beta2=2e-6))
|
||||
)
|
||||
self.assertEqual(
|
||||
5, len(conform_pulse_params("gaussian", width=5, energy=6, gamma=0.01, beta2=2e-6))
|
||||
)
|
||||
self.assertEqual(
|
||||
5, len(conform_pulse_params("gaussian", t0=5, peak_power=6, gamma=0.01, beta2=2e-6))
|
||||
)
|
||||
self.assertEqual(
|
||||
5, len(conform_pulse_params("gaussian", width=5, peak_power=6, gamma=0.01, beta2=2e-6))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,67 +0,0 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import toml
|
||||
from scgenerator import utils
|
||||
|
||||
|
||||
def load_conf(name):
|
||||
with open("testing/configs/" + name + ".toml") as file:
|
||||
conf = toml.load(file)
|
||||
return conf
|
||||
|
||||
|
||||
def conf_maker(folder, val=True):
|
||||
def conf(name):
|
||||
if val:
|
||||
return initialize.Config(**load_conf(folder + "/" + name))
|
||||
else:
|
||||
return initialize.Config(**load_conf(folder + "/" + name))
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
class TestUtilsMethods(unittest.TestCase):
|
||||
def test_count_variations(self):
|
||||
conf = conf_maker("count_variations")
|
||||
|
||||
for sim, vary in [(1, 0), (1, 1), (2, 1), (2, 0), (120, 3)]:
|
||||
self.assertEqual((sim, vary), utils.count_variations(conf(f"{sim}sim_{vary}vary")))
|
||||
|
||||
def test_format_value(self):
|
||||
values = [
|
||||
122e-6,
|
||||
True,
|
||||
["raman", "ss"],
|
||||
np.arange(5),
|
||||
1.123,
|
||||
1.1230001,
|
||||
0.002e122,
|
||||
12.3456e-9,
|
||||
]
|
||||
s = [
|
||||
"0.000122",
|
||||
"True",
|
||||
"raman-ss",
|
||||
"0-1-2-3-4",
|
||||
"1.123",
|
||||
"1.1230001",
|
||||
"2e+119",
|
||||
"1.23456e-08",
|
||||
]
|
||||
|
||||
for value, target in zip(values, s):
|
||||
self.assertEqual(target, utils.format_value(value))
|
||||
|
||||
def test_override_config(self):
|
||||
conf = conf_maker("override", False)
|
||||
old = conf("initial_config")
|
||||
new = conf("fiber2")
|
||||
|
||||
over = utils.override_config(vars(new), old)
|
||||
self.assertNotIn("input_transmission", over.variable)
|
||||
self.assertIsNone(over.input_transmission)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
54
testing/test_variationer.py
Normal file
54
testing/test_variationer.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from pydantic import main
|
||||
import scgenerator as sc
|
||||
|
||||
|
||||
def test_descriptor():
|
||||
# Same branch
|
||||
var1 = sc.VariationDescriptor(
|
||||
raw_descr=[[("num", 1), ("a", False)], [("b", 0)]], index=[[1, 0], [0]]
|
||||
)
|
||||
var2 = sc.VariationDescriptor(
|
||||
raw_descr=[[("num", 2), ("a", False)], [("b", 0)]], index=[[1, 0], [0]]
|
||||
)
|
||||
assert var1.branch.identifier == "b_0"
|
||||
assert var1.identifier != var1.branch.identifier
|
||||
assert var1.identifier != var2.identifier
|
||||
assert var1.branch.identifier == var2.branch.identifier
|
||||
|
||||
# different branch
|
||||
var3 = sc.VariationDescriptor(
|
||||
raw_descr=[[("num", 2), ("a", True)], [("b", 0)]], index=[[1, 0], [0]]
|
||||
)
|
||||
assert var1.branch.identifier != var3.branch.identifier
|
||||
assert var1.formatted_descriptor() != var2.formatted_descriptor()
|
||||
assert var1.formatted_descriptor() != var3.formatted_descriptor()
|
||||
|
||||
|
||||
def test_variationer():
|
||||
var = sc.Variationer(
|
||||
[
|
||||
dict(a=[1, 2], num=[0, 1, 2]),
|
||||
[dict(b=["000", "111"], c=["a", "-1"])],
|
||||
dict(),
|
||||
dict(),
|
||||
[dict(aaa=[True, False], bb=[1, 3])],
|
||||
]
|
||||
)
|
||||
assert var.var_num(0) == 6
|
||||
assert var.var_num(1) == 12
|
||||
assert var.var_num() == 24
|
||||
|
||||
cfg = dict(bb=None)
|
||||
branches = set()
|
||||
for descr in var.iterate():
|
||||
assert descr.update_config(cfg).items() >= set(descr.raw_descr[-1])
|
||||
branches.add(descr.branch.identifier)
|
||||
assert len(branches) == 8
|
||||
|
||||
|
||||
def main():
|
||||
test_descriptor()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user