Sim starts, merge not

This commit is contained in:
Benoît Sierro
2021-09-27 13:21:02 +02:00
parent ef48711aa6
commit 695ac3bd73
9 changed files with 335 additions and 319 deletions

View File

@@ -5,4 +5,5 @@ from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
from .spectra import Pulse, Spectrum from .spectra import Pulse, Spectrum
from .utils import Paths, open_config, parameter from .utils import Paths, open_config, parameter
from .utils.parameter import Configuration, Parameters, PlotRange from .utils.parameter import Configuration, Parameters
from .utils.utils import PlotRange

View File

@@ -11,7 +11,7 @@ from send2trash import send2trash
from .. import env, utils from .. import env, utils
from ..logger import get_logger from ..logger import get_logger
from ..utils.parameter import Configuration, Parameters, format_variable_list from ..utils.parameter import Configuration, Parameters
from . import pulse from . import pulse
from .fiber import create_non_linear_op, fast_dispersion_op from .fiber import create_non_linear_op, fast_dispersion_op
@@ -466,14 +466,14 @@ class Simulations:
self.configuration = configuration self.configuration = configuration
self.name = self.configuration.final_path self.name = self.configuration.name
self.sim_dir = self.configuration.final_sim_dir self.sim_dir = self.configuration.final_path
self.configuration.save_parameters() self.configuration.save_parameters()
self.sim_jobs_per_node = 1 self.sim_jobs_per_node = 1
def finished_and_complete(self): def finished_and_complete(self):
for sim in self.configuration.all_configs_dict.values(): for sim in self.configuration.all_configs.values():
if ( if (
self.configuration.sim_status(sim.output_path)[0] self.configuration.sim_status(sim.output_path)[0]
!= self.configuration.State.COMPLETE != self.configuration.State.COMPLETE
@@ -487,7 +487,7 @@ class Simulations:
def _run_available(self): def _run_available(self):
for variable, params in self.configuration: for variable, params in self.configuration:
v_list_str = format_variable_list(variable, add_iden=True) v_list_str = variable.formatted_descriptor(True)
utils.save_parameters(params.prepare_for_dump(), Path(params.output_path)) utils.save_parameters(params.prepare_for_dump(), Path(params.output_path))
self.new_sim(v_list_str, params) self.new_sim(v_list_str, params)
@@ -526,7 +526,9 @@ class SequencialSimulations(Simulations, priority=0):
def __init__(self, configuration: Configuration, task_id): def __init__(self, configuration: Configuration, task_id):
super().__init__(configuration, task_id=task_id) super().__init__(configuration, task_id=task_id)
self.pbars = utils.PBars( self.pbars = utils.PBars(
self.configuration.total_num_steps, "Simulating " + self.configuration.final_path, 1 self.configuration.total_num_steps,
"Simulating " + self.configuration.final_path.name,
1,
) )
self.configuration.skip_callback = lambda num: self.pbars.update(0, num) self.configuration.skip_callback = lambda num: self.pbars.update(0, num)
@@ -569,7 +571,7 @@ class MultiProcSimulations(Simulations, priority=1):
self.p_worker = multiprocessing.Process( self.p_worker = multiprocessing.Process(
target=utils.progress_worker, target=utils.progress_worker,
args=( args=(
self.configuration.final_path, self.configuration.final_path.name,
self.sim_jobs_per_node, self.sim_jobs_per_node,
self.configuration.total_num_steps, self.configuration.total_num_steps,
self.progress_queue, self.progress_queue,
@@ -716,7 +718,7 @@ def run_simulation(
sim = new_simulation(config, method) sim = new_simulation(config, method)
sim.run() sim.run()
path_trees = utils.build_path_trees(config.sim_dirs[-1]) path_trees = utils.build_path_trees(config.fiber_paths[-1])
final_name = env.get(env.OUTPUT_PATH) final_name = env.get(env.OUTPUT_PATH)
if final_name is None: if final_name is None:
@@ -724,7 +726,7 @@ def run_simulation(
utils.merge(final_name, path_trees) utils.merge(final_name, path_trees)
try: try:
send2trash(config.sim_dirs) send2trash(config.fiber_paths)
except (PermissionError, OSError): except (PermissionError, OSError):
get_logger(__name__).error("Could not send temporary directories to trash") get_logger(__name__).error("Could not send temporary directories to trash")

View File

@@ -14,7 +14,8 @@ from .const import PARAM_SEPARATOR
from .defaults import default_plotting as defaults from .defaults import default_plotting as defaults
from .math import abs2, span from .math import abs2, span
from .physics import pulse, units from .physics import pulse, units
from .utils.parameter import Parameters, PlotRange, sort_axis from .utils.parameter import Parameters
from .utils.utils import PlotRange, sort_axis
RangeType = tuple[float, float, Union[str, Callable]] RangeType = tuple[float, float, Union[str, Callable]]
NO_LIM = object() NO_LIM = object()

View File

@@ -16,9 +16,8 @@ from ..utils import auto_crop, open_config, save_toml, translate_parameters
from ..utils.parameter import ( from ..utils.parameter import (
Configuration, Configuration,
Parameters, Parameters,
pretty_format_from_sim_name,
pretty_format_value,
) )
from ..utils.variationer import VariationDescriptor
def fingerprint(params: Parameters): def fingerprint(params: Parameters):
@@ -46,7 +45,7 @@ def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
path, fig, ax = plot_setup( path, fig, ax = plot_setup(
pulse.path.parent pulse.path.parent
/ ( / (
pretty_format_from_sim_name(pulse.path.name) pulse.path.name
+ PARAM_SEPARATOR + PARAM_SEPARATOR
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}" + f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
) )

View File

@@ -16,7 +16,8 @@ from .plotting import (
single_position_plot, single_position_plot,
transform_2D_propagation, transform_2D_propagation,
) )
from .utils.parameter import Parameters, PlotRange from .utils.parameter import Parameters
from .utils.utils import PlotRange
from .utils import load_spectrum from .utils import load_spectrum

View File

@@ -161,29 +161,35 @@ def save_toml(path: os.PathLike, dico):
return dico return dico
def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str, Any]], str]: def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]:
loaded_config = open_config(final_config_path) """loads a configuration file
final_name = loaded_config.get("name")
fiber_list = loaded_config.pop("Fiber")
configs = []
if fiber_list is not None:
master_variable = loaded_config.get("variable", {})
for i, params in enumerate(fiber_list):
params.setdefault("variable", master_variable if i == 0 else {})
if i == 0:
params["variable"] |= master_variable
configs.append(loaded_config | params)
else:
configs.append(loaded_config)
while "previous_config_file" in configs[0]:
configs.insert(0, open_config(configs[0]["previous_config_file"]))
configs[0].setdefault("variable", {})
for pre, nex in zip(configs[:-1], configs[1:]):
variable = nex.pop("variable", {})
nex.update({k: v for k, v in pre.items() if k not in nex})
nex["variable"] = variable
return configs, final_name Parameters
----------
path : os.PathLike
path to the config toml file
Returns
-------
final_path : Path
output name of the simulation
list[dict[str, Any]]
one config per fiber
"""
loaded_config = open_config(path)
fiber_list: list[dict[str, Any]] = loaded_config.pop("Fiber")
if len(fiber_list) == 0:
raise ValueError(f"No fiber in config {path}")
final_path = loaded_config.get("name")
configs = []
for i, params in enumerate(fiber_list):
params.setdefault("variable", {})
configs.append(loaded_config | params)
configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"]
return Path(final_path), configs
def save_parameters( def save_parameters(

View File

@@ -11,14 +11,18 @@ from dataclasses import asdict, dataclass, fields
from functools import cache, lru_cache from functools import cache, lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union
import numpy as np import numpy as np
from numpy.lib import isin from numpy.lib import isin
from scgenerator.utils import ensure_folder, variationer
from .. import math, utils from .. import math, utils
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__ from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
from ..errors import EvaluatorError, NoDefaultError from ..errors import EvaluatorError, NoDefaultError
from ..logger import get_logger from ..logger import get_logger
from ..physics import fiber, materials, pulse, units from ..physics import fiber, materials, pulse, units
from ..utils.variationer import VariationDescriptor, Variationer
from .utils import func_rewrite, _mock_function, get_arg_names
T = TypeVar("T") T = TypeVar("T")
@@ -256,7 +260,7 @@ class Parameter:
---------- ----------
tpe : type tpe : type
type of the paramter type of the paramter
validators : Callable[[str, Any], None] validator : Callable[[str, Any], None]
signature : validator(name, value) signature : validator(name, value)
must raise a ValueError when value doesn't fit the criteria checked by must raise a ValueError when value doesn't fit the criteria checked by
validator. name is passed to validator to be included in the error message validator. name is passed to validator to be included in the error message
@@ -290,7 +294,6 @@ class Parameter:
if isinstance(value, Parameter): if isinstance(value, Parameter):
defaut = None if self.default is None else copy(self.default) defaut = None if self.default is None else copy(self.default)
instance.__dict__[self.name] = defaut instance.__dict__[self.name] = defaut
# instance.__dict__[self.name] = None
else: else:
if value is not None: if value is not None:
if self.converter is not None: if self.converter is not None:
@@ -768,9 +771,11 @@ class Configuration:
obj with the output path of the simulation saved in its output_path attribute. obj with the output path of the simulation saved in its output_path attribute.
""" """
master_configs: list[dict[str, Any]] fiber_configs: list[dict[str, Any]]
sim_dirs: list[Path] master_config: dict[str, Any]
fiber_paths: list[Path]
num_sim: int num_sim: int
num_fibers: int
repeat: int repeat: int
z_num: int z_num: int
total_num_steps: int total_num_steps: int
@@ -778,19 +783,17 @@ class Configuration:
parallel: bool parallel: bool
overwrite: bool overwrite: bool
final_path: str final_path: str
all_configs_dict: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"] all_configs: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
all_configs_list: list[list["Configuration.__SimConfig"]]
@dataclass(frozen=True) @dataclass(frozen=True)
class __SimConfig: class __SimConfig:
vary_list: list[tuple[str, Any]] descriptor: VariationDescriptor
config: dict[str, Any] config: dict[str, Any]
output_path: Path output_path: Path
index: tuple[tuple[int, ...], ...]
@property @property
def sim_num(self) -> int: def sim_num(self) -> int:
return len(self.index) return len(self.descriptor.index)
class State(enum.Enum): class State(enum.Enum):
COMPLETE = enum.auto() COMPLETE = enum.auto()
@@ -810,48 +813,48 @@ class Configuration:
): ):
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
self.master_configs, self.final_path = utils.load_config_sequence(final_config_path) self.overwrite = overwrite
if self.final_path is None: self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
self.final_path = Parameters.name.default self.final_path = utils.ensure_folder(
self.name = Path(self.final_path).name self.final_path, mkdir=False, prevent_overwrite=not self.overwrite
)
self.master_config = self.fiber_configs[0]
self.name = self.final_path.name
self.z_num = 0 self.z_num = 0
self.total_num_steps = 0 self.total_num_steps = 0
self.sim_dirs = [] self.fiber_paths = []
self.overwrite = overwrite self.all_configs = {}
self.skip_callback = skip_callback self.skip_callback = skip_callback
self.worker_num = self.master_configs[0].get("worker_num", max(1, os.cpu_count() // 2)) self.worker_num = self.master_config.get("worker_num", max(1, os.cpu_count() // 2))
self.repeat = self.master_configs[0].get("repeat", 1) self.repeat = self.master_config.get("repeat", 1)
self.variationer = Variationer()
names = set() fiber_names = set()
for i, config in enumerate(self.master_configs): self.num_fibers = 0
for i, config in enumerate(self.fiber_configs):
config.setdefault("name", Parameters.name.default)
self.z_num += config["z_num"] self.z_num += config["z_num"]
config.setdefault("name", f"{Parameters.name.default} {i}") fiber_names.add(config["name"])
given_name = config["name"] self.variationer.append(config.pop("variable"))
fn_i = 0 self.fiber_paths.append(
while config["name"] in names:
config["name"] = given_name + f"_{fn_i}"
fn_i += 1
names.add(config["name"])
self.sim_dirs.append(
utils.ensure_folder( utils.ensure_folder(
Path("_".join(["_", self.name, Path(config["name"]).name, "_"])), self.fiber_path(i, config),
mkdir=False, mkdir=False,
prevent_overwrite=not self.overwrite, prevent_overwrite=not self.overwrite,
) )
) )
self.__validate_variable(config) self.__validate_variable(config)
self.__compute_sim_dirs() self.num_fibers += 1
[Evaluator.evaluate_default(c[0].config, True) for c in self.all_configs_list] Evaluator.evaluate_default(config, True)
self.num_sim = len(self.all_configs_list[-1]) self.num_sim = self.variationer.var_num()
self.total_num_steps = sum( self.total_num_steps = sum(
config["z_num"] * len(self.all_configs_list[i]) config["z_num"] * self.variationer.var_num(i)
for i, config in enumerate(self.master_configs) for i, config in enumerate(self.fiber_configs)
) )
self.final_sim_dir = utils.ensure_folder( self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
Path(self.master_configs[-1]["name"]), mkdir=False, prevent_overwrite=not self.overwrite
) def fiber_path(self, i: int, full_config: dict[str, Any]) -> Path:
self.parallel = self.master_configs[0].get("parallel", Parameters.parallel.default) return self.final_path / PARAM_SEPARATOR.join([format(i), self.name, full_config["name"]])
def __validate_variable(self, config: dict[str, Any]): def __validate_variable(self, config: dict[str, Any]):
for k, v in config.get("variable", {}).items(): for k, v in config.get("variable", {}).items():
@@ -862,76 +865,62 @@ class Configuration:
if len(v) == 0: if len(v) == 0:
raise ValueError(f"variable parameter {k!r} must not be empty") raise ValueError(f"variable parameter {k!r} must not be empty")
def __compute_sim_dirs(self): def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]:
self.all_configs_dict = {} for i in range(self.num_fibers):
self.all_configs_list = [] for sim_config in self.iterate_single_fiber(i):
self.master_configs[0]["variable"]["num"] = list(
range(self.master_configs[0].get("repeat", 1))
)
dp = DataPather([c["variable"] for c in self.master_configs])
for i, conf in enumerate(self.master_configs):
self.all_configs_list.append([])
for sim_index, prev_path, this_path, this_vary in dp.all_vary_list(i):
this_conf = conf.copy()
if i > 0: if i > 0:
prev_path = utils.ensure_folder( sim_config.config["prev_data_dir"] = str(
self.sim_dirs[i - 1] / prev_path, not self.overwrite, False self.fiber_paths[i - 1] / sim_config.descriptor[:i].formatted_descriptor()
) )
this_conf["prev_data_dir"] = str(prev_path) params = Parameters(**sim_config.config)
params.compute()
this_path = utils.ensure_folder(
self.sim_dirs[i] / this_path, not self.overwrite, False
)
this_conf.pop("variable")
conf_to_use = {k: v for k, v in this_vary if k != "num"} | this_conf
self.all_configs_dict[sim_index] = self.__SimConfig(
this_vary, conf_to_use, this_path, sim_index
)
self.all_configs_list[i].append(self.all_configs_dict[sim_index])
def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]:
for i, sim_config_list in enumerate(self.all_configs_list):
for sim_config, params in self.__iter_1_sim(sim_config_list):
fiber_map = [] fiber_map = []
for j in range(i + 1): for j in range(i + 1):
this_conf = self.all_configs_dict[sim_config.index[: j + 1]].config this_conf = self.all_configs[sim_config.descriptor.index[: j + 1]].config
if j > 0: if j > 0:
prev_conf = self.all_configs_dict[sim_config.index[:j]].config prev_conf = self.all_configs[sim_config.descriptor.index[:j]].config
length = prev_conf["length"] + fiber_map[j - 1][0] length = prev_conf["length"] + fiber_map[j - 1][0]
else: else:
length = 0.0 length = 0.0
fiber_map.append((length, this_conf["name"])) fiber_map.append((length, this_conf["name"]))
params.output_path = str(sim_config.output_path)
params.fiber_map = fiber_map params.fiber_map = fiber_map
yield sim_config.vary_list, params yield sim_config.descriptor, params
def __iter_1_sim( def iterate_single_fiber(
self, configs: list["Configuration.__SimConfig"] self, index: int
) -> Generator[tuple["Configuration.__SimConfig", Parameters], None, None]: ) -> Generator["Configuration.__SimConfig", None, None]:
"""iterates through the parameters of only one fiber. It takes care of recovering partially """iterates through the parameters of only one fiber. It takes care of recovering partially
completed simulations, skipping complete ones and waiting for the previous fiber to finish completed simulations, skipping complete ones and waiting for the previous fiber to finish
Parameters Parameters
---------- ----------
configs : list[__SimConfig] index : int
list of configuration obj which fiber to iterate over
Yields Yields
------- -------
__SimConfig __SimConfig
configuration obj configuration obj
Parameters
computed Parameters obj
""" """
sim_dict: dict[Path, Configuration.__SimConfig] = {s.output_path: s for s in configs} sim_dict: dict[Path, self.__SimConfig] = {}
for descr in self.variationer.iterate(index):
cfg = descr.update_config(self.fiber_configs[index])
p = ensure_folder(
self.fiber_paths[index] / descr.formatted_descriptor(),
not self.overwrite,
False,
)
cfg["output_path"] = str(p)
sim_config = self.__SimConfig(descr, cfg, p)
sim_dict[p] = sim_config
self.all_configs[sim_config.descriptor.index] = sim_config
while len(sim_dict) > 0: while len(sim_dict) > 0:
for data_dir, sim_config in sim_dict.items(): for data_dir, sim_config in sim_dict.items():
task, config_dict = self.__decide(sim_config) task, config_dict = self.__decide(sim_config)
if task == self.Action.RUN: if task == self.Action.RUN:
sim_dict.pop(data_dir) sim_dict.pop(data_dir)
p = Parameters(**config_dict) yield sim_config
p.compute()
yield sim_config, p
if "recovery_last_stored" in config_dict and self.skip_callback is not None: if "recovery_last_stored" in config_dict and self.skip_callback is not None:
self.skip_callback(config_dict["recovery_last_stored"]) self.skip_callback(config_dict["recovery_last_stored"])
break break
@@ -956,7 +945,7 @@ class Configuration:
Returns Returns
------- -------
str : {'run', 'wait', 'skip'} str : Configuration.Action
what to do what to do
config_dict : dict[str, Any] config_dict : dict[str, Any]
config dictionary. The only key possibly modified is 'prev_data_dir', which config dictionary. The only key possibly modified is 'prev_data_dir', which
@@ -1012,7 +1001,7 @@ class Configuration:
raise ValueError(f"Too many spectra in {data_dir}") raise ValueError(f"Too many spectra in {data_dir}")
def save_parameters(self): def save_parameters(self):
for config, sim_dir in zip(self.master_configs, self.sim_dirs): for config, sim_dir in zip(self.fiber_configs, self.fiber_paths):
os.makedirs(sim_dir, exist_ok=True) os.makedirs(sim_dir, exist_ok=True)
utils.save_toml(sim_dir / f"initial_config.toml", config) utils.save_toml(sim_dir / f"initial_config.toml", config)
@@ -1022,144 +1011,6 @@ class Configuration:
return param return param
@dataclass(frozen=True)
class PlotRange:
left: float = Parameter(type_checker(int, float))
right: float = Parameter(type_checker(int, float))
unit: Callable[[float], float] = Parameter(units.is_unit, converter=units.get_unit)
conserved_quantity: bool = Parameter(boolean, default=True)
def __str__(self):
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
return sort_axis(axis, self)
def __iter__(self):
yield self.left
yield self.right
yield self.unit.__name__
def sort_axis(
axis: np.ndarray, plt_range: PlotRange
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
"""
given an axis, returns this axis cropped according to the given range, converted and sorted
Parameters
----------
axis : 1D array containing the original axis (usual the w or t array)
plt_range : tupple (min, max, conversion_function) used to crop the axis
Returns
-------
cropped : the axis cropped, converted and sorted
indices : indices to use to slice and sort other array in the same fashion
extent : tupple with min and max of cropped
Example
-------
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
t = np.linspace(-10, 10, 400)
W, T = np.meshgrid(w, t)
y = np.exp(-W**2 - T**2)
# Define ranges
rw = (-4, 4, s)
rt = (-2, 6, s)
w, cw = sort_axis(w, rw)
t, ct = sort_axis(t, rt)
# slice y according to the given ranges
y = y[ct][:, cw]
"""
if isinstance(plt_range, tuple):
plt_range = PlotRange(*plt_range)
r = np.array((plt_range.left, plt_range.right), dtype="float")
indices = np.arange(len(axis))[
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
]
cropped = axis[indices]
order = np.argsort(plt_range.unit.inv(cropped))
indices = indices[order]
cropped = cropped[order]
out_ax = plt_range.unit.inv(cropped)
return out_ax, indices, (out_ax[0], out_ax[-1])
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
out_func = scope[tmp_name]
out_func.__module__ = "evaluator"
return out_func
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
if not isinstance(num_args, int) and isinstance(num_returns, int):
raise TypeError(f"num_args and num_returns must be int")
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func
def pretty_format_from_sim_name(name: str) -> str:
"""formats a pretty version of a simulation directory
Parameters
----------
name : str
name of the simulation (directory name)
Returns
-------
str
prettier name
"""
s = name.split(PARAM_SEPARATOR)
out = []
for key, value in zip(s[::2], s[1::2]):
try:
out += [key.replace("_", " "), getattr(Parameters, key).display(float(value))]
except (AttributeError, ValueError):
out.append(key + PARAM_SEPARATOR + value)
return PARAM_SEPARATOR.join(out)
default_rules: list[Rule] = [ default_rules: list[Rule] = [
# Grid # Grid
*Rule.deduce( *Rule.deduce(

View File

@@ -1,4 +1,23 @@
import inspect
import re
from functools import cache
from string import printable as str_printable from string import printable as str_printable
from typing import Callable
import numpy as np
from pydantic import BaseModel
from ..physics.units import get_unit
class HashableBaseModel(BaseModel):
"""Pydantic BaseModel that's immutable and can be hashed"""
def __hash__(self) -> int:
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
class Config:
allow_mutation = False
def to_62(i: int) -> str: def to_62(i: int) -> str:
@@ -10,3 +29,118 @@ def to_62(i: int) -> str:
i, value = divmod(i, 62) i, value = divmod(i, 62)
arr.append(str_printable[value]) arr.append(str_printable[value])
return "".join(reversed(arr)) return "".join(reversed(arr))
class PlotRange(HashableBaseModel):
left: float
right: float
unit: Callable[[float], float]
conserved_quantity: bool = True
def __init__(self, left, right, unit, **kwargs):
super().__init__(left=left, right=right, unit=get_unit(unit), **kwargs)
def __str__(self):
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
return sort_axis(axis, self)
def __iter__(self):
yield self.left
yield self.right
yield self.unit.__name__
def sort_axis(
axis: np.ndarray, plt_range: PlotRange
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
"""
given an axis, returns this axis cropped according to the given range, converted and sorted
Parameters
----------
axis : 1D array containing the original axis (usual the w or t array)
plt_range : tupple (min, max, conversion_function) used to crop the axis
Returns
-------
cropped : the axis cropped, converted and sorted
indices : indices to use to slice and sort other array in the same fashion
extent : tupple with min and max of cropped
Example
-------
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
t = np.linspace(-10, 10, 400)
W, T = np.meshgrid(w, t)
y = np.exp(-W**2 - T**2)
# Define ranges
rw = (-4, 4, s)
rt = (-2, 6, s)
w, cw = sort_axis(w, rw)
t, ct = sort_axis(t, rt)
# slice y according to the given ranges
y = y[ct][:, cw]
"""
if isinstance(plt_range, tuple):
plt_range = PlotRange(*plt_range)
r = np.array((plt_range.left, plt_range.right), dtype="float")
indices = np.arange(len(axis))[
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
]
cropped = axis[indices]
order = np.argsort(plt_range.unit.inv(cropped))
indices = indices[order]
cropped = cropped[order]
out_ax = plt_range.unit.inv(cropped)
return out_ax, indices, (out_ax[0], out_ax[-1])
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
out_func = scope[tmp_name]
out_func.__module__ = "evaluator"
return out_func
@cache
def _mock_function(num_args: int, num_returns: int) -> Callable:
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
return_str = ", ".join("True" for _ in range(num_returns))
func_name = f"__mock_{num_args}_{num_returns}"
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
scope = {}
exec(func_str, scope)
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func

View File

@@ -1,45 +1,14 @@
from pydantic import BaseModel, validator from math import prod
from typing import Union, Iterable, Generator, Any
from collections.abc import Sequence, MutableMapping
import itertools import itertools
from collections.abc import MutableMapping, Sequence
from pathlib import Path
from typing import Any, Callable, Generator, Iterable, Union
import numpy as np
from pydantic import validator
from ..const import PARAM_SEPARATOR from ..const import PARAM_SEPARATOR
from . import utils from . import utils
import numpy as np
from pathlib import Path
def format_value(name: str, value) -> str:
if value is True or value is False:
return str(value)
elif isinstance(value, (float, int)):
try:
return getattr(Parameters, name).display(value)
except AttributeError:
return format(value, ".9g")
elif isinstance(value, (list, tuple, np.ndarray)):
return "-".join([str(v) for v in value])
elif isinstance(value, str):
p = Path(value)
if p.exists():
return p.stem
return str(value)
def pretty_format_value(name: str, value) -> str:
try:
return getattr(Parameters, name).display(value)
except AttributeError:
return name + PARAM_SEPARATOR + str(value)
class HashableBaseModel(BaseModel):
"""Pydantic BaseModel that's immutable and can be hashed"""
def __hash__(self) -> int:
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
class Config:
allow_mutation = False
class VariationSpecsError(ValueError): class VariationSpecsError(ValueError):
@@ -67,23 +36,20 @@ class Variationer:
all_indices: list[list[int]] all_indices: list[list[int]]
all_dicts: list[list[dict[str, list]]] all_dicts: list[list[dict[str, list]]]
def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]]): def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]] = None):
self.all_indices = [] self.all_indices = []
self.all_dicts = [] self.all_dicts = []
for i, el in enumerate(variables): if variables is not None:
if not isinstance(el, Sequence): for i, el in enumerate(variables):
el = [{k: v} for k, v in el.items()] self.append(el)
else:
el = list(el)
self.append(el)
def append(self, var_list: list[dict[str, list]]): def append(self, var_list: Union[list[MutableMapping], MutableMapping]):
"""append a list of variable parameter sets """append a list of variable parameter sets
each call to append creates a new group of parameters each call to append creates a new group of parameters
Parameters Parameters
---------- ----------
var_list : list[dict[str, list]] var_list : Union[list[MutableMapping], MutableMapping]
each dict in the list is treated as an independent parameter each dict in the list is treated as an independent parameter
this means that if for one dict, len > 1, the lists of possible values this means that if for one dict, len > 1, the lists of possible values
must be the same length must be the same length
@@ -100,6 +66,10 @@ class Variationer:
VariationSpecsError VariationSpecsError
raised when possible values lists in a same dict are not the same length raised when possible values lists in a same dict are not the same length
""" """
if not isinstance(var_list, Sequence):
var_list = [{k: v} for k, v in var_list.items()]
else:
var_list = list(var_list)
num_vars = [] num_vars = []
for d in var_list: for d in var_list:
values = list(d.values()) values = list(d.values())
@@ -114,30 +84,43 @@ class Variationer:
self.all_indices.append(num_vars) self.all_indices.append(num_vars)
self.all_dicts.append(var_list) self.all_dicts.append(var_list)
def iterate(self, index: int = -1) -> Generator["SimulationDescriptor", None, None]: def iterate(self, index: int = -1) -> Generator["VariationDescriptor", None, None]:
if index < 0: index = self.__index(index)
index = len(self.all_indices) + index + 1 flattened_indices = sum(self.all_indices[: index + 1], [])
flattened_indices = sum(self.all_indices[:index], []) index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[: index + 1]])
index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[:index]])
ranges = [range(i) for i in flattened_indices] ranges = [range(i) for i in flattened_indices]
for r in itertools.product(*ranges): for r in itertools.product(*ranges):
out: list[list[tuple[str, Any]]] = [] out: list[list[tuple[str, Any]]] = []
indicies: list[list[int]] = []
for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])): for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])):
out.append([]) out.append([])
indicies.append([])
for value_index, var_d in zip(r[start:end], self.all_dicts[i]): for value_index, var_d in zip(r[start:end], self.all_dicts[i]):
for k, v in var_d.items(): for k, v in var_d.items():
out[-1].append((k, v[value_index])) out[-1].append((k, v[value_index]))
yield SimulationDescriptor(raw_descr=out) indicies[-1].append(value_index)
yield VariationDescriptor(raw_descr=out, index=indicies)
def __index(self, index: int) -> int:
if index < 0:
index = len(self.all_indices) + index
return index
def var_num(self, index: int = -1) -> int:
index = self.__index(index)
return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
class SimulationDescriptor(HashableBaseModel): class VariationDescriptor(utils.HashableBaseModel):
raw_descr: tuple[tuple[tuple[str, Any], ...], ...] raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
index: tuple[tuple[int, ...], ...]
separator: str = "fiber" separator: str = "fiber"
_format_registry: dict[str, Callable[..., str]] = {}
def __str__(self) -> str: def __str__(self) -> str:
return self.descriptor(add_identifier=False) return self.formatted_descriptor(add_identifier=False)
def descriptor(self, add_identifier=False) -> str: def formatted_descriptor(self, add_identifier=False) -> str:
"""formats a variable list into a str such that each simulation has a unique """formats a variable list into a str such that each simulation has a unique
directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations) directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations)
branch identifier can added at the beginning. branch identifier can added at the beginning.
@@ -156,7 +139,7 @@ class SimulationDescriptor(HashableBaseModel):
for p_name, p_value in self.flat: for p_name, p_value in self.flat:
ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "") ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "")
vs = format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "") vs = self.format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
str_list.append(ps + PARAM_SEPARATOR + vs) str_list.append(ps + PARAM_SEPARATOR + vs)
tmp_name = PARAM_SEPARATOR.join(str_list) tmp_name = PARAM_SEPARATOR.join(str_list)
if not add_identifier: if not add_identifier:
@@ -165,6 +148,34 @@ class SimulationDescriptor(HashableBaseModel):
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
) )
@classmethod
def register_formatter(cls, p_name: str, func: Callable[..., str]):
cls._format_registry[p_name] = func
def format_value(self, name: str, value) -> str:
if value is True or value is False:
return str(value)
elif isinstance(value, (float, int)):
try:
return self._format_registry[name](value)
except KeyError:
return format(value, ".9g")
elif isinstance(value, (list, tuple, np.ndarray)):
return "-".join([str(v) for v in value])
elif isinstance(value, str):
p = Path(value)
if p.exists():
return p.stem
return str(value)
def __getitem__(self, key) -> "VariationDescriptor":
return VariationDescriptor(
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
)
def update_config(self, cfg: dict[str, Any]):
return cfg | {k: v for k, v in self.raw_descr[-1]}
@property @property
def flat(self) -> list[tuple[str, Any]]: def flat(self) -> list[tuple[str, Any]]:
out = [] out = []
@@ -177,17 +188,27 @@ class SimulationDescriptor(HashableBaseModel):
@property @property
def branch(self) -> "BranchDescriptor": def branch(self) -> "BranchDescriptor":
return SimulationDescriptor(raw_descr=self.raw_descr, separator=self.separator) for i in reversed(range(len(self.raw_descr))):
for j in reversed(range(len(self.raw_descr[i]))):
if self.raw_descr[i][j][0] == "num":
del self.raw_descr[i][j]
return VariationDescriptor(
raw_descr=self.raw_descr, index=self.index, separator=self.separator
)
@property @property
def identifier(self) -> str: def identifier(self) -> str:
return "u_" + utils.to_62(hash(str(self.flat))) return "u_" + utils.to_62(hash(str(self.flat)))
class BranchDescriptor(SimulationDescriptor): class BranchDescriptor(VariationDescriptor):
__ids: dict[int, int] = {}
@property @property
def identifier(self) -> str: def identifier(self) -> str:
return "b_" + utils.to_62(hash(str(self.flat))) branch_id = hash(str(self.flat))
self.__ids.setdefault(branch_id, len(self.__ids))
return str(self.__ids[branch_id])
@validator("raw_descr") @validator("raw_descr")
def validate_raw_descr(cls, v): def validate_raw_descr(cls, v):