better pbars

This commit is contained in:
Benoît Sierro
2021-05-31 16:16:49 +02:00
parent e40dc3ce2c
commit 04f3ac4b38
3 changed files with 72 additions and 71 deletions

View File

@@ -370,7 +370,15 @@ def append_and_merge(final_sim_path: os.PathLike, new_name=None):
destination_path = final_sim_path.parent / new_name destination_path = final_sim_path.parent / new_name
destination_path.mkdir(exist_ok=True) destination_path.mkdir(exist_ok=True)
for sim_path in tqdm(list(final_sim_path.glob("id*num*")), position=0, desc="Appending"): sim_paths = list(final_sim_path.glob("id*num*"))
pbars = utils.PBars.auto(
len(sim_paths),
0,
head_kwargs=dict(desc="Appending"),
worker_kwargs=dict(desc=""),
)
for sim_path in sim_paths:
path_tree = [sim_path] path_tree = [sim_path]
sim_name = sim_path.name sim_name = sim_path.name
appended_sim_path = destination_path / sim_name appended_sim_path = destination_path / sim_name
@@ -384,7 +392,9 @@ def append_and_merge(final_sim_path: os.PathLike, new_name=None):
z: List[np.ndarray] = [] z: List[np.ndarray] = []
z_num = 0 z_num = 0
last_z = 0 last_z = 0
for path in tqdm(list(reversed(path_tree)), position=1, leave=False): paths_r = list(reversed(path_tree))
for path in paths_r:
curr_z_num = load_toml(path / "params.toml")["z_num"] curr_z_num = load_toml(path / "params.toml")["z_num"]
for i in range(curr_z_num): for i in range(curr_z_num):
shutil.copy( shutil.copy(
@@ -398,6 +408,7 @@ def append_and_merge(final_sim_path: os.PathLike, new_name=None):
z_arr = np.concatenate(z) z_arr = np.concatenate(z)
update_appended_params(sim_path / "params.toml", appended_sim_path / "params.toml", z_arr) update_appended_params(sim_path / "params.toml", appended_sim_path / "params.toml", z_arr)
np.save(appended_sim_path / "z.npy", z_arr) np.save(appended_sim_path / "z.npy", z_arr)
pbars.update(0)
update_appended_params( update_appended_params(
final_sim_path / "initial_config.toml", destination_path / "initial_config.toml", z_arr final_sim_path / "initial_config.toml", destination_path / "initial_config.toml", z_arr
@@ -437,7 +448,7 @@ def merge_same_simulations(path: Path, delete=True):
check_data_integrity(sub_folders, z_num) check_data_integrity(sub_folders, z_num)
sim_num, param_num = utils.count_variations(config) sim_num, param_num = utils.count_variations(config)
pbar = utils.PBars(tqdm(total=sim_num * z_num, desc="Merging data", ncols=100)) pbar = utils.PBars.auto(sim_num * z_num, head_kwargs=dict(desc="Merging data"))
spectra = [] spectra = []
for z_id in range(z_num): for z_id in range(z_num):

View File

@@ -545,7 +545,7 @@ class MultiProcSimulations(Simulations, priority=1):
] ]
self.p_worker = multiprocessing.Process( self.p_worker = multiprocessing.Process(
target=utils.progress_worker, target=utils.progress_worker,
args=(self.param_seq.num_steps, self.progress_queue), args=(self.sim_jobs_per_node, self.param_seq.num_steps, self.progress_queue),
) )
self.p_worker.start() self.p_worker.start()
@@ -629,28 +629,9 @@ class RaySimulations(Simulations, priority=2):
self.jobs = [] self.jobs = []
self.actors = {} self.actors = {}
self.rolling_id = 0 self.rolling_id = 0
self.p_actor = ray.remote(utils.ProgressBarActor).remote(self.sim_jobs_total) self.p_actor = ray.remote(utils.ProgressBarActor).remote(
self.p_bars = utils.PBars( self.sim_jobs_total, self.param_seq.num_steps
[
tqdm(
total=self.param_seq.num_steps,
unit="step",
desc="Simulating",
smoothing=0,
ncols=100,
)
]
) )
for i in range(1, self.sim_jobs_total + 1):
self.p_bars.append(
tqdm(
total=1,
desc=f"Worker {i}",
position=i,
ncols=100,
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
)
)
def new_sim(self, variable_list: List[tuple], params: dict): def new_sim(self, variable_list: List[tuple], params: dict):
while len(self.jobs) >= self.sim_jobs_total: while len(self.jobs) >= self.sim_jobs_total:
@@ -679,11 +660,11 @@ class RaySimulations(Simulations, priority=2):
def finish(self): def finish(self):
while len(self.jobs) > 0: while len(self.jobs) > 0:
self._collect_1_job() self._collect_1_job()
self.p_bars.close() ray.get(self.p_actor.close.remote())
def _collect_1_job(self): def _collect_1_job(self):
ready, self.jobs = ray.wait(self.jobs, timeout=self.update_cluster_frequency) ready, self.jobs = ray.wait(self.jobs, timeout=self.update_cluster_frequency)
self.update_pbars() ray.get(self.p_actor.update_pbars.remote())
if len(ready) == 0: if len(ready) == 0:
return return
ray.get(ready) ray.get(ready)
@@ -699,12 +680,6 @@ class RaySimulations(Simulations, priority=2):
tot_cpus = min(tot_cpus, self.max_concurrent_jobs) tot_cpus = min(tot_cpus, self.max_concurrent_jobs)
return int(min(self.param_seq.num_sim, tot_cpus)) return int(min(self.param_seq.num_sim, tot_cpus))
def update_pbars(self):
counters = ray.get(self.p_actor.wait_for_update.remote())
for counter, pbar in zip(counters, self.p_bars):
pbar.update(counter - pbar.n)
self.p_bars.print()
def run_simulation_sequence( def run_simulation_sequence(
*config_files: os.PathLike, *config_files: os.PathLike,

View File

@@ -13,7 +13,6 @@ import multiprocessing
import socket import socket
import os import os
from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union
from asyncio import Event
from io import StringIO from io import StringIO
import numpy as np import numpy as np
@@ -22,7 +21,7 @@ from copy import deepcopy
from tqdm import tqdm from tqdm import tqdm
from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, pbar_format, HUSH_PROGRESS from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, HUSH_PROGRESS
from .logger import get_logger from .logger import get_logger
from .math import * from .math import *
@@ -32,6 +31,29 @@ from .math import *
class PBars: class PBars:
@classmethod
def auto(
cls, num_tot: int, num_sub_bars: int = 0, head_kwargs=None, worker_kwargs=None
) -> "PBars":
if head_kwargs is None:
head_kwargs = dict(unit="step", desc="Simulating", smoothing=0)
if worker_kwargs is None:
worker_kwargs = dict(
total=1,
desc="Worker {worker_id}",
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
)
if os.getenv(HUSH_PROGRESS) is not None:
head_kwargs["file"] = worker_kwargs["file"] = StringIO()
p = cls([tqdm(total=num_tot, ncols=100, **head_kwargs)])
for i in range(1, num_sub_bars + 1):
kwargs = {k: v for k, v in worker_kwargs.items()}
if "desc" in kwargs:
kwargs["desc"] = kwargs["desc"].format(worker_id=i)
p.append(tqdm(position=i, ncols=100, **kwargs))
return p
def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None: def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None:
if isinstance(pbars, tqdm): if isinstance(pbars, tqdm):
self.pbars = [pbars] self.pbars = [pbars]
@@ -54,14 +76,22 @@ class PBars:
def __getitem__(self, key): def __getitem__(self, key):
return self.pbars[key] return self.pbars[key]
def update(self): def update(self, i=None, value=1):
for pbar in self: if i is None:
pbar.update() for pbar in self.pbars[1:]:
pbar.update(value)
else:
self.pbars[i].update(value)
self.pbars[0].update()
self.print() self.print()
def append(self, pbar: tqdm): def append(self, pbar: tqdm):
self.pbars.append(pbar) self.pbars.append(pbar)
def reset(self, i):
self.pbars[i].update(-self.pbars[i].n)
self.print()
def close(self): def close(self):
for pbar in self.pbars: for pbar in self.pbars:
pbar.close() pbar.close()
@@ -127,13 +157,9 @@ class ProgressTracker:
class ProgressBarActor: class ProgressBarActor:
counter: int def __init__(self, num_workers: int, num_steps: int) -> None:
delta: int
event: Event
def __init__(self, num_workers: int) -> None:
self.counters = [0 for _ in range(num_workers + 1)] self.counters = [0 for _ in range(num_workers + 1)]
self.event = Event() self.p_bars = PBars.auto(num_steps, num_workers)
def update(self, worker_id: int, rel_pos: float = None) -> None: def update(self, worker_id: int, rel_pos: float = None) -> None:
"""update a counter """update a counter
@@ -150,21 +176,17 @@ class ProgressBarActor:
self.counters[worker_id] += 1 self.counters[worker_id] += 1
else: else:
self.counters[worker_id] = rel_pos self.counters[worker_id] = rel_pos
self.event.set()
async def wait_for_update(self) -> List[float]: def update_pbars(self):
"""Blocking call. for counter, pbar in zip(self.counters, self.p_bars):
pbar.update(counter - pbar.n)
self.p_bars.print()
Waits until somebody calls `update`, then returns a tuple of def close(self):
the number of updates since the last call to self.p_bars.close()
`wait_for_update`, and the total number of completed items.
"""
await self.event.wait()
self.event.clear()
return self.counters
def progress_worker(num_steps: int, progress_queue: multiprocessing.Queue): def progress_worker(num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue):
"""keeps track of progress on a separate thread """keeps track of progress on a separate thread
Parameters Parameters
@@ -176,22 +198,15 @@ def progress_worker(num_steps: int, progress_queue: multiprocessing.Queue):
Literal[0] : stop the worker and close the progress bars Literal[0] : stop the worker and close the progress bars
Tuple[int, float] : worker id and relative progress between 0 and 1 Tuple[int, float] : worker id and relative progress between 0 and 1
""" """
kwargs = {} pbars = PBars.auto(num_steps, num_workers)
if os.getenv(HUSH_PROGRESS) is not None: while True:
kwargs = dict(file=StringIO()) raw = progress_queue.get()
pbars: Dict[int, tqdm] = {} if raw == 0:
with tqdm(total=num_steps, desc="Simulating", unit="step", position=0, **kwargs) as tq: pbars.close()
while True: return
raw = progress_queue.get() i, rel_pos = raw
if raw == 0: pbars[i].update(rel_pos - pbars[i].n)
for pbar in pbars.values(): pbars[0].update()
pbar.close()
return
i, rel_pos = raw
if i not in pbars:
pbars[i] = tqdm(**pbar_format(i), **kwargs)
pbars[i].update(rel_pos - pbars[i].n)
tq.update()
def count_variations(config: dict) -> Tuple[int, int]: def count_variations(config: dict) -> Tuple[int, int]: