added sequence fix to recovery
This commit is contained in:
@@ -411,6 +411,27 @@ class RecoveryParamSequence(ParamSequence):
|
||||
self.id = task_id
|
||||
self.num_steps = 0
|
||||
|
||||
self.prev_sim_dir = None
|
||||
if self.config.prev_sim_dir is not None:
|
||||
self.prev_sim_dir = Path(self.config.prev_sim_dir)
|
||||
init_config = io.load_config(self.prev_sim_dir / "initial_config.toml")
|
||||
self.prev_variable_lists = [
|
||||
(
|
||||
set(variable_list[1:]),
|
||||
self.prev_sim_dir / utils.format_variable_list(variable_list),
|
||||
)
|
||||
for variable_list, _ in required_simulations(init_config)
|
||||
]
|
||||
additional_sims_factor = int(
|
||||
np.prod(
|
||||
[
|
||||
len(init_config.variable[k])
|
||||
for k in (self.config.variable.keys() & init_config.variable.keys())
|
||||
if init_config.variable[k] != self.config.variable[k]
|
||||
]
|
||||
)
|
||||
)
|
||||
self.update_num_sim(self.num_sim * additional_sims_factor)
|
||||
not_started = self.num_sim
|
||||
sub_folders = io.get_data_dirs(io.get_sim_dir(self.id))
|
||||
|
||||
@@ -426,18 +447,6 @@ class RecoveryParamSequence(ParamSequence):
|
||||
self.num_steps += not_started * self.config.z_num
|
||||
self.single_sim = self.num_sim == 1
|
||||
|
||||
self.prev_sim_dir = None
|
||||
if self.config.prev_sim_dir is not None:
|
||||
self.prev_sim_dir = Path(self.config.prev_sim_dir)
|
||||
init_config = io.load_config(self.prev_sim_dir / "initial_config.toml")
|
||||
self.prev_variable_lists = [
|
||||
(
|
||||
set(variable_list[1:]),
|
||||
self.prev_sim_dir / utils.format_variable_list(variable_list),
|
||||
)
|
||||
for variable_list, _ in required_simulations(init_config)
|
||||
]
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
|
||||
for variable_list, bare_params in required_simulations(self.config):
|
||||
|
||||
@@ -452,7 +461,7 @@ class RecoveryParamSequence(ParamSequence):
|
||||
else:
|
||||
continue
|
||||
|
||||
def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path:
|
||||
def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]:
|
||||
"""finds the previous simulation data that this new config should start from
|
||||
|
||||
Parameters
|
||||
@@ -470,16 +479,16 @@ class RecoveryParamSequence(ParamSequence):
|
||||
ValueError
|
||||
no data folder found
|
||||
"""
|
||||
if self.prev_sim_dir is None:
|
||||
return None
|
||||
to_test = set(new_variable_list[1:])
|
||||
for old_v_list, path in self.prev_variable_lists:
|
||||
if to_test.issuperset(old_v_list):
|
||||
return path
|
||||
new_set = set(new_variable_list[1:])
|
||||
path_dic = defaultdict(list)
|
||||
max_in_common = 0
|
||||
for stored_set, path in self.prev_variable_lists:
|
||||
in_common = stored_set & new_set
|
||||
num_in_common = len(in_common)
|
||||
max_in_common = max(num_in_common, max_in_common)
|
||||
path_dic[num_in_common].append(path)
|
||||
|
||||
raise ValueError(
|
||||
f"cannot find a previous data folder for {new_variable_list} in {self.prev_sim_dir}"
|
||||
)
|
||||
return path_dic[max_in_common]
|
||||
|
||||
|
||||
def validate_config_sequence(*configs: os.PathLike) -> Config:
|
||||
|
||||
Reference in New Issue
Block a user