fixed multibars
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import scgenerator as sc
|
||||
|
||||
SIZE = 100
|
||||
@@ -8,11 +10,15 @@ SIZE = 100
|
||||
|
||||
def compute_stuff(num: int, pbar: sc.threading.Multibar):
|
||||
speed = random.random() * 5
|
||||
out = 0
|
||||
for i in pbar(range(SIZE), desc=f"num {num}"):
|
||||
time.sleep(0.05 * speed * random.random())
|
||||
# time.sleep(0.01 * speed * random.random())
|
||||
out += np.abs(np.subtract.outer(np.random.rand(1 << 13), np.random.rand(1 << 13))).min()
|
||||
if i == 32:
|
||||
pbar.print(f"reached 32 in {num}")
|
||||
# if random.random() > 0.98:
|
||||
# print(f"some text {i}")
|
||||
return num
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "scgenerator"
|
||||
version = "0.4.4"
|
||||
version = "0.4.5"
|
||||
description = "Simulate nonlinear pulse propagation in optical fibers"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
||||
|
||||
@@ -98,6 +98,7 @@ class DelayedTqdm:
|
||||
self.label = name
|
||||
if self.bar is not None:
|
||||
self.bar.set_description(self.label)
|
||||
self.bar.refresh()
|
||||
|
||||
|
||||
class PositionGetter:
|
||||
@@ -127,7 +128,12 @@ class MultibarThread(Thread):
|
||||
sub_pbars: list[DelayedTqdm]
|
||||
|
||||
def __init__(
|
||||
self, n_bars: int, n_show: int, queue: Queue, default_total: int | float | None = None
|
||||
self,
|
||||
n_bars: int,
|
||||
n_show: int,
|
||||
queue: Queue,
|
||||
default_total: int | float | None = None,
|
||||
desc: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
@@ -136,7 +142,10 @@ class MultibarThread(Thread):
|
||||
|
||||
self.sub_pbars = [DelayedTqdm(pos_getter, total=default_total) for _ in range(n_bars)]
|
||||
self.main_pbar = tqdm(
|
||||
total=sum((bar.total or 0) for bar in self.sub_pbars), position=0, colour="green"
|
||||
total=sum((bar.total or 0) for bar in self.sub_pbars),
|
||||
position=0,
|
||||
colour="green",
|
||||
desc=desc,
|
||||
)
|
||||
self.started = False
|
||||
self.finished = False
|
||||
@@ -233,8 +242,6 @@ class Multibar:
|
||||
|
||||
def close(self):
|
||||
self.queue.put(PBarMessage(self.id, Command.FINISHED))
|
||||
# sys.stdout = sys.__stdout__
|
||||
# sys.stderr = sys.__stderr__
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -272,6 +279,12 @@ class Multibar:
|
||||
self.queue.put(PBarMessage(self.id, Command.FINISHED))
|
||||
self.finished = True
|
||||
|
||||
def set_description(self, desc: str):
|
||||
self.queue.put(PBarMessage(self.id, Command.NEW_NAME, desc))
|
||||
|
||||
def print(self, msg: str):
|
||||
self.queue.put(PBarMessage(self.id, Command.PRINT, msg))
|
||||
|
||||
def check_start_finish(self):
|
||||
if self.finished:
|
||||
raise ValueError(f"{self.__class__.__name__} {self.id} has already finished")
|
||||
@@ -306,10 +319,12 @@ class MultibarManager:
|
||||
yield from self.mbars
|
||||
|
||||
|
||||
def multibar(n_tasks: int, n_show: int, n_pertask: int | float | None = None) -> MultibarThread:
|
||||
def multibar(
|
||||
n_tasks: int, n_show: int, n_pertask: int | float | None = None, desc: str | None = None
|
||||
) -> MultibarThread:
|
||||
manager = multiprocessing.Manager()
|
||||
queue = manager.Queue()
|
||||
thread = MultibarThread(n_tasks, n_show, queue, n_pertask)
|
||||
thread = MultibarThread(n_tasks, n_show, queue, n_pertask, desc)
|
||||
|
||||
mbars = [Multibar(i, queue, n_pertask) for i in range(n_tasks)]
|
||||
return MultibarManager(thread, queue, mbars)
|
||||
@@ -321,12 +336,13 @@ def apply_with_progress(
|
||||
n_cpu: int | None = None,
|
||||
n_pertask: int | None = None,
|
||||
unpack: bool = True,
|
||||
desc: str | None = None,
|
||||
) -> list[T]:
|
||||
args = list(args)
|
||||
n_cpu = n_cpu or multiprocessing.cpu_count()
|
||||
with (
|
||||
multiprocessing.Pool(n_cpu) as pool,
|
||||
multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask) as bars,
|
||||
multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask, desc=desc) as bars,
|
||||
):
|
||||
if unpack:
|
||||
all_args = ((*arg, bars[i]) for i, arg in enumerate(args))
|
||||
|
||||
@@ -152,6 +152,11 @@ def test_zip_bundle(tmp_path: Path):
|
||||
assert prop3.parameters.effective_area_file.path == new_aeff_path.name
|
||||
assert prop3.parameters.effective_area_file.prefix == "zip"
|
||||
|
||||
assert (
|
||||
DataFile(None, new_disp_path, None).load_bytes()
|
||||
== prop3.parameters.dispersion_file.load_bytes()
|
||||
)
|
||||
|
||||
with ZipFile(tmp_path / "file3.zip", "r") as zfile:
|
||||
with zfile.open(new_aeff_path.name) as file:
|
||||
assert file.read() == new_aeff_path.read_bytes()
|
||||
|
||||
Reference in New Issue
Block a user