removed a bunch of stuff
Removed: - Variationer - FileConfiguration - Scripts (slurm, ...) - CLI
This commit is contained in:
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
.DS_store
|
||||
.idea
|
||||
**/*.npy
|
||||
.conda-env
|
||||
|
||||
pyrightconfig.json
|
||||
|
||||
@@ -17,15 +18,8 @@ __pycache__
|
||||
tmp*
|
||||
paths.json
|
||||
scgenerator_log*
|
||||
scgenerator.log
|
||||
.scgenerator_tmp
|
||||
sc-*.log
|
||||
|
||||
.vscode
|
||||
|
||||
|
||||
# latex
|
||||
*.aux
|
||||
*.fdb_latexmk
|
||||
*.fls
|
||||
*.log
|
||||
*.synctex.gz
|
||||
|
||||
@@ -27,6 +27,7 @@ dependencies = [
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
ignore = ["E741"]
|
||||
|
||||
[tool.ruff.pydocstyle]
|
||||
convention = "numpy"
|
||||
@@ -34,3 +35,6 @@ convention = "numpy"
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
# # flake8: noqa
|
||||
# isort: skip_file
|
||||
# ruff: noqa
|
||||
from scgenerator import math, operators, plotting
|
||||
from scgenerator.helpers import *
|
||||
from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace
|
||||
from scgenerator.parameter import FileConfiguration, Parameters
|
||||
from scgenerator.parameter import Parameters
|
||||
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
||||
from scgenerator.physics.units import PlotRange
|
||||
from scgenerator.solver import integrate, solve43
|
||||
from scgenerator.utils import (Paths, _open_config, open_single_config,
|
||||
simulations_list)
|
||||
|
||||
@@ -312,7 +312,7 @@ default_rules: list[Rule] = [
|
||||
Rule("w_num", len, ["w"]),
|
||||
Rule("dw", lambda w: w[1] - w[0]),
|
||||
Rule(["fft", "ifft"], utils.fft_functions, priorities=1),
|
||||
Rule("interpolation_range", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)),
|
||||
Rule("wavelength_window", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)),
|
||||
# Pulse
|
||||
Rule("field_0", pulse.finalize_pulse),
|
||||
Rule(["input_time", "input_field"], pulse.load_custom_field),
|
||||
@@ -393,7 +393,7 @@ default_rules: list[Rule] = [
|
||||
Rule(
|
||||
"V_eff_arr",
|
||||
fiber.V_eff_step_index,
|
||||
["l", "core_radius", "numerical_aperture", "interpolation_range"],
|
||||
["l", "core_radius", "numerical_aperture", "wavelength_window"],
|
||||
),
|
||||
Rule("n2", materials.gas_n2),
|
||||
Rule("n2", lambda: 2.2e-20, priorities=-1),
|
||||
@@ -403,7 +403,7 @@ default_rules: list[Rule] = [
|
||||
# Raman
|
||||
Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w),
|
||||
Rule("raman_fraction", fiber.raman_fraction),
|
||||
Rule("raman_fraction", lambda:0, priorities=-1),
|
||||
Rule("raman_fraction", lambda: 0, priorities=-1),
|
||||
# loss
|
||||
Rule("alpha_arr", fiber.scalar_loss),
|
||||
Rule("alpha_arr", fiber.safe_capillary_loss, conditions=dict(loss="capillary")),
|
||||
@@ -434,7 +434,7 @@ envelope_rules = default_rules + [
|
||||
Rule("beta2_arr", fiber.dispersion_from_coefficients),
|
||||
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
|
||||
Rule(
|
||||
["wl_for_disp", "beta2_arr", "interpolation_range"],
|
||||
["wl_for_disp", "beta2_arr", "wavelength_window"],
|
||||
fiber.load_custom_dispersion,
|
||||
priorities=[2, 2, 2],
|
||||
),
|
||||
@@ -442,7 +442,7 @@ envelope_rules = default_rules + [
|
||||
Rule("gamma_op", operators.variable_gamma, priorities=2),
|
||||
Rule("gamma_op", operators.constant_quantity, ["gamma_arr"], priorities=1),
|
||||
Rule("gamma_op", lambda w_num, gamma: operators.constant_quantity(np.ones(w_num) * gamma)),
|
||||
Rule("gamma_op", operators.no_op_freq, priorities=-1),
|
||||
Rule("gamma_op", lambda: operators.constant_quantity(0.0), priorities=-1),
|
||||
Rule("ss_op", lambda w_c, w0: operators.constant_quantity(w_c / w0)),
|
||||
Rule("ss_op", lambda: operators.constant_quantity(0), priorities=-1),
|
||||
Rule("spm_op", operators.envelope_spm),
|
||||
|
||||
@@ -48,7 +48,6 @@ def configure_logger(logger: logging.Logger):
|
||||
updated logger
|
||||
"""
|
||||
if not hasattr(logger, "already_configured"):
|
||||
|
||||
print_lvl = lvl_map.get(log_print_level(), logging.NOTSET)
|
||||
file_lvl = lvl_map.get(log_file_level(), logging.NOTSET)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scgenerator import math
|
||||
from scgenerator.logger import get_logger
|
||||
from scgenerator.physics import fiber, materials, plasma, pulse, units
|
||||
@@ -266,8 +267,7 @@ def constant_wave_vector(
|
||||
##################################################
|
||||
|
||||
|
||||
def envelope_raman(hr_w:np.ndarra, raman_fraction: float) -> FieldOperator:
|
||||
|
||||
def envelope_raman(hr_w: np.ndarra, raman_fraction: float) -> FieldOperator:
|
||||
def operate(field: np.ndarray, z: float) -> np.ndarray:
|
||||
return raman_fraction * np.fft.ifft(hr_w * np.fft.fft(math.abs2(field)))
|
||||
|
||||
@@ -336,7 +336,6 @@ def ionization(
|
||||
N0 = number_density(z)
|
||||
plasma_info = plasma_obj(field, N0)
|
||||
|
||||
|
||||
# state.stats["ionization_fraction"] = plasma_info.electron_density[-1] / N0
|
||||
# state.stats["electron_density"] = plasma_info.electron_density[-1]
|
||||
return plasma_info.polarization
|
||||
|
||||
@@ -1,32 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as datetime_module
|
||||
import enum
|
||||
import os
|
||||
import time
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field, fields
|
||||
from functools import lru_cache, wraps
|
||||
from math import isnan
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeVar
|
||||
from typing import (Any, Callable, ClassVar, Iterable, Iterator, Set, Type,
|
||||
TypeVar)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scgenerator import env, utils
|
||||
from scgenerator.const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
|
||||
from scgenerator import utils
|
||||
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
||||
from scgenerator.errors import EvaluatorError
|
||||
from scgenerator.evaluator import Evaluator
|
||||
from scgenerator.logger import get_logger
|
||||
from scgenerator.operators import Qualifier, SpecOperator
|
||||
from scgenerator.utils import fiber_folder, update_path_name
|
||||
from scgenerator.variationer import VariationDescriptor, Variationer
|
||||
from scgenerator.utils import update_path_name
|
||||
|
||||
T = TypeVar("T")
|
||||
DISPLAY_INFO = {}
|
||||
|
||||
|
||||
def format_value(name: str, value) -> str:
|
||||
if value is True or value is False:
|
||||
return str(value)
|
||||
elif isinstance(value, (float, int)):
|
||||
try:
|
||||
return DISPLAY_INFO[name](value)
|
||||
except KeyError:
|
||||
return format(value, ".9g")
|
||||
elif isinstance(value, np.ndarray):
|
||||
return np.array2string(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return "-".join([str(v) for v in value])
|
||||
elif isinstance(value, str):
|
||||
p = Path(value)
|
||||
if p.exists():
|
||||
return p.stem
|
||||
elif callable(value):
|
||||
return getattr(value, "__name__", repr(value))
|
||||
return str(value)
|
||||
|
||||
|
||||
# Validator
|
||||
|
||||
|
||||
@lru_cache
|
||||
def type_checker(*types):
|
||||
def _type_checker_wrapper(validator, n=None):
|
||||
@@ -224,7 +242,7 @@ class Parameter:
|
||||
pass
|
||||
if self.default is not None:
|
||||
Evaluator.register_default_param(self.name, self.default)
|
||||
VariationDescriptor.register_formatter(self.name, self.display)
|
||||
DISPLAY_INFO[self.name] = self.display
|
||||
|
||||
def __get__(self, instance: Parameters, owner):
|
||||
if instance is None:
|
||||
@@ -382,7 +400,7 @@ class Parameters:
|
||||
dt: float = Parameter(in_range_excl(0, 10e-15))
|
||||
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
|
||||
step_size: float = Parameter(non_negative(float, int), default=0)
|
||||
interpolation_range: tuple[float, float] = Parameter(
|
||||
wavelength_window: tuple[float, float] = Parameter(
|
||||
validator_and(float_pair, validator_list(in_range_incl(100e-9, 10000e-9)))
|
||||
)
|
||||
interpolation_degree: int = Parameter(validator_and(type_checker(int), in_range_incl(2, 18)))
|
||||
@@ -469,11 +487,7 @@ class Parameters:
|
||||
exclude = exclude or []
|
||||
if isinstance(exclude, str):
|
||||
exclude = [exclude]
|
||||
p_pairs = [
|
||||
(k, VariationDescriptor.format_value(k, getattr(self, k)))
|
||||
for k in params
|
||||
if k not in exclude
|
||||
]
|
||||
p_pairs = [(k, format_value(k, getattr(self, k))) for k in params if k not in exclude]
|
||||
max_left = max(len(el[0]) for el in p_pairs)
|
||||
max_right = max(len(el[1]) for el in p_pairs)
|
||||
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
|
||||
@@ -544,262 +558,6 @@ class Parameters:
|
||||
return None
|
||||
|
||||
|
||||
class AbstractConfiguration:
|
||||
fiber_paths: list[Path]
|
||||
num_sim: int
|
||||
total_num_steps: int
|
||||
worker_num: int
|
||||
final_path: Path
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def save_parameters(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FileConfiguration(AbstractConfiguration):
|
||||
"""
|
||||
Primary role is to load the final config file of the simulation and deduce every
|
||||
simulatin that has to happen. Iterating through the Configuration obj yields a list of
|
||||
parameter names and values that change throughout the simulation as well as parameter
|
||||
obj with the output path of the simulation saved in its output_path attribute.
|
||||
"""
|
||||
|
||||
fiber_configs: list[utils.SubConfig]
|
||||
master_config_dict: dict[str, Any]
|
||||
num_fibers: int
|
||||
repeat: int
|
||||
z_num: int
|
||||
overwrite: bool
|
||||
all_configs: dict[tuple[tuple[int, ...], ...], "FileConfiguration.__SimConfig"]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class __SimConfig:
|
||||
descriptor: VariationDescriptor
|
||||
config: dict[str, Any]
|
||||
output_path: Path
|
||||
|
||||
@property
|
||||
def sim_num(self) -> int:
|
||||
return len(self.descriptor.index)
|
||||
|
||||
class State(enum.Enum):
|
||||
COMPLETE = enum.auto()
|
||||
PARTIAL = enum.auto()
|
||||
ABSENT = enum.auto()
|
||||
|
||||
class Action(enum.Enum):
|
||||
RUN = enum.auto()
|
||||
WAIT = enum.auto()
|
||||
SKIP = enum.auto()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: os.PathLike,
|
||||
overwrite: bool = True,
|
||||
wait: bool = False,
|
||||
skip_callback: Callable[[int], None] = None,
|
||||
final_output_path: os.PathLike = None,
|
||||
):
|
||||
self.logger = get_logger(__name__)
|
||||
self.wait = wait
|
||||
|
||||
self.overwrite = overwrite
|
||||
self.final_path, self.fiber_configs = utils.load_config_sequence(config_path)
|
||||
self.final_path = env.get(env.OUTPUT_PATH, self.final_path)
|
||||
if final_output_path is not None:
|
||||
self.final_path = final_output_path
|
||||
self.final_path = utils.ensure_folder(
|
||||
Path(self.final_path),
|
||||
mkdir=False,
|
||||
prevent_overwrite=not self.overwrite,
|
||||
)
|
||||
self.master_config_dict = self.fiber_configs[0].fixed | {
|
||||
k: v[0] for vary_dict in self.fiber_configs[0].variable for k, v in vary_dict.items()
|
||||
}
|
||||
self.name = self.final_path.name
|
||||
self.z_num = 0
|
||||
self.total_num_steps = 0
|
||||
self.fiber_paths = []
|
||||
self.all_configs = {}
|
||||
self.skip_callback = skip_callback
|
||||
self.worker_num = self.master_config_dict.get("worker_num", max(1, os.cpu_count() // 2))
|
||||
self.repeat = self.master_config_dict.get("repeat", 1)
|
||||
self.variationer = Variationer()
|
||||
|
||||
fiber_names = set()
|
||||
self.num_fibers = 0
|
||||
for i, config in enumerate(self.fiber_configs):
|
||||
config.fixed.setdefault("name", Parameters.name.default)
|
||||
self.z_num += config.fixed["z_num"]
|
||||
fiber_names.add(config.fixed["name"])
|
||||
self.variationer.append(config.variable)
|
||||
self.fiber_paths.append(
|
||||
utils.ensure_folder(
|
||||
self.final_path / fiber_folder(i, self.name, Path(config.fixed["name"]).name),
|
||||
mkdir=False,
|
||||
prevent_overwrite=not self.overwrite,
|
||||
)
|
||||
)
|
||||
self.__validate_variable(config.variable)
|
||||
self.num_fibers += 1
|
||||
Evaluator.evaluate_default(
|
||||
self.master_config_dict
|
||||
| config.fixed
|
||||
| {k: v[0] for vary_dict in config.variable for k, v in vary_dict.items()},
|
||||
True,
|
||||
)
|
||||
self.num_sim = self.variationer.var_num()
|
||||
self.total_num_steps = sum(
|
||||
config.fixed["z_num"] * self.variationer.var_num(i)
|
||||
for i, config in enumerate(self.fiber_configs)
|
||||
)
|
||||
|
||||
def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
|
||||
for vary_dict in vary_dict_list:
|
||||
for k, v in vary_dict.items():
|
||||
p: Parameter = getattr(Parameters, k)
|
||||
validator_list(p.validator)("variable " + k, v)
|
||||
if k not in VALID_VARIABLE:
|
||||
raise TypeError(f"{k!r} is not a valid variable parameter")
|
||||
if len(v) == 0:
|
||||
raise ValueError(f"variable parameter {k!r} must not be empty")
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]:
|
||||
for i in range(self.num_fibers):
|
||||
yield from self.iterate_single_fiber(i)
|
||||
|
||||
def iterate_single_fiber(self, index: int) -> Iterator[tuple[VariationDescriptor, Parameters]]:
|
||||
"""iterates through the parameters of only one fiber. It takes care of recovering partially
|
||||
completed simulations, skipping complete ones and waiting for the previous fiber to finish
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : int
|
||||
which fiber to iterate over
|
||||
|
||||
Yields
|
||||
-------
|
||||
__SimConfig
|
||||
configuration obj
|
||||
"""
|
||||
if index < 0:
|
||||
index = self.num_fibers + index
|
||||
sim_dict: dict[Path, FileConfiguration.__SimConfig] = {}
|
||||
for descriptor in self.variationer.iterate(index):
|
||||
cfg = descriptor.update_config(self.fiber_configs[index].fixed)
|
||||
if index > 0:
|
||||
cfg["prev_data_dir"] = str(
|
||||
self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True)
|
||||
)
|
||||
p = utils.ensure_folder(
|
||||
self.fiber_paths[index] / descriptor.formatted_descriptor(True),
|
||||
not self.overwrite,
|
||||
False,
|
||||
)
|
||||
cfg["output_path"] = p
|
||||
sim_config = self.__SimConfig(descriptor, cfg, p)
|
||||
sim_dict[p] = self.all_configs[sim_config.descriptor.index] = sim_config
|
||||
while len(sim_dict) > 0:
|
||||
for data_dir, sim_config in sim_dict.items():
|
||||
task, config_dict = self.__decide(sim_config)
|
||||
if task == self.Action.RUN:
|
||||
sim_dict.pop(data_dir)
|
||||
param_dict = sim_config.config
|
||||
yield sim_config.descriptor, Parameters(**param_dict)
|
||||
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
|
||||
self.skip_callback(config_dict["recovery_last_stored"])
|
||||
break
|
||||
elif task == self.Action.SKIP:
|
||||
sim_dict.pop(data_dir)
|
||||
self.logger.debug(f"skipping {data_dir} as it is already complete")
|
||||
if self.skip_callback is not None:
|
||||
self.skip_callback(config_dict["z_num"])
|
||||
break
|
||||
else:
|
||||
self.logger.debug("sleeping while waiting for other simulations to complete")
|
||||
time.sleep(1)
|
||||
|
||||
def __decide(
|
||||
self, sim_config: "FileConfiguration.__SimConfig"
|
||||
) -> tuple["FileConfiguration.Action", dict[str, Any]]:
|
||||
"""decide what to to with a particular simulation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sim_config : __SimConfig
|
||||
|
||||
Returns
|
||||
-------
|
||||
str : Configuration.Action
|
||||
what to do
|
||||
config_dict : dict[str, Any]
|
||||
config dictionary. The only key possibly modified is 'prev_data_dir', which
|
||||
gets set if the simulation is partially completed
|
||||
"""
|
||||
if not self.wait:
|
||||
return self.Action.RUN, sim_config.config
|
||||
out_status, num = self.sim_status(sim_config.output_path, sim_config.config)
|
||||
if out_status == self.State.COMPLETE:
|
||||
return self.Action.SKIP, sim_config.config
|
||||
elif out_status == self.State.PARTIAL:
|
||||
sim_config.config["recovery_data_dir"] = str(sim_config.output_path)
|
||||
sim_config.config["recovery_last_stored"] = num
|
||||
return self.Action.RUN, sim_config.config
|
||||
|
||||
if "prev_data_dir" in sim_config.config:
|
||||
prev_data_path = Path(sim_config.config["prev_data_dir"])
|
||||
prev_status, _ = self.sim_status(prev_data_path)
|
||||
if prev_status in {self.State.PARTIAL, self.State.ABSENT}:
|
||||
return self.Action.WAIT, sim_config.config
|
||||
return self.Action.RUN, sim_config.config
|
||||
|
||||
def sim_status(
|
||||
self, data_dir: Path, config_dict: dict[str, Any] = None
|
||||
) -> tuple["FileConfiguration.State", int]:
|
||||
"""returns the status of a simulation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_dir : Path
|
||||
directory where simulation data is to be saved
|
||||
config_dict : dict[str, Any], optional
|
||||
configuration of the simulation. If None, will attempt to load
|
||||
the params.toml file if present, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
Configuration.State
|
||||
status
|
||||
"""
|
||||
num = utils.find_last_spectrum_num(data_dir)
|
||||
if config_dict is None:
|
||||
try:
|
||||
config_dict = utils.load_toml(data_dir / PARAM_FN)
|
||||
except FileNotFoundError:
|
||||
self.logger.warning(f"did not find {PARAM_FN!r} in {data_dir}")
|
||||
return self.State.ABSENT, 0
|
||||
if num == config_dict["z_num"] - 1:
|
||||
return self.State.COMPLETE, num
|
||||
elif config_dict["z_num"] - 1 > num > 0:
|
||||
return self.State.PARTIAL, num
|
||||
elif num == 0:
|
||||
return self.State.ABSENT, 0
|
||||
else:
|
||||
raise ValueError(f"Too many spectra in {data_dir}")
|
||||
|
||||
def save_parameters(self):
|
||||
os.makedirs(self.final_path, exist_ok=True)
|
||||
cfgs = [cfg.fixed | dict(variable=cfg.variable) for cfg in self.fiber_configs]
|
||||
utils.save_toml(self.final_path / "initial_config.toml", dict(name=self.name, Fiber=cfgs))
|
||||
|
||||
@property
|
||||
def first(self) -> Parameters:
|
||||
for _, param in self:
|
||||
return param
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
numero = type_checker(int)
|
||||
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import typing
|
||||
from collections import abc
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from scgenerator.env import pbar_policy
|
||||
|
||||
T_ = typing.TypeVar("T_")
|
||||
|
||||
|
||||
class PBars:
|
||||
def __init__(
|
||||
self,
|
||||
task: Union[int, Iterable[T_]],
|
||||
desc: str,
|
||||
num_sub_bars: int = 0,
|
||||
head_kwargs=None,
|
||||
worker_kwargs=None,
|
||||
) -> "PBars":
|
||||
"""creates a PBars obj
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : int | Iterable
|
||||
if int : total length of the main task
|
||||
if Iterable : behaves like tqdm
|
||||
desc : str
|
||||
description of the main task
|
||||
num_sub_bars : int
|
||||
number of sub-tasks
|
||||
|
||||
"""
|
||||
self.id = random.randint(100000, 999999)
|
||||
try:
|
||||
self.width = os.get_terminal_size().columns
|
||||
except OSError:
|
||||
self.width = 120
|
||||
if isinstance(task, abc.Iterable):
|
||||
self.iterator: Iterable[T_] = iter(task)
|
||||
self.num_tot: int = len(task)
|
||||
else:
|
||||
self.num_tot: int = task
|
||||
self.iterator = None
|
||||
|
||||
self.policy = pbar_policy()
|
||||
if head_kwargs is None:
|
||||
head_kwargs = dict()
|
||||
if worker_kwargs is None:
|
||||
worker_kwargs = dict(
|
||||
total=1,
|
||||
desc="Worker {worker_id}",
|
||||
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
|
||||
)
|
||||
if "print" not in pbar_policy():
|
||||
head_kwargs["file"] = worker_kwargs["file"] = StringIO()
|
||||
self.width = 80
|
||||
head_kwargs["desc"] = desc
|
||||
self.pbars = [tqdm(total=self.num_tot, ncols=self.width, ascii=False, **head_kwargs)]
|
||||
for i in range(1, num_sub_bars + 1):
|
||||
kwargs = {k: v for k, v in worker_kwargs.items()}
|
||||
if "desc" in kwargs:
|
||||
kwargs["desc"] = kwargs["desc"].format(worker_id=i)
|
||||
self.append(tqdm(position=i, ncols=self.width, ascii=False, **kwargs))
|
||||
self.print_path = Path(
|
||||
f"progress {self.pbars[0].desc.replace('/', '')} {self.id}"
|
||||
).resolve()
|
||||
self.close_ev = threading.Event()
|
||||
if "file" in self.policy:
|
||||
self.thread = threading.Thread(target=self.print_worker, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def print(self):
|
||||
if "file" not in self.policy:
|
||||
return
|
||||
s = []
|
||||
for pbar in self.pbars:
|
||||
s.append(str(pbar))
|
||||
self.print_path.write_text("\n".join(s))
|
||||
|
||||
def print_worker(self):
|
||||
while True:
|
||||
if self.close_ev.wait(2.0):
|
||||
return
|
||||
self.print()
|
||||
|
||||
def __iter__(self):
|
||||
with self as pb:
|
||||
for thing in self.iterator:
|
||||
yield thing
|
||||
pb.update()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.pbars[key]
|
||||
|
||||
def update(self, i=None, value=1):
|
||||
if i is None:
|
||||
for pbar in self.pbars[1:]:
|
||||
pbar.update(value)
|
||||
elif i > 0:
|
||||
self.pbars[i].update(value)
|
||||
self.pbars[0].update()
|
||||
|
||||
def append(self, pbar: tqdm):
|
||||
self.pbars.append(pbar)
|
||||
|
||||
def reset(self, i):
|
||||
self.pbars[i].update(-self.pbars[i].n)
|
||||
self.print()
|
||||
|
||||
def close(self):
|
||||
self.print()
|
||||
self.close_ev.set()
|
||||
if "file" in self.policy:
|
||||
self.thread.join()
|
||||
for pbar in self.pbars:
|
||||
pbar.close()
|
||||
|
||||
|
||||
class ProgressBarActor:
|
||||
def __init__(self, name: str, num_workers: int, num_steps: int) -> None:
|
||||
self.counters = [0 for _ in range(num_workers + 1)]
|
||||
self.p_bars = PBars(
|
||||
num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")
|
||||
)
|
||||
|
||||
def update(self, worker_id: int, rel_pos: float = None) -> None:
|
||||
"""update a counter
|
||||
|
||||
Parameters
|
||||
----------
|
||||
worker_id : int
|
||||
id of the worker. 0 is the overall progress
|
||||
rel_pos : float, optional
|
||||
if None, increase the counter by one, if set, will set
|
||||
the counter to the specified value (instead of incrementing it), by default None
|
||||
"""
|
||||
if rel_pos is None:
|
||||
self.counters[worker_id] += 1
|
||||
else:
|
||||
self.counters[worker_id] = rel_pos
|
||||
|
||||
def update_pbars(self):
|
||||
for counter, pbar in zip(self.counters, self.p_bars.pbars):
|
||||
pbar.update(counter - pbar.n)
|
||||
|
||||
def close(self):
|
||||
self.p_bars.close()
|
||||
|
||||
|
||||
def progress_worker(
|
||||
name: str, num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue
|
||||
):
|
||||
"""keeps track of progress on a separate thread
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_steps : int
|
||||
total number of steps, used for the main progress bar (position 0)
|
||||
progress_queue : multiprocessing.Queue
|
||||
values are either
|
||||
Literal[0] : stop the worker and close the progress bars
|
||||
tuple[int, float] : worker id and relative progress between 0 and 1
|
||||
"""
|
||||
with PBars(
|
||||
num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")
|
||||
) as pbars:
|
||||
while True:
|
||||
raw = progress_queue.get()
|
||||
if raw == 0:
|
||||
return
|
||||
i, rel_pos = raw
|
||||
if i > 0:
|
||||
pbars[i].update(rel_pos - pbars[i].n)
|
||||
pbars[0].update()
|
||||
elif i == 0:
|
||||
pbars[0].update(rel_pos)
|
||||
@@ -18,7 +18,7 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
def lambda_for_envelope_dispersion(
|
||||
l: np.ndarray, interpolation_range: tuple[float, float]
|
||||
l: np.ndarray, wavelength_window: tuple[float, float]
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Returns a wl vector for dispersion calculation in envelope mode
|
||||
|
||||
@@ -30,10 +30,10 @@ def lambda_for_envelope_dispersion(
|
||||
np.ndarray
|
||||
indices of the original l where the values are valid (i.e. without the two extra on each side)
|
||||
"""
|
||||
su = np.where((l >= interpolation_range[0]) & (l <= interpolation_range[1]))[0]
|
||||
if l[su].min() > 1.01 * interpolation_range[0]:
|
||||
su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0]
|
||||
if l[su].min() > 1.01 * wavelength_window[0]:
|
||||
raise ValueError(
|
||||
f"lower range of {1e9*interpolation_range[0]:.1f}nm is not reached by the grid. "
|
||||
f"lower range of {1e9*wavelength_window[0]:.1f}nm is not reached by the grid. "
|
||||
f"Minimum of grid is {1e9*l[su].min():.1f}nm. Try a finer grid"
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ def lambda_for_envelope_dispersion(
|
||||
|
||||
|
||||
def lambda_for_full_field_dispersion(
|
||||
l: np.ndarray, interpolation_range: tuple[float, float]
|
||||
l: np.ndarray, wavelength_window: tuple[float, float]
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Returns a wl vector for dispersion calculation in full field mode
|
||||
|
||||
@@ -60,10 +60,10 @@ def lambda_for_full_field_dispersion(
|
||||
np.ndarray
|
||||
indices of the original l where the values are valid (i.e. without the two extra on each side)
|
||||
"""
|
||||
su = np.where((l >= interpolation_range[0]) & (l <= interpolation_range[1]))[0]
|
||||
if l[su].min() > 1.01 * interpolation_range[0]:
|
||||
su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0]
|
||||
if l[su].min() > 1.01 * wavelength_window[0]:
|
||||
raise ValueError(
|
||||
f"lower range of {1e9*interpolation_range[0]:.1f}nm is not reached by the grid. "
|
||||
f"lower range of {1e9*wavelength_window[0]:.1f}nm is not reached by the grid. "
|
||||
"try a finer grid"
|
||||
)
|
||||
fu = np.concatenate((su[:2] - 2, su, su[-2:] + 2))
|
||||
@@ -385,7 +385,7 @@ def V_eff_step_index(
|
||||
l: T,
|
||||
core_radius: float,
|
||||
numerical_aperture: float,
|
||||
interpolation_range: tuple[float, float] = None,
|
||||
wavelength_window: tuple[float, float] = None,
|
||||
) -> T:
|
||||
"""computes the V parameter of a step-index fiber
|
||||
|
||||
@@ -397,7 +397,7 @@ def V_eff_step_index(
|
||||
radius of the core
|
||||
numerical_aperture : float
|
||||
as a decimal number
|
||||
interpolation_range : tuple[float, float], optional
|
||||
wavelength_window : tuple[float, float], optional
|
||||
when provided, only computes V over this range, wavelengths outside this range will
|
||||
yield V=inf, by default None
|
||||
|
||||
@@ -407,8 +407,8 @@ def V_eff_step_index(
|
||||
V parameter
|
||||
"""
|
||||
pi2cn = 2 * pi * core_radius * numerical_aperture
|
||||
if interpolation_range is not None and isinstance(l, np.ndarray):
|
||||
low, high = interpolation_range
|
||||
if wavelength_window is not None and isinstance(l, np.ndarray):
|
||||
low, high = wavelength_window
|
||||
l = np.where((l >= low) & (l <= high), l, np.inf)
|
||||
return pi2cn / l
|
||||
|
||||
@@ -805,7 +805,6 @@ def delayed_raman_w(t: np.ndarray, raman_type: str) -> tuple[np.ndarray, float]:
|
||||
return hr_w, raman_fraction(raman_type)
|
||||
|
||||
|
||||
|
||||
def fast_poly_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None)):
|
||||
"""
|
||||
dispersive operator
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -106,14 +106,16 @@ class Gas:
|
||||
chi3_0: float
|
||||
ionization_energy: float | None
|
||||
|
||||
_raw_sellmeier: dict[str, Any]
|
||||
|
||||
def __init__(self, gas_name: str):
|
||||
self.name = gas_name
|
||||
self.mat_dico = utils.load_material_dico(gas_name)
|
||||
self.atomic_mass = self.mat_dico["atomic_mass"]
|
||||
self.atomic_number = self.mat_dico["atomic_number"]
|
||||
self.ionization_energy = self.mat_dico.get("ionization_energy")
|
||||
self._raw_sellmeier = utils.load_material_dico(gas_name)
|
||||
self.atomic_mass = self._raw_sellmeier["atomic_mass"]
|
||||
self.atomic_number = self._raw_sellmeier["atomic_number"]
|
||||
self.ionization_energy = self._raw_sellmeier.get("ionization_energy")
|
||||
|
||||
s = self.mat_dico.get("sellmeier", {})
|
||||
s = self._raw_sellmeier.get("sellmeier", {})
|
||||
self.sellmeier = Sellmeier(
|
||||
**{
|
||||
newk: s.get(k, None)
|
||||
@@ -124,7 +126,7 @@ class Gas:
|
||||
if k in s
|
||||
}
|
||||
)
|
||||
kerr = self.mat_dico["kerr"]
|
||||
kerr = self._raw_sellmeier["kerr"]
|
||||
n2_0 = kerr["n2"]
|
||||
self._kerr_wl = kerr.get("wavelength", 800e-9)
|
||||
self.chi3_0 = (
|
||||
@@ -212,18 +214,23 @@ class Gas:
|
||||
|
||||
Raises
|
||||
----------
|
||||
ValueError : Since the Van der Waals equation is a cubic one, there could be more than one real, positive solution
|
||||
ValueError : Since the Van der Waals equation is a cubic one, there could be more than one
|
||||
real, positive solution
|
||||
"""
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if pressure == 0:
|
||||
return 0
|
||||
a = self.mat_dico.get("a", 0)
|
||||
b = self.mat_dico.get("b", 0)
|
||||
pressure = self.mat_dico["sellmeier"].get("P0", 101325) if pressure is None else pressure
|
||||
a = self._raw_sellmeier.get("a", 0)
|
||||
b = self._raw_sellmeier.get("b", 0)
|
||||
pressure = (
|
||||
self._raw_sellmeier["sellmeier"].get("P0", 101325) if pressure is None else pressure
|
||||
)
|
||||
temperature = (
|
||||
self.mat_dico["sellmeier"].get("T0", 273.15) if temperature is None else temperature
|
||||
self._raw_sellmeier["sellmeier"].get("T0", 273.15)
|
||||
if temperature is None
|
||||
else temperature
|
||||
)
|
||||
ap = a / NA**2
|
||||
bp = b / NA
|
||||
@@ -302,10 +309,10 @@ class Gas:
|
||||
return Z**3 / (16 * ns**4) * 5.14220670712125e11
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.mat_dico.get(key, default)
|
||||
return self._raw_sellmeier.get(key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.mat_dico[key]
|
||||
return self._raw_sellmeier[key]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.name!r})"
|
||||
@@ -317,7 +324,7 @@ def n_gas_2(wl_for_disp: np.ndarray, gas_name: str, pressure: float, temperature
|
||||
return Sellmeier.load(gas_name).n_gas_2(wl_for_disp, temperature, pressure)
|
||||
|
||||
|
||||
def pressure_from_gradient(ratio, p0, p1):
|
||||
def pressure_from_gradient(ratio: float, p0: float, p1: float) -> float:
|
||||
"""returns the pressure as function of distance with eq. 20 in Markos et al. (2017)
|
||||
Parameters
|
||||
----------
|
||||
|
||||
@@ -1,346 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from cycler import cycler
|
||||
from tqdm import tqdm
|
||||
|
||||
from scgenerator import env, math
|
||||
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN
|
||||
from scgenerator.parameter import FileConfiguration, Parameters
|
||||
from scgenerator.physics import fiber, units
|
||||
from scgenerator.plotting import plot_setup, transform_2D_propagation, get_extent
|
||||
from scgenerator.spectra import SimulationSeries
|
||||
from scgenerator.utils import _open_config, auto_crop, save_toml, simulations_list, load_toml, load_spectrum
|
||||
|
||||
|
||||
def fingerprint(params: Parameters):
|
||||
h1 = hash(params.field_0.tobytes())
|
||||
h2 = tuple(params.beta2_coefficients)
|
||||
return h1, h2
|
||||
|
||||
|
||||
def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
|
||||
for k, v in opts.items():
|
||||
if k in ["skip"]:
|
||||
opts[k] = int(v)
|
||||
if v == "True":
|
||||
opts[k] = True
|
||||
elif v == "False":
|
||||
opts[k] = False
|
||||
dir_list = simulations_list(sim_dir)
|
||||
if len(dir_list) == 0:
|
||||
dir_list = [sim_dir]
|
||||
limits = [
|
||||
tuple(func(el) for func, el in zip([float, float, str], lim.split(","))) for lim in limits
|
||||
]
|
||||
with tqdm(total=len(dir_list) * max(1, len(limits))) as bar:
|
||||
for p in dir_list:
|
||||
pulse = SimulationSeries(p)
|
||||
if not limits:
|
||||
limits = [
|
||||
(
|
||||
pulse.params.interpolation_range[0] * 1e9,
|
||||
pulse.params.interpolation_range[1] * 1e9,
|
||||
"nm",
|
||||
)
|
||||
]
|
||||
for left, right, unit in limits:
|
||||
path, fig, ax = plot_setup(
|
||||
pulse.path.parent
|
||||
/ (
|
||||
pulse.path.name
|
||||
+ PARAM_SEPARATOR
|
||||
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
|
||||
)
|
||||
)
|
||||
fig.suptitle(p.name)
|
||||
pulse.plot_2D(
|
||||
left,
|
||||
right,
|
||||
unit,
|
||||
ax,
|
||||
**opts,
|
||||
)
|
||||
bar.update()
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
fig.savefig(path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_init_field_spec(
|
||||
config_path: Path,
|
||||
lim_t: tuple[float, float] = None,
|
||||
lim_l: tuple[float, float] = None,
|
||||
):
|
||||
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7))
|
||||
all_labels = []
|
||||
already_plotted = set()
|
||||
for style, lbl, params in plot_helper(config_path):
|
||||
if (bbb := hash(params.field_0.tobytes())) not in already_plotted:
|
||||
already_plotted.add(bbb)
|
||||
else:
|
||||
continue
|
||||
|
||||
lbl = plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params)
|
||||
all_labels.append(lbl)
|
||||
finish_plot(fig, left, right, all_labels, params)
|
||||
|
||||
|
||||
def plot_dispersion(config_path: Path, lim: tuple[float, float] = None):
|
||||
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7))
|
||||
left.grid()
|
||||
right.grid()
|
||||
all_labels = []
|
||||
already_plotted = set()
|
||||
loss_ax = None
|
||||
plt.sca(left)
|
||||
for style, lbl, params in plot_helper(config_path):
|
||||
if params.alpha_arr is not None and loss_ax is None:
|
||||
loss_ax = right.twinx()
|
||||
if (bbb := tuple(params.beta2_coefficients)) not in already_plotted:
|
||||
already_plotted.add(bbb)
|
||||
else:
|
||||
continue
|
||||
|
||||
lbl = plot_1_dispersion(lim, left, right, style, lbl, params, loss_ax)
|
||||
all_labels.append(lbl)
|
||||
finish_plot(fig, right, all_labels, params)
|
||||
|
||||
|
||||
def plot_init(
|
||||
config_path: Path,
|
||||
lim_field: tuple[float, float] = None,
|
||||
lim_spec: tuple[float, float] = None,
|
||||
lim_disp: tuple[float, float] = None,
|
||||
):
|
||||
fig, ((tl, tr), (bl, br)) = plt.subplots(2, 2, figsize=(14, 10))
|
||||
loss_ax = None
|
||||
tl.grid()
|
||||
tr.grid()
|
||||
all_labels = []
|
||||
already_plotted = set()
|
||||
for style, lbl, params in plot_helper(config_path):
|
||||
if params.alpha_arr is not None and loss_ax is None:
|
||||
loss_ax = tr.twinx()
|
||||
if (fp := fingerprint(params)) not in already_plotted:
|
||||
already_plotted.add(fp)
|
||||
else:
|
||||
continue
|
||||
lbl = plot_1_dispersion(lim_disp, tl, tr, style, lbl, params, loss_ax)
|
||||
lbl = plot_1_init_spec_field(lim_field, lim_spec, bl, br, style, lbl, params)
|
||||
all_labels.append(lbl)
|
||||
print(params.pretty_str(exclude="beta2_coefficients"))
|
||||
finish_plot(fig, tr, all_labels, params)
|
||||
|
||||
|
||||
def plot_1_init_spec_field(
|
||||
lim_t: Optional[tuple[float, float]],
|
||||
lim_l: Optional[tuple[float, float]],
|
||||
left: plt.Axes,
|
||||
right: plt.Axes,
|
||||
style: dict[str, Any],
|
||||
lbl: str,
|
||||
params: Parameters,
|
||||
):
|
||||
field = math.abs2(params.field_0)
|
||||
spec = math.abs2(params.spec_0)
|
||||
t = units.fs.inv(params.t)
|
||||
wl = units.nm.inv(params.w)
|
||||
|
||||
lbl += f" max at {wl[spec.argmax()]:.1f} nm"
|
||||
|
||||
mt = np.ones_like(t, dtype=bool)
|
||||
if lim_t is not None:
|
||||
mt &= t >= lim_t[0]
|
||||
mt &= t <= lim_t[1]
|
||||
else:
|
||||
mt = auto_crop(t, field)
|
||||
ml = np.ones_like(wl, dtype=bool)
|
||||
if lim_l is not None:
|
||||
ml &= wl >= lim_l[0]
|
||||
ml &= wl <= lim_l[1]
|
||||
else:
|
||||
ml = auto_crop(wl, spec)
|
||||
|
||||
left.plot(t[mt], field[mt])
|
||||
right.plot(wl[ml], spec[ml], label=" ", **style)
|
||||
return lbl
|
||||
|
||||
|
||||
def plot_1_dispersion(
|
||||
lim: Optional[tuple[float, float]],
|
||||
left: plt.Axes,
|
||||
right: plt.Axes,
|
||||
style: dict[str, Any],
|
||||
lbl: list[str],
|
||||
params: Parameters,
|
||||
loss: plt.Axes = None,
|
||||
):
|
||||
beta_arr = fiber.dispersion_from_coefficients(params.w_c, params.beta2_coefficients)
|
||||
wl = units.m.inv(params.w)
|
||||
D = fiber.beta2_to_D(beta_arr, wl) * 1e6
|
||||
|
||||
zdw = math.all_zeros(wl, beta_arr)
|
||||
zdw = zdw[(zdw >= params.interpolation_range[0]) & (zdw <= params.interpolation_range[1])]
|
||||
if len(zdw) > 0:
|
||||
zdw = zdw[np.argmin(abs(zdw - params.wavelength))]
|
||||
lbl += f" ZDW at {zdw*1e9:.1f}nm"
|
||||
else:
|
||||
lbl += ""
|
||||
|
||||
m = np.ones_like(wl, dtype=bool)
|
||||
if lim is None:
|
||||
lim = params.interpolation_range
|
||||
m &= wl >= (lim[0] if lim[0] < 1 else lim[0] * 1e-9)
|
||||
m &= wl <= (lim[1] if lim[1] < 1 else lim[1] * 1e-9)
|
||||
|
||||
info_str = (
|
||||
rf"$\lambda_{{\mathrm{{min}}}}={np.min(params.l[params.l>0])*1e9:.1f}$ nm"
|
||||
+ f"\nlower interpolation limit : {params.interpolation_range[0]*1e9:.1f} nm\n"
|
||||
+ f"max time delay : {params.t.max()*1e12:.1f} ps"
|
||||
)
|
||||
|
||||
left.annotate(
|
||||
info_str,
|
||||
xy=(1, 1),
|
||||
xytext=(-12, -12),
|
||||
xycoords="axes fraction",
|
||||
textcoords="offset points",
|
||||
va="top",
|
||||
ha="right",
|
||||
backgroundcolor=(1, 1, 1, 0.4),
|
||||
)
|
||||
|
||||
m = np.argwhere(m)[:, 0]
|
||||
m = np.array(sorted(m, key=lambda el: wl[el]))
|
||||
|
||||
if len(m) == 0:
|
||||
raise ValueError(f"nothing to plot in the range {lim!r}")
|
||||
|
||||
# plot D
|
||||
right.plot(1e9 * wl[m], D[m], label=" ", **style)
|
||||
right.set_ylabel(units.D_ps_nm_km.label)
|
||||
|
||||
# plot beta2
|
||||
left.plot(units.nm.inv(params.w[m]), units.beta2_fs_cm.inv(beta_arr[m]), label=" ", **style)
|
||||
left.set_ylabel(units.beta2_fs_cm.label)
|
||||
|
||||
left.set_xlabel(units.nm.label)
|
||||
right.set_xlabel("wavelength (nm)")
|
||||
|
||||
if params.alpha_arr is not None and loss is not None:
|
||||
loss.plot(1e9 * wl[m], params.alpha_arr[m], c="r", ls="--")
|
||||
loss.set_ylabel("loss (1/m)", color="r")
|
||||
loss.set_yscale("log")
|
||||
loss.tick_params(axis="y", labelcolor="r")
|
||||
|
||||
return lbl
|
||||
|
||||
|
||||
def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], params: Parameters):
|
||||
fig.suptitle(params.name)
|
||||
plt.tight_layout()
|
||||
|
||||
handles, _ = legend_axes.get_legend_handles_labels()
|
||||
|
||||
legend_axes.legend(handles, all_labels, prop=dict(size=8, family="monospace"))
|
||||
|
||||
out_path = env.output_path()
|
||||
|
||||
show = out_path is None
|
||||
if not show:
|
||||
file_name = out_path.stem + ".pdf"
|
||||
out_path = out_path.parent / file_name
|
||||
if (
|
||||
out_path.exists()
|
||||
and input(f"{out_path.name} already exsits, overwrite ? (y/[n])\n > ") != "y"
|
||||
):
|
||||
show = True
|
||||
else:
|
||||
fig.savefig(out_path, bbox_inches="tight")
|
||||
if show:
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]:
|
||||
cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"])
|
||||
for style, (descriptor, params), _ in zip(cc, FileConfiguration(config_path), range(20)):
|
||||
yield style, descriptor.branch.formatted_descriptor(), params
|
||||
|
||||
|
||||
def convert_params(params_file: os.PathLike):
|
||||
p = Path(params_file)
|
||||
if p.name == PARAM_FN:
|
||||
d = _open_config(params_file)
|
||||
save_toml(params_file, d)
|
||||
print(f"converted {p}")
|
||||
else:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
for pp in p.glob(PARAM_FN):
|
||||
convert_params(pp)
|
||||
for pp in p.glob("fiber*"):
|
||||
if pp.is_dir():
|
||||
convert_params(pp)
|
||||
|
||||
|
||||
def partial_plot(root: os.PathLike, lim: str = None):
|
||||
path = Path(root)
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
fig.suptitle(path.name)
|
||||
spec_list = sorted(
|
||||
path.glob(SPEC1_FN.format("*")), key=lambda el: int(re.search("[0-9]+", el.name)[0])
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
params = Parameters(**load_toml(path / "params.toml"))
|
||||
params.z_targets = params.z_targets[: len(spec_list)]
|
||||
raw_values = np.array([load_spectrum(s) for s in spec_list])
|
||||
if lim is None:
|
||||
plot_range = units.PlotRange(
|
||||
0.5 * params.interpolation_range[0] * 1e9,
|
||||
1.1 * params.interpolation_range[1] * 1e9,
|
||||
"nm",
|
||||
)
|
||||
else:
|
||||
left_u, right_u, unit = lim.split(",")
|
||||
plot_range = units.PlotRange(float(left_u), float(right_u), unit)
|
||||
if plot_range.unit.type == "TIME":
|
||||
values = params.ifft(raw_values)
|
||||
log = False
|
||||
vmin = None
|
||||
else:
|
||||
values = raw_values
|
||||
log = "2D"
|
||||
vmin = -60
|
||||
|
||||
x, y, values = transform_2D_propagation(
|
||||
values,
|
||||
plot_range,
|
||||
params,
|
||||
log=log,
|
||||
)
|
||||
ax.imshow(
|
||||
values,
|
||||
origin="lower",
|
||||
aspect="auto",
|
||||
vmin=vmin,
|
||||
interpolation="nearest",
|
||||
extent=get_extent(x, y),
|
||||
)
|
||||
|
||||
return ax
|
||||
@@ -1,157 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import Paths
|
||||
from ..parameter import FileConfiguration
|
||||
|
||||
|
||||
def primes(n):
|
||||
prime_factors = []
|
||||
d = 2
|
||||
while d * d <= n:
|
||||
while (n % d) == 0:
|
||||
prime_factors.append(d)
|
||||
n //= d
|
||||
d += 1
|
||||
if n > 1:
|
||||
prime_factors.append(n)
|
||||
return prime_factors
|
||||
|
||||
|
||||
def balance(n, lim=(32, 32)):
|
||||
factors = primes(n)
|
||||
if len(factors) == 1:
|
||||
factors = primes(n + 1)
|
||||
a, b, x, y = 1, 1, 1, 1
|
||||
while len(factors) > 0 and x <= lim[0] and y <= lim[1]:
|
||||
a = x
|
||||
b = y
|
||||
if y >= x:
|
||||
x *= factors.pop(0)
|
||||
else:
|
||||
y *= factors.pop()
|
||||
return a, b
|
||||
|
||||
|
||||
def distribute(
|
||||
num: int, nodes: int = None, cpus_per_node: int = None, lim=(16, 32)
|
||||
) -> Tuple[int, int]:
|
||||
if nodes is None and cpus_per_node is None:
|
||||
balanced = balance(num, lim)
|
||||
if num > max(lim):
|
||||
while np.product(balanced) < min(lim):
|
||||
num += 1
|
||||
balanced = balance(num, lim)
|
||||
nodes = min(balanced)
|
||||
cpus_per_node = max(balanced)
|
||||
|
||||
elif nodes is None:
|
||||
nodes = num // cpus_per_node
|
||||
while nodes > lim[0]:
|
||||
nodes //= 2
|
||||
elif cpus_per_node is None:
|
||||
cpus_per_node = num // nodes
|
||||
while cpus_per_node > lim[1]:
|
||||
cpus_per_node //= 2
|
||||
return nodes, cpus_per_node
|
||||
|
||||
|
||||
def format_time(t):
|
||||
try:
|
||||
t = float(t)
|
||||
return timedelta(minutes=t)
|
||||
except ValueError:
|
||||
return t
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description="submit a job to a slurm cluster")
|
||||
parser.add_argument("config", help="path to the toml configuration file")
|
||||
parser.add_argument(
|
||||
"-t", "--time", required=True, type=str, help="time required for the job in hh:mm:ss"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--cpus-per-node", default=None, type=int, help="number of cpus required per node"
|
||||
)
|
||||
parser.add_argument("-n", "--nodes", default=None, type=int, help="number of nodes required")
|
||||
parser.add_argument(
|
||||
"--environment-setup",
|
||||
required=False,
|
||||
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && "
|
||||
"export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_PRINT_LEVEL=none && export SCGENERATOR_LOG_FILE_LEVEL=info",
|
||||
help="commands to run to setup the environement (default : activate the sc environment with conda)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--command", default="run", choices=["run", "resume", "merge"], help="command to run"
|
||||
)
|
||||
parser.add_argument("--dependency", default=None, help="sbatch dependency argument")
|
||||
return parser
|
||||
|
||||
|
||||
def copy_starting_files():
|
||||
for name in ["start_worker", "start_head"]:
|
||||
path = Paths.get(name)
|
||||
file_name = os.path.split(path)[1]
|
||||
shutil.copy(path, file_name)
|
||||
mode = os.stat(file_name)
|
||||
os.chmod(file_name, 0o100 | mode.st_mode)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
command_map = dict(run="Propagate", resume="Resuming", merge="Merging")
|
||||
|
||||
parser = create_parser()
|
||||
template = Paths.gets("submit_job_template")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dependency is None:
|
||||
args.dependency = ""
|
||||
else:
|
||||
args.dependency = f"#SBATCH --dependency={args.dependency}"
|
||||
|
||||
if not re.match(r"^[0-9]{2}:[0-9]{2}:[0-9]{2}$", args.time) and not re.match(
|
||||
r"^[0-9]+$", args.time
|
||||
):
|
||||
|
||||
raise ValueError(
|
||||
"time format must be an integer number of minute or must match the pattern hh:mm:ss"
|
||||
)
|
||||
|
||||
config = FileConfiguration(args.config)
|
||||
final_name = config.final_path
|
||||
sim_num = config.num_sim
|
||||
|
||||
if args.command == "merge":
|
||||
args.nodes = 1
|
||||
args.cpus_per_node = 1
|
||||
else:
|
||||
args.nodes, args.cpus_per_node = distribute(config.num_sim, args.nodes, args.cpus_per_node)
|
||||
|
||||
submit_path = Path(
|
||||
"submit " + final_name.replace("/", "") + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh"
|
||||
)
|
||||
tmp_path = Path("submit tmp.sh")
|
||||
|
||||
job_name = f"supercontinuum {final_name}"
|
||||
submit_sh = template.format(job_name=job_name, **vars(args))
|
||||
|
||||
tmp_path.write_text(submit_sh)
|
||||
subprocess.run(["sbatch", "--test-only", str(tmp_path)])
|
||||
submit = input(
|
||||
f"{command_map[args.command]} {sim_num} pulses from config {args.config} with {args.cpus_per_node} cpus"
|
||||
+ f" per node on {args.nodes} nodes for {format_time(args.time)} ? (y/[n])\n"
|
||||
)
|
||||
if submit.lower() in ["y", "yes"]:
|
||||
submit_path.write_text(submit_sh)
|
||||
copy_starting_files()
|
||||
subprocess.run(["sbatch", submit_path])
|
||||
tmp_path.unlink()
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Iterator, Sequence
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
|
||||
from scgenerator.math import abs2
|
||||
from scgenerator.operators import SpecOperator
|
||||
from scgenerator.utils import TimedMessage
|
||||
@@ -133,6 +134,7 @@ def solve43(
|
||||
targets = list(sorted(set(targets)))
|
||||
if targets[0] == 0:
|
||||
targets.pop(0)
|
||||
h = min(h, targets[0] / 2)
|
||||
|
||||
step_ind = 0
|
||||
msg = TimedMessage(2)
|
||||
|
||||
@@ -405,6 +405,7 @@ class SimulatedFiber:
|
||||
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
|
||||
else:
|
||||
return load_spectrum(self.path / SPEC1_FN.format(z_ind))
|
||||
psd = np.fft.rfft(signal) / np.sqrt(0.5 * len(time) / dt)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(path={self.path})"
|
||||
|
||||
@@ -23,15 +23,13 @@ import pkg_resources as pkg
|
||||
import tomli
|
||||
import tomli_w
|
||||
|
||||
from scgenerator.const import (PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS,
|
||||
SPEC1_FN, Z_FN)
|
||||
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS, SPEC1_FN, Z_FN
|
||||
from scgenerator.errors import DuplicateParameterError
|
||||
from scgenerator.logger import get_logger
|
||||
|
||||
T_ = TypeVar("T_")
|
||||
|
||||
|
||||
|
||||
class TimedMessage:
|
||||
def __init__(self, interval: float = 10.0):
|
||||
self.interval = datetime.timedelta(seconds=interval)
|
||||
@@ -179,7 +177,8 @@ def _open_config(path: os.PathLike):
|
||||
|
||||
return dico
|
||||
|
||||
def resolve_relative_paths(d:dict[str, Any], root:os.PathLike | None=None):
|
||||
|
||||
def resolve_relative_paths(d: dict[str, Any], root: os.PathLike | None = None):
|
||||
root = Path(root) if root is not None else Path.cwd()
|
||||
for k, v in d.items():
|
||||
if isinstance(v, MutableMapping):
|
||||
@@ -192,7 +191,6 @@ def resolve_relative_paths(d:dict[str, Any], root:os.PathLike | None=None):
|
||||
d[k] = str(root / v)
|
||||
|
||||
|
||||
|
||||
def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]:
|
||||
if (f_list := dico.pop("INCLUDE", None)) is not None:
|
||||
if isinstance(f_list, str):
|
||||
|
||||
@@ -1,336 +0,0 @@
|
||||
import itertools
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from math import prod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generator, Generic, Iterable, Iterator, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import validator
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from scgenerator.const import PARAM_SEPARATOR
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class VariationSpecsError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class Variationer:
|
||||
"""
|
||||
manages possible combinations of values given dicts of lists
|
||||
|
||||
Example
|
||||
-------
|
||||
```
|
||||
>> var = Variationer([dict(a=[1, 2]), [dict(b=["000", "111"], c=["a", "-1"])]])
|
||||
list(v.raw_descr for v in var.iterate())
|
||||
|
||||
[
|
||||
((("a", 1),), (("b", "000"), ("c", "a"))),
|
||||
((("a", 1),), (("b", "111"), ("c", "-1"))),
|
||||
((("a", 2),), (("b", "000"), ("c", "a"))),
|
||||
((("a", 2),), (("b", "111"), ("c", "-1"))),
|
||||
]
|
||||
```
|
||||
"""
|
||||
|
||||
all_indices: list[list[int]]
|
||||
all_dicts: list[list[dict[str, list]]]
|
||||
|
||||
def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]] = None):
|
||||
self.all_indices = []
|
||||
self.all_dicts = []
|
||||
if variables is not None:
|
||||
for i, el in enumerate(variables):
|
||||
self.append(el)
|
||||
|
||||
def append(self, var_list: Union[list[MutableMapping], MutableMapping]):
|
||||
"""append a list of variable parameter sets
|
||||
each call to append creates a new group of parameters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
var_list : Union[list[MutableMapping], MutableMapping]
|
||||
each dict in the list is treated as an independent parameter
|
||||
this means that if for one dict, len > 1, the lists of possible values
|
||||
must be the same length
|
||||
|
||||
Example
|
||||
-------
|
||||
`>> append([dict(wavelength=[800e-9, 900e-9], power=[1e3, 2e3]), dict(length=[3e-2, 3.5e-2, 4e-2])])`
|
||||
|
||||
means that for every parameter variations, wavelength=800e-9 will always occur when power=1e3 and
|
||||
vice versa, while length is free to vary independently
|
||||
|
||||
Raises
|
||||
------
|
||||
VariationSpecsError
|
||||
raised when possible values lists in a same dict are not the same length
|
||||
"""
|
||||
if not isinstance(var_list, Sequence):
|
||||
var_list = [{k: v} for k, v in var_list.items()]
|
||||
else:
|
||||
var_list = list(var_list)
|
||||
num_vars = []
|
||||
for d in var_list:
|
||||
values = list(d.values())
|
||||
if len(values) > 0:
|
||||
len_to_test = len(values[0])
|
||||
if not all(len(v) == len_to_test for v in values[1:]):
|
||||
raise VariationSpecsError(
|
||||
"variable items should all have the same number of parameters"
|
||||
)
|
||||
num_vars.append(len_to_test)
|
||||
if len(num_vars) == 0:
|
||||
num_vars = [1]
|
||||
self.all_indices.append(num_vars)
|
||||
self.all_dicts.append(var_list)
|
||||
|
||||
def iterate(self, index: int = -1) -> Generator["VariationDescriptor", None, None]:
|
||||
index = self.__index(index)
|
||||
flattened_indices = sum(self.all_indices[: index + 1], [])
|
||||
index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[: index + 1]])
|
||||
ranges = [range(i) for i in flattened_indices]
|
||||
for r in itertools.product(*ranges):
|
||||
out: list[list[tuple[str, Any]]] = []
|
||||
indicies: list[list[int]] = []
|
||||
for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])):
|
||||
out.append([])
|
||||
indicies.append([])
|
||||
for value_index, var_d in zip(r[start:end], self.all_dicts[i]):
|
||||
for k, v in var_d.items():
|
||||
out[-1].append((k, v[value_index]))
|
||||
indicies[-1].append(value_index)
|
||||
yield VariationDescriptor(raw_descr=out, index=indicies)
|
||||
|
||||
def __index(self, index: int) -> int:
|
||||
if index < 0:
|
||||
index = len(self.all_indices) + index
|
||||
return index
|
||||
|
||||
def var_num(self, index: int = -1) -> int:
|
||||
index = self.__index(index)
|
||||
return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
|
||||
|
||||
|
||||
class VariationDescriptor(BaseModel):
|
||||
raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
|
||||
index: tuple[tuple[int, ...], ...]
|
||||
separator: str = "fiber"
|
||||
_format_registry: dict[str, Callable[..., str]] = {}
|
||||
__ids: dict[int, int] = {}
|
||||
|
||||
@classmethod
|
||||
def register_formatter(cls, p_name: str, func: Callable[..., str]):
|
||||
"""register a function that formats a particular parameter
|
||||
|
||||
Parameters
|
||||
----------
|
||||
p_name : str
|
||||
name of the parameter
|
||||
func : Callable[..., str]
|
||||
function that takes as single argument the value of the parameter and returns a string
|
||||
"""
|
||||
cls._format_registry[p_name] = func
|
||||
|
||||
@classmethod
|
||||
def format_value(cls, name: str, value) -> str:
|
||||
if value is True or value is False:
|
||||
return str(value)
|
||||
elif isinstance(value, (float, int)):
|
||||
try:
|
||||
return cls._format_registry[name](value)
|
||||
except KeyError:
|
||||
return format(value, ".9g")
|
||||
elif isinstance(value, (list, tuple, np.ndarray)):
|
||||
return "-".join([str(v) for v in value])
|
||||
elif isinstance(value, str):
|
||||
p = Path(value)
|
||||
if p.exists():
|
||||
return p.stem
|
||||
elif callable(value):
|
||||
return getattr(value, "__name__", repr(value))
|
||||
return str(value)
|
||||
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
|
||||
def formatted_descriptor(self, add_identifier=False) -> str:
|
||||
"""formats a variable list into a str such that each simulation has a unique
|
||||
directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations)
|
||||
branch identifier can added at the beginning.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
add_identifier : bool
|
||||
add unique simulation and parameter-set identifiers
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
simulation descriptor
|
||||
"""
|
||||
str_list = []
|
||||
|
||||
for p_name, p_value in self.flat:
|
||||
ps, vs = self._format_single_pair(p_name, p_value)
|
||||
str_list.append(ps + PARAM_SEPARATOR + vs)
|
||||
tmp_name = PARAM_SEPARATOR.join(str_list)
|
||||
if not add_identifier:
|
||||
return tmp_name
|
||||
return (
|
||||
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
|
||||
)
|
||||
|
||||
def _format_single_pair(self, p_name: str, p_value: Any) -> tuple[str, str]:
|
||||
ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "")
|
||||
vs = self.format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
|
||||
return ps, vs
|
||||
|
||||
def __getitem__(self, key) -> "VariationDescriptor":
|
||||
return VariationDescriptor(
|
||||
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.formatted_descriptor(add_identifier=False)
|
||||
|
||||
def __lt__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr < other.raw_descr
|
||||
|
||||
def __le__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr <= other.raw_descr
|
||||
|
||||
def __gt__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr > other.raw_descr
|
||||
|
||||
def __ge__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr >= other.raw_descr
|
||||
|
||||
def __eq__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr == other.raw_descr
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.raw_descr)
|
||||
|
||||
def __contains__(self, other: "VariationDescriptor") -> bool:
|
||||
return all(el in self.raw_descr for el in other.raw_descr)
|
||||
|
||||
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
|
||||
"""updates a dictionary with the value of the descriptor
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cfg : dict[str, Any]
|
||||
dict to be updated
|
||||
index : int, optional
|
||||
index of the fiber from which to apply the parameters, by default -1
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, Any]
|
||||
same as cfg but with key from the descriptor added/updated.
|
||||
"""
|
||||
out_cfg = cfg.copy()
|
||||
out_cfg.pop("variable", None)
|
||||
return out_cfg | {k: v for k, v in self.raw_descr[index]}
|
||||
|
||||
def iter_parents(self) -> Iterator["VariationDescriptor"]:
|
||||
if (p := self.parent) is not None:
|
||||
yield from p.iter_parents()
|
||||
yield self
|
||||
|
||||
@property
|
||||
def short(self) -> str:
|
||||
"""shortened description of the simulation"""
|
||||
return " ".join(
|
||||
self._format_single_pair(p, v)[1] for p, v in self.flat if p not in {"fiber", "num"}
|
||||
)
|
||||
|
||||
@property
|
||||
def flat(self) -> list[tuple[str, Any]]:
|
||||
out = []
|
||||
for n, variable_list in enumerate(self.raw_descr):
|
||||
out += [
|
||||
(self.separator, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[n % 26] * (n // 26 + 1)),
|
||||
*variable_list,
|
||||
]
|
||||
return out
|
||||
|
||||
@property
|
||||
def branch(self) -> "BranchDescriptor":
|
||||
descr: list[list[tuple[str, Any]]] = []
|
||||
ind: list[list[int]] = []
|
||||
for i, l in enumerate(self.raw_descr):
|
||||
descr.append([])
|
||||
ind.append([])
|
||||
for j, (k, v) in enumerate(l):
|
||||
if k != "num":
|
||||
descr[-1].append((k, v))
|
||||
ind[-1].append(self.index[i][j])
|
||||
return BranchDescriptor(raw_descr=descr, index=ind, separator=self.separator)
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
unique_id = hash(str(self.flat))
|
||||
self.__ids.setdefault(unique_id, len(self.__ids))
|
||||
return "u_" + str(self.__ids[unique_id])
|
||||
|
||||
@property
|
||||
def parent(self) -> Optional["VariationDescriptor"]:
|
||||
if len(self.raw_descr) < 2:
|
||||
return None
|
||||
return VariationDescriptor(
|
||||
raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator
|
||||
)
|
||||
|
||||
@property
|
||||
def num_fibers(self) -> int:
|
||||
return len(self.raw_descr)
|
||||
|
||||
|
||||
class BranchDescriptor(VariationDescriptor):
|
||||
__ids: dict[int, int] = {}
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
branch_id = hash(str(self.flat))
|
||||
self.__ids.setdefault(branch_id, len(self.__ids))
|
||||
return "b_" + str(self.__ids[branch_id])
|
||||
|
||||
@validator("raw_descr")
|
||||
def validate_raw_descr(cls, v):
|
||||
return tuple(tuple(el for el in variable if el[0] != "num") for variable in v)
|
||||
|
||||
|
||||
class DescriptorDict(Generic[T]):
|
||||
def __init__(self, dico: dict[VariationDescriptor, T] = None):
|
||||
self.dico: dict[tuple[tuple[tuple[str, Any], ...], ...], tuple[VariationDescriptor, T]] = {}
|
||||
if dico is not None:
|
||||
for k, v in dico.items():
|
||||
self[k] = v
|
||||
|
||||
def __setitem__(self, key: VariationDescriptor, value: T):
|
||||
if not isinstance(key, VariationDescriptor):
|
||||
raise TypeError("key must be a VariationDescriptor instance")
|
||||
self.dico[key.raw_descr] = (key, value)
|
||||
|
||||
def __getitem__(
|
||||
self, key: Union[VariationDescriptor, tuple[tuple[tuple[str, Any], ...], ...]]
|
||||
) -> T:
|
||||
if isinstance(key, VariationDescriptor):
|
||||
return self.dico[key.raw_descr][1]
|
||||
else:
|
||||
return self.dico[key][1]
|
||||
|
||||
def items(self) -> Iterator[tuple[VariationDescriptor, T]]:
|
||||
for k, v in self.dico.items():
|
||||
yield k, v[1]
|
||||
|
||||
def keys(self) -> list[VariationDescriptor]:
|
||||
return [v[0] for v in self.dico.values()]
|
||||
|
||||
def values(self) -> list[T]:
|
||||
return [v[1] for v in self.dico.values()]
|
||||
Reference in New Issue
Block a user