A lot of improvements

This commit is contained in:
Benoît Sierro
2021-06-02 12:01:26 +02:00
parent 9fdf9ed525
commit 49513eff66
12 changed files with 328 additions and 443 deletions

View File

@@ -13,6 +13,10 @@ spectra, params = load_sim_data("varyTechNoise100kW_sim_data")
to plot to plot
plot_results_2D(spectra[0], (600, 1450, nm), params) plot_results_2D(spectra[0], (600, 1450, nm), params)
# Environment variables
SCGENERATOR_PBAR_POLICY : "none", "file", "print", "both", optional
whether progress should be printed to a file ("file"), to the standard output ("print") or both, default : print
# Configuration # Configuration

View File

@@ -2,6 +2,5 @@ numpy
matplotlib matplotlib
scipy scipy
ray ray
send2trash
toml toml
tqdm tqdm

View File

@@ -1,13 +1,16 @@
import argparse import argparse
import os import os
from pathlib import Path
import random import random
import sys
import ray import ray
from scgenerator import initialize from scgenerator.physics.simulate import (
from ..physics.simulate import run_simulation_sequence, resume_simulations, SequencialSimulations run_simulation_sequence,
from .. import io resume_simulations,
SequencialSimulations,
)
from scgenerator import io
def create_parser(): def create_parser():
@@ -73,7 +76,11 @@ def run_sim(args):
def merge(args): def merge(args):
io.append_and_merge(args.path, args.output_name) path_trees = io.build_path_trees(Path(args.path))
if args.output_name is None:
args.output_name = path_trees[-1][0][0].parent.name + " merged"
io.merge(args.output_name, path_trees)
def prep_ray(args): def prep_ray(args):
@@ -98,7 +105,7 @@ def resume_sim(args):
sim = resume_simulations(args.sim_dir, method=method) sim = resume_simulations(args.sim_dir, method=method)
sim.run() sim.run()
run_simulation_sequence( run_simulation_sequence(
*args.configs, method=method, prev_sim_dir=sim.data_folder, final_name=args.output_name *args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name
) )

View File

@@ -246,7 +246,12 @@ valid_variable = dict(
) )
ENVIRON_KEY_BASE = "SCGENERATOR_" ENVIRON_KEY_BASE = "SCGENERATOR_"
HUSH_PROGRESS = ENVIRON_KEY_BASE + "HUSH_PROGRESS" PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_" TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_"
PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_" PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_"
PARAM_SEPARATOR = " " PARAM_SEPARATOR = " "
SPEC1_FN = "spectrum_{}.npy"
SPECN_FN = "spectra_{}.npy"
Z_FN = "z.npy"
PARAM_FN = "params.toml"

29
src/scgenerator/env.py Normal file
View File

@@ -0,0 +1,29 @@
import os
from pathlib import Path
from typing import Dict, List, Literal, Optional
from .const import ENVIRON_KEY_BASE, PBAR_POLICY, TMP_FOLDER_KEY_BASE
def data_folder(task_id: int) -> Optional[Path]:
idstr = str(int(task_id))
tmp = os.getenv(TMP_FOLDER_KEY_BASE + idstr)
return tmp
def all_environ() -> Dict[str, str]:
"""returns a dictionary of all environment variables set by any instance of scgenerator"""
d = dict(filter(lambda el: el[0].startswith(ENVIRON_KEY_BASE), os.environ.items()))
return d
def pbar_policy() -> List[Literal["print", "file"]]:
policy = os.getenv(PBAR_POLICY)
if policy == "print" or policy is None:
return ["print"]
elif policy == "file":
return ["file"]
elif policy == "both":
return ["file", "print"]
else:
return []

View File

@@ -76,8 +76,8 @@ class ContinuationParamSequence(ParamSequence):
"""iterates through all possible parameters, yielding a config as well as a flattened """iterates through all possible parameters, yielding a config as well as a flattened
computed parameters set each time""" computed parameters set each time"""
for variable_list, full_config in required_simulations(self.config): for variable_list, full_config in required_simulations(self.config):
prev_data_dir = self.find_prev_data_dir(variable_list) prev_data_dir = self.find_prev_data_dir(variable_list).resolve()
full_config["prev_data_dir"] = str(prev_data_dir.resolve()) full_config["prev_data_dir"] = str(prev_data_dir)
yield variable_list, compute_init_parameters(full_config) yield variable_list, compute_init_parameters(full_config)
def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path: def find_prev_data_dir(self, new_variable_list: List[Tuple[str, Any]]) -> Path:
@@ -116,7 +116,7 @@ class RecoveryParamSequence(ParamSequence):
z_num = config["simulation"]["z_num"] z_num = config["simulation"]["z_num"]
started = self.num_sim started = self.num_sim
sub_folders = io.get_data_subfolders(self.id) sub_folders = io.get_data_dirs(io.get_sim_dir(self.id))
pbar_store = utils.PBars( pbar_store = utils.PBars(
tqdm( tqdm(
@@ -157,7 +157,7 @@ class RecoveryParamSequence(ParamSequence):
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
for variable_list, params in required_simulations(self.config): for variable_list, params in required_simulations(self.config):
data_dir = io.get_data_folder(self.id) / utils.format_variable_list(variable_list) data_dir = io.get_sim_dir(self.id) / utils.format_variable_list(variable_list)
if not data_dir.is_dir() or io.find_last_spectrum_num(data_dir) == 0: if not data_dir.is_dir() or io.find_last_spectrum_num(data_dir) == 0:
if (prev_data_dir := self.find_prev_data_dir(variable_list)) is not None: if (prev_data_dir := self.find_prev_data_dir(variable_list)) is not None:
@@ -565,7 +565,8 @@ def _ensure_consistency(config):
def recover_params(config: Dict[str, Any], data_folder: Path) -> Dict[str, Any]: def recover_params(config: Dict[str, Any], data_folder: Path) -> Dict[str, Any]:
params = compute_init_parameters(config) params = compute_init_parameters(config)
try: try:
prev_params = io.load_toml(data_folder / "params.toml") prev_params = io.load_previous_parameters(data_folder / "params.toml")
prev_params = build_sim_grid(prev_params)
except FileNotFoundError: except FileNotFoundError:
prev_params = {} prev_params = {}
for k, v in prev_params.items(): for k, v in prev_params.items():
@@ -602,7 +603,7 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
for key, value in config.get(section, {}).items(): for key, value in config.get(section, {}).items():
params[key] = value params[key] = value
params = _generate_sim_grid(params) params = build_sim_grid(params)
# Initial field may influence the grid # Initial field may influence the grid
if "mean_power" in params: if "mean_power" in params:
@@ -789,7 +790,7 @@ def _interp_range(w, upper, lower):
return interp_range return interp_range
def _generate_sim_grid(params): def build_sim_grid(params):
"""computes a bunch of values that relate to the simulation grid """computes a bunch of values that relate to the simulation grid
Parameters Parameters

View File

@@ -1,30 +1,29 @@
import os import os
import shutil
from datetime import datetime from datetime import datetime
from glob import glob from typing import Any, Dict, Generator, List, Sequence, Tuple
from typing import Any, Dict, Iterable, List, Tuple, Union import shutil
import numpy as np import numpy as np
import pkg_resources as pkg import pkg_resources as pkg
import toml import toml
from send2trash import TrashPermissionError, send2trash
from tqdm import tqdm
from pathlib import Path from pathlib import Path
import itertools import itertools
from . import utils from . import utils, env
from .const import ENVIRON_KEY_BASE, PARAM_SEPARATOR, PREFIX_KEY_BASE, TMP_FOLDER_KEY_BASE from .const import (
ENVIRON_KEY_BASE,
PARAM_SEPARATOR,
PBAR_POLICY,
TMP_FOLDER_KEY_BASE,
SPEC1_FN,
SPECN_FN,
Z_FN,
PARAM_FN,
)
from .errors import IncompleteDataFolderError from .errors import IncompleteDataFolderError
from .logger import get_logger from .logger import get_logger
using_ray = False PathTree = List[Tuple[Path, ...]]
try:
import ray
from ray.util.queue import Queue
using_ray = True
except ModuleNotFoundError:
pass
class Paths: class Paths:
@@ -81,51 +80,6 @@ class Paths:
return os.path.join(cls.get("plots"), name) return os.path.join(cls.get("plots"), name)
class DataBuffer:
def __init__(self, task_id):
self.logger = get_logger(__name__)
self.id = task_id
self.queue = Queue()
def empty(self):
num = self.queue.size()
if num == 0:
return 0
self.logger.info(f"buffer length at time of emptying : {num}")
while not self.queue.empty():
name, identifier, data = self.queue.get()
save_data(data, name, self.id, identifier)
return num
def append(self, file_name: str, identifier: str, data: np.ndarray):
self.queue.put((file_name, identifier, data))
# def abspath(rel_path: str):
# """returns the complete path with the correct root. In other words, allows to modify absolute paths
# in case the process accessing this function is a sub-process started from another device.
# Parameters
# ----------
# rel_path : str
# relative path
# Returns
# -------
# str
# absolute path
# """
# key = utils.formatted_hostname()
# prefix = os.getenv(key)
# if prefix is None:
# p = os.path.abspath(rel_path)
# else:
# p = os.path.join(prefix, rel_path)
# return os.path.normpath(p)
def conform_toml_path(path: os.PathLike) -> Path: def conform_toml_path(path: os.PathLike) -> Path:
path = Path(path) path = Path(path)
if not path.name.lower().endswith(".toml"): if not path.name.lower().endswith(".toml"):
@@ -159,7 +113,7 @@ def serializable(val):
return out return out
def _prepare_for_serialization(dico: Dict[str, Any]): def prepare_for_serialization(dico: Dict[str, Any]) -> Dict[str, Any]:
"""prepares a dictionary for serialization. Some keys may not be preserved """prepares a dictionary for serialization. Some keys may not be preserved
(dropped due to no conversion available) (dropped due to no conversion available)
@@ -168,7 +122,7 @@ def _prepare_for_serialization(dico: Dict[str, Any]):
dico : dict dico : dict
dictionary dictionary
""" """
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w"] forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"]
types = (np.ndarray, float, int, str, list, tuple, dict) types = (np.ndarray, float, int, str, list, tuple, dict)
out = {} out = {}
for key, value in dico.items(): for key, value in dico.items():
@@ -177,7 +131,7 @@ def _prepare_for_serialization(dico: Dict[str, Any]):
if not isinstance(value, types): if not isinstance(value, types):
continue continue
if isinstance(value, dict): if isinstance(value, dict):
out[key] = _prepare_for_serialization(value) out[key] = prepare_for_serialization(value)
elif isinstance(value, np.ndarray) and value.dtype == complex: elif isinstance(value, np.ndarray) and value.dtype == complex:
continue continue
else: else:
@@ -186,11 +140,11 @@ def _prepare_for_serialization(dico: Dict[str, Any]):
return out return out
def save_parameters(param_dict: Dict[str, Any], task_id: int, data_dir_name: str): def save_parameters(param_dict: Dict[str, Any], data_dir: Path) -> Path:
param = param_dict.copy() param = param_dict.copy()
file_path = generate_file_path("params.toml", task_id, data_dir_name) file_path = data_dir / "params.toml"
param = _prepare_for_serialization(param) param = prepare_for_serialization(param)
param["datetime"] = datetime.now() param["datetime"] = datetime.now()
file_path.parent.mkdir(exist_ok=True) file_path.parent.mkdir(exist_ok=True)
@@ -202,35 +156,6 @@ def save_parameters(param_dict: Dict[str, Any], task_id: int, data_dir_name: str
return file_path return file_path
# def save_parameters_old(param_dict, file_name="param"):
# """Writes the flattened parameters dictionary specific to a single simulation into a toml file
# Parameters
# ----------
# param_dict : dictionary of parameters. Only floats, int and arrays of
# non complex values are stored in the json
# folder_name : folder where to save the files (relative to cwd)
# file_name : name of the readable file.
# """
# param = param_dict.copy()
# folder_name, file_name = os.path.split(file_name)
# folder_name = "tmp" if folder_name == "" else folder_name
# file_name = os.path.splitext(file_name)[0]
# if not os.path.exists(folder_name):
# os.makedirs(folder_name)
# param = _prepare_for_serialization(param)
# param["datetime"] = datetime.now()
# # save toml of the simulation
# with open(os.path.join(folder_name, file_name + ".toml"), "w") as file:
# toml.dump(param, file, encoder=toml.TomlNumpyEncoder())
# return os.path.join(folder_name, file_name)
def load_previous_parameters(path: os.PathLike): def load_previous_parameters(path: os.PathLike):
"""loads a parameters toml files and converts data to appropriate type """loads a parameters toml files and converts data to appropriate type
Parameters Parameters
@@ -267,31 +192,21 @@ def load_material_dico(name):
return toml.loads(Paths.gets("gas"))[name] return toml.loads(Paths.gets("gas"))[name]
def get_all_environ() -> Dict[str, str]: def get_data_dirs(sim_dir: Path) -> List[Path]:
"""returns a dictionary of all environment variables set by any instance of scgenerator""" """returns a list of absolute paths corresponding to a particular run
d = dict(filter(lambda el: el[0].startswith(ENVIRON_KEY_BASE), os.environ.items()))
return d
def load_single_spectrum(folder: Path, index) -> np.ndarray:
return np.load(folder / f"spectra_{index}.npy")
def get_data_subfolders(task_id: int) -> List[Path]:
"""returns a list of relative path/subfolders in the specified directory
Parameters Parameters
---------- ----------
path : str sim_dir : Path
path to directory containing the initial config file and the spectra sub folders path to directory containing the initial config file and the spectra sub folders
Returns Returns
------- -------
List[str] List[Path]
paths to sub folders paths to sub folders
""" """
return [p.resolve() for p in get_data_folder(task_id).glob("*") if p.is_dir()] return [p.resolve() for p in sim_dir.glob("*") if p.is_dir()]
def check_data_integrity(sub_folders: List[Path], init_z_num: int): def check_data_integrity(sub_folders: List[Path], init_z_num: int):
@@ -336,8 +251,7 @@ def num_left_to_propagate(sub_folder: Path, init_z_num: int) -> int:
IncompleteDataFolderError IncompleteDataFolderError
raised if init_z_num doesn't match that specified in the individual parameter file raised if init_z_num doesn't match that specified in the individual parameter file
""" """
params = load_toml(sub_folder / "params.toml") z_num = load_toml(sub_folder / "params.toml")["z_num"]
z_num = params["z_num"]
num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing
if z_num != init_z_num: if z_num != init_z_num:
@@ -351,7 +265,7 @@ def num_left_to_propagate(sub_folder: Path, init_z_num: int) -> int:
def find_last_spectrum_num(data_dir: Path): def find_last_spectrum_num(data_dir: Path):
for num in itertools.count(1): for num in itertools.count(1):
p_to_test = data_dir / f"spectrum_{num}.npy" p_to_test = data_dir / SPEC1_FN.format(num)
if not p_to_test.is_file() or len(p_to_test.read_bytes()) == 0: if not p_to_test.is_file() or len(p_to_test.read_bytes()) == 0:
return num - 1 return num - 1
@@ -359,142 +273,168 @@ def find_last_spectrum_num(data_dir: Path):
def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]: def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]:
"""return the last spectrum stored in path as well as its id""" """return the last spectrum stored in path as well as its id"""
num = find_last_spectrum_num(data_dir) num = find_last_spectrum_num(data_dir)
return num, np.load(data_dir / f"spectrum_{num}.npy") return num, np.load(data_dir / SPEC1_FN.format(num))
def append_and_merge(final_sim_path: os.PathLike, new_name=None): def update_appended_params(source: Path, destination: Path, z: Sequence):
final_sim_path = Path(final_sim_path).resolve()
if new_name is None:
new_name = final_sim_path.name + " appended"
destination_path = final_sim_path.parent / new_name
destination_path.mkdir(exist_ok=True)
sim_paths = list(final_sim_path.glob("id*num*"))
pbars = utils.PBars.auto(
len(sim_paths),
0,
head_kwargs=dict(desc="Appending"),
worker_kwargs=dict(desc=""),
)
for sim_path in sim_paths:
path_tree = [sim_path]
sim_name = sim_path.name
appended_sim_path = destination_path / sim_name
appended_sim_path.mkdir(exist_ok=True)
while (
prev_sim_path := load_toml(path_tree[-1] / "params.toml").get("prev_data_dir")
) is not None:
path_tree.append(Path(prev_sim_path).resolve())
z: List[np.ndarray] = []
z_num = 0
last_z = 0
paths_r = list(reversed(path_tree))
for path in paths_r:
curr_z_num = load_toml(path / "params.toml")["z_num"]
for i in range(curr_z_num):
shutil.copy(
path / f"spectrum_{i}.npy",
appended_sim_path / f"spectrum_{i + z_num}.npy",
)
z_arr = np.load(path / "z.npy")
z.append(z_arr + last_z)
last_z += z_arr[-1]
z_num += curr_z_num
z_arr = np.concatenate(z)
update_appended_params(sim_path / "params.toml", appended_sim_path / "params.toml", z_arr)
np.save(appended_sim_path / "z.npy", z_arr)
pbars.update(0)
update_appended_params(
final_sim_path / "initial_config.toml", destination_path / "initial_config.toml", z_arr
)
pbars.close()
merge(destination_path, delete=True)
def update_appended_params(param_path: Path, new_path: Path, z):
z_num = len(z) z_num = len(z)
params = load_toml(param_path) params = load_toml(source)
if "simulation" in params: if "simulation" in params:
params["simulation"]["z_num"] = z_num params["simulation"]["z_num"] = z_num
params["simulation"]["z_targets"] = z params["fiber"]["length"] = float(z[-1] - z[0])
else: else:
params["z_num"] = z_num params["z_num"] = z_num
params["z_targets"] = z params["length"] = float(z[-1] - z[0])
save_toml(new_path, params) save_toml(destination, params)
def merge(paths: Union[Path, List[Path]], delete=False): def build_path_trees(sim_dir: Path) -> List[PathTree]:
if isinstance(paths, Path): sim_dir = sim_dir.resolve()
paths = [paths] path_branches: List[Tuple[Path, ...]] = []
for path in paths: to_check = list(sim_dir.glob("id*num*"))
merge_same_simulations(path, delete=delete) pbar = utils.PBars.auto(len(to_check), desc="Building path trees")
for branch in map(build_path_branch, to_check):
if branch is not None:
def merge_same_simulations(path: Path, delete=True): path_branches.append(branch)
logger = get_logger(__name__)
num_separator = PARAM_SEPARATOR + "num" + PARAM_SEPARATOR
sub_folders = [p for p in path.glob("*") if p.is_dir()]
config = load_toml(path / "initial_config.toml")
repeat = config["simulation"].get("repeat", 1)
max_repeat_id = repeat - 1
z_num = config["simulation"]["z_num"]
check_data_integrity(sub_folders, z_num)
sim_num, param_num = utils.count_variations(config)
pbar = utils.PBars.auto(sim_num * z_num, head_kwargs=dict(desc="Merging data"))
spectra = []
for z_id in range(z_num):
for variable_and_ind, _ in utils.required_simulations(config):
repeat_id = variable_and_ind[-1][1]
# reset the buffer once we move to a new parameter set
if repeat_id == 0:
spectra = []
in_path = path / utils.format_variable_list(variable_and_ind)
spectra.append(np.load(in_path / f"spectrum_{z_id}.npy"))
pbar.update() pbar.update()
pbar.close()
path_trees = group_path_branches(path_branches)
return path_trees
# write new files only once all those from one parameter set are collected
if repeat_id == max_repeat_id:
out_path = path / (
utils.format_variable_list(variable_and_ind[1:-1]) + PARAM_SEPARATOR + "merged"
)
out_path = ensure_folder(out_path, prevent_overwrite=False) def build_path_branch(data_dir: Path) -> Tuple[Path, ...]:
spectra = np.array(spectra).reshape(repeat, len(spectra[0])) if not data_dir.is_dir():
np.save(out_path / f"spectra_{z_id}.npy", spectra.squeeze()) return None
path_branch = [data_dir]
while (prev_sim_path := load_toml(path_branch[-1] / PARAM_FN).get("prev_data_dir")) is not None:
p = Path(prev_sim_path).resolve()
if not p.exists():
p = Path(*p.parts[-2:]).resolve()
path_branch.append(p)
return tuple(reversed(path_branch))
def group_path_branches(path_branches: List[Tuple[Path, ...]]) -> List[PathTree]:
"""groups path lists
[
("a/id 0 wavelength 100 num 0"," b/id 0 wavelength 100 num 0"),
("a/id 2 wavelength 100 num 1"," b/id 2 wavelength 100 num 1"),
("a/id 1 wavelength 200 num 0"," b/id 1 wavelength 200 num 0"),
("a/id 3 wavelength 200 num 1"," b/id 3 wavelength 200 num 1")
]
->
[
(
("a/id 0 wavelength 100 num 0", "a/id 2 wavelength 100 num 1"),
("b/id 0 wavelength 100 num 0", "b/id 2 wavelength 100 num 1"),
)
(
("a/id 1 wavelength 200 num 0", "a/id 3 wavelength 200 num 1"),
("b/id 1 wavelength 200 num 0", "b/id 3 wavelength 200 num 1"),
)
]
Parameters
----------
path_branches : List[Tuple[Path, ...]]
each element of the list is a path to a folder containing data of one simulation
Returns
-------
List[PathTree]
List of PathTrees to be used in merge
"""
sort_key = lambda el: el[0]
size = len(path_branches[0])
out_trees_map: Dict[str, Dict[int, Dict[int, Path]]] = {}
for branch in path_branches:
b_id = utils.branch_id(branch)
out_trees_map.setdefault(b_id, {i: {} for i in range(size)})
for sim_part, data_dir in enumerate(branch):
*_, num = data_dir.name.split()
out_trees_map[b_id][sim_part][int(num)] = data_dir
return [
tuple(
tuple(w for _, w in sorted(v.items(), key=sort_key))
for __, v in sorted(d.items(), key=sort_key)
)
for d in out_trees_map.values()
]
def merge_path_tree(path_tree: PathTree, destination: Path):
"""given a path tree, copies the file into the right location
Parameters
----------
path_tree : PathTree
elements of the list returned by group_path_branches
destination : Path
dir where to save the data
"""
z_arr: List[float] = []
destination.mkdir(exist_ok=True)
for i, (z, merged_spectra) in enumerate(merge_spectra(path_tree)):
z_arr.append(z)
spec_out_name = SPECN_FN.format(i)
np.save(destination / spec_out_name, merged_spectra)
d = np.diff(z_arr)
d[d < 0] = 0
z_arr = np.concatenate(([z_arr[0]], np.cumsum(d)))
np.save(destination / Z_FN, z_arr)
update_appended_params(path_tree[-1][0] / PARAM_FN, destination / PARAM_FN, z_arr)
def merge_spectra(
path_tree: PathTree,
) -> Generator[Tuple[float, np.ndarray], None, None]:
for same_sim_paths in path_tree:
z_arr = np.load(same_sim_paths[0] / Z_FN)
for i, z in enumerate(z_arr):
spectra: List[np.ndarray] = []
for data_dir in same_sim_paths:
spec = np.load(data_dir / SPEC1_FN.format(i))
spectra.append(spec)
yield z, np.atleast_2d(spectra)
def merge(destination: os.PathLike, path_trees: List[PathTree] = None):
destination = ensure_folder(Path(destination))
for i, sim_dir in enumerate(sim_dirs(path_trees)):
shutil.copy(
sim_dir / "initial_config.toml",
destination / f"initial_config_{i}.toml",
)
pbar = utils.PBars.auto(len(path_trees), desc="Merging")
for path_tree in path_trees:
iden = PARAM_SEPARATOR.join(path_tree[-1][0].name.split()[2:-2])
merge_path_tree(path_tree, destination / iden)
pbar.update()
# copy other files only once
if z_id == 0:
for file_name in ["z.npy", "params.toml"]:
shutil.copy(in_path / file_name, out_path)
pbar.close() pbar.close()
if delete:
for sub_folder in sub_folders: def sim_dirs(path_trees: List[PathTree]) -> Generator[Path, None, None]:
try: for p in path_trees[0]:
send2trash(str(sub_folder)) yield p[0].parent
except TrashPermissionError:
logger.warning(f"could not send send {sub_folder} to trash")
def get_data_folder(task_id: int, name_if_new: str = "data") -> Path: def get_sim_dir(task_id: int, name_if_new: str = "data") -> Path:
if name_if_new == "": if name_if_new == "":
name_if_new = "data" name_if_new = "data"
idstr = str(int(task_id)) tmp = env.data_folder(task_id)
tmp = os.getenv(TMP_FOLDER_KEY_BASE + idstr)
if tmp is None: if tmp is None:
tmp = ensure_folder(Path("scgenerator" + PARAM_SEPARATOR + name_if_new)) tmp = ensure_folder(Path("scgenerator" + PARAM_SEPARATOR + name_if_new))
os.environ[TMP_FOLDER_KEY_BASE + idstr] = str(tmp) os.environ[TMP_FOLDER_KEY_BASE + str(task_id)] = str(tmp)
tmp = Path(tmp).resolve() tmp = Path(tmp).resolve()
if not tmp.exists(): if not tmp.exists():
tmp.mkdir() tmp.mkdir()
@@ -515,30 +455,7 @@ def set_data_folder(task_id: int, path: os.PathLike):
os.environ[TMP_FOLDER_KEY_BASE + idstr] = str(path) os.environ[TMP_FOLDER_KEY_BASE + idstr] = str(path)
def generate_file_path(file_name: str, task_id: int, identifier: str = "") -> Path: def save_data(data: np.ndarray, data_dir: Path, file_name: str):
"""generates a path for the desired file name
Parameters
----------
file_name : str
desired file name. May be altered if it already exists
task_id : int
unique id of the process
identifier : str
subfolder in which to store the file. default : ""
Returns
-------
str
the full path
"""
path = get_data_folder(task_id) / identifier / file_name
path.parent.mkdir(exist_ok=True)
return path
def save_data(data: np.ndarray, file_name: str, task_id: int, identifier: str = ""):
"""saves numpy array to disk """saves numpy array to disk
Parameters Parameters
@@ -552,7 +469,7 @@ def save_data(data: np.ndarray, file_name: str, task_id: int, identifier: str =
identifier : str, optional identifier : str, optional
identifier in the main data folder of the task, by default "" identifier in the main data folder of the task, by default ""
""" """
path = generate_file_path(file_name, task_id, identifier) path = data_dir / file_name
np.save(path, data) np.save(path, data)
get_logger(__name__).debug(f"saved data in {path}") get_logger(__name__).debug(f"saved data in {path}")
return return

View File

@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Tuple, Type
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from .. import initialize, io, utils, const from .. import initialize, io, utils, const, env
from ..errors import IncompleteDataFolderError from ..errors import IncompleteDataFolderError
from ..logger import get_logger from ..logger import get_logger
from . import pulse from . import pulse
@@ -70,11 +70,13 @@ class RK4IP:
print/log progress update every n_percent, by default 10 print/log progress update every n_percent, by default 10
""" """
self.set_new_params(sim_params, save_data, job_identifier, task_id, n_percent)
def set_new_params(self, sim_params, save_data, job_identifier, task_id, n_percent):
self.job_identifier = job_identifier self.job_identifier = job_identifier
self.id = task_id self.id = task_id
self.sim_dir = io.get_sim_dir(self.id)
self.sim_dir.mkdir(exist_ok=True)
self.data_dir = self.sim_dir/self.job_identifier
self.n_percent = n_percent self.n_percent = n_percent
self.logger = get_logger(self.job_identifier) self.logger = get_logger(self.job_identifier)
self.resuming = False self.resuming = False
@@ -177,7 +179,7 @@ class RK4IP:
name : str name : str
file name file name
""" """
io.save_data(data, name, self.id, self.job_identifier) io.save_data(data, self.data_dir, name)
def run(self): def run(self):
@@ -307,14 +309,13 @@ class SequentialRK4IP(RK4IP):
def __init__( def __init__(
self, self,
sim_params, sim_params,
overall_pbar: tqdm, pbars: utils.PBars,
save_data=False, save_data=False,
job_identifier="", job_identifier="",
task_id=0, task_id=0,
n_percent=10, n_percent=10,
): ):
self.overall_pbar = overall_pbar self.pbars = pbars
self.pbar = tqdm(**const.pbar_format(1))
super().__init__( super().__init__(
sim_params, sim_params,
save_data=save_data, save_data=save_data,
@@ -324,8 +325,8 @@ class SequentialRK4IP(RK4IP):
) )
def step_saved(self): def step_saved(self):
self.overall_pbar.update() self.pbars.update(0)
self.pbar.update(self.z / self.z_final - self.pbar.n) self.pbars.update(1, self.z / self.z_final - self.pbars[1].n)
class MutliProcRK4IP(RK4IP): class MutliProcRK4IP(RK4IP):
@@ -441,8 +442,8 @@ class Simulations:
self.update(param_seq) self.update(param_seq)
self.name = self.param_seq.name self.name = self.param_seq.name
self.data_folder = io.get_data_folder(self.id, name_if_new=self.name) self.sim_dir = io.get_sim_dir(self.id, name_if_new=self.name)
io.save_toml(os.path.join(self.data_folder, "initial_config.toml"), self.param_seq.config) io.save_toml(os.path.join(self.sim_dir, "initial_config.toml"), self.param_seq.config)
self.sim_jobs_per_node = 1 self.sim_jobs_per_node = 1
self.max_concurrent_jobs = np.inf self.max_concurrent_jobs = np.inf
@@ -451,7 +452,7 @@ class Simulations:
def finished_and_complete(self): def finished_and_complete(self):
try: try:
io.check_data_integrity( io.check_data_integrity(
io.get_data_subfolders(self.id), self.param_seq["simulation", "z_num"] io.get_data_dirs(self.sim_dir), self.param_seq["simulation", "z_num"]
) )
return True return True
except IncompleteDataFolderError: except IncompleteDataFolderError:
@@ -469,21 +470,21 @@ class Simulations:
def _run_available(self): def _run_available(self):
for variable, params in self.param_seq: for variable, params in self.param_seq:
io.save_parameters(params, self.id, utils.format_variable_list(variable)) v_list_str = utils.format_variable_list(variable)
io.save_parameters(params, self.sim_dir / v_list_str)
self.new_sim(variable, params) self.new_sim(v_list_str, params)
self.finish() self.finish()
def new_sim(self, variable_list: List[tuple], params: dict): def new_sim(self, v_list_str: str, params: dict):
"""responsible to launch a new simulation """responsible to launch a new simulation
Parameters Parameters
---------- ----------
variable_list : list[tuple] v_list_str : str
list of tuples (name, value) where name is the name of a string that uniquely identifies the simulation as returned by utils.format_variable_list
variable parameter and value is its current value
params : dict params : dict
a flattened parameter dictionary, as returned by scgenerator.initialize.compute_init_parameters a flattened parameter dictionary, as returned by initialize.compute_init_parameters
""" """
raise NotImplementedError() raise NotImplementedError()
@@ -508,15 +509,14 @@ class SequencialSimulations(Simulations, priority=0):
def __init__(self, param_seq: initialize.ParamSequence, task_id): def __init__(self, param_seq: initialize.ParamSequence, task_id):
super().__init__(param_seq, task_id=task_id) super().__init__(param_seq, task_id=task_id)
self.overall_pbar = tqdm( self.pbars = utils.PBars.auto(
total=self.param_seq.num_steps, desc="Simulating", unit="step", **const.pbar_format(0) self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1
) )
def new_sim(self, variable_list: List[tuple], params: Dict[str, Any]): def new_sim(self, v_list_str: str, params: Dict[str, Any]):
v_list_str = utils.format_variable_list(variable_list)
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}") self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
SequentialRK4IP( SequentialRK4IP(
params, self.overall_pbar, save_data=True, job_identifier=v_list_str, task_id=self.id params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id
).run() ).run()
def stop(self): def stop(self):
@@ -545,7 +545,12 @@ class MultiProcSimulations(Simulations, priority=1):
] ]
self.p_worker = multiprocessing.Process( self.p_worker = multiprocessing.Process(
target=utils.progress_worker, target=utils.progress_worker,
args=(self.sim_jobs_per_node, self.param_seq.num_steps, self.progress_queue), args=(
self.param_seq.name,
self.sim_jobs_per_node,
self.param_seq.num_steps,
self.progress_queue,
),
) )
self.p_worker.start() self.p_worker.start()
@@ -554,8 +559,8 @@ class MultiProcSimulations(Simulations, priority=1):
worker.start() worker.start()
super().run() super().run()
def new_sim(self, variable_list: List[tuple], params: dict): def new_sim(self, v_list_str: str, params: dict):
self.queue.put((variable_list, params), block=True, timeout=None) self.queue.put((v_list_str, params), block=True, timeout=None)
def finish(self): def finish(self):
"""0 means finished""" """0 means finished"""
@@ -581,8 +586,7 @@ class MultiProcSimulations(Simulations, priority=1):
if raw_data == 0: if raw_data == 0:
queue.task_done() queue.task_done()
return return
variable_list, params = raw_data v_list_str, params = raw_data
v_list_str = utils.format_variable_list(variable_list)
MutliProcRK4IP( MutliProcRK4IP(
params, params,
p_queue, p_queue,
@@ -620,7 +624,7 @@ class RaySimulations(Simulations, priority=2):
) )
self.propagator = ray.remote(RayRK4IP).options( self.propagator = ray.remote(RayRK4IP).options(
override_environment_variables=io.get_all_environ() override_environment_variables=env.all_environ()
) )
self.sim_jobs_per_node = min( self.sim_jobs_per_node = min(
self.param_seq.num_sim, self.param_seq["simulation", "parallel"] self.param_seq.num_sim, self.param_seq["simulation", "parallel"]
@@ -631,16 +635,15 @@ class RaySimulations(Simulations, priority=2):
self.rolling_id = 0 self.rolling_id = 0
self.p_actor = ( self.p_actor = (
ray.remote(utils.ProgressBarActor) ray.remote(utils.ProgressBarActor)
.options(override_environment_variables=io.get_all_environ()) .options(override_environment_variables=env.all_environ())
.remote(self.sim_jobs_total, self.param_seq.num_steps) .remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
) )
def new_sim(self, variable_list: List[tuple], params: dict): def new_sim(self, v_list_str: str, params: dict):
while len(self.jobs) >= self.sim_jobs_total: while len(self.jobs) >= self.sim_jobs_total:
self._collect_1_job() self._collect_1_job()
self.rolling_id = (self.rolling_id + 1) % self.sim_jobs_total self.rolling_id = (self.rolling_id + 1) % self.sim_jobs_total
v_list_str = utils.format_variable_list(variable_list)
new_actor = self.propagator.remote( new_actor = self.propagator.remote(
params, params,
@@ -693,8 +696,9 @@ def run_simulation_sequence(
for config_file in config_files: for config_file in config_files:
sim = new_simulation(config_file, prev, method) sim = new_simulation(config_file, prev, method)
sim.run() sim.run()
prev = sim.data_folder prev = sim.sim_dir
io.append_and_merge(prev, final_name) path_trees = io.build_path_trees(sim.sim_dir)
io.merge(final_name, path_trees)
def new_simulation( def new_simulation(

View File

@@ -1,11 +1,12 @@
import argparse import argparse
import os import os
from pathlib import Path
import re import re
import shutil import shutil
import subprocess import subprocess
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
from ..initialize import validate_config_sequence from ..initialize import validate_config_sequence
@@ -85,7 +86,7 @@ def create_parser():
parser.add_argument( parser.add_argument(
"--environment-setup", "--environment-setup",
required=False, required=False,
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && export SCGENERATOR_HUSH_PROGRESS=\"\"", default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && export SCGENERATOR_PBAR_POLICY=file",
help="commands to run to setup the environement (default : activate the sc environment with conda)", help="commands to run to setup the environement (default : activate the sc environment with conda)",
) )
parser.add_argument( parser.add_argument(

View File

@@ -6,6 +6,8 @@ from pathlib import Path
import numpy as np import numpy as np
from scgenerator.const import SPECN_FN
from . import io, initialize, math from . import io, initialize, math
from .plotting import units from .plotting import units
from .logger import get_logger from .logger import get_logger
@@ -45,6 +47,8 @@ class Pulse(Sequence):
except FileNotFoundError: except FileNotFoundError:
self.logger.info(f"parameters corresponding to {self.path} not found") self.logger.info(f"parameters corresponding to {self.path} not found")
self.params = initialize.build_sim_grid(self.params)
try: try:
self.z = np.load(os.path.join(path, "z.npy")) self.z = np.load(os.path.join(path, "z.npy"))
except FileNotFoundError: except FileNotFoundError:
@@ -192,7 +196,7 @@ class Pulse(Sequence):
i = self.nmax + i i = self.nmax + i
if i in self.cache: if i in self.cache:
return self.cache[i] return self.cache[i]
spec = io.load_single_spectrum(self.path, i) spec = np.load(self.path / SPECN_FN.format(i))
if self.__ensure_2d: if self.__ensure_2d:
spec = np.atleast_2d(spec) spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.wl, self.params["frep"]) spec = Spectrum(spec, self.wl, self.params["frep"])

View File

@@ -10,18 +10,19 @@ import datetime as dt
import itertools import itertools
import logging import logging
import multiprocessing import multiprocessing
import socket from copy import deepcopy
import os
from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union
from io import StringIO from io import StringIO
from pathlib import Path
import threading
from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union
import time
import numpy as np import numpy as np
import ray import ray
from copy import deepcopy
from tqdm import tqdm from tqdm import tqdm
from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, HUSH_PROGRESS from . import env
from .const import PARAM_SEPARATOR, valid_variable
from .logger import get_logger from .logger import get_logger
from .math import * from .math import *
@@ -29,19 +30,19 @@ from .math import *
class PBars: class PBars:
@classmethod @classmethod
def auto( def auto(
cls, num_tot: int, num_sub_bars: int = 0, head_kwargs=None, worker_kwargs=None cls, num_tot: int, desc: str, num_sub_bars: int = 0, head_kwargs=None, worker_kwargs=None
) -> "PBars": ) -> "PBars":
if head_kwargs is None: if head_kwargs is None:
head_kwargs = dict(unit="step", desc="Simulating", smoothing=0) head_kwargs = dict()
if worker_kwargs is None: if worker_kwargs is None:
worker_kwargs = dict( worker_kwargs = dict(
total=1, total=1,
desc="Worker {worker_id}", desc="Worker {worker_id}",
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]", bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
) )
if "print" not in env.pbar_policy():
if os.getenv(HUSH_PROGRESS) is not None:
head_kwargs["file"] = worker_kwargs["file"] = StringIO() head_kwargs["file"] = worker_kwargs["file"] = StringIO()
head_kwargs["desc"] = desc
p = cls([tqdm(total=num_tot, ncols=100, ascii=False, **head_kwargs)]) p = cls([tqdm(total=num_tot, ncols=100, ascii=False, **head_kwargs)])
for i in range(1, num_sub_bars + 1): for i in range(1, num_sub_bars + 1):
kwargs = {k: v for k, v in worker_kwargs.items()} kwargs = {k: v for k, v in worker_kwargs.items()}
@@ -51,20 +52,35 @@ class PBars:
return p return p
def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None: def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None:
self.policy = env.pbar_policy()
self.print_path = Path("progress " + pbars[0].desc).resolve()
if isinstance(pbars, tqdm): if isinstance(pbars, tqdm):
self.pbars = [pbars] self.pbars = [pbars]
else: else:
self.pbars = pbars self.pbars = pbars
self.logger = get_logger(__name__) self.open = True
if "file" in self.policy:
self.thread = threading.Thread(target=self.print_worker, daemon=True)
self.thread.start()
def print(self): def print(self):
if "file" not in self.policy:
return
if len(self.pbars) > 1: if len(self.pbars) > 1:
s = [""] s = [""]
else: else:
s = [] s = []
for pbar in self.pbars: for pbar in self.pbars:
s.append(str(pbar)) s.append(str(pbar))
self.logger.info("\n".join(s)) self.print_path.write_text("\n".join(s))
def print_worker(self):
while True:
for _ in range(100):
if not self.open:
return
time.sleep(0.02)
self.print()
def __iter__(self): def __iter__(self):
yield from self.pbars yield from self.pbars
@@ -79,7 +95,6 @@ class PBars:
else: else:
self.pbars[i].update(value) self.pbars[i].update(value)
self.pbars[0].update() self.pbars[0].update()
self.print()
def append(self, pbar: tqdm): def append(self, pbar: tqdm):
self.pbars.append(pbar) self.pbars.append(pbar)
@@ -89,73 +104,20 @@ class PBars:
self.print() self.print()
def close(self): def close(self):
self.print()
self.open = False
if "file" in self.policy:
self.thread.join()
for pbar in self.pbars: for pbar in self.pbars:
pbar.close() pbar.close()
class ProgressTracker:
def __init__(
self,
max: Union[int, float],
prefix: str = "",
suffix: str = "",
logger: logging.Logger = None,
auto_print: bool = True,
percent_incr: Union[int, float] = 5,
default_update: Union[int, float] = 1,
):
self.max = max
self.current = 0
self.prefix = prefix
self.suffix = suffix
self.start_time = dt.datetime.now()
self.auto_print = auto_print
self.next_percent = percent_incr
self.percent_incr = percent_incr
self.default_update = default_update
self.logger = logger if logger is not None else get_logger()
def _update(self):
if self.auto_print and self.current / self.max >= self.next_percent / 100:
self.next_percent += self.percent_incr
self.logger.info(self.prefix + self.ETA + self.suffix)
def update(self, num=None):
if num is None:
num = self.default_update
self.current += num
self._update()
def set(self, value):
self.current = value
self._update()
@property
def ETA(self):
if self.current <= 0:
return "\033[31mETA : unknown\033[0m"
eta = (
(dt.datetime.now() - self.start_time).seconds / self.current * (self.max - self.current)
)
H = eta // 3600
M = (eta - H * 3600) // 60
S = eta % 60
percent = int(100 * self.current / self.max)
return "\033[34mremaining : {:.0f}h {:.0f}min {:.0f}s ({:.0f}% in total). \033[31mETA : {:%Y-%m-%d %H:%M:%S}\033[0m".format(
H, M, S, percent, dt.datetime.now() + dt.timedelta(seconds=eta)
)
def get_eta(self):
return self.ETA
def __str__(self):
return "{}/{}".format(self.current, self.max)
class ProgressBarActor: class ProgressBarActor:
def __init__(self, num_workers: int, num_steps: int) -> None: def __init__(self, name: str, num_workers: int, num_steps: int) -> None:
self.counters = [0 for _ in range(num_workers + 1)] self.counters = [0 for _ in range(num_workers + 1)]
self.p_bars = PBars.auto(num_steps, num_workers) self.p_bars = PBars.auto(
num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")
)
def update(self, worker_id: int, rel_pos: float = None) -> None: def update(self, worker_id: int, rel_pos: float = None) -> None:
"""update a counter """update a counter
@@ -182,7 +144,9 @@ class ProgressBarActor:
self.p_bars.close() self.p_bars.close()
def progress_worker(num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue): def progress_worker(
name: str, num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue
):
"""keeps track of progress on a separate thread """keeps track of progress on a separate thread
Parameters Parameters
@@ -194,7 +158,7 @@ def progress_worker(num_workers: int, num_steps: int, progress_queue: multiproce
Literal[0] : stop the worker and close the progress bars Literal[0] : stop the worker and close the progress bars
Tuple[int, float] : worker id and relative progress between 0 and 1 Tuple[int, float] : worker id and relative progress between 0 and 1
""" """
pbars = PBars.auto(num_steps, num_workers) pbars = PBars.auto(num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step"))
while True: while True:
raw = progress_queue.get() raw = progress_queue.get()
if raw == 0: if raw == 0:
@@ -230,6 +194,10 @@ def format_variable_list(l: List[tuple]):
return joints[0].join(str_list) return joints[0].join(str_list)
def branch_id(branch: Tuple[Path, ...]) -> str:
return "".join("".join(b.name.split()[2:-2]) for b in branch)
def format_value(value): def format_value(value):
if type(value) == type(False): if type(value) == type(False):
return str(value) return str(value)
@@ -304,55 +272,6 @@ def required_simulations(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]
yield variable_ind, full_config yield variable_ind, full_config
def parallelize(func, arg_iter, sim_jobs=4, progress_tracker_kwargs=None, const_kwarg={}):
"""given a function and an iterable of arguments, runs the function in parallel
Parameters
----------
func : a function
arg_iter : an iterable that yields a tuple to be unpacked to the function as argument(s)
sim_jobs : number of parallel runs
progress_tracker_kwargs : key word arguments to be passed to the ProgressTracker
const_kwarg : keyword arguments to be passed to the function on every run
Returns
----------
a list of the result ordered like arg_iter
"""
pt = None
if progress_tracker_kwargs is not None:
progress_tracker_kwargs["auto_print"] = True
pt = ray.remote(ProgressTracker).remote(**progress_tracker_kwargs)
# Initial setup
func = ray.remote(func)
jobs = []
results = []
dico = {} # to keep track of the order, as tasks may no finish in order
for k, args in enumerate(arg_iter):
if not isinstance(args, tuple):
print("iterator must return a tuple")
quit()
# as we got through the iterator, wait for first one to finish before
# adding a new job
if len(jobs) >= sim_jobs:
res, jobs = ray.wait(jobs)
results[dico[res[0].task_id()]] = ray.get(res[0])
if pt is not None:
ray.get(pt.update.remote())
newJob = func.remote(*args, **const_kwarg)
jobs.append(newJob)
dico[newJob.task_id()] = k
results.append(None)
# still have to wait for the last few jobs when there is no more new jobs
for j in jobs:
results[dico[j.task_id()]] = ray.get(j)
if pt is not None:
ray.get(pt.update.remote())
return np.array(results)
def deep_update(d: Mapping, u: Mapping) -> dict: def deep_update(d: Mapping, u: Mapping) -> dict:
for k, v in u.items(): for k, v in u.items():
if isinstance(v, collections.abc.Mapping): if isinstance(v, collections.abc.Mapping):
@@ -391,8 +310,3 @@ def override_config(new: Dict[str, Any], old: Dict[str, Any] = None) -> Dict[str
else: else:
out[section_name] = section out[section_name] = section
return out return out
def formatted_hostname():
s = socket.gethostname().replace(".", "_")
return (PREFIX_KEY_BASE + s).upper()

View File

@@ -173,32 +173,32 @@ class TestInitializeMethods(unittest.TestCase):
t = d["time"] t = d["time"]
field = d["field"] field = d["field"]
conf = load_conf("custom_field/no_change") conf = load_conf("custom_field/no_change")
conf = init._generate_sim_grid(conf) conf = init.build_sim_grid(conf)
result = init.setup_custom_field(conf) result = init.setup_custom_field(conf)
self.assertAlmostEqual(conf["field_0"].real.max(), field.real.max(), 4) self.assertAlmostEqual(conf["field_0"].real.max(), field.real.max(), 4)
self.assertTrue(result) self.assertTrue(result)
conf = load_conf("custom_field/peak_power") conf = load_conf("custom_field/peak_power")
conf = init._generate_sim_grid(conf) conf = init.build_sim_grid(conf)
result = init.setup_custom_field(conf) result = init.setup_custom_field(conf)
self.assertAlmostEqual(math.abs2(conf["field_0"]).max(), 20000, 4) self.assertAlmostEqual(math.abs2(conf["field_0"]).max(), 20000, 4)
self.assertTrue(result) self.assertTrue(result)
self.assertNotAlmostEqual(conf["wavelength"], 1593e-9) self.assertNotAlmostEqual(conf["wavelength"], 1593e-9)
conf = load_conf("custom_field/mean_power") conf = load_conf("custom_field/mean_power")
conf = init._generate_sim_grid(conf) conf = init.build_sim_grid(conf)
result = init.setup_custom_field(conf) result = init.setup_custom_field(conf)
self.assertAlmostEqual(np.trapz(math.abs2(conf["field_0"]), conf["t"]), 0.22 / 40e6, 4) self.assertAlmostEqual(np.trapz(math.abs2(conf["field_0"]), conf["t"]), 0.22 / 40e6, 4)
self.assertTrue(result) self.assertTrue(result)
conf = load_conf("custom_field/recover1") conf = load_conf("custom_field/recover1")
conf = init._generate_sim_grid(conf) conf = init.build_sim_grid(conf)
result = init.setup_custom_field(conf) result = init.setup_custom_field(conf)
self.assertAlmostEqual(math.abs2(conf["field_0"] - field).sum(), 0) self.assertAlmostEqual(math.abs2(conf["field_0"] - field).sum(), 0)
self.assertTrue(result) self.assertTrue(result)
conf = load_conf("custom_field/recover2") conf = load_conf("custom_field/recover2")
conf = init._generate_sim_grid(conf) conf = init.build_sim_grid(conf)
result = init.setup_custom_field(conf) result = init.setup_custom_field(conf)
self.assertAlmostEqual((math.abs2(conf["field_0"]) / 0.9 - math.abs2(field)).sum(), 0) self.assertAlmostEqual((math.abs2(conf["field_0"]) / 0.9 - math.abs2(field)).sum(), 0)
self.assertTrue(result) self.assertTrue(result)