custom worker num, better logging, better config

This commit is contained in:
Benoît Sierro
2021-06-17 10:41:06 +02:00
parent 0108617b8e
commit 1dc90439fe
17 changed files with 193 additions and 141 deletions

View File

@@ -4,3 +4,4 @@ from .math import abs2, argclosest, span
from .physics import fiber, materials, pulse, simulate, units
from .physics.simulate import RK4IP, new_simulation, resume_simulations
from .plotting import plot_avg, plot_results_1D, plot_results_2D, plot_spectrogram
from .spectra import Pulse

View File

@@ -2,15 +2,19 @@ import argparse
import os
import random
from pathlib import Path
from collections import ChainMap
from scgenerator import io
from scgenerator.physics.simulate import (
from ray.worker import get
from .. import io, env, const
from ..logger import get_logger
from ..physics.simulate import (
SequencialSimulations,
resume_simulations,
run_simulation_sequence,
)
from scgenerator.physics.fiber import dispersion_coefficients
from ..physics.fiber import dispersion_coefficients
from pprint import pprint
try:
import ray
@@ -18,29 +22,26 @@ except ImportError:
ray = None
def set_env_variables(cmd_line_args: dict[str, str]):
cm = ChainMap(cmd_line_args, os.environ)
for env_key in const.global_config:
k = env_key.replace(const.ENVIRON_KEY_BASE, "").lower()
v = cm.get(k)
if v is not None:
os.environ[env_key] = str(v)
def create_parser():
parser = argparse.ArgumentParser(description="scgenerator command", prog="scgenerator")
subparsers = parser.add_subparsers(help="sub-command help")
parser.add_argument(
"--id",
type=int,
default=random.randint(0, 1e18),
help="Unique id of the session. Only useful when running several processes at the same time.",
)
parser.add_argument(
"--start-ray",
action="store_true",
help="initialize ray (ray must be installed)",
)
parser.add_argument(
"--no-ray",
action="store_true",
help="force not to use ray",
)
parser.add_argument("--output-name", "-o", help="path to the final output folder", default=None)
for key, args in const.global_config.items():
names = ["--" + key.replace(const.ENVIRON_KEY_BASE, "").replace("_", "-").lower()]
if "short_name" in args:
names.append(args["short_name"])
parser.add_argument(
*names, **{k: v for k, v in args.items() if k not in {"short_name", "type"}}
)
run_parser = subparsers.add_parser("run", help="run a simulation from a config file")
run_parser.add_argument("configs", help="path(s) to the toml configuration file(s)", nargs="+")
@@ -71,15 +72,19 @@ def create_parser():
def main():
parser = create_parser()
args = parser.parse_args()
set_env_variables({k: v for k, v in vars(args).items() if v is not None})
args.func(args)
print(f"coef hits : {dispersion_coefficients.hits}, misses : {dispersion_coefficients.misses}")
logger = get_logger(__name__)
logger.info(f"dispersion cache : {dispersion_coefficients.cache_info()}")
def run_sim(args):
method = prep_ray(args)
run_simulation_sequence(*args.configs, method=method, final_name=args.output_name)
run_simulation_sequence(*args.configs, method=method)
def merge(args):
@@ -91,19 +96,20 @@ def merge(args):
def prep_ray(args):
logger = get_logger(__name__)
if ray:
if args.start_ray:
if env.get(const.START_RAY):
init_str = ray.init()
elif not args.no_ray:
elif not env.get(const.NO_RAY):
try:
init_str = ray.init(
address="auto",
_redis_password=os.environ.get("redis_password", "caco1234"),
)
print(init_str)
logger.info(init_str)
except ConnectionError as e:
print(e)
return SequencialSimulations if args.no_ray else None
logger.error(e)
return SequencialSimulations if env.get(const.NO_RAY) else None
def resume_sim(args):
@@ -111,9 +117,7 @@ def resume_sim(args):
method = prep_ray(args)
sim = resume_simulations(Path(args.sim_dir), method=method)
sim.run()
run_simulation_sequence(
*args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name
)
run_simulation_sequence(*args.configs, method=method, prev_sim_dir=sim.sim_dir)
if __name__ == "__main__":

View File

@@ -1,6 +1,9 @@
__version__ = "0.1.0"
from typing import Any
def pbar_format(worker_id: int):
if worker_id == 0:
return dict(
@@ -17,12 +20,45 @@ def pbar_format(worker_id: int):
ENVIRON_KEY_BASE = "SCGENERATOR_"
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
LOG_POLICY = ENVIRON_KEY_BASE + "LOG_POLICY"
TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_"
PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_"
PARAM_SEPARATOR = " "
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
LOG_FILE_LEVEL = ENVIRON_KEY_BASE + "LOG_FILE_LEVEL"
LOG_PRINT_LEVEL = ENVIRON_KEY_BASE + "LOG_PRINT_LEVEL"
START_RAY = ENVIRON_KEY_BASE + "START_RAY"
NO_RAY = ENVIRON_KEY_BASE + "NO_RAY"
OUTPUT_PATH = ENVIRON_KEY_BASE + "OUTPUT_PATH"
global_config: dict[str, dict[str, Any]] = {
LOG_FILE_LEVEL: dict(
help="minimum lvl of message to be saved in the log file",
choices=["critical", "error", "warning", "info", "debug"],
default=None,
type=str,
),
LOG_PRINT_LEVEL: dict(
help="minimum lvl of message to be printed to the standard output",
choices=["critical", "error", "warning", "info", "debug"],
default="error",
type=str,
),
PBAR_POLICY: dict(
help="what to do with progress pars (print them, make them a txt file or nothing), default is print",
choices=["print", "file", "both", "none"],
default=None,
type=str,
),
START_RAY: dict(action="store_true", help="initialize ray (ray must be installed)", type=bool),
NO_RAY: dict(action="store_true", help="force not to use ray", type=bool),
OUTPUT_PATH: dict(
short_name="-o", help="path to the final output folder", default=None, type=str
),
}
SPEC1_FN = "spectrum_{}.npy"
SPECN_FN = "spectra_{}.npy"
Z_FN = "z.npy"

View File

@@ -38,6 +38,14 @@ b = 3.978e-5
a = 0.425
b = 5.105e-5
[air]
a = 0.1358
b = 3.64e-5
[vacuum]
a = 0
b = 0
[air.sellmeier]
B = [57921050000.0, 1679170000.0]
C = [238018500000000.0, 57362000000000.0]
@@ -45,6 +53,11 @@ P0 = 101325
T0 = 288.15
kind = 2
[air.kerr]
P0 = 101325
T0 = 293.15
n2 = 3.01e-23
[nitrogen.sellmeier]
B = [32431570000.0]
C = [144000000000000.0]
@@ -181,3 +194,16 @@ P0 = 30400.0
T0 = 273.15
n2 = 5.8e-23
source = "Wahlstrand, J. K., Cheng, Y. H., & Milchberg, H. M. (2012). High field optical nonlinearity and the Kramers-Kronig relations. Physical review letters, 109(11), 113904."
[vacuum.sellmeier]
B = []
C = []
P0 = 101325
T0 = 273.15
kind = 1
[vacuum.kerr]
P0 = 30400.0
T0 = 273.15
n2 = 0
source = "none"

View File

@@ -1,7 +1,14 @@
import os
from typing import Dict, Literal, Optional, Set
from typing import Any, Dict, Literal, Optional, Set
from .const import ENVIRON_KEY_BASE, LOG_POLICY, PBAR_POLICY, TMP_FOLDER_KEY_BASE
from .const import (
ENVIRON_KEY_BASE,
LOG_FILE_LEVEL,
LOG_PRINT_LEVEL,
PBAR_POLICY,
TMP_FOLDER_KEY_BASE,
global_config,
)
def data_folder(task_id: int) -> Optional[str]:
@@ -10,6 +17,14 @@ def data_folder(task_id: int) -> Optional[str]:
return tmp
def get(key: str) -> Any:
str_value = os.environ.get(key)
try:
return global_config[key]["type"](str_value)
except (ValueError, KeyError):
return None
def all_environ() -> Dict[str, str]:
"""returns a dictionary of all environment variables set by any instance of scgenerator"""
d = dict(filter(lambda el: el[0].startswith(ENVIRON_KEY_BASE), os.environ.items()))
@@ -17,7 +32,7 @@ def all_environ() -> Dict[str, str]:
def pbar_policy() -> Set[Literal["print", "file"]]:
policy = os.getenv(PBAR_POLICY)
policy = get(PBAR_POLICY)
if policy == "print" or policy is None:
return {"print"}
elif policy == "file":
@@ -28,13 +43,23 @@ def pbar_policy() -> Set[Literal["print", "file"]]:
return set()
def log_policy() -> Set[Literal["print", "file"]]:
policy = os.getenv(LOG_POLICY)
if policy == "print" or policy is None:
return {"print"}
elif policy == "file":
return {"file"}
elif policy == "both":
return {"file", "print"}
else:
return set()
def log_file_level() -> Set[Literal["critical", "error", "warning", "info", "debug"]]:
policy = get(LOG_FILE_LEVEL)
try:
policy = policy.lower()
if policy in {"critical", "error", "warning", "info", "debug"}:
return policy
except AttributeError:
pass
return None
def log_print_level() -> Set[Literal["critical", "error", "warning", "info", "debug"]]:
policy = get(LOG_PRINT_LEVEL)
try:
policy = policy.lower()
if policy in {"critical", "error", "warning", "info", "debug"}:
return policy
except AttributeError:
pass
return None

View File

@@ -212,7 +212,6 @@ class Config(BareConfig):
for param in [
"behaviors",
"z_num",
"frep",
"tolerated_error",
"parallel",
"repeat",

View File

@@ -1,37 +1,16 @@
import logging
from .env import log_policy
# class DebugOnlyFileHandler(logging.FileHandler):
# def __init__(
# self, filename, mode: str, encoding: Optional[str] = None, delay: bool = False
# ) -> None:
# super().__init__(filename, mode=mode, encoding=encoding, delay=delay)
# self.setLevel(logging.DEBUG)
# def emit(self, record: logging.LogRecord) -> None:
# if not record.levelno == logging.DEBUG:
# return
# return super().emit(record)
from .env import log_file_level, log_print_level
DEFAULT_LEVEL = logging.INFO
lvl_map = dict(
lvl_map: dict[str, int] = dict(
debug=logging.DEBUG,
info=logging.INFO,
warning=logging.WARNING,
error=logging.ERROR,
fatal=logging.FATAL,
critical=logging.CRITICAL,
)
loggers = []
def _set_debug():
DEFAULT_LEVEL = logging.DEBUG
def get_logger(name=None):
"""returns a logging.Logger instance. This function is there because if scgenerator
@@ -50,21 +29,10 @@ def get_logger(name=None):
"""
name = __name__ if name is None else name
logger = logging.getLogger(name)
if name not in loggers:
loggers.append(logger)
return configure_logger(logger)
# def set_level_all(lvl):
# _default_lvl =
# logging.basicConfig(level=lvl_map[lvl])
# for logger in loggers:
# logger.setLevel(lvl_map[lvl])
# for handler in logger.handlers:
# handler.setLevel(lvl_map[lvl])
def configure_logger(logger):
def configure_logger(logger: logging.Logger):
"""configures a logging.Logger obj
Parameters
@@ -81,17 +49,20 @@ def configure_logger(logger):
"""
if not hasattr(logger, "already_configured"):
if "file" in log_policy():
print_lvl = lvl_map.get(log_print_level())
file_lvl = lvl_map.get(log_file_level())
if file_lvl is not None:
formatter = logging.Formatter("{levelname}: {name}: {message}", style="{")
file_handler1 = logging.FileHandler("scgenerator.log", "a+")
file_handler1.setFormatter(formatter)
file_handler1.setLevel(logging.DEBUG)
file_handler1.setLevel(file_lvl)
logger.addHandler(file_handler1)
if "print" in log_policy():
if print_lvl is not None:
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setLevel(print_lvl)
logger.addHandler(stream_handler)
logger.setLevel(logging.DEBUG)
logger.setLevel(min(print_lvl, file_lvl))
logger.already_configured = True
return logger

View File

@@ -91,7 +91,7 @@ def u_nm(n, m):
@np_cache
def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
def ndft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
"""creates the nfft matrix
Parameters
@@ -111,7 +111,7 @@ def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
@np_cache
def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
def indft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
"""creates the nfft matrix
Parameters
@@ -126,10 +126,10 @@ def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
np.ndarray, shape = (m, n)
multiply ~X(f) by this matrix to get x(t)
"""
return np.linalg.pinv(nfft_matrix(t, f))
return np.linalg.pinv(ndft_matrix(t, f))
def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
def ndft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
"""computes the Fourier transform of an uneven signal
Parameters
@@ -146,10 +146,10 @@ def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
np.ndarray, shape = (m, )
amplitude at each frequency
"""
return nfft_matrix(t, f) @ s
return ndft_matrix(t, f) @ s
def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
def indft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
"""computes the inverse Fourier transform of an uneven spectrum
Parameters
@@ -166,7 +166,7 @@ def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
np.ndarray, shape = (m, )
amplitude at each point of t
"""
return infft_matrix(t, f) @ a
return indft_matrix(t, f) @ a
def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"):

View File

@@ -130,7 +130,7 @@ def n_eff_marcatili(lambda_, n_gas_2, core_radius, he_mode=(1, 1)):
lambda_ : ndarray, shape (n, )
wavelengths array (m)
n_gas_2 : ndarray, shape (n, )
refractive index of the gas as function of lambda_
square of the refractive index of the gas as function of lambda_
core_radius : float
inner radius of the capillary (m)
he_mode : tuple, shape (2, ), optional
@@ -455,7 +455,6 @@ def HCPCF_dispersion(
Temperature of the material
pressure : float
constant pressure
FIXME tupple : a pressure gradient from pressure[0] to pressure[1] is computed
Returns
-------
@@ -692,15 +691,7 @@ def compute_dispersion(params: BareParams):
)
else:
# Load material info
gas_name = params.gas_name
if gas_name == "vacuum":
material_dico = None
else:
material_dico = toml.loads(io.Paths.gets("gas"))[gas_name]
# compute dispersion
material_dico = toml.loads(io.Paths.gets("gas"))[params.gas_name]
if params.dynamic_dispersion:
return dynamic_HCPCF_dispersion(
lambda_,
@@ -716,9 +707,6 @@ def compute_dispersion(params: BareParams):
params.interp_degree,
)
else:
# actually compute the dispersion
beta2 = HCPCF_dispersion(
lambda_,
material_dico,
@@ -730,10 +718,14 @@ def compute_dispersion(params: BareParams):
)
if material_dico is not None:
A_eff = 1.5 * params.core_radius ** 2
n2 = mat.non_linear_refractive_index(
material_dico, params.pressure, params.temperature
)
A_eff = 1.5 * params.core_radius ** 2 if params.A_eff is None else params.A_eff
if params.n2 is None:
n2 = mat.non_linear_refractive_index(
material_dico, params.pressure, params.temperature
)
else:
n2 = params.n2
gamma = gamma_parameter(n2, params.w0, A_eff)
else:
gamma = None

View File

@@ -39,7 +39,7 @@ def number_density_van_der_waals(
ValueError : Since the Van der Waals equation is a cubic one, there could be more than one real, positive solution
"""
logger = get_logger
logger = get_logger(__name__)
if pressure == 0:
return 0
@@ -73,7 +73,7 @@ def number_density_van_der_waals(
s = f"Van der Waals eq with parameters P={pressure}, T={temperature}, a={a}, b={b}"
s += f", There is more than one possible number density : {roots}."
s += f", {np.min(roots)} was returned"
logger.info(s)
logger.warning(s)
return np.min(roots)
@@ -90,6 +90,8 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
----------
an array n(lambda_)^2 - 1
"""
logger = get_logger(__name__)
WL_THRESHOLD = 8.285e-6
temp_l = lambda_[lambda_ < WL_THRESHOLD]
kind = 1
@@ -101,7 +103,6 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
t0 = material_dico["sellmeier"].get("t0", 273.15)
kind = material_dico["sellmeier"].get("kind", 1)
# Sellmeier equation
chi = np.zeros_like(lambda_) # = n^2 - 1
if kind == 1:
for b, c in zip(B, C):
@@ -120,6 +121,7 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
if pressure is not None:
chi *= pressure / P0
logger.debug(f"computed chi between {np.min(chi):.2e} and {np.max(chi):.2e}")
return chi

View File

@@ -3,11 +3,10 @@ import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Type
from typing_extensions import runtime
import numpy as np
from .. import env, initialize, io, utils
from .. import const, env, initialize, io, utils
from ..errors import IncompleteDataFolderError
from ..logger import get_logger
from . import pulse
@@ -290,7 +289,6 @@ class SequentialRK4IP(RK4IP):
save_data=False,
job_identifier="",
task_id=0,
n_percent=10,
):
self.pbars = pbars
super().__init__(
@@ -298,7 +296,6 @@ class SequentialRK4IP(RK4IP):
save_data=save_data,
job_identifier=job_identifier,
task_id=task_id,
n_percent=n_percent,
)
def step_saved(self):
@@ -314,7 +311,6 @@ class MutliProcRK4IP(RK4IP):
save_data=False,
job_identifier="",
task_id=0,
n_percent=10,
):
self.worker_id = worker_id
self.p_queue = p_queue
@@ -323,7 +319,6 @@ class MutliProcRK4IP(RK4IP):
save_data=save_data,
job_identifier=job_identifier,
task_id=task_id,
n_percent=n_percent,
)
def step_saved(self):
@@ -512,7 +507,10 @@ class MultiProcSimulations(Simulations, priority=1):
def __init__(self, param_seq: initialize.ParamSequence, task_id):
super().__init__(param_seq, task_id=task_id)
self.sim_jobs_per_node = max(1, os.cpu_count() // 2)
if param_seq.config.worker_num is not None:
self.sim_jobs_per_node = param_seq.config.worker_num
else:
self.sim_jobs_per_node = max(1, os.cpu_count() // 2)
self.queue = multiprocessing.JoinableQueue(self.sim_jobs_per_node)
self.progress_queue = multiprocessing.Queue()
self.workers = [
@@ -656,6 +654,8 @@ class RaySimulations(Simulations, priority=2):
@property
def sim_jobs_total(self):
if self.param_seq.config.worker_num is not None:
return self.param_seq.config.worker_num
tot_cpus = sum([node.get("Resources", {}).get("CPU", 0) for node in ray.nodes()])
tot_cpus = min(tot_cpus, self.max_concurrent_jobs)
return int(min(self.param_seq.num_sim, tot_cpus))
@@ -664,7 +664,6 @@ class RaySimulations(Simulations, priority=2):
def run_simulation_sequence(
*config_files: os.PathLike,
method=None,
final_name: str = None,
prev_sim_dir: os.PathLike = None,
):
prev = prev_sim_dir
@@ -674,6 +673,7 @@ def run_simulation_sequence(
prev = sim.sim_dir
path_trees = io.build_path_trees(sim.sim_dir)
final_name = env.get(const.OUTPUT_PATH)
if final_name is None:
final_name = path_trees[0][-1][0].parent.name + " merged"
@@ -687,6 +687,7 @@ def new_simulation(
) -> Simulations:
config_dict = io.load_toml(config_file)
logger = get_logger(__name__)
if prev_sim_dir is not None:
config_dict["prev_sim_dir"] = str(prev_sim_dir)
@@ -698,7 +699,7 @@ def new_simulation(
else:
param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config_dict)
print(f"{param_seq.name=}")
logger.info(f"running {param_seq.name}")
return Simulations.new(param_seq, task_id, method)

View File

@@ -422,7 +422,7 @@ def plot_spectrogram(
values = spec[ind_t][:, ind_f]
if f_range[2].type == "WL":
values = np.apply_along_axis(
units.to_WL, 1, values, params.frep, units.m(f_range[2].inv(new_f))
units.to_WL, 1, values, params.repetition_rate, units.m(f_range[2].inv(new_f))
)
values = np.apply_along_axis(make_uniform_1D, 1, values, new_f)
@@ -528,7 +528,7 @@ def plot_results_2D(
# make uniform if converting to wavelength
if plt_range.unit.type == "WL":
if is_spectrum:
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
values = np.apply_along_axis(units.to_WL, 1, values, params.repetition_rate, x_axis)
values = np.array(
[make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values]
)
@@ -648,7 +648,7 @@ def plot_results_1D(
# make uniform if converting to wavelength
if plt_range.unit.type == "WL":
if is_spectrum:
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind]))
values = units.to_WL(values, params.repetition_rate, units.m.inv(params.w[ind]))
# change the resolution
if isinstance(spacing, float):
@@ -810,8 +810,8 @@ def plot_avg(
values *= yscaling
mean_values = np.mean(values, axis=0)
if plt_range.unit.type == "WL" and renormalize:
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
mean_values = units.to_WL(mean_values, params.frep, x_axis)
values = np.apply_along_axis(units.to_WL, 1, values, params.repetition_rate, x_axis)
mean_values = units.to_WL(mean_values, params.repetition_rate, x_axis)
# change the resolution
if isinstance(spacing, float):

View File

@@ -8,13 +8,9 @@ from pathlib import Path
from typing import Tuple
import numpy as np
from toml import load
from scgenerator.utils.parameter import BareConfig
from ..initialize import validate_config_sequence
from ..io import Paths, load_config
from ..utils import count_variations
def primes(n):
@@ -90,7 +86,7 @@ def create_parser():
"--environment-setup",
required=False,
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && "
"export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_POLICY=file",
"export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_PRINT_LEVEL=none && export SCGENERATOR_LOG_FILE_LEVEL=info",
help="commands to run to setup the environement (default : activate the sc environment with conda)",
)
parser.add_argument(

View File

@@ -198,7 +198,7 @@ class Pulse(Sequence):
spec = np.load(self.path / SPECN_FN.format(i))
if self.__ensure_2d:
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.wl, self.params.frep)
spec = Spectrum(spec, self.wl, self.params.repetition_rate)
self.cache[i] = spec
return spec

View File

@@ -187,7 +187,7 @@ def count_variations(config: BareConfig) -> int:
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
variable_params_num is the number of distinct parameters that will vary."""
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
return sim_num
return int(sim_num)
def format_variable_list(l: List[Tuple[str, Any]]):

View File

@@ -339,7 +339,7 @@ class BareParams:
peak_power: float = Parameter(positive(float, int))
mean_power: float = Parameter(positive(float, int))
energy: float = Parameter(positive(float, int))
soliton_num: float = Parameter(positive(float, int))
soliton_num: float = Parameter(non_negative(float, int))
quantum_noise: bool = Parameter(boolean)
shape: str = Parameter(literal("gaussian", "sech"))
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9))
@@ -362,10 +362,10 @@ class BareParams:
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9))
upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9))
interp_degree: int = Parameter(positive(int))
frep: float = Parameter(positive(float, int))
prev_sim_dir: str = Parameter(string)
readjust_wavelength: bool = Parameter(boolean)
recovery_last_stored: int = Parameter(non_negative(int))
worker_num: int = Parameter(positive(int))
# computed
field_0: np.ndarray = Parameter(type_checker(np.ndarray))

View File

@@ -12,7 +12,6 @@ class TestRecoveryParamSequence(unittest.TestCase):
def setUp(self):
shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP)
self.conf = toml.load(TMP + "/initial_config.toml")
logger.DEFAULT_LEVEL = logger.logging.FATAL
io.set_data_folder(55, TMP)
def test_remaining_simulations_count(self):