added sequence fix to recovery
This commit is contained in:
@@ -411,6 +411,27 @@ class RecoveryParamSequence(ParamSequence):
|
|||||||
self.id = task_id
|
self.id = task_id
|
||||||
self.num_steps = 0
|
self.num_steps = 0
|
||||||
|
|
||||||
|
self.prev_sim_dir = None
|
||||||
|
if self.config.prev_sim_dir is not None:
|
||||||
|
self.prev_sim_dir = Path(self.config.prev_sim_dir)
|
||||||
|
init_config = io.load_config(self.prev_sim_dir / "initial_config.toml")
|
||||||
|
self.prev_variable_lists = [
|
||||||
|
(
|
||||||
|
set(variable_list[1:]),
|
||||||
|
self.prev_sim_dir / utils.format_variable_list(variable_list),
|
||||||
|
)
|
||||||
|
for variable_list, _ in required_simulations(init_config)
|
||||||
|
]
|
||||||
|
additional_sims_factor = int(
|
||||||
|
np.prod(
|
||||||
|
[
|
||||||
|
len(init_config.variable[k])
|
||||||
|
for k in (self.config.variable.keys() & init_config.variable.keys())
|
||||||
|
if init_config.variable[k] != self.config.variable[k]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.update_num_sim(self.num_sim * additional_sims_factor)
|
||||||
not_started = self.num_sim
|
not_started = self.num_sim
|
||||||
sub_folders = io.get_data_dirs(io.get_sim_dir(self.id))
|
sub_folders = io.get_data_dirs(io.get_sim_dir(self.id))
|
||||||
|
|
||||||
@@ -426,18 +447,6 @@ class RecoveryParamSequence(ParamSequence):
|
|||||||
self.num_steps += not_started * self.config.z_num
|
self.num_steps += not_started * self.config.z_num
|
||||||
self.single_sim = self.num_sim == 1
|
self.single_sim = self.num_sim == 1
|
||||||
|
|
||||||
self.prev_sim_dir = None
|
|
||||||
if self.config.prev_sim_dir is not None:
|
|
||||||
self.prev_sim_dir = Path(self.config.prev_sim_dir)
|
|
||||||
init_config = io.load_config(self.prev_sim_dir / "initial_config.toml")
|
|
||||||
self.prev_variable_lists = [
|
|
||||||
(
|
|
||||||
set(variable_list[1:]),
|
|
||||||
self.prev_sim_dir / utils.format_variable_list(variable_list),
|
|
||||||
)
|
|
||||||
for variable_list, _ in required_simulations(init_config)
|
|
||||||
]
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
|
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
|
||||||
for variable_list, bare_params in required_simulations(self.config):
|
for variable_list, bare_params in required_simulations(self.config):
|
||||||
|
|
||||||
@@ -452,7 +461,7 @@ class RecoveryParamSequence(ParamSequence):
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path:
|
def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]:
|
||||||
"""finds the previous simulation data that this new config should start from
|
"""finds the previous simulation data that this new config should start from
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -470,16 +479,16 @@ class RecoveryParamSequence(ParamSequence):
|
|||||||
ValueError
|
ValueError
|
||||||
no data folder found
|
no data folder found
|
||||||
"""
|
"""
|
||||||
if self.prev_sim_dir is None:
|
new_set = set(new_variable_list[1:])
|
||||||
return None
|
path_dic = defaultdict(list)
|
||||||
to_test = set(new_variable_list[1:])
|
max_in_common = 0
|
||||||
for old_v_list, path in self.prev_variable_lists:
|
for stored_set, path in self.prev_variable_lists:
|
||||||
if to_test.issuperset(old_v_list):
|
in_common = stored_set & new_set
|
||||||
return path
|
num_in_common = len(in_common)
|
||||||
|
max_in_common = max(num_in_common, max_in_common)
|
||||||
|
path_dic[num_in_common].append(path)
|
||||||
|
|
||||||
raise ValueError(
|
return path_dic[max_in_common]
|
||||||
f"cannot find a previous data folder for {new_variable_list} in {self.prev_sim_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_config_sequence(*configs: os.PathLike) -> Config:
|
def validate_config_sequence(*configs: os.PathLike) -> Config:
|
||||||
|
|||||||
Reference in New Issue
Block a user