Sim starts, merge not

This commit is contained in:
Benoît Sierro
2021-09-27 13:21:02 +02:00
parent ef48711aa6
commit 695ac3bd73
9 changed files with 335 additions and 319 deletions

View File

@@ -5,4 +5,5 @@ from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
from .spectra import Pulse, Spectrum
from .utils import Paths, open_config, parameter
from .utils.parameter import Configuration, Parameters, PlotRange
from .utils.parameter import Configuration, Parameters
from .utils.utils import PlotRange

View File

@@ -11,7 +11,7 @@ from send2trash import send2trash
from .. import env, utils
from ..logger import get_logger
from ..utils.parameter import Configuration, Parameters, format_variable_list
from ..utils.parameter import Configuration, Parameters
from . import pulse
from .fiber import create_non_linear_op, fast_dispersion_op
@@ -466,14 +466,14 @@ class Simulations:
self.configuration = configuration
self.name = self.configuration.final_path
self.sim_dir = self.configuration.final_sim_dir
self.name = self.configuration.name
self.sim_dir = self.configuration.final_path
self.configuration.save_parameters()
self.sim_jobs_per_node = 1
def finished_and_complete(self):
for sim in self.configuration.all_configs_dict.values():
for sim in self.configuration.all_configs.values():
if (
self.configuration.sim_status(sim.output_path)[0]
!= self.configuration.State.COMPLETE
@@ -487,7 +487,7 @@ class Simulations:
def _run_available(self):
for variable, params in self.configuration:
v_list_str = format_variable_list(variable, add_iden=True)
v_list_str = variable.formatted_descriptor(True)
utils.save_parameters(params.prepare_for_dump(), Path(params.output_path))
self.new_sim(v_list_str, params)
@@ -526,7 +526,9 @@ class SequencialSimulations(Simulations, priority=0):
def __init__(self, configuration: Configuration, task_id):
super().__init__(configuration, task_id=task_id)
self.pbars = utils.PBars(
self.configuration.total_num_steps, "Simulating " + self.configuration.final_path, 1
self.configuration.total_num_steps,
"Simulating " + self.configuration.final_path.name,
1,
)
self.configuration.skip_callback = lambda num: self.pbars.update(0, num)
@@ -569,7 +571,7 @@ class MultiProcSimulations(Simulations, priority=1):
self.p_worker = multiprocessing.Process(
target=utils.progress_worker,
args=(
self.configuration.final_path,
self.configuration.final_path.name,
self.sim_jobs_per_node,
self.configuration.total_num_steps,
self.progress_queue,
@@ -716,7 +718,7 @@ def run_simulation(
sim = new_simulation(config, method)
sim.run()
path_trees = utils.build_path_trees(config.sim_dirs[-1])
path_trees = utils.build_path_trees(config.fiber_paths[-1])
final_name = env.get(env.OUTPUT_PATH)
if final_name is None:
@@ -724,7 +726,7 @@ def run_simulation(
utils.merge(final_name, path_trees)
try:
send2trash(config.sim_dirs)
send2trash(config.fiber_paths)
except (PermissionError, OSError):
get_logger(__name__).error("Could not send temporary directories to trash")

View File

@@ -14,7 +14,8 @@ from .const import PARAM_SEPARATOR
from .defaults import default_plotting as defaults
from .math import abs2, span
from .physics import pulse, units
from .utils.parameter import Parameters, PlotRange, sort_axis
from .utils.parameter import Parameters
from .utils.utils import PlotRange, sort_axis
RangeType = tuple[float, float, Union[str, Callable]]
NO_LIM = object()

View File

@@ -16,9 +16,8 @@ from ..utils import auto_crop, open_config, save_toml, translate_parameters
from ..utils.parameter import (
Configuration,
Parameters,
pretty_format_from_sim_name,
pretty_format_value,
)
from ..utils.variationer import VariationDescriptor
def fingerprint(params: Parameters):
@@ -46,7 +45,7 @@ def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
path, fig, ax = plot_setup(
pulse.path.parent
/ (
pretty_format_from_sim_name(pulse.path.name)
pulse.path.name
+ PARAM_SEPARATOR
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
)

View File

@@ -16,7 +16,8 @@ from .plotting import (
single_position_plot,
transform_2D_propagation,
)
from .utils.parameter import Parameters, PlotRange
from .utils.parameter import Parameters
from .utils.utils import PlotRange
from .utils import load_spectrum

View File

@@ -161,29 +161,35 @@ def save_toml(path: os.PathLike, dico):
return dico
def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str, Any]], str]:
loaded_config = open_config(final_config_path)
final_name = loaded_config.get("name")
fiber_list = loaded_config.pop("Fiber")
configs = []
if fiber_list is not None:
master_variable = loaded_config.get("variable", {})
for i, params in enumerate(fiber_list):
params.setdefault("variable", master_variable if i == 0 else {})
if i == 0:
params["variable"] |= master_variable
configs.append(loaded_config | params)
else:
configs.append(loaded_config)
while "previous_config_file" in configs[0]:
configs.insert(0, open_config(configs[0]["previous_config_file"]))
configs[0].setdefault("variable", {})
for pre, nex in zip(configs[:-1], configs[1:]):
variable = nex.pop("variable", {})
nex.update({k: v for k, v in pre.items() if k not in nex})
nex["variable"] = variable
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]:
"""loads a configuration file
return configs, final_name
Parameters
----------
path : os.PathLike
path to the config toml file
Returns
-------
final_path : Path
output name of the simulation
list[dict[str, Any]]
one config per fiber
"""
loaded_config = open_config(path)
fiber_list: list[dict[str, Any]] = loaded_config.pop("Fiber")
if len(fiber_list) == 0:
raise ValueError(f"No fiber in config {path}")
final_path = loaded_config.get("name")
configs = []
for i, params in enumerate(fiber_list):
params.setdefault("variable", {})
configs.append(loaded_config | params)
configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"]
return Path(final_path), configs
def save_parameters(

View File

@@ -11,14 +11,18 @@ from dataclasses import asdict, dataclass, fields
from functools import cache, lru_cache
from pathlib import Path
from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union
import numpy as np
from numpy.lib import isin
from scgenerator.utils import ensure_folder, variationer
from .. import math, utils
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
from ..errors import EvaluatorError, NoDefaultError
from ..logger import get_logger
from ..physics import fiber, materials, pulse, units
from ..utils.variationer import VariationDescriptor, Variationer
from .utils import func_rewrite, _mock_function, get_arg_names
T = TypeVar("T")
@@ -256,7 +260,7 @@ class Parameter:
----------
tpe : type
type of the paramter
validators : Callable[[str, Any], None]
validator : Callable[[str, Any], None]
signature : validator(name, value)
must raise a ValueError when value doesn't fit the criteria checked by
validator. name is passed to validator to be included in the error message
@@ -290,7 +294,6 @@ class Parameter:
if isinstance(value, Parameter):
defaut = None if self.default is None else copy(self.default)
instance.__dict__[self.name] = defaut
# instance.__dict__[self.name] = None
else:
if value is not None:
if self.converter is not None:
@@ -768,9 +771,11 @@ class Configuration:
obj with the output path of the simulation saved in its output_path attribute.
"""
master_configs: list[dict[str, Any]]
sim_dirs: list[Path]
fiber_configs: list[dict[str, Any]]
master_config: dict[str, Any]
fiber_paths: list[Path]
num_sim: int
num_fibers: int
repeat: int
z_num: int
total_num_steps: int
@@ -778,19 +783,17 @@ class Configuration:
parallel: bool
overwrite: bool
final_path: str
all_configs_dict: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
all_configs_list: list[list["Configuration.__SimConfig"]]
all_configs: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
@dataclass(frozen=True)
class __SimConfig:
vary_list: list[tuple[str, Any]]
descriptor: VariationDescriptor
config: dict[str, Any]
output_path: Path
index: tuple[tuple[int, ...], ...]
@property
def sim_num(self) -> int:
return len(self.index)
return len(self.descriptor.index)
class State(enum.Enum):
COMPLETE = enum.auto()
@@ -810,48 +813,48 @@ class Configuration:
):
self.logger = get_logger(__name__)
self.master_configs, self.final_path = utils.load_config_sequence(final_config_path)
if self.final_path is None:
self.final_path = Parameters.name.default
self.name = Path(self.final_path).name
self.overwrite = overwrite
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
self.final_path = utils.ensure_folder(
self.final_path, mkdir=False, prevent_overwrite=not self.overwrite
)
self.master_config = self.fiber_configs[0]
self.name = self.final_path.name
self.z_num = 0
self.total_num_steps = 0
self.sim_dirs = []
self.overwrite = overwrite
self.fiber_paths = []
self.all_configs = {}
self.skip_callback = skip_callback
self.worker_num = self.master_configs[0].get("worker_num", max(1, os.cpu_count() // 2))
self.repeat = self.master_configs[0].get("repeat", 1)
self.worker_num = self.master_config.get("worker_num", max(1, os.cpu_count() // 2))
self.repeat = self.master_config.get("repeat", 1)
self.variationer = Variationer()
names = set()
for i, config in enumerate(self.master_configs):
fiber_names = set()
self.num_fibers = 0
for i, config in enumerate(self.fiber_configs):
config.setdefault("name", Parameters.name.default)
self.z_num += config["z_num"]
config.setdefault("name", f"{Parameters.name.default} {i}")
given_name = config["name"]
fn_i = 0
while config["name"] in names:
config["name"] = given_name + f"_{fn_i}"
fn_i += 1
names.add(config["name"])
self.sim_dirs.append(
fiber_names.add(config["name"])
self.variationer.append(config.pop("variable"))
self.fiber_paths.append(
utils.ensure_folder(
Path("_".join(["_", self.name, Path(config["name"]).name, "_"])),
self.fiber_path(i, config),
mkdir=False,
prevent_overwrite=not self.overwrite,
)
)
self.__validate_variable(config)
self.__compute_sim_dirs()
[Evaluator.evaluate_default(c[0].config, True) for c in self.all_configs_list]
self.num_sim = len(self.all_configs_list[-1])
self.num_fibers += 1
Evaluator.evaluate_default(config, True)
self.num_sim = self.variationer.var_num()
self.total_num_steps = sum(
config["z_num"] * len(self.all_configs_list[i])
for i, config in enumerate(self.master_configs)
config["z_num"] * self.variationer.var_num(i)
for i, config in enumerate(self.fiber_configs)
)
self.final_sim_dir = utils.ensure_folder(
Path(self.master_configs[-1]["name"]), mkdir=False, prevent_overwrite=not self.overwrite
)
self.parallel = self.master_configs[0].get("parallel", Parameters.parallel.default)
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
def fiber_path(self, i: int, full_config: dict[str, Any]) -> Path:
return self.final_path / PARAM_SEPARATOR.join([format(i), self.name, full_config["name"]])
def __validate_variable(self, config: dict[str, Any]):
for k, v in config.get("variable", {}).items():
@@ -862,76 +865,62 @@ class Configuration:
if len(v) == 0:
raise ValueError(f"variable parameter {k!r} must not be empty")
def __compute_sim_dirs(self):
self.all_configs_dict = {}
self.all_configs_list = []
self.master_configs[0]["variable"]["num"] = list(
range(self.master_configs[0].get("repeat", 1))
)
dp = DataPather([c["variable"] for c in self.master_configs])
for i, conf in enumerate(self.master_configs):
self.all_configs_list.append([])
for sim_index, prev_path, this_path, this_vary in dp.all_vary_list(i):
this_conf = conf.copy()
def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]:
for i in range(self.num_fibers):
for sim_config in self.iterate_single_fiber(i):
if i > 0:
prev_path = utils.ensure_folder(
self.sim_dirs[i - 1] / prev_path, not self.overwrite, False
sim_config.config["prev_data_dir"] = str(
self.fiber_paths[i - 1] / sim_config.descriptor[:i].formatted_descriptor()
)
this_conf["prev_data_dir"] = str(prev_path)
this_path = utils.ensure_folder(
self.sim_dirs[i] / this_path, not self.overwrite, False
)
this_conf.pop("variable")
conf_to_use = {k: v for k, v in this_vary if k != "num"} | this_conf
self.all_configs_dict[sim_index] = self.__SimConfig(
this_vary, conf_to_use, this_path, sim_index
)
self.all_configs_list[i].append(self.all_configs_dict[sim_index])
def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]:
for i, sim_config_list in enumerate(self.all_configs_list):
for sim_config, params in self.__iter_1_sim(sim_config_list):
params = Parameters(**sim_config.config)
params.compute()
fiber_map = []
for j in range(i + 1):
this_conf = self.all_configs_dict[sim_config.index[: j + 1]].config
this_conf = self.all_configs[sim_config.descriptor.index[: j + 1]].config
if j > 0:
prev_conf = self.all_configs_dict[sim_config.index[:j]].config
prev_conf = self.all_configs[sim_config.descriptor.index[:j]].config
length = prev_conf["length"] + fiber_map[j - 1][0]
else:
length = 0.0
fiber_map.append((length, this_conf["name"]))
params.output_path = str(sim_config.output_path)
params.fiber_map = fiber_map
yield sim_config.vary_list, params
yield sim_config.descriptor, params
def __iter_1_sim(
self, configs: list["Configuration.__SimConfig"]
) -> Generator[tuple["Configuration.__SimConfig", Parameters], None, None]:
def iterate_single_fiber(
self, index: int
) -> Generator["Configuration.__SimConfig", None, None]:
"""iterates through the parameters of only one fiber. It takes care of recovering partially
completed simulations, skipping complete ones and waiting for the previous fiber to finish
Parameters
----------
configs : list[__SimConfig]
list of configuration obj
index : int
which fiber to iterate over
Yields
-------
__SimConfig
configuration obj
Parameters
computed Parameters obj
"""
sim_dict: dict[Path, Configuration.__SimConfig] = {s.output_path: s for s in configs}
sim_dict: dict[Path, self.__SimConfig] = {}
for descr in self.variationer.iterate(index):
cfg = descr.update_config(self.fiber_configs[index])
p = ensure_folder(
self.fiber_paths[index] / descr.formatted_descriptor(),
not self.overwrite,
False,
)
cfg["output_path"] = str(p)
sim_config = self.__SimConfig(descr, cfg, p)
sim_dict[p] = sim_config
self.all_configs[sim_config.descriptor.index] = sim_config
while len(sim_dict) > 0:
for data_dir, sim_config in sim_dict.items():
task, config_dict = self.__decide(sim_config)
if task == self.Action.RUN:
sim_dict.pop(data_dir)
p = Parameters(**config_dict)
p.compute()
yield sim_config, p
yield sim_config
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
self.skip_callback(config_dict["recovery_last_stored"])
break
@@ -956,7 +945,7 @@ class Configuration:
Returns
-------
str : {'run', 'wait', 'skip'}
str : Configuration.Action
what to do
config_dict : dict[str, Any]
config dictionary. The only key possibly modified is 'prev_data_dir', which
@@ -1012,7 +1001,7 @@ class Configuration:
raise ValueError(f"Too many spectra in {data_dir}")
def save_parameters(self):
for config, sim_dir in zip(self.master_configs, self.sim_dirs):
for config, sim_dir in zip(self.fiber_configs, self.fiber_paths):
os.makedirs(sim_dir, exist_ok=True)
utils.save_toml(sim_dir / f"initial_config.toml", config)
@@ -1022,144 +1011,6 @@ class Configuration:
return param
@dataclass(frozen=True)
class PlotRange:
left: float = Parameter(type_checker(int, float))
right: float = Parameter(type_checker(int, float))
unit: Callable[[float], float] = Parameter(units.is_unit, converter=units.get_unit)
conserved_quantity: bool = Parameter(boolean, default=True)
def __str__(self):
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
return sort_axis(axis, self)
def __iter__(self):
yield self.left
yield self.right
yield self.unit.__name__
def sort_axis(
axis: np.ndarray, plt_range: PlotRange
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
"""
given an axis, returns this axis cropped according to the given range, converted and sorted
Parameters
----------
axis : 1D array containing the original axis (usual the w or t array)
plt_range : tupple (min, max, conversion_function) used to crop the axis
Returns
-------
cropped : the axis cropped, converted and sorted
indices : indices to use to slice and sort other array in the same fashion
extent : tupple with min and max of cropped
Example
-------
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
t = np.linspace(-10, 10, 400)
W, T = np.meshgrid(w, t)
y = np.exp(-W**2 - T**2)
# Define ranges
rw = (-4, 4, s)
rt = (-2, 6, s)
w, cw = sort_axis(w, rw)
t, ct = sort_axis(t, rt)
# slice y according to the given ranges
y = y[ct][:, cw]
"""
if isinstance(plt_range, tuple):
plt_range = PlotRange(*plt_range)
r = np.array((plt_range.left, plt_range.right), dtype="float")
indices = np.arange(len(axis))[
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
]
cropped = axis[indices]
order = np.argsort(plt_range.unit.inv(cropped))
indices = indices[order]
cropped = cropped[order]
out_ax = plt_range.unit.inv(cropped)
return out_ax, indices, (out_ax[0], out_ax[-1])
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
out_func = scope[tmp_name]
out_func.__module__ = "evaluator"
return out_func
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
if not isinstance(num_args, int) and isinstance(num_returns, int):
raise TypeError(f"num_args and num_returns must be int")
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func
def pretty_format_from_sim_name(name: str) -> str:
"""formats a pretty version of a simulation directory
Parameters
----------
name : str
name of the simulation (directory name)
Returns
-------
str
prettier name
"""
s = name.split(PARAM_SEPARATOR)
out = []
for key, value in zip(s[::2], s[1::2]):
try:
out += [key.replace("_", " "), getattr(Parameters, key).display(float(value))]
except (AttributeError, ValueError):
out.append(key + PARAM_SEPARATOR + value)
return PARAM_SEPARATOR.join(out)
default_rules: list[Rule] = [
# Grid
*Rule.deduce(

View File

@@ -1,4 +1,23 @@
import inspect
import re
from functools import cache
from string import printable as str_printable
from typing import Callable
import numpy as np
from pydantic import BaseModel
from ..physics.units import get_unit
class HashableBaseModel(BaseModel):
"""Pydantic BaseModel that's immutable and can be hashed"""
def __hash__(self) -> int:
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
class Config:
allow_mutation = False
def to_62(i: int) -> str:
@@ -10,3 +29,118 @@ def to_62(i: int) -> str:
i, value = divmod(i, 62)
arr.append(str_printable[value])
return "".join(reversed(arr))
class PlotRange(HashableBaseModel):
left: float
right: float
unit: Callable[[float], float]
conserved_quantity: bool = True
def __init__(self, left, right, unit, **kwargs):
super().__init__(left=left, right=right, unit=get_unit(unit), **kwargs)
def __str__(self):
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
return sort_axis(axis, self)
def __iter__(self):
yield self.left
yield self.right
yield self.unit.__name__
def sort_axis(
axis: np.ndarray, plt_range: PlotRange
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
"""
given an axis, returns this axis cropped according to the given range, converted and sorted
Parameters
----------
axis : 1D array containing the original axis (usual the w or t array)
plt_range : tupple (min, max, conversion_function) used to crop the axis
Returns
-------
cropped : the axis cropped, converted and sorted
indices : indices to use to slice and sort other array in the same fashion
extent : tupple with min and max of cropped
Example
-------
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
t = np.linspace(-10, 10, 400)
W, T = np.meshgrid(w, t)
y = np.exp(-W**2 - T**2)
# Define ranges
rw = (-4, 4, s)
rt = (-2, 6, s)
w, cw = sort_axis(w, rw)
t, ct = sort_axis(t, rt)
# slice y according to the given ranges
y = y[ct][:, cw]
"""
if isinstance(plt_range, tuple):
plt_range = PlotRange(*plt_range)
r = np.array((plt_range.left, plt_range.right), dtype="float")
indices = np.arange(len(axis))[
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
]
cropped = axis[indices]
order = np.argsort(plt_range.unit.inv(cropped))
indices = indices[order]
cropped = cropped[order]
out_ax = plt_range.unit.inv(cropped)
return out_ax, indices, (out_ax[0], out_ax[-1])
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
out_func = scope[tmp_name]
out_func.__module__ = "evaluator"
return out_func
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func

View File

@@ -1,45 +1,14 @@
from pydantic import BaseModel, validator
from typing import Union, Iterable, Generator, Any
from collections.abc import Sequence, MutableMapping
from math import prod
import itertools
from collections.abc import MutableMapping, Sequence
from pathlib import Path
from typing import Any, Callable, Generator, Iterable, Union
import numpy as np
from pydantic import validator
from ..const import PARAM_SEPARATOR
from . import utils
import numpy as np
from pathlib import Path
def format_value(name: str, value) -> str:
if value is True or value is False:
return str(value)
elif isinstance(value, (float, int)):
try:
return getattr(Parameters, name).display(value)
except AttributeError:
return format(value, ".9g")
elif isinstance(value, (list, tuple, np.ndarray)):
return "-".join([str(v) for v in value])
elif isinstance(value, str):
p = Path(value)
if p.exists():
return p.stem
return str(value)
def pretty_format_value(name: str, value) -> str:
try:
return getattr(Parameters, name).display(value)
except AttributeError:
return name + PARAM_SEPARATOR + str(value)
class HashableBaseModel(BaseModel):
"""Pydantic BaseModel that's immutable and can be hashed"""
def __hash__(self) -> int:
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
class Config:
allow_mutation = False
class VariationSpecsError(ValueError):
@@ -67,23 +36,20 @@ class Variationer:
all_indices: list[list[int]]
all_dicts: list[list[dict[str, list]]]
def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]]):
def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]] = None):
self.all_indices = []
self.all_dicts = []
if variables is not None:
for i, el in enumerate(variables):
if not isinstance(el, Sequence):
el = [{k: v} for k, v in el.items()]
else:
el = list(el)
self.append(el)
def append(self, var_list: list[dict[str, list]]):
def append(self, var_list: Union[list[MutableMapping], MutableMapping]):
"""append a list of variable parameter sets
each call to append creates a new group of parameters
Parameters
----------
var_list : list[dict[str, list]]
var_list : Union[list[MutableMapping], MutableMapping]
each dict in the list is treated as an independent parameter
this means that if for one dict, len > 1, the lists of possible values
must be the same length
@@ -100,6 +66,10 @@ class Variationer:
VariationSpecsError
raised when possible values lists in a same dict are not the same length
"""
if not isinstance(var_list, Sequence):
var_list = [{k: v} for k, v in var_list.items()]
else:
var_list = list(var_list)
num_vars = []
for d in var_list:
values = list(d.values())
@@ -114,30 +84,43 @@ class Variationer:
self.all_indices.append(num_vars)
self.all_dicts.append(var_list)
def iterate(self, index: int = -1) -> Generator["SimulationDescriptor", None, None]:
if index < 0:
index = len(self.all_indices) + index + 1
flattened_indices = sum(self.all_indices[:index], [])
index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[:index]])
def iterate(self, index: int = -1) -> Generator["VariationDescriptor", None, None]:
index = self.__index(index)
flattened_indices = sum(self.all_indices[: index + 1], [])
index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[: index + 1]])
ranges = [range(i) for i in flattened_indices]
for r in itertools.product(*ranges):
out: list[list[tuple[str, Any]]] = []
indicies: list[list[int]] = []
for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])):
out.append([])
indicies.append([])
for value_index, var_d in zip(r[start:end], self.all_dicts[i]):
for k, v in var_d.items():
out[-1].append((k, v[value_index]))
yield SimulationDescriptor(raw_descr=out)
indicies[-1].append(value_index)
yield VariationDescriptor(raw_descr=out, index=indicies)
def __index(self, index: int) -> int:
if index < 0:
index = len(self.all_indices) + index
return index
def var_num(self, index: int = -1) -> int:
index = self.__index(index)
return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
class SimulationDescriptor(HashableBaseModel):
class VariationDescriptor(utils.HashableBaseModel):
raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
index: tuple[tuple[int, ...], ...]
separator: str = "fiber"
_format_registry: dict[str, Callable[..., str]] = {}
def __str__(self) -> str:
return self.descriptor(add_identifier=False)
return self.formatted_descriptor(add_identifier=False)
def descriptor(self, add_identifier=False) -> str:
def formatted_descriptor(self, add_identifier=False) -> str:
"""formats a variable list into a str such that each simulation has a unique
directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations)
branch identifier can added at the beginning.
@@ -156,7 +139,7 @@ class SimulationDescriptor(HashableBaseModel):
for p_name, p_value in self.flat:
ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "")
vs = format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
vs = self.format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
str_list.append(ps + PARAM_SEPARATOR + vs)
tmp_name = PARAM_SEPARATOR.join(str_list)
if not add_identifier:
@@ -165,6 +148,34 @@ class SimulationDescriptor(HashableBaseModel):
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
)
@classmethod
def register_formatter(cls, p_name: str, func: Callable[..., str]):
cls._format_registry[p_name] = func
def format_value(self, name: str, value) -> str:
if value is True or value is False:
return str(value)
elif isinstance(value, (float, int)):
try:
return self._format_registry[name](value)
except KeyError:
return format(value, ".9g")
elif isinstance(value, (list, tuple, np.ndarray)):
return "-".join([str(v) for v in value])
elif isinstance(value, str):
p = Path(value)
if p.exists():
return p.stem
return str(value)
def __getitem__(self, key) -> "VariationDescriptor":
return VariationDescriptor(
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
)
def update_config(self, cfg: dict[str, Any]):
return cfg | {k: v for k, v in self.raw_descr[-1]}
@property
def flat(self) -> list[tuple[str, Any]]:
out = []
@@ -177,17 +188,27 @@ class SimulationDescriptor(HashableBaseModel):
@property
def branch(self) -> "BranchDescriptor":
return SimulationDescriptor(raw_descr=self.raw_descr, separator=self.separator)
for i in reversed(range(len(self.raw_descr))):
for j in reversed(range(len(self.raw_descr[i]))):
if self.raw_descr[i][j][0] == "num":
del self.raw_descr[i][j]
return VariationDescriptor(
raw_descr=self.raw_descr, index=self.index, separator=self.separator
)
@property
def identifier(self) -> str:
return "u_" + utils.to_62(hash(str(self.flat)))
class BranchDescriptor(SimulationDescriptor):
class BranchDescriptor(VariationDescriptor):
__ids: dict[int, int] = {}
@property
def identifier(self) -> str:
return "b_" + utils.to_62(hash(str(self.flat)))
branch_id = hash(str(self.flat))
self.__ids.setdefault(branch_id, len(self.__ids))
return str(self.__ids[branch_id])
@validator("raw_descr")
def validate_raw_descr(cls, v):