sim count fix
This commit is contained in:
@@ -7,6 +7,7 @@ from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from numpy import pi
|
||||
from numpy.core.fromnumeric import var
|
||||
|
||||
from . import io, utils
|
||||
from .defaults import default_parameters
|
||||
@@ -14,7 +15,7 @@ from .errors import *
|
||||
from .logger import get_logger
|
||||
from .math import abs2, power_fact
|
||||
from .physics import fiber, pulse, units
|
||||
from .utils import count_variations, override_config, required_simulations
|
||||
from .utils import count_variations, override_config, required_simulations, variable_iterator
|
||||
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
|
||||
|
||||
|
||||
@@ -491,7 +492,7 @@ class RecoveryParamSequence(ParamSequence):
|
||||
return path_dic[max_in_common]
|
||||
|
||||
|
||||
def validate_config_sequence(*configs: os.PathLike) -> Config:
|
||||
def validate_config_sequence(*configs: os.PathLike) -> Tuple[Config, int]:
|
||||
"""validates a sequence of configs where all but the first one may have
|
||||
parameters missing
|
||||
|
||||
@@ -506,12 +507,14 @@ def validate_config_sequence(*configs: os.PathLike) -> Config:
|
||||
the final config as would be simulated, but of course missing input fields in the middle
|
||||
"""
|
||||
previous = None
|
||||
variables = set()
|
||||
for config in configs:
|
||||
if (p := Path(config)).is_dir():
|
||||
config = p / "initial_config.toml"
|
||||
dico = io.load_toml(config)
|
||||
previous = Config.from_bare(override_config(dico, previous))
|
||||
return previous
|
||||
variables |= {(k, tuple(v)) for k, v in previous.variable.items()}
|
||||
return previous, np.product([len(v) for k, v in variables])
|
||||
|
||||
|
||||
def wspace(t, t_num=0):
|
||||
|
||||
@@ -124,7 +124,7 @@ def main():
|
||||
config_paths = args.configs
|
||||
final_config = validate_config_sequence(*config_paths)
|
||||
|
||||
sim_num, _ = count_variations(final_config)
|
||||
sim_num = count_variations(final_config)
|
||||
|
||||
args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node)
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ def progress_worker(
|
||||
pbars[0].update()
|
||||
|
||||
|
||||
def count_variations(config: BareConfig) -> Tuple[int, int]:
|
||||
def count_variations(config: BareConfig) -> int:
|
||||
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
|
||||
variable_params_num is the number of distinct parameters that will vary."""
|
||||
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
|
||||
|
||||
Reference in New Issue
Block a user