Sim starts, merge not
This commit is contained in:
@@ -5,4 +5,5 @@ from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
|
|||||||
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
||||||
from .spectra import Pulse, Spectrum
|
from .spectra import Pulse, Spectrum
|
||||||
from .utils import Paths, open_config, parameter
|
from .utils import Paths, open_config, parameter
|
||||||
from .utils.parameter import Configuration, Parameters, PlotRange
|
from .utils.parameter import Configuration, Parameters
|
||||||
|
from .utils.utils import PlotRange
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from send2trash import send2trash
|
|||||||
|
|
||||||
from .. import env, utils
|
from .. import env, utils
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..utils.parameter import Configuration, Parameters, format_variable_list
|
from ..utils.parameter import Configuration, Parameters
|
||||||
from . import pulse
|
from . import pulse
|
||||||
from .fiber import create_non_linear_op, fast_dispersion_op
|
from .fiber import create_non_linear_op, fast_dispersion_op
|
||||||
|
|
||||||
@@ -466,14 +466,14 @@ class Simulations:
|
|||||||
|
|
||||||
self.configuration = configuration
|
self.configuration = configuration
|
||||||
|
|
||||||
self.name = self.configuration.final_path
|
self.name = self.configuration.name
|
||||||
self.sim_dir = self.configuration.final_sim_dir
|
self.sim_dir = self.configuration.final_path
|
||||||
self.configuration.save_parameters()
|
self.configuration.save_parameters()
|
||||||
|
|
||||||
self.sim_jobs_per_node = 1
|
self.sim_jobs_per_node = 1
|
||||||
|
|
||||||
def finished_and_complete(self):
|
def finished_and_complete(self):
|
||||||
for sim in self.configuration.all_configs_dict.values():
|
for sim in self.configuration.all_configs.values():
|
||||||
if (
|
if (
|
||||||
self.configuration.sim_status(sim.output_path)[0]
|
self.configuration.sim_status(sim.output_path)[0]
|
||||||
!= self.configuration.State.COMPLETE
|
!= self.configuration.State.COMPLETE
|
||||||
@@ -487,7 +487,7 @@ class Simulations:
|
|||||||
|
|
||||||
def _run_available(self):
|
def _run_available(self):
|
||||||
for variable, params in self.configuration:
|
for variable, params in self.configuration:
|
||||||
v_list_str = format_variable_list(variable, add_iden=True)
|
v_list_str = variable.formatted_descriptor(True)
|
||||||
utils.save_parameters(params.prepare_for_dump(), Path(params.output_path))
|
utils.save_parameters(params.prepare_for_dump(), Path(params.output_path))
|
||||||
|
|
||||||
self.new_sim(v_list_str, params)
|
self.new_sim(v_list_str, params)
|
||||||
@@ -526,7 +526,9 @@ class SequencialSimulations(Simulations, priority=0):
|
|||||||
def __init__(self, configuration: Configuration, task_id):
|
def __init__(self, configuration: Configuration, task_id):
|
||||||
super().__init__(configuration, task_id=task_id)
|
super().__init__(configuration, task_id=task_id)
|
||||||
self.pbars = utils.PBars(
|
self.pbars = utils.PBars(
|
||||||
self.configuration.total_num_steps, "Simulating " + self.configuration.final_path, 1
|
self.configuration.total_num_steps,
|
||||||
|
"Simulating " + self.configuration.final_path.name,
|
||||||
|
1,
|
||||||
)
|
)
|
||||||
self.configuration.skip_callback = lambda num: self.pbars.update(0, num)
|
self.configuration.skip_callback = lambda num: self.pbars.update(0, num)
|
||||||
|
|
||||||
@@ -569,7 +571,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
|||||||
self.p_worker = multiprocessing.Process(
|
self.p_worker = multiprocessing.Process(
|
||||||
target=utils.progress_worker,
|
target=utils.progress_worker,
|
||||||
args=(
|
args=(
|
||||||
self.configuration.final_path,
|
self.configuration.final_path.name,
|
||||||
self.sim_jobs_per_node,
|
self.sim_jobs_per_node,
|
||||||
self.configuration.total_num_steps,
|
self.configuration.total_num_steps,
|
||||||
self.progress_queue,
|
self.progress_queue,
|
||||||
@@ -716,7 +718,7 @@ def run_simulation(
|
|||||||
|
|
||||||
sim = new_simulation(config, method)
|
sim = new_simulation(config, method)
|
||||||
sim.run()
|
sim.run()
|
||||||
path_trees = utils.build_path_trees(config.sim_dirs[-1])
|
path_trees = utils.build_path_trees(config.fiber_paths[-1])
|
||||||
|
|
||||||
final_name = env.get(env.OUTPUT_PATH)
|
final_name = env.get(env.OUTPUT_PATH)
|
||||||
if final_name is None:
|
if final_name is None:
|
||||||
@@ -724,7 +726,7 @@ def run_simulation(
|
|||||||
|
|
||||||
utils.merge(final_name, path_trees)
|
utils.merge(final_name, path_trees)
|
||||||
try:
|
try:
|
||||||
send2trash(config.sim_dirs)
|
send2trash(config.fiber_paths)
|
||||||
except (PermissionError, OSError):
|
except (PermissionError, OSError):
|
||||||
get_logger(__name__).error("Could not send temporary directories to trash")
|
get_logger(__name__).error("Could not send temporary directories to trash")
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ from .const import PARAM_SEPARATOR
|
|||||||
from .defaults import default_plotting as defaults
|
from .defaults import default_plotting as defaults
|
||||||
from .math import abs2, span
|
from .math import abs2, span
|
||||||
from .physics import pulse, units
|
from .physics import pulse, units
|
||||||
from .utils.parameter import Parameters, PlotRange, sort_axis
|
from .utils.parameter import Parameters
|
||||||
|
from .utils.utils import PlotRange, sort_axis
|
||||||
|
|
||||||
RangeType = tuple[float, float, Union[str, Callable]]
|
RangeType = tuple[float, float, Union[str, Callable]]
|
||||||
NO_LIM = object()
|
NO_LIM = object()
|
||||||
|
|||||||
@@ -16,9 +16,8 @@ from ..utils import auto_crop, open_config, save_toml, translate_parameters
|
|||||||
from ..utils.parameter import (
|
from ..utils.parameter import (
|
||||||
Configuration,
|
Configuration,
|
||||||
Parameters,
|
Parameters,
|
||||||
pretty_format_from_sim_name,
|
|
||||||
pretty_format_value,
|
|
||||||
)
|
)
|
||||||
|
from ..utils.variationer import VariationDescriptor
|
||||||
|
|
||||||
|
|
||||||
def fingerprint(params: Parameters):
|
def fingerprint(params: Parameters):
|
||||||
@@ -46,7 +45,7 @@ def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
|
|||||||
path, fig, ax = plot_setup(
|
path, fig, ax = plot_setup(
|
||||||
pulse.path.parent
|
pulse.path.parent
|
||||||
/ (
|
/ (
|
||||||
pretty_format_from_sim_name(pulse.path.name)
|
pulse.path.name
|
||||||
+ PARAM_SEPARATOR
|
+ PARAM_SEPARATOR
|
||||||
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
|
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ from .plotting import (
|
|||||||
single_position_plot,
|
single_position_plot,
|
||||||
transform_2D_propagation,
|
transform_2D_propagation,
|
||||||
)
|
)
|
||||||
from .utils.parameter import Parameters, PlotRange
|
from .utils.parameter import Parameters
|
||||||
|
from .utils.utils import PlotRange
|
||||||
from .utils import load_spectrum
|
from .utils import load_spectrum
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -161,29 +161,35 @@ def save_toml(path: os.PathLike, dico):
|
|||||||
return dico
|
return dico
|
||||||
|
|
||||||
|
|
||||||
def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str, Any]], str]:
|
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]:
|
||||||
loaded_config = open_config(final_config_path)
|
"""loads a configuration file
|
||||||
final_name = loaded_config.get("name")
|
|
||||||
fiber_list = loaded_config.pop("Fiber")
|
|
||||||
configs = []
|
|
||||||
if fiber_list is not None:
|
|
||||||
master_variable = loaded_config.get("variable", {})
|
|
||||||
for i, params in enumerate(fiber_list):
|
|
||||||
params.setdefault("variable", master_variable if i == 0 else {})
|
|
||||||
if i == 0:
|
|
||||||
params["variable"] |= master_variable
|
|
||||||
configs.append(loaded_config | params)
|
|
||||||
else:
|
|
||||||
configs.append(loaded_config)
|
|
||||||
while "previous_config_file" in configs[0]:
|
|
||||||
configs.insert(0, open_config(configs[0]["previous_config_file"]))
|
|
||||||
configs[0].setdefault("variable", {})
|
|
||||||
for pre, nex in zip(configs[:-1], configs[1:]):
|
|
||||||
variable = nex.pop("variable", {})
|
|
||||||
nex.update({k: v for k, v in pre.items() if k not in nex})
|
|
||||||
nex["variable"] = variable
|
|
||||||
|
|
||||||
return configs, final_name
|
Parameters
|
||||||
|
----------
|
||||||
|
path : os.PathLike
|
||||||
|
path to the config toml file
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
final_path : Path
|
||||||
|
output name of the simulation
|
||||||
|
list[dict[str, Any]]
|
||||||
|
one config per fiber
|
||||||
|
|
||||||
|
"""
|
||||||
|
loaded_config = open_config(path)
|
||||||
|
|
||||||
|
fiber_list: list[dict[str, Any]] = loaded_config.pop("Fiber")
|
||||||
|
if len(fiber_list) == 0:
|
||||||
|
raise ValueError(f"No fiber in config {path}")
|
||||||
|
final_path = loaded_config.get("name")
|
||||||
|
configs = []
|
||||||
|
for i, params in enumerate(fiber_list):
|
||||||
|
params.setdefault("variable", {})
|
||||||
|
configs.append(loaded_config | params)
|
||||||
|
configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"]
|
||||||
|
|
||||||
|
return Path(final_path), configs
|
||||||
|
|
||||||
|
|
||||||
def save_parameters(
|
def save_parameters(
|
||||||
|
|||||||
@@ -11,14 +11,18 @@ from dataclasses import asdict, dataclass, fields
|
|||||||
from functools import cache, lru_cache
|
from functools import cache, lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union
|
from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.lib import isin
|
from numpy.lib import isin
|
||||||
|
from scgenerator.utils import ensure_folder, variationer
|
||||||
|
|
||||||
from .. import math, utils
|
from .. import math, utils
|
||||||
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
|
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
|
||||||
from ..errors import EvaluatorError, NoDefaultError
|
from ..errors import EvaluatorError, NoDefaultError
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..physics import fiber, materials, pulse, units
|
from ..physics import fiber, materials, pulse, units
|
||||||
|
from ..utils.variationer import VariationDescriptor, Variationer
|
||||||
|
from .utils import func_rewrite, _mock_function, get_arg_names
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@@ -256,7 +260,7 @@ class Parameter:
|
|||||||
----------
|
----------
|
||||||
tpe : type
|
tpe : type
|
||||||
type of the paramter
|
type of the paramter
|
||||||
validators : Callable[[str, Any], None]
|
validator : Callable[[str, Any], None]
|
||||||
signature : validator(name, value)
|
signature : validator(name, value)
|
||||||
must raise a ValueError when value doesn't fit the criteria checked by
|
must raise a ValueError when value doesn't fit the criteria checked by
|
||||||
validator. name is passed to validator to be included in the error message
|
validator. name is passed to validator to be included in the error message
|
||||||
@@ -290,7 +294,6 @@ class Parameter:
|
|||||||
if isinstance(value, Parameter):
|
if isinstance(value, Parameter):
|
||||||
defaut = None if self.default is None else copy(self.default)
|
defaut = None if self.default is None else copy(self.default)
|
||||||
instance.__dict__[self.name] = defaut
|
instance.__dict__[self.name] = defaut
|
||||||
# instance.__dict__[self.name] = None
|
|
||||||
else:
|
else:
|
||||||
if value is not None:
|
if value is not None:
|
||||||
if self.converter is not None:
|
if self.converter is not None:
|
||||||
@@ -768,9 +771,11 @@ class Configuration:
|
|||||||
obj with the output path of the simulation saved in its output_path attribute.
|
obj with the output path of the simulation saved in its output_path attribute.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
master_configs: list[dict[str, Any]]
|
fiber_configs: list[dict[str, Any]]
|
||||||
sim_dirs: list[Path]
|
master_config: dict[str, Any]
|
||||||
|
fiber_paths: list[Path]
|
||||||
num_sim: int
|
num_sim: int
|
||||||
|
num_fibers: int
|
||||||
repeat: int
|
repeat: int
|
||||||
z_num: int
|
z_num: int
|
||||||
total_num_steps: int
|
total_num_steps: int
|
||||||
@@ -778,19 +783,17 @@ class Configuration:
|
|||||||
parallel: bool
|
parallel: bool
|
||||||
overwrite: bool
|
overwrite: bool
|
||||||
final_path: str
|
final_path: str
|
||||||
all_configs_dict: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
|
all_configs: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
|
||||||
all_configs_list: list[list["Configuration.__SimConfig"]]
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class __SimConfig:
|
class __SimConfig:
|
||||||
vary_list: list[tuple[str, Any]]
|
descriptor: VariationDescriptor
|
||||||
config: dict[str, Any]
|
config: dict[str, Any]
|
||||||
output_path: Path
|
output_path: Path
|
||||||
index: tuple[tuple[int, ...], ...]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sim_num(self) -> int:
|
def sim_num(self) -> int:
|
||||||
return len(self.index)
|
return len(self.descriptor.index)
|
||||||
|
|
||||||
class State(enum.Enum):
|
class State(enum.Enum):
|
||||||
COMPLETE = enum.auto()
|
COMPLETE = enum.auto()
|
||||||
@@ -810,48 +813,48 @@ class Configuration:
|
|||||||
):
|
):
|
||||||
self.logger = get_logger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
self.master_configs, self.final_path = utils.load_config_sequence(final_config_path)
|
self.overwrite = overwrite
|
||||||
if self.final_path is None:
|
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
|
||||||
self.final_path = Parameters.name.default
|
self.final_path = utils.ensure_folder(
|
||||||
self.name = Path(self.final_path).name
|
self.final_path, mkdir=False, prevent_overwrite=not self.overwrite
|
||||||
|
)
|
||||||
|
self.master_config = self.fiber_configs[0]
|
||||||
|
self.name = self.final_path.name
|
||||||
self.z_num = 0
|
self.z_num = 0
|
||||||
self.total_num_steps = 0
|
self.total_num_steps = 0
|
||||||
self.sim_dirs = []
|
self.fiber_paths = []
|
||||||
self.overwrite = overwrite
|
self.all_configs = {}
|
||||||
self.skip_callback = skip_callback
|
self.skip_callback = skip_callback
|
||||||
self.worker_num = self.master_configs[0].get("worker_num", max(1, os.cpu_count() // 2))
|
self.worker_num = self.master_config.get("worker_num", max(1, os.cpu_count() // 2))
|
||||||
self.repeat = self.master_configs[0].get("repeat", 1)
|
self.repeat = self.master_config.get("repeat", 1)
|
||||||
|
self.variationer = Variationer()
|
||||||
|
|
||||||
names = set()
|
fiber_names = set()
|
||||||
for i, config in enumerate(self.master_configs):
|
self.num_fibers = 0
|
||||||
|
for i, config in enumerate(self.fiber_configs):
|
||||||
|
config.setdefault("name", Parameters.name.default)
|
||||||
self.z_num += config["z_num"]
|
self.z_num += config["z_num"]
|
||||||
config.setdefault("name", f"{Parameters.name.default} {i}")
|
fiber_names.add(config["name"])
|
||||||
given_name = config["name"]
|
self.variationer.append(config.pop("variable"))
|
||||||
fn_i = 0
|
self.fiber_paths.append(
|
||||||
while config["name"] in names:
|
|
||||||
config["name"] = given_name + f"_{fn_i}"
|
|
||||||
fn_i += 1
|
|
||||||
names.add(config["name"])
|
|
||||||
|
|
||||||
self.sim_dirs.append(
|
|
||||||
utils.ensure_folder(
|
utils.ensure_folder(
|
||||||
Path("_".join(["_", self.name, Path(config["name"]).name, "_"])),
|
self.fiber_path(i, config),
|
||||||
mkdir=False,
|
mkdir=False,
|
||||||
prevent_overwrite=not self.overwrite,
|
prevent_overwrite=not self.overwrite,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.__validate_variable(config)
|
self.__validate_variable(config)
|
||||||
self.__compute_sim_dirs()
|
self.num_fibers += 1
|
||||||
[Evaluator.evaluate_default(c[0].config, True) for c in self.all_configs_list]
|
Evaluator.evaluate_default(config, True)
|
||||||
self.num_sim = len(self.all_configs_list[-1])
|
self.num_sim = self.variationer.var_num()
|
||||||
self.total_num_steps = sum(
|
self.total_num_steps = sum(
|
||||||
config["z_num"] * len(self.all_configs_list[i])
|
config["z_num"] * self.variationer.var_num(i)
|
||||||
for i, config in enumerate(self.master_configs)
|
for i, config in enumerate(self.fiber_configs)
|
||||||
)
|
)
|
||||||
self.final_sim_dir = utils.ensure_folder(
|
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
|
||||||
Path(self.master_configs[-1]["name"]), mkdir=False, prevent_overwrite=not self.overwrite
|
|
||||||
)
|
def fiber_path(self, i: int, full_config: dict[str, Any]) -> Path:
|
||||||
self.parallel = self.master_configs[0].get("parallel", Parameters.parallel.default)
|
return self.final_path / PARAM_SEPARATOR.join([format(i), self.name, full_config["name"]])
|
||||||
|
|
||||||
def __validate_variable(self, config: dict[str, Any]):
|
def __validate_variable(self, config: dict[str, Any]):
|
||||||
for k, v in config.get("variable", {}).items():
|
for k, v in config.get("variable", {}).items():
|
||||||
@@ -862,76 +865,62 @@ class Configuration:
|
|||||||
if len(v) == 0:
|
if len(v) == 0:
|
||||||
raise ValueError(f"variable parameter {k!r} must not be empty")
|
raise ValueError(f"variable parameter {k!r} must not be empty")
|
||||||
|
|
||||||
def __compute_sim_dirs(self):
|
def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]:
|
||||||
self.all_configs_dict = {}
|
for i in range(self.num_fibers):
|
||||||
self.all_configs_list = []
|
for sim_config in self.iterate_single_fiber(i):
|
||||||
self.master_configs[0]["variable"]["num"] = list(
|
|
||||||
range(self.master_configs[0].get("repeat", 1))
|
|
||||||
)
|
|
||||||
dp = DataPather([c["variable"] for c in self.master_configs])
|
|
||||||
for i, conf in enumerate(self.master_configs):
|
|
||||||
self.all_configs_list.append([])
|
|
||||||
for sim_index, prev_path, this_path, this_vary in dp.all_vary_list(i):
|
|
||||||
this_conf = conf.copy()
|
|
||||||
if i > 0:
|
if i > 0:
|
||||||
prev_path = utils.ensure_folder(
|
sim_config.config["prev_data_dir"] = str(
|
||||||
self.sim_dirs[i - 1] / prev_path, not self.overwrite, False
|
self.fiber_paths[i - 1] / sim_config.descriptor[:i].formatted_descriptor()
|
||||||
)
|
)
|
||||||
this_conf["prev_data_dir"] = str(prev_path)
|
params = Parameters(**sim_config.config)
|
||||||
|
params.compute()
|
||||||
this_path = utils.ensure_folder(
|
|
||||||
self.sim_dirs[i] / this_path, not self.overwrite, False
|
|
||||||
)
|
|
||||||
this_conf.pop("variable")
|
|
||||||
conf_to_use = {k: v for k, v in this_vary if k != "num"} | this_conf
|
|
||||||
self.all_configs_dict[sim_index] = self.__SimConfig(
|
|
||||||
this_vary, conf_to_use, this_path, sim_index
|
|
||||||
)
|
|
||||||
self.all_configs_list[i].append(self.all_configs_dict[sim_index])
|
|
||||||
|
|
||||||
def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]:
|
|
||||||
for i, sim_config_list in enumerate(self.all_configs_list):
|
|
||||||
for sim_config, params in self.__iter_1_sim(sim_config_list):
|
|
||||||
fiber_map = []
|
fiber_map = []
|
||||||
for j in range(i + 1):
|
for j in range(i + 1):
|
||||||
this_conf = self.all_configs_dict[sim_config.index[: j + 1]].config
|
this_conf = self.all_configs[sim_config.descriptor.index[: j + 1]].config
|
||||||
if j > 0:
|
if j > 0:
|
||||||
prev_conf = self.all_configs_dict[sim_config.index[:j]].config
|
prev_conf = self.all_configs[sim_config.descriptor.index[:j]].config
|
||||||
length = prev_conf["length"] + fiber_map[j - 1][0]
|
length = prev_conf["length"] + fiber_map[j - 1][0]
|
||||||
else:
|
else:
|
||||||
length = 0.0
|
length = 0.0
|
||||||
fiber_map.append((length, this_conf["name"]))
|
fiber_map.append((length, this_conf["name"]))
|
||||||
params.output_path = str(sim_config.output_path)
|
|
||||||
params.fiber_map = fiber_map
|
params.fiber_map = fiber_map
|
||||||
yield sim_config.vary_list, params
|
yield sim_config.descriptor, params
|
||||||
|
|
||||||
def __iter_1_sim(
|
def iterate_single_fiber(
|
||||||
self, configs: list["Configuration.__SimConfig"]
|
self, index: int
|
||||||
) -> Generator[tuple["Configuration.__SimConfig", Parameters], None, None]:
|
) -> Generator["Configuration.__SimConfig", None, None]:
|
||||||
"""iterates through the parameters of only one fiber. It takes care of recovering partially
|
"""iterates through the parameters of only one fiber. It takes care of recovering partially
|
||||||
completed simulations, skipping complete ones and waiting for the previous fiber to finish
|
completed simulations, skipping complete ones and waiting for the previous fiber to finish
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
configs : list[__SimConfig]
|
index : int
|
||||||
list of configuration obj
|
which fiber to iterate over
|
||||||
|
|
||||||
Yields
|
Yields
|
||||||
-------
|
-------
|
||||||
__SimConfig
|
__SimConfig
|
||||||
configuration obj
|
configuration obj
|
||||||
Parameters
|
|
||||||
computed Parameters obj
|
|
||||||
"""
|
"""
|
||||||
sim_dict: dict[Path, Configuration.__SimConfig] = {s.output_path: s for s in configs}
|
sim_dict: dict[Path, self.__SimConfig] = {}
|
||||||
|
for descr in self.variationer.iterate(index):
|
||||||
|
cfg = descr.update_config(self.fiber_configs[index])
|
||||||
|
p = ensure_folder(
|
||||||
|
self.fiber_paths[index] / descr.formatted_descriptor(),
|
||||||
|
not self.overwrite,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
cfg["output_path"] = str(p)
|
||||||
|
sim_config = self.__SimConfig(descr, cfg, p)
|
||||||
|
sim_dict[p] = sim_config
|
||||||
|
self.all_configs[sim_config.descriptor.index] = sim_config
|
||||||
while len(sim_dict) > 0:
|
while len(sim_dict) > 0:
|
||||||
for data_dir, sim_config in sim_dict.items():
|
for data_dir, sim_config in sim_dict.items():
|
||||||
task, config_dict = self.__decide(sim_config)
|
task, config_dict = self.__decide(sim_config)
|
||||||
if task == self.Action.RUN:
|
if task == self.Action.RUN:
|
||||||
sim_dict.pop(data_dir)
|
sim_dict.pop(data_dir)
|
||||||
p = Parameters(**config_dict)
|
yield sim_config
|
||||||
p.compute()
|
|
||||||
yield sim_config, p
|
|
||||||
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
|
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
|
||||||
self.skip_callback(config_dict["recovery_last_stored"])
|
self.skip_callback(config_dict["recovery_last_stored"])
|
||||||
break
|
break
|
||||||
@@ -956,7 +945,7 @@ class Configuration:
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
str : {'run', 'wait', 'skip'}
|
str : Configuration.Action
|
||||||
what to do
|
what to do
|
||||||
config_dict : dict[str, Any]
|
config_dict : dict[str, Any]
|
||||||
config dictionary. The only key possibly modified is 'prev_data_dir', which
|
config dictionary. The only key possibly modified is 'prev_data_dir', which
|
||||||
@@ -1012,7 +1001,7 @@ class Configuration:
|
|||||||
raise ValueError(f"Too many spectra in {data_dir}")
|
raise ValueError(f"Too many spectra in {data_dir}")
|
||||||
|
|
||||||
def save_parameters(self):
|
def save_parameters(self):
|
||||||
for config, sim_dir in zip(self.master_configs, self.sim_dirs):
|
for config, sim_dir in zip(self.fiber_configs, self.fiber_paths):
|
||||||
os.makedirs(sim_dir, exist_ok=True)
|
os.makedirs(sim_dir, exist_ok=True)
|
||||||
utils.save_toml(sim_dir / f"initial_config.toml", config)
|
utils.save_toml(sim_dir / f"initial_config.toml", config)
|
||||||
|
|
||||||
@@ -1022,144 +1011,6 @@ class Configuration:
|
|||||||
return param
|
return param
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class PlotRange:
|
|
||||||
left: float = Parameter(type_checker(int, float))
|
|
||||||
right: float = Parameter(type_checker(int, float))
|
|
||||||
unit: Callable[[float], float] = Parameter(units.is_unit, converter=units.get_unit)
|
|
||||||
conserved_quantity: bool = Parameter(boolean, default=True)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
|
||||||
|
|
||||||
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
|
||||||
return sort_axis(axis, self)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
yield self.left
|
|
||||||
yield self.right
|
|
||||||
yield self.unit.__name__
|
|
||||||
|
|
||||||
|
|
||||||
def sort_axis(
|
|
||||||
axis: np.ndarray, plt_range: PlotRange
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
|
||||||
"""
|
|
||||||
given an axis, returns this axis cropped according to the given range, converted and sorted
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
axis : 1D array containing the original axis (usual the w or t array)
|
|
||||||
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
cropped : the axis cropped, converted and sorted
|
|
||||||
indices : indices to use to slice and sort other array in the same fashion
|
|
||||||
extent : tupple with min and max of cropped
|
|
||||||
|
|
||||||
Example
|
|
||||||
-------
|
|
||||||
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
|
||||||
t = np.linspace(-10, 10, 400)
|
|
||||||
W, T = np.meshgrid(w, t)
|
|
||||||
y = np.exp(-W**2 - T**2)
|
|
||||||
|
|
||||||
# Define ranges
|
|
||||||
rw = (-4, 4, s)
|
|
||||||
rt = (-2, 6, s)
|
|
||||||
|
|
||||||
w, cw = sort_axis(w, rw)
|
|
||||||
t, ct = sort_axis(t, rt)
|
|
||||||
|
|
||||||
# slice y according to the given ranges
|
|
||||||
y = y[ct][:, cw]
|
|
||||||
"""
|
|
||||||
if isinstance(plt_range, tuple):
|
|
||||||
plt_range = PlotRange(*plt_range)
|
|
||||||
r = np.array((plt_range.left, plt_range.right), dtype="float")
|
|
||||||
|
|
||||||
indices = np.arange(len(axis))[
|
|
||||||
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
|
|
||||||
]
|
|
||||||
cropped = axis[indices]
|
|
||||||
order = np.argsort(plt_range.unit.inv(cropped))
|
|
||||||
indices = indices[order]
|
|
||||||
cropped = cropped[order]
|
|
||||||
out_ax = plt_range.unit.inv(cropped)
|
|
||||||
|
|
||||||
return out_ax, indices, (out_ax[0], out_ax[-1])
|
|
||||||
|
|
||||||
|
|
||||||
def get_arg_names(func: Callable) -> list[str]:
|
|
||||||
spec = inspect.getfullargspec(func)
|
|
||||||
args = spec.args
|
|
||||||
if spec.defaults is not None and len(spec.defaults) > 0:
|
|
||||||
args = args[: -len(spec.defaults)]
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def validate_arg_names(names: list[str]):
|
|
||||||
for n in names:
|
|
||||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
|
||||||
raise ValueError(f"{n} is an invalid parameter name")
|
|
||||||
|
|
||||||
|
|
||||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
|
||||||
if arg_names is None:
|
|
||||||
arg_names = get_arg_names(func)
|
|
||||||
else:
|
|
||||||
validate_arg_names(arg_names)
|
|
||||||
validate_arg_names(kwarg_names)
|
|
||||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
|
||||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
|
||||||
tmp_name = f"{func.__name__}_0"
|
|
||||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
|
||||||
scope = dict(__func__=func)
|
|
||||||
exec(func_str, scope)
|
|
||||||
out_func = scope[tmp_name]
|
|
||||||
out_func.__module__ = "evaluator"
|
|
||||||
return out_func
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
|
||||||
if not isinstance(num_args, int) and isinstance(num_returns, int):
|
|
||||||
raise TypeError(f"num_args and num_returns must be int")
|
|
||||||
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
|
||||||
return_str = ", ".join("True" for _ in range(num_returns))
|
|
||||||
func_name = f"__mock_{num_args}_{num_returns}"
|
|
||||||
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
|
||||||
scope = {}
|
|
||||||
exec(func_str, scope)
|
|
||||||
out_func = scope[func_name]
|
|
||||||
out_func.__module__ = "evaluator"
|
|
||||||
return out_func
|
|
||||||
|
|
||||||
|
|
||||||
def pretty_format_from_sim_name(name: str) -> str:
|
|
||||||
"""formats a pretty version of a simulation directory
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
name of the simulation (directory name)
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str
|
|
||||||
prettier name
|
|
||||||
"""
|
|
||||||
s = name.split(PARAM_SEPARATOR)
|
|
||||||
out = []
|
|
||||||
for key, value in zip(s[::2], s[1::2]):
|
|
||||||
try:
|
|
||||||
out += [key.replace("_", " "), getattr(Parameters, key).display(float(value))]
|
|
||||||
except (AttributeError, ValueError):
|
|
||||||
out.append(key + PARAM_SEPARATOR + value)
|
|
||||||
return PARAM_SEPARATOR.join(out)
|
|
||||||
|
|
||||||
|
|
||||||
default_rules: list[Rule] = [
|
default_rules: list[Rule] = [
|
||||||
# Grid
|
# Grid
|
||||||
*Rule.deduce(
|
*Rule.deduce(
|
||||||
|
|||||||
@@ -1,4 +1,23 @@
|
|||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
from functools import cache
|
||||||
from string import printable as str_printable
|
from string import printable as str_printable
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..physics.units import get_unit
|
||||||
|
|
||||||
|
|
||||||
|
class HashableBaseModel(BaseModel):
|
||||||
|
"""Pydantic BaseModel that's immutable and can be hashed"""
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
allow_mutation = False
|
||||||
|
|
||||||
|
|
||||||
def to_62(i: int) -> str:
|
def to_62(i: int) -> str:
|
||||||
@@ -10,3 +29,118 @@ def to_62(i: int) -> str:
|
|||||||
i, value = divmod(i, 62)
|
i, value = divmod(i, 62)
|
||||||
arr.append(str_printable[value])
|
arr.append(str_printable[value])
|
||||||
return "".join(reversed(arr))
|
return "".join(reversed(arr))
|
||||||
|
|
||||||
|
|
||||||
|
class PlotRange(HashableBaseModel):
|
||||||
|
left: float
|
||||||
|
right: float
|
||||||
|
unit: Callable[[float], float]
|
||||||
|
conserved_quantity: bool = True
|
||||||
|
|
||||||
|
def __init__(self, left, right, unit, **kwargs):
|
||||||
|
super().__init__(left=left, right=right, unit=get_unit(unit), **kwargs)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||||
|
|
||||||
|
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||||
|
return sort_axis(axis, self)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
yield self.left
|
||||||
|
yield self.right
|
||||||
|
yield self.unit.__name__
|
||||||
|
|
||||||
|
|
||||||
|
def sort_axis(
|
||||||
|
axis: np.ndarray, plt_range: PlotRange
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||||
|
"""
|
||||||
|
given an axis, returns this axis cropped according to the given range, converted and sorted
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
axis : 1D array containing the original axis (usual the w or t array)
|
||||||
|
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
cropped : the axis cropped, converted and sorted
|
||||||
|
indices : indices to use to slice and sort other array in the same fashion
|
||||||
|
extent : tupple with min and max of cropped
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
||||||
|
t = np.linspace(-10, 10, 400)
|
||||||
|
W, T = np.meshgrid(w, t)
|
||||||
|
y = np.exp(-W**2 - T**2)
|
||||||
|
|
||||||
|
# Define ranges
|
||||||
|
rw = (-4, 4, s)
|
||||||
|
rt = (-2, 6, s)
|
||||||
|
|
||||||
|
w, cw = sort_axis(w, rw)
|
||||||
|
t, ct = sort_axis(t, rt)
|
||||||
|
|
||||||
|
# slice y according to the given ranges
|
||||||
|
y = y[ct][:, cw]
|
||||||
|
"""
|
||||||
|
if isinstance(plt_range, tuple):
|
||||||
|
plt_range = PlotRange(*plt_range)
|
||||||
|
r = np.array((plt_range.left, plt_range.right), dtype="float")
|
||||||
|
|
||||||
|
indices = np.arange(len(axis))[
|
||||||
|
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
|
||||||
|
]
|
||||||
|
cropped = axis[indices]
|
||||||
|
order = np.argsort(plt_range.unit.inv(cropped))
|
||||||
|
indices = indices[order]
|
||||||
|
cropped = cropped[order]
|
||||||
|
out_ax = plt_range.unit.inv(cropped)
|
||||||
|
|
||||||
|
return out_ax, indices, (out_ax[0], out_ax[-1])
|
||||||
|
|
||||||
|
|
||||||
|
def get_arg_names(func: Callable) -> list[str]:
|
||||||
|
spec = inspect.getfullargspec(func)
|
||||||
|
args = spec.args
|
||||||
|
if spec.defaults is not None and len(spec.defaults) > 0:
|
||||||
|
args = args[: -len(spec.defaults)]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def validate_arg_names(names: list[str]):
|
||||||
|
for n in names:
|
||||||
|
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||||
|
raise ValueError(f"{n} is an invalid parameter name")
|
||||||
|
|
||||||
|
|
||||||
|
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
||||||
|
if arg_names is None:
|
||||||
|
arg_names = get_arg_names(func)
|
||||||
|
else:
|
||||||
|
validate_arg_names(arg_names)
|
||||||
|
validate_arg_names(kwarg_names)
|
||||||
|
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||||
|
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||||
|
tmp_name = f"{func.__name__}_0"
|
||||||
|
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||||
|
scope = dict(__func__=func)
|
||||||
|
exec(func_str, scope)
|
||||||
|
out_func = scope[tmp_name]
|
||||||
|
out_func.__module__ = "evaluator"
|
||||||
|
return out_func
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
||||||
|
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
||||||
|
return_str = ", ".join("True" for _ in range(num_returns))
|
||||||
|
func_name = f"__mock_{num_args}_{num_returns}"
|
||||||
|
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
||||||
|
scope = {}
|
||||||
|
exec(func_str, scope)
|
||||||
|
out_func = scope[func_name]
|
||||||
|
out_func.__module__ = "evaluator"
|
||||||
|
return out_func
|
||||||
|
|||||||
@@ -1,45 +1,14 @@
|
|||||||
from pydantic import BaseModel, validator
|
from math import prod
|
||||||
from typing import Union, Iterable, Generator, Any
|
|
||||||
from collections.abc import Sequence, MutableMapping
|
|
||||||
import itertools
|
import itertools
|
||||||
|
from collections.abc import MutableMapping, Sequence
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Generator, Iterable, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import validator
|
||||||
|
|
||||||
from ..const import PARAM_SEPARATOR
|
from ..const import PARAM_SEPARATOR
|
||||||
from . import utils
|
from . import utils
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def format_value(name: str, value) -> str:
|
|
||||||
if value is True or value is False:
|
|
||||||
return str(value)
|
|
||||||
elif isinstance(value, (float, int)):
|
|
||||||
try:
|
|
||||||
return getattr(Parameters, name).display(value)
|
|
||||||
except AttributeError:
|
|
||||||
return format(value, ".9g")
|
|
||||||
elif isinstance(value, (list, tuple, np.ndarray)):
|
|
||||||
return "-".join([str(v) for v in value])
|
|
||||||
elif isinstance(value, str):
|
|
||||||
p = Path(value)
|
|
||||||
if p.exists():
|
|
||||||
return p.stem
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
|
|
||||||
def pretty_format_value(name: str, value) -> str:
|
|
||||||
try:
|
|
||||||
return getattr(Parameters, name).display(value)
|
|
||||||
except AttributeError:
|
|
||||||
return name + PARAM_SEPARATOR + str(value)
|
|
||||||
|
|
||||||
|
|
||||||
class HashableBaseModel(BaseModel):
|
|
||||||
"""Pydantic BaseModel that's immutable and can be hashed"""
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
allow_mutation = False
|
|
||||||
|
|
||||||
|
|
||||||
class VariationSpecsError(ValueError):
|
class VariationSpecsError(ValueError):
|
||||||
@@ -67,23 +36,20 @@ class Variationer:
|
|||||||
all_indices: list[list[int]]
|
all_indices: list[list[int]]
|
||||||
all_dicts: list[list[dict[str, list]]]
|
all_dicts: list[list[dict[str, list]]]
|
||||||
|
|
||||||
def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]]):
|
def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]] = None):
|
||||||
self.all_indices = []
|
self.all_indices = []
|
||||||
self.all_dicts = []
|
self.all_dicts = []
|
||||||
for i, el in enumerate(variables):
|
if variables is not None:
|
||||||
if not isinstance(el, Sequence):
|
for i, el in enumerate(variables):
|
||||||
el = [{k: v} for k, v in el.items()]
|
self.append(el)
|
||||||
else:
|
|
||||||
el = list(el)
|
|
||||||
self.append(el)
|
|
||||||
|
|
||||||
def append(self, var_list: list[dict[str, list]]):
|
def append(self, var_list: Union[list[MutableMapping], MutableMapping]):
|
||||||
"""append a list of variable parameter sets
|
"""append a list of variable parameter sets
|
||||||
each call to append creates a new group of parameters
|
each call to append creates a new group of parameters
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
var_list : list[dict[str, list]]
|
var_list : Union[list[MutableMapping], MutableMapping]
|
||||||
each dict in the list is treated as an independent parameter
|
each dict in the list is treated as an independent parameter
|
||||||
this means that if for one dict, len > 1, the lists of possible values
|
this means that if for one dict, len > 1, the lists of possible values
|
||||||
must be the same length
|
must be the same length
|
||||||
@@ -100,6 +66,10 @@ class Variationer:
|
|||||||
VariationSpecsError
|
VariationSpecsError
|
||||||
raised when possible values lists in a same dict are not the same length
|
raised when possible values lists in a same dict are not the same length
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(var_list, Sequence):
|
||||||
|
var_list = [{k: v} for k, v in var_list.items()]
|
||||||
|
else:
|
||||||
|
var_list = list(var_list)
|
||||||
num_vars = []
|
num_vars = []
|
||||||
for d in var_list:
|
for d in var_list:
|
||||||
values = list(d.values())
|
values = list(d.values())
|
||||||
@@ -114,30 +84,43 @@ class Variationer:
|
|||||||
self.all_indices.append(num_vars)
|
self.all_indices.append(num_vars)
|
||||||
self.all_dicts.append(var_list)
|
self.all_dicts.append(var_list)
|
||||||
|
|
||||||
def iterate(self, index: int = -1) -> Generator["SimulationDescriptor", None, None]:
|
def iterate(self, index: int = -1) -> Generator["VariationDescriptor", None, None]:
|
||||||
if index < 0:
|
index = self.__index(index)
|
||||||
index = len(self.all_indices) + index + 1
|
flattened_indices = sum(self.all_indices[: index + 1], [])
|
||||||
flattened_indices = sum(self.all_indices[:index], [])
|
index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[: index + 1]])
|
||||||
index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[:index]])
|
|
||||||
ranges = [range(i) for i in flattened_indices]
|
ranges = [range(i) for i in flattened_indices]
|
||||||
for r in itertools.product(*ranges):
|
for r in itertools.product(*ranges):
|
||||||
out: list[list[tuple[str, Any]]] = []
|
out: list[list[tuple[str, Any]]] = []
|
||||||
|
indicies: list[list[int]] = []
|
||||||
for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])):
|
for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])):
|
||||||
out.append([])
|
out.append([])
|
||||||
|
indicies.append([])
|
||||||
for value_index, var_d in zip(r[start:end], self.all_dicts[i]):
|
for value_index, var_d in zip(r[start:end], self.all_dicts[i]):
|
||||||
for k, v in var_d.items():
|
for k, v in var_d.items():
|
||||||
out[-1].append((k, v[value_index]))
|
out[-1].append((k, v[value_index]))
|
||||||
yield SimulationDescriptor(raw_descr=out)
|
indicies[-1].append(value_index)
|
||||||
|
yield VariationDescriptor(raw_descr=out, index=indicies)
|
||||||
|
|
||||||
|
def __index(self, index: int) -> int:
|
||||||
|
if index < 0:
|
||||||
|
index = len(self.all_indices) + index
|
||||||
|
return index
|
||||||
|
|
||||||
|
def var_num(self, index: int = -1) -> int:
|
||||||
|
index = self.__index(index)
|
||||||
|
return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
|
||||||
|
|
||||||
|
|
||||||
class SimulationDescriptor(HashableBaseModel):
|
class VariationDescriptor(utils.HashableBaseModel):
|
||||||
raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
|
raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
|
||||||
|
index: tuple[tuple[int, ...], ...]
|
||||||
separator: str = "fiber"
|
separator: str = "fiber"
|
||||||
|
_format_registry: dict[str, Callable[..., str]] = {}
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.descriptor(add_identifier=False)
|
return self.formatted_descriptor(add_identifier=False)
|
||||||
|
|
||||||
def descriptor(self, add_identifier=False) -> str:
|
def formatted_descriptor(self, add_identifier=False) -> str:
|
||||||
"""formats a variable list into a str such that each simulation has a unique
|
"""formats a variable list into a str such that each simulation has a unique
|
||||||
directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations)
|
directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations)
|
||||||
branch identifier can added at the beginning.
|
branch identifier can added at the beginning.
|
||||||
@@ -156,7 +139,7 @@ class SimulationDescriptor(HashableBaseModel):
|
|||||||
|
|
||||||
for p_name, p_value in self.flat:
|
for p_name, p_value in self.flat:
|
||||||
ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "")
|
ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "")
|
||||||
vs = format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
|
vs = self.format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
|
||||||
str_list.append(ps + PARAM_SEPARATOR + vs)
|
str_list.append(ps + PARAM_SEPARATOR + vs)
|
||||||
tmp_name = PARAM_SEPARATOR.join(str_list)
|
tmp_name = PARAM_SEPARATOR.join(str_list)
|
||||||
if not add_identifier:
|
if not add_identifier:
|
||||||
@@ -165,6 +148,34 @@ class SimulationDescriptor(HashableBaseModel):
|
|||||||
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
|
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_formatter(cls, p_name: str, func: Callable[..., str]):
|
||||||
|
cls._format_registry[p_name] = func
|
||||||
|
|
||||||
|
def format_value(self, name: str, value) -> str:
|
||||||
|
if value is True or value is False:
|
||||||
|
return str(value)
|
||||||
|
elif isinstance(value, (float, int)):
|
||||||
|
try:
|
||||||
|
return self._format_registry[name](value)
|
||||||
|
except KeyError:
|
||||||
|
return format(value, ".9g")
|
||||||
|
elif isinstance(value, (list, tuple, np.ndarray)):
|
||||||
|
return "-".join([str(v) for v in value])
|
||||||
|
elif isinstance(value, str):
|
||||||
|
p = Path(value)
|
||||||
|
if p.exists():
|
||||||
|
return p.stem
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
def __getitem__(self, key) -> "VariationDescriptor":
|
||||||
|
return VariationDescriptor(
|
||||||
|
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_config(self, cfg: dict[str, Any]):
|
||||||
|
return cfg | {k: v for k, v in self.raw_descr[-1]}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def flat(self) -> list[tuple[str, Any]]:
|
def flat(self) -> list[tuple[str, Any]]:
|
||||||
out = []
|
out = []
|
||||||
@@ -177,17 +188,27 @@ class SimulationDescriptor(HashableBaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def branch(self) -> "BranchDescriptor":
|
def branch(self) -> "BranchDescriptor":
|
||||||
return SimulationDescriptor(raw_descr=self.raw_descr, separator=self.separator)
|
for i in reversed(range(len(self.raw_descr))):
|
||||||
|
for j in reversed(range(len(self.raw_descr[i]))):
|
||||||
|
if self.raw_descr[i][j][0] == "num":
|
||||||
|
del self.raw_descr[i][j]
|
||||||
|
return VariationDescriptor(
|
||||||
|
raw_descr=self.raw_descr, index=self.index, separator=self.separator
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identifier(self) -> str:
|
def identifier(self) -> str:
|
||||||
return "u_" + utils.to_62(hash(str(self.flat)))
|
return "u_" + utils.to_62(hash(str(self.flat)))
|
||||||
|
|
||||||
|
|
||||||
class BranchDescriptor(SimulationDescriptor):
|
class BranchDescriptor(VariationDescriptor):
|
||||||
|
__ids: dict[int, int] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identifier(self) -> str:
|
def identifier(self) -> str:
|
||||||
return "b_" + utils.to_62(hash(str(self.flat)))
|
branch_id = hash(str(self.flat))
|
||||||
|
self.__ids.setdefault(branch_id, len(self.__ids))
|
||||||
|
return str(self.__ids[branch_id])
|
||||||
|
|
||||||
@validator("raw_descr")
|
@validator("raw_descr")
|
||||||
def validate_raw_descr(cls, v):
|
def validate_raw_descr(cls, v):
|
||||||
|
|||||||
Reference in New Issue
Block a user