custom worker num, better logging, better config
This commit is contained in:
@@ -4,3 +4,4 @@ from .math import abs2, argclosest, span
|
||||
from .physics import fiber, materials, pulse, simulate, units
|
||||
from .physics.simulate import RK4IP, new_simulation, resume_simulations
|
||||
from .plotting import plot_avg, plot_results_1D, plot_results_2D, plot_spectrogram
|
||||
from .spectra import Pulse
|
||||
|
||||
@@ -2,15 +2,19 @@ import argparse
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from collections import ChainMap
|
||||
|
||||
from scgenerator import io
|
||||
from scgenerator.physics.simulate import (
|
||||
from ray.worker import get
|
||||
|
||||
from .. import io, env, const
|
||||
from ..logger import get_logger
|
||||
from ..physics.simulate import (
|
||||
SequencialSimulations,
|
||||
resume_simulations,
|
||||
run_simulation_sequence,
|
||||
)
|
||||
from scgenerator.physics.fiber import dispersion_coefficients
|
||||
|
||||
from ..physics.fiber import dispersion_coefficients
|
||||
from pprint import pprint
|
||||
|
||||
try:
|
||||
import ray
|
||||
@@ -18,29 +22,26 @@ except ImportError:
|
||||
ray = None
|
||||
|
||||
|
||||
def set_env_variables(cmd_line_args: dict[str, str]):
|
||||
cm = ChainMap(cmd_line_args, os.environ)
|
||||
for env_key in const.global_config:
|
||||
k = env_key.replace(const.ENVIRON_KEY_BASE, "").lower()
|
||||
v = cm.get(k)
|
||||
if v is not None:
|
||||
os.environ[env_key] = str(v)
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(description="scgenerator command", prog="scgenerator")
|
||||
|
||||
subparsers = parser.add_subparsers(help="sub-command help")
|
||||
|
||||
for key, args in const.global_config.items():
|
||||
names = ["--" + key.replace(const.ENVIRON_KEY_BASE, "").replace("_", "-").lower()]
|
||||
if "short_name" in args:
|
||||
names.append(args["short_name"])
|
||||
parser.add_argument(
|
||||
"--id",
|
||||
type=int,
|
||||
default=random.randint(0, 1e18),
|
||||
help="Unique id of the session. Only useful when running several processes at the same time.",
|
||||
*names, **{k: v for k, v in args.items() if k not in {"short_name", "type"}}
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-ray",
|
||||
action="store_true",
|
||||
help="initialize ray (ray must be installed)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-ray",
|
||||
action="store_true",
|
||||
help="force not to use ray",
|
||||
)
|
||||
parser.add_argument("--output-name", "-o", help="path to the final output folder", default=None)
|
||||
|
||||
run_parser = subparsers.add_parser("run", help="run a simulation from a config file")
|
||||
run_parser.add_argument("configs", help="path(s) to the toml configuration file(s)", nargs="+")
|
||||
@@ -71,15 +72,19 @@ def create_parser():
|
||||
def main():
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
set_env_variables({k: v for k, v in vars(args).items() if v is not None})
|
||||
|
||||
args.func(args)
|
||||
|
||||
print(f"coef hits : {dispersion_coefficients.hits}, misses : {dispersion_coefficients.misses}")
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"dispersion cache : {dispersion_coefficients.cache_info()}")
|
||||
|
||||
|
||||
def run_sim(args):
|
||||
|
||||
method = prep_ray(args)
|
||||
run_simulation_sequence(*args.configs, method=method, final_name=args.output_name)
|
||||
run_simulation_sequence(*args.configs, method=method)
|
||||
|
||||
|
||||
def merge(args):
|
||||
@@ -91,19 +96,20 @@ def merge(args):
|
||||
|
||||
|
||||
def prep_ray(args):
|
||||
logger = get_logger(__name__)
|
||||
if ray:
|
||||
if args.start_ray:
|
||||
if env.get(const.START_RAY):
|
||||
init_str = ray.init()
|
||||
elif not args.no_ray:
|
||||
elif not env.get(const.NO_RAY):
|
||||
try:
|
||||
init_str = ray.init(
|
||||
address="auto",
|
||||
_redis_password=os.environ.get("redis_password", "caco1234"),
|
||||
)
|
||||
print(init_str)
|
||||
logger.info(init_str)
|
||||
except ConnectionError as e:
|
||||
print(e)
|
||||
return SequencialSimulations if args.no_ray else None
|
||||
logger.error(e)
|
||||
return SequencialSimulations if env.get(const.NO_RAY) else None
|
||||
|
||||
|
||||
def resume_sim(args):
|
||||
@@ -111,9 +117,7 @@ def resume_sim(args):
|
||||
method = prep_ray(args)
|
||||
sim = resume_simulations(Path(args.sim_dir), method=method)
|
||||
sim.run()
|
||||
run_simulation_sequence(
|
||||
*args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name
|
||||
)
|
||||
run_simulation_sequence(*args.configs, method=method, prev_sim_dir=sim.sim_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
__version__ = "0.1.0"
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def pbar_format(worker_id: int):
|
||||
if worker_id == 0:
|
||||
return dict(
|
||||
@@ -17,12 +20,45 @@ def pbar_format(worker_id: int):
|
||||
|
||||
|
||||
ENVIRON_KEY_BASE = "SCGENERATOR_"
|
||||
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
|
||||
LOG_POLICY = ENVIRON_KEY_BASE + "LOG_POLICY"
|
||||
TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_"
|
||||
PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_"
|
||||
PARAM_SEPARATOR = " "
|
||||
|
||||
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
|
||||
LOG_FILE_LEVEL = ENVIRON_KEY_BASE + "LOG_FILE_LEVEL"
|
||||
LOG_PRINT_LEVEL = ENVIRON_KEY_BASE + "LOG_PRINT_LEVEL"
|
||||
START_RAY = ENVIRON_KEY_BASE + "START_RAY"
|
||||
NO_RAY = ENVIRON_KEY_BASE + "NO_RAY"
|
||||
OUTPUT_PATH = ENVIRON_KEY_BASE + "OUTPUT_PATH"
|
||||
|
||||
|
||||
global_config: dict[str, dict[str, Any]] = {
|
||||
LOG_FILE_LEVEL: dict(
|
||||
help="minimum lvl of message to be saved in the log file",
|
||||
choices=["critical", "error", "warning", "info", "debug"],
|
||||
default=None,
|
||||
type=str,
|
||||
),
|
||||
LOG_PRINT_LEVEL: dict(
|
||||
help="minimum lvl of message to be printed to the standard output",
|
||||
choices=["critical", "error", "warning", "info", "debug"],
|
||||
default="error",
|
||||
type=str,
|
||||
),
|
||||
PBAR_POLICY: dict(
|
||||
help="what to do with progress pars (print them, make them a txt file or nothing), default is print",
|
||||
choices=["print", "file", "both", "none"],
|
||||
default=None,
|
||||
type=str,
|
||||
),
|
||||
START_RAY: dict(action="store_true", help="initialize ray (ray must be installed)", type=bool),
|
||||
NO_RAY: dict(action="store_true", help="force not to use ray", type=bool),
|
||||
OUTPUT_PATH: dict(
|
||||
short_name="-o", help="path to the final output folder", default=None, type=str
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
SPEC1_FN = "spectrum_{}.npy"
|
||||
SPECN_FN = "spectra_{}.npy"
|
||||
Z_FN = "z.npy"
|
||||
|
||||
@@ -38,6 +38,14 @@ b = 3.978e-5
|
||||
a = 0.425
|
||||
b = 5.105e-5
|
||||
|
||||
[air]
|
||||
a = 0.1358
|
||||
b = 3.64e-5
|
||||
|
||||
[vacuum]
|
||||
a = 0
|
||||
b = 0
|
||||
|
||||
[air.sellmeier]
|
||||
B = [57921050000.0, 1679170000.0]
|
||||
C = [238018500000000.0, 57362000000000.0]
|
||||
@@ -45,6 +53,11 @@ P0 = 101325
|
||||
T0 = 288.15
|
||||
kind = 2
|
||||
|
||||
[air.kerr]
|
||||
P0 = 101325
|
||||
T0 = 293.15
|
||||
n2 = 3.01e-23
|
||||
|
||||
[nitrogen.sellmeier]
|
||||
B = [32431570000.0]
|
||||
C = [144000000000000.0]
|
||||
@@ -181,3 +194,16 @@ P0 = 30400.0
|
||||
T0 = 273.15
|
||||
n2 = 5.8e-23
|
||||
source = "Wahlstrand, J. K., Cheng, Y. H., & Milchberg, H. M. (2012). High field optical nonlinearity and the Kramers-Kronig relations. Physical review letters, 109(11), 113904."
|
||||
|
||||
[vacuum.sellmeier]
|
||||
B = []
|
||||
C = []
|
||||
P0 = 101325
|
||||
T0 = 273.15
|
||||
kind = 1
|
||||
|
||||
[vacuum.kerr]
|
||||
P0 = 30400.0
|
||||
T0 = 273.15
|
||||
n2 = 0
|
||||
source = "none"
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
import os
|
||||
from typing import Dict, Literal, Optional, Set
|
||||
from typing import Any, Dict, Literal, Optional, Set
|
||||
|
||||
from .const import ENVIRON_KEY_BASE, LOG_POLICY, PBAR_POLICY, TMP_FOLDER_KEY_BASE
|
||||
from .const import (
|
||||
ENVIRON_KEY_BASE,
|
||||
LOG_FILE_LEVEL,
|
||||
LOG_PRINT_LEVEL,
|
||||
PBAR_POLICY,
|
||||
TMP_FOLDER_KEY_BASE,
|
||||
global_config,
|
||||
)
|
||||
|
||||
|
||||
def data_folder(task_id: int) -> Optional[str]:
|
||||
@@ -10,6 +17,14 @@ def data_folder(task_id: int) -> Optional[str]:
|
||||
return tmp
|
||||
|
||||
|
||||
def get(key: str) -> Any:
|
||||
str_value = os.environ.get(key)
|
||||
try:
|
||||
return global_config[key]["type"](str_value)
|
||||
except (ValueError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def all_environ() -> Dict[str, str]:
|
||||
"""returns a dictionary of all environment variables set by any instance of scgenerator"""
|
||||
d = dict(filter(lambda el: el[0].startswith(ENVIRON_KEY_BASE), os.environ.items()))
|
||||
@@ -17,7 +32,7 @@ def all_environ() -> Dict[str, str]:
|
||||
|
||||
|
||||
def pbar_policy() -> Set[Literal["print", "file"]]:
|
||||
policy = os.getenv(PBAR_POLICY)
|
||||
policy = get(PBAR_POLICY)
|
||||
if policy == "print" or policy is None:
|
||||
return {"print"}
|
||||
elif policy == "file":
|
||||
@@ -28,13 +43,23 @@ def pbar_policy() -> Set[Literal["print", "file"]]:
|
||||
return set()
|
||||
|
||||
|
||||
def log_policy() -> Set[Literal["print", "file"]]:
|
||||
policy = os.getenv(LOG_POLICY)
|
||||
if policy == "print" or policy is None:
|
||||
return {"print"}
|
||||
elif policy == "file":
|
||||
return {"file"}
|
||||
elif policy == "both":
|
||||
return {"file", "print"}
|
||||
else:
|
||||
return set()
|
||||
def log_file_level() -> Set[Literal["critical", "error", "warning", "info", "debug"]]:
|
||||
policy = get(LOG_FILE_LEVEL)
|
||||
try:
|
||||
policy = policy.lower()
|
||||
if policy in {"critical", "error", "warning", "info", "debug"}:
|
||||
return policy
|
||||
except AttributeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def log_print_level() -> Set[Literal["critical", "error", "warning", "info", "debug"]]:
|
||||
policy = get(LOG_PRINT_LEVEL)
|
||||
try:
|
||||
policy = policy.lower()
|
||||
if policy in {"critical", "error", "warning", "info", "debug"}:
|
||||
return policy
|
||||
except AttributeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
@@ -212,7 +212,6 @@ class Config(BareConfig):
|
||||
for param in [
|
||||
"behaviors",
|
||||
"z_num",
|
||||
"frep",
|
||||
"tolerated_error",
|
||||
"parallel",
|
||||
"repeat",
|
||||
|
||||
@@ -1,37 +1,16 @@
|
||||
import logging
|
||||
|
||||
from .env import log_policy
|
||||
|
||||
# class DebugOnlyFileHandler(logging.FileHandler):
|
||||
# def __init__(
|
||||
# self, filename, mode: str, encoding: Optional[str] = None, delay: bool = False
|
||||
# ) -> None:
|
||||
# super().__init__(filename, mode=mode, encoding=encoding, delay=delay)
|
||||
# self.setLevel(logging.DEBUG)
|
||||
|
||||
# def emit(self, record: logging.LogRecord) -> None:
|
||||
# if not record.levelno == logging.DEBUG:
|
||||
# return
|
||||
# return super().emit(record)
|
||||
from .env import log_file_level, log_print_level
|
||||
|
||||
|
||||
DEFAULT_LEVEL = logging.INFO
|
||||
|
||||
lvl_map = dict(
|
||||
lvl_map: dict[str, int] = dict(
|
||||
debug=logging.DEBUG,
|
||||
info=logging.INFO,
|
||||
warning=logging.WARNING,
|
||||
error=logging.ERROR,
|
||||
fatal=logging.FATAL,
|
||||
critical=logging.CRITICAL,
|
||||
)
|
||||
|
||||
loggers = []
|
||||
|
||||
|
||||
def _set_debug():
|
||||
DEFAULT_LEVEL = logging.DEBUG
|
||||
|
||||
|
||||
def get_logger(name=None):
|
||||
"""returns a logging.Logger instance. This function is there because if scgenerator
|
||||
@@ -50,21 +29,10 @@ def get_logger(name=None):
|
||||
"""
|
||||
name = __name__ if name is None else name
|
||||
logger = logging.getLogger(name)
|
||||
if name not in loggers:
|
||||
loggers.append(logger)
|
||||
return configure_logger(logger)
|
||||
|
||||
|
||||
# def set_level_all(lvl):
|
||||
# _default_lvl =
|
||||
# logging.basicConfig(level=lvl_map[lvl])
|
||||
# for logger in loggers:
|
||||
# logger.setLevel(lvl_map[lvl])
|
||||
# for handler in logger.handlers:
|
||||
# handler.setLevel(lvl_map[lvl])
|
||||
|
||||
|
||||
def configure_logger(logger):
|
||||
def configure_logger(logger: logging.Logger):
|
||||
"""configures a logging.Logger obj
|
||||
|
||||
Parameters
|
||||
@@ -81,17 +49,20 @@ def configure_logger(logger):
|
||||
"""
|
||||
if not hasattr(logger, "already_configured"):
|
||||
|
||||
if "file" in log_policy():
|
||||
print_lvl = lvl_map.get(log_print_level())
|
||||
file_lvl = lvl_map.get(log_file_level())
|
||||
|
||||
if file_lvl is not None:
|
||||
formatter = logging.Formatter("{levelname}: {name}: {message}", style="{")
|
||||
file_handler1 = logging.FileHandler("scgenerator.log", "a+")
|
||||
file_handler1.setFormatter(formatter)
|
||||
file_handler1.setLevel(logging.DEBUG)
|
||||
file_handler1.setLevel(file_lvl)
|
||||
logger.addHandler(file_handler1)
|
||||
if "print" in log_policy():
|
||||
if print_lvl is not None:
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setLevel(logging.INFO)
|
||||
stream_handler.setLevel(print_lvl)
|
||||
logger.addHandler(stream_handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
logger.setLevel(min(print_lvl, file_lvl))
|
||||
logger.already_configured = True
|
||||
return logger
|
||||
|
||||
@@ -91,7 +91,7 @@ def u_nm(n, m):
|
||||
|
||||
|
||||
@np_cache
|
||||
def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
def ndft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
"""creates the nfft matrix
|
||||
|
||||
Parameters
|
||||
@@ -111,7 +111,7 @@ def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
@np_cache
|
||||
def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
def indft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
"""creates the nfft matrix
|
||||
|
||||
Parameters
|
||||
@@ -126,10 +126,10 @@ def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
np.ndarray, shape = (m, n)
|
||||
multiply ~X(f) by this matrix to get x(t)
|
||||
"""
|
||||
return np.linalg.pinv(nfft_matrix(t, f))
|
||||
return np.linalg.pinv(ndft_matrix(t, f))
|
||||
|
||||
|
||||
def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
def ndft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
"""computes the Fourier transform of an uneven signal
|
||||
|
||||
Parameters
|
||||
@@ -146,10 +146,10 @@ def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray:
|
||||
np.ndarray, shape = (m, )
|
||||
amplitude at each frequency
|
||||
"""
|
||||
return nfft_matrix(t, f) @ s
|
||||
return ndft_matrix(t, f) @ s
|
||||
|
||||
|
||||
def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
|
||||
def indft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
|
||||
"""computes the inverse Fourier transform of an uneven spectrum
|
||||
|
||||
Parameters
|
||||
@@ -166,7 +166,7 @@ def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray:
|
||||
np.ndarray, shape = (m, )
|
||||
amplitude at each point of t
|
||||
"""
|
||||
return infft_matrix(t, f) @ a
|
||||
return indft_matrix(t, f) @ a
|
||||
|
||||
|
||||
def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"):
|
||||
|
||||
@@ -130,7 +130,7 @@ def n_eff_marcatili(lambda_, n_gas_2, core_radius, he_mode=(1, 1)):
|
||||
lambda_ : ndarray, shape (n, )
|
||||
wavelengths array (m)
|
||||
n_gas_2 : ndarray, shape (n, )
|
||||
refractive index of the gas as function of lambda_
|
||||
square of the refractive index of the gas as function of lambda_
|
||||
core_radius : float
|
||||
inner radius of the capillary (m)
|
||||
he_mode : tuple, shape (2, ), optional
|
||||
@@ -455,7 +455,6 @@ def HCPCF_dispersion(
|
||||
Temperature of the material
|
||||
pressure : float
|
||||
constant pressure
|
||||
FIXME tupple : a pressure gradient from pressure[0] to pressure[1] is computed
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -692,15 +691,7 @@ def compute_dispersion(params: BareParams):
|
||||
)
|
||||
|
||||
else:
|
||||
# Load material info
|
||||
gas_name = params.gas_name
|
||||
|
||||
if gas_name == "vacuum":
|
||||
material_dico = None
|
||||
else:
|
||||
material_dico = toml.loads(io.Paths.gets("gas"))[gas_name]
|
||||
|
||||
# compute dispersion
|
||||
material_dico = toml.loads(io.Paths.gets("gas"))[params.gas_name]
|
||||
if params.dynamic_dispersion:
|
||||
return dynamic_HCPCF_dispersion(
|
||||
lambda_,
|
||||
@@ -716,9 +707,6 @@ def compute_dispersion(params: BareParams):
|
||||
params.interp_degree,
|
||||
)
|
||||
else:
|
||||
|
||||
# actually compute the dispersion
|
||||
|
||||
beta2 = HCPCF_dispersion(
|
||||
lambda_,
|
||||
material_dico,
|
||||
@@ -730,10 +718,14 @@ def compute_dispersion(params: BareParams):
|
||||
)
|
||||
|
||||
if material_dico is not None:
|
||||
A_eff = 1.5 * params.core_radius ** 2
|
||||
|
||||
A_eff = 1.5 * params.core_radius ** 2 if params.A_eff is None else params.A_eff
|
||||
if params.n2 is None:
|
||||
n2 = mat.non_linear_refractive_index(
|
||||
material_dico, params.pressure, params.temperature
|
||||
)
|
||||
else:
|
||||
n2 = params.n2
|
||||
gamma = gamma_parameter(n2, params.w0, A_eff)
|
||||
else:
|
||||
gamma = None
|
||||
|
||||
@@ -39,7 +39,7 @@ def number_density_van_der_waals(
|
||||
ValueError : Since the Van der Waals equation is a cubic one, there could be more than one real, positive solution
|
||||
"""
|
||||
|
||||
logger = get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if pressure == 0:
|
||||
return 0
|
||||
@@ -73,7 +73,7 @@ def number_density_van_der_waals(
|
||||
s = f"Van der Waals eq with parameters P={pressure}, T={temperature}, a={a}, b={b}"
|
||||
s += f", There is more than one possible number density : {roots}."
|
||||
s += f", {np.min(roots)} was returned"
|
||||
logger.info(s)
|
||||
logger.warning(s)
|
||||
return np.min(roots)
|
||||
|
||||
|
||||
@@ -90,6 +90,8 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
|
||||
----------
|
||||
an array n(lambda_)^2 - 1
|
||||
"""
|
||||
logger = get_logger(__name__)
|
||||
|
||||
WL_THRESHOLD = 8.285e-6
|
||||
temp_l = lambda_[lambda_ < WL_THRESHOLD]
|
||||
kind = 1
|
||||
@@ -101,7 +103,6 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
|
||||
t0 = material_dico["sellmeier"].get("t0", 273.15)
|
||||
kind = material_dico["sellmeier"].get("kind", 1)
|
||||
|
||||
# Sellmeier equation
|
||||
chi = np.zeros_like(lambda_) # = n^2 - 1
|
||||
if kind == 1:
|
||||
for b, c in zip(B, C):
|
||||
@@ -120,6 +121,7 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None):
|
||||
if pressure is not None:
|
||||
chi *= pressure / P0
|
||||
|
||||
logger.debug(f"computed chi between {np.min(chi):.2e} and {np.max(chi):.2e}")
|
||||
return chi
|
||||
|
||||
|
||||
|
||||
@@ -3,11 +3,10 @@ import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Type
|
||||
from typing_extensions import runtime
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import env, initialize, io, utils
|
||||
from .. import const, env, initialize, io, utils
|
||||
from ..errors import IncompleteDataFolderError
|
||||
from ..logger import get_logger
|
||||
from . import pulse
|
||||
@@ -290,7 +289,6 @@ class SequentialRK4IP(RK4IP):
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
task_id=0,
|
||||
n_percent=10,
|
||||
):
|
||||
self.pbars = pbars
|
||||
super().__init__(
|
||||
@@ -298,7 +296,6 @@ class SequentialRK4IP(RK4IP):
|
||||
save_data=save_data,
|
||||
job_identifier=job_identifier,
|
||||
task_id=task_id,
|
||||
n_percent=n_percent,
|
||||
)
|
||||
|
||||
def step_saved(self):
|
||||
@@ -314,7 +311,6 @@ class MutliProcRK4IP(RK4IP):
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
task_id=0,
|
||||
n_percent=10,
|
||||
):
|
||||
self.worker_id = worker_id
|
||||
self.p_queue = p_queue
|
||||
@@ -323,7 +319,6 @@ class MutliProcRK4IP(RK4IP):
|
||||
save_data=save_data,
|
||||
job_identifier=job_identifier,
|
||||
task_id=task_id,
|
||||
n_percent=n_percent,
|
||||
)
|
||||
|
||||
def step_saved(self):
|
||||
@@ -512,6 +507,9 @@ class MultiProcSimulations(Simulations, priority=1):
|
||||
|
||||
def __init__(self, param_seq: initialize.ParamSequence, task_id):
|
||||
super().__init__(param_seq, task_id=task_id)
|
||||
if param_seq.config.worker_num is not None:
|
||||
self.sim_jobs_per_node = param_seq.config.worker_num
|
||||
else:
|
||||
self.sim_jobs_per_node = max(1, os.cpu_count() // 2)
|
||||
self.queue = multiprocessing.JoinableQueue(self.sim_jobs_per_node)
|
||||
self.progress_queue = multiprocessing.Queue()
|
||||
@@ -656,6 +654,8 @@ class RaySimulations(Simulations, priority=2):
|
||||
|
||||
@property
|
||||
def sim_jobs_total(self):
|
||||
if self.param_seq.config.worker_num is not None:
|
||||
return self.param_seq.config.worker_num
|
||||
tot_cpus = sum([node.get("Resources", {}).get("CPU", 0) for node in ray.nodes()])
|
||||
tot_cpus = min(tot_cpus, self.max_concurrent_jobs)
|
||||
return int(min(self.param_seq.num_sim, tot_cpus))
|
||||
@@ -664,7 +664,6 @@ class RaySimulations(Simulations, priority=2):
|
||||
def run_simulation_sequence(
|
||||
*config_files: os.PathLike,
|
||||
method=None,
|
||||
final_name: str = None,
|
||||
prev_sim_dir: os.PathLike = None,
|
||||
):
|
||||
prev = prev_sim_dir
|
||||
@@ -674,6 +673,7 @@ def run_simulation_sequence(
|
||||
prev = sim.sim_dir
|
||||
path_trees = io.build_path_trees(sim.sim_dir)
|
||||
|
||||
final_name = env.get(const.OUTPUT_PATH)
|
||||
if final_name is None:
|
||||
final_name = path_trees[0][-1][0].parent.name + " merged"
|
||||
|
||||
@@ -687,6 +687,7 @@ def new_simulation(
|
||||
) -> Simulations:
|
||||
|
||||
config_dict = io.load_toml(config_file)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if prev_sim_dir is not None:
|
||||
config_dict["prev_sim_dir"] = str(prev_sim_dir)
|
||||
@@ -698,7 +699,7 @@ def new_simulation(
|
||||
else:
|
||||
param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config_dict)
|
||||
|
||||
print(f"{param_seq.name=}")
|
||||
logger.info(f"running {param_seq.name}")
|
||||
|
||||
return Simulations.new(param_seq, task_id, method)
|
||||
|
||||
|
||||
@@ -422,7 +422,7 @@ def plot_spectrogram(
|
||||
values = spec[ind_t][:, ind_f]
|
||||
if f_range[2].type == "WL":
|
||||
values = np.apply_along_axis(
|
||||
units.to_WL, 1, values, params.frep, units.m(f_range[2].inv(new_f))
|
||||
units.to_WL, 1, values, params.repetition_rate, units.m(f_range[2].inv(new_f))
|
||||
)
|
||||
values = np.apply_along_axis(make_uniform_1D, 1, values, new_f)
|
||||
|
||||
@@ -528,7 +528,7 @@ def plot_results_2D(
|
||||
# make uniform if converting to wavelength
|
||||
if plt_range.unit.type == "WL":
|
||||
if is_spectrum:
|
||||
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
||||
values = np.apply_along_axis(units.to_WL, 1, values, params.repetition_rate, x_axis)
|
||||
values = np.array(
|
||||
[make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values]
|
||||
)
|
||||
@@ -648,7 +648,7 @@ def plot_results_1D(
|
||||
# make uniform if converting to wavelength
|
||||
if plt_range.unit.type == "WL":
|
||||
if is_spectrum:
|
||||
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind]))
|
||||
values = units.to_WL(values, params.repetition_rate, units.m.inv(params.w[ind]))
|
||||
|
||||
# change the resolution
|
||||
if isinstance(spacing, float):
|
||||
@@ -810,8 +810,8 @@ def plot_avg(
|
||||
values *= yscaling
|
||||
mean_values = np.mean(values, axis=0)
|
||||
if plt_range.unit.type == "WL" and renormalize:
|
||||
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
|
||||
mean_values = units.to_WL(mean_values, params.frep, x_axis)
|
||||
values = np.apply_along_axis(units.to_WL, 1, values, params.repetition_rate, x_axis)
|
||||
mean_values = units.to_WL(mean_values, params.repetition_rate, x_axis)
|
||||
|
||||
# change the resolution
|
||||
if isinstance(spacing, float):
|
||||
|
||||
@@ -8,13 +8,9 @@ from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
from toml import load
|
||||
|
||||
from scgenerator.utils.parameter import BareConfig
|
||||
|
||||
from ..initialize import validate_config_sequence
|
||||
from ..io import Paths, load_config
|
||||
from ..utils import count_variations
|
||||
|
||||
|
||||
def primes(n):
|
||||
@@ -90,7 +86,7 @@ def create_parser():
|
||||
"--environment-setup",
|
||||
required=False,
|
||||
default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && "
|
||||
"export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_POLICY=file",
|
||||
"export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_PRINT_LEVEL=none && export SCGENERATOR_LOG_FILE_LEVEL=info",
|
||||
help="commands to run to setup the environement (default : activate the sc environment with conda)",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -198,7 +198,7 @@ class Pulse(Sequence):
|
||||
spec = np.load(self.path / SPECN_FN.format(i))
|
||||
if self.__ensure_2d:
|
||||
spec = np.atleast_2d(spec)
|
||||
spec = Spectrum(spec, self.wl, self.params.frep)
|
||||
spec = Spectrum(spec, self.wl, self.params.repetition_rate)
|
||||
self.cache[i] = spec
|
||||
return spec
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ def count_variations(config: BareConfig) -> int:
|
||||
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
|
||||
variable_params_num is the number of distinct parameters that will vary."""
|
||||
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
|
||||
return sim_num
|
||||
return int(sim_num)
|
||||
|
||||
|
||||
def format_variable_list(l: List[Tuple[str, Any]]):
|
||||
|
||||
@@ -339,7 +339,7 @@ class BareParams:
|
||||
peak_power: float = Parameter(positive(float, int))
|
||||
mean_power: float = Parameter(positive(float, int))
|
||||
energy: float = Parameter(positive(float, int))
|
||||
soliton_num: float = Parameter(positive(float, int))
|
||||
soliton_num: float = Parameter(non_negative(float, int))
|
||||
quantum_noise: bool = Parameter(boolean)
|
||||
shape: str = Parameter(literal("gaussian", "sech"))
|
||||
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9))
|
||||
@@ -362,10 +362,10 @@ class BareParams:
|
||||
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9))
|
||||
upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9))
|
||||
interp_degree: int = Parameter(positive(int))
|
||||
frep: float = Parameter(positive(float, int))
|
||||
prev_sim_dir: str = Parameter(string)
|
||||
readjust_wavelength: bool = Parameter(boolean)
|
||||
recovery_last_stored: int = Parameter(non_negative(int))
|
||||
worker_num: int = Parameter(positive(int))
|
||||
|
||||
# computed
|
||||
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
|
||||
@@ -12,7 +12,6 @@ class TestRecoveryParamSequence(unittest.TestCase):
|
||||
def setUp(self):
|
||||
shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP)
|
||||
self.conf = toml.load(TMP + "/initial_config.toml")
|
||||
logger.DEFAULT_LEVEL = logger.logging.FATAL
|
||||
io.set_data_folder(55, TMP)
|
||||
|
||||
def test_remaining_simulations_count(self):
|
||||
|
||||
Reference in New Issue
Block a user