removed a bunch of stuff

Removed:
- Variationer
- FileConfiguration
- Scripts (slurm, ...)
- CLI
This commit is contained in:
Benoît Sierro
2023-07-24 08:23:04 +02:00
parent d6ba1bec7f
commit 57c593cf4f
16 changed files with 87 additions and 1355 deletions

10
.gitignore vendored
View File

@@ -1,6 +1,7 @@
.DS_store .DS_store
.idea .idea
**/*.npy **/*.npy
.conda-env
pyrightconfig.json pyrightconfig.json
@@ -17,15 +18,8 @@ __pycache__
tmp* tmp*
paths.json paths.json
scgenerator_log* scgenerator_log*
scgenerator.log
.scgenerator_tmp .scgenerator_tmp
sc-*.log sc-*.log
.vscode .vscode
# latex
*.aux
*.fdb_latexmk
*.fls
*.log
*.synctex.gz

View File

@@ -27,6 +27,7 @@ dependencies = [
[tool.ruff] [tool.ruff]
line-length = 100 line-length = 100
ignore = ["E741"]
[tool.ruff.pydocstyle] [tool.ruff.pydocstyle]
convention = "numpy" convention = "numpy"
@@ -34,3 +35,6 @@ convention = "numpy"
[tool.black] [tool.black]
line-length = 100 line-length = 100
[tool.isort]
profile = "black"

View File

@@ -1,10 +1,9 @@
# # flake8: noqa # isort: skip_file
# ruff: noqa
from scgenerator import math, operators, plotting from scgenerator import math, operators, plotting
from scgenerator.helpers import * from scgenerator.helpers import *
from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace
from scgenerator.parameter import FileConfiguration, Parameters from scgenerator.parameter import Parameters
from scgenerator.physics import fiber, materials, plasma, pulse, units from scgenerator.physics import fiber, materials, plasma, pulse, units
from scgenerator.physics.units import PlotRange from scgenerator.physics.units import PlotRange
from scgenerator.solver import integrate, solve43 from scgenerator.solver import integrate, solve43
from scgenerator.utils import (Paths, _open_config, open_single_config,
simulations_list)

View File

@@ -312,7 +312,7 @@ default_rules: list[Rule] = [
Rule("w_num", len, ["w"]), Rule("w_num", len, ["w"]),
Rule("dw", lambda w: w[1] - w[0]), Rule("dw", lambda w: w[1] - w[0]),
Rule(["fft", "ifft"], utils.fft_functions, priorities=1), Rule(["fft", "ifft"], utils.fft_functions, priorities=1),
Rule("interpolation_range", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)), Rule("wavelength_window", lambda dt: (max(100e-9, 2 * units.c * dt), 8e-6)),
# Pulse # Pulse
Rule("field_0", pulse.finalize_pulse), Rule("field_0", pulse.finalize_pulse),
Rule(["input_time", "input_field"], pulse.load_custom_field), Rule(["input_time", "input_field"], pulse.load_custom_field),
@@ -393,7 +393,7 @@ default_rules: list[Rule] = [
Rule( Rule(
"V_eff_arr", "V_eff_arr",
fiber.V_eff_step_index, fiber.V_eff_step_index,
["l", "core_radius", "numerical_aperture", "interpolation_range"], ["l", "core_radius", "numerical_aperture", "wavelength_window"],
), ),
Rule("n2", materials.gas_n2), Rule("n2", materials.gas_n2),
Rule("n2", lambda: 2.2e-20, priorities=-1), Rule("n2", lambda: 2.2e-20, priorities=-1),
@@ -434,7 +434,7 @@ envelope_rules = default_rules + [
Rule("beta2_arr", fiber.dispersion_from_coefficients), Rule("beta2_arr", fiber.dispersion_from_coefficients),
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]), Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
Rule( Rule(
["wl_for_disp", "beta2_arr", "interpolation_range"], ["wl_for_disp", "beta2_arr", "wavelength_window"],
fiber.load_custom_dispersion, fiber.load_custom_dispersion,
priorities=[2, 2, 2], priorities=[2, 2, 2],
), ),
@@ -442,7 +442,7 @@ envelope_rules = default_rules + [
Rule("gamma_op", operators.variable_gamma, priorities=2), Rule("gamma_op", operators.variable_gamma, priorities=2),
Rule("gamma_op", operators.constant_quantity, ["gamma_arr"], priorities=1), Rule("gamma_op", operators.constant_quantity, ["gamma_arr"], priorities=1),
Rule("gamma_op", lambda w_num, gamma: operators.constant_quantity(np.ones(w_num) * gamma)), Rule("gamma_op", lambda w_num, gamma: operators.constant_quantity(np.ones(w_num) * gamma)),
Rule("gamma_op", operators.no_op_freq, priorities=-1), Rule("gamma_op", lambda: operators.constant_quantity(0.0), priorities=-1),
Rule("ss_op", lambda w_c, w0: operators.constant_quantity(w_c / w0)), Rule("ss_op", lambda w_c, w0: operators.constant_quantity(w_c / w0)),
Rule("ss_op", lambda: operators.constant_quantity(0), priorities=-1), Rule("ss_op", lambda: operators.constant_quantity(0), priorities=-1),
Rule("spm_op", operators.envelope_spm), Rule("spm_op", operators.envelope_spm),

View File

@@ -48,7 +48,6 @@ def configure_logger(logger: logging.Logger):
updated logger updated logger
""" """
if not hasattr(logger, "already_configured"): if not hasattr(logger, "already_configured"):
print_lvl = lvl_map.get(log_print_level(), logging.NOTSET) print_lvl = lvl_map.get(log_print_level(), logging.NOTSET)
file_lvl = lvl_map.get(log_file_level(), logging.NOTSET) file_lvl = lvl_map.get(log_file_level(), logging.NOTSET)

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
from typing import Callable from typing import Callable
import numpy as np import numpy as np
from scgenerator import math from scgenerator import math
from scgenerator.logger import get_logger from scgenerator.logger import get_logger
from scgenerator.physics import fiber, materials, plasma, pulse, units from scgenerator.physics import fiber, materials, plasma, pulse, units
@@ -267,7 +268,6 @@ def constant_wave_vector(
def envelope_raman(hr_w: np.ndarra, raman_fraction: float) -> FieldOperator: def envelope_raman(hr_w: np.ndarra, raman_fraction: float) -> FieldOperator:
def operate(field: np.ndarray, z: float) -> np.ndarray: def operate(field: np.ndarray, z: float) -> np.ndarray:
return raman_fraction * np.fft.ifft(hr_w * np.fft.fft(math.abs2(field))) return raman_fraction * np.fft.ifft(hr_w * np.fft.fft(math.abs2(field)))
@@ -336,7 +336,6 @@ def ionization(
N0 = number_density(z) N0 = number_density(z)
plasma_info = plasma_obj(field, N0) plasma_info = plasma_obj(field, N0)
# state.stats["ionization_fraction"] = plasma_info.electron_density[-1] / N0 # state.stats["ionization_fraction"] = plasma_info.electron_density[-1] / N0
# state.stats["electron_density"] = plasma_info.electron_density[-1] # state.stats["electron_density"] = plasma_info.electron_density[-1]
return plasma_info.polarization return plasma_info.polarization

View File

@@ -1,32 +1,50 @@
from __future__ import annotations from __future__ import annotations
import datetime as datetime_module import datetime as datetime_module
import enum
import os import os
import time
from copy import copy from copy import copy
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from functools import lru_cache, wraps from functools import lru_cache, wraps
from math import isnan from math import isnan
from pathlib import Path from pathlib import Path
from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeVar from typing import (Any, Callable, ClassVar, Iterable, Iterator, Set, Type,
TypeVar)
import numpy as np import numpy as np
from scgenerator import env, utils from scgenerator import utils
from scgenerator.const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__ from scgenerator.const import MANDATORY_PARAMETERS, __version__
from scgenerator.errors import EvaluatorError from scgenerator.errors import EvaluatorError
from scgenerator.evaluator import Evaluator from scgenerator.evaluator import Evaluator
from scgenerator.logger import get_logger
from scgenerator.operators import Qualifier, SpecOperator from scgenerator.operators import Qualifier, SpecOperator
from scgenerator.utils import fiber_folder, update_path_name from scgenerator.utils import update_path_name
from scgenerator.variationer import VariationDescriptor, Variationer
T = TypeVar("T") T = TypeVar("T")
DISPLAY_INFO = {}
def format_value(name: str, value) -> str:
if value is True or value is False:
return str(value)
elif isinstance(value, (float, int)):
try:
return DISPLAY_INFO[name](value)
except KeyError:
return format(value, ".9g")
elif isinstance(value, np.ndarray):
return np.array2string(value)
elif isinstance(value, (list, tuple)):
return "-".join([str(v) for v in value])
elif isinstance(value, str):
p = Path(value)
if p.exists():
return p.stem
elif callable(value):
return getattr(value, "__name__", repr(value))
return str(value)
# Validator # Validator
@lru_cache @lru_cache
def type_checker(*types): def type_checker(*types):
def _type_checker_wrapper(validator, n=None): def _type_checker_wrapper(validator, n=None):
@@ -224,7 +242,7 @@ class Parameter:
pass pass
if self.default is not None: if self.default is not None:
Evaluator.register_default_param(self.name, self.default) Evaluator.register_default_param(self.name, self.default)
VariationDescriptor.register_formatter(self.name, self.display) DISPLAY_INFO[self.name] = self.display
def __get__(self, instance: Parameters, owner): def __get__(self, instance: Parameters, owner):
if instance is None: if instance is None:
@@ -382,7 +400,7 @@ class Parameters:
dt: float = Parameter(in_range_excl(0, 10e-15)) dt: float = Parameter(in_range_excl(0, 10e-15))
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11) tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-3), default=1e-11)
step_size: float = Parameter(non_negative(float, int), default=0) step_size: float = Parameter(non_negative(float, int), default=0)
interpolation_range: tuple[float, float] = Parameter( wavelength_window: tuple[float, float] = Parameter(
validator_and(float_pair, validator_list(in_range_incl(100e-9, 10000e-9))) validator_and(float_pair, validator_list(in_range_incl(100e-9, 10000e-9)))
) )
interpolation_degree: int = Parameter(validator_and(type_checker(int), in_range_incl(2, 18))) interpolation_degree: int = Parameter(validator_and(type_checker(int), in_range_incl(2, 18)))
@@ -469,11 +487,7 @@ class Parameters:
exclude = exclude or [] exclude = exclude or []
if isinstance(exclude, str): if isinstance(exclude, str):
exclude = [exclude] exclude = [exclude]
p_pairs = [ p_pairs = [(k, format_value(k, getattr(self, k))) for k in params if k not in exclude]
(k, VariationDescriptor.format_value(k, getattr(self, k)))
for k in params
if k not in exclude
]
max_left = max(len(el[0]) for el in p_pairs) max_left = max(len(el[0]) for el in p_pairs)
max_right = max(len(el[1]) for el in p_pairs) max_right = max(len(el[1]) for el in p_pairs)
return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs) return "\n".join("{:>{l}} = {:{r}}".format(*p, l=max_left, r=max_right) for p in p_pairs)
@@ -544,262 +558,6 @@ class Parameters:
return None return None
class AbstractConfiguration:
fiber_paths: list[Path]
num_sim: int
total_num_steps: int
worker_num: int
final_path: Path
def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]:
raise NotImplementedError()
def save_parameters(self):
raise NotImplementedError()
class FileConfiguration(AbstractConfiguration):
"""
Primary role is to load the final config file of the simulation and deduce every
simulatin that has to happen. Iterating through the Configuration obj yields a list of
parameter names and values that change throughout the simulation as well as parameter
obj with the output path of the simulation saved in its output_path attribute.
"""
fiber_configs: list[utils.SubConfig]
master_config_dict: dict[str, Any]
num_fibers: int
repeat: int
z_num: int
overwrite: bool
all_configs: dict[tuple[tuple[int, ...], ...], "FileConfiguration.__SimConfig"]
@dataclass(frozen=True)
class __SimConfig:
descriptor: VariationDescriptor
config: dict[str, Any]
output_path: Path
@property
def sim_num(self) -> int:
return len(self.descriptor.index)
class State(enum.Enum):
COMPLETE = enum.auto()
PARTIAL = enum.auto()
ABSENT = enum.auto()
class Action(enum.Enum):
RUN = enum.auto()
WAIT = enum.auto()
SKIP = enum.auto()
def __init__(
self,
config_path: os.PathLike,
overwrite: bool = True,
wait: bool = False,
skip_callback: Callable[[int], None] = None,
final_output_path: os.PathLike = None,
):
self.logger = get_logger(__name__)
self.wait = wait
self.overwrite = overwrite
self.final_path, self.fiber_configs = utils.load_config_sequence(config_path)
self.final_path = env.get(env.OUTPUT_PATH, self.final_path)
if final_output_path is not None:
self.final_path = final_output_path
self.final_path = utils.ensure_folder(
Path(self.final_path),
mkdir=False,
prevent_overwrite=not self.overwrite,
)
self.master_config_dict = self.fiber_configs[0].fixed | {
k: v[0] for vary_dict in self.fiber_configs[0].variable for k, v in vary_dict.items()
}
self.name = self.final_path.name
self.z_num = 0
self.total_num_steps = 0
self.fiber_paths = []
self.all_configs = {}
self.skip_callback = skip_callback
self.worker_num = self.master_config_dict.get("worker_num", max(1, os.cpu_count() // 2))
self.repeat = self.master_config_dict.get("repeat", 1)
self.variationer = Variationer()
fiber_names = set()
self.num_fibers = 0
for i, config in enumerate(self.fiber_configs):
config.fixed.setdefault("name", Parameters.name.default)
self.z_num += config.fixed["z_num"]
fiber_names.add(config.fixed["name"])
self.variationer.append(config.variable)
self.fiber_paths.append(
utils.ensure_folder(
self.final_path / fiber_folder(i, self.name, Path(config.fixed["name"]).name),
mkdir=False,
prevent_overwrite=not self.overwrite,
)
)
self.__validate_variable(config.variable)
self.num_fibers += 1
Evaluator.evaluate_default(
self.master_config_dict
| config.fixed
| {k: v[0] for vary_dict in config.variable for k, v in vary_dict.items()},
True,
)
self.num_sim = self.variationer.var_num()
self.total_num_steps = sum(
config.fixed["z_num"] * self.variationer.var_num(i)
for i, config in enumerate(self.fiber_configs)
)
def __validate_variable(self, vary_dict_list: list[dict[str, list]]):
for vary_dict in vary_dict_list:
for k, v in vary_dict.items():
p: Parameter = getattr(Parameters, k)
validator_list(p.validator)("variable " + k, v)
if k not in VALID_VARIABLE:
raise TypeError(f"{k!r} is not a valid variable parameter")
if len(v) == 0:
raise ValueError(f"variable parameter {k!r} must not be empty")
def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]:
for i in range(self.num_fibers):
yield from self.iterate_single_fiber(i)
def iterate_single_fiber(self, index: int) -> Iterator[tuple[VariationDescriptor, Parameters]]:
"""iterates through the parameters of only one fiber. It takes care of recovering partially
completed simulations, skipping complete ones and waiting for the previous fiber to finish
Parameters
----------
index : int
which fiber to iterate over
Yields
-------
__SimConfig
configuration obj
"""
if index < 0:
index = self.num_fibers + index
sim_dict: dict[Path, FileConfiguration.__SimConfig] = {}
for descriptor in self.variationer.iterate(index):
cfg = descriptor.update_config(self.fiber_configs[index].fixed)
if index > 0:
cfg["prev_data_dir"] = str(
self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True)
)
p = utils.ensure_folder(
self.fiber_paths[index] / descriptor.formatted_descriptor(True),
not self.overwrite,
False,
)
cfg["output_path"] = p
sim_config = self.__SimConfig(descriptor, cfg, p)
sim_dict[p] = self.all_configs[sim_config.descriptor.index] = sim_config
while len(sim_dict) > 0:
for data_dir, sim_config in sim_dict.items():
task, config_dict = self.__decide(sim_config)
if task == self.Action.RUN:
sim_dict.pop(data_dir)
param_dict = sim_config.config
yield sim_config.descriptor, Parameters(**param_dict)
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
self.skip_callback(config_dict["recovery_last_stored"])
break
elif task == self.Action.SKIP:
sim_dict.pop(data_dir)
self.logger.debug(f"skipping {data_dir} as it is already complete")
if self.skip_callback is not None:
self.skip_callback(config_dict["z_num"])
break
else:
self.logger.debug("sleeping while waiting for other simulations to complete")
time.sleep(1)
def __decide(
self, sim_config: "FileConfiguration.__SimConfig"
) -> tuple["FileConfiguration.Action", dict[str, Any]]:
"""decide what to to with a particular simulation
Parameters
----------
sim_config : __SimConfig
Returns
-------
str : Configuration.Action
what to do
config_dict : dict[str, Any]
config dictionary. The only key possibly modified is 'prev_data_dir', which
gets set if the simulation is partially completed
"""
if not self.wait:
return self.Action.RUN, sim_config.config
out_status, num = self.sim_status(sim_config.output_path, sim_config.config)
if out_status == self.State.COMPLETE:
return self.Action.SKIP, sim_config.config
elif out_status == self.State.PARTIAL:
sim_config.config["recovery_data_dir"] = str(sim_config.output_path)
sim_config.config["recovery_last_stored"] = num
return self.Action.RUN, sim_config.config
if "prev_data_dir" in sim_config.config:
prev_data_path = Path(sim_config.config["prev_data_dir"])
prev_status, _ = self.sim_status(prev_data_path)
if prev_status in {self.State.PARTIAL, self.State.ABSENT}:
return self.Action.WAIT, sim_config.config
return self.Action.RUN, sim_config.config
def sim_status(
self, data_dir: Path, config_dict: dict[str, Any] = None
) -> tuple["FileConfiguration.State", int]:
"""returns the status of a simulation
Parameters
----------
data_dir : Path
directory where simulation data is to be saved
config_dict : dict[str, Any], optional
configuration of the simulation. If None, will attempt to load
the params.toml file if present, by default None
Returns
-------
Configuration.State
status
"""
num = utils.find_last_spectrum_num(data_dir)
if config_dict is None:
try:
config_dict = utils.load_toml(data_dir / PARAM_FN)
except FileNotFoundError:
self.logger.warning(f"did not find {PARAM_FN!r} in {data_dir}")
return self.State.ABSENT, 0
if num == config_dict["z_num"] - 1:
return self.State.COMPLETE, num
elif config_dict["z_num"] - 1 > num > 0:
return self.State.PARTIAL, num
elif num == 0:
return self.State.ABSENT, 0
else:
raise ValueError(f"Too many spectra in {data_dir}")
def save_parameters(self):
os.makedirs(self.final_path, exist_ok=True)
cfgs = [cfg.fixed | dict(variable=cfg.variable) for cfg in self.fiber_configs]
utils.save_toml(self.final_path / "initial_config.toml", dict(name=self.name, Fiber=cfgs))
@property
def first(self) -> Parameters:
for _, param in self:
return param
if __name__ == "__main__": if __name__ == "__main__":
numero = type_checker(int) numero = type_checker(int)

View File

@@ -1,189 +0,0 @@
import multiprocessing
import os
import random
import threading
import typing
from collections import abc
from io import StringIO
from pathlib import Path
from typing import Iterable, Union
from tqdm import tqdm
from scgenerator.env import pbar_policy
T_ = typing.TypeVar("T_")
class PBars:
def __init__(
self,
task: Union[int, Iterable[T_]],
desc: str,
num_sub_bars: int = 0,
head_kwargs=None,
worker_kwargs=None,
) -> "PBars":
"""creates a PBars obj
Parameters
----------
task : int | Iterable
if int : total length of the main task
if Iterable : behaves like tqdm
desc : str
description of the main task
num_sub_bars : int
number of sub-tasks
"""
self.id = random.randint(100000, 999999)
try:
self.width = os.get_terminal_size().columns
except OSError:
self.width = 120
if isinstance(task, abc.Iterable):
self.iterator: Iterable[T_] = iter(task)
self.num_tot: int = len(task)
else:
self.num_tot: int = task
self.iterator = None
self.policy = pbar_policy()
if head_kwargs is None:
head_kwargs = dict()
if worker_kwargs is None:
worker_kwargs = dict(
total=1,
desc="Worker {worker_id}",
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
)
if "print" not in pbar_policy():
head_kwargs["file"] = worker_kwargs["file"] = StringIO()
self.width = 80
head_kwargs["desc"] = desc
self.pbars = [tqdm(total=self.num_tot, ncols=self.width, ascii=False, **head_kwargs)]
for i in range(1, num_sub_bars + 1):
kwargs = {k: v for k, v in worker_kwargs.items()}
if "desc" in kwargs:
kwargs["desc"] = kwargs["desc"].format(worker_id=i)
self.append(tqdm(position=i, ncols=self.width, ascii=False, **kwargs))
self.print_path = Path(
f"progress {self.pbars[0].desc.replace('/', '')} {self.id}"
).resolve()
self.close_ev = threading.Event()
if "file" in self.policy:
self.thread = threading.Thread(target=self.print_worker, daemon=True)
self.thread.start()
def print(self):
if "file" not in self.policy:
return
s = []
for pbar in self.pbars:
s.append(str(pbar))
self.print_path.write_text("\n".join(s))
def print_worker(self):
while True:
if self.close_ev.wait(2.0):
return
self.print()
def __iter__(self):
with self as pb:
for thing in self.iterator:
yield thing
pb.update()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __getitem__(self, key):
return self.pbars[key]
def update(self, i=None, value=1):
if i is None:
for pbar in self.pbars[1:]:
pbar.update(value)
elif i > 0:
self.pbars[i].update(value)
self.pbars[0].update()
def append(self, pbar: tqdm):
self.pbars.append(pbar)
def reset(self, i):
self.pbars[i].update(-self.pbars[i].n)
self.print()
def close(self):
self.print()
self.close_ev.set()
if "file" in self.policy:
self.thread.join()
for pbar in self.pbars:
pbar.close()
class ProgressBarActor:
def __init__(self, name: str, num_workers: int, num_steps: int) -> None:
self.counters = [0 for _ in range(num_workers + 1)]
self.p_bars = PBars(
num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")
)
def update(self, worker_id: int, rel_pos: float = None) -> None:
"""update a counter
Parameters
----------
worker_id : int
id of the worker. 0 is the overall progress
rel_pos : float, optional
if None, increase the counter by one, if set, will set
the counter to the specified value (instead of incrementing it), by default None
"""
if rel_pos is None:
self.counters[worker_id] += 1
else:
self.counters[worker_id] = rel_pos
def update_pbars(self):
for counter, pbar in zip(self.counters, self.p_bars.pbars):
pbar.update(counter - pbar.n)
def close(self):
self.p_bars.close()
def progress_worker(
name: str, num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue
):
"""keeps track of progress on a separate thread
Parameters
----------
num_steps : int
total number of steps, used for the main progress bar (position 0)
progress_queue : multiprocessing.Queue
values are either
Literal[0] : stop the worker and close the progress bars
tuple[int, float] : worker id and relative progress between 0 and 1
"""
with PBars(
num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")
) as pbars:
while True:
raw = progress_queue.get()
if raw == 0:
return
i, rel_pos = raw
if i > 0:
pbars[i].update(rel_pos - pbars[i].n)
pbars[0].update()
elif i == 0:
pbars[0].update(rel_pos)

View File

@@ -18,7 +18,7 @@ T = TypeVar("T")
def lambda_for_envelope_dispersion( def lambda_for_envelope_dispersion(
l: np.ndarray, interpolation_range: tuple[float, float] l: np.ndarray, wavelength_window: tuple[float, float]
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
"""Returns a wl vector for dispersion calculation in envelope mode """Returns a wl vector for dispersion calculation in envelope mode
@@ -30,10 +30,10 @@ def lambda_for_envelope_dispersion(
np.ndarray np.ndarray
indices of the original l where the values are valid (i.e. without the two extra on each side) indices of the original l where the values are valid (i.e. without the two extra on each side)
""" """
su = np.where((l >= interpolation_range[0]) & (l <= interpolation_range[1]))[0] su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0]
if l[su].min() > 1.01 * interpolation_range[0]: if l[su].min() > 1.01 * wavelength_window[0]:
raise ValueError( raise ValueError(
f"lower range of {1e9*interpolation_range[0]:.1f}nm is not reached by the grid. " f"lower range of {1e9*wavelength_window[0]:.1f}nm is not reached by the grid. "
f"Minimum of grid is {1e9*l[su].min():.1f}nm. Try a finer grid" f"Minimum of grid is {1e9*l[su].min():.1f}nm. Try a finer grid"
) )
@@ -48,7 +48,7 @@ def lambda_for_envelope_dispersion(
def lambda_for_full_field_dispersion( def lambda_for_full_field_dispersion(
l: np.ndarray, interpolation_range: tuple[float, float] l: np.ndarray, wavelength_window: tuple[float, float]
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
"""Returns a wl vector for dispersion calculation in full field mode """Returns a wl vector for dispersion calculation in full field mode
@@ -60,10 +60,10 @@ def lambda_for_full_field_dispersion(
np.ndarray np.ndarray
indices of the original l where the values are valid (i.e. without the two extra on each side) indices of the original l where the values are valid (i.e. without the two extra on each side)
""" """
su = np.where((l >= interpolation_range[0]) & (l <= interpolation_range[1]))[0] su = np.where((l >= wavelength_window[0]) & (l <= wavelength_window[1]))[0]
if l[su].min() > 1.01 * interpolation_range[0]: if l[su].min() > 1.01 * wavelength_window[0]:
raise ValueError( raise ValueError(
f"lower range of {1e9*interpolation_range[0]:.1f}nm is not reached by the grid. " f"lower range of {1e9*wavelength_window[0]:.1f}nm is not reached by the grid. "
"try a finer grid" "try a finer grid"
) )
fu = np.concatenate((su[:2] - 2, su, su[-2:] + 2)) fu = np.concatenate((su[:2] - 2, su, su[-2:] + 2))
@@ -385,7 +385,7 @@ def V_eff_step_index(
l: T, l: T,
core_radius: float, core_radius: float,
numerical_aperture: float, numerical_aperture: float,
interpolation_range: tuple[float, float] = None, wavelength_window: tuple[float, float] = None,
) -> T: ) -> T:
"""computes the V parameter of a step-index fiber """computes the V parameter of a step-index fiber
@@ -397,7 +397,7 @@ def V_eff_step_index(
radius of the core radius of the core
numerical_aperture : float numerical_aperture : float
as a decimal number as a decimal number
interpolation_range : tuple[float, float], optional wavelength_window : tuple[float, float], optional
when provided, only computes V over this range, wavelengths outside this range will when provided, only computes V over this range, wavelengths outside this range will
yield V=inf, by default None yield V=inf, by default None
@@ -407,8 +407,8 @@ def V_eff_step_index(
V parameter V parameter
""" """
pi2cn = 2 * pi * core_radius * numerical_aperture pi2cn = 2 * pi * core_radius * numerical_aperture
if interpolation_range is not None and isinstance(l, np.ndarray): if wavelength_window is not None and isinstance(l, np.ndarray):
low, high = interpolation_range low, high = wavelength_window
l = np.where((l >= low) & (l <= high), l, np.inf) l = np.where((l >= low) & (l <= high), l, np.inf)
return pi2cn / l return pi2cn / l
@@ -805,7 +805,6 @@ def delayed_raman_w(t: np.ndarray, raman_type: str) -> tuple[np.ndarray, float]:
return hr_w, raman_fraction(raman_type) return hr_w, raman_fraction(raman_type)
def fast_poly_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None)): def fast_poly_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None)):
""" """
dispersive operator dispersive operator

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cache from functools import cache
from typing import TypeVar from typing import Any, TypeVar
import numpy as np import numpy as np
@@ -106,14 +106,16 @@ class Gas:
chi3_0: float chi3_0: float
ionization_energy: float | None ionization_energy: float | None
_raw_sellmeier: dict[str, Any]
def __init__(self, gas_name: str): def __init__(self, gas_name: str):
self.name = gas_name self.name = gas_name
self.mat_dico = utils.load_material_dico(gas_name) self._raw_sellmeier = utils.load_material_dico(gas_name)
self.atomic_mass = self.mat_dico["atomic_mass"] self.atomic_mass = self._raw_sellmeier["atomic_mass"]
self.atomic_number = self.mat_dico["atomic_number"] self.atomic_number = self._raw_sellmeier["atomic_number"]
self.ionization_energy = self.mat_dico.get("ionization_energy") self.ionization_energy = self._raw_sellmeier.get("ionization_energy")
s = self.mat_dico.get("sellmeier", {}) s = self._raw_sellmeier.get("sellmeier", {})
self.sellmeier = Sellmeier( self.sellmeier = Sellmeier(
**{ **{
newk: s.get(k, None) newk: s.get(k, None)
@@ -124,7 +126,7 @@ class Gas:
if k in s if k in s
} }
) )
kerr = self.mat_dico["kerr"] kerr = self._raw_sellmeier["kerr"]
n2_0 = kerr["n2"] n2_0 = kerr["n2"]
self._kerr_wl = kerr.get("wavelength", 800e-9) self._kerr_wl = kerr.get("wavelength", 800e-9)
self.chi3_0 = ( self.chi3_0 = (
@@ -212,18 +214,23 @@ class Gas:
Raises Raises
---------- ----------
ValueError : Since the Van der Waals equation is a cubic one, there could be more than one real, positive solution ValueError : Since the Van der Waals equation is a cubic one, there could be more than one
real, positive solution
""" """
logger = get_logger(__name__) logger = get_logger(__name__)
if pressure == 0: if pressure == 0:
return 0 return 0
a = self.mat_dico.get("a", 0) a = self._raw_sellmeier.get("a", 0)
b = self.mat_dico.get("b", 0) b = self._raw_sellmeier.get("b", 0)
pressure = self.mat_dico["sellmeier"].get("P0", 101325) if pressure is None else pressure pressure = (
self._raw_sellmeier["sellmeier"].get("P0", 101325) if pressure is None else pressure
)
temperature = ( temperature = (
self.mat_dico["sellmeier"].get("T0", 273.15) if temperature is None else temperature self._raw_sellmeier["sellmeier"].get("T0", 273.15)
if temperature is None
else temperature
) )
ap = a / NA**2 ap = a / NA**2
bp = b / NA bp = b / NA
@@ -302,10 +309,10 @@ class Gas:
return Z**3 / (16 * ns**4) * 5.14220670712125e11 return Z**3 / (16 * ns**4) * 5.14220670712125e11
def get(self, key, default=None): def get(self, key, default=None):
return self.mat_dico.get(key, default) return self._raw_sellmeier.get(key, default)
def __getitem__(self, key): def __getitem__(self, key):
return self.mat_dico[key] return self._raw_sellmeier[key]
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name!r})" return f"{self.__class__.__name__}({self.name!r})"
@@ -317,7 +324,7 @@ def n_gas_2(wl_for_disp: np.ndarray, gas_name: str, pressure: float, temperature
return Sellmeier.load(gas_name).n_gas_2(wl_for_disp, temperature, pressure) return Sellmeier.load(gas_name).n_gas_2(wl_for_disp, temperature, pressure)
def pressure_from_gradient(ratio, p0, p1): def pressure_from_gradient(ratio: float, p0: float, p1: float) -> float:
"""returns the pressure as function of distance with eq. 20 in Markos et al. (2017) """returns the pressure as function of distance with eq. 20 in Markos et al. (2017)
Parameters Parameters
---------- ----------

View File

@@ -1,346 +0,0 @@
import os
import re
from pathlib import Path
from typing import Any, Iterable, Optional
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from tqdm import tqdm
from scgenerator import env, math
from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN
from scgenerator.parameter import FileConfiguration, Parameters
from scgenerator.physics import fiber, units
from scgenerator.plotting import plot_setup, transform_2D_propagation, get_extent
from scgenerator.spectra import SimulationSeries
from scgenerator.utils import _open_config, auto_crop, save_toml, simulations_list, load_toml, load_spectrum
def fingerprint(params: Parameters):
h1 = hash(params.field_0.tobytes())
h2 = tuple(params.beta2_coefficients)
return h1, h2
def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
for k, v in opts.items():
if k in ["skip"]:
opts[k] = int(v)
if v == "True":
opts[k] = True
elif v == "False":
opts[k] = False
dir_list = simulations_list(sim_dir)
if len(dir_list) == 0:
dir_list = [sim_dir]
limits = [
tuple(func(el) for func, el in zip([float, float, str], lim.split(","))) for lim in limits
]
with tqdm(total=len(dir_list) * max(1, len(limits))) as bar:
for p in dir_list:
pulse = SimulationSeries(p)
if not limits:
limits = [
(
pulse.params.interpolation_range[0] * 1e9,
pulse.params.interpolation_range[1] * 1e9,
"nm",
)
]
for left, right, unit in limits:
path, fig, ax = plot_setup(
pulse.path.parent
/ (
pulse.path.name
+ PARAM_SEPARATOR
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
)
)
fig.suptitle(p.name)
pulse.plot_2D(
left,
right,
unit,
ax,
**opts,
)
bar.update()
if show:
plt.show()
else:
fig.savefig(path, bbox_inches="tight")
plt.close(fig)
def plot_init_field_spec(
config_path: Path,
lim_t: tuple[float, float] = None,
lim_l: tuple[float, float] = None,
):
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7))
all_labels = []
already_plotted = set()
for style, lbl, params in plot_helper(config_path):
if (bbb := hash(params.field_0.tobytes())) not in already_plotted:
already_plotted.add(bbb)
else:
continue
lbl = plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params)
all_labels.append(lbl)
finish_plot(fig, left, right, all_labels, params)
def plot_dispersion(config_path: Path, lim: tuple[float, float] = None):
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7))
left.grid()
right.grid()
all_labels = []
already_plotted = set()
loss_ax = None
plt.sca(left)
for style, lbl, params in plot_helper(config_path):
if params.alpha_arr is not None and loss_ax is None:
loss_ax = right.twinx()
if (bbb := tuple(params.beta2_coefficients)) not in already_plotted:
already_plotted.add(bbb)
else:
continue
lbl = plot_1_dispersion(lim, left, right, style, lbl, params, loss_ax)
all_labels.append(lbl)
finish_plot(fig, right, all_labels, params)
def plot_init(
config_path: Path,
lim_field: tuple[float, float] = None,
lim_spec: tuple[float, float] = None,
lim_disp: tuple[float, float] = None,
):
fig, ((tl, tr), (bl, br)) = plt.subplots(2, 2, figsize=(14, 10))
loss_ax = None
tl.grid()
tr.grid()
all_labels = []
already_plotted = set()
for style, lbl, params in plot_helper(config_path):
if params.alpha_arr is not None and loss_ax is None:
loss_ax = tr.twinx()
if (fp := fingerprint(params)) not in already_plotted:
already_plotted.add(fp)
else:
continue
lbl = plot_1_dispersion(lim_disp, tl, tr, style, lbl, params, loss_ax)
lbl = plot_1_init_spec_field(lim_field, lim_spec, bl, br, style, lbl, params)
all_labels.append(lbl)
print(params.pretty_str(exclude="beta2_coefficients"))
finish_plot(fig, tr, all_labels, params)
def plot_1_init_spec_field(
lim_t: Optional[tuple[float, float]],
lim_l: Optional[tuple[float, float]],
left: plt.Axes,
right: plt.Axes,
style: dict[str, Any],
lbl: str,
params: Parameters,
):
field = math.abs2(params.field_0)
spec = math.abs2(params.spec_0)
t = units.fs.inv(params.t)
wl = units.nm.inv(params.w)
lbl += f" max at {wl[spec.argmax()]:.1f} nm"
mt = np.ones_like(t, dtype=bool)
if lim_t is not None:
mt &= t >= lim_t[0]
mt &= t <= lim_t[1]
else:
mt = auto_crop(t, field)
ml = np.ones_like(wl, dtype=bool)
if lim_l is not None:
ml &= wl >= lim_l[0]
ml &= wl <= lim_l[1]
else:
ml = auto_crop(wl, spec)
left.plot(t[mt], field[mt])
right.plot(wl[ml], spec[ml], label=" ", **style)
return lbl
def plot_1_dispersion(
lim: Optional[tuple[float, float]],
left: plt.Axes,
right: plt.Axes,
style: dict[str, Any],
lbl: list[str],
params: Parameters,
loss: plt.Axes = None,
):
beta_arr = fiber.dispersion_from_coefficients(params.w_c, params.beta2_coefficients)
wl = units.m.inv(params.w)
D = fiber.beta2_to_D(beta_arr, wl) * 1e6
zdw = math.all_zeros(wl, beta_arr)
zdw = zdw[(zdw >= params.interpolation_range[0]) & (zdw <= params.interpolation_range[1])]
if len(zdw) > 0:
zdw = zdw[np.argmin(abs(zdw - params.wavelength))]
lbl += f" ZDW at {zdw*1e9:.1f}nm"
else:
lbl += ""
m = np.ones_like(wl, dtype=bool)
if lim is None:
lim = params.interpolation_range
m &= wl >= (lim[0] if lim[0] < 1 else lim[0] * 1e-9)
m &= wl <= (lim[1] if lim[1] < 1 else lim[1] * 1e-9)
info_str = (
rf"$\lambda_{{\mathrm{{min}}}}={np.min(params.l[params.l>0])*1e9:.1f}$ nm"
+ f"\nlower interpolation limit : {params.interpolation_range[0]*1e9:.1f} nm\n"
+ f"max time delay : {params.t.max()*1e12:.1f} ps"
)
left.annotate(
info_str,
xy=(1, 1),
xytext=(-12, -12),
xycoords="axes fraction",
textcoords="offset points",
va="top",
ha="right",
backgroundcolor=(1, 1, 1, 0.4),
)
m = np.argwhere(m)[:, 0]
m = np.array(sorted(m, key=lambda el: wl[el]))
if len(m) == 0:
raise ValueError(f"nothing to plot in the range {lim!r}")
# plot D
right.plot(1e9 * wl[m], D[m], label=" ", **style)
right.set_ylabel(units.D_ps_nm_km.label)
# plot beta2
left.plot(units.nm.inv(params.w[m]), units.beta2_fs_cm.inv(beta_arr[m]), label=" ", **style)
left.set_ylabel(units.beta2_fs_cm.label)
left.set_xlabel(units.nm.label)
right.set_xlabel("wavelength (nm)")
if params.alpha_arr is not None and loss is not None:
loss.plot(1e9 * wl[m], params.alpha_arr[m], c="r", ls="--")
loss.set_ylabel("loss (1/m)", color="r")
loss.set_yscale("log")
loss.tick_params(axis="y", labelcolor="r")
return lbl
def finish_plot(fig: plt.Figure, legend_axes: plt.Axes, all_labels: list[str], params: Parameters):
fig.suptitle(params.name)
plt.tight_layout()
handles, _ = legend_axes.get_legend_handles_labels()
legend_axes.legend(handles, all_labels, prop=dict(size=8, family="monospace"))
out_path = env.output_path()
show = out_path is None
if not show:
file_name = out_path.stem + ".pdf"
out_path = out_path.parent / file_name
if (
out_path.exists()
and input(f"{out_path.name} already exsits, overwrite ? (y/[n])\n > ") != "y"
):
show = True
else:
fig.savefig(out_path, bbox_inches="tight")
if show:
plt.show()
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]:
cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"])
for style, (descriptor, params), _ in zip(cc, FileConfiguration(config_path), range(20)):
yield style, descriptor.branch.formatted_descriptor(), params
def convert_params(params_file: os.PathLike):
p = Path(params_file)
if p.name == PARAM_FN:
d = _open_config(params_file)
save_toml(params_file, d)
print(f"converted {p}")
else:
for pp in p.glob(PARAM_FN):
convert_params(pp)
for pp in p.glob("fiber*"):
if pp.is_dir():
convert_params(pp)
def partial_plot(root: os.PathLike, lim: str = None):
path = Path(root)
fig, ax = plt.subplots(figsize=(12, 8))
fig.suptitle(path.name)
spec_list = sorted(
path.glob(SPEC1_FN.format("*")), key=lambda el: int(re.search("[0-9]+", el.name)[0])
)
params = Parameters(**load_toml(path / "params.toml"))
params.z_targets = params.z_targets[: len(spec_list)]
raw_values = np.array([load_spectrum(s) for s in spec_list])
if lim is None:
plot_range = units.PlotRange(
0.5 * params.interpolation_range[0] * 1e9,
1.1 * params.interpolation_range[1] * 1e9,
"nm",
)
else:
left_u, right_u, unit = lim.split(",")
plot_range = units.PlotRange(float(left_u), float(right_u), unit)
if plot_range.unit.type == "TIME":
values = params.ifft(raw_values)
log = False
vmin = None
else:
values = raw_values
log = "2D"
vmin = -60
x, y, values = transform_2D_propagation(
values,
plot_range,
params,
log=log,
)
ax.imshow(
values,
origin="lower",
aspect="auto",
vmin=vmin,
interpolation="nearest",
extent=get_extent(x, y),
)
return ax

View File

@@ -1,157 +0,0 @@
import argparse
import os
import re
import shutil
import subprocess
from datetime import datetime, timedelta
from pathlib import Path
from typing import Tuple
import numpy as np
from ..utils import Paths
from ..parameter import FileConfiguration
def primes(n):
prime_factors = []
d = 2
while d * d <= n:
while (n % d) == 0:
prime_factors.append(d)
n //= d
d += 1
if n > 1:
prime_factors.append(n)
return prime_factors
def balance(n, lim=(32, 32)):
factors = primes(n)
if len(factors) == 1:
factors = primes(n + 1)
a, b, x, y = 1, 1, 1, 1
while len(factors) > 0 and x <= lim[0] and y <= lim[1]:
a = x
b = y
if y >= x:
x *= factors.pop(0)
else:
y *= factors.pop()
return a, b
def distribute(
num: int, nodes: int = None, cpus_per_node: int = None, lim=(16, 32)
) -> Tuple[int, int]:
if nodes is None and cpus_per_node is None:
balanced = balance(num, lim)
if num > max(lim):
while np.product(balanced) < min(lim):
num += 1
balanced = balance(num, lim)
nodes = min(balanced)
cpus_per_node = max(balanced)
elif nodes is None:
nodes = num // cpus_per_node
while nodes > lim[0]:
nodes //= 2
elif cpus_per_node is None:
cpus_per_node = num // nodes
while cpus_per_node > lim[1]:
cpus_per_node //= 2
return nodes, cpus_per_node
def format_time(t):
try:
t = float(t)
return timedelta(minutes=t)
except ValueError:
return t
def create_parser():
parser = argparse.ArgumentParser(description="submit a job to a slurm cluster")
parser.add_argument("config", help="path to the toml configuration file")
parser.add_argument(
"-t", "--time", required=True, type=str, help="time required for the job in hh:mm:ss"
)
parser.add_argument(
"-c", "--cpus-per-node", default=None, type=int, help="number of cpus required per node"
)
parser.add_argument("-n", "--nodes", default=None, type=int, help="number of nodes required")
parser.add_argument(
"--environment-setup",
required=False,
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && "
"export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_PRINT_LEVEL=none && export SCGENERATOR_LOG_FILE_LEVEL=info",
help="commands to run to setup the environement (default : activate the sc environment with conda)",
)
parser.add_argument(
"--command", default="run", choices=["run", "resume", "merge"], help="command to run"
)
parser.add_argument("--dependency", default=None, help="sbatch dependency argument")
return parser
def copy_starting_files():
for name in ["start_worker", "start_head"]:
path = Paths.get(name)
file_name = os.path.split(path)[1]
shutil.copy(path, file_name)
mode = os.stat(file_name)
os.chmod(file_name, 0o100 | mode.st_mode)
def main():
command_map = dict(run="Propagate", resume="Resuming", merge="Merging")
parser = create_parser()
template = Paths.gets("submit_job_template")
args = parser.parse_args()
if args.dependency is None:
args.dependency = ""
else:
args.dependency = f"#SBATCH --dependency={args.dependency}"
if not re.match(r"^[0-9]{2}:[0-9]{2}:[0-9]{2}$", args.time) and not re.match(
r"^[0-9]+$", args.time
):
raise ValueError(
"time format must be an integer number of minute or must match the pattern hh:mm:ss"
)
config = FileConfiguration(args.config)
final_name = config.final_path
sim_num = config.num_sim
if args.command == "merge":
args.nodes = 1
args.cpus_per_node = 1
else:
args.nodes, args.cpus_per_node = distribute(config.num_sim, args.nodes, args.cpus_per_node)
submit_path = Path(
"submit " + final_name.replace("/", "") + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh"
)
tmp_path = Path("submit tmp.sh")
job_name = f"supercontinuum {final_name}"
submit_sh = template.format(job_name=job_name, **vars(args))
tmp_path.write_text(submit_sh)
subprocess.run(["sbatch", "--test-only", str(tmp_path)])
submit = input(
f"{command_map[args.command]} {sim_num} pulses from config {args.config} with {args.cpus_per_node} cpus"
+ f" per node on {args.nodes} nodes for {format_time(args.time)} ? (y/[n])\n"
)
if submit.lower() in ["y", "yes"]:
submit_path.write_text(submit_sh)
copy_starting_files()
subprocess.run(["sbatch", submit_path])
tmp_path.unlink()

View File

@@ -6,6 +6,7 @@ from typing import Any, Iterator, Sequence
import numba import numba
import numpy as np import numpy as np
from scgenerator.math import abs2 from scgenerator.math import abs2
from scgenerator.operators import SpecOperator from scgenerator.operators import SpecOperator
from scgenerator.utils import TimedMessage from scgenerator.utils import TimedMessage
@@ -133,6 +134,7 @@ def solve43(
targets = list(sorted(set(targets))) targets = list(sorted(set(targets)))
if targets[0] == 0: if targets[0] == 0:
targets.pop(0) targets.pop(0)
h = min(h, targets[0] / 2)
step_ind = 0 step_ind = 0
msg = TimedMessage(2) msg = TimedMessage(2)

View File

@@ -405,6 +405,7 @@ class SimulatedFiber:
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind)) return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
else: else:
return load_spectrum(self.path / SPEC1_FN.format(z_ind)) return load_spectrum(self.path / SPEC1_FN.format(z_ind))
psd = np.fft.rfft(signal) / np.sqrt(0.5 * len(time) / dt)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path})" return f"{self.__class__.__name__}(path={self.path})"

View File

@@ -23,15 +23,13 @@ import pkg_resources as pkg
import tomli import tomli
import tomli_w import tomli_w
from scgenerator.const import (PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS, from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, ROOT_PARAMETERS, SPEC1_FN, Z_FN
SPEC1_FN, Z_FN)
from scgenerator.errors import DuplicateParameterError from scgenerator.errors import DuplicateParameterError
from scgenerator.logger import get_logger from scgenerator.logger import get_logger
T_ = TypeVar("T_") T_ = TypeVar("T_")
class TimedMessage: class TimedMessage:
def __init__(self, interval: float = 10.0): def __init__(self, interval: float = 10.0):
self.interval = datetime.timedelta(seconds=interval) self.interval = datetime.timedelta(seconds=interval)
@@ -179,6 +177,7 @@ def _open_config(path: os.PathLike):
return dico return dico
def resolve_relative_paths(d: dict[str, Any], root: os.PathLike | None = None): def resolve_relative_paths(d: dict[str, Any], root: os.PathLike | None = None):
root = Path(root) if root is not None else Path.cwd() root = Path(root) if root is not None else Path.cwd()
for k, v in d.items(): for k, v in d.items():
@@ -192,7 +191,6 @@ def resolve_relative_paths(d:dict[str, Any], root:os.PathLike | None=None):
d[k] = str(root / v) d[k] = str(root / v)
def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]: def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]:
if (f_list := dico.pop("INCLUDE", None)) is not None: if (f_list := dico.pop("INCLUDE", None)) is not None:
if isinstance(f_list, str): if isinstance(f_list, str):

View File

@@ -1,336 +0,0 @@
import itertools
from collections.abc import MutableMapping, Sequence
from math import prod
from pathlib import Path
from typing import Any, Callable, Generator, Generic, Iterable, Iterator, Optional, TypeVar, Union
import numpy as np
from pydantic import validator
from pydantic.main import BaseModel
from scgenerator.const import PARAM_SEPARATOR
T = TypeVar("T")
class VariationSpecsError(ValueError):
pass
class Variationer:
"""
manages possible combinations of values given dicts of lists
Example
-------
```
>> var = Variationer([dict(a=[1, 2]), [dict(b=["000", "111"], c=["a", "-1"])]])
list(v.raw_descr for v in var.iterate())
[
((("a", 1),), (("b", "000"), ("c", "a"))),
((("a", 1),), (("b", "111"), ("c", "-1"))),
((("a", 2),), (("b", "000"), ("c", "a"))),
((("a", 2),), (("b", "111"), ("c", "-1"))),
]
```
"""
all_indices: list[list[int]]
all_dicts: list[list[dict[str, list]]]
def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]] = None):
self.all_indices = []
self.all_dicts = []
if variables is not None:
for i, el in enumerate(variables):
self.append(el)
def append(self, var_list: Union[list[MutableMapping], MutableMapping]):
"""append a list of variable parameter sets
each call to append creates a new group of parameters
Parameters
----------
var_list : Union[list[MutableMapping], MutableMapping]
each dict in the list is treated as an independent parameter
this means that if for one dict, len > 1, the lists of possible values
must be the same length
Example
-------
`>> append([dict(wavelength=[800e-9, 900e-9], power=[1e3, 2e3]), dict(length=[3e-2, 3.5e-2, 4e-2])])`
means that for every parameter variations, wavelength=800e-9 will always occur when power=1e3 and
vice versa, while length is free to vary independently
Raises
------
VariationSpecsError
raised when possible values lists in a same dict are not the same length
"""
if not isinstance(var_list, Sequence):
var_list = [{k: v} for k, v in var_list.items()]
else:
var_list = list(var_list)
num_vars = []
for d in var_list:
values = list(d.values())
if len(values) > 0:
len_to_test = len(values[0])
if not all(len(v) == len_to_test for v in values[1:]):
raise VariationSpecsError(
"variable items should all have the same number of parameters"
)
num_vars.append(len_to_test)
if len(num_vars) == 0:
num_vars = [1]
self.all_indices.append(num_vars)
self.all_dicts.append(var_list)
def iterate(self, index: int = -1) -> Generator["VariationDescriptor", None, None]:
index = self.__index(index)
flattened_indices = sum(self.all_indices[: index + 1], [])
index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[: index + 1]])
ranges = [range(i) for i in flattened_indices]
for r in itertools.product(*ranges):
out: list[list[tuple[str, Any]]] = []
indicies: list[list[int]] = []
for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])):
out.append([])
indicies.append([])
for value_index, var_d in zip(r[start:end], self.all_dicts[i]):
for k, v in var_d.items():
out[-1].append((k, v[value_index]))
indicies[-1].append(value_index)
yield VariationDescriptor(raw_descr=out, index=indicies)
def __index(self, index: int) -> int:
if index < 0:
index = len(self.all_indices) + index
return index
def var_num(self, index: int = -1) -> int:
index = self.__index(index)
return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
class VariationDescriptor(BaseModel):
raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
index: tuple[tuple[int, ...], ...]
separator: str = "fiber"
_format_registry: dict[str, Callable[..., str]] = {}
__ids: dict[int, int] = {}
@classmethod
def register_formatter(cls, p_name: str, func: Callable[..., str]):
"""register a function that formats a particular parameter
Parameters
----------
p_name : str
name of the parameter
func : Callable[..., str]
function that takes as single argument the value of the parameter and returns a string
"""
cls._format_registry[p_name] = func
@classmethod
def format_value(cls, name: str, value) -> str:
if value is True or value is False:
return str(value)
elif isinstance(value, (float, int)):
try:
return cls._format_registry[name](value)
except KeyError:
return format(value, ".9g")
elif isinstance(value, (list, tuple, np.ndarray)):
return "-".join([str(v) for v in value])
elif isinstance(value, str):
p = Path(value)
if p.exists():
return p.stem
elif callable(value):
return getattr(value, "__name__", repr(value))
return str(value)
class Config:
allow_mutation = False
def formatted_descriptor(self, add_identifier=False) -> str:
"""formats a variable list into a str such that each simulation has a unique
directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations)
branch identifier can added at the beginning.
Parameters
----------
add_identifier : bool
add unique simulation and parameter-set identifiers
Returns
-------
str
simulation descriptor
"""
str_list = []
for p_name, p_value in self.flat:
ps, vs = self._format_single_pair(p_name, p_value)
str_list.append(ps + PARAM_SEPARATOR + vs)
tmp_name = PARAM_SEPARATOR.join(str_list)
if not add_identifier:
return tmp_name
return (
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
)
def _format_single_pair(self, p_name: str, p_value: Any) -> tuple[str, str]:
ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "")
vs = self.format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
return ps, vs
def __getitem__(self, key) -> "VariationDescriptor":
return VariationDescriptor(
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
)
def __str__(self) -> str:
return self.formatted_descriptor(add_identifier=False)
def __lt__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr < other.raw_descr
def __le__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr <= other.raw_descr
def __gt__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr > other.raw_descr
def __ge__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr >= other.raw_descr
def __eq__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr == other.raw_descr
def __hash__(self) -> int:
return hash(self.raw_descr)
def __contains__(self, other: "VariationDescriptor") -> bool:
return all(el in self.raw_descr for el in other.raw_descr)
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
"""updates a dictionary with the value of the descriptor
Parameters
----------
cfg : dict[str, Any]
dict to be updated
index : int, optional
index of the fiber from which to apply the parameters, by default -1
Returns
-------
dict[str, Any]
same as cfg but with key from the descriptor added/updated.
"""
out_cfg = cfg.copy()
out_cfg.pop("variable", None)
return out_cfg | {k: v for k, v in self.raw_descr[index]}
def iter_parents(self) -> Iterator["VariationDescriptor"]:
if (p := self.parent) is not None:
yield from p.iter_parents()
yield self
@property
def short(self) -> str:
"""shortened description of the simulation"""
return " ".join(
self._format_single_pair(p, v)[1] for p, v in self.flat if p not in {"fiber", "num"}
)
@property
def flat(self) -> list[tuple[str, Any]]:
out = []
for n, variable_list in enumerate(self.raw_descr):
out += [
(self.separator, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[n % 26] * (n // 26 + 1)),
*variable_list,
]
return out
@property
def branch(self) -> "BranchDescriptor":
descr: list[list[tuple[str, Any]]] = []
ind: list[list[int]] = []
for i, l in enumerate(self.raw_descr):
descr.append([])
ind.append([])
for j, (k, v) in enumerate(l):
if k != "num":
descr[-1].append((k, v))
ind[-1].append(self.index[i][j])
return BranchDescriptor(raw_descr=descr, index=ind, separator=self.separator)
@property
def identifier(self) -> str:
unique_id = hash(str(self.flat))
self.__ids.setdefault(unique_id, len(self.__ids))
return "u_" + str(self.__ids[unique_id])
@property
def parent(self) -> Optional["VariationDescriptor"]:
if len(self.raw_descr) < 2:
return None
return VariationDescriptor(
raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator
)
@property
def num_fibers(self) -> int:
return len(self.raw_descr)
class BranchDescriptor(VariationDescriptor):
__ids: dict[int, int] = {}
@property
def identifier(self) -> str:
branch_id = hash(str(self.flat))
self.__ids.setdefault(branch_id, len(self.__ids))
return "b_" + str(self.__ids[branch_id])
@validator("raw_descr")
def validate_raw_descr(cls, v):
return tuple(tuple(el for el in variable if el[0] != "num") for variable in v)
class DescriptorDict(Generic[T]):
def __init__(self, dico: dict[VariationDescriptor, T] = None):
self.dico: dict[tuple[tuple[tuple[str, Any], ...], ...], tuple[VariationDescriptor, T]] = {}
if dico is not None:
for k, v in dico.items():
self[k] = v
def __setitem__(self, key: VariationDescriptor, value: T):
if not isinstance(key, VariationDescriptor):
raise TypeError("key must be a VariationDescriptor instance")
self.dico[key.raw_descr] = (key, value)
def __getitem__(
self, key: Union[VariationDescriptor, tuple[tuple[tuple[str, Any], ...], ...]]
) -> T:
if isinstance(key, VariationDescriptor):
return self.dico[key.raw_descr][1]
else:
return self.dico[key][1]
def items(self) -> Iterator[tuple[VariationDescriptor, T]]:
for k, v in self.dico.items():
yield k, v[1]
def keys(self) -> list[VariationDescriptor]:
return [v[0] for v in self.dico.values()]
def values(self) -> list[T]:
return [v[1] for v in self.dico.values()]