fixed multibars

This commit is contained in:
2024-02-15 15:58:31 +01:00
parent 866f8cd2ff
commit a0d597723c
4 changed files with 37 additions and 10 deletions

View File

@@ -1,6 +1,8 @@
import random
import time
import numpy as np
import scgenerator as sc
SIZE = 100
@@ -8,11 +10,15 @@ SIZE = 100
def compute_stuff(num: int, pbar: sc.threading.Multibar):
speed = random.random() * 5
out = 0
for i in pbar(range(SIZE), desc=f"num {num}"):
time.sleep(0.05 * speed * random.random())
# time.sleep(0.01 * speed * random.random())
out += np.abs(np.subtract.outer(np.random.rand(1 << 13), np.random.rand(1 << 13))).min()
if i == 32:
pbar.print(f"reached 32 in {num}")
# if random.random() > 0.98:
# print(f"some text {i}")
return num
return out
def main():

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "scgenerator"
version = "0.4.4"
version = "0.4.5"
description = "Simulate nonlinear pulse propagation in optical fibers"
readme = "README.md"
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]

View File

@@ -98,6 +98,7 @@ class DelayedTqdm:
self.label = name
if self.bar is not None:
self.bar.set_description(self.label)
self.bar.refresh()
class PositionGetter:
@@ -127,7 +128,12 @@ class MultibarThread(Thread):
sub_pbars: list[DelayedTqdm]
def __init__(
self, n_bars: int, n_show: int, queue: Queue, default_total: int | float | None = None
self,
n_bars: int,
n_show: int,
queue: Queue,
default_total: int | float | None = None,
desc: str | None = None,
):
super().__init__()
self.queue = queue
@@ -136,7 +142,10 @@ class MultibarThread(Thread):
self.sub_pbars = [DelayedTqdm(pos_getter, total=default_total) for _ in range(n_bars)]
self.main_pbar = tqdm(
total=sum((bar.total or 0) for bar in self.sub_pbars), position=0, colour="green"
total=sum((bar.total or 0) for bar in self.sub_pbars),
position=0,
colour="green",
desc=desc,
)
self.started = False
self.finished = False
@@ -233,8 +242,6 @@ class Multibar:
def close(self):
self.queue.put(PBarMessage(self.id, Command.FINISHED))
# sys.stdout = sys.__stdout__
# sys.stderr = sys.__stderr__
def __enter__(self):
return self
@@ -272,6 +279,12 @@ class Multibar:
self.queue.put(PBarMessage(self.id, Command.FINISHED))
self.finished = True
def set_description(self, desc: str):
self.queue.put(PBarMessage(self.id, Command.NEW_NAME, desc))
def print(self, msg: str):
self.queue.put(PBarMessage(self.id, Command.PRINT, msg))
def check_start_finish(self):
if self.finished:
raise ValueError(f"{self.__class__.__name__} {self.id} has already finished")
@@ -306,10 +319,12 @@ class MultibarManager:
yield from self.mbars
def multibar(n_tasks: int, n_show: int, n_pertask: int | float | None = None) -> MultibarThread:
def multibar(
n_tasks: int, n_show: int, n_pertask: int | float | None = None, desc: str | None = None
) -> MultibarThread:
manager = multiprocessing.Manager()
queue = manager.Queue()
thread = MultibarThread(n_tasks, n_show, queue, n_pertask)
thread = MultibarThread(n_tasks, n_show, queue, n_pertask, desc)
mbars = [Multibar(i, queue, n_pertask) for i in range(n_tasks)]
return MultibarManager(thread, queue, mbars)
@@ -321,12 +336,13 @@ def apply_with_progress(
n_cpu: int | None = None,
n_pertask: int | None = None,
unpack: bool = True,
desc: str | None = None,
) -> list[T]:
args = list(args)
n_cpu = n_cpu or multiprocessing.cpu_count()
with (
multiprocessing.Pool(n_cpu) as pool,
multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask) as bars,
multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask, desc=desc) as bars,
):
if unpack:
all_args = ((*arg, bars[i]) for i, arg in enumerate(args))

View File

@@ -152,6 +152,11 @@ def test_zip_bundle(tmp_path: Path):
assert prop3.parameters.effective_area_file.path == new_aeff_path.name
assert prop3.parameters.effective_area_file.prefix == "zip"
assert (
DataFile(None, new_disp_path, None).load_bytes()
== prop3.parameters.dispersion_file.load_bytes()
)
with ZipFile(tmp_path / "file3.zip", "r") as zfile:
with zfile.open(new_aeff_path.name) as file:
assert file.read() == new_aeff_path.read_bytes()