From 1dc90439fe8baffd1f41fb0601f3a58d8808055d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 17 Jun 2021 10:41:06 +0200 Subject: [PATCH] custom worker num, better logging, better config --- src/scgenerator/__init__.py | 1 + src/scgenerator/cli/cli.py | 70 ++++++++++--------- src/scgenerator/const.py | 40 ++++++++++- src/scgenerator/data/gas.toml | 26 +++++++ src/scgenerator/env.py | 51 ++++++++++---- src/scgenerator/initialize.py | 1 - src/scgenerator/logger.py | 51 +++----------- src/scgenerator/math.py | 14 ++-- src/scgenerator/physics/fiber.py | 28 +++----- src/scgenerator/physics/materials.py | 8 ++- src/scgenerator/physics/simulate.py | 19 ++--- src/scgenerator/plotting.py | 10 +-- src/scgenerator/scripts/slurm_submit.py | 6 +- src/scgenerator/spectra.py | 2 +- src/scgenerator/utils/__init__.py | 2 +- src/scgenerator/utils/parameter.py | 4 +- testing/long_tests/test_recovery_param_seq.py | 1 - 17 files changed, 193 insertions(+), 141 deletions(-) diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 827cad8..aa16c7d 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -4,3 +4,4 @@ from .math import abs2, argclosest, span from .physics import fiber, materials, pulse, simulate, units from .physics.simulate import RK4IP, new_simulation, resume_simulations from .plotting import plot_avg, plot_results_1D, plot_results_2D, plot_spectrogram +from .spectra import Pulse diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index dbd9f2d..f4636c6 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -2,15 +2,19 @@ import argparse import os import random from pathlib import Path +from collections import ChainMap -from scgenerator import io -from scgenerator.physics.simulate import ( +from ray.worker import get + +from .. import io, env, const +from ..logger import get_logger +from ..physics.simulate import ( SequencialSimulations, resume_simulations, run_simulation_sequence, ) -from scgenerator.physics.fiber import dispersion_coefficients - +from ..physics.fiber import dispersion_coefficients +from pprint import pprint try: import ray @@ -18,29 +22,26 @@ except ImportError: ray = None +def set_env_variables(cmd_line_args: dict[str, str]): + cm = ChainMap(cmd_line_args, os.environ) + for env_key in const.global_config: + k = env_key.replace(const.ENVIRON_KEY_BASE, "").lower() + v = cm.get(k) + if v is not None: + os.environ[env_key] = str(v) + + def create_parser(): parser = argparse.ArgumentParser(description="scgenerator command", prog="scgenerator") - subparsers = parser.add_subparsers(help="sub-command help") - parser.add_argument( - "--id", - type=int, - default=random.randint(0, 1e18), - help="Unique id of the session. Only useful when running several processes at the same time.", - ) - parser.add_argument( - "--start-ray", - action="store_true", - help="initialize ray (ray must be installed)", - ) - - parser.add_argument( - "--no-ray", - action="store_true", - help="force not to use ray", - ) - parser.add_argument("--output-name", "-o", help="path to the final output folder", default=None) + for key, args in const.global_config.items(): + names = ["--" + key.replace(const.ENVIRON_KEY_BASE, "").replace("_", "-").lower()] + if "short_name" in args: + names.append(args["short_name"]) + parser.add_argument( + *names, **{k: v for k, v in args.items() if k not in {"short_name", "type"}} + ) run_parser = subparsers.add_parser("run", help="run a simulation from a config file") run_parser.add_argument("configs", help="path(s) to the toml configuration file(s)", nargs="+") @@ -71,15 +72,19 @@ def create_parser(): def main(): parser = create_parser() args = parser.parse_args() + + set_env_variables({k: v for k, v in vars(args).items() if v is not None}) + args.func(args) - print(f"coef hits : {dispersion_coefficients.hits}, misses : {dispersion_coefficients.misses}") + logger = get_logger(__name__) + logger.info(f"dispersion cache : {dispersion_coefficients.cache_info()}") def run_sim(args): method = prep_ray(args) - run_simulation_sequence(*args.configs, method=method, final_name=args.output_name) + run_simulation_sequence(*args.configs, method=method) def merge(args): @@ -91,19 +96,20 @@ def merge(args): def prep_ray(args): + logger = get_logger(__name__) if ray: - if args.start_ray: + if env.get(const.START_RAY): init_str = ray.init() - elif not args.no_ray: + elif not env.get(const.NO_RAY): try: init_str = ray.init( address="auto", _redis_password=os.environ.get("redis_password", "caco1234"), ) - print(init_str) + logger.info(init_str) except ConnectionError as e: - print(e) - return SequencialSimulations if args.no_ray else None + logger.error(e) + return SequencialSimulations if env.get(const.NO_RAY) else None def resume_sim(args): @@ -111,9 +117,7 @@ def resume_sim(args): method = prep_ray(args) sim = resume_simulations(Path(args.sim_dir), method=method) sim.run() - run_simulation_sequence( - *args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name - ) + run_simulation_sequence(*args.configs, method=method, prev_sim_dir=sim.sim_dir) if __name__ == "__main__": diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index 7e24687..ead019e 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -1,6 +1,9 @@ __version__ = "0.1.0" +from typing import Any + + def pbar_format(worker_id: int): if worker_id == 0: return dict( @@ -17,12 +20,45 @@ def pbar_format(worker_id: int): ENVIRON_KEY_BASE = "SCGENERATOR_" -PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY" -LOG_POLICY = ENVIRON_KEY_BASE + "LOG_POLICY" TMP_FOLDER_KEY_BASE = ENVIRON_KEY_BASE + "SC_TMP_" PREFIX_KEY_BASE = ENVIRON_KEY_BASE + "PREFIX_" PARAM_SEPARATOR = " " +PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY" +LOG_FILE_LEVEL = ENVIRON_KEY_BASE + "LOG_FILE_LEVEL" +LOG_PRINT_LEVEL = ENVIRON_KEY_BASE + "LOG_PRINT_LEVEL" +START_RAY = ENVIRON_KEY_BASE + "START_RAY" +NO_RAY = ENVIRON_KEY_BASE + "NO_RAY" +OUTPUT_PATH = ENVIRON_KEY_BASE + "OUTPUT_PATH" + + +global_config: dict[str, dict[str, Any]] = { + LOG_FILE_LEVEL: dict( + help="minimum lvl of message to be saved in the log file", + choices=["critical", "error", "warning", "info", "debug"], + default=None, + type=str, + ), + LOG_PRINT_LEVEL: dict( + help="minimum lvl of message to be printed to the standard output", + choices=["critical", "error", "warning", "info", "debug"], + default="error", + type=str, + ), + PBAR_POLICY: dict( + help="what to do with progress pars (print them, make them a txt file or nothing), default is print", + choices=["print", "file", "both", "none"], + default=None, + type=str, + ), + START_RAY: dict(action="store_true", help="initialize ray (ray must be installed)", type=bool), + NO_RAY: dict(action="store_true", help="force not to use ray", type=bool), + OUTPUT_PATH: dict( + short_name="-o", help="path to the final output folder", default=None, type=str + ), +} + + SPEC1_FN = "spectrum_{}.npy" SPECN_FN = "spectra_{}.npy" Z_FN = "z.npy" diff --git a/src/scgenerator/data/gas.toml b/src/scgenerator/data/gas.toml index 977289a..1e61685 100644 --- a/src/scgenerator/data/gas.toml +++ b/src/scgenerator/data/gas.toml @@ -38,6 +38,14 @@ b = 3.978e-5 a = 0.425 b = 5.105e-5 +[air] +a = 0.1358 +b = 3.64e-5 + +[vacuum] +a = 0 +b = 0 + [air.sellmeier] B = [57921050000.0, 1679170000.0] C = [238018500000000.0, 57362000000000.0] @@ -45,6 +53,11 @@ P0 = 101325 T0 = 288.15 kind = 2 +[air.kerr] +P0 = 101325 +T0 = 293.15 +n2 = 3.01e-23 + [nitrogen.sellmeier] B = [32431570000.0] C = [144000000000000.0] @@ -181,3 +194,16 @@ P0 = 30400.0 T0 = 273.15 n2 = 5.8e-23 source = "Wahlstrand, J. K., Cheng, Y. H., & Milchberg, H. M. (2012). High field optical nonlinearity and the Kramers-Kronig relations. Physical review letters, 109(11), 113904." + +[vacuum.sellmeier] +B = [] +C = [] +P0 = 101325 +T0 = 273.15 +kind = 1 + +[vacuum.kerr] +P0 = 30400.0 +T0 = 273.15 +n2 = 0 +source = "none" diff --git a/src/scgenerator/env.py b/src/scgenerator/env.py index 3cc445b..0f6ff82 100644 --- a/src/scgenerator/env.py +++ b/src/scgenerator/env.py @@ -1,7 +1,14 @@ import os -from typing import Dict, Literal, Optional, Set +from typing import Any, Dict, Literal, Optional, Set -from .const import ENVIRON_KEY_BASE, LOG_POLICY, PBAR_POLICY, TMP_FOLDER_KEY_BASE +from .const import ( + ENVIRON_KEY_BASE, + LOG_FILE_LEVEL, + LOG_PRINT_LEVEL, + PBAR_POLICY, + TMP_FOLDER_KEY_BASE, + global_config, +) def data_folder(task_id: int) -> Optional[str]: @@ -10,6 +17,14 @@ def data_folder(task_id: int) -> Optional[str]: return tmp +def get(key: str) -> Any: + str_value = os.environ.get(key) + try: + return global_config[key]["type"](str_value) + except (ValueError, KeyError): + return None + + def all_environ() -> Dict[str, str]: """returns a dictionary of all environment variables set by any instance of scgenerator""" d = dict(filter(lambda el: el[0].startswith(ENVIRON_KEY_BASE), os.environ.items())) @@ -17,7 +32,7 @@ def all_environ() -> Dict[str, str]: def pbar_policy() -> Set[Literal["print", "file"]]: - policy = os.getenv(PBAR_POLICY) + policy = get(PBAR_POLICY) if policy == "print" or policy is None: return {"print"} elif policy == "file": @@ -28,13 +43,23 @@ def pbar_policy() -> Set[Literal["print", "file"]]: return set() -def log_policy() -> Set[Literal["print", "file"]]: - policy = os.getenv(LOG_POLICY) - if policy == "print" or policy is None: - return {"print"} - elif policy == "file": - return {"file"} - elif policy == "both": - return {"file", "print"} - else: - return set() +def log_file_level() -> Set[Literal["critical", "error", "warning", "info", "debug"]]: + policy = get(LOG_FILE_LEVEL) + try: + policy = policy.lower() + if policy in {"critical", "error", "warning", "info", "debug"}: + return policy + except AttributeError: + pass + return None + + +def log_print_level() -> Set[Literal["critical", "error", "warning", "info", "debug"]]: + policy = get(LOG_PRINT_LEVEL) + try: + policy = policy.lower() + if policy in {"critical", "error", "warning", "info", "debug"}: + return policy + except AttributeError: + pass + return None diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 6a7f3ab..a3d726a 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -212,7 +212,6 @@ class Config(BareConfig): for param in [ "behaviors", "z_num", - "frep", "tolerated_error", "parallel", "repeat", diff --git a/src/scgenerator/logger.py b/src/scgenerator/logger.py index 50f5299..04de7fb 100644 --- a/src/scgenerator/logger.py +++ b/src/scgenerator/logger.py @@ -1,37 +1,16 @@ import logging -from .env import log_policy - -# class DebugOnlyFileHandler(logging.FileHandler): -# def __init__( -# self, filename, mode: str, encoding: Optional[str] = None, delay: bool = False -# ) -> None: -# super().__init__(filename, mode=mode, encoding=encoding, delay=delay) -# self.setLevel(logging.DEBUG) - -# def emit(self, record: logging.LogRecord) -> None: -# if not record.levelno == logging.DEBUG: -# return -# return super().emit(record) +from .env import log_file_level, log_print_level -DEFAULT_LEVEL = logging.INFO - -lvl_map = dict( +lvl_map: dict[str, int] = dict( debug=logging.DEBUG, info=logging.INFO, warning=logging.WARNING, error=logging.ERROR, - fatal=logging.FATAL, critical=logging.CRITICAL, ) -loggers = [] - - -def _set_debug(): - DEFAULT_LEVEL = logging.DEBUG - def get_logger(name=None): """returns a logging.Logger instance. This function is there because if scgenerator @@ -50,21 +29,10 @@ def get_logger(name=None): """ name = __name__ if name is None else name logger = logging.getLogger(name) - if name not in loggers: - loggers.append(logger) return configure_logger(logger) -# def set_level_all(lvl): -# _default_lvl = -# logging.basicConfig(level=lvl_map[lvl]) -# for logger in loggers: -# logger.setLevel(lvl_map[lvl]) -# for handler in logger.handlers: -# handler.setLevel(lvl_map[lvl]) - - -def configure_logger(logger): +def configure_logger(logger: logging.Logger): """configures a logging.Logger obj Parameters @@ -81,17 +49,20 @@ def configure_logger(logger): """ if not hasattr(logger, "already_configured"): - if "file" in log_policy(): + print_lvl = lvl_map.get(log_print_level()) + file_lvl = lvl_map.get(log_file_level()) + + if file_lvl is not None: formatter = logging.Formatter("{levelname}: {name}: {message}", style="{") file_handler1 = logging.FileHandler("scgenerator.log", "a+") file_handler1.setFormatter(formatter) - file_handler1.setLevel(logging.DEBUG) + file_handler1.setLevel(file_lvl) logger.addHandler(file_handler1) - if "print" in log_policy(): + if print_lvl is not None: stream_handler = logging.StreamHandler() - stream_handler.setLevel(logging.INFO) + stream_handler.setLevel(print_lvl) logger.addHandler(stream_handler) - logger.setLevel(logging.DEBUG) + logger.setLevel(min(print_lvl, file_lvl)) logger.already_configured = True return logger diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 61c1150..2581966 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -91,7 +91,7 @@ def u_nm(n, m): @np_cache -def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray: +def ndft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray: """creates the nfft matrix Parameters @@ -111,7 +111,7 @@ def nfft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray: @np_cache -def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray: +def indft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray: """creates the nfft matrix Parameters @@ -126,10 +126,10 @@ def infft_matrix(t: np.ndarray, f: np.ndarray) -> np.ndarray: np.ndarray, shape = (m, n) multiply ~X(f) by this matrix to get x(t) """ - return np.linalg.pinv(nfft_matrix(t, f)) + return np.linalg.pinv(ndft_matrix(t, f)) -def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray: +def ndft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray: """computes the Fourier transform of an uneven signal Parameters @@ -146,10 +146,10 @@ def nfft(t: np.ndarray, s: np.ndarray, f: np.ndarray) -> np.ndarray: np.ndarray, shape = (m, ) amplitude at each frequency """ - return nfft_matrix(t, f) @ s + return ndft_matrix(t, f) @ s -def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray: +def indft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray: """computes the inverse Fourier transform of an uneven spectrum Parameters @@ -166,7 +166,7 @@ def infft(f: np.ndarray, a: np.ndarray, t: np.ndarray) -> np.ndarray: np.ndarray, shape = (m, ) amplitude at each point of t """ - return infft_matrix(t, f) @ a + return indft_matrix(t, f) @ a def make_uniform_2D(values, x_axis, y_axis, n=1024, method="linear"): diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index d08e529..08fd85b 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -130,7 +130,7 @@ def n_eff_marcatili(lambda_, n_gas_2, core_radius, he_mode=(1, 1)): lambda_ : ndarray, shape (n, ) wavelengths array (m) n_gas_2 : ndarray, shape (n, ) - refractive index of the gas as function of lambda_ + square of the refractive index of the gas as function of lambda_ core_radius : float inner radius of the capillary (m) he_mode : tuple, shape (2, ), optional @@ -455,7 +455,6 @@ def HCPCF_dispersion( Temperature of the material pressure : float constant pressure - FIXME tupple : a pressure gradient from pressure[0] to pressure[1] is computed Returns ------- @@ -692,15 +691,7 @@ def compute_dispersion(params: BareParams): ) else: - # Load material info - gas_name = params.gas_name - - if gas_name == "vacuum": - material_dico = None - else: - material_dico = toml.loads(io.Paths.gets("gas"))[gas_name] - - # compute dispersion + material_dico = toml.loads(io.Paths.gets("gas"))[params.gas_name] if params.dynamic_dispersion: return dynamic_HCPCF_dispersion( lambda_, @@ -716,9 +707,6 @@ def compute_dispersion(params: BareParams): params.interp_degree, ) else: - - # actually compute the dispersion - beta2 = HCPCF_dispersion( lambda_, material_dico, @@ -730,10 +718,14 @@ def compute_dispersion(params: BareParams): ) if material_dico is not None: - A_eff = 1.5 * params.core_radius ** 2 - n2 = mat.non_linear_refractive_index( - material_dico, params.pressure, params.temperature - ) + + A_eff = 1.5 * params.core_radius ** 2 if params.A_eff is None else params.A_eff + if params.n2 is None: + n2 = mat.non_linear_refractive_index( + material_dico, params.pressure, params.temperature + ) + else: + n2 = params.n2 gamma = gamma_parameter(n2, params.w0, A_eff) else: gamma = None diff --git a/src/scgenerator/physics/materials.py b/src/scgenerator/physics/materials.py index 2e38a14..d7832a2 100644 --- a/src/scgenerator/physics/materials.py +++ b/src/scgenerator/physics/materials.py @@ -39,7 +39,7 @@ def number_density_van_der_waals( ValueError : Since the Van der Waals equation is a cubic one, there could be more than one real, positive solution """ - logger = get_logger + logger = get_logger(__name__) if pressure == 0: return 0 @@ -73,7 +73,7 @@ def number_density_van_der_waals( s = f"Van der Waals eq with parameters P={pressure}, T={temperature}, a={a}, b={b}" s += f", There is more than one possible number density : {roots}." s += f", {np.min(roots)} was returned" - logger.info(s) + logger.warning(s) return np.min(roots) @@ -90,6 +90,8 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None): ---------- an array n(lambda_)^2 - 1 """ + logger = get_logger(__name__) + WL_THRESHOLD = 8.285e-6 temp_l = lambda_[lambda_ < WL_THRESHOLD] kind = 1 @@ -101,7 +103,6 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None): t0 = material_dico["sellmeier"].get("t0", 273.15) kind = material_dico["sellmeier"].get("kind", 1) - # Sellmeier equation chi = np.zeros_like(lambda_) # = n^2 - 1 if kind == 1: for b, c in zip(B, C): @@ -120,6 +121,7 @@ def sellmeier(lambda_, material_dico, pressure=None, temperature=None): if pressure is not None: chi *= pressure / P0 + logger.debug(f"computed chi between {np.min(chi):.2e} and {np.max(chi):.2e}") return chi diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index d2cb94f..08486cf 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -3,11 +3,10 @@ import os from datetime import datetime from pathlib import Path from typing import Dict, List, Tuple, Type -from typing_extensions import runtime import numpy as np -from .. import env, initialize, io, utils +from .. import const, env, initialize, io, utils from ..errors import IncompleteDataFolderError from ..logger import get_logger from . import pulse @@ -290,7 +289,6 @@ class SequentialRK4IP(RK4IP): save_data=False, job_identifier="", task_id=0, - n_percent=10, ): self.pbars = pbars super().__init__( @@ -298,7 +296,6 @@ class SequentialRK4IP(RK4IP): save_data=save_data, job_identifier=job_identifier, task_id=task_id, - n_percent=n_percent, ) def step_saved(self): @@ -314,7 +311,6 @@ class MutliProcRK4IP(RK4IP): save_data=False, job_identifier="", task_id=0, - n_percent=10, ): self.worker_id = worker_id self.p_queue = p_queue @@ -323,7 +319,6 @@ class MutliProcRK4IP(RK4IP): save_data=save_data, job_identifier=job_identifier, task_id=task_id, - n_percent=n_percent, ) def step_saved(self): @@ -512,7 +507,10 @@ class MultiProcSimulations(Simulations, priority=1): def __init__(self, param_seq: initialize.ParamSequence, task_id): super().__init__(param_seq, task_id=task_id) - self.sim_jobs_per_node = max(1, os.cpu_count() // 2) + if param_seq.config.worker_num is not None: + self.sim_jobs_per_node = param_seq.config.worker_num + else: + self.sim_jobs_per_node = max(1, os.cpu_count() // 2) self.queue = multiprocessing.JoinableQueue(self.sim_jobs_per_node) self.progress_queue = multiprocessing.Queue() self.workers = [ @@ -656,6 +654,8 @@ class RaySimulations(Simulations, priority=2): @property def sim_jobs_total(self): + if self.param_seq.config.worker_num is not None: + return self.param_seq.config.worker_num tot_cpus = sum([node.get("Resources", {}).get("CPU", 0) for node in ray.nodes()]) tot_cpus = min(tot_cpus, self.max_concurrent_jobs) return int(min(self.param_seq.num_sim, tot_cpus)) @@ -664,7 +664,6 @@ class RaySimulations(Simulations, priority=2): def run_simulation_sequence( *config_files: os.PathLike, method=None, - final_name: str = None, prev_sim_dir: os.PathLike = None, ): prev = prev_sim_dir @@ -674,6 +673,7 @@ def run_simulation_sequence( prev = sim.sim_dir path_trees = io.build_path_trees(sim.sim_dir) + final_name = env.get(const.OUTPUT_PATH) if final_name is None: final_name = path_trees[0][-1][0].parent.name + " merged" @@ -687,6 +687,7 @@ def new_simulation( ) -> Simulations: config_dict = io.load_toml(config_file) + logger = get_logger(__name__) if prev_sim_dir is not None: config_dict["prev_sim_dir"] = str(prev_sim_dir) @@ -698,7 +699,7 @@ def new_simulation( else: param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config_dict) - print(f"{param_seq.name=}") + logger.info(f"running {param_seq.name}") return Simulations.new(param_seq, task_id, method) diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index 3de9cba..e30a782 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -422,7 +422,7 @@ def plot_spectrogram( values = spec[ind_t][:, ind_f] if f_range[2].type == "WL": values = np.apply_along_axis( - units.to_WL, 1, values, params.frep, units.m(f_range[2].inv(new_f)) + units.to_WL, 1, values, params.repetition_rate, units.m(f_range[2].inv(new_f)) ) values = np.apply_along_axis(make_uniform_1D, 1, values, new_f) @@ -528,7 +528,7 @@ def plot_results_2D( # make uniform if converting to wavelength if plt_range.unit.type == "WL": if is_spectrum: - values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis) + values = np.apply_along_axis(units.to_WL, 1, values, params.repetition_rate, x_axis) values = np.array( [make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values] ) @@ -648,7 +648,7 @@ def plot_results_1D( # make uniform if converting to wavelength if plt_range.unit.type == "WL": if is_spectrum: - values = units.to_WL(values, params.frep, units.m.inv(params.w[ind])) + values = units.to_WL(values, params.repetition_rate, units.m.inv(params.w[ind])) # change the resolution if isinstance(spacing, float): @@ -810,8 +810,8 @@ def plot_avg( values *= yscaling mean_values = np.mean(values, axis=0) if plt_range.unit.type == "WL" and renormalize: - values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis) - mean_values = units.to_WL(mean_values, params.frep, x_axis) + values = np.apply_along_axis(units.to_WL, 1, values, params.repetition_rate, x_axis) + mean_values = units.to_WL(mean_values, params.repetition_rate, x_axis) # change the resolution if isinstance(spacing, float): diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index 3d71605..70e05a8 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -8,13 +8,9 @@ from pathlib import Path from typing import Tuple import numpy as np -from toml import load - -from scgenerator.utils.parameter import BareConfig from ..initialize import validate_config_sequence from ..io import Paths, load_config -from ..utils import count_variations def primes(n): @@ -90,7 +86,7 @@ def create_parser(): "--environment-setup", required=False, default=f"source {os.path.expanduser('~/anaconda3/etc/profile.d/conda.sh')} && conda activate sc && " - "export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_POLICY=file", + "export SCGENERATOR_PBAR_POLICY=file && export SCGENERATOR_LOG_PRINT_LEVEL=none && export SCGENERATOR_LOG_FILE_LEVEL=info", help="commands to run to setup the environement (default : activate the sc environment with conda)", ) parser.add_argument( diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 6329256..9eb81ab 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -198,7 +198,7 @@ class Pulse(Sequence): spec = np.load(self.path / SPECN_FN.format(i)) if self.__ensure_2d: spec = np.atleast_2d(spec) - spec = Spectrum(spec, self.wl, self.params.frep) + spec = Spectrum(spec, self.wl, self.params.repetition_rate) self.cache[i] = spec return spec diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index 7c81e99..35fa78e 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -187,7 +187,7 @@ def count_variations(config: BareConfig) -> int: """returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and variable_params_num is the number of distinct parameters that will vary.""" sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat - return sim_num + return int(sim_num) def format_variable_list(l: List[Tuple[str, Any]]): diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index ccbccd8..c842639 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -339,7 +339,7 @@ class BareParams: peak_power: float = Parameter(positive(float, int)) mean_power: float = Parameter(positive(float, int)) energy: float = Parameter(positive(float, int)) - soliton_num: float = Parameter(positive(float, int)) + soliton_num: float = Parameter(non_negative(float, int)) quantum_noise: bool = Parameter(boolean) shape: str = Parameter(literal("gaussian", "sech")) wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9)) @@ -362,10 +362,10 @@ class BareParams: lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9)) upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9)) interp_degree: int = Parameter(positive(int)) - frep: float = Parameter(positive(float, int)) prev_sim_dir: str = Parameter(string) readjust_wavelength: bool = Parameter(boolean) recovery_last_stored: int = Parameter(non_negative(int)) + worker_num: int = Parameter(positive(int)) # computed field_0: np.ndarray = Parameter(type_checker(np.ndarray)) diff --git a/testing/long_tests/test_recovery_param_seq.py b/testing/long_tests/test_recovery_param_seq.py index cc1896e..037bb22 100644 --- a/testing/long_tests/test_recovery_param_seq.py +++ b/testing/long_tests/test_recovery_param_seq.py @@ -12,7 +12,6 @@ class TestRecoveryParamSequence(unittest.TestCase): def setUp(self): shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP) self.conf = toml.load(TMP + "/initial_config.toml") - logger.DEFAULT_LEVEL = logger.logging.FATAL io.set_data_folder(55, TMP) def test_remaining_simulations_count(self):