A lot of improvements
This commit is contained in:
@@ -13,6 +13,10 @@ spectra, params = load_sim_data("varyTechNoise100kW_sim_data")
|
|||||||
to plot
|
to plot
|
||||||
plot_results_2D(spectra[0], (600, 1450, nm), params)
|
plot_results_2D(spectra[0], (600, 1450, nm), params)
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
|
|
||||||
|
SCGENERATOR_PBAR_POLICY : "none", "file", "print", "both", optional
|
||||||
|
whether progress should be printed to a file ("file"), to the standard output ("print") or both, default : print
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,5 @@ numpy
|
|||||||
matplotlib
|
matplotlib
|
||||||
scipy
|
scipy
|
||||||
ray
|
ray
|
||||||
send2trash
|
|
||||||
toml
|
toml
|
||||||
tqdm
|
tqdm
|
||||||
@@ -1,13 +1,16 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
from scgenerator import initialize
|
from scgenerator.physics.simulate import (
|
||||||
from ..physics.simulate import run_simulation_sequence, resume_simulations, SequencialSimulations
|
run_simulation_sequence,
|
||||||
from .. import io
|
resume_simulations,
|
||||||
|
SequencialSimulations,
|
||||||
|
)
|
||||||
|
from scgenerator import io
|
||||||
|
|
||||||
|
|
||||||
def create_parser():
|
def create_parser():
|
||||||
@@ -73,7 +76,11 @@ def run_sim(args):
|
|||||||
|
|
||||||
|
|
||||||
def merge(args):
|
def merge(args):
|
||||||
io.append_and_merge(args.path, args.output_name)
|
path_trees = io.build_path_trees(Path(args.path))
|
||||||
|
|
||||||
|
if args.output_name is None:
|
||||||
|
args.output_name = path_trees[-1][0][0].parent.name + " merged"
|
||||||
|
io.merge(args.output_name, path_trees)
|
||||||
|
|
||||||
|
|
||||||
def prep_ray(args):
|
def prep_ray(args):
|
||||||
@@ -98,7 +105,7 @@ def resume_sim(args):
|
|||||||
sim = resume_simulations(args.sim_dir, method=method)
|
sim = resume_simulations(args.sim_dir, method=method)
|
||||||
sim.run()
|
sim.run()
|
||||||
run_simulation_sequence(
|
run_simulation_sequence(
|
||||||
*args.configs, method=method, prev_sim_dir=sim.data_folder, final_name=args.output_name
|
*args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -246,7 +246,12 @@ valid_variable = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
ENVIRON_KEY_BASE = "SCGENERATOR_"
|
ENVIRON_KEY_BASE = "SCGENERATOR_"
|
||||||
HUSH_PROGRESS = ENVIRON_KEY_BASE + "HUSH_PROGRESS"
|
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
|
||||||
TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_"
|
TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_"
|
||||||
PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_"
|
PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_"
|
||||||
PARAM_SEPARATOR = " "
|
PARAM_SEPARATOR = " "
|
||||||
|
|
||||||
|
SPEC1_FN = "spectrum_{}.npy"
|
||||||
|
SPECN_FN = "spectra_{}.npy"
|
||||||
|
Z_FN = "z.npy"
|
||||||
|
PARAM_FN = "params.toml"
|
||||||
29
src/scgenerator/env.py
Normal file
29
src/scgenerator/env.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from .const import ENVIRON_KEY_BASE, PBAR_POLICY, TMP_FOLDER_KEY_BASE
|
||||||
|
|
||||||
|
|
||||||
|
def data_folder(task_id: int) -> Optional[Path]:
|
||||||
|
idstr = str(int(task_id))
|
||||||
|
tmp = os.getenv(TMP_FOLDER_KEY_BASE + idstr)
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
|
||||||
|
def all_environ() -> Dict[str, str]:
|
||||||
|
"""returns a dictionary of all environment variables set by any instance of scgenerator"""
|
||||||
|
d = dict(filter(lambda el: el[0].startswith(ENVIRON_KEY_BASE), os.environ.items()))
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def pbar_policy() -> List[Literal["print", "file"]]:
|
||||||
|
policy = os.getenv(PBAR_POLICY)
|
||||||
|
if policy == "print" or policy is None:
|
||||||
|
return ["print"]
|
||||||
|
elif policy == "file":
|
||||||
|
return ["file"]
|
||||||
|
elif policy == "both":
|
||||||
|
return ["file", "print"]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
@@ -76,8 +76,8 @@ class ContinuationParamSequence(ParamSequence):
|
|||||||
"""iterates through all possible parameters, yielding a config as well as a flattened
|
"""iterates through all possible parameters, yielding a config as well as a flattened
|
||||||
computed parameters set each time"""
|
computed parameters set each time"""
|
||||||
for variable_list, full_config in required_simulations(self.config):
|
for variable_list, full_config in required_simulations(self.config):
|
||||||
prev_data_dir = self.find_prev_data_dir(variable_list)
|
prev_data_dir = self.find_prev_data_dir(variable_list).resolve()
|
||||||
full_config["prev_data_dir"] = str(prev_data_dir.resolve())
|
full_config["prev_data_dir"] = str(prev_data_dir)
|
||||||
yield variable_list, compute_init_parameters(full_config)
|
yield variable_list, compute_init_parameters(full_config)
|
||||||
|
|
||||||
def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path:
|
def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path:
|
||||||
@@ -116,7 +116,7 @@ class RecoveryParamSequence(ParamSequence):
|
|||||||
|
|
||||||
z_num = config["simulation"]["z_num"]
|
z_num = config["simulation"]["z_num"]
|
||||||
started = self.num_sim
|
started = self.num_sim
|
||||||
sub_folders = io.get_data_subfolders(self.id)
|
sub_folders = io.get_data_dirs(io.get_sim_dir(self.id))
|
||||||
|
|
||||||
pbar_store = utils.PBars(
|
pbar_store = utils.PBars(
|
||||||
tqdm(
|
tqdm(
|
||||||
@@ -157,7 +157,7 @@ class RecoveryParamSequence(ParamSequence):
|
|||||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
|
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
|
||||||
for variable_list, params in required_simulations(self.config):
|
for variable_list, params in required_simulations(self.config):
|
||||||
|
|
||||||
data_dir = io.get_data_folder(self.id) / utils.format_variable_list(variable_list)
|
data_dir = io.get_sim_dir(self.id) / utils.format_variable_list(variable_list)
|
||||||
|
|
||||||
if not data_dir.is_dir() or io.find_last_spectrum_num(data_dir) == 0:
|
if not data_dir.is_dir() or io.find_last_spectrum_num(data_dir) == 0:
|
||||||
if (prev_data_dir := self.find_prev_data_dir(variable_list)) is not None:
|
if (prev_data_dir := self.find_prev_data_dir(variable_list)) is not None:
|
||||||
@@ -565,7 +565,8 @@ def _ensure_consistency(config):
|
|||||||
def recover_params(config: Dict[str, Any], data_folder: Path) -> Dict[str, Any]:
|
def recover_params(config: Dict[str, Any], data_folder: Path) -> Dict[str, Any]:
|
||||||
params = compute_init_parameters(config)
|
params = compute_init_parameters(config)
|
||||||
try:
|
try:
|
||||||
prev_params = io.load_toml(data_folder / "params.toml")
|
prev_params = io.load_previous_parameters(data_folder / "params.toml")
|
||||||
|
prev_params = build_sim_grid(prev_params)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
prev_params = {}
|
prev_params = {}
|
||||||
for k, v in prev_params.items():
|
for k, v in prev_params.items():
|
||||||
@@ -602,7 +603,7 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
for key, value in config.get(section, {}).items():
|
for key, value in config.get(section, {}).items():
|
||||||
params[key] = value
|
params[key] = value
|
||||||
|
|
||||||
params = _generate_sim_grid(params)
|
params = build_sim_grid(params)
|
||||||
|
|
||||||
# Initial field may influence the grid
|
# Initial field may influence the grid
|
||||||
if "mean_power" in params:
|
if "mean_power" in params:
|
||||||
@@ -789,7 +790,7 @@ def _interp_range(w, upper, lower):
|
|||||||
return interp_range
|
return interp_range
|
||||||
|
|
||||||
|
|
||||||
def _generate_sim_grid(params):
|
def build_sim_grid(params):
|
||||||
"""computes a bunch of values that relate to the simulation grid
|
"""computes a bunch of values that relate to the simulation grid
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|||||||
@@ -1,30 +1,29 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from glob import glob
|
from typing import Any, Dict, Generator, List, Sequence, Tuple
|
||||||
from typing import Any, Dict, Iterable, List, Tuple, Union
|
import shutil
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pkg_resources as pkg
|
import pkg_resources as pkg
|
||||||
import toml
|
import toml
|
||||||
from send2trash import TrashPermissionError, send2trash
|
|
||||||
from tqdm import tqdm
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from . import utils
|
from . import utils, env
|
||||||
from .const import ENVIRON_KEY_BASE, PARAM_SEPARATOR, PREFIX_KEY_BASE, TMP_FOLDER_KEY_BASE
|
from .const import (
|
||||||
|
ENVIRON_KEY_BASE,
|
||||||
|
PARAM_SEPARATOR,
|
||||||
|
PBAR_POLICY,
|
||||||
|
TMP_FOLDER_KEY_BASE,
|
||||||
|
SPEC1_FN,
|
||||||
|
SPECN_FN,
|
||||||
|
Z_FN,
|
||||||
|
PARAM_FN,
|
||||||
|
)
|
||||||
from .errors import IncompleteDataFolderError
|
from .errors import IncompleteDataFolderError
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
|
||||||
using_ray = False
|
PathTree = List[Tuple[Path, ...]]
|
||||||
try:
|
|
||||||
import ray
|
|
||||||
from ray.util.queue import Queue
|
|
||||||
|
|
||||||
using_ray = True
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Paths:
|
class Paths:
|
||||||
@@ -81,51 +80,6 @@ class Paths:
|
|||||||
return os.path.join(cls.get("plots"), name)
|
return os.path.join(cls.get("plots"), name)
|
||||||
|
|
||||||
|
|
||||||
class DataBuffer:
|
|
||||||
def __init__(self, task_id):
|
|
||||||
self.logger = get_logger(__name__)
|
|
||||||
self.id = task_id
|
|
||||||
self.queue = Queue()
|
|
||||||
|
|
||||||
def empty(self):
|
|
||||||
num = self.queue.size()
|
|
||||||
if num == 0:
|
|
||||||
return 0
|
|
||||||
self.logger.info(f"buffer length at time of emptying : {num}")
|
|
||||||
while not self.queue.empty():
|
|
||||||
name, identifier, data = self.queue.get()
|
|
||||||
save_data(data, name, self.id, identifier)
|
|
||||||
|
|
||||||
return num
|
|
||||||
|
|
||||||
def append(self, file_name: str, identifier: str, data: np.ndarray):
|
|
||||||
self.queue.put((file_name, identifier, data))
|
|
||||||
|
|
||||||
|
|
||||||
# def abspath(rel_path: str):
|
|
||||||
# """returns the complete path with the correct root. In other words, allows to modify absolute paths
|
|
||||||
# in case the process accessing this function is a sub-process started from another device.
|
|
||||||
|
|
||||||
# Parameters
|
|
||||||
# ----------
|
|
||||||
# rel_path : str
|
|
||||||
# relative path
|
|
||||||
|
|
||||||
# Returns
|
|
||||||
# -------
|
|
||||||
# str
|
|
||||||
# absolute path
|
|
||||||
# """
|
|
||||||
# key = utils.formatted_hostname()
|
|
||||||
# prefix = os.getenv(key)
|
|
||||||
# if prefix is None:
|
|
||||||
# p = os.path.abspath(rel_path)
|
|
||||||
# else:
|
|
||||||
# p = os.path.join(prefix, rel_path)
|
|
||||||
|
|
||||||
# return os.path.normpath(p)
|
|
||||||
|
|
||||||
|
|
||||||
def conform_toml_path(path: os.PathLike) -> Path:
|
def conform_toml_path(path: os.PathLike) -> Path:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if not path.name.lower().endswith(".toml"):
|
if not path.name.lower().endswith(".toml"):
|
||||||
@@ -159,7 +113,7 @@ def serializable(val):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _prepare_for_serialization(dico: Dict[str, Any]):
|
def prepare_for_serialization(dico: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""prepares a dictionary for serialization. Some keys may not be preserved
|
"""prepares a dictionary for serialization. Some keys may not be preserved
|
||||||
(dropped due to no conversion available)
|
(dropped due to no conversion available)
|
||||||
|
|
||||||
@@ -168,7 +122,7 @@ def _prepare_for_serialization(dico: Dict[str, Any]):
|
|||||||
dico : dict
|
dico : dict
|
||||||
dictionary
|
dictionary
|
||||||
"""
|
"""
|
||||||
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w"]
|
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"]
|
||||||
types = (np.ndarray, float, int, str, list, tuple, dict)
|
types = (np.ndarray, float, int, str, list, tuple, dict)
|
||||||
out = {}
|
out = {}
|
||||||
for key, value in dico.items():
|
for key, value in dico.items():
|
||||||
@@ -177,7 +131,7 @@ def _prepare_for_serialization(dico: Dict[str, Any]):
|
|||||||
if not isinstance(value, types):
|
if not isinstance(value, types):
|
||||||
continue
|
continue
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
out[key] = _prepare_for_serialization(value)
|
out[key] = prepare_for_serialization(value)
|
||||||
elif isinstance(value, np.ndarray) and value.dtype == complex:
|
elif isinstance(value, np.ndarray) and value.dtype == complex:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@@ -186,11 +140,11 @@ def _prepare_for_serialization(dico: Dict[str, Any]):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def save_parameters(param_dict: Dict[str, Any], task_id: int, data_dir_name: str):
|
def save_parameters(param_dict: Dict[str, Any], data_dir: Path) -> Path:
|
||||||
param = param_dict.copy()
|
param = param_dict.copy()
|
||||||
file_path = generate_file_path("params.toml", task_id, data_dir_name)
|
file_path = data_dir / "params.toml"
|
||||||
|
|
||||||
param = _prepare_for_serialization(param)
|
param = prepare_for_serialization(param)
|
||||||
param["datetime"] = datetime.now()
|
param["datetime"] = datetime.now()
|
||||||
|
|
||||||
file_path.parent.mkdir(exist_ok=True)
|
file_path.parent.mkdir(exist_ok=True)
|
||||||
@@ -202,35 +156,6 @@ def save_parameters(param_dict: Dict[str, Any], task_id: int, data_dir_name: str
|
|||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
# def save_parameters_old(param_dict, file_name="param"):
|
|
||||||
# """Writes the flattened parameters dictionary specific to a single simulation into a toml file
|
|
||||||
|
|
||||||
# Parameters
|
|
||||||
# ----------
|
|
||||||
# param_dict : dictionary of parameters. Only floats, int and arrays of
|
|
||||||
# non complex values are stored in the json
|
|
||||||
# folder_name : folder where to save the files (relative to cwd)
|
|
||||||
# file_name : name of the readable file.
|
|
||||||
# """
|
|
||||||
# param = param_dict.copy()
|
|
||||||
|
|
||||||
# folder_name, file_name = os.path.split(file_name)
|
|
||||||
# folder_name = "tmp" if folder_name == "" else folder_name
|
|
||||||
# file_name = os.path.splitext(file_name)[0]
|
|
||||||
|
|
||||||
# if not os.path.exists(folder_name):
|
|
||||||
# os.makedirs(folder_name)
|
|
||||||
|
|
||||||
# param = _prepare_for_serialization(param)
|
|
||||||
# param["datetime"] = datetime.now()
|
|
||||||
|
|
||||||
# # save toml of the simulation
|
|
||||||
# with open(os.path.join(folder_name, file_name + ".toml"), "w") as file:
|
|
||||||
# toml.dump(param, file, encoder=toml.TomlNumpyEncoder())
|
|
||||||
|
|
||||||
# return os.path.join(folder_name, file_name)
|
|
||||||
|
|
||||||
|
|
||||||
def load_previous_parameters(path: os.PathLike):
|
def load_previous_parameters(path: os.PathLike):
|
||||||
"""loads a parameters toml files and converts data to appropriate type
|
"""loads a parameters toml files and converts data to appropriate type
|
||||||
Parameters
|
Parameters
|
||||||
@@ -267,31 +192,21 @@ def load_material_dico(name):
|
|||||||
return toml.loads(Paths.gets("gas"))[name]
|
return toml.loads(Paths.gets("gas"))[name]
|
||||||
|
|
||||||
|
|
||||||
def get_all_environ() -> Dict[str, str]:
|
def get_data_dirs(sim_dir: Path) -> List[Path]:
|
||||||
"""returns a dictionary of all environment variables set by any instance of scgenerator"""
|
"""returns a list of absolute paths corresponding to a particular run
|
||||||
d = dict(filter(lambda el: el[0].startswith(ENVIRON_KEY_BASE), os.environ.items()))
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def load_single_spectrum(folder: Path, index) -> np.ndarray:
|
|
||||||
return np.load(folder / f"spectra_{index}.npy")
|
|
||||||
|
|
||||||
|
|
||||||
def get_data_subfolders(task_id: int) -> List[Path]:
|
|
||||||
"""returns a list of relative path/subfolders in the specified directory
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : str
|
sim_dir : Path
|
||||||
path to directory containing the initial config file and the spectra sub folders
|
path to directory containing the initial config file and the spectra sub folders
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
List[str]
|
List[Path]
|
||||||
paths to sub folders
|
paths to sub folders
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return [p.resolve() for p in get_data_folder(task_id).glob("*") if p.is_dir()]
|
return [p.resolve() for p in sim_dir.glob("*") if p.is_dir()]
|
||||||
|
|
||||||
|
|
||||||
def check_data_integrity(sub_folders: List[Path], init_z_num: int):
|
def check_data_integrity(sub_folders: List[Path], init_z_num: int):
|
||||||
@@ -336,8 +251,7 @@ def num_left_to_propagate(sub_folder: Path, init_z_num: int) -> int:
|
|||||||
IncompleteDataFolderError
|
IncompleteDataFolderError
|
||||||
raised if init_z_num doesn't match that specified in the individual parameter file
|
raised if init_z_num doesn't match that specified in the individual parameter file
|
||||||
"""
|
"""
|
||||||
params = load_toml(sub_folder / "params.toml")
|
z_num = load_toml(sub_folder / "params.toml")["z_num"]
|
||||||
z_num = params["z_num"]
|
|
||||||
num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing
|
num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing
|
||||||
|
|
||||||
if z_num != init_z_num:
|
if z_num != init_z_num:
|
||||||
@@ -351,7 +265,7 @@ def num_left_to_propagate(sub_folder: Path, init_z_num: int) -> int:
|
|||||||
|
|
||||||
def find_last_spectrum_num(data_dir: Path):
|
def find_last_spectrum_num(data_dir: Path):
|
||||||
for num in itertools.count(1):
|
for num in itertools.count(1):
|
||||||
p_to_test = data_dir / f"spectrum_{num}.npy"
|
p_to_test = data_dir / SPEC1_FN.format(num)
|
||||||
if not p_to_test.is_file() or len(p_to_test.read_bytes()) == 0:
|
if not p_to_test.is_file() or len(p_to_test.read_bytes()) == 0:
|
||||||
return num - 1
|
return num - 1
|
||||||
|
|
||||||
@@ -359,142 +273,168 @@ def find_last_spectrum_num(data_dir: Path):
|
|||||||
def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]:
|
def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]:
|
||||||
"""return the last spectrum stored in path as well as its id"""
|
"""return the last spectrum stored in path as well as its id"""
|
||||||
num = find_last_spectrum_num(data_dir)
|
num = find_last_spectrum_num(data_dir)
|
||||||
return num, np.load(data_dir / f"spectrum_{num}.npy")
|
return num, np.load(data_dir / SPEC1_FN.format(num))
|
||||||
|
|
||||||
|
|
||||||
def append_and_merge(final_sim_path: os.PathLike, new_name=None):
|
def update_appended_params(source: Path, destination: Path, z: Sequence):
|
||||||
final_sim_path = Path(final_sim_path).resolve()
|
|
||||||
if new_name is None:
|
|
||||||
new_name = final_sim_path.name + " appended"
|
|
||||||
|
|
||||||
destination_path = final_sim_path.parent / new_name
|
|
||||||
destination_path.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
sim_paths = list(final_sim_path.glob("id*num*"))
|
|
||||||
pbars = utils.PBars.auto(
|
|
||||||
len(sim_paths),
|
|
||||||
0,
|
|
||||||
head_kwargs=dict(desc="Appending"),
|
|
||||||
worker_kwargs=dict(desc=""),
|
|
||||||
)
|
|
||||||
|
|
||||||
for sim_path in sim_paths:
|
|
||||||
path_tree = [sim_path]
|
|
||||||
sim_name = sim_path.name
|
|
||||||
appended_sim_path = destination_path / sim_name
|
|
||||||
appended_sim_path.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
while (
|
|
||||||
prev_sim_path := load_toml(path_tree[-1] / "params.toml").get("prev_data_dir")
|
|
||||||
) is not None:
|
|
||||||
path_tree.append(Path(prev_sim_path).resolve())
|
|
||||||
|
|
||||||
z: List[np.ndarray] = []
|
|
||||||
z_num = 0
|
|
||||||
last_z = 0
|
|
||||||
paths_r = list(reversed(path_tree))
|
|
||||||
|
|
||||||
for path in paths_r:
|
|
||||||
curr_z_num = load_toml(path / "params.toml")["z_num"]
|
|
||||||
for i in range(curr_z_num):
|
|
||||||
shutil.copy(
|
|
||||||
path / f"spectrum_{i}.npy",
|
|
||||||
appended_sim_path / f"spectrum_{i + z_num}.npy",
|
|
||||||
)
|
|
||||||
z_arr = np.load(path / "z.npy")
|
|
||||||
z.append(z_arr + last_z)
|
|
||||||
last_z += z_arr[-1]
|
|
||||||
z_num += curr_z_num
|
|
||||||
z_arr = np.concatenate(z)
|
|
||||||
update_appended_params(sim_path / "params.toml", appended_sim_path / "params.toml", z_arr)
|
|
||||||
np.save(appended_sim_path / "z.npy", z_arr)
|
|
||||||
pbars.update(0)
|
|
||||||
|
|
||||||
update_appended_params(
|
|
||||||
final_sim_path / "initial_config.toml", destination_path / "initial_config.toml", z_arr
|
|
||||||
)
|
|
||||||
pbars.close()
|
|
||||||
merge(destination_path, delete=True)
|
|
||||||
|
|
||||||
|
|
||||||
def update_appended_params(param_path: Path, new_path: Path, z):
|
|
||||||
z_num = len(z)
|
z_num = len(z)
|
||||||
params = load_toml(param_path)
|
params = load_toml(source)
|
||||||
if "simulation" in params:
|
if "simulation" in params:
|
||||||
params["simulation"]["z_num"] = z_num
|
params["simulation"]["z_num"] = z_num
|
||||||
params["simulation"]["z_targets"] = z
|
params["fiber"]["length"] = float(z[-1] - z[0])
|
||||||
else:
|
else:
|
||||||
params["z_num"] = z_num
|
params["z_num"] = z_num
|
||||||
params["z_targets"] = z
|
params["length"] = float(z[-1] - z[0])
|
||||||
save_toml(new_path, params)
|
save_toml(destination, params)
|
||||||
|
|
||||||
|
|
||||||
def merge(paths: Union[Path, List[Path]], delete=False):
|
def build_path_trees(sim_dir: Path) -> List[PathTree]:
|
||||||
if isinstance(paths, Path):
|
sim_dir = sim_dir.resolve()
|
||||||
paths = [paths]
|
path_branches: List[Tuple[Path, ...]] = []
|
||||||
for path in paths:
|
to_check = list(sim_dir.glob("id*num*"))
|
||||||
merge_same_simulations(path, delete=delete)
|
pbar = utils.PBars.auto(len(to_check), desc="Building path trees")
|
||||||
|
for branch in map(build_path_branch, to_check):
|
||||||
|
if branch is not None:
|
||||||
def merge_same_simulations(path: Path, delete=True):
|
path_branches.append(branch)
|
||||||
logger = get_logger(__name__)
|
|
||||||
num_separator = PARAM_SEPARATOR + "num" + PARAM_SEPARATOR
|
|
||||||
sub_folders = [p for p in path.glob("*") if p.is_dir()]
|
|
||||||
config = load_toml(path / "initial_config.toml")
|
|
||||||
repeat = config["simulation"].get("repeat", 1)
|
|
||||||
max_repeat_id = repeat - 1
|
|
||||||
z_num = config["simulation"]["z_num"]
|
|
||||||
|
|
||||||
check_data_integrity(sub_folders, z_num)
|
|
||||||
|
|
||||||
sim_num, param_num = utils.count_variations(config)
|
|
||||||
pbar = utils.PBars.auto(sim_num * z_num, head_kwargs=dict(desc="Merging data"))
|
|
||||||
|
|
||||||
spectra = []
|
|
||||||
for z_id in range(z_num):
|
|
||||||
for variable_and_ind, _ in utils.required_simulations(config):
|
|
||||||
repeat_id = variable_and_ind[-1][1]
|
|
||||||
|
|
||||||
# reset the buffer once we move to a new parameter set
|
|
||||||
if repeat_id == 0:
|
|
||||||
spectra = []
|
|
||||||
|
|
||||||
in_path = path / utils.format_variable_list(variable_and_ind)
|
|
||||||
spectra.append(np.load(in_path / f"spectrum_{z_id}.npy"))
|
|
||||||
pbar.update()
|
pbar.update()
|
||||||
|
pbar.close()
|
||||||
|
path_trees = group_path_branches(path_branches)
|
||||||
|
return path_trees
|
||||||
|
|
||||||
# write new files only once all those from one parameter set are collected
|
|
||||||
if repeat_id == max_repeat_id:
|
def build_path_branch(data_dir: Path) -> Tuple[Path, ...]:
|
||||||
out_path = path / (
|
if not data_dir.is_dir():
|
||||||
utils.format_variable_list(variable_and_ind[1:-1]) + PARAM_SEPARATOR + "merged"
|
return None
|
||||||
|
path_branch = [data_dir]
|
||||||
|
while (prev_sim_path := load_toml(path_branch[-1] / PARAM_FN).get("prev_data_dir")) is not None:
|
||||||
|
p = Path(prev_sim_path).resolve()
|
||||||
|
if not p.exists():
|
||||||
|
p = Path(*p.parts[-2:]).resolve()
|
||||||
|
path_branch.append(p)
|
||||||
|
return tuple(reversed(path_branch))
|
||||||
|
|
||||||
|
|
||||||
|
def group_path_branches(path_branches: List[Tuple[Path, ...]]) -> List[PathTree]:
|
||||||
|
"""groups path lists
|
||||||
|
|
||||||
|
[
|
||||||
|
("a/id 0 wavelength 100 num 0"," b/id 0 wavelength 100 num 0"),
|
||||||
|
("a/id 2 wavelength 100 num 1"," b/id 2 wavelength 100 num 1"),
|
||||||
|
("a/id 1 wavelength 200 num 0"," b/id 1 wavelength 200 num 0"),
|
||||||
|
("a/id 3 wavelength 200 num 1"," b/id 3 wavelength 200 num 1")
|
||||||
|
]
|
||||||
|
->
|
||||||
|
[
|
||||||
|
(
|
||||||
|
("a/id 0 wavelength 100 num 0", "a/id 2 wavelength 100 num 1"),
|
||||||
|
("b/id 0 wavelength 100 num 0", "b/id 2 wavelength 100 num 1"),
|
||||||
|
)
|
||||||
|
(
|
||||||
|
("a/id 1 wavelength 200 num 0", "a/id 3 wavelength 200 num 1"),
|
||||||
|
("b/id 1 wavelength 200 num 0", "b/id 3 wavelength 200 num 1"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path_branches : List[Tuple[Path, ...]]
|
||||||
|
each element of the list is a path to a folder containing data of one simulation
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[PathTree]
|
||||||
|
List of PathTrees to be used in merge
|
||||||
|
"""
|
||||||
|
sort_key = lambda el: el[0]
|
||||||
|
|
||||||
|
size = len(path_branches[0])
|
||||||
|
out_trees_map: Dict[str, Dict[int, Dict[int, Path]]] = {}
|
||||||
|
for branch in path_branches:
|
||||||
|
b_id = utils.branch_id(branch)
|
||||||
|
out_trees_map.setdefault(b_id, {i: {} for i in range(size)})
|
||||||
|
for sim_part, data_dir in enumerate(branch):
|
||||||
|
*_, num = data_dir.name.split()
|
||||||
|
out_trees_map[b_id][sim_part][int(num)] = data_dir
|
||||||
|
|
||||||
|
return [
|
||||||
|
tuple(
|
||||||
|
tuple(w for _, w in sorted(v.items(), key=sort_key))
|
||||||
|
for __, v in sorted(d.items(), key=sort_key)
|
||||||
|
)
|
||||||
|
for d in out_trees_map.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_path_tree(path_tree: PathTree, destination: Path):
|
||||||
|
"""given a path tree, copies the file into the right location
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path_tree : PathTree
|
||||||
|
elements of the list returned by group_path_branches
|
||||||
|
destination : Path
|
||||||
|
dir where to save the data
|
||||||
|
"""
|
||||||
|
z_arr: List[float] = []
|
||||||
|
|
||||||
|
destination.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
for i, (z, merged_spectra) in enumerate(merge_spectra(path_tree)):
|
||||||
|
z_arr.append(z)
|
||||||
|
spec_out_name = SPECN_FN.format(i)
|
||||||
|
np.save(destination / spec_out_name, merged_spectra)
|
||||||
|
d = np.diff(z_arr)
|
||||||
|
d[d < 0] = 0
|
||||||
|
z_arr = np.concatenate(([z_arr[0]], np.cumsum(d)))
|
||||||
|
np.save(destination / Z_FN, z_arr)
|
||||||
|
update_appended_params(path_tree[-1][0] / PARAM_FN, destination / PARAM_FN, z_arr)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_spectra(
|
||||||
|
path_tree: PathTree,
|
||||||
|
) -> Generator[Tuple[float, np.ndarray], None, None]:
|
||||||
|
for same_sim_paths in path_tree:
|
||||||
|
z_arr = np.load(same_sim_paths[0] / Z_FN)
|
||||||
|
for i, z in enumerate(z_arr):
|
||||||
|
spectra: List[np.ndarray] = []
|
||||||
|
for data_dir in same_sim_paths:
|
||||||
|
spec = np.load(data_dir / SPEC1_FN.format(i))
|
||||||
|
spectra.append(spec)
|
||||||
|
yield z, np.atleast_2d(spectra)
|
||||||
|
|
||||||
|
|
||||||
|
def merge(destination: os.PathLike, path_trees: List[PathTree] = None):
|
||||||
|
|
||||||
|
destination = ensure_folder(Path(destination))
|
||||||
|
|
||||||
|
for i, sim_dir in enumerate(sim_dirs(path_trees)):
|
||||||
|
shutil.copy(
|
||||||
|
sim_dir / "initial_config.toml",
|
||||||
|
destination / f"initial_config_{i}.toml",
|
||||||
)
|
)
|
||||||
|
|
||||||
out_path = ensure_folder(out_path, prevent_overwrite=False)
|
pbar = utils.PBars.auto(len(path_trees), desc="Merging")
|
||||||
spectra = np.array(spectra).reshape(repeat, len(spectra[0]))
|
for path_tree in path_trees:
|
||||||
np.save(out_path / f"spectra_{z_id}.npy", spectra.squeeze())
|
iden = PARAM_SEPARATOR.join(path_tree[-1][0].name.split()[2:-2])
|
||||||
|
merge_path_tree(path_tree, destination / iden)
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
# copy other files only once
|
|
||||||
if z_id == 0:
|
|
||||||
for file_name in ["z.npy", "params.toml"]:
|
|
||||||
shutil.copy(in_path / file_name, out_path)
|
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
if delete:
|
|
||||||
for sub_folder in sub_folders:
|
def sim_dirs(path_trees: List[PathTree]) -> Generator[Path, None, None]:
|
||||||
try:
|
for p in path_trees[0]:
|
||||||
send2trash(str(sub_folder))
|
yield p[0].parent
|
||||||
except TrashPermissionError:
|
|
||||||
logger.warning(f"could not send send {sub_folder} to trash")
|
|
||||||
|
|
||||||
|
|
||||||
def get_data_folder(task_id: int, name_if_new: str = "data") -> Path:
|
def get_sim_dir(task_id: int, name_if_new: str = "data") -> Path:
|
||||||
if name_if_new == "":
|
if name_if_new == "":
|
||||||
name_if_new = "data"
|
name_if_new = "data"
|
||||||
idstr = str(int(task_id))
|
tmp = env.data_folder(task_id)
|
||||||
tmp = os.getenv(TMP_FOLDER_KEY_BASE + idstr)
|
|
||||||
if tmp is None:
|
if tmp is None:
|
||||||
tmp = ensure_folder(Path("scgenerator" + PARAM_SEPARATOR + name_if_new))
|
tmp = ensure_folder(Path("scgenerator" + PARAM_SEPARATOR + name_if_new))
|
||||||
os.environ[TMP_FOLDER_KEY_BASE + idstr] = str(tmp)
|
os.environ[TMP_FOLDER_KEY_BASE + str(task_id)] = str(tmp)
|
||||||
tmp = Path(tmp).resolve()
|
tmp = Path(tmp).resolve()
|
||||||
if not tmp.exists():
|
if not tmp.exists():
|
||||||
tmp.mkdir()
|
tmp.mkdir()
|
||||||
@@ -515,30 +455,7 @@ def set_data_folder(task_id: int, path: os.PathLike):
|
|||||||
os.environ[TMP_FOLDER_KEY_BASE + idstr] = str(path)
|
os.environ[TMP_FOLDER_KEY_BASE + idstr] = str(path)
|
||||||
|
|
||||||
|
|
||||||
def generate_file_path(file_name: str, task_id: int, identifier: str = "") -> Path:
|
def save_data(data: np.ndarray, data_dir: Path, file_name: str):
|
||||||
"""generates a path for the desired file name
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
file_name : str
|
|
||||||
desired file name. May be altered if it already exists
|
|
||||||
task_id : int
|
|
||||||
unique id of the process
|
|
||||||
identifier : str
|
|
||||||
subfolder in which to store the file. default : ""
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str
|
|
||||||
the full path
|
|
||||||
"""
|
|
||||||
path = get_data_folder(task_id) / identifier / file_name
|
|
||||||
path.parent.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
def save_data(data: np.ndarray, file_name: str, task_id: int, identifier: str = ""):
|
|
||||||
"""saves numpy array to disk
|
"""saves numpy array to disk
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -552,7 +469,7 @@ def save_data(data: np.ndarray, file_name: str, task_id: int, identifier: str =
|
|||||||
identifier : str, optional
|
identifier : str, optional
|
||||||
identifier in the main data folder of the task, by default ""
|
identifier in the main data folder of the task, by default ""
|
||||||
"""
|
"""
|
||||||
path = generate_file_path(file_name, task_id, identifier)
|
path = data_dir / file_name
|
||||||
np.save(path, data)
|
np.save(path, data)
|
||||||
get_logger(__name__).debug(f"saved data in {path}")
|
get_logger(__name__).debug(f"saved data in {path}")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Tuple, Type
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .. import initialize, io, utils, const
|
from .. import initialize, io, utils, const, env
|
||||||
from ..errors import IncompleteDataFolderError
|
from ..errors import IncompleteDataFolderError
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from . import pulse
|
from . import pulse
|
||||||
@@ -70,11 +70,13 @@ class RK4IP:
|
|||||||
print/log progress update every n_percent, by default 10
|
print/log progress update every n_percent, by default 10
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.set_new_params(sim_params, save_data, job_identifier, task_id, n_percent)
|
|
||||||
|
|
||||||
def set_new_params(self, sim_params, save_data, job_identifier, task_id, n_percent):
|
|
||||||
self.job_identifier = job_identifier
|
self.job_identifier = job_identifier
|
||||||
self.id = task_id
|
self.id = task_id
|
||||||
|
|
||||||
|
self.sim_dir = io.get_sim_dir(self.id)
|
||||||
|
self.sim_dir.mkdir(exist_ok=True)
|
||||||
|
self.data_dir = self.sim_dir/self.job_identifier
|
||||||
|
|
||||||
self.n_percent = n_percent
|
self.n_percent = n_percent
|
||||||
self.logger = get_logger(self.job_identifier)
|
self.logger = get_logger(self.job_identifier)
|
||||||
self.resuming = False
|
self.resuming = False
|
||||||
@@ -177,7 +179,7 @@ class RK4IP:
|
|||||||
name : str
|
name : str
|
||||||
file name
|
file name
|
||||||
"""
|
"""
|
||||||
io.save_data(data, name, self.id, self.job_identifier)
|
io.save_data(data, self.data_dir, name)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
||||||
@@ -307,14 +309,13 @@ class SequentialRK4IP(RK4IP):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sim_params,
|
sim_params,
|
||||||
overall_pbar: tqdm,
|
pbars: utils.PBars,
|
||||||
save_data=False,
|
save_data=False,
|
||||||
job_identifier="",
|
job_identifier="",
|
||||||
task_id=0,
|
task_id=0,
|
||||||
n_percent=10,
|
n_percent=10,
|
||||||
):
|
):
|
||||||
self.overall_pbar = overall_pbar
|
self.pbars = pbars
|
||||||
self.pbar = tqdm(**const.pbar_format(1))
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
sim_params,
|
sim_params,
|
||||||
save_data=save_data,
|
save_data=save_data,
|
||||||
@@ -324,8 +325,8 @@ class SequentialRK4IP(RK4IP):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def step_saved(self):
|
def step_saved(self):
|
||||||
self.overall_pbar.update()
|
self.pbars.update(0)
|
||||||
self.pbar.update(self.z / self.z_final - self.pbar.n)
|
self.pbars.update(1, self.z / self.z_final - self.pbars[1].n)
|
||||||
|
|
||||||
|
|
||||||
class MutliProcRK4IP(RK4IP):
|
class MutliProcRK4IP(RK4IP):
|
||||||
@@ -441,8 +442,8 @@ class Simulations:
|
|||||||
self.update(param_seq)
|
self.update(param_seq)
|
||||||
|
|
||||||
self.name = self.param_seq.name
|
self.name = self.param_seq.name
|
||||||
self.data_folder = io.get_data_folder(self.id, name_if_new=self.name)
|
self.sim_dir = io.get_sim_dir(self.id, name_if_new=self.name)
|
||||||
io.save_toml(os.path.join(self.data_folder, "initial_config.toml"), self.param_seq.config)
|
io.save_toml(os.path.join(self.sim_dir, "initial_config.toml"), self.param_seq.config)
|
||||||
|
|
||||||
self.sim_jobs_per_node = 1
|
self.sim_jobs_per_node = 1
|
||||||
self.max_concurrent_jobs = np.inf
|
self.max_concurrent_jobs = np.inf
|
||||||
@@ -451,7 +452,7 @@ class Simulations:
|
|||||||
def finished_and_complete(self):
|
def finished_and_complete(self):
|
||||||
try:
|
try:
|
||||||
io.check_data_integrity(
|
io.check_data_integrity(
|
||||||
io.get_data_subfolders(self.id), self.param_seq["simulation", "z_num"]
|
io.get_data_dirs(self.sim_dir), self.param_seq["simulation", "z_num"]
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
except IncompleteDataFolderError:
|
except IncompleteDataFolderError:
|
||||||
@@ -469,21 +470,21 @@ class Simulations:
|
|||||||
|
|
||||||
def _run_available(self):
|
def _run_available(self):
|
||||||
for variable, params in self.param_seq:
|
for variable, params in self.param_seq:
|
||||||
io.save_parameters(params, self.id, utils.format_variable_list(variable))
|
v_list_str = utils.format_variable_list(variable)
|
||||||
|
io.save_parameters(params, self.sim_dir / v_list_str)
|
||||||
|
|
||||||
self.new_sim(variable, params)
|
self.new_sim(v_list_str, params)
|
||||||
self.finish()
|
self.finish()
|
||||||
|
|
||||||
def new_sim(self, variable_list: List[tuple], params: dict):
|
def new_sim(self, v_list_str: str, params: dict):
|
||||||
"""responsible to launch a new simulation
|
"""responsible to launch a new simulation
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
variable_list : list[tuple]
|
v_list_str : str
|
||||||
list of tuples (name, value) where name is the name of a
|
string that uniquely identifies the simulation as returned by utils.format_variable_list
|
||||||
variable parameter and value is its current value
|
|
||||||
params : dict
|
params : dict
|
||||||
a flattened parameter dictionary, as returned by scgenerator.initialize.compute_init_parameters
|
a flattened parameter dictionary, as returned by initialize.compute_init_parameters
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@@ -508,15 +509,14 @@ class SequencialSimulations(Simulations, priority=0):
|
|||||||
|
|
||||||
def __init__(self, param_seq: initialize.ParamSequence, task_id):
|
def __init__(self, param_seq: initialize.ParamSequence, task_id):
|
||||||
super().__init__(param_seq, task_id=task_id)
|
super().__init__(param_seq, task_id=task_id)
|
||||||
self.overall_pbar = tqdm(
|
self.pbars = utils.PBars.auto(
|
||||||
total=self.param_seq.num_steps, desc="Simulating", unit="step", **const.pbar_format(0)
|
self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1
|
||||||
)
|
)
|
||||||
|
|
||||||
def new_sim(self, variable_list: List[tuple], params: Dict[str, Any]):
|
def new_sim(self, v_list_str: str, params: Dict[str, Any]):
|
||||||
v_list_str = utils.format_variable_list(variable_list)
|
|
||||||
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
|
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
|
||||||
SequentialRK4IP(
|
SequentialRK4IP(
|
||||||
params, self.overall_pbar, save_data=True, job_identifier=v_list_str, task_id=self.id
|
params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id
|
||||||
).run()
|
).run()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
@@ -545,7 +545,12 @@ class MultiProcSimulations(Simulations, priority=1):
|
|||||||
]
|
]
|
||||||
self.p_worker = multiprocessing.Process(
|
self.p_worker = multiprocessing.Process(
|
||||||
target=utils.progress_worker,
|
target=utils.progress_worker,
|
||||||
args=(self.sim_jobs_per_node, self.param_seq.num_steps, self.progress_queue),
|
args=(
|
||||||
|
self.param_seq.name,
|
||||||
|
self.sim_jobs_per_node,
|
||||||
|
self.param_seq.num_steps,
|
||||||
|
self.progress_queue,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
self.p_worker.start()
|
self.p_worker.start()
|
||||||
|
|
||||||
@@ -554,8 +559,8 @@ class MultiProcSimulations(Simulations, priority=1):
|
|||||||
worker.start()
|
worker.start()
|
||||||
super().run()
|
super().run()
|
||||||
|
|
||||||
def new_sim(self, variable_list: List[tuple], params: dict):
|
def new_sim(self, v_list_str: str, params: dict):
|
||||||
self.queue.put((variable_list, params), block=True, timeout=None)
|
self.queue.put((v_list_str, params), block=True, timeout=None)
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
"""0 means finished"""
|
"""0 means finished"""
|
||||||
@@ -581,8 +586,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
|||||||
if raw_data == 0:
|
if raw_data == 0:
|
||||||
queue.task_done()
|
queue.task_done()
|
||||||
return
|
return
|
||||||
variable_list, params = raw_data
|
v_list_str, params = raw_data
|
||||||
v_list_str = utils.format_variable_list(variable_list)
|
|
||||||
MutliProcRK4IP(
|
MutliProcRK4IP(
|
||||||
params,
|
params,
|
||||||
p_queue,
|
p_queue,
|
||||||
@@ -620,7 +624,7 @@ class RaySimulations(Simulations, priority=2):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.propagator = ray.remote(RayRK4IP).options(
|
self.propagator = ray.remote(RayRK4IP).options(
|
||||||
override_environment_variables=io.get_all_environ()
|
override_environment_variables=env.all_environ()
|
||||||
)
|
)
|
||||||
self.sim_jobs_per_node = min(
|
self.sim_jobs_per_node = min(
|
||||||
self.param_seq.num_sim, self.param_seq["simulation", "parallel"]
|
self.param_seq.num_sim, self.param_seq["simulation", "parallel"]
|
||||||
@@ -631,16 +635,15 @@ class RaySimulations(Simulations, priority=2):
|
|||||||
self.rolling_id = 0
|
self.rolling_id = 0
|
||||||
self.p_actor = (
|
self.p_actor = (
|
||||||
ray.remote(utils.ProgressBarActor)
|
ray.remote(utils.ProgressBarActor)
|
||||||
.options(override_environment_variables=io.get_all_environ())
|
.options(override_environment_variables=env.all_environ())
|
||||||
.remote(self.sim_jobs_total, self.param_seq.num_steps)
|
.remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
|
||||||
)
|
)
|
||||||
|
|
||||||
def new_sim(self, variable_list: List[tuple], params: dict):
|
def new_sim(self, v_list_str: str, params: dict):
|
||||||
while len(self.jobs) >= self.sim_jobs_total:
|
while len(self.jobs) >= self.sim_jobs_total:
|
||||||
self._collect_1_job()
|
self._collect_1_job()
|
||||||
|
|
||||||
self.rolling_id = (self.rolling_id + 1) % self.sim_jobs_total
|
self.rolling_id = (self.rolling_id + 1) % self.sim_jobs_total
|
||||||
v_list_str = utils.format_variable_list(variable_list)
|
|
||||||
|
|
||||||
new_actor = self.propagator.remote(
|
new_actor = self.propagator.remote(
|
||||||
params,
|
params,
|
||||||
@@ -693,8 +696,9 @@ def run_simulation_sequence(
|
|||||||
for config_file in config_files:
|
for config_file in config_files:
|
||||||
sim = new_simulation(config_file, prev, method)
|
sim = new_simulation(config_file, prev, method)
|
||||||
sim.run()
|
sim.run()
|
||||||
prev = sim.data_folder
|
prev = sim.sim_dir
|
||||||
io.append_and_merge(prev, final_name)
|
path_trees = io.build_path_trees(sim.sim_dir)
|
||||||
|
io.merge(final_name, path_trees)
|
||||||
|
|
||||||
|
|
||||||
def new_simulation(
|
def new_simulation(
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..initialize import validate_config_sequence
|
from ..initialize import validate_config_sequence
|
||||||
@@ -85,7 +86,7 @@ def create_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--environment-setup",
|
"--environment-setup",
|
||||||
required=False,
|
required=False,
|
||||||
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && export SCGENERATOR_HUSH_PROGRESS=\"\"",
|
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && export SCGENERATOR_PBAR_POLICY=file",
|
||||||
help="commands to run to setup the environement (default : activate the sc environment with conda)",
|
help="commands to run to setup the environement (default : activate the sc environment with conda)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from scgenerator.const import SPECN_FN
|
||||||
|
|
||||||
from . import io, initialize, math
|
from . import io, initialize, math
|
||||||
from .plotting import units
|
from .plotting import units
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
@@ -45,6 +47,8 @@ class Pulse(Sequence):
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
self.logger.info(f"parameters corresponding to {self.path} not found")
|
self.logger.info(f"parameters corresponding to {self.path} not found")
|
||||||
|
|
||||||
|
self.params = initialize.build_sim_grid(self.params)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.z = np.load(os.path.join(path, "z.npy"))
|
self.z = np.load(os.path.join(path, "z.npy"))
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -192,7 +196,7 @@ class Pulse(Sequence):
|
|||||||
i = self.nmax + i
|
i = self.nmax + i
|
||||||
if i in self.cache:
|
if i in self.cache:
|
||||||
return self.cache[i]
|
return self.cache[i]
|
||||||
spec = io.load_single_spectrum(self.path, i)
|
spec = np.load(self.path / SPECN_FN.format(i))
|
||||||
if self.__ensure_2d:
|
if self.__ensure_2d:
|
||||||
spec = np.atleast_2d(spec)
|
spec = np.atleast_2d(spec)
|
||||||
spec = Spectrum(spec, self.wl, self.params["frep"])
|
spec = Spectrum(spec, self.wl, self.params["frep"])
|
||||||
|
|||||||
@@ -10,18 +10,19 @@ import datetime as dt
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import socket
|
from copy import deepcopy
|
||||||
import os
|
|
||||||
from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union
|
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from pathlib import Path
|
||||||
|
import threading
|
||||||
|
from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import ray
|
import ray
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, HUSH_PROGRESS
|
from . import env
|
||||||
|
from .const import PARAM_SEPARATOR, valid_variable
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .math import *
|
from .math import *
|
||||||
|
|
||||||
@@ -29,19 +30,19 @@ from .math import *
|
|||||||
class PBars:
|
class PBars:
|
||||||
@classmethod
|
@classmethod
|
||||||
def auto(
|
def auto(
|
||||||
cls, num_tot: int, num_sub_bars: int = 0, head_kwargs=None, worker_kwargs=None
|
cls, num_tot: int, desc: str, num_sub_bars: int = 0, head_kwargs=None, worker_kwargs=None
|
||||||
) -> "PBars":
|
) -> "PBars":
|
||||||
if head_kwargs is None:
|
if head_kwargs is None:
|
||||||
head_kwargs = dict(unit="step", desc="Simulating", smoothing=0)
|
head_kwargs = dict()
|
||||||
if worker_kwargs is None:
|
if worker_kwargs is None:
|
||||||
worker_kwargs = dict(
|
worker_kwargs = dict(
|
||||||
total=1,
|
total=1,
|
||||||
desc="Worker {worker_id}",
|
desc="Worker {worker_id}",
|
||||||
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
|
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
|
||||||
)
|
)
|
||||||
|
if "print" not in env.pbar_policy():
|
||||||
if os.getenv(HUSH_PROGRESS) is not None:
|
|
||||||
head_kwargs["file"] = worker_kwargs["file"] = StringIO()
|
head_kwargs["file"] = worker_kwargs["file"] = StringIO()
|
||||||
|
head_kwargs["desc"] = desc
|
||||||
p = cls([tqdm(total=num_tot, ncols=100, ascii=False, **head_kwargs)])
|
p = cls([tqdm(total=num_tot, ncols=100, ascii=False, **head_kwargs)])
|
||||||
for i in range(1, num_sub_bars + 1):
|
for i in range(1, num_sub_bars + 1):
|
||||||
kwargs = {k: v for k, v in worker_kwargs.items()}
|
kwargs = {k: v for k, v in worker_kwargs.items()}
|
||||||
@@ -51,20 +52,35 @@ class PBars:
|
|||||||
return p
|
return p
|
||||||
|
|
||||||
def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None:
|
def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None:
|
||||||
|
self.policy = env.pbar_policy()
|
||||||
|
self.print_path = Path("progress " + pbars[0].desc).resolve()
|
||||||
if isinstance(pbars, tqdm):
|
if isinstance(pbars, tqdm):
|
||||||
self.pbars = [pbars]
|
self.pbars = [pbars]
|
||||||
else:
|
else:
|
||||||
self.pbars = pbars
|
self.pbars = pbars
|
||||||
self.logger = get_logger(__name__)
|
self.open = True
|
||||||
|
if "file" in self.policy:
|
||||||
|
self.thread = threading.Thread(target=self.print_worker, daemon=True)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
def print(self):
|
def print(self):
|
||||||
|
if "file" not in self.policy:
|
||||||
|
return
|
||||||
if len(self.pbars) > 1:
|
if len(self.pbars) > 1:
|
||||||
s = [""]
|
s = [""]
|
||||||
else:
|
else:
|
||||||
s = []
|
s = []
|
||||||
for pbar in self.pbars:
|
for pbar in self.pbars:
|
||||||
s.append(str(pbar))
|
s.append(str(pbar))
|
||||||
self.logger.info("\n".join(s))
|
self.print_path.write_text("\n".join(s))
|
||||||
|
|
||||||
|
def print_worker(self):
|
||||||
|
while True:
|
||||||
|
for _ in range(100):
|
||||||
|
if not self.open:
|
||||||
|
return
|
||||||
|
time.sleep(0.02)
|
||||||
|
self.print()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
yield from self.pbars
|
yield from self.pbars
|
||||||
@@ -79,7 +95,6 @@ class PBars:
|
|||||||
else:
|
else:
|
||||||
self.pbars[i].update(value)
|
self.pbars[i].update(value)
|
||||||
self.pbars[0].update()
|
self.pbars[0].update()
|
||||||
self.print()
|
|
||||||
|
|
||||||
def append(self, pbar: tqdm):
|
def append(self, pbar: tqdm):
|
||||||
self.pbars.append(pbar)
|
self.pbars.append(pbar)
|
||||||
@@ -89,73 +104,20 @@ class PBars:
|
|||||||
self.print()
|
self.print()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
self.print()
|
||||||
|
self.open = False
|
||||||
|
if "file" in self.policy:
|
||||||
|
self.thread.join()
|
||||||
for pbar in self.pbars:
|
for pbar in self.pbars:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
|
|
||||||
class ProgressTracker:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max: Union[int, float],
|
|
||||||
prefix: str = "",
|
|
||||||
suffix: str = "",
|
|
||||||
logger: logging.Logger = None,
|
|
||||||
auto_print: bool = True,
|
|
||||||
percent_incr: Union[int, float] = 5,
|
|
||||||
default_update: Union[int, float] = 1,
|
|
||||||
):
|
|
||||||
self.max = max
|
|
||||||
self.current = 0
|
|
||||||
self.prefix = prefix
|
|
||||||
self.suffix = suffix
|
|
||||||
self.start_time = dt.datetime.now()
|
|
||||||
self.auto_print = auto_print
|
|
||||||
self.next_percent = percent_incr
|
|
||||||
self.percent_incr = percent_incr
|
|
||||||
self.default_update = default_update
|
|
||||||
self.logger = logger if logger is not None else get_logger()
|
|
||||||
|
|
||||||
def _update(self):
|
|
||||||
if self.auto_print and self.current / self.max >= self.next_percent / 100:
|
|
||||||
self.next_percent += self.percent_incr
|
|
||||||
self.logger.info(self.prefix + self.ETA + self.suffix)
|
|
||||||
|
|
||||||
def update(self, num=None):
|
|
||||||
if num is None:
|
|
||||||
num = self.default_update
|
|
||||||
self.current += num
|
|
||||||
self._update()
|
|
||||||
|
|
||||||
def set(self, value):
|
|
||||||
self.current = value
|
|
||||||
self._update()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ETA(self):
|
|
||||||
if self.current <= 0:
|
|
||||||
return "\033[31mETA : unknown\033[0m"
|
|
||||||
eta = (
|
|
||||||
(dt.datetime.now() - self.start_time).seconds / self.current * (self.max - self.current)
|
|
||||||
)
|
|
||||||
H = eta // 3600
|
|
||||||
M = (eta - H * 3600) // 60
|
|
||||||
S = eta % 60
|
|
||||||
percent = int(100 * self.current / self.max)
|
|
||||||
return "\033[34mremaining : {:.0f}h {:.0f}min {:.0f}s ({:.0f}% in total). \033[31mETA : {:%Y-%m-%d %H:%M:%S}\033[0m".format(
|
|
||||||
H, M, S, percent, dt.datetime.now() + dt.timedelta(seconds=eta)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_eta(self):
|
|
||||||
return self.ETA
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "{}/{}".format(self.current, self.max)
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressBarActor:
|
class ProgressBarActor:
|
||||||
def __init__(self, num_workers: int, num_steps: int) -> None:
|
def __init__(self, name: str, num_workers: int, num_steps: int) -> None:
|
||||||
self.counters = [0 for _ in range(num_workers + 1)]
|
self.counters = [0 for _ in range(num_workers + 1)]
|
||||||
self.p_bars = PBars.auto(num_steps, num_workers)
|
self.p_bars = PBars.auto(
|
||||||
|
num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")
|
||||||
|
)
|
||||||
|
|
||||||
def update(self, worker_id: int, rel_pos: float = None) -> None:
|
def update(self, worker_id: int, rel_pos: float = None) -> None:
|
||||||
"""update a counter
|
"""update a counter
|
||||||
@@ -182,7 +144,9 @@ class ProgressBarActor:
|
|||||||
self.p_bars.close()
|
self.p_bars.close()
|
||||||
|
|
||||||
|
|
||||||
def progress_worker(num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue):
|
def progress_worker(
|
||||||
|
name: str, num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue
|
||||||
|
):
|
||||||
"""keeps track of progress on a separate thread
|
"""keeps track of progress on a separate thread
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -194,7 +158,7 @@ def progress_worker(num_workers: int, num_steps: int, progress_queue: multiproce
|
|||||||
Literal[0] : stop the worker and close the progress bars
|
Literal[0] : stop the worker and close the progress bars
|
||||||
Tuple[int, float] : worker id and relative progress between 0 and 1
|
Tuple[int, float] : worker id and relative progress between 0 and 1
|
||||||
"""
|
"""
|
||||||
pbars = PBars.auto(num_steps, num_workers)
|
pbars = PBars.auto(num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step"))
|
||||||
while True:
|
while True:
|
||||||
raw = progress_queue.get()
|
raw = progress_queue.get()
|
||||||
if raw == 0:
|
if raw == 0:
|
||||||
@@ -230,6 +194,10 @@ def format_variable_list(l: List[tuple]):
|
|||||||
return joints[0].join(str_list)
|
return joints[0].join(str_list)
|
||||||
|
|
||||||
|
|
||||||
|
def branch_id(branch: Tuple[Path, ...]) -> str:
|
||||||
|
return "".join("".join(b.name.split()[2:-2]) for b in branch)
|
||||||
|
|
||||||
|
|
||||||
def format_value(value):
|
def format_value(value):
|
||||||
if type(value) == type(False):
|
if type(value) == type(False):
|
||||||
return str(value)
|
return str(value)
|
||||||
@@ -304,55 +272,6 @@ def required_simulations(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]
|
|||||||
yield variable_ind, full_config
|
yield variable_ind, full_config
|
||||||
|
|
||||||
|
|
||||||
def parallelize(func, arg_iter, sim_jobs=4, progress_tracker_kwargs=None, const_kwarg={}):
|
|
||||||
"""given a function and an iterable of arguments, runs the function in parallel
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
func : a function
|
|
||||||
arg_iter : an iterable that yields a tuple to be unpacked to the function as argument(s)
|
|
||||||
sim_jobs : number of parallel runs
|
|
||||||
progress_tracker_kwargs : key word arguments to be passed to the ProgressTracker
|
|
||||||
const_kwarg : keyword arguments to be passed to the function on every run
|
|
||||||
|
|
||||||
Returns
|
|
||||||
----------
|
|
||||||
a list of the result ordered like arg_iter
|
|
||||||
"""
|
|
||||||
pt = None
|
|
||||||
if progress_tracker_kwargs is not None:
|
|
||||||
progress_tracker_kwargs["auto_print"] = True
|
|
||||||
pt = ray.remote(ProgressTracker).remote(**progress_tracker_kwargs)
|
|
||||||
|
|
||||||
# Initial setup
|
|
||||||
func = ray.remote(func)
|
|
||||||
jobs = []
|
|
||||||
results = []
|
|
||||||
dico = {} # to keep track of the order, as tasks may no finish in order
|
|
||||||
for k, args in enumerate(arg_iter):
|
|
||||||
if not isinstance(args, tuple):
|
|
||||||
print("iterator must return a tuple")
|
|
||||||
quit()
|
|
||||||
# as we got through the iterator, wait for first one to finish before
|
|
||||||
# adding a new job
|
|
||||||
if len(jobs) >= sim_jobs:
|
|
||||||
res, jobs = ray.wait(jobs)
|
|
||||||
results[dico[res[0].task_id()]] = ray.get(res[0])
|
|
||||||
if pt is not None:
|
|
||||||
ray.get(pt.update.remote())
|
|
||||||
newJob = func.remote(*args, **const_kwarg)
|
|
||||||
jobs.append(newJob)
|
|
||||||
dico[newJob.task_id()] = k
|
|
||||||
results.append(None)
|
|
||||||
|
|
||||||
# still have to wait for the last few jobs when there is no more new jobs
|
|
||||||
for j in jobs:
|
|
||||||
results[dico[j.task_id()]] = ray.get(j)
|
|
||||||
if pt is not None:
|
|
||||||
ray.get(pt.update.remote())
|
|
||||||
|
|
||||||
return np.array(results)
|
|
||||||
|
|
||||||
|
|
||||||
def deep_update(d: Mapping, u: Mapping) -> dict:
|
def deep_update(d: Mapping, u: Mapping) -> dict:
|
||||||
for k, v in u.items():
|
for k, v in u.items():
|
||||||
if isinstance(v, collections.abc.Mapping):
|
if isinstance(v, collections.abc.Mapping):
|
||||||
@@ -391,8 +310,3 @@ def override_config(new: Dict[str, Any], old: Dict[str, Any] = None) -> Dict[str
|
|||||||
else:
|
else:
|
||||||
out[section_name] = section
|
out[section_name] = section
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def formatted_hostname():
|
|
||||||
s = socket.gethostname().replace(".", "_")
|
|
||||||
return (PREFIX_KEY_BASE + s).upper()
|
|
||||||
@@ -173,32 +173,32 @@ class TestInitializeMethods(unittest.TestCase):
|
|||||||
t = d["time"]
|
t = d["time"]
|
||||||
field = d["field"]
|
field = d["field"]
|
||||||
conf = load_conf("custom_field/no_change")
|
conf = load_conf("custom_field/no_change")
|
||||||
conf = init._generate_sim_grid(conf)
|
conf = init.build_sim_grid(conf)
|
||||||
result = init.setup_custom_field(conf)
|
result = init.setup_custom_field(conf)
|
||||||
self.assertAlmostEqual(conf["field_0"].real.max(), field.real.max(), 4)
|
self.assertAlmostEqual(conf["field_0"].real.max(), field.real.max(), 4)
|
||||||
self.assertTrue(result)
|
self.assertTrue(result)
|
||||||
|
|
||||||
conf = load_conf("custom_field/peak_power")
|
conf = load_conf("custom_field/peak_power")
|
||||||
conf = init._generate_sim_grid(conf)
|
conf = init.build_sim_grid(conf)
|
||||||
result = init.setup_custom_field(conf)
|
result = init.setup_custom_field(conf)
|
||||||
self.assertAlmostEqual(math.abs2(conf["field_0"]).max(), 20000, 4)
|
self.assertAlmostEqual(math.abs2(conf["field_0"]).max(), 20000, 4)
|
||||||
self.assertTrue(result)
|
self.assertTrue(result)
|
||||||
self.assertNotAlmostEqual(conf["wavelength"], 1593e-9)
|
self.assertNotAlmostEqual(conf["wavelength"], 1593e-9)
|
||||||
|
|
||||||
conf = load_conf("custom_field/mean_power")
|
conf = load_conf("custom_field/mean_power")
|
||||||
conf = init._generate_sim_grid(conf)
|
conf = init.build_sim_grid(conf)
|
||||||
result = init.setup_custom_field(conf)
|
result = init.setup_custom_field(conf)
|
||||||
self.assertAlmostEqual(np.trapz(math.abs2(conf["field_0"]), conf["t"]), 0.22 / 40e6, 4)
|
self.assertAlmostEqual(np.trapz(math.abs2(conf["field_0"]), conf["t"]), 0.22 / 40e6, 4)
|
||||||
self.assertTrue(result)
|
self.assertTrue(result)
|
||||||
|
|
||||||
conf = load_conf("custom_field/recover1")
|
conf = load_conf("custom_field/recover1")
|
||||||
conf = init._generate_sim_grid(conf)
|
conf = init.build_sim_grid(conf)
|
||||||
result = init.setup_custom_field(conf)
|
result = init.setup_custom_field(conf)
|
||||||
self.assertAlmostEqual(math.abs2(conf["field_0"] - field).sum(), 0)
|
self.assertAlmostEqual(math.abs2(conf["field_0"] - field).sum(), 0)
|
||||||
self.assertTrue(result)
|
self.assertTrue(result)
|
||||||
|
|
||||||
conf = load_conf("custom_field/recover2")
|
conf = load_conf("custom_field/recover2")
|
||||||
conf = init._generate_sim_grid(conf)
|
conf = init.build_sim_grid(conf)
|
||||||
result = init.setup_custom_field(conf)
|
result = init.setup_custom_field(conf)
|
||||||
self.assertAlmostEqual((math.abs2(conf["field_0"]) / 0.9 - math.abs2(field)).sum(), 0)
|
self.assertAlmostEqual((math.abs2(conf["field_0"]) / 0.9 - math.abs2(field)).sum(), 0)
|
||||||
self.assertTrue(result)
|
self.assertTrue(result)
|
||||||
|
|||||||
Reference in New Issue
Block a user