sim count fix
This commit is contained in:
@@ -7,6 +7,7 @@ from collections import defaultdict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import pi
|
from numpy import pi
|
||||||
|
from numpy.core.fromnumeric import var
|
||||||
|
|
||||||
from . import io, utils
|
from . import io, utils
|
||||||
from .defaults import default_parameters
|
from .defaults import default_parameters
|
||||||
@@ -14,7 +15,7 @@ from .errors import *
|
|||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .math import abs2, power_fact
|
from .math import abs2, power_fact
|
||||||
from .physics import fiber, pulse, units
|
from .physics import fiber, pulse, units
|
||||||
from .utils import count_variations, override_config, required_simulations
|
from .utils import count_variations, override_config, required_simulations, variable_iterator
|
||||||
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
|
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
|
||||||
|
|
||||||
|
|
||||||
@@ -491,7 +492,7 @@ class RecoveryParamSequence(ParamSequence):
|
|||||||
return path_dic[max_in_common]
|
return path_dic[max_in_common]
|
||||||
|
|
||||||
|
|
||||||
def validate_config_sequence(*configs: os.PathLike) -> Config:
|
def validate_config_sequence(*configs: os.PathLike) -> Tuple[Config, int]:
|
||||||
"""validates a sequence of configs where all but the first one may have
|
"""validates a sequence of configs where all but the first one may have
|
||||||
parameters missing
|
parameters missing
|
||||||
|
|
||||||
@@ -506,12 +507,14 @@ def validate_config_sequence(*configs: os.PathLike) -> Config:
|
|||||||
the final config as would be simulated, but of course missing input fields in the middle
|
the final config as would be simulated, but of course missing input fields in the middle
|
||||||
"""
|
"""
|
||||||
previous = None
|
previous = None
|
||||||
|
variables = set()
|
||||||
for config in configs:
|
for config in configs:
|
||||||
if (p := Path(config)).is_dir():
|
if (p := Path(config)).is_dir():
|
||||||
config = p / "initial_config.toml"
|
config = p / "initial_config.toml"
|
||||||
dico = io.load_toml(config)
|
dico = io.load_toml(config)
|
||||||
previous = Config.from_bare(override_config(dico, previous))
|
previous = Config.from_bare(override_config(dico, previous))
|
||||||
return previous
|
variables |= {(k, tuple(v)) for k, v in previous.variable.items()}
|
||||||
|
return previous, np.product([len(v) for k, v in variables])
|
||||||
|
|
||||||
|
|
||||||
def wspace(t, t_num=0):
|
def wspace(t, t_num=0):
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ def main():
|
|||||||
config_paths = args.configs
|
config_paths = args.configs
|
||||||
final_config = validate_config_sequence(*config_paths)
|
final_config = validate_config_sequence(*config_paths)
|
||||||
|
|
||||||
sim_num, _ = count_variations(final_config)
|
sim_num = count_variations(final_config)
|
||||||
|
|
||||||
args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node)
|
args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node)
|
||||||
|
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ def progress_worker(
|
|||||||
pbars[0].update()
|
pbars[0].update()
|
||||||
|
|
||||||
|
|
||||||
def count_variations(config: BareConfig) -> Tuple[int, int]:
|
def count_variations(config: BareConfig) -> int:
|
||||||
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
|
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
|
||||||
variable_params_num is the number of distinct parameters that will vary."""
|
variable_params_num is the number of distinct parameters that will vary."""
|
||||||
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
|
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
|
||||||
|
|||||||
Reference in New Issue
Block a user