diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 5716394..9044176 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from dataclasses import asdict, dataclass from pathlib import Path from typing import Any, Dict, Iterator, List, Tuple, Union +from collections import defaultdict import numpy as np from numpy import pi @@ -102,9 +103,9 @@ class Params(BareParams): self.energy, self.soliton_num, self.gamma, - self.beta, + self.beta[0], ) - logger.info(f"computed initial N = {self['soliton_num']:.3g}") + logger.info(f"computed initial N = {self.soliton_num:.3g}") self.L_D = self.t0 ** 2 / abs(self.beta[0]) self.L_NL = 1 / (self.gamma * self.peak_power) if self.gamma else np.inf @@ -309,9 +310,7 @@ class ParamSequence: self.name = self.config.name self.logger = get_logger(__name__) - self.num_sim, self.num_variable = count_variations(self.config) - self.num_steps = self.num_sim * self.config.z_num - self.single_sim = self.num_sim == 1 + self.update_num_sim(count_variations(self.config)) def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: """iterates through all possible parameters, yielding a config as well as a flattened @@ -325,6 +324,11 @@ class ParamSequence: def __repr__(self) -> str: return f"dispatcher generated from config {self.name}" + def update_num_sim(self, num_sim): + self.num_sim = num_sim + self.num_steps = self.num_sim * self.config.z_num + self.single_sim = self.num_sim == 1 + class ContinuationParamSequence(ParamSequence): def __init__(self, prev_sim_dir: os.PathLike, new_config_dict: Dict[str, Any]): @@ -348,18 +352,30 @@ class ContinuationParamSequence(ParamSequence): for variable_list, _ in required_simulations(init_config) ] + new_variable_keys = set(new_config_dict.get("variable", {}).keys()) new_config = utils.override_config(new_config_dict, init_config) super().__init__(new_config) + additional_sims_factor = int( + np.prod( + [ + len(init_config.variable[k]) + for k in (new_variable_keys & init_config.variable.keys()) + ] + ) + ) + self.update_num_sim(self.num_sim * additional_sims_factor) def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: """iterates through all possible parameters, yielding a config as well as a flattened computed parameters set each time""" for variable_list, bare_params in required_simulations(self.config): - prev_data_dir = self.find_prev_data_dir(variable_list).resolve() - bare_params.prev_data_dir = str(prev_data_dir) - yield variable_list, Params.from_bare(bare_params) + variable_list.insert(1, ("prev_data_dir", None)) + for prev_data_dir in self.find_prev_data_dirs(variable_list): + variable_list[1] = ("prev_data_dir", str(prev_data_dir.name)) + bare_params.prev_data_dir = str(prev_data_dir.resolve()) + yield variable_list, Params.from_bare(bare_params) - def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path: + def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]: """finds the previous simulation data that this new config should start from Parameters @@ -377,14 +393,16 @@ class ContinuationParamSequence(ParamSequence): ValueError no data folder found """ - to_test = set(new_variable_list[1:]) - for old_v_list, path in self.prev_variable_lists: - if to_test.issuperset(old_v_list): - return path + new_set = set(new_variable_list[1:]) + path_dic = defaultdict(list) + max_in_common = 0 + for stored_set, path in self.prev_variable_lists: + in_common = stored_set & new_set + num_in_common = len(in_common) + max_in_common = max(num_in_common, max_in_common) + path_dic[num_in_common].append(path) - raise ValueError( - f"cannot find a previous data folder for {new_variable_list} in {self.prev_sim_dir}" - ) + return path_dic[max_in_common] class RecoveryParamSequence(ParamSequence): diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 881e3d9..ddff8c6 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -221,7 +221,10 @@ def setup_custom_field(params: BareParams) -> bool: bool True if the field has been modified """ - field_0 = width = peak_power = energy = None + field_0 = params.field_0 + width = params.width + peak_power = params.peak_power + energy = params.energy did_set = True diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index aa1abfe..783fcce 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -86,7 +86,8 @@ def create_parser(): parser.add_argument( "--environment-setup", required=False, - default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && export SCGENERATOR_PBAR_POLICY=file", + default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && " + "export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_POLICY=file", help="commands to run to setup the environement (default : activate the sc environment with conda)", ) parser.add_argument( diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index 4b3e839..6730014 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -179,12 +179,11 @@ def progress_worker( def count_variations(config: BareConfig) -> Tuple[int, int]: """returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and variable_params_num is the number of distinct parameters that will vary.""" - variable_params_num = len(config.variable) sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat - return sim_num, variable_params_num + return sim_num -def format_variable_list(l: List[tuple]): +def format_variable_list(l: List[Tuple[str, Any]]): joints = 2 * PARAM_SEPARATOR str_list = [] for p_name, p_value in l: