diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 7fc9aea..c7edc87 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -370,7 +370,15 @@ def append_and_merge(final_sim_path: os.PathLike, new_name=None): destination_path = final_sim_path.parent / new_name destination_path.mkdir(exist_ok=True) - for sim_path in tqdm(list(final_sim_path.glob("id*num*")), position=0, desc="Appending"): + sim_paths = list(final_sim_path.glob("id*num*")) + pbars = utils.PBars.auto( + len(sim_paths), + 0, + head_kwargs=dict(desc="Appending"), + worker_kwargs=dict(desc=""), + ) + + for sim_path in sim_paths: path_tree = [sim_path] sim_name = sim_path.name appended_sim_path = destination_path / sim_name @@ -384,7 +392,9 @@ def append_and_merge(final_sim_path: os.PathLike, new_name=None): z: List[np.ndarray] = [] z_num = 0 last_z = 0 - for path in tqdm(list(reversed(path_tree)), position=1, leave=False): + paths_r = list(reversed(path_tree)) + + for path in paths_r: curr_z_num = load_toml(path / "params.toml")["z_num"] for i in range(curr_z_num): shutil.copy( @@ -398,6 +408,7 @@ def append_and_merge(final_sim_path: os.PathLike, new_name=None): z_arr = np.concatenate(z) update_appended_params(sim_path / "params.toml", appended_sim_path / "params.toml", z_arr) np.save(appended_sim_path / "z.npy", z_arr) + pbars.update(0) update_appended_params( final_sim_path / "initial_config.toml", destination_path / "initial_config.toml", z_arr @@ -437,7 +448,7 @@ def merge_same_simulations(path: Path, delete=True): check_data_integrity(sub_folders, z_num) sim_num, param_num = utils.count_variations(config) - pbar = utils.PBars(tqdm(total=sim_num * z_num, desc="Merging data", ncols=100)) + pbar = utils.PBars.auto(sim_num * z_num, head_kwargs=dict(desc="Merging data")) spectra = [] for z_id in range(z_num): diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 1cf4e5f..b3f2bb9 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -545,7 +545,7 @@ class MultiProcSimulations(Simulations, priority=1): ] self.p_worker = multiprocessing.Process( target=utils.progress_worker, - args=(self.param_seq.num_steps, self.progress_queue), + args=(self.sim_jobs_per_node, self.param_seq.num_steps, self.progress_queue), ) self.p_worker.start() @@ -629,28 +629,9 @@ class RaySimulations(Simulations, priority=2): self.jobs = [] self.actors = {} self.rolling_id = 0 - self.p_actor = ray.remote(utils.ProgressBarActor).remote(self.sim_jobs_total) - self.p_bars = utils.PBars( - [ - tqdm( - total=self.param_seq.num_steps, - unit="step", - desc="Simulating", - smoothing=0, - ncols=100, - ) - ] + self.p_actor = ray.remote(utils.ProgressBarActor).remote( + self.sim_jobs_total, self.param_seq.num_steps ) - for i in range(1, self.sim_jobs_total + 1): - self.p_bars.append( - tqdm( - total=1, - desc=f"Worker {i}", - position=i, - ncols=100, - bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]", - ) - ) def new_sim(self, variable_list: List[tuple], params: dict): while len(self.jobs) >= self.sim_jobs_total: @@ -679,11 +660,11 @@ class RaySimulations(Simulations, priority=2): def finish(self): while len(self.jobs) > 0: self._collect_1_job() - self.p_bars.close() + ray.get(self.p_actor.close.remote()) def _collect_1_job(self): ready, self.jobs = ray.wait(self.jobs, timeout=self.update_cluster_frequency) - self.update_pbars() + ray.get(self.p_actor.update_pbars.remote()) if len(ready) == 0: return ray.get(ready) @@ -699,12 +680,6 @@ class RaySimulations(Simulations, priority=2): tot_cpus = min(tot_cpus, self.max_concurrent_jobs) return int(min(self.param_seq.num_sim, tot_cpus)) - def update_pbars(self): - counters = ray.get(self.p_actor.wait_for_update.remote()) - for counter, pbar in zip(counters, self.p_bars): - pbar.update(counter - pbar.n) - self.p_bars.print() - def run_simulation_sequence( *config_files: os.PathLike, diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index bcf9901..849361f 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -13,7 +13,6 @@ import multiprocessing import socket import os from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union -from asyncio import Event from io import StringIO import numpy as np @@ -22,7 +21,7 @@ from copy import deepcopy from tqdm import tqdm -from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, pbar_format, HUSH_PROGRESS +from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, HUSH_PROGRESS from .logger import get_logger from .math import * @@ -32,6 +31,29 @@ from .math import * class PBars: + @classmethod + def auto( + cls, num_tot: int, num_sub_bars: int = 0, head_kwargs=None, worker_kwargs=None + ) -> "PBars": + if head_kwargs is None: + head_kwargs = dict(unit="step", desc="Simulating", smoothing=0) + if worker_kwargs is None: + worker_kwargs = dict( + total=1, + desc="Worker {worker_id}", + bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]", + ) + + if os.getenv(HUSH_PROGRESS) is not None: + head_kwargs["file"] = worker_kwargs["file"] = StringIO() + p = cls([tqdm(total=num_tot, ncols=100, **head_kwargs)]) + for i in range(1, num_sub_bars + 1): + kwargs = {k: v for k, v in worker_kwargs.items()} + if "desc" in kwargs: + kwargs["desc"] = kwargs["desc"].format(worker_id=i) + p.append(tqdm(position=i, ncols=100, **kwargs)) + return p + def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None: if isinstance(pbars, tqdm): self.pbars = [pbars] @@ -54,14 +76,22 @@ class PBars: def __getitem__(self, key): return self.pbars[key] - def update(self): - for pbar in self: - pbar.update() + def update(self, i=None, value=1): + if i is None: + for pbar in self.pbars[1:]: + pbar.update(value) + else: + self.pbars[i].update(value) + self.pbars[0].update() self.print() def append(self, pbar: tqdm): self.pbars.append(pbar) + def reset(self, i): + self.pbars[i].update(-self.pbars[i].n) + self.print() + def close(self): for pbar in self.pbars: pbar.close() @@ -127,13 +157,9 @@ class ProgressTracker: class ProgressBarActor: - counter: int - delta: int - event: Event - - def __init__(self, num_workers: int) -> None: + def __init__(self, num_workers: int, num_steps: int) -> None: self.counters = [0 for _ in range(num_workers + 1)] - self.event = Event() + self.p_bars = PBars.auto(num_steps, num_workers) def update(self, worker_id: int, rel_pos: float = None) -> None: """update a counter @@ -150,21 +176,17 @@ class ProgressBarActor: self.counters[worker_id] += 1 else: self.counters[worker_id] = rel_pos - self.event.set() - async def wait_for_update(self) -> List[float]: - """Blocking call. + def update_pbars(self): + for counter, pbar in zip(self.counters, self.p_bars): + pbar.update(counter - pbar.n) + self.p_bars.print() - Waits until somebody calls `update`, then returns a tuple of - the number of updates since the last call to - `wait_for_update`, and the total number of completed items. - """ - await self.event.wait() - self.event.clear() - return self.counters + def close(self): + self.p_bars.close() -def progress_worker(num_steps: int, progress_queue: multiprocessing.Queue): +def progress_worker(num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue): """keeps track of progress on a separate thread Parameters @@ -176,22 +198,15 @@ def progress_worker(num_steps: int, progress_queue: multiprocessing.Queue): Literal[0] : stop the worker and close the progress bars Tuple[int, float] : worker id and relative progress between 0 and 1 """ - kwargs = {} - if os.getenv(HUSH_PROGRESS) is not None: - kwargs = dict(file=StringIO()) - pbars: Dict[int, tqdm] = {} - with tqdm(total=num_steps, desc="Simulating", unit="step", position=0, **kwargs) as tq: - while True: - raw = progress_queue.get() - if raw == 0: - for pbar in pbars.values(): - pbar.close() - return - i, rel_pos = raw - if i not in pbars: - pbars[i] = tqdm(**pbar_format(i), **kwargs) - pbars[i].update(rel_pos - pbars[i].n) - tq.update() + pbars = PBars.auto(num_steps, num_workers) + while True: + raw = progress_queue.get() + if raw == 0: + pbars.close() + return + i, rel_pos = raw + pbars[i].update(rel_pos - pbars[i].n) + pbars[0].update() def count_variations(config: dict) -> Tuple[int, int]: