fix same variable keys with diff vals in seq

This commit is contained in:
Benoît Sierro
2021-06-14 11:48:45 +02:00
parent 7e2ba74520
commit 0da7561c55
4 changed files with 42 additions and 21 deletions

View File

@@ -3,6 +3,7 @@ from collections.abc import Mapping
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Iterator, List, Tuple, Union
from collections import defaultdict
import numpy as np
from numpy import pi
@@ -102,9 +103,9 @@ class Params(BareParams):
self.energy,
self.soliton_num,
self.gamma,
self.beta,
self.beta[0],
)
logger.info(f"computed initial N = {self['soliton_num']:.3g}")
logger.info(f"computed initial N = {self.soliton_num:.3g}")
self.L_D = self.t0 ** 2 / abs(self.beta[0])
self.L_NL = 1 / (self.gamma * self.peak_power) if self.gamma else np.inf
@@ -309,9 +310,7 @@ class ParamSequence:
self.name = self.config.name
self.logger = get_logger(__name__)
self.num_sim, self.num_variable = count_variations(self.config)
self.num_steps = self.num_sim * self.config.z_num
self.single_sim = self.num_sim == 1
self.update_num_sim(count_variations(self.config))
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
"""iterates through all possible parameters, yielding a config as well as a flattened
@@ -325,6 +324,11 @@ class ParamSequence:
def __repr__(self) -> str:
return f"dispatcher generated from config {self.name}"
def update_num_sim(self, num_sim):
self.num_sim = num_sim
self.num_steps = self.num_sim * self.config.z_num
self.single_sim = self.num_sim == 1
class ContinuationParamSequence(ParamSequence):
def __init__(self, prev_sim_dir: os.PathLike, new_config_dict: Dict[str, Any]):
@@ -348,18 +352,30 @@ class ContinuationParamSequence(ParamSequence):
for variable_list, _ in required_simulations(init_config)
]
new_variable_keys = set(new_config_dict.get("variable", {}).keys())
new_config = utils.override_config(new_config_dict, init_config)
super().__init__(new_config)
additional_sims_factor = int(
np.prod(
[
len(init_config.variable[k])
for k in (new_variable_keys & init_config.variable.keys())
]
)
)
self.update_num_sim(self.num_sim * additional_sims_factor)
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
"""iterates through all possible parameters, yielding a config as well as a flattened
computed parameters set each time"""
for variable_list, bare_params in required_simulations(self.config):
prev_data_dir = self.find_prev_data_dir(variable_list).resolve()
bare_params.prev_data_dir = str(prev_data_dir)
variable_list.insert(1, ("prev_data_dir", None))
for prev_data_dir in self.find_prev_data_dirs(variable_list):
variable_list[1] = ("prev_data_dir", str(prev_data_dir.name))
bare_params.prev_data_dir = str(prev_data_dir.resolve())
yield variable_list, Params.from_bare(bare_params)
def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path:
def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]:
"""finds the previous simulation data that this new config should start from
Parameters
@@ -377,14 +393,16 @@ class ContinuationParamSequence(ParamSequence):
ValueError
no data folder found
"""
to_test = set(new_variable_list[1:])
for old_v_list, path in self.prev_variable_lists:
if to_test.issuperset(old_v_list):
return path
new_set = set(new_variable_list[1:])
path_dic = defaultdict(list)
max_in_common = 0
for stored_set, path in self.prev_variable_lists:
in_common = stored_set & new_set
num_in_common = len(in_common)
max_in_common = max(num_in_common, max_in_common)
path_dic[num_in_common].append(path)
raise ValueError(
f"cannot find a previous data folder for {new_variable_list} in {self.prev_sim_dir}"
)
return path_dic[max_in_common]
class RecoveryParamSequence(ParamSequence):

View File

@@ -221,7 +221,10 @@ def setup_custom_field(params: BareParams) -> bool:
bool
True if the field has been modified
"""
field_0 = width = peak_power = energy = None
field_0 = params.field_0
width = params.width
peak_power = params.peak_power
energy = params.energy
did_set = True

View File

@@ -86,7 +86,8 @@ def create_parser():
parser.add_argument(
"--environment-setup",
required=False,
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && export SCGENERATOR_PBAR_POLICY=file",
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && "
"export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_POLICY=file",
help="commands to run to setup the environement (default : activate the sc environment with conda)",
)
parser.add_argument(

View File

@@ -179,12 +179,11 @@ def progress_worker(
def count_variations(config: BareConfig) -> Tuple[int, int]:
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
variable_params_num is the number of distinct parameters that will vary."""
variable_params_num = len(config.variable)
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
return sim_num, variable_params_num
return sim_num
def format_variable_list(l: List[tuple]):
def format_variable_list(l: List[Tuple[str, Any]]):
joints = 2 * PARAM_SEPARATOR
str_list = []
for p_name, p_value in l: