sim count fix

This commit is contained in:
Benoît Sierro
2021-06-14 12:23:30 +02:00
parent c4d55d18bf
commit 11345242e2
3 changed files with 8 additions and 5 deletions

View File

@@ -7,6 +7,7 @@ from collections import defaultdict
import numpy as np
from numpy import pi
from numpy.core.fromnumeric import var
from . import io, utils
from .defaults import default_parameters
@@ -14,7 +15,7 @@ from .errors import *
from .logger import get_logger
from .math import abs2, power_fact
from .physics import fiber, pulse, units
from .utils import count_variations, override_config, required_simulations
from .utils import count_variations, override_config, required_simulations, variable_iterator
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
@@ -491,7 +492,7 @@ class RecoveryParamSequence(ParamSequence):
return path_dic[max_in_common]
def validate_config_sequence(*configs: os.PathLike) -> Config:
def validate_config_sequence(*configs: os.PathLike) -> Tuple[Config, int]:
"""validates a sequence of configs where all but the first one may have
parameters missing
@@ -506,12 +507,14 @@ def validate_config_sequence(*configs: os.PathLike) -> Config:
the final config as would be simulated, but of course missing input fields in the middle
"""
previous = None
variables = set()
for config in configs:
if (p := Path(config)).is_dir():
config = p / "initial_config.toml"
dico = io.load_toml(config)
previous = Config.from_bare(override_config(dico, previous))
return previous
variables |= {(k, tuple(v)) for k, v in previous.variable.items()}
return previous, np.product([len(v) for k, v in variables])
def wspace(t, t_num=0):

View File

@@ -124,7 +124,7 @@ def main():
config_paths = args.configs
final_config = validate_config_sequence(*config_paths)
sim_num, _ = count_variations(final_config)
sim_num = count_variations(final_config)
args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node)

View File

@@ -176,7 +176,7 @@ def progress_worker(
pbars[0].update()
def count_variations(config: BareConfig) -> Tuple[int, int]:
def count_variations(config: BareConfig) -> int:
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
variable_params_num is the number of distinct parameters that will vary."""
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat