fix same variable keys with diff vals in seq
This commit is contained in:
@@ -3,6 +3,7 @@ from collections.abc import Mapping
|
|||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterator, List, Tuple, Union
|
from typing import Any, Dict, Iterator, List, Tuple, Union
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import pi
|
from numpy import pi
|
||||||
@@ -102,9 +103,9 @@ class Params(BareParams):
|
|||||||
self.energy,
|
self.energy,
|
||||||
self.soliton_num,
|
self.soliton_num,
|
||||||
self.gamma,
|
self.gamma,
|
||||||
self.beta,
|
self.beta[0],
|
||||||
)
|
)
|
||||||
logger.info(f"computed initial N = {self['soliton_num']:.3g}")
|
logger.info(f"computed initial N = {self.soliton_num:.3g}")
|
||||||
|
|
||||||
self.L_D = self.t0 ** 2 / abs(self.beta[0])
|
self.L_D = self.t0 ** 2 / abs(self.beta[0])
|
||||||
self.L_NL = 1 / (self.gamma * self.peak_power) if self.gamma else np.inf
|
self.L_NL = 1 / (self.gamma * self.peak_power) if self.gamma else np.inf
|
||||||
@@ -309,9 +310,7 @@ class ParamSequence:
|
|||||||
self.name = self.config.name
|
self.name = self.config.name
|
||||||
self.logger = get_logger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
|
|
||||||
self.num_sim, self.num_variable = count_variations(self.config)
|
self.update_num_sim(count_variations(self.config))
|
||||||
self.num_steps = self.num_sim * self.config.z_num
|
|
||||||
self.single_sim = self.num_sim == 1
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
|
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
|
||||||
"""iterates through all possible parameters, yielding a config as well as a flattened
|
"""iterates through all possible parameters, yielding a config as well as a flattened
|
||||||
@@ -325,6 +324,11 @@ class ParamSequence:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"dispatcher generated from config {self.name}"
|
return f"dispatcher generated from config {self.name}"
|
||||||
|
|
||||||
|
def update_num_sim(self, num_sim):
|
||||||
|
self.num_sim = num_sim
|
||||||
|
self.num_steps = self.num_sim * self.config.z_num
|
||||||
|
self.single_sim = self.num_sim == 1
|
||||||
|
|
||||||
|
|
||||||
class ContinuationParamSequence(ParamSequence):
|
class ContinuationParamSequence(ParamSequence):
|
||||||
def __init__(self, prev_sim_dir: os.PathLike, new_config_dict: Dict[str, Any]):
|
def __init__(self, prev_sim_dir: os.PathLike, new_config_dict: Dict[str, Any]):
|
||||||
@@ -348,18 +352,30 @@ class ContinuationParamSequence(ParamSequence):
|
|||||||
for variable_list, _ in required_simulations(init_config)
|
for variable_list, _ in required_simulations(init_config)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
new_variable_keys = set(new_config_dict.get("variable", {}).keys())
|
||||||
new_config = utils.override_config(new_config_dict, init_config)
|
new_config = utils.override_config(new_config_dict, init_config)
|
||||||
super().__init__(new_config)
|
super().__init__(new_config)
|
||||||
|
additional_sims_factor = int(
|
||||||
|
np.prod(
|
||||||
|
[
|
||||||
|
len(init_config.variable[k])
|
||||||
|
for k in (new_variable_keys & init_config.variable.keys())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.update_num_sim(self.num_sim * additional_sims_factor)
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
|
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
|
||||||
"""iterates through all possible parameters, yielding a config as well as a flattened
|
"""iterates through all possible parameters, yielding a config as well as a flattened
|
||||||
computed parameters set each time"""
|
computed parameters set each time"""
|
||||||
for variable_list, bare_params in required_simulations(self.config):
|
for variable_list, bare_params in required_simulations(self.config):
|
||||||
prev_data_dir = self.find_prev_data_dir(variable_list).resolve()
|
variable_list.insert(1, ("prev_data_dir", None))
|
||||||
bare_params.prev_data_dir = str(prev_data_dir)
|
for prev_data_dir in self.find_prev_data_dirs(variable_list):
|
||||||
yield variable_list, Params.from_bare(bare_params)
|
variable_list[1] = ("prev_data_dir", str(prev_data_dir.name))
|
||||||
|
bare_params.prev_data_dir = str(prev_data_dir.resolve())
|
||||||
|
yield variable_list, Params.from_bare(bare_params)
|
||||||
|
|
||||||
def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path:
|
def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]:
|
||||||
"""finds the previous simulation data that this new config should start from
|
"""finds the previous simulation data that this new config should start from
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -377,14 +393,16 @@ class ContinuationParamSequence(ParamSequence):
|
|||||||
ValueError
|
ValueError
|
||||||
no data folder found
|
no data folder found
|
||||||
"""
|
"""
|
||||||
to_test = set(new_variable_list[1:])
|
new_set = set(new_variable_list[1:])
|
||||||
for old_v_list, path in self.prev_variable_lists:
|
path_dic = defaultdict(list)
|
||||||
if to_test.issuperset(old_v_list):
|
max_in_common = 0
|
||||||
return path
|
for stored_set, path in self.prev_variable_lists:
|
||||||
|
in_common = stored_set & new_set
|
||||||
|
num_in_common = len(in_common)
|
||||||
|
max_in_common = max(num_in_common, max_in_common)
|
||||||
|
path_dic[num_in_common].append(path)
|
||||||
|
|
||||||
raise ValueError(
|
return path_dic[max_in_common]
|
||||||
f"cannot find a previous data folder for {new_variable_list} in {self.prev_sim_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RecoveryParamSequence(ParamSequence):
|
class RecoveryParamSequence(ParamSequence):
|
||||||
|
|||||||
@@ -221,7 +221,10 @@ def setup_custom_field(params: BareParams) -> bool:
|
|||||||
bool
|
bool
|
||||||
True if the field has been modified
|
True if the field has been modified
|
||||||
"""
|
"""
|
||||||
field_0 = width = peak_power = energy = None
|
field_0 = params.field_0
|
||||||
|
width = params.width
|
||||||
|
peak_power = params.peak_power
|
||||||
|
energy = params.energy
|
||||||
|
|
||||||
did_set = True
|
did_set = True
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,8 @@ def create_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--environment-setup",
|
"--environment-setup",
|
||||||
required=False,
|
required=False,
|
||||||
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && export SCGENERATOR_PBAR_POLICY=file",
|
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && "
|
||||||
|
"export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_POLICY=file",
|
||||||
help="commands to run to setup the environement (default : activate the sc environment with conda)",
|
help="commands to run to setup the environement (default : activate the sc environment with conda)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -179,12 +179,11 @@ def progress_worker(
|
|||||||
def count_variations(config: BareConfig) -> Tuple[int, int]:
|
def count_variations(config: BareConfig) -> Tuple[int, int]:
|
||||||
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
|
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
|
||||||
variable_params_num is the number of distinct parameters that will vary."""
|
variable_params_num is the number of distinct parameters that will vary."""
|
||||||
variable_params_num = len(config.variable)
|
|
||||||
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
|
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
|
||||||
return sim_num, variable_params_num
|
return sim_num
|
||||||
|
|
||||||
|
|
||||||
def format_variable_list(l: List[tuple]):
|
def format_variable_list(l: List[Tuple[str, Any]]):
|
||||||
joints = 2 * PARAM_SEPARATOR
|
joints = 2 * PARAM_SEPARATOR
|
||||||
str_list = []
|
str_list = []
|
||||||
for p_name, p_value in l:
|
for p_name, p_value in l:
|
||||||
|
|||||||
Reference in New Issue
Block a user