custom worker num, better logging, better config

This commit is contained in:
Benoît Sierro
2021-06-17 10:41:06 +02:00
parent 0108617b8e
commit 1dc90439fe
17 changed files with 193 additions and 141 deletions

View File

@@ -4,3 +4,4 @@ from .math import abs2, argclosest, span
from .physics import fiber, materials, pulse, simulate, units from .physics import fiber, materials, pulse, simulate, units
from .physics.simulate import RK4IP, new_simulation, resume_simulations from .physics.simulate import RK4IP, new_simulation, resume_simulations
from .plotting import plot_avg, plot_results_1D, plot_results_2D, plot_spectrogram from .plotting import plot_avg, plot_results_1D, plot_results_2D, plot_spectrogram
from .spectra import Pulse

View File

@@ -2,15 +2,19 @@ import argparse
import os import os
import random import random
from pathlib import Path from pathlib import Path
from collections import ChainMap
from scgenerator import io from ray.worker import get
from scgenerator.physics.simulate import (
from .. import io, env, const
from ..logger import get_logger
from ..physics.simulate import (
SequencialSimulations, SequencialSimulations,
resume_simulations, resume_simulations,
run_simulation_sequence, run_simulation_sequence,
) )
from scgenerator.physics.fiber import dispersion_coefficients from ..physics.fiber import dispersion_coefficients
from pprint import pprint
try: try:
import ray import ray
@@ -18,29 +22,26 @@ except ImportError:
ray = None ray = None
def set_env_variables(cmd_line_args: dict[str, str]):
cm = ChainMap(cmd_line_args, os.environ)
for env_key in const.global_config:
k = env_key.replace(const.ENVIRON_KEY_BASE, "").lower()
v = cm.get(k)
if v is not None:
os.environ[env_key] = str(v)
def create_parser(): def create_parser():
parser = argparse.ArgumentParser(description="scgenerator command", prog="scgenerator") parser = argparse.ArgumentParser(description="scgenerator command", prog="scgenerator")
subparsers = parser.add_subparsers(help="sub-command help") subparsers = parser.add_subparsers(help="sub-command help")
for key, args in const.global_config.items():
names = ["--" + key.replace(const.ENVIRON_KEY_BASE, "").replace("_", "-").lower()]
if "short_name" in args:
names.append(args["short_name"])
parser.add_argument( parser.add_argument(
"--id", *names, **{k: v for k, v in args.items() if k not in {"short_name", "type"}}
type=int,
default=random.randint(0, 1e18),
help="Unique id of the session. Only useful when running several processes at the same time.",
) )
parser.add_argument(
"--start-ray",
action="store_true",
help="initialize ray (ray must be installed)",
)
parser.add_argument(
"--no-ray",
action="store_true",
help="force not to use ray",
)
parser.add_argument("--output-name", "-o", help="path to the final output folder", default=None)
run_parser = subparsers.add_parser("run", help="run a simulation from a config file") run_parser = subparsers.add_parser("run", help="run a simulation from a config file")
run_parser.add_argument("configs", help="path(s) to the toml configuration file(s)", nargs="+") run_parser.add_argument("configs", help="path(s) to the toml configuration file(s)", nargs="+")
@@ -71,15 +72,19 @@ def create_parser():
def main(): def main():
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
set_env_variables({k: v for k, v in vars(args).items() if v is not None})
args.func(args) args.func(args)
print(f"coef hits : {dispersion_coefficients.hits}, misses : {dispersion_coefficients.misses}") logger = get_logger(__name__)
logger.info(f"dispersion cache : {dispersion_coefficients.cache_info()}")
def run_sim(args): def run_sim(args):
method = prep_ray(args) method = prep_ray(args)
run_simulation_sequence(*args.configs, method=method, final_name=args.output_name) run_simulation_sequence(*args.configs, method=method)
def merge(args): def merge(args):
@@ -91,19 +96,20 @@ def merge(args):
def prep_ray(args): def prep_ray(args):
logger = get_logger(__name__)
if ray: if ray:
if args.start_ray: if env.get(const.START_RAY):
init_str = ray.init() init_str = ray.init()
elif not args.no_ray: elif not env.get(const.NO_RAY):
try: try:
init_str = ray.init( init_str = ray.init(
address="auto", address="auto",
_redis_password=os.environ.get("redis_password", "caco1234"), _redis_password=os.environ.get("redis_password", "caco1234"),
) )
print(init_str) logger.info(init_str)
except ConnectionError as e: except ConnectionError as e:
print(e) logger.error(e)
return SequencialSimulations if args.no_ray else None return SequencialSimulations if env.get(const.NO_RAY) else None
def resume_sim(args): def resume_sim(args):
@@ -111,9 +117,7 @@ def resume_sim(args):
method = prep_ray(args) method = prep_ray(args)
sim = resume_simulations(Path(args.sim_dir), method=method) sim = resume_simulations(Path(args.sim_dir), method=method)
sim.run() sim.run()
run_simulation_sequence( run_simulation_sequence(*args.configs, method=method, prev_sim_dir=sim.sim_dir)
*args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,6 +1,9 @@
__version__ = "0.1.0" __version__ = "0.1.0"
from typing import Any
def pbar_format(worker_id: int): def pbar_format(worker_id: int):
if worker_id == 0: if worker_id == 0:
return dict( return dict(
@@ -17,12 +20,45 @@ def pbar_format(worker_id: int):
ENVIRON_KEY_BASE = "SCGENERATOR_" ENVIRON_KEY_BASE = "SCGENERATOR_"
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
LOG_POLICY = ENVIRON_KEY_BASE + "LOG_POLICY"
TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_" TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_"
PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_" PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_"
PARAM_SEPARATOR = " " PARAM_SEPARATOR = " "
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
LOG_FILE_LEVEL = ENVIRON_KEY_BASE + "LOG_FILE_LEVEL"
LOG_PRINT_LEVEL = ENVIRON_KEY_BASE + "LOG_PRINT_LEVEL"
START_RAY = ENVIRON_KEY_BASE + "START_RAY"
NO_RAY = ENVIRON_KEY_BASE + "NO_RAY"
OUTPUT_PATH = ENVIRON_KEY_BASE + "OUTPUT_PATH"
global_config: dict[str, dict[str, Any]] = {
LOG_FILE_LEVEL: dict(
help="minimum lvl of message to be saved in the log file",
choices=["critical", "error", "warning", "info", "debug"],
default=None,
type=str,
),
LOG_PRINT_LEVEL: dict(
help="minimum lvl of message to be printed to the standard output",
choices=["critical", "error", "warning", "info", "debug"],
default="error",
type=str,
),
PBAR_POLICY: dict(
help="what to do with progress pars (print them, make them a txt file or nothing), default is print",
choices=["print", "file", "both", "none"],
default=None,
type=str,
),
START_RAY: dict(action="store_true", help="initialize ray (ray must be installed)", type=bool),
NO_RAY: dict(action="store_true", help="force not to use ray", type=bool),
OUTPUT_PATH: dict(
short_name="-o", help="path to the final output folder", default=None, type=str
),
}
SPEC1_FN = "spectrum_{}.npy" SPEC1_FN = "spectrum_{}.npy"
SPECN_FN = "spectra_{}.npy" SPECN_FN = "spectra_{}.npy"
Z_FN = "z.npy" Z_FN = "z.npy"

View File

@@ -38,6 +38,14 @@ b = 3.978e-5
a = 0.425 a = 0.425
b = 5.105e-5 b = 5.105e-5
[air]
a = 0.1358
b = 3.64e-5
[vacuum]
a = 0
b = 0
[air.sellmeier] [air.sellmeier]
B = [57921050000.0, 1679170000.0] B = [57921050000.0, 1679170000.0]
C = [238018500000000.0, 57362000000000.0] C = [238018500000000.0, 57362000000000.0]
@@ -45,6 +53,11 @@ P0 = 101325
T0 = 288.15 T0 = 288.15
kind = 2 kind = 2
[air.kerr]
P0 = 101325
T0 = 293.15
n2 = 3.01e-23
[nitrogen.sellmeier] [nitrogen.sellmeier]
B = [32431570000.0] B = [32431570000.0]
C = [144000000000000.0] C = [144000000000000.0]
@@ -181,3 +194,16 @@ P0 = 30400.0
T0 = 273.15 T0 = 273.15
n2 = 5.8e-23 n2 = 5.8e-23
source = "Wahlstrand, J. K., Cheng, Y. H., & Milchberg, H. M. (2012). High field optical nonlinearity and the Kramers-Kronig relations. Physical review letters, 109(11), 113904." source = "Wahlstrand, J. K., Cheng, Y. H., & Milchberg, H. M. (2012). High field optical nonlinearity and the Kramers-Kronig relations. Physical review letters, 109(11), 113904."
[vacuum.sellmeier]
B = []
C = []
P0 = 101325
T0 = 273.15
kind = 1
[vacuum.kerr]
P0 = 30400.0
T0 = 273.15
n2 = 0
source = "none"

View File

@@ -1,7 +1,14 @@
import os import os
from typing import Dict, Literal, Optional, Set from typing import Any, Dict, Literal, Optional, Set
from .const import ENVIRON_KEY_BASE, LOG_POLICY, PBAR_POLICY, TMP_FOLDER_KEY_BASE from .const import (
ENVIRON_KEY_BASE,
LOG_FILE_LEVEL,
LOG_PRINT_LEVEL,
PBAR_POLICY,
TMP_FOLDER_KEY_BASE,
global_config,
)
def data_folder(task_id: int) -> Optional[str]: def data_folder(task_id: int) -> Optional[str]:
@@ -10,6 +17,14 @@ def data_folder(task_id: int) -> Optional[str]:
return tmp return tmp
def get(key: str) -> Any:
str_value = os.environ.get(key)
try:
return global_config[key]["type"](str_value)
except (ValueError, KeyError):
return None
def all_environ() -> Dict[str, str]: def all_environ() -> Dict[str, str]:
"""returns a dictionary of all environment variables set by any instance of scgenerator""" """returns a dictionary of all environment variables set by any instance of scgenerator"""
d = dict(filter(lambda el: el[0].startswith(ENVIRON_KEY_BASE), os.environ.items())) d = dict(filter(lambda el: el[0].startswith(ENVIRON_KEY_BASE), os.environ.items()))
@@ -17,7 +32,7 @@ def all_environ() -> Dict[str, str]:
def pbar_policy() -> Set[Literal["print", "file"]]: def pbar_policy() -> Set[Literal["print", "file"]]:
policy = os.getenv(PBAR_POLICY) policy = get(PBAR_POLICY)
if policy == "print" or policy is None: if policy == "print" or policy is None:
return {"print"} return {"print"}
elif policy == "file": elif policy == "file":
@@ -28,13 +43,23 @@ def pbar_policy() -> Set[Literal["print", "file"]]:
return set() return set()
def log_policy() -> Set[Literal["print", "file"]]: def log_file_level() -> Set[Literal["critical", "error", "warning", "info", "debug"]]:
policy = os.getenv(LOG_POLICY) policy = get(LOG_FILE_LEVEL)
if policy == "print" or policy is None: try:
return {"print"} policy = policy.lower()
elif policy == "file": if policy in {"critical", "error", "warning", "info", "debug"}:
return {"file"} return policy
elif policy == "both": except AttributeError:
return {"file", "print"} pass
else: return None
return set()
def log_print_level() -> Set[Literal["critical", "error", "warning", "info", "debug"]]:
policy = get(LOG_PRINT_LEVEL)
try:
policy = policy.lower()
if policy in {"critical", "error", "warning", "info", "debug"}:
return policy
except AttributeError:
pass
return None

View File

@@ -212,7 +212,6 @@ class Config(BareConfig):
for param in [ for param in [
"behaviors", "behaviors",
"z_num", "z_num",
"frep",
"tolerated_error", "tolerated_error",
"parallel", "parallel",
"repeat", "repeat",

View File

@@ -1,37 +1,16 @@
import logging import logging
from .env import log_policy from .env import log_file_level, log_print_level
# class DebugOnlyFileHandler(logging.FileHandler):
# def __init__(
# self, filename, mode: str, encoding: Optional[str] = None, delay: bool = False
# ) -> None:
# super().__init__(filename, mode=mode, encoding=encoding, delay=delay)
# self.setLevel(logging.DEBUG)
# def emit(self, record: logging.LogRecord) -> None:
# if not record.levelno == logging.DEBUG:
# return
# return super().emit(record)
DEFAULT_LEVEL = logging.INFO lvl_map: dict[str, int] = dict(
lvl_map = dict(
debug=logging.DEBUG, debug=logging.DEBUG,
info=logging.INFO, info=logging.INFO,
warning=logging.WARNING, warning=logging.WARNING,
error=logging.ERROR, error=logging.ERROR,
fatal=logging.FATAL,
critical=logging.CRITICAL, critical=logging.CRITICAL,
) )
loggers = []
def _set_debug():
DEFAULT_LEVEL = logging.DEBUG
def get_logger(name=None): def get_logger(name=None):
"""returns a logging.Logger instance. This function is there because if scgenerator """returns a logging.Logger instance. This function is there because if scgenerator
@@ -50,21 +29,10 @@ def get_logger(name=None):
""" """
name = __name__ if name is None else name name = __name__ if name is None else name
logger = logging.getLogger(name) logger = logging.getLogger(name)
if name not in loggers:
loggers.append(logger)
return configure_logger(logger) return configure_logger(logger)
# def set_level_all(lvl): def configure_logger(logger: logging.Logger):
# _default_lvl =
# logging.basicConfig(level=lvl_map[lvl])
# for logger in loggers:
# logger.setLevel(lvl_map[lvl])
# for handler in logger.handlers:
# handler.setLevel(lvl_map[lvl])
def configure_logger(logger):
"""configures a logging.Logger obj """configures a logging.Logger obj
Parameters Parameters
@@ -81,17 +49,20 @@ def configure_logger(logger):
""" """
if not hasattr(logger, "already_configured"): if not hasattr(logger, "already_configured"):
if "file" in log_policy(): print_lvl = lvl_map.get(log_print_level())
file_lvl = lvl_map.get(log_file_level())
if file_lvl is not None:
formatter = logging.Formatter("{levelname}: {name}: {message}", style="{") formatter = logging.Formatter("{levelname}: {name}: {message}", style="{")
file_handler1 = logging.FileHandler("scgenerator.log", "a+") file_handler1 = logging.FileHandler("scgenerator.log", "a+")
file_handler1.setFormatter(formatter) file_handler1.setFormatter(formatter)
file_handler1.setLevel(logging.DEBUG) file_handler1.setLevel(file_lvl)
logger.addHandler(file_handler1) logger.addHandler(file_handler1)
if "print" in log_policy(): if print_lvl is not None:
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO) stream_handler.setLevel(print_lvl)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
logger.setLevel(logging.DEBUG)
logger.setLevel(min(print_lvl, file_lvl))
logger.already_configured = True logger.already_configured = True
return logger return logger

View File

@@ -91,7 +91,7 @@ def u_nm(n, m):
@np_cache @np_cache
def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray: def ndft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
"""creates the nfft matrix """creates the nfft matrix
Parameters Parameters
@@ -111,7 +111,7 @@ def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
@np_cache @np_cache
def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray: def indft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
"""creates the nfft matrix """creates the nfft matrix
Parameters Parameters
@@ -126,10 +126,10 @@ def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
np.ndarray, shape = (m, n) np.ndarray, shape = (m, n)
multiply ~X(f) by this matrix to get x(t) multiply ~X(f) by this matrix to get x(t)
""" """
return np.linalg.pinv(nfft_matrix(t, f)) return np.linalg.pinv(ndft_matrix(t, f))
def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray: def ndft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
"""computes the Fourier transform of an uneven signal """computes the Fourier transform of an uneven signal
Parameters Parameters
@@ -146,10 +146,10 @@ def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
np.ndarray, shape = (m, ) np.ndarray, shape = (m, )
amplitude at each frequency amplitude at each frequency
""" """
return nfft_matrix(t, f) @ s return ndft_matrix(t, f) @ s
def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray: def indft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
"""computes the inverse Fourier transform of an uneven spectrum """computes the inverse Fourier transform of an uneven spectrum
Parameters Parameters
@@ -166,7 +166,7 @@ def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
np.ndarray, shape = (m, ) np.ndarray, shape = (m, )
amplitude at each point of t amplitude at each point of t
""" """
return infft_matrix(t, f) @ a return indft_matrix(t, f) @ a
def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"): def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"):

View File

@@ -130,7 +130,7 @@ def n_eff_marcatili(lambda_, n_gas_2, core_radius, he_mode=(1, 1)):
lambda_ : ndarray, shape (n, ) lambda_ : ndarray, shape (n, )
wavelengths array (m) wavelengths array (m)
n_gas_2 : ndarray, shape (n, ) n_gas_2 : ndarray, shape (n, )
refractive index of the gas as function of lambda_ square of the refractive index of the gas as function of lambda_
core_radius : float core_radius : float
inner radius of the capillary (m) inner radius of the capillary (m)
he_mode : tuple, shape (2, ), optional he_mode : tuple, shape (2, ), optional
@@ -455,7 +455,6 @@ def HCPCF_dispersion(
Temperature of the material Temperature of the material
pressure : float pressure : float
constant pressure constant pressure
FIXME tupple : a pressure gradient from pressure[0] to pressure[1] is computed
Returns Returns
------- -------
@@ -692,15 +691,7 @@ def compute_dispersion(params: BareParams):
) )
else: else:
# Load material info material_dico = toml.loads(io.Paths.gets("gas"))[params.gas_name]
gas_name = params.gas_name
if gas_name == "vacuum":
material_dico = None
else:
material_dico = toml.loads(io.Paths.gets("gas"))[gas_name]
# compute dispersion
if params.dynamic_dispersion: if params.dynamic_dispersion:
return dynamic_HCPCF_dispersion( return dynamic_HCPCF_dispersion(
lambda_, lambda_,
@@ -716,9 +707,6 @@ def compute_dispersion(params: BareParams):
params.interp_degree, params.interp_degree,
) )
else: else:
# actually compute the dispersion
beta2 = HCPCF_dispersion( beta2 = HCPCF_dispersion(
lambda_, lambda_,
material_dico, material_dico,
@@ -730,10 +718,14 @@ def compute_dispersion(params: BareParams):
) )
if material_dico is not None: if material_dico is not None:
A_eff = 1.5 * params.core_radius ** 2
A_eff = 1.5 * params.core_radius ** 2 if params.A_eff is None else params.A_eff
if params.n2 is None:
n2 = mat.non_linear_refractive_index( n2 = mat.non_linear_refractive_index(
material_dico, params.pressure, params.temperature material_dico, params.pressure, params.temperature
) )
else:
n2 = params.n2
gamma = gamma_parameter(n2, params.w0, A_eff) gamma = gamma_parameter(n2, params.w0, A_eff)
else: else:
gamma = None gamma = None

View File

@@ -39,7 +39,7 @@ def number_density_van_der_waals(
ValueError : Since the Van der Waals equation is a cubic one, there could be more than one real, positive solution ValueError : Since the Van der Waals equation is a cubic one, there could be more than one real, positive solution
""" """
logger = get_logger logger = get_logger(__name__)
if pressure == 0: if pressure == 0:
return 0 return 0
@@ -73,7 +73,7 @@ def number_density_van_der_waals(
s = f"Van der Waals eq with parameters P={pressure}, T={temperature}, a={a}, b={b}" s = f"Van der Waals eq with parameters P={pressure}, T={temperature}, a={a}, b={b}"
s += f", There is more than one possible number density : {roots}." s += f", There is more than one possible number density : {roots}."
s += f", {np.min(roots)} was returned" s += f", {np.min(roots)} was returned"
logger.info(s) logger.warning(s)
return np.min(roots) return np.min(roots)
@@ -90,6 +90,8 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
---------- ----------
an array n(lambda_)^2 - 1 an array n(lambda_)^2 - 1
""" """
logger = get_logger(__name__)
WL_THRESHOLD = 8.285e-6 WL_THRESHOLD = 8.285e-6
temp_l = lambda_[lambda_ < WL_THRESHOLD] temp_l = lambda_[lambda_ < WL_THRESHOLD]
kind = 1 kind = 1
@@ -101,7 +103,6 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
t0 = material_dico["sellmeier"].get("t0", 273.15) t0 = material_dico["sellmeier"].get("t0", 273.15)
kind = material_dico["sellmeier"].get("kind", 1) kind = material_dico["sellmeier"].get("kind", 1)
# Sellmeier equation
chi = np.zeros_like(lambda_) # = n^2 - 1 chi = np.zeros_like(lambda_) # = n^2 - 1
if kind == 1: if kind == 1:
for b, c in zip(B, C): for b, c in zip(B, C):
@@ -120,6 +121,7 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
if pressure is not None: if pressure is not None:
chi *= pressure / P0 chi *= pressure / P0
logger.debug(f"computed chi between {np.min(chi):.2e} and {np.max(chi):.2e}")
return chi return chi

View File

@@ -3,11 +3,10 @@ import os
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple, Type from typing import Dict, List, Tuple, Type
from typing_extensions import runtime
import numpy as np import numpy as np
from .. import env, initialize, io, utils from .. import const, env, initialize, io, utils
from ..errors import IncompleteDataFolderError from ..errors import IncompleteDataFolderError
from ..logger import get_logger from ..logger import get_logger
from . import pulse from . import pulse
@@ -290,7 +289,6 @@ class SequentialRK4IP(RK4IP):
save_data=False, save_data=False,
job_identifier="", job_identifier="",
task_id=0, task_id=0,
n_percent=10,
): ):
self.pbars = pbars self.pbars = pbars
super().__init__( super().__init__(
@@ -298,7 +296,6 @@ class SequentialRK4IP(RK4IP):
save_data=save_data, save_data=save_data,
job_identifier=job_identifier, job_identifier=job_identifier,
task_id=task_id, task_id=task_id,
n_percent=n_percent,
) )
def step_saved(self): def step_saved(self):
@@ -314,7 +311,6 @@ class MutliProcRK4IP(RK4IP):
save_data=False, save_data=False,
job_identifier="", job_identifier="",
task_id=0, task_id=0,
n_percent=10,
): ):
self.worker_id = worker_id self.worker_id = worker_id
self.p_queue = p_queue self.p_queue = p_queue
@@ -323,7 +319,6 @@ class MutliProcRK4IP(RK4IP):
save_data=save_data, save_data=save_data,
job_identifier=job_identifier, job_identifier=job_identifier,
task_id=task_id, task_id=task_id,
n_percent=n_percent,
) )
def step_saved(self): def step_saved(self):
@@ -512,6 +507,9 @@ class MultiProcSimulations(Simulations, priority=1):
def __init__(self, param_seq: initialize.ParamSequence, task_id): def __init__(self, param_seq: initialize.ParamSequence, task_id):
super().__init__(param_seq, task_id=task_id) super().__init__(param_seq, task_id=task_id)
if param_seq.config.worker_num is not None:
self.sim_jobs_per_node = param_seq.config.worker_num
else:
self.sim_jobs_per_node = max(1, os.cpu_count() // 2) self.sim_jobs_per_node = max(1, os.cpu_count() // 2)
self.queue = multiprocessing.JoinableQueue(self.sim_jobs_per_node) self.queue = multiprocessing.JoinableQueue(self.sim_jobs_per_node)
self.progress_queue = multiprocessing.Queue() self.progress_queue = multiprocessing.Queue()
@@ -656,6 +654,8 @@ class RaySimulations(Simulations, priority=2):
@property @property
def sim_jobs_total(self): def sim_jobs_total(self):
if self.param_seq.config.worker_num is not None:
return self.param_seq.config.worker_num
tot_cpus = sum([node.get("Resources", {}).get("CPU", 0) for node in ray.nodes()]) tot_cpus = sum([node.get("Resources", {}).get("CPU", 0) for node in ray.nodes()])
tot_cpus = min(tot_cpus, self.max_concurrent_jobs) tot_cpus = min(tot_cpus, self.max_concurrent_jobs)
return int(min(self.param_seq.num_sim, tot_cpus)) return int(min(self.param_seq.num_sim, tot_cpus))
@@ -664,7 +664,6 @@ class RaySimulations(Simulations, priority=2):
def run_simulation_sequence( def run_simulation_sequence(
*config_files: os.PathLike, *config_files: os.PathLike,
method=None, method=None,
final_name: str = None,
prev_sim_dir: os.PathLike = None, prev_sim_dir: os.PathLike = None,
): ):
prev = prev_sim_dir prev = prev_sim_dir
@@ -674,6 +673,7 @@ def run_simulation_sequence(
prev = sim.sim_dir prev = sim.sim_dir
path_trees = io.build_path_trees(sim.sim_dir) path_trees = io.build_path_trees(sim.sim_dir)
final_name = env.get(const.OUTPUT_PATH)
if final_name is None: if final_name is None:
final_name = path_trees[0][-1][0].parent.name + " merged" final_name = path_trees[0][-1][0].parent.name + " merged"
@@ -687,6 +687,7 @@ def new_simulation(
) -> Simulations: ) -> Simulations:
config_dict = io.load_toml(config_file) config_dict = io.load_toml(config_file)
logger = get_logger(__name__)
if prev_sim_dir is not None: if prev_sim_dir is not None:
config_dict["prev_sim_dir"] = str(prev_sim_dir) config_dict["prev_sim_dir"] = str(prev_sim_dir)
@@ -698,7 +699,7 @@ def new_simulation(
else: else:
param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config_dict) param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config_dict)
print(f"{param_seq.name=}") logger.info(f"running {param_seq.name}")
return Simulations.new(param_seq, task_id, method) return Simulations.new(param_seq, task_id, method)

View File

@@ -422,7 +422,7 @@ def plot_spectrogram(
values = spec[ind_t][:, ind_f] values = spec[ind_t][:, ind_f]
if f_range[2].type == "WL": if f_range[2].type == "WL":
values = np.apply_along_axis( values = np.apply_along_axis(
units.to_WL, 1, values, params.frep, units.m(f_range[2].inv(new_f)) units.to_WL, 1, values, params.repetition_rate, units.m(f_range[2].inv(new_f))
) )
values = np.apply_along_axis(make_uniform_1D, 1, values, new_f) values = np.apply_along_axis(make_uniform_1D, 1, values, new_f)
@@ -528,7 +528,7 @@ def plot_results_2D(
# make uniform if converting to wavelength # make uniform if converting to wavelength
if plt_range.unit.type == "WL": if plt_range.unit.type == "WL":
if is_spectrum: if is_spectrum:
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis) values = np.apply_along_axis(units.to_WL, 1, values, params.repetition_rate, x_axis)
values = np.array( values = np.array(
[make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values] [make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values]
) )
@@ -648,7 +648,7 @@ def plot_results_1D(
# make uniform if converting to wavelength # make uniform if converting to wavelength
if plt_range.unit.type == "WL": if plt_range.unit.type == "WL":
if is_spectrum: if is_spectrum:
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind])) values = units.to_WL(values, params.repetition_rate, units.m.inv(params.w[ind]))
# change the resolution # change the resolution
if isinstance(spacing, float): if isinstance(spacing, float):
@@ -810,8 +810,8 @@ def plot_avg(
values *= yscaling values *= yscaling
mean_values = np.mean(values, axis=0) mean_values = np.mean(values, axis=0)
if plt_range.unit.type == "WL" and renormalize: if plt_range.unit.type == "WL" and renormalize:
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis) values = np.apply_along_axis(units.to_WL, 1, values, params.repetition_rate, x_axis)
mean_values = units.to_WL(mean_values, params.frep, x_axis) mean_values = units.to_WL(mean_values, params.repetition_rate, x_axis)
# change the resolution # change the resolution
if isinstance(spacing, float): if isinstance(spacing, float):

View File

@@ -8,13 +8,9 @@ from pathlib import Path
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
from toml import load
from scgenerator.utils.parameter import BareConfig
from ..initialize import validate_config_sequence from ..initialize import validate_config_sequence
from ..io import Paths, load_config from ..io import Paths, load_config
from ..utils import count_variations
def primes(n): def primes(n):
@@ -90,7 +86,7 @@ def create_parser():
"--environment-setup", "--environment-setup",
required=False, required=False,
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && " default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && "
"export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_POLICY=file", "export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_PRINT_LEVEL=none && export SCGENERATOR_LOG_FILE_LEVEL=info",
help="commands to run to setup the environement (default : activate the sc environment with conda)", help="commands to run to setup the environement (default : activate the sc environment with conda)",
) )
parser.add_argument( parser.add_argument(

View File

@@ -198,7 +198,7 @@ class Pulse(Sequence):
spec = np.load(self.path / SPECN_FN.format(i)) spec = np.load(self.path / SPECN_FN.format(i))
if self.__ensure_2d: if self.__ensure_2d:
spec = np.atleast_2d(spec) spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.wl, self.params.frep) spec = Spectrum(spec, self.wl, self.params.repetition_rate)
self.cache[i] = spec self.cache[i] = spec
return spec return spec

View File

@@ -187,7 +187,7 @@ def count_variations(config: BareConfig) -> int:
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and """returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
variable_params_num is the number of distinct parameters that will vary.""" variable_params_num is the number of distinct parameters that will vary."""
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
return sim_num return int(sim_num)
def format_variable_list(l: List[Tuple[str, Any]]): def format_variable_list(l: List[Tuple[str, Any]]):

View File

@@ -339,7 +339,7 @@ class BareParams:
peak_power: float = Parameter(positive(float, int)) peak_power: float = Parameter(positive(float, int))
mean_power: float = Parameter(positive(float, int)) mean_power: float = Parameter(positive(float, int))
energy: float = Parameter(positive(float, int)) energy: float = Parameter(positive(float, int))
soliton_num: float = Parameter(positive(float, int)) soliton_num: float = Parameter(non_negative(float, int))
quantum_noise: bool = Parameter(boolean) quantum_noise: bool = Parameter(boolean)
shape: str = Parameter(literal("gaussian", "sech")) shape: str = Parameter(literal("gaussian", "sech"))
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9)) wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9))
@@ -362,10 +362,10 @@ class BareParams:
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9)) lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9))
upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9)) upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9))
interp_degree: int = Parameter(positive(int)) interp_degree: int = Parameter(positive(int))
frep: float = Parameter(positive(float, int))
prev_sim_dir: str = Parameter(string) prev_sim_dir: str = Parameter(string)
readjust_wavelength: bool = Parameter(boolean) readjust_wavelength: bool = Parameter(boolean)
recovery_last_stored: int = Parameter(non_negative(int)) recovery_last_stored: int = Parameter(non_negative(int))
worker_num: int = Parameter(positive(int))
# computed # computed
field_0: np.ndarray = Parameter(type_checker(np.ndarray)) field_0: np.ndarray = Parameter(type_checker(np.ndarray))

View File

@@ -12,7 +12,6 @@ class TestRecoveryParamSequence(unittest.TestCase):
def setUp(self): def setUp(self):
shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP) shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP)
self.conf = toml.load(TMP + "/initial_config.toml") self.conf = toml.load(TMP + "/initial_config.toml")
logger.DEFAULT_LEVEL = logger.logging.FATAL
io.set_data_folder(55, TMP) io.set_data_folder(55, TMP)
def test_remaining_simulations_count(self): def test_remaining_simulations_count(self):