better pbars
This commit is contained in:
@@ -370,7 +370,15 @@ def append_and_merge(final_sim_path: os.PathLike, new_name=None):
|
|||||||
destination_path = final_sim_path.parent / new_name
|
destination_path = final_sim_path.parent / new_name
|
||||||
destination_path.mkdir(exist_ok=True)
|
destination_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
for sim_path in tqdm(list(final_sim_path.glob("id*num*")), position=0, desc="Appending"):
|
sim_paths = list(final_sim_path.glob("id*num*"))
|
||||||
|
pbars = utils.PBars.auto(
|
||||||
|
len(sim_paths),
|
||||||
|
0,
|
||||||
|
head_kwargs=dict(desc="Appending"),
|
||||||
|
worker_kwargs=dict(desc=""),
|
||||||
|
)
|
||||||
|
|
||||||
|
for sim_path in sim_paths:
|
||||||
path_tree = [sim_path]
|
path_tree = [sim_path]
|
||||||
sim_name = sim_path.name
|
sim_name = sim_path.name
|
||||||
appended_sim_path = destination_path / sim_name
|
appended_sim_path = destination_path / sim_name
|
||||||
@@ -384,7 +392,9 @@ def append_and_merge(final_sim_path: os.PathLike, new_name=None):
|
|||||||
z: List[np.ndarray] = []
|
z: List[np.ndarray] = []
|
||||||
z_num = 0
|
z_num = 0
|
||||||
last_z = 0
|
last_z = 0
|
||||||
for path in tqdm(list(reversed(path_tree)), position=1, leave=False):
|
paths_r = list(reversed(path_tree))
|
||||||
|
|
||||||
|
for path in paths_r:
|
||||||
curr_z_num = load_toml(path / "params.toml")["z_num"]
|
curr_z_num = load_toml(path / "params.toml")["z_num"]
|
||||||
for i in range(curr_z_num):
|
for i in range(curr_z_num):
|
||||||
shutil.copy(
|
shutil.copy(
|
||||||
@@ -398,6 +408,7 @@ def append_and_merge(final_sim_path: os.PathLike, new_name=None):
|
|||||||
z_arr = np.concatenate(z)
|
z_arr = np.concatenate(z)
|
||||||
update_appended_params(sim_path / "params.toml", appended_sim_path / "params.toml", z_arr)
|
update_appended_params(sim_path / "params.toml", appended_sim_path / "params.toml", z_arr)
|
||||||
np.save(appended_sim_path / "z.npy", z_arr)
|
np.save(appended_sim_path / "z.npy", z_arr)
|
||||||
|
pbars.update(0)
|
||||||
|
|
||||||
update_appended_params(
|
update_appended_params(
|
||||||
final_sim_path / "initial_config.toml", destination_path / "initial_config.toml", z_arr
|
final_sim_path / "initial_config.toml", destination_path / "initial_config.toml", z_arr
|
||||||
@@ -437,7 +448,7 @@ def merge_same_simulations(path: Path, delete=True):
|
|||||||
check_data_integrity(sub_folders, z_num)
|
check_data_integrity(sub_folders, z_num)
|
||||||
|
|
||||||
sim_num, param_num = utils.count_variations(config)
|
sim_num, param_num = utils.count_variations(config)
|
||||||
pbar = utils.PBars(tqdm(total=sim_num * z_num, desc="Merging data", ncols=100))
|
pbar = utils.PBars.auto(sim_num * z_num, head_kwargs=dict(desc="Merging data"))
|
||||||
|
|
||||||
spectra = []
|
spectra = []
|
||||||
for z_id in range(z_num):
|
for z_id in range(z_num):
|
||||||
|
|||||||
@@ -545,7 +545,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
|||||||
]
|
]
|
||||||
self.p_worker = multiprocessing.Process(
|
self.p_worker = multiprocessing.Process(
|
||||||
target=utils.progress_worker,
|
target=utils.progress_worker,
|
||||||
args=(self.param_seq.num_steps, self.progress_queue),
|
args=(self.sim_jobs_per_node, self.param_seq.num_steps, self.progress_queue),
|
||||||
)
|
)
|
||||||
self.p_worker.start()
|
self.p_worker.start()
|
||||||
|
|
||||||
@@ -629,28 +629,9 @@ class RaySimulations(Simulations, priority=2):
|
|||||||
self.jobs = []
|
self.jobs = []
|
||||||
self.actors = {}
|
self.actors = {}
|
||||||
self.rolling_id = 0
|
self.rolling_id = 0
|
||||||
self.p_actor = ray.remote(utils.ProgressBarActor).remote(self.sim_jobs_total)
|
self.p_actor = ray.remote(utils.ProgressBarActor).remote(
|
||||||
self.p_bars = utils.PBars(
|
self.sim_jobs_total, self.param_seq.num_steps
|
||||||
[
|
|
||||||
tqdm(
|
|
||||||
total=self.param_seq.num_steps,
|
|
||||||
unit="step",
|
|
||||||
desc="Simulating",
|
|
||||||
smoothing=0,
|
|
||||||
ncols=100,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
for i in range(1, self.sim_jobs_total + 1):
|
|
||||||
self.p_bars.append(
|
|
||||||
tqdm(
|
|
||||||
total=1,
|
|
||||||
desc=f"Worker {i}",
|
|
||||||
position=i,
|
|
||||||
ncols=100,
|
|
||||||
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def new_sim(self, variable_list: List[tuple], params: dict):
|
def new_sim(self, variable_list: List[tuple], params: dict):
|
||||||
while len(self.jobs) >= self.sim_jobs_total:
|
while len(self.jobs) >= self.sim_jobs_total:
|
||||||
@@ -679,11 +660,11 @@ class RaySimulations(Simulations, priority=2):
|
|||||||
def finish(self):
|
def finish(self):
|
||||||
while len(self.jobs) > 0:
|
while len(self.jobs) > 0:
|
||||||
self._collect_1_job()
|
self._collect_1_job()
|
||||||
self.p_bars.close()
|
ray.get(self.p_actor.close.remote())
|
||||||
|
|
||||||
def _collect_1_job(self):
|
def _collect_1_job(self):
|
||||||
ready, self.jobs = ray.wait(self.jobs, timeout=self.update_cluster_frequency)
|
ready, self.jobs = ray.wait(self.jobs, timeout=self.update_cluster_frequency)
|
||||||
self.update_pbars()
|
ray.get(self.p_actor.update_pbars.remote())
|
||||||
if len(ready) == 0:
|
if len(ready) == 0:
|
||||||
return
|
return
|
||||||
ray.get(ready)
|
ray.get(ready)
|
||||||
@@ -699,12 +680,6 @@ class RaySimulations(Simulations, priority=2):
|
|||||||
tot_cpus = min(tot_cpus, self.max_concurrent_jobs)
|
tot_cpus = min(tot_cpus, self.max_concurrent_jobs)
|
||||||
return int(min(self.param_seq.num_sim, tot_cpus))
|
return int(min(self.param_seq.num_sim, tot_cpus))
|
||||||
|
|
||||||
def update_pbars(self):
|
|
||||||
counters = ray.get(self.p_actor.wait_for_update.remote())
|
|
||||||
for counter, pbar in zip(counters, self.p_bars):
|
|
||||||
pbar.update(counter - pbar.n)
|
|
||||||
self.p_bars.print()
|
|
||||||
|
|
||||||
|
|
||||||
def run_simulation_sequence(
|
def run_simulation_sequence(
|
||||||
*config_files: os.PathLike,
|
*config_files: os.PathLike,
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import multiprocessing
|
|||||||
import socket
|
import socket
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union
|
from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union
|
||||||
from asyncio import Event
|
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -22,7 +21,7 @@ from copy import deepcopy
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, pbar_format, HUSH_PROGRESS
|
from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, HUSH_PROGRESS
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .math import *
|
from .math import *
|
||||||
|
|
||||||
@@ -32,6 +31,29 @@ from .math import *
|
|||||||
|
|
||||||
|
|
||||||
class PBars:
|
class PBars:
|
||||||
|
@classmethod
|
||||||
|
def auto(
|
||||||
|
cls, num_tot: int, num_sub_bars: int = 0, head_kwargs=None, worker_kwargs=None
|
||||||
|
) -> "PBars":
|
||||||
|
if head_kwargs is None:
|
||||||
|
head_kwargs = dict(unit="step", desc="Simulating", smoothing=0)
|
||||||
|
if worker_kwargs is None:
|
||||||
|
worker_kwargs = dict(
|
||||||
|
total=1,
|
||||||
|
desc="Worker {worker_id}",
|
||||||
|
bar_format="{l_bar}{bar}" "|[{elapsed}<{remaining}, " "{rate_fmt}{postfix}]",
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.getenv(HUSH_PROGRESS) is not None:
|
||||||
|
head_kwargs["file"] = worker_kwargs["file"] = StringIO()
|
||||||
|
p = cls([tqdm(total=num_tot, ncols=100, **head_kwargs)])
|
||||||
|
for i in range(1, num_sub_bars + 1):
|
||||||
|
kwargs = {k: v for k, v in worker_kwargs.items()}
|
||||||
|
if "desc" in kwargs:
|
||||||
|
kwargs["desc"] = kwargs["desc"].format(worker_id=i)
|
||||||
|
p.append(tqdm(position=i, ncols=100, **kwargs))
|
||||||
|
return p
|
||||||
|
|
||||||
def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None:
|
def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None:
|
||||||
if isinstance(pbars, tqdm):
|
if isinstance(pbars, tqdm):
|
||||||
self.pbars = [pbars]
|
self.pbars = [pbars]
|
||||||
@@ -54,14 +76,22 @@ class PBars:
|
|||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self.pbars[key]
|
return self.pbars[key]
|
||||||
|
|
||||||
def update(self):
|
def update(self, i=None, value=1):
|
||||||
for pbar in self:
|
if i is None:
|
||||||
pbar.update()
|
for pbar in self.pbars[1:]:
|
||||||
|
pbar.update(value)
|
||||||
|
else:
|
||||||
|
self.pbars[i].update(value)
|
||||||
|
self.pbars[0].update()
|
||||||
self.print()
|
self.print()
|
||||||
|
|
||||||
def append(self, pbar: tqdm):
|
def append(self, pbar: tqdm):
|
||||||
self.pbars.append(pbar)
|
self.pbars.append(pbar)
|
||||||
|
|
||||||
|
def reset(self, i):
|
||||||
|
self.pbars[i].update(-self.pbars[i].n)
|
||||||
|
self.print()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
for pbar in self.pbars:
|
for pbar in self.pbars:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
@@ -127,13 +157,9 @@ class ProgressTracker:
|
|||||||
|
|
||||||
|
|
||||||
class ProgressBarActor:
|
class ProgressBarActor:
|
||||||
counter: int
|
def __init__(self, num_workers: int, num_steps: int) -> None:
|
||||||
delta: int
|
|
||||||
event: Event
|
|
||||||
|
|
||||||
def __init__(self, num_workers: int) -> None:
|
|
||||||
self.counters = [0 for _ in range(num_workers + 1)]
|
self.counters = [0 for _ in range(num_workers + 1)]
|
||||||
self.event = Event()
|
self.p_bars = PBars.auto(num_steps, num_workers)
|
||||||
|
|
||||||
def update(self, worker_id: int, rel_pos: float = None) -> None:
|
def update(self, worker_id: int, rel_pos: float = None) -> None:
|
||||||
"""update a counter
|
"""update a counter
|
||||||
@@ -150,21 +176,17 @@ class ProgressBarActor:
|
|||||||
self.counters[worker_id] += 1
|
self.counters[worker_id] += 1
|
||||||
else:
|
else:
|
||||||
self.counters[worker_id] = rel_pos
|
self.counters[worker_id] = rel_pos
|
||||||
self.event.set()
|
|
||||||
|
|
||||||
async def wait_for_update(self) -> List[float]:
|
def update_pbars(self):
|
||||||
"""Blocking call.
|
for counter, pbar in zip(self.counters, self.p_bars):
|
||||||
|
pbar.update(counter - pbar.n)
|
||||||
|
self.p_bars.print()
|
||||||
|
|
||||||
Waits until somebody calls `update`, then returns a tuple of
|
def close(self):
|
||||||
the number of updates since the last call to
|
self.p_bars.close()
|
||||||
`wait_for_update`, and the total number of completed items.
|
|
||||||
"""
|
|
||||||
await self.event.wait()
|
|
||||||
self.event.clear()
|
|
||||||
return self.counters
|
|
||||||
|
|
||||||
|
|
||||||
def progress_worker(num_steps: int, progress_queue: multiprocessing.Queue):
|
def progress_worker(num_workers: int, num_steps: int, progress_queue: multiprocessing.Queue):
|
||||||
"""keeps track of progress on a separate thread
|
"""keeps track of progress on a separate thread
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -176,22 +198,15 @@ def progress_worker(num_steps: int, progress_queue: multiprocessing.Queue):
|
|||||||
Literal[0] : stop the worker and close the progress bars
|
Literal[0] : stop the worker and close the progress bars
|
||||||
Tuple[int, float] : worker id and relative progress between 0 and 1
|
Tuple[int, float] : worker id and relative progress between 0 and 1
|
||||||
"""
|
"""
|
||||||
kwargs = {}
|
pbars = PBars.auto(num_steps, num_workers)
|
||||||
if os.getenv(HUSH_PROGRESS) is not None:
|
while True:
|
||||||
kwargs = dict(file=StringIO())
|
raw = progress_queue.get()
|
||||||
pbars: Dict[int, tqdm] = {}
|
if raw == 0:
|
||||||
with tqdm(total=num_steps, desc="Simulating", unit="step", position=0, **kwargs) as tq:
|
pbars.close()
|
||||||
while True:
|
return
|
||||||
raw = progress_queue.get()
|
i, rel_pos = raw
|
||||||
if raw == 0:
|
pbars[i].update(rel_pos - pbars[i].n)
|
||||||
for pbar in pbars.values():
|
pbars[0].update()
|
||||||
pbar.close()
|
|
||||||
return
|
|
||||||
i, rel_pos = raw
|
|
||||||
if i not in pbars:
|
|
||||||
pbars[i] = tqdm(**pbar_format(i), **kwargs)
|
|
||||||
pbars[i].update(rel_pos - pbars[i].n)
|
|
||||||
tq.update()
|
|
||||||
|
|
||||||
|
|
||||||
def count_variations(config: dict) -> Tuple[int, int]:
|
def count_variations(config: dict) -> Tuple[int, int]:
|
||||||
|
|||||||
Reference in New Issue
Block a user