Sim starts, merge not
This commit is contained in:
@@ -5,4 +5,5 @@ from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
|
||||
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
||||
from .spectra import Pulse, Spectrum
|
||||
from .utils import Paths, open_config, parameter
|
||||
from .utils.parameter import Configuration, Parameters, PlotRange
|
||||
from .utils.parameter import Configuration, Parameters
|
||||
from .utils.utils import PlotRange
|
||||
|
||||
@@ -11,7 +11,7 @@ from send2trash import send2trash
|
||||
|
||||
from .. import env, utils
|
||||
from ..logger import get_logger
|
||||
from ..utils.parameter import Configuration, Parameters, format_variable_list
|
||||
from ..utils.parameter import Configuration, Parameters
|
||||
from . import pulse
|
||||
from .fiber import create_non_linear_op, fast_dispersion_op
|
||||
|
||||
@@ -466,14 +466,14 @@ class Simulations:
|
||||
|
||||
self.configuration = configuration
|
||||
|
||||
self.name = self.configuration.final_path
|
||||
self.sim_dir = self.configuration.final_sim_dir
|
||||
self.name = self.configuration.name
|
||||
self.sim_dir = self.configuration.final_path
|
||||
self.configuration.save_parameters()
|
||||
|
||||
self.sim_jobs_per_node = 1
|
||||
|
||||
def finished_and_complete(self):
|
||||
for sim in self.configuration.all_configs_dict.values():
|
||||
for sim in self.configuration.all_configs.values():
|
||||
if (
|
||||
self.configuration.sim_status(sim.output_path)[0]
|
||||
!= self.configuration.State.COMPLETE
|
||||
@@ -487,7 +487,7 @@ class Simulations:
|
||||
|
||||
def _run_available(self):
|
||||
for variable, params in self.configuration:
|
||||
v_list_str = format_variable_list(variable, add_iden=True)
|
||||
v_list_str = variable.formatted_descriptor(True)
|
||||
utils.save_parameters(params.prepare_for_dump(), Path(params.output_path))
|
||||
|
||||
self.new_sim(v_list_str, params)
|
||||
@@ -526,7 +526,9 @@ class SequencialSimulations(Simulations, priority=0):
|
||||
def __init__(self, configuration: Configuration, task_id):
|
||||
super().__init__(configuration, task_id=task_id)
|
||||
self.pbars = utils.PBars(
|
||||
self.configuration.total_num_steps, "Simulating " + self.configuration.final_path, 1
|
||||
self.configuration.total_num_steps,
|
||||
"Simulating " + self.configuration.final_path.name,
|
||||
1,
|
||||
)
|
||||
self.configuration.skip_callback = lambda num: self.pbars.update(0, num)
|
||||
|
||||
@@ -569,7 +571,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
||||
self.p_worker = multiprocessing.Process(
|
||||
target=utils.progress_worker,
|
||||
args=(
|
||||
self.configuration.final_path,
|
||||
self.configuration.final_path.name,
|
||||
self.sim_jobs_per_node,
|
||||
self.configuration.total_num_steps,
|
||||
self.progress_queue,
|
||||
@@ -716,7 +718,7 @@ def run_simulation(
|
||||
|
||||
sim = new_simulation(config, method)
|
||||
sim.run()
|
||||
path_trees = utils.build_path_trees(config.sim_dirs[-1])
|
||||
path_trees = utils.build_path_trees(config.fiber_paths[-1])
|
||||
|
||||
final_name = env.get(env.OUTPUT_PATH)
|
||||
if final_name is None:
|
||||
@@ -724,7 +726,7 @@ def run_simulation(
|
||||
|
||||
utils.merge(final_name, path_trees)
|
||||
try:
|
||||
send2trash(config.sim_dirs)
|
||||
send2trash(config.fiber_paths)
|
||||
except (PermissionError, OSError):
|
||||
get_logger(__name__).error("Could not send temporary directories to trash")
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ from .const import PARAM_SEPARATOR
|
||||
from .defaults import default_plotting as defaults
|
||||
from .math import abs2, span
|
||||
from .physics import pulse, units
|
||||
from .utils.parameter import Parameters, PlotRange, sort_axis
|
||||
from .utils.parameter import Parameters
|
||||
from .utils.utils import PlotRange, sort_axis
|
||||
|
||||
RangeType = tuple[float, float, Union[str, Callable]]
|
||||
NO_LIM = object()
|
||||
|
||||
@@ -16,9 +16,8 @@ from ..utils import auto_crop, open_config, save_toml, translate_parameters
|
||||
from ..utils.parameter import (
|
||||
Configuration,
|
||||
Parameters,
|
||||
pretty_format_from_sim_name,
|
||||
pretty_format_value,
|
||||
)
|
||||
from ..utils.variationer import VariationDescriptor
|
||||
|
||||
|
||||
def fingerprint(params: Parameters):
|
||||
@@ -46,7 +45,7 @@ def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
|
||||
path, fig, ax = plot_setup(
|
||||
pulse.path.parent
|
||||
/ (
|
||||
pretty_format_from_sim_name(pulse.path.name)
|
||||
pulse.path.name
|
||||
+ PARAM_SEPARATOR
|
||||
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
|
||||
)
|
||||
|
||||
@@ -16,7 +16,8 @@ from .plotting import (
|
||||
single_position_plot,
|
||||
transform_2D_propagation,
|
||||
)
|
||||
from .utils.parameter import Parameters, PlotRange
|
||||
from .utils.parameter import Parameters
|
||||
from .utils.utils import PlotRange
|
||||
from .utils import load_spectrum
|
||||
|
||||
|
||||
|
||||
@@ -161,29 +161,35 @@ def save_toml(path: os.PathLike, dico):
|
||||
return dico
|
||||
|
||||
|
||||
def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str, Any]], str]:
|
||||
loaded_config = open_config(final_config_path)
|
||||
final_name = loaded_config.get("name")
|
||||
fiber_list = loaded_config.pop("Fiber")
|
||||
configs = []
|
||||
if fiber_list is not None:
|
||||
master_variable = loaded_config.get("variable", {})
|
||||
for i, params in enumerate(fiber_list):
|
||||
params.setdefault("variable", master_variable if i == 0 else {})
|
||||
if i == 0:
|
||||
params["variable"] |= master_variable
|
||||
configs.append(loaded_config | params)
|
||||
else:
|
||||
configs.append(loaded_config)
|
||||
while "previous_config_file" in configs[0]:
|
||||
configs.insert(0, open_config(configs[0]["previous_config_file"]))
|
||||
configs[0].setdefault("variable", {})
|
||||
for pre, nex in zip(configs[:-1], configs[1:]):
|
||||
variable = nex.pop("variable", {})
|
||||
nex.update({k: v for k, v in pre.items() if k not in nex})
|
||||
nex["variable"] = variable
|
||||
def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]:
|
||||
"""loads a configuration file
|
||||
|
||||
return configs, final_name
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
path to the config toml file
|
||||
|
||||
Returns
|
||||
-------
|
||||
final_path : Path
|
||||
output name of the simulation
|
||||
list[dict[str, Any]]
|
||||
one config per fiber
|
||||
|
||||
"""
|
||||
loaded_config = open_config(path)
|
||||
|
||||
fiber_list: list[dict[str, Any]] = loaded_config.pop("Fiber")
|
||||
if len(fiber_list) == 0:
|
||||
raise ValueError(f"No fiber in config {path}")
|
||||
final_path = loaded_config.get("name")
|
||||
configs = []
|
||||
for i, params in enumerate(fiber_list):
|
||||
params.setdefault("variable", {})
|
||||
configs.append(loaded_config | params)
|
||||
configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"]
|
||||
|
||||
return Path(final_path), configs
|
||||
|
||||
|
||||
def save_parameters(
|
||||
|
||||
@@ -11,14 +11,18 @@ from dataclasses import asdict, dataclass, fields
|
||||
from functools import cache, lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib import isin
|
||||
from scgenerator.utils import ensure_folder, variationer
|
||||
|
||||
from .. import math, utils
|
||||
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
|
||||
from ..errors import EvaluatorError, NoDefaultError
|
||||
from ..logger import get_logger
|
||||
from ..physics import fiber, materials, pulse, units
|
||||
from ..utils.variationer import VariationDescriptor, Variationer
|
||||
from .utils import func_rewrite, _mock_function, get_arg_names
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -256,7 +260,7 @@ class Parameter:
|
||||
----------
|
||||
tpe : type
|
||||
type of the paramter
|
||||
validators : Callable[[str, Any], None]
|
||||
validator : Callable[[str, Any], None]
|
||||
signature : validator(name, value)
|
||||
must raise a ValueError when value doesn't fit the criteria checked by
|
||||
validator. name is passed to validator to be included in the error message
|
||||
@@ -290,7 +294,6 @@ class Parameter:
|
||||
if isinstance(value, Parameter):
|
||||
defaut = None if self.default is None else copy(self.default)
|
||||
instance.__dict__[self.name] = defaut
|
||||
# instance.__dict__[self.name] = None
|
||||
else:
|
||||
if value is not None:
|
||||
if self.converter is not None:
|
||||
@@ -768,9 +771,11 @@ class Configuration:
|
||||
obj with the output path of the simulation saved in its output_path attribute.
|
||||
"""
|
||||
|
||||
master_configs: list[dict[str, Any]]
|
||||
sim_dirs: list[Path]
|
||||
fiber_configs: list[dict[str, Any]]
|
||||
master_config: dict[str, Any]
|
||||
fiber_paths: list[Path]
|
||||
num_sim: int
|
||||
num_fibers: int
|
||||
repeat: int
|
||||
z_num: int
|
||||
total_num_steps: int
|
||||
@@ -778,19 +783,17 @@ class Configuration:
|
||||
parallel: bool
|
||||
overwrite: bool
|
||||
final_path: str
|
||||
all_configs_dict: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
|
||||
all_configs_list: list[list["Configuration.__SimConfig"]]
|
||||
all_configs: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class __SimConfig:
|
||||
vary_list: list[tuple[str, Any]]
|
||||
descriptor: VariationDescriptor
|
||||
config: dict[str, Any]
|
||||
output_path: Path
|
||||
index: tuple[tuple[int, ...], ...]
|
||||
|
||||
@property
|
||||
def sim_num(self) -> int:
|
||||
return len(self.index)
|
||||
return len(self.descriptor.index)
|
||||
|
||||
class State(enum.Enum):
|
||||
COMPLETE = enum.auto()
|
||||
@@ -810,48 +813,48 @@ class Configuration:
|
||||
):
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
self.master_configs, self.final_path = utils.load_config_sequence(final_config_path)
|
||||
if self.final_path is None:
|
||||
self.final_path = Parameters.name.default
|
||||
self.name = Path(self.final_path).name
|
||||
self.overwrite = overwrite
|
||||
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
|
||||
self.final_path = utils.ensure_folder(
|
||||
self.final_path, mkdir=False, prevent_overwrite=not self.overwrite
|
||||
)
|
||||
self.master_config = self.fiber_configs[0]
|
||||
self.name = self.final_path.name
|
||||
self.z_num = 0
|
||||
self.total_num_steps = 0
|
||||
self.sim_dirs = []
|
||||
self.overwrite = overwrite
|
||||
self.fiber_paths = []
|
||||
self.all_configs = {}
|
||||
self.skip_callback = skip_callback
|
||||
self.worker_num = self.master_configs[0].get("worker_num", max(1, os.cpu_count() // 2))
|
||||
self.repeat = self.master_configs[0].get("repeat", 1)
|
||||
self.worker_num = self.master_config.get("worker_num", max(1, os.cpu_count() // 2))
|
||||
self.repeat = self.master_config.get("repeat", 1)
|
||||
self.variationer = Variationer()
|
||||
|
||||
names = set()
|
||||
for i, config in enumerate(self.master_configs):
|
||||
fiber_names = set()
|
||||
self.num_fibers = 0
|
||||
for i, config in enumerate(self.fiber_configs):
|
||||
config.setdefault("name", Parameters.name.default)
|
||||
self.z_num += config["z_num"]
|
||||
config.setdefault("name", f"{Parameters.name.default} {i}")
|
||||
given_name = config["name"]
|
||||
fn_i = 0
|
||||
while config["name"] in names:
|
||||
config["name"] = given_name + f"_{fn_i}"
|
||||
fn_i += 1
|
||||
names.add(config["name"])
|
||||
|
||||
self.sim_dirs.append(
|
||||
fiber_names.add(config["name"])
|
||||
self.variationer.append(config.pop("variable"))
|
||||
self.fiber_paths.append(
|
||||
utils.ensure_folder(
|
||||
Path("_".join(["_", self.name, Path(config["name"]).name, "_"])),
|
||||
self.fiber_path(i, config),
|
||||
mkdir=False,
|
||||
prevent_overwrite=not self.overwrite,
|
||||
)
|
||||
)
|
||||
self.__validate_variable(config)
|
||||
self.__compute_sim_dirs()
|
||||
[Evaluator.evaluate_default(c[0].config, True) for c in self.all_configs_list]
|
||||
self.num_sim = len(self.all_configs_list[-1])
|
||||
self.num_fibers += 1
|
||||
Evaluator.evaluate_default(config, True)
|
||||
self.num_sim = self.variationer.var_num()
|
||||
self.total_num_steps = sum(
|
||||
config["z_num"] * len(self.all_configs_list[i])
|
||||
for i, config in enumerate(self.master_configs)
|
||||
config["z_num"] * self.variationer.var_num(i)
|
||||
for i, config in enumerate(self.fiber_configs)
|
||||
)
|
||||
self.final_sim_dir = utils.ensure_folder(
|
||||
Path(self.master_configs[-1]["name"]), mkdir=False, prevent_overwrite=not self.overwrite
|
||||
)
|
||||
self.parallel = self.master_configs[0].get("parallel", Parameters.parallel.default)
|
||||
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
|
||||
|
||||
def fiber_path(self, i: int, full_config: dict[str, Any]) -> Path:
|
||||
return self.final_path / PARAM_SEPARATOR.join([format(i), self.name, full_config["name"]])
|
||||
|
||||
def __validate_variable(self, config: dict[str, Any]):
|
||||
for k, v in config.get("variable", {}).items():
|
||||
@@ -862,76 +865,62 @@ class Configuration:
|
||||
if len(v) == 0:
|
||||
raise ValueError(f"variable parameter {k!r} must not be empty")
|
||||
|
||||
def __compute_sim_dirs(self):
|
||||
self.all_configs_dict = {}
|
||||
self.all_configs_list = []
|
||||
self.master_configs[0]["variable"]["num"] = list(
|
||||
range(self.master_configs[0].get("repeat", 1))
|
||||
)
|
||||
dp = DataPather([c["variable"] for c in self.master_configs])
|
||||
for i, conf in enumerate(self.master_configs):
|
||||
self.all_configs_list.append([])
|
||||
for sim_index, prev_path, this_path, this_vary in dp.all_vary_list(i):
|
||||
this_conf = conf.copy()
|
||||
def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]:
|
||||
for i in range(self.num_fibers):
|
||||
for sim_config in self.iterate_single_fiber(i):
|
||||
|
||||
if i > 0:
|
||||
prev_path = utils.ensure_folder(
|
||||
self.sim_dirs[i - 1] / prev_path, not self.overwrite, False
|
||||
sim_config.config["prev_data_dir"] = str(
|
||||
self.fiber_paths[i - 1] / sim_config.descriptor[:i].formatted_descriptor()
|
||||
)
|
||||
this_conf["prev_data_dir"] = str(prev_path)
|
||||
|
||||
this_path = utils.ensure_folder(
|
||||
self.sim_dirs[i] / this_path, not self.overwrite, False
|
||||
)
|
||||
this_conf.pop("variable")
|
||||
conf_to_use = {k: v for k, v in this_vary if k != "num"} | this_conf
|
||||
self.all_configs_dict[sim_index] = self.__SimConfig(
|
||||
this_vary, conf_to_use, this_path, sim_index
|
||||
)
|
||||
self.all_configs_list[i].append(self.all_configs_dict[sim_index])
|
||||
|
||||
def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]:
|
||||
for i, sim_config_list in enumerate(self.all_configs_list):
|
||||
for sim_config, params in self.__iter_1_sim(sim_config_list):
|
||||
params = Parameters(**sim_config.config)
|
||||
params.compute()
|
||||
fiber_map = []
|
||||
for j in range(i + 1):
|
||||
this_conf = self.all_configs_dict[sim_config.index[: j + 1]].config
|
||||
this_conf = self.all_configs[sim_config.descriptor.index[: j + 1]].config
|
||||
if j > 0:
|
||||
prev_conf = self.all_configs_dict[sim_config.index[:j]].config
|
||||
prev_conf = self.all_configs[sim_config.descriptor.index[:j]].config
|
||||
length = prev_conf["length"] + fiber_map[j - 1][0]
|
||||
else:
|
||||
length = 0.0
|
||||
fiber_map.append((length, this_conf["name"]))
|
||||
params.output_path = str(sim_config.output_path)
|
||||
params.fiber_map = fiber_map
|
||||
yield sim_config.vary_list, params
|
||||
yield sim_config.descriptor, params
|
||||
|
||||
def __iter_1_sim(
|
||||
self, configs: list["Configuration.__SimConfig"]
|
||||
) -> Generator[tuple["Configuration.__SimConfig", Parameters], None, None]:
|
||||
def iterate_single_fiber(
|
||||
self, index: int
|
||||
) -> Generator["Configuration.__SimConfig", None, None]:
|
||||
"""iterates through the parameters of only one fiber. It takes care of recovering partially
|
||||
completed simulations, skipping complete ones and waiting for the previous fiber to finish
|
||||
|
||||
Parameters
|
||||
----------
|
||||
configs : list[__SimConfig]
|
||||
list of configuration obj
|
||||
index : int
|
||||
which fiber to iterate over
|
||||
|
||||
Yields
|
||||
-------
|
||||
__SimConfig
|
||||
configuration obj
|
||||
Parameters
|
||||
computed Parameters obj
|
||||
"""
|
||||
sim_dict: dict[Path, Configuration.__SimConfig] = {s.output_path: s for s in configs}
|
||||
sim_dict: dict[Path, self.__SimConfig] = {}
|
||||
for descr in self.variationer.iterate(index):
|
||||
cfg = descr.update_config(self.fiber_configs[index])
|
||||
p = ensure_folder(
|
||||
self.fiber_paths[index] / descr.formatted_descriptor(),
|
||||
not self.overwrite,
|
||||
False,
|
||||
)
|
||||
cfg["output_path"] = str(p)
|
||||
sim_config = self.__SimConfig(descr, cfg, p)
|
||||
sim_dict[p] = sim_config
|
||||
self.all_configs[sim_config.descriptor.index] = sim_config
|
||||
while len(sim_dict) > 0:
|
||||
for data_dir, sim_config in sim_dict.items():
|
||||
task, config_dict = self.__decide(sim_config)
|
||||
if task == self.Action.RUN:
|
||||
sim_dict.pop(data_dir)
|
||||
p = Parameters(**config_dict)
|
||||
p.compute()
|
||||
yield sim_config, p
|
||||
yield sim_config
|
||||
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
|
||||
self.skip_callback(config_dict["recovery_last_stored"])
|
||||
break
|
||||
@@ -956,7 +945,7 @@ class Configuration:
|
||||
|
||||
Returns
|
||||
-------
|
||||
str : {'run', 'wait', 'skip'}
|
||||
str : Configuration.Action
|
||||
what to do
|
||||
config_dict : dict[str, Any]
|
||||
config dictionary. The only key possibly modified is 'prev_data_dir', which
|
||||
@@ -1012,7 +1001,7 @@ class Configuration:
|
||||
raise ValueError(f"Too many spectra in {data_dir}")
|
||||
|
||||
def save_parameters(self):
|
||||
for config, sim_dir in zip(self.master_configs, self.sim_dirs):
|
||||
for config, sim_dir in zip(self.fiber_configs, self.fiber_paths):
|
||||
os.makedirs(sim_dir, exist_ok=True)
|
||||
utils.save_toml(sim_dir / f"initial_config.toml", config)
|
||||
|
||||
@@ -1022,144 +1011,6 @@ class Configuration:
|
||||
return param
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlotRange:
|
||||
left: float = Parameter(type_checker(int, float))
|
||||
right: float = Parameter(type_checker(int, float))
|
||||
unit: Callable[[float], float] = Parameter(units.is_unit, converter=units.get_unit)
|
||||
conserved_quantity: bool = Parameter(boolean, default=True)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||
|
||||
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
return sort_axis(axis, self)
|
||||
|
||||
def __iter__(self):
|
||||
yield self.left
|
||||
yield self.right
|
||||
yield self.unit.__name__
|
||||
|
||||
|
||||
def sort_axis(
|
||||
axis: np.ndarray, plt_range: PlotRange
|
||||
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
"""
|
||||
given an axis, returns this axis cropped according to the given range, converted and sorted
|
||||
|
||||
Parameters
|
||||
----------
|
||||
axis : 1D array containing the original axis (usual the w or t array)
|
||||
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
||||
|
||||
Returns
|
||||
-------
|
||||
cropped : the axis cropped, converted and sorted
|
||||
indices : indices to use to slice and sort other array in the same fashion
|
||||
extent : tupple with min and max of cropped
|
||||
|
||||
Example
|
||||
-------
|
||||
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
||||
t = np.linspace(-10, 10, 400)
|
||||
W, T = np.meshgrid(w, t)
|
||||
y = np.exp(-W**2 - T**2)
|
||||
|
||||
# Define ranges
|
||||
rw = (-4, 4, s)
|
||||
rt = (-2, 6, s)
|
||||
|
||||
w, cw = sort_axis(w, rw)
|
||||
t, ct = sort_axis(t, rt)
|
||||
|
||||
# slice y according to the given ranges
|
||||
y = y[ct][:, cw]
|
||||
"""
|
||||
if isinstance(plt_range, tuple):
|
||||
plt_range = PlotRange(*plt_range)
|
||||
r = np.array((plt_range.left, plt_range.right), dtype="float")
|
||||
|
||||
indices = np.arange(len(axis))[
|
||||
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
|
||||
]
|
||||
cropped = axis[indices]
|
||||
order = np.argsort(plt_range.unit.inv(cropped))
|
||||
indices = indices[order]
|
||||
cropped = cropped[order]
|
||||
out_ax = plt_range.unit.inv(cropped)
|
||||
|
||||
return out_ax, indices, (out_ax[0], out_ax[-1])
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
spec = inspect.getfullargspec(func)
|
||||
args = spec.args
|
||||
if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
args = args[: -len(spec.defaults)]
|
||||
return args
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
out_func = scope[tmp_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
@cache
|
||||
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
||||
if not isinstance(num_args, int) and isinstance(num_returns, int):
|
||||
raise TypeError(f"num_args and num_returns must be int")
|
||||
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
||||
return_str = ", ".join("True" for _ in range(num_returns))
|
||||
func_name = f"__mock_{num_args}_{num_returns}"
|
||||
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
||||
scope = {}
|
||||
exec(func_str, scope)
|
||||
out_func = scope[func_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
def pretty_format_from_sim_name(name: str) -> str:
|
||||
"""formats a pretty version of a simulation directory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
name of the simulation (directory name)
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
prettier name
|
||||
"""
|
||||
s = name.split(PARAM_SEPARATOR)
|
||||
out = []
|
||||
for key, value in zip(s[::2], s[1::2]):
|
||||
try:
|
||||
out += [key.replace("_", " "), getattr(Parameters, key).display(float(value))]
|
||||
except (AttributeError, ValueError):
|
||||
out.append(key + PARAM_SEPARATOR + value)
|
||||
return PARAM_SEPARATOR.join(out)
|
||||
|
||||
|
||||
default_rules: list[Rule] = [
|
||||
# Grid
|
||||
*Rule.deduce(
|
||||
|
||||
@@ -1,4 +1,23 @@
|
||||
import inspect
|
||||
import re
|
||||
from functools import cache
|
||||
from string import printable as str_printable
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..physics.units import get_unit
|
||||
|
||||
|
||||
class HashableBaseModel(BaseModel):
|
||||
"""Pydantic BaseModel that's immutable and can be hashed"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
|
||||
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
|
||||
|
||||
def to_62(i: int) -> str:
|
||||
@@ -10,3 +29,118 @@ def to_62(i: int) -> str:
|
||||
i, value = divmod(i, 62)
|
||||
arr.append(str_printable[value])
|
||||
return "".join(reversed(arr))
|
||||
|
||||
|
||||
class PlotRange(HashableBaseModel):
|
||||
left: float
|
||||
right: float
|
||||
unit: Callable[[float], float]
|
||||
conserved_quantity: bool = True
|
||||
|
||||
def __init__(self, left, right, unit, **kwargs):
|
||||
super().__init__(left=left, right=right, unit=get_unit(unit), **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||
|
||||
def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
return sort_axis(axis, self)
|
||||
|
||||
def __iter__(self):
|
||||
yield self.left
|
||||
yield self.right
|
||||
yield self.unit.__name__
|
||||
|
||||
|
||||
def sort_axis(
|
||||
axis: np.ndarray, plt_range: PlotRange
|
||||
) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]:
|
||||
"""
|
||||
given an axis, returns this axis cropped according to the given range, converted and sorted
|
||||
|
||||
Parameters
|
||||
----------
|
||||
axis : 1D array containing the original axis (usual the w or t array)
|
||||
plt_range : tupple (min, max, conversion_function) used to crop the axis
|
||||
|
||||
Returns
|
||||
-------
|
||||
cropped : the axis cropped, converted and sorted
|
||||
indices : indices to use to slice and sort other array in the same fashion
|
||||
extent : tupple with min and max of cropped
|
||||
|
||||
Example
|
||||
-------
|
||||
w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20))
|
||||
t = np.linspace(-10, 10, 400)
|
||||
W, T = np.meshgrid(w, t)
|
||||
y = np.exp(-W**2 - T**2)
|
||||
|
||||
# Define ranges
|
||||
rw = (-4, 4, s)
|
||||
rt = (-2, 6, s)
|
||||
|
||||
w, cw = sort_axis(w, rw)
|
||||
t, ct = sort_axis(t, rt)
|
||||
|
||||
# slice y according to the given ranges
|
||||
y = y[ct][:, cw]
|
||||
"""
|
||||
if isinstance(plt_range, tuple):
|
||||
plt_range = PlotRange(*plt_range)
|
||||
r = np.array((plt_range.left, plt_range.right), dtype="float")
|
||||
|
||||
indices = np.arange(len(axis))[
|
||||
(axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r)))
|
||||
]
|
||||
cropped = axis[indices]
|
||||
order = np.argsort(plt_range.unit.inv(cropped))
|
||||
indices = indices[order]
|
||||
cropped = cropped[order]
|
||||
out_ax = plt_range.unit.inv(cropped)
|
||||
|
||||
return out_ax, indices, (out_ax[0], out_ax[-1])
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
spec = inspect.getfullargspec(func)
|
||||
args = spec.args
|
||||
if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
args = args[: -len(spec.defaults)]
|
||||
return args
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable:
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
out_func = scope[tmp_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
@cache
|
||||
def _mock_function(num_args: int, num_returns: int) -> Callable:
|
||||
arg_str = ", ".join("a" * (n + 1) for n in range(num_args))
|
||||
return_str = ", ".join("True" for _ in range(num_returns))
|
||||
func_name = f"__mock_{num_args}_{num_returns}"
|
||||
func_str = f"def {func_name}({arg_str}):\n return {return_str}"
|
||||
scope = {}
|
||||
exec(func_str, scope)
|
||||
out_func = scope[func_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
@@ -1,45 +1,14 @@
|
||||
from pydantic import BaseModel, validator
|
||||
from typing import Union, Iterable, Generator, Any
|
||||
from collections.abc import Sequence, MutableMapping
|
||||
from math import prod
|
||||
import itertools
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generator, Iterable, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import validator
|
||||
|
||||
from ..const import PARAM_SEPARATOR
|
||||
from . import utils
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def format_value(name: str, value) -> str:
|
||||
if value is True or value is False:
|
||||
return str(value)
|
||||
elif isinstance(value, (float, int)):
|
||||
try:
|
||||
return getattr(Parameters, name).display(value)
|
||||
except AttributeError:
|
||||
return format(value, ".9g")
|
||||
elif isinstance(value, (list, tuple, np.ndarray)):
|
||||
return "-".join([str(v) for v in value])
|
||||
elif isinstance(value, str):
|
||||
p = Path(value)
|
||||
if p.exists():
|
||||
return p.stem
|
||||
return str(value)
|
||||
|
||||
|
||||
def pretty_format_value(name: str, value) -> str:
|
||||
try:
|
||||
return getattr(Parameters, name).display(value)
|
||||
except AttributeError:
|
||||
return name + PARAM_SEPARATOR + str(value)
|
||||
|
||||
|
||||
class HashableBaseModel(BaseModel):
|
||||
"""Pydantic BaseModel that's immutable and can be hashed"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(type(self)) + sum(hash(v) for v in self.__dict__.values())
|
||||
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
|
||||
|
||||
class VariationSpecsError(ValueError):
|
||||
@@ -67,23 +36,20 @@ class Variationer:
|
||||
all_indices: list[list[int]]
|
||||
all_dicts: list[list[dict[str, list]]]
|
||||
|
||||
def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]]):
|
||||
def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]] = None):
|
||||
self.all_indices = []
|
||||
self.all_dicts = []
|
||||
for i, el in enumerate(variables):
|
||||
if not isinstance(el, Sequence):
|
||||
el = [{k: v} for k, v in el.items()]
|
||||
else:
|
||||
el = list(el)
|
||||
self.append(el)
|
||||
if variables is not None:
|
||||
for i, el in enumerate(variables):
|
||||
self.append(el)
|
||||
|
||||
def append(self, var_list: list[dict[str, list]]):
|
||||
def append(self, var_list: Union[list[MutableMapping], MutableMapping]):
|
||||
"""append a list of variable parameter sets
|
||||
each call to append creates a new group of parameters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
var_list : list[dict[str, list]]
|
||||
var_list : Union[list[MutableMapping], MutableMapping]
|
||||
each dict in the list is treated as an independent parameter
|
||||
this means that if for one dict, len > 1, the lists of possible values
|
||||
must be the same length
|
||||
@@ -100,6 +66,10 @@ class Variationer:
|
||||
VariationSpecsError
|
||||
raised when possible values lists in a same dict are not the same length
|
||||
"""
|
||||
if not isinstance(var_list, Sequence):
|
||||
var_list = [{k: v} for k, v in var_list.items()]
|
||||
else:
|
||||
var_list = list(var_list)
|
||||
num_vars = []
|
||||
for d in var_list:
|
||||
values = list(d.values())
|
||||
@@ -114,30 +84,43 @@ class Variationer:
|
||||
self.all_indices.append(num_vars)
|
||||
self.all_dicts.append(var_list)
|
||||
|
||||
def iterate(self, index: int = -1) -> Generator["SimulationDescriptor", None, None]:
|
||||
if index < 0:
|
||||
index = len(self.all_indices) + index + 1
|
||||
flattened_indices = sum(self.all_indices[:index], [])
|
||||
index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[:index]])
|
||||
def iterate(self, index: int = -1) -> Generator["VariationDescriptor", None, None]:
|
||||
index = self.__index(index)
|
||||
flattened_indices = sum(self.all_indices[: index + 1], [])
|
||||
index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[: index + 1]])
|
||||
ranges = [range(i) for i in flattened_indices]
|
||||
for r in itertools.product(*ranges):
|
||||
out: list[list[tuple[str, Any]]] = []
|
||||
indicies: list[list[int]] = []
|
||||
for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])):
|
||||
out.append([])
|
||||
indicies.append([])
|
||||
for value_index, var_d in zip(r[start:end], self.all_dicts[i]):
|
||||
for k, v in var_d.items():
|
||||
out[-1].append((k, v[value_index]))
|
||||
yield SimulationDescriptor(raw_descr=out)
|
||||
indicies[-1].append(value_index)
|
||||
yield VariationDescriptor(raw_descr=out, index=indicies)
|
||||
|
||||
def __index(self, index: int) -> int:
|
||||
if index < 0:
|
||||
index = len(self.all_indices) + index
|
||||
return index
|
||||
|
||||
def var_num(self, index: int = -1) -> int:
|
||||
index = self.__index(index)
|
||||
return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
|
||||
|
||||
|
||||
class SimulationDescriptor(HashableBaseModel):
|
||||
class VariationDescriptor(utils.HashableBaseModel):
|
||||
raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
|
||||
index: tuple[tuple[int, ...], ...]
|
||||
separator: str = "fiber"
|
||||
_format_registry: dict[str, Callable[..., str]] = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.descriptor(add_identifier=False)
|
||||
return self.formatted_descriptor(add_identifier=False)
|
||||
|
||||
def descriptor(self, add_identifier=False) -> str:
|
||||
def formatted_descriptor(self, add_identifier=False) -> str:
|
||||
"""formats a variable list into a str such that each simulation has a unique
|
||||
directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations)
|
||||
branch identifier can added at the beginning.
|
||||
@@ -156,7 +139,7 @@ class SimulationDescriptor(HashableBaseModel):
|
||||
|
||||
for p_name, p_value in self.flat:
|
||||
ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "")
|
||||
vs = format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
|
||||
vs = self.format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "")
|
||||
str_list.append(ps + PARAM_SEPARATOR + vs)
|
||||
tmp_name = PARAM_SEPARATOR.join(str_list)
|
||||
if not add_identifier:
|
||||
@@ -165,6 +148,34 @@ class SimulationDescriptor(HashableBaseModel):
|
||||
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_formatter(cls, p_name: str, func: Callable[..., str]):
|
||||
cls._format_registry[p_name] = func
|
||||
|
||||
def format_value(self, name: str, value) -> str:
|
||||
if value is True or value is False:
|
||||
return str(value)
|
||||
elif isinstance(value, (float, int)):
|
||||
try:
|
||||
return self._format_registry[name](value)
|
||||
except KeyError:
|
||||
return format(value, ".9g")
|
||||
elif isinstance(value, (list, tuple, np.ndarray)):
|
||||
return "-".join([str(v) for v in value])
|
||||
elif isinstance(value, str):
|
||||
p = Path(value)
|
||||
if p.exists():
|
||||
return p.stem
|
||||
return str(value)
|
||||
|
||||
def __getitem__(self, key) -> "VariationDescriptor":
|
||||
return VariationDescriptor(
|
||||
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
|
||||
)
|
||||
|
||||
def update_config(self, cfg: dict[str, Any]):
|
||||
return cfg | {k: v for k, v in self.raw_descr[-1]}
|
||||
|
||||
@property
|
||||
def flat(self) -> list[tuple[str, Any]]:
|
||||
out = []
|
||||
@@ -177,17 +188,27 @@ class SimulationDescriptor(HashableBaseModel):
|
||||
|
||||
@property
|
||||
def branch(self) -> "BranchDescriptor":
|
||||
return SimulationDescriptor(raw_descr=self.raw_descr, separator=self.separator)
|
||||
for i in reversed(range(len(self.raw_descr))):
|
||||
for j in reversed(range(len(self.raw_descr[i]))):
|
||||
if self.raw_descr[i][j][0] == "num":
|
||||
del self.raw_descr[i][j]
|
||||
return VariationDescriptor(
|
||||
raw_descr=self.raw_descr, index=self.index, separator=self.separator
|
||||
)
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
return "u_" + utils.to_62(hash(str(self.flat)))
|
||||
|
||||
|
||||
class BranchDescriptor(SimulationDescriptor):
|
||||
class BranchDescriptor(VariationDescriptor):
|
||||
__ids: dict[int, int] = {}
|
||||
|
||||
@property
|
||||
def identifier(self) -> str:
|
||||
return "b_" + utils.to_62(hash(str(self.flat)))
|
||||
branch_id = hash(str(self.flat))
|
||||
self.__ids.setdefault(branch_id, len(self.__ids))
|
||||
return str(self.__ids[branch_id])
|
||||
|
||||
@validator("raw_descr")
|
||||
def validate_raw_descr(cls, v):
|
||||
|
||||
Reference in New Issue
Block a user