diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 9044176..5c54591 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -411,6 +411,27 @@ class RecoveryParamSequence(ParamSequence): self.id = task_id self.num_steps = 0 + self.prev_sim_dir = None + if self.config.prev_sim_dir is not None: + self.prev_sim_dir = Path(self.config.prev_sim_dir) + init_config = io.load_config(self.prev_sim_dir / "initial_config.toml") + self.prev_variable_lists = [ + ( + set(variable_list[1:]), + self.prev_sim_dir / utils.format_variable_list(variable_list), + ) + for variable_list, _ in required_simulations(init_config) + ] + additional_sims_factor = int( + np.prod( + [ + len(init_config.variable[k]) + for k in (self.config.variable.keys() & init_config.variable.keys()) + if init_config.variable[k] != self.config.variable[k] + ] + ) + ) + self.update_num_sim(self.num_sim * additional_sims_factor) not_started = self.num_sim sub_folders = io.get_data_dirs(io.get_sim_dir(self.id)) @@ -426,18 +447,6 @@ class RecoveryParamSequence(ParamSequence): self.num_steps += not_started * self.config.z_num self.single_sim = self.num_sim == 1 - self.prev_sim_dir = None - if self.config.prev_sim_dir is not None: - self.prev_sim_dir = Path(self.config.prev_sim_dir) - init_config = io.load_config(self.prev_sim_dir / "initial_config.toml") - self.prev_variable_lists = [ - ( - set(variable_list[1:]), - self.prev_sim_dir / utils.format_variable_list(variable_list), - ) - for variable_list, _ in required_simulations(init_config) - ] - def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: for variable_list, bare_params in required_simulations(self.config): @@ -452,7 +461,7 @@ class RecoveryParamSequence(ParamSequence): else: continue - def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path: + def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]: """finds the previous simulation data that this new config should start from Parameters @@ -470,16 +479,16 @@ class RecoveryParamSequence(ParamSequence): ValueError no data folder found """ - if self.prev_sim_dir is None: - return None - to_test = set(new_variable_list[1:]) - for old_v_list, path in self.prev_variable_lists: - if to_test.issuperset(old_v_list): - return path + new_set = set(new_variable_list[1:]) + path_dic = defaultdict(list) + max_in_common = 0 + for stored_set, path in self.prev_variable_lists: + in_common = stored_set & new_set + num_in_common = len(in_common) + max_in_common = max(num_in_common, max_in_common) + path_dic[num_in_common].append(path) - raise ValueError( - f"cannot find a previous data folder for {new_variable_list} in {self.prev_sim_dir}" - ) + return path_dic[max_in_common] def validate_config_sequence(*configs: os.PathLike) -> Config: