diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 4c9e73d..997e495 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -5,4 +5,5 @@ from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot from .spectra import Pulse, Spectrum from .utils import Paths, open_config, parameter -from .utils.parameter import Configuration, Parameters, PlotRange +from .utils.parameter import Configuration, Parameters +from .utils.utils import PlotRange diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 4d34743..477d9c5 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -11,7 +11,7 @@ from send2trash import send2trash from .. import env, utils from ..logger import get_logger -from ..utils.parameter import Configuration, Parameters, format_variable_list +from ..utils.parameter import Configuration, Parameters from . import pulse from .fiber import create_non_linear_op, fast_dispersion_op @@ -466,14 +466,14 @@ class Simulations: self.configuration = configuration - self.name = self.configuration.final_path - self.sim_dir = self.configuration.final_sim_dir + self.name = self.configuration.name + self.sim_dir = self.configuration.final_path self.configuration.save_parameters() self.sim_jobs_per_node = 1 def finished_and_complete(self): - for sim in self.configuration.all_configs_dict.values(): + for sim in self.configuration.all_configs.values(): if ( self.configuration.sim_status(sim.output_path)[0] != self.configuration.State.COMPLETE @@ -487,7 +487,7 @@ class Simulations: def _run_available(self): for variable, params in self.configuration: - v_list_str = format_variable_list(variable, add_iden=True) + v_list_str = variable.formatted_descriptor(True) utils.save_parameters(params.prepare_for_dump(), Path(params.output_path)) self.new_sim(v_list_str, params) @@ -526,7 +526,9 @@ class SequencialSimulations(Simulations, priority=0): def __init__(self, configuration: Configuration, task_id): super().__init__(configuration, task_id=task_id) self.pbars = utils.PBars( - self.configuration.total_num_steps, "Simulating " + self.configuration.final_path, 1 + self.configuration.total_num_steps, + "Simulating " + self.configuration.final_path.name, + 1, ) self.configuration.skip_callback = lambda num: self.pbars.update(0, num) @@ -569,7 +571,7 @@ class MultiProcSimulations(Simulations, priority=1): self.p_worker = multiprocessing.Process( target=utils.progress_worker, args=( - self.configuration.final_path, + self.configuration.final_path.name, self.sim_jobs_per_node, self.configuration.total_num_steps, self.progress_queue, @@ -716,7 +718,7 @@ def run_simulation( sim = new_simulation(config, method) sim.run() - path_trees = utils.build_path_trees(config.sim_dirs[-1]) + path_trees = utils.build_path_trees(config.fiber_paths[-1]) final_name = env.get(env.OUTPUT_PATH) if final_name is None: @@ -724,7 +726,7 @@ def run_simulation( utils.merge(final_name, path_trees) try: - send2trash(config.sim_dirs) + send2trash(config.fiber_paths) except (PermissionError, OSError): get_logger(__name__).error("Could not send temporary directories to trash") diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index fbaa903..712bf12 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -14,7 +14,8 @@ from .const import PARAM_SEPARATOR from .defaults import default_plotting as defaults from .math import abs2, span from .physics import pulse, units -from .utils.parameter import Parameters, PlotRange, sort_axis +from .utils.parameter import Parameters +from .utils.utils import PlotRange, sort_axis RangeType = tuple[float, float, Union[str, Callable]] NO_LIM = object() diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index ede8211..029268f 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -16,9 +16,8 @@ from ..utils import auto_crop, open_config, save_toml, translate_parameters from ..utils.parameter import ( Configuration, Parameters, - pretty_format_from_sim_name, - pretty_format_value, ) +from ..utils.variationer import VariationDescriptor def fingerprint(params: Parameters): @@ -46,7 +45,7 @@ def plot_all(sim_dir: Path, limits: list[str], show=False, **opts): path, fig, ax = plot_setup( pulse.path.parent / ( - pretty_format_from_sim_name(pulse.path.name) + pulse.path.name + PARAM_SEPARATOR + f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}" ) diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 1ec018c..fdd6b78 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -16,7 +16,8 @@ from .plotting import ( single_position_plot, transform_2D_propagation, ) -from .utils.parameter import Parameters, PlotRange +from .utils.parameter import Parameters +from .utils.utils import PlotRange from .utils import load_spectrum diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index 8062bd3..fdcbcea 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -161,29 +161,35 @@ def save_toml(path: os.PathLike, dico): return dico -def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str, Any]], str]: - loaded_config = open_config(final_config_path) - final_name = loaded_config.get("name") - fiber_list = loaded_config.pop("Fiber") - configs = [] - if fiber_list is not None: - master_variable = loaded_config.get("variable", {}) - for i, params in enumerate(fiber_list): - params.setdefault("variable", master_variable if i == 0 else {}) - if i == 0: - params["variable"] |= master_variable - configs.append(loaded_config | params) - else: - configs.append(loaded_config) - while "previous_config_file" in configs[0]: - configs.insert(0, open_config(configs[0]["previous_config_file"])) - configs[0].setdefault("variable", {}) - for pre, nex in zip(configs[:-1], configs[1:]): - variable = nex.pop("variable", {}) - nex.update({k: v for k, v in pre.items() if k not in nex}) - nex["variable"] = variable +def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]: + """loads a configuration file - return configs, final_name + Parameters + ---------- + path : os.PathLike + path to the config toml file + + Returns + ------- + final_path : Path + output name of the simulation + list[dict[str, Any]] + one config per fiber + + """ + loaded_config = open_config(path) + + fiber_list: list[dict[str, Any]] = loaded_config.pop("Fiber") + if len(fiber_list) == 0: + raise ValueError(f"No fiber in config {path}") + final_path = loaded_config.get("name") + configs = [] + for i, params in enumerate(fiber_list): + params.setdefault("variable", {}) + configs.append(loaded_config | params) + configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"] + + return Path(final_path), configs def save_parameters( diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index 17fe9f4..fa43c58 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -11,14 +11,18 @@ from dataclasses import asdict, dataclass, fields from functools import cache, lru_cache from pathlib import Path from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union + import numpy as np from numpy.lib import isin +from scgenerator.utils import ensure_folder, variationer from .. import math, utils from ..const import PARAM_FN, PARAM_SEPARATOR, __version__ from ..errors import EvaluatorError, NoDefaultError from ..logger import get_logger from ..physics import fiber, materials, pulse, units +from ..utils.variationer import VariationDescriptor, Variationer +from .utils import func_rewrite, _mock_function, get_arg_names T = TypeVar("T") @@ -256,7 +260,7 @@ class Parameter: ---------- tpe : type type of the paramter - validators : Callable[[str, Any], None] + validator : Callable[[str, Any], None] signature : validator(name, value) must raise a ValueError when value doesn't fit the criteria checked by validator. name is passed to validator to be included in the error message @@ -290,7 +294,6 @@ class Parameter: if isinstance(value, Parameter): defaut = None if self.default is None else copy(self.default) instance.__dict__[self.name] = defaut - # instance.__dict__[self.name] = None else: if value is not None: if self.converter is not None: @@ -768,9 +771,11 @@ class Configuration: obj with the output path of the simulation saved in its output_path attribute. """ - master_configs: list[dict[str, Any]] - sim_dirs: list[Path] + fiber_configs: list[dict[str, Any]] + master_config: dict[str, Any] + fiber_paths: list[Path] num_sim: int + num_fibers: int repeat: int z_num: int total_num_steps: int @@ -778,19 +783,17 @@ class Configuration: parallel: bool overwrite: bool final_path: str - all_configs_dict: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"] - all_configs_list: list[list["Configuration.__SimConfig"]] + all_configs: dict[tuple[tuple[int, ...], ...], "Configuration.__SimConfig"] @dataclass(frozen=True) class __SimConfig: - vary_list: list[tuple[str, Any]] + descriptor: VariationDescriptor config: dict[str, Any] output_path: Path - index: tuple[tuple[int, ...], ...] @property def sim_num(self) -> int: - return len(self.index) + return len(self.descriptor.index) class State(enum.Enum): COMPLETE = enum.auto() @@ -810,48 +813,48 @@ class Configuration: ): self.logger = get_logger(__name__) - self.master_configs, self.final_path = utils.load_config_sequence(final_config_path) - if self.final_path is None: - self.final_path = Parameters.name.default - self.name = Path(self.final_path).name + self.overwrite = overwrite + self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path) + self.final_path = utils.ensure_folder( + self.final_path, mkdir=False, prevent_overwrite=not self.overwrite + ) + self.master_config = self.fiber_configs[0] + self.name = self.final_path.name self.z_num = 0 self.total_num_steps = 0 - self.sim_dirs = [] - self.overwrite = overwrite + self.fiber_paths = [] + self.all_configs = {} self.skip_callback = skip_callback - self.worker_num = self.master_configs[0].get("worker_num", max(1, os.cpu_count() // 2)) - self.repeat = self.master_configs[0].get("repeat", 1) + self.worker_num = self.master_config.get("worker_num", max(1, os.cpu_count() // 2)) + self.repeat = self.master_config.get("repeat", 1) + self.variationer = Variationer() - names = set() - for i, config in enumerate(self.master_configs): + fiber_names = set() + self.num_fibers = 0 + for i, config in enumerate(self.fiber_configs): + config.setdefault("name", Parameters.name.default) self.z_num += config["z_num"] - config.setdefault("name", f"{Parameters.name.default} {i}") - given_name = config["name"] - fn_i = 0 - while config["name"] in names: - config["name"] = given_name + f"_{fn_i}" - fn_i += 1 - names.add(config["name"]) - - self.sim_dirs.append( + fiber_names.add(config["name"]) + self.variationer.append(config.pop("variable")) + self.fiber_paths.append( utils.ensure_folder( - Path("_".join(["_", self.name, Path(config["name"]).name, "_"])), + self.fiber_path(i, config), mkdir=False, prevent_overwrite=not self.overwrite, ) ) self.__validate_variable(config) - self.__compute_sim_dirs() - [Evaluator.evaluate_default(c[0].config, True) for c in self.all_configs_list] - self.num_sim = len(self.all_configs_list[-1]) + self.num_fibers += 1 + Evaluator.evaluate_default(config, True) + self.num_sim = self.variationer.var_num() self.total_num_steps = sum( - config["z_num"] * len(self.all_configs_list[i]) - for i, config in enumerate(self.master_configs) + config["z_num"] * self.variationer.var_num(i) + for i, config in enumerate(self.fiber_configs) ) - self.final_sim_dir = utils.ensure_folder( - Path(self.master_configs[-1]["name"]), mkdir=False, prevent_overwrite=not self.overwrite - ) - self.parallel = self.master_configs[0].get("parallel", Parameters.parallel.default) + self.parallel = self.master_config.get("parallel", Parameters.parallel.default) + + def fiber_path(self, i: int, full_config: dict[str, Any]) -> Path: + return self.final_path / PARAM_SEPARATOR.join([format(i), self.name, full_config["name"]]) def __validate_variable(self, config: dict[str, Any]): for k, v in config.get("variable", {}).items(): @@ -862,76 +865,62 @@ class Configuration: if len(v) == 0: raise ValueError(f"variable parameter {k!r} must not be empty") - def __compute_sim_dirs(self): - self.all_configs_dict = {} - self.all_configs_list = [] - self.master_configs[0]["variable"]["num"] = list( - range(self.master_configs[0].get("repeat", 1)) - ) - dp = DataPather([c["variable"] for c in self.master_configs]) - for i, conf in enumerate(self.master_configs): - self.all_configs_list.append([]) - for sim_index, prev_path, this_path, this_vary in dp.all_vary_list(i): - this_conf = conf.copy() + def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]: + for i in range(self.num_fibers): + for sim_config in self.iterate_single_fiber(i): + if i > 0: - prev_path = utils.ensure_folder( - self.sim_dirs[i - 1] / prev_path, not self.overwrite, False + sim_config.config["prev_data_dir"] = str( + self.fiber_paths[i - 1] / sim_config.descriptor[:i].formatted_descriptor() ) - this_conf["prev_data_dir"] = str(prev_path) - - this_path = utils.ensure_folder( - self.sim_dirs[i] / this_path, not self.overwrite, False - ) - this_conf.pop("variable") - conf_to_use = {k: v for k, v in this_vary if k != "num"} | this_conf - self.all_configs_dict[sim_index] = self.__SimConfig( - this_vary, conf_to_use, this_path, sim_index - ) - self.all_configs_list[i].append(self.all_configs_dict[sim_index]) - - def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]: - for i, sim_config_list in enumerate(self.all_configs_list): - for sim_config, params in self.__iter_1_sim(sim_config_list): + params = Parameters(**sim_config.config) + params.compute() fiber_map = [] for j in range(i + 1): - this_conf = self.all_configs_dict[sim_config.index[: j + 1]].config + this_conf = self.all_configs[sim_config.descriptor.index[: j + 1]].config if j > 0: - prev_conf = self.all_configs_dict[sim_config.index[:j]].config + prev_conf = self.all_configs[sim_config.descriptor.index[:j]].config length = prev_conf["length"] + fiber_map[j - 1][0] else: length = 0.0 fiber_map.append((length, this_conf["name"])) - params.output_path = str(sim_config.output_path) params.fiber_map = fiber_map - yield sim_config.vary_list, params + yield sim_config.descriptor, params - def __iter_1_sim( - self, configs: list["Configuration.__SimConfig"] - ) -> Generator[tuple["Configuration.__SimConfig", Parameters], None, None]: + def iterate_single_fiber( + self, index: int + ) -> Generator["Configuration.__SimConfig", None, None]: """iterates through the parameters of only one fiber. It takes care of recovering partially completed simulations, skipping complete ones and waiting for the previous fiber to finish Parameters ---------- - configs : list[__SimConfig] - list of configuration obj + index : int + which fiber to iterate over Yields ------- __SimConfig configuration obj - Parameters - computed Parameters obj """ - sim_dict: dict[Path, Configuration.__SimConfig] = {s.output_path: s for s in configs} + sim_dict: dict[Path, self.__SimConfig] = {} + for descr in self.variationer.iterate(index): + cfg = descr.update_config(self.fiber_configs[index]) + p = ensure_folder( + self.fiber_paths[index] / descr.formatted_descriptor(), + not self.overwrite, + False, + ) + cfg["output_path"] = str(p) + sim_config = self.__SimConfig(descr, cfg, p) + sim_dict[p] = sim_config + self.all_configs[sim_config.descriptor.index] = sim_config while len(sim_dict) > 0: for data_dir, sim_config in sim_dict.items(): task, config_dict = self.__decide(sim_config) if task == self.Action.RUN: sim_dict.pop(data_dir) - p = Parameters(**config_dict) - p.compute() - yield sim_config, p + yield sim_config if "recovery_last_stored" in config_dict and self.skip_callback is not None: self.skip_callback(config_dict["recovery_last_stored"]) break @@ -956,7 +945,7 @@ class Configuration: Returns ------- - str : {'run', 'wait', 'skip'} + str : Configuration.Action what to do config_dict : dict[str, Any] config dictionary. The only key possibly modified is 'prev_data_dir', which @@ -1012,7 +1001,7 @@ class Configuration: raise ValueError(f"Too many spectra in {data_dir}") def save_parameters(self): - for config, sim_dir in zip(self.master_configs, self.sim_dirs): + for config, sim_dir in zip(self.fiber_configs, self.fiber_paths): os.makedirs(sim_dir, exist_ok=True) utils.save_toml(sim_dir / f"initial_config.toml", config) @@ -1022,144 +1011,6 @@ class Configuration: return param -@dataclass(frozen=True) -class PlotRange: - left: float = Parameter(type_checker(int, float)) - right: float = Parameter(type_checker(int, float)) - unit: Callable[[float], float] = Parameter(units.is_unit, converter=units.get_unit) - conserved_quantity: bool = Parameter(boolean, default=True) - - def __str__(self): - return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}" - - def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: - return sort_axis(axis, self) - - def __iter__(self): - yield self.left - yield self.right - yield self.unit.__name__ - - -def sort_axis( - axis: np.ndarray, plt_range: PlotRange -) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: - """ - given an axis, returns this axis cropped according to the given range, converted and sorted - - Parameters - ---------- - axis : 1D array containing the original axis (usual the w or t array) - plt_range : tupple (min, max, conversion_function) used to crop the axis - - Returns - ------- - cropped : the axis cropped, converted and sorted - indices : indices to use to slice and sort other array in the same fashion - extent : tupple with min and max of cropped - - Example - ------- - w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20)) - t = np.linspace(-10, 10, 400) - W, T = np.meshgrid(w, t) - y = np.exp(-W**2 - T**2) - - # Define ranges - rw = (-4, 4, s) - rt = (-2, 6, s) - - w, cw = sort_axis(w, rw) - t, ct = sort_axis(t, rt) - - # slice y according to the given ranges - y = y[ct][:, cw] - """ - if isinstance(plt_range, tuple): - plt_range = PlotRange(*plt_range) - r = np.array((plt_range.left, plt_range.right), dtype="float") - - indices = np.arange(len(axis))[ - (axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r))) - ] - cropped = axis[indices] - order = np.argsort(plt_range.unit.inv(cropped)) - indices = indices[order] - cropped = cropped[order] - out_ax = plt_range.unit.inv(cropped) - - return out_ax, indices, (out_ax[0], out_ax[-1]) - - -def get_arg_names(func: Callable) -> list[str]: - spec = inspect.getfullargspec(func) - args = spec.args - if spec.defaults is not None and len(spec.defaults) > 0: - args = args[: -len(spec.defaults)] - return args - - -def validate_arg_names(names: list[str]): - for n in names: - if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: - raise ValueError(f"{n} is an invalid parameter name") - - -def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable: - if arg_names is None: - arg_names = get_arg_names(func) - else: - validate_arg_names(arg_names) - validate_arg_names(kwarg_names) - sign_arg_str = ", ".join(arg_names + kwarg_names) - call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) - tmp_name = f"{func.__name__}_0" - func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" - scope = dict(__func__=func) - exec(func_str, scope) - out_func = scope[tmp_name] - out_func.__module__ = "evaluator" - return out_func - - -@cache -def _mock_function(num_args: int, num_returns: int) -> Callable: - if not isinstance(num_args, int) and isinstance(num_returns, int): - raise TypeError(f"num_args and num_returns must be int") - arg_str = ", ".join("a" * (n + 1) for n in range(num_args)) - return_str = ", ".join("True" for _ in range(num_returns)) - func_name = f"__mock_{num_args}_{num_returns}" - func_str = f"def {func_name}({arg_str}):\n return {return_str}" - scope = {} - exec(func_str, scope) - out_func = scope[func_name] - out_func.__module__ = "evaluator" - return out_func - - -def pretty_format_from_sim_name(name: str) -> str: - """formats a pretty version of a simulation directory - - Parameters - ---------- - name : str - name of the simulation (directory name) - - Returns - ------- - str - prettier name - """ - s = name.split(PARAM_SEPARATOR) - out = [] - for key, value in zip(s[::2], s[1::2]): - try: - out += [key.replace("_", " "), getattr(Parameters, key).display(float(value))] - except (AttributeError, ValueError): - out.append(key + PARAM_SEPARATOR + value) - return PARAM_SEPARATOR.join(out) - - default_rules: list[Rule] = [ # Grid *Rule.deduce( diff --git a/src/scgenerator/utils/utils.py b/src/scgenerator/utils/utils.py index b81744d..a18aacb 100644 --- a/src/scgenerator/utils/utils.py +++ b/src/scgenerator/utils/utils.py @@ -1,4 +1,23 @@ +import inspect +import re +from functools import cache from string import printable as str_printable +from typing import Callable + +import numpy as np +from pydantic import BaseModel + +from ..physics.units import get_unit + + +class HashableBaseModel(BaseModel): + """Pydantic BaseModel that's immutable and can be hashed""" + + def __hash__(self) -> int: + return hash(type(self)) + sum(hash(v) for v in self.__dict__.values()) + + class Config: + allow_mutation = False def to_62(i: int) -> str: @@ -10,3 +29,118 @@ def to_62(i: int) -> str: i, value = divmod(i, 62) arr.append(str_printable[value]) return "".join(reversed(arr)) + + +class PlotRange(HashableBaseModel): + left: float + right: float + unit: Callable[[float], float] + conserved_quantity: bool = True + + def __init__(self, left, right, unit, **kwargs): + super().__init__(left=left, right=right, unit=get_unit(unit), **kwargs) + + def __str__(self): + return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}" + + def sort_axis(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: + return sort_axis(axis, self) + + def __iter__(self): + yield self.left + yield self.right + yield self.unit.__name__ + + +def sort_axis( + axis: np.ndarray, plt_range: PlotRange +) -> tuple[np.ndarray, np.ndarray, tuple[float, float]]: + """ + given an axis, returns this axis cropped according to the given range, converted and sorted + + Parameters + ---------- + axis : 1D array containing the original axis (usual the w or t array) + plt_range : tupple (min, max, conversion_function) used to crop the axis + + Returns + ------- + cropped : the axis cropped, converted and sorted + indices : indices to use to slice and sort other array in the same fashion + extent : tupple with min and max of cropped + + Example + ------- + w = np.append(np.linspace(0, -10, 20), np.linspace(0, 10, 20)) + t = np.linspace(-10, 10, 400) + W, T = np.meshgrid(w, t) + y = np.exp(-W**2 - T**2) + + # Define ranges + rw = (-4, 4, s) + rt = (-2, 6, s) + + w, cw = sort_axis(w, rw) + t, ct = sort_axis(t, rt) + + # slice y according to the given ranges + y = y[ct][:, cw] + """ + if isinstance(plt_range, tuple): + plt_range = PlotRange(*plt_range) + r = np.array((plt_range.left, plt_range.right), dtype="float") + + indices = np.arange(len(axis))[ + (axis <= np.max(plt_range.unit(r))) & (axis >= np.min(plt_range.unit(r))) + ] + cropped = axis[indices] + order = np.argsort(plt_range.unit.inv(cropped)) + indices = indices[order] + cropped = cropped[order] + out_ax = plt_range.unit.inv(cropped) + + return out_ax, indices, (out_ax[0], out_ax[-1]) + + +def get_arg_names(func: Callable) -> list[str]: + spec = inspect.getfullargspec(func) + args = spec.args + if spec.defaults is not None and len(spec.defaults) > 0: + args = args[: -len(spec.defaults)] + return args + + +def validate_arg_names(names: list[str]): + for n in names: + if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None: + raise ValueError(f"{n} is an invalid parameter name") + + +def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None) -> Callable: + if arg_names is None: + arg_names = get_arg_names(func) + else: + validate_arg_names(arg_names) + validate_arg_names(kwarg_names) + sign_arg_str = ", ".join(arg_names + kwarg_names) + call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names]) + tmp_name = f"{func.__name__}_0" + func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})" + scope = dict(__func__=func) + exec(func_str, scope) + out_func = scope[tmp_name] + out_func.__module__ = "evaluator" + return out_func + + +@cache +def _mock_function(num_args: int, num_returns: int) -> Callable: + arg_str = ", ".join("a" * (n + 1) for n in range(num_args)) + return_str = ", ".join("True" for _ in range(num_returns)) + func_name = f"__mock_{num_args}_{num_returns}" + func_str = f"def {func_name}({arg_str}):\n return {return_str}" + scope = {} + exec(func_str, scope) + out_func = scope[func_name] + out_func.__module__ = "evaluator" + return out_func diff --git a/src/scgenerator/utils/variationer.py b/src/scgenerator/utils/variationer.py index 2876957..d53777d 100644 --- a/src/scgenerator/utils/variationer.py +++ b/src/scgenerator/utils/variationer.py @@ -1,45 +1,14 @@ -from pydantic import BaseModel, validator -from typing import Union, Iterable, Generator, Any -from collections.abc import Sequence, MutableMapping +from math import prod import itertools +from collections.abc import MutableMapping, Sequence +from pathlib import Path +from typing import Any, Callable, Generator, Iterable, Union + +import numpy as np +from pydantic import validator + from ..const import PARAM_SEPARATOR from . import utils -import numpy as np -from pathlib import Path - - -def format_value(name: str, value) -> str: - if value is True or value is False: - return str(value) - elif isinstance(value, (float, int)): - try: - return getattr(Parameters, name).display(value) - except AttributeError: - return format(value, ".9g") - elif isinstance(value, (list, tuple, np.ndarray)): - return "-".join([str(v) for v in value]) - elif isinstance(value, str): - p = Path(value) - if p.exists(): - return p.stem - return str(value) - - -def pretty_format_value(name: str, value) -> str: - try: - return getattr(Parameters, name).display(value) - except AttributeError: - return name + PARAM_SEPARATOR + str(value) - - -class HashableBaseModel(BaseModel): - """Pydantic BaseModel that's immutable and can be hashed""" - - def __hash__(self) -> int: - return hash(type(self)) + sum(hash(v) for v in self.__dict__.values()) - - class Config: - allow_mutation = False class VariationSpecsError(ValueError): @@ -67,23 +36,20 @@ class Variationer: all_indices: list[list[int]] all_dicts: list[list[dict[str, list]]] - def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]]): + def __init__(self, variables: Iterable[Union[list[MutableMapping], MutableMapping]] = None): self.all_indices = [] self.all_dicts = [] - for i, el in enumerate(variables): - if not isinstance(el, Sequence): - el = [{k: v} for k, v in el.items()] - else: - el = list(el) - self.append(el) + if variables is not None: + for i, el in enumerate(variables): + self.append(el) - def append(self, var_list: list[dict[str, list]]): + def append(self, var_list: Union[list[MutableMapping], MutableMapping]): """append a list of variable parameter sets each call to append creates a new group of parameters Parameters ---------- - var_list : list[dict[str, list]] + var_list : Union[list[MutableMapping], MutableMapping] each dict in the list is treated as an independent parameter this means that if for one dict, len > 1, the lists of possible values must be the same length @@ -100,6 +66,10 @@ class Variationer: VariationSpecsError raised when possible values lists in a same dict are not the same length """ + if not isinstance(var_list, Sequence): + var_list = [{k: v} for k, v in var_list.items()] + else: + var_list = list(var_list) num_vars = [] for d in var_list: values = list(d.values()) @@ -114,30 +84,43 @@ class Variationer: self.all_indices.append(num_vars) self.all_dicts.append(var_list) - def iterate(self, index: int = -1) -> Generator["SimulationDescriptor", None, None]: - if index < 0: - index = len(self.all_indices) + index + 1 - flattened_indices = sum(self.all_indices[:index], []) - index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[:index]]) + def iterate(self, index: int = -1) -> Generator["VariationDescriptor", None, None]: + index = self.__index(index) + flattened_indices = sum(self.all_indices[: index + 1], []) + index_positions = np.cumsum([0] + [len(i) for i in self.all_indices[: index + 1]]) ranges = [range(i) for i in flattened_indices] for r in itertools.product(*ranges): out: list[list[tuple[str, Any]]] = [] + indicies: list[list[int]] = [] for i, (start, end) in enumerate(zip(index_positions[:-1], index_positions[1:])): out.append([]) + indicies.append([]) for value_index, var_d in zip(r[start:end], self.all_dicts[i]): for k, v in var_d.items(): out[-1].append((k, v[value_index])) - yield SimulationDescriptor(raw_descr=out) + indicies[-1].append(value_index) + yield VariationDescriptor(raw_descr=out, index=indicies) + + def __index(self, index: int) -> int: + if index < 0: + index = len(self.all_indices) + index + return index + + def var_num(self, index: int = -1) -> int: + index = self.__index(index) + return max(1, prod(prod(el) for el in self.all_indices[: index + 1])) -class SimulationDescriptor(HashableBaseModel): +class VariationDescriptor(utils.HashableBaseModel): raw_descr: tuple[tuple[tuple[str, Any], ...], ...] + index: tuple[tuple[int, ...], ...] separator: str = "fiber" + _format_registry: dict[str, Callable[..., str]] = {} def __str__(self) -> str: - return self.descriptor(add_identifier=False) + return self.formatted_descriptor(add_identifier=False) - def descriptor(self, add_identifier=False) -> str: + def formatted_descriptor(self, add_identifier=False) -> str: """formats a variable list into a str such that each simulation has a unique directory name. A u_XXX unique identifier and b_XXX (ignoring repeat simulations) branch identifier can added at the beginning. @@ -156,7 +139,7 @@ class SimulationDescriptor(HashableBaseModel): for p_name, p_value in self.flat: ps = p_name.replace("/", "").replace("\\", "").replace(PARAM_SEPARATOR, "") - vs = format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "") + vs = self.format_value(p_name, p_value).replace("/", "").replace(PARAM_SEPARATOR, "") str_list.append(ps + PARAM_SEPARATOR + vs) tmp_name = PARAM_SEPARATOR.join(str_list) if not add_identifier: @@ -165,6 +148,34 @@ class SimulationDescriptor(HashableBaseModel): self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name ) + @classmethod + def register_formatter(cls, p_name: str, func: Callable[..., str]): + cls._format_registry[p_name] = func + + def format_value(self, name: str, value) -> str: + if value is True or value is False: + return str(value) + elif isinstance(value, (float, int)): + try: + return self._format_registry[name](value) + except KeyError: + return format(value, ".9g") + elif isinstance(value, (list, tuple, np.ndarray)): + return "-".join([str(v) for v in value]) + elif isinstance(value, str): + p = Path(value) + if p.exists(): + return p.stem + return str(value) + + def __getitem__(self, key) -> "VariationDescriptor": + return VariationDescriptor( + raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator + ) + + def update_config(self, cfg: dict[str, Any]): + return cfg | {k: v for k, v in self.raw_descr[-1]} + @property def flat(self) -> list[tuple[str, Any]]: out = [] @@ -177,17 +188,27 @@ class SimulationDescriptor(HashableBaseModel): @property def branch(self) -> "BranchDescriptor": - return SimulationDescriptor(raw_descr=self.raw_descr, separator=self.separator) + for i in reversed(range(len(self.raw_descr))): + for j in reversed(range(len(self.raw_descr[i]))): + if self.raw_descr[i][j][0] == "num": + del self.raw_descr[i][j] + return VariationDescriptor( + raw_descr=self.raw_descr, index=self.index, separator=self.separator + ) @property def identifier(self) -> str: return "u_" + utils.to_62(hash(str(self.flat))) -class BranchDescriptor(SimulationDescriptor): +class BranchDescriptor(VariationDescriptor): + __ids: dict[int, int] = {} + @property def identifier(self) -> str: - return "b_" + utils.to_62(hash(str(self.flat))) + branch_id = hash(str(self.flat)) + self.__ids.setdefault(branch_id, len(self.__ids)) + return str(self.__ids[branch_id]) @validator("raw_descr") def validate_raw_descr(cls, v):