From aae35e2b63428aa84a18f379ef68b1b47d3426a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Wed, 12 May 2021 11:04:53 +0200 Subject: [PATCH] swtiched to tqdm progress bars --- play.py | 6 + sc-DEBUG.log | 0 src/scgenerator/defaults.py | 2 +- src/scgenerator/initialize.py | 30 +++- src/scgenerator/io.py | 13 +- src/scgenerator/logger.py | 25 ++- src/scgenerator/math.py | 25 ++- src/scgenerator/physics/fiber.py | 58 ++++--- src/scgenerator/physics/pulse.py | 3 + src/scgenerator/physics/simulate.py | 231 +++++++++++++++++++++++----- src/scgenerator/utils.py | 82 +++++++++- 11 files changed, 393 insertions(+), 82 deletions(-) create mode 100644 play.py create mode 100644 sc-DEBUG.log diff --git a/play.py b/play.py new file mode 100644 index 0000000..0fb33a2 --- /dev/null +++ b/play.py @@ -0,0 +1,6 @@ +from tqdm import tqdm +import time +import random + +for i in tqdm(range(100), smoothing=0): + time.sleep(random.random()) diff --git a/sc-DEBUG.log b/sc-DEBUG.log new file mode 100644 index 0000000..e69de29 diff --git a/src/scgenerator/defaults.py b/src/scgenerator/defaults.py index 0197e55..a5a345f 100644 --- a/src/scgenerator/defaults.py +++ b/src/scgenerator/defaults.py @@ -23,7 +23,7 @@ default_parameters = dict( parallel=False, repeat=1, tolerated_error=1e-11, - lower_wavelength_interp_limit=0, + lower_wavelength_interp_limit=300e-9, upper_wavelength_interp_limit=1900e-9, ideal_gas=False, ) diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 5a4aae7..838e978 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -4,6 +4,7 @@ from typing import Any, Iterator, List, Tuple import numpy as np from numpy import pi +from tqdm import tqdm from . import defaults, io, utils from .const import hc_model_specific_parameters, valid_param_types, valid_variable @@ -18,9 +19,10 @@ class ParamSequence(Mapping): def __init__(self, config): self.config = validate(config) self.name = self.config["name"] + self.logger = get_logger(__name__) self.num_sim, self.num_variable = count_variations(self.config) - self.num_steps = self.num_sim * self.config["simulation", "z_num"] + self.num_steps = self.num_sim * self.config["simulation"]["z_num"] self.single_sim = self.num_sim == 1 def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: @@ -44,11 +46,31 @@ class RecoveryParamSequence(ParamSequence): super().__init__(config) self.id = task_id self.num_steps = 0 - for sub_folder in io.get_data_subfolders(io.get_data_folder(self.id)): - num_left = io.num_left_to_propagate(sub_folder, config["simulation"]["z_num"]) + + z_num = config["simulation"]["z_num"] + started = self.num_sim + sub_folders = io.get_data_subfolders(io.get_data_folder(self.id)) + + pbar_store = utils.PBars( + tqdm( + total=len(sub_folders), + desc="Initial recovery process", + unit="sim", + ncols=100, + ) + ) + + for sub_folder in sub_folders: + num_left = io.num_left_to_propagate(sub_folder, z_num) if num_left == 0: self.num_sim -= 1 self.num_steps += num_left + started -= 1 + pbar_store.update() + + pbar_store.close() + + self.num_steps += started * z_num self.single_sim = self.num_sim == 1 def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: @@ -579,7 +601,7 @@ def _generate_sim_grid(params): params["w0"] = w0 params["w_c"] = w_c params["w"] = w_c + w0 - params["w_power_fact"] = [power_fact(w_c, k) for k in range(2, 11)] + params["w_power_fact"] = np.array([power_fact(w_c, k) for k in range(2, 11)]) params["z_targets"] = np.linspace(0, params["length"], params["z_num"]) diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 4bb825a..e0d60b9 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -6,12 +6,13 @@ from typing import Any, Dict, Iterable, List, Tuple import numpy as np import pkg_resources as pkg -from ray import util import toml +from ray import util from send2trash import TrashPermissionError, send2trash +from tqdm import tqdm from . import utils -from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, TMP_FOLDER_KEY_BASE, ENVIRON_KEY_BASE +from .const import ENVIRON_KEY_BASE, PARAM_SEPARATOR, PREFIX_KEY_BASE, TMP_FOLDER_KEY_BASE from .errors import IncompleteDataFolderError from .logger import get_logger @@ -381,8 +382,7 @@ def merge_same_simulations(path: str): base_folders.add(base_folder) sim_num, param_num = utils.count_variations(config) - pt = utils.ProgressTracker(sim_num * z_num, logger=logger, prefix="merging data : ") - print(f"{pt.max=}") + pbar = utils.PBars(tqdm(total=sim_num * z_num, desc="merging data")) spectra = [] for z_id in range(z_num): @@ -395,7 +395,7 @@ def merge_same_simulations(path: str): in_path = os.path.join(path, utils.format_variable_list(variable_and_ind)) spectra.append(np.load(os.path.join(in_path, f"spectrum_{z_id}.npy"))) - pt.update() + pbar.update() # write new files only once all those from one parameter set are collected if repeat_id == max_repeat_id: @@ -411,6 +411,7 @@ def merge_same_simulations(path: str): os.path.join(in_path, file_name), os.path.join(out_path, ""), ) + pbar.close() try: for sub_folder in sub_folders: @@ -528,4 +529,4 @@ def ensure_folder(name, i=0, suffix="", prevent_overwrite=True): def _end_of_path_tree(path): out = path == os.path.abspath(os.sep) out |= path == "" - return out \ No newline at end of file + return out diff --git a/src/scgenerator/logger.py b/src/scgenerator/logger.py index a8d333e..9c592a0 100644 --- a/src/scgenerator/logger.py +++ b/src/scgenerator/logger.py @@ -1,4 +1,19 @@ import logging +from typing import Optional + + +class DebugOnlyFileHandler(logging.FileHandler): + def __init__( + self, filename, mode: str, encoding: Optional[str] = None, delay: bool = False + ) -> None: + super().__init__(filename, mode=mode, encoding=encoding, delay=delay) + self.setLevel(logging.DEBUG) + + def emit(self, record: logging.LogRecord) -> None: + if not record.levelno == logging.DEBUG: + return + return super().emit(record) + DEFAULT_LEVEL = logging.INFO @@ -66,13 +81,17 @@ def configure_logger(logger): """ if not hasattr(logger, "already_configured"): formatter = logging.Formatter("{levelname}: {name}: {message}", style="{") - file_handler1 = logging.FileHandler("sc-DEBUG.log", "a+") + file_handler1 = DebugOnlyFileHandler("sc-DEBUG.log", "a+") file_handler1.setFormatter(formatter) - file_handler1.setLevel(logging.DEBUG) logger.addHandler(file_handler1) + file_handler2 = logging.FileHandler("sc-INFO.log", "a+") + file_handler2.setFormatter(formatter) + file_handler2.setLevel(logging.INFO) + logger.addHandler(file_handler2) + stream_handler = logging.StreamHandler() - stream_handler.setLevel(logging.INFO) + stream_handler.setLevel(logging.WARNING) logger.addHandler(stream_handler) logger.setLevel(logging.DEBUG) diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 4219f96..a966d28 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -1,7 +1,8 @@ -from typing import Type +from typing import Type, Union import numpy as np from scipy.special import jn_zeros from scipy.interpolate import interp1d, griddata +from numba import jit def span(*vec): @@ -21,7 +22,7 @@ def span(*vec): return out -def argclosest(array, target): +def argclosest(array: np.ndarray, target: Union[float, int]): """returns the index/indices corresponding to the closest matches of target in array""" min_dist = np.inf index = None @@ -45,21 +46,31 @@ def power_fact(x, n): returns x ^ n / n! """ if isinstance(x, (int, float)): - x = float(x) - result = 1.0 + return _power_fact_single(x, n) elif isinstance(x, np.ndarray): - if x.dtype == int: - x = np.array(x, dtype=float) - result = np.ones(len(x)) + return _power_fact_array(x, n) else: raise TypeError(f"type {type(x)} of x not supported.") + +@jit(nopython=True) +def _power_fact_single(x, n): + result = 1.0 for k in range(n): result = result * x / (n - k) return result +@jit(nopython=True) +def _power_fact_array(x, n): + result = np.ones(len(x), dtype=np.float64) + for k in range(n): + result = result * x / (n - k) + return result + + +@jit(nopython=True) def abs2(z): return z.real ** 2 + z.imag ** 2 diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index 6c03e04..79231b0 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -1,5 +1,7 @@ import numpy as np +from numpy.lib.arraysetops import isin import toml +from numba import jit from numpy.fft import fft, ifft from numpy.polynomial.chebyshev import Chebyshev, cheb2poly from scipy.interpolate import interp1d @@ -841,33 +843,37 @@ def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=Non elif raman_type == "agrawal": f_r = 0.245 - # Define the non linear operator - def N_func(spectrum, r=0): - field = ifft(spectrum) + if "spm" in behaviors: + spm_part = lambda fi: (1 - f_r) * abs2(fi) + else: + spm_part = lambda fi: 0 - ss_part = w_c / w0 if "ss" in behaviors else 0 - spm_part = (1 - f_r) * abs2(field) if "spm" in behaviors else 0 - raman_part = f_r * ifft(hr_w * fft(abs2(field))) if "raman" in behaviors else 0 - raman_noise_part = 1j * 0 - if isinstance(gamma, (float, int)): + if "raman" in behaviors: + raman_part = lambda fi: f_r * ifft(hr_w * fft(abs2(fi))) + else: + raman_part = lambda fi: 0 + + spm_part = jit(spm_part, nopython=True) + ss_part = w_c / w0 if "ss" in behaviors else 0 + + if isinstance(gamma, (float, int)): + + def N_func(spectrum: np.ndarray, r=0): + field = ifft(spectrum) + return -1j * gamma * (1 + ss_part) * fft(field * (spm_part(field) + raman_part(field))) + + else: + + def N_func(spectrum: np.ndarray, r=0): + field = ifft(spectrum) return ( - -1j - * gamma - * (1 + ss_part) - * fft(field * (spm_part + raman_part) + raman_noise_part) - ) - else: - return ( - -1j - * gamma(r) - * (1 + ss_part) - * fft(field * (spm_part + raman_part) + raman_noise_part) + -1j * gamma(r) * (1 + ss_part) * fft(field * (spm_part(field) + raman_part(field))) ) return N_func -def fast_dispersion_op(w_c, beta_arr, power_fact, where=slice(None)): +def fast_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None)): """ dispersive operator @@ -888,10 +894,7 @@ def fast_dispersion_op(w_c, beta_arr, power_fact, where=slice(None)): dispersive component """ - dispersion = np.zeros_like(w_c) - - for k, beta in reversed(list(enumerate(beta_arr))): - dispersion = dispersion + beta * power_fact[k] + dispersion = _fast_disp_loop(np.zeros_like(w_c), beta_arr, power_fact_arr) out = np.zeros_like(dispersion) out[where] = dispersion[where] @@ -899,6 +902,13 @@ def fast_dispersion_op(w_c, beta_arr, power_fact, where=slice(None)): return -1j * out +@jit(nopython=True) +def _fast_disp_loop(dispersion: np.ndarray, beta_arr, power_fact_arr): + for k in range(len(beta_arr) - 1, -1, -1): + dispersion = dispersion + beta_arr[k] * power_fact_arr[k] + return dispersion + + def dispersion_op(w_c, beta_arr, where=None): """ dispersive operator diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 1a93077..2bf30a8 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -17,6 +17,7 @@ import numpy as np from numpy import pi from numpy.fft import fft, fftshift, ifft from scipy.interpolate import UnivariateSpline +from numba import jit from ..defaults import default_plotting @@ -184,10 +185,12 @@ def gauss_pulse(t, t0, P0, offset=0): return np.sqrt(P0) * np.exp(-(((t - offset) / t0) ** 2)) +@jit(nopython=True) def photon_number(spectrum, w, dw, gamma): return np.sum(1 / gamma * abs2(spectrum) / w * dw) +@jit(nopython=True) def pulse_energy(spectrum, w, dw, _): return np.sum(abs2(spectrum) * dw) diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index e691a1d..3dd41e4 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -1,9 +1,12 @@ +import multiprocessing import os import sys from datetime import datetime -from typing import List, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type import numpy as np +from numba import jit +from tqdm import tqdm from .. import initialize, io, utils from ..errors import IncompleteDataFolderError @@ -68,11 +71,13 @@ class RK4IP: print/log progress update every n_percent, by default 10 """ + self.set_new_params(sim_params, save_data, job_identifier, task_id, n_percent) + + def set_new_params(self, sim_params, save_data, job_identifier, task_id, n_percent): self.job_identifier = job_identifier self.id = task_id self.n_percent = n_percent self.logger = get_logger(self.job_identifier) - self.resuming = False self.save_data = save_data self._extract_params(sim_params) @@ -101,6 +106,7 @@ class RK4IP: self.N_func = create_non_linear_op( self.behaviors, self.w_c, self.w0, self.gamma, self.raman_type, self.f_r, self.hr_w ) + if self.dynamic_dispersion: self.disp = lambda r: fast_dispersion_op(self.w_c, self.beta(r), self.w_power_fact) else: @@ -127,10 +133,6 @@ class RK4IP: self.z = self.z_targets.pop(0) self.z_stored = list(self.z_targets.copy()[0 : self.starting_num + 1]) - self.progress_tracker = utils.ProgressTracker( - self.z_final, percent_incr=self.n_percent, logger=self.logger - ) - # Setup initial values for every physical quantity that we want to track self.current_spectrum = self.spec_0.copy() self.stored_spectra = self.starting_num * [None] + [self.current_spectrum.copy()] @@ -183,7 +185,6 @@ class RK4IP: self.logger.debug( "Computing {} new spectra, first one at {}m".format(self.store_num, self.z_targets[0]) ) - self.progress_tracker.set(self.z) # Start of the integration step = 1 @@ -203,16 +204,14 @@ class RK4IP: # Whether the current spectrum has to be stored depends on previous step if store: - self.progress_tracker.suffix = " ({} steps). z = {:.4f}, h = {:.5g}".format( - step, self.z, h_taken - ) - self.progress_tracker.set(self.z) + self.logger.debug("{} steps, z = {:.4f}, h = {:.5g}".format(step, self.z, h_taken)) self.stored_spectra.append(self.current_spectrum) if self.save_data: self._save_current_spectrum(len(self.stored_spectra) - 1) self.z_stored.append(self.z) + self.step_saved() del self.z_targets[0] # reset the constant step size after a spectrum is stored @@ -229,7 +228,7 @@ class RK4IP: store = True h_next_step = self.z_targets[0] - self.z - self.logger.debug( + self.logger.info( "propagation finished in {} steps ({} seconds)".format( step, (datetime.today() - time_start).total_seconds() ) @@ -301,6 +300,60 @@ class RK4IP: h_next_step = h return h, h_next_step, new_spectrum + def step_saved(self): + pass + + +class MutliProcRK4IP(RK4IP): + def __init__( + self, + sim_params, + p_queue: multiprocessing.Queue, + worker_id: int, + save_data=False, + job_identifier="", + task_id=0, + n_percent=10, + ): + super().__init__( + sim_params, + save_data=save_data, + job_identifier=job_identifier, + task_id=task_id, + n_percent=n_percent, + ) + self.worker_id = worker_id + self.p_queue = p_queue + + def step_saved(self): + self.p_queue.put((self.worker_id, self.z / self.z_final)) + + +class RayRK4IP(RK4IP): + def __init__( + self, + sim_params, + p_actor, + worker_id: int, + save_data=False, + job_identifier="", + task_id=0, + n_percent=10, + ): + super().__init__( + sim_params, + save_data=save_data, + job_identifier=job_identifier, + task_id=task_id, + n_percent=n_percent, + ) + self.worker_id = worker_id + self.p_actor = p_actor + + def step_saved(self): + self.p_actor.update.remote(self.worker_id, self.z / self.z_final) + self.p_actor.update.remote(0) + class Simulations: """The recommended way to run simulations. @@ -343,8 +396,6 @@ class Simulations: self.sim_jobs_per_node = 1 self.max_concurrent_jobs = np.inf - self.propagator = RK4IP - @property def finished_and_complete(self): try: @@ -360,12 +411,6 @@ class Simulations: def update(self, param_seq: initialize.ParamSequence): self.param_seq = param_seq - self.progress_tracker = utils.ProgressTracker( - self.param_seq.num_steps, - percent_incr=1, - logger=self.logger, - prefix="Overall : ", - ) def run(self): self._run_available() @@ -416,16 +461,10 @@ class Simulations: class SequencialSimulations(Simulations, available=True, priority=0): - def new_sim(self, variable_list: List[tuple], params: dict): + def new_sim(self, variable_list: List[tuple], params: Dict[str, Any]): v_list_str = utils.format_variable_list(variable_list) self.logger.info(f"launching simulation with {v_list_str}") - self.propagator( - params, - save_data=True, - job_identifier=v_list_str, - task_id=self.id, - ).run() - self.progress_tracker.update(self.param_seq["simulation", "z_num"]) + RK4IP(params, save_data=True, job_identifier=v_list_str, task_id=self.id).run() def stop(self): pass @@ -434,7 +473,94 @@ class SequencialSimulations(Simulations, available=True, priority=0): pass -class RaySimulations(Simulations, available=using_ray, priority=1): +class MultiProcSimulations(Simulations, available=True, priority=1): + def __init__(self, param_seq: initialize.ParamSequence, task_id, data_folder): + super().__init__(param_seq, task_id=task_id, data_folder=data_folder) + self.sim_jobs_per_node = max(1, os.cpu_count() // 2) + self.queue = multiprocessing.JoinableQueue(self.sim_jobs_per_node) + self.progress_queue = multiprocessing.Queue() + self.workers = [ + multiprocessing.Process( + target=MultiProcSimulations.worker, + args=(self.id, i + 1, self.queue, self.progress_queue), + ) + for i in range(self.sim_jobs_per_node) + ] + self.p_worker = multiprocessing.Process( + target=MultiProcSimulations.progress_worker, + args=(self.param_seq.num_steps, self.progress_queue), + ) + self.p_worker.start() + + def run(self): + for worker in self.workers: + worker.start() + super().run() + + def new_sim(self, variable_list: List[tuple], params: dict): + self.queue.put((variable_list, params), block=True, timeout=None) + + def finish(self): + """0 means finished""" + for worker in self.workers: + self.queue.put(0) + for worker in self.workers: + worker.join() + self.queue.join() + self.progress_queue.put(0) + + def stop(self): + self.finish() + + @staticmethod + def worker( + task_id, + worker_id: int, + queue: multiprocessing.JoinableQueue, + p_queue: multiprocessing.Queue, + ): + while True: + raw_data: Tuple[List[tuple], Dict[str, Any]] = queue.get() + if raw_data == 0: + queue.task_done() + return + variable_list, params = raw_data + v_list_str = utils.format_variable_list(variable_list) + MutliProcRK4IP( + params, + p_queue, + worker_id, + save_data=True, + job_identifier=v_list_str, + task_id=task_id, + ).run() + queue.task_done() + + @staticmethod + def progress_worker(num_steps: int, progress_queue: multiprocessing.Queue): + pbars: Dict[int, tqdm] = {} + with tqdm(total=num_steps, desc="Simulating", unit="step", position=0) as tq: + while True: + raw = progress_queue.get() + if raw == 0: + for pbar in pbars.values(): + pbar.close() + return + i, rel_pos = raw + if i not in pbars: + pbars[i] = tqdm( + total=1, + desc=f"Worker {i}", + position=i, + bar_format="{l_bar}{bar}" + "|[{elapsed}<{remaining}, " + "{rate_fmt}{postfix}]", + ) + pbars[i].update(rel_pos - pbars[i].n) + tq.update() + + +class RaySimulations(Simulations, available=using_ray, priority=2): """runs simulation with the help of the ray module. ray must be initialized before creating an instance of RaySimulations""" def __init__( @@ -456,7 +582,7 @@ class RaySimulations(Simulations, available=using_ray, priority=1): ) ) - self.propagator = ray.remote(RK4IP).options( + self.propagator = ray.remote(RayRK4IP).options( override_environment_variables=io.get_all_environ() ) self.sim_jobs_per_node = min( @@ -465,15 +591,44 @@ class RaySimulations(Simulations, available=using_ray, priority=1): self.update_cluster_frequency = 3 self.jobs = [] self.actors = {} + self.rolling_id = 0 + self.p_actor = ray.remote(utils.ProgressBarActor).remote(self.sim_jobs_total) + self.p_bars = utils.PBars( + [ + tqdm( + total=self.param_seq.num_steps, + unit="step", + desc="Simulating", + smoothing=0, + ncols=100, + ) + ] + ) + for i in range(1, self.sim_jobs_total + 1): + self.p_bars.append( + tqdm( + total=1, + desc=f"Worker {i}", + position=i, + ncols=100, + bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]", + ) + ) def new_sim(self, variable_list: List[tuple], params: dict): while len(self.jobs) >= self.sim_jobs_total: self._collect_1_job() + self.rolling_id = (self.rolling_id + 1) % self.sim_jobs_total v_list_str = utils.format_variable_list(variable_list) new_actor = self.propagator.remote( - params, save_data=True, job_identifier=v_list_str, task_id=self.id + params, + self.p_actor, + self.rolling_id + 1, + save_data=True, + job_identifier=v_list_str, + task_id=self.id, ) new_job = new_actor.run.remote() @@ -485,11 +640,11 @@ class RaySimulations(Simulations, available=using_ray, priority=1): def finish(self): while len(self.jobs) > 0: self._collect_1_job() + self.p_bars.close() def _collect_1_job(self): ready, self.jobs = ray.wait(self.jobs, timeout=self.update_cluster_frequency) - self.progress_tracker.update(self.param_seq["simulation", "z_num"]) - + self.update_pbars() if len(ready) == 0: return ray.get(ready) @@ -505,6 +660,12 @@ class RaySimulations(Simulations, available=using_ray, priority=1): tot_cpus = min(tot_cpus, self.max_concurrent_jobs) return int(min(self.param_seq.num_sim, tot_cpus)) + def update_pbars(self): + counters = ray.get(self.p_actor.wait_for_update.remote()) + for counter, pbar in zip(counters, self.p_bars): + pbar.update(counter - pbar.n) + self.p_bars.print() + def new_simulations( config_file: str, @@ -535,7 +696,7 @@ def _new_simulations( task_id, data_folder, Method: Type[Simulations], -): +) -> Simulations: if Method is not None: return Method(param_seq, task_id, data_folder=data_folder) elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"] and using_ray: @@ -550,4 +711,4 @@ if __name__ == "__main__": except NameError: pass config_file, *opts = sys.argv[1:] - new_simulations(config_file, *opts) \ No newline at end of file + new_simulations(config_file, *opts) diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index eab18c0..25f0e0e 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -8,13 +8,17 @@ scgenerator module but some function may be used in any python program import datetime as dt import itertools import logging +import re import socket from typing import Any, Callable, Iterator, List, Tuple, Union +from asyncio import Event import numpy as np import ray from copy import deepcopy +from tqdm import tqdm + from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable from .logger import get_logger from .math import * @@ -24,13 +28,49 @@ from .math import * # XXX ############################################ +class PBars: + def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None: + if isinstance(pbars, tqdm): + self.pbars = [pbars] + else: + self.pbars = pbars + self.logger = get_logger(__name__) + + def print(self): + if len(self.pbars) > 1: + s = [""] + else: + s = [] + for pbar in self.pbars: + s.append(str(pbar)) + self.logger.info("\n".join(s)) + + def __iter__(self): + yield from self.pbars + + def __getitem__(self, key): + return self.pbars[key] + + def update(self): + for pbar in self: + pbar.update() + self.print() + + def append(self, pbar: tqdm): + self.pbars.append(pbar) + + def close(self): + for pbar in self.pbars: + pbar.close() + + class ProgressTracker: def __init__( self, max: Union[int, float], prefix: str = "", suffix: str = "", - logger: logging.Logger = get_logger(), + logger: logging.Logger = None, auto_print: bool = True, percent_incr: Union[int, float] = 5, default_update: Union[int, float] = 1, @@ -44,7 +84,7 @@ class ProgressTracker: self.next_percent = percent_incr self.percent_incr = percent_incr self.default_update = default_update - self.logger = logger + self.logger = logger if logger is not None else get_logger() def _update(self): if self.auto_print and self.current / self.max >= self.next_percent / 100: @@ -83,6 +123,44 @@ class ProgressTracker: return "{}/{}".format(self.current, self.max) +class ProgressBarActor: + counter: int + delta: int + event: Event + + def __init__(self, num_workers: int) -> None: + self.counters = [0 for _ in range(num_workers + 1)] + self.event = Event() + + def update(self, worker_id: int, rel_pos: float = None) -> None: + """update a counter + + Parameters + ---------- + worker_id : int + id of the worker + rel_pos : float, optional + if None, increase the counter by one, if set, will set + the counter to the specified value (instead of incrementing it), by default None + """ + if rel_pos is None: + self.counters[worker_id] += 1 + else: + self.counters[worker_id] = rel_pos + self.event.set() + + async def wait_for_update(self) -> List[float]: + """Blocking call. + + Waits until somebody calls `update`, then returns a tuple of + the number of updates since the last call to + `wait_for_update`, and the total number of completed items. + """ + await self.event.wait() + self.event.clear() + return self.counters + + def count_variations(config: dict) -> Tuple[int, int]: """returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and variable_params_num is the number of distinct parameters that will vary."""