From 11345242e2751c6a4b7278c30ae4a974f5fca72e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Mon, 14 Jun 2021 12:23:30 +0200 Subject: [PATCH] sim count fix --- src/scgenerator/initialize.py | 9 ++++++--- src/scgenerator/scripts/slurm_submit.py | 2 +- src/scgenerator/utils/__init__.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 5c54591..6a7f3ab 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -7,6 +7,7 @@ from collections import defaultdict import numpy as np from numpy import pi +from numpy.core.fromnumeric import var from . import io, utils from .defaults import default_parameters @@ -14,7 +15,7 @@ from .errors import * from .logger import get_logger from .math import abs2, power_fact from .physics import fiber, pulse, units -from .utils import count_variations, override_config, required_simulations +from .utils import count_variations, override_config, required_simulations, variable_iterator from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters @@ -491,7 +492,7 @@ class RecoveryParamSequence(ParamSequence): return path_dic[max_in_common] -def validate_config_sequence(*configs: os.PathLike) -> Config: +def validate_config_sequence(*configs: os.PathLike) -> Tuple[Config, int]: """validates a sequence of configs where all but the first one may have parameters missing @@ -506,12 +507,14 @@ def validate_config_sequence(*configs: os.PathLike) -> Config: the final config as would be simulated, but of course missing input fields in the middle """ previous = None + variables = set() for config in configs: if (p := Path(config)).is_dir(): config = p / "initial_config.toml" dico = io.load_toml(config) previous = Config.from_bare(override_config(dico, previous)) - return previous + variables |= {(k, tuple(v)) for k, v in previous.variable.items()} + return previous, np.product([len(v) for k, v in variables]) def wspace(t, t_num=0): diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index 783fcce..6a65b7d 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -124,7 +124,7 @@ def main(): config_paths = args.configs final_config = validate_config_sequence(*config_paths) - sim_num, _ = count_variations(final_config) + sim_num = count_variations(final_config) args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node) diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index 6730014..e71bb14 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -176,7 +176,7 @@ def progress_worker( pbars[0].update() -def count_variations(config: BareConfig) -> Tuple[int, int]: +def count_variations(config: BareConfig) -> int: """returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and variable_params_num is the number of distinct parameters that will vary.""" sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat