fixed multibars
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import scgenerator as sc
|
import scgenerator as sc
|
||||||
|
|
||||||
SIZE = 100
|
SIZE = 100
|
||||||
@@ -8,11 +10,15 @@ SIZE = 100
|
|||||||
|
|
||||||
def compute_stuff(num: int, pbar: sc.threading.Multibar):
|
def compute_stuff(num: int, pbar: sc.threading.Multibar):
|
||||||
speed = random.random() * 5
|
speed = random.random() * 5
|
||||||
|
out = 0
|
||||||
for i in pbar(range(SIZE), desc=f"num {num}"):
|
for i in pbar(range(SIZE), desc=f"num {num}"):
|
||||||
time.sleep(0.05 * speed * random.random())
|
# time.sleep(0.01 * speed * random.random())
|
||||||
|
out += np.abs(np.subtract.outer(np.random.rand(1 << 13), np.random.rand(1 << 13))).min()
|
||||||
|
if i == 32:
|
||||||
|
pbar.print(f"reached 32 in {num}")
|
||||||
# if random.random() > 0.98:
|
# if random.random() > 0.98:
|
||||||
# print(f"some text {i}")
|
# print(f"some text {i}")
|
||||||
return num
|
return out
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "scgenerator"
|
name = "scgenerator"
|
||||||
version = "0.4.4"
|
version = "0.4.5"
|
||||||
description = "Simulate nonlinear pulse propagation in optical fibers"
|
description = "Simulate nonlinear pulse propagation in optical fibers"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ class DelayedTqdm:
|
|||||||
self.label = name
|
self.label = name
|
||||||
if self.bar is not None:
|
if self.bar is not None:
|
||||||
self.bar.set_description(self.label)
|
self.bar.set_description(self.label)
|
||||||
|
self.bar.refresh()
|
||||||
|
|
||||||
|
|
||||||
class PositionGetter:
|
class PositionGetter:
|
||||||
@@ -127,7 +128,12 @@ class MultibarThread(Thread):
|
|||||||
sub_pbars: list[DelayedTqdm]
|
sub_pbars: list[DelayedTqdm]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, n_bars: int, n_show: int, queue: Queue, default_total: int | float | None = None
|
self,
|
||||||
|
n_bars: int,
|
||||||
|
n_show: int,
|
||||||
|
queue: Queue,
|
||||||
|
default_total: int | float | None = None,
|
||||||
|
desc: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
@@ -136,7 +142,10 @@ class MultibarThread(Thread):
|
|||||||
|
|
||||||
self.sub_pbars = [DelayedTqdm(pos_getter, total=default_total) for _ in range(n_bars)]
|
self.sub_pbars = [DelayedTqdm(pos_getter, total=default_total) for _ in range(n_bars)]
|
||||||
self.main_pbar = tqdm(
|
self.main_pbar = tqdm(
|
||||||
total=sum((bar.total or 0) for bar in self.sub_pbars), position=0, colour="green"
|
total=sum((bar.total or 0) for bar in self.sub_pbars),
|
||||||
|
position=0,
|
||||||
|
colour="green",
|
||||||
|
desc=desc,
|
||||||
)
|
)
|
||||||
self.started = False
|
self.started = False
|
||||||
self.finished = False
|
self.finished = False
|
||||||
@@ -233,8 +242,6 @@ class Multibar:
|
|||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.queue.put(PBarMessage(self.id, Command.FINISHED))
|
self.queue.put(PBarMessage(self.id, Command.FINISHED))
|
||||||
# sys.stdout = sys.__stdout__
|
|
||||||
# sys.stderr = sys.__stderr__
|
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
@@ -272,6 +279,12 @@ class Multibar:
|
|||||||
self.queue.put(PBarMessage(self.id, Command.FINISHED))
|
self.queue.put(PBarMessage(self.id, Command.FINISHED))
|
||||||
self.finished = True
|
self.finished = True
|
||||||
|
|
||||||
|
def set_description(self, desc: str):
|
||||||
|
self.queue.put(PBarMessage(self.id, Command.NEW_NAME, desc))
|
||||||
|
|
||||||
|
def print(self, msg: str):
|
||||||
|
self.queue.put(PBarMessage(self.id, Command.PRINT, msg))
|
||||||
|
|
||||||
def check_start_finish(self):
|
def check_start_finish(self):
|
||||||
if self.finished:
|
if self.finished:
|
||||||
raise ValueError(f"{self.__class__.__name__} {self.id} has already finished")
|
raise ValueError(f"{self.__class__.__name__} {self.id} has already finished")
|
||||||
@@ -306,10 +319,12 @@ class MultibarManager:
|
|||||||
yield from self.mbars
|
yield from self.mbars
|
||||||
|
|
||||||
|
|
||||||
def multibar(n_tasks: int, n_show: int, n_pertask: int | float | None = None) -> MultibarThread:
|
def multibar(
|
||||||
|
n_tasks: int, n_show: int, n_pertask: int | float | None = None, desc: str | None = None
|
||||||
|
) -> MultibarThread:
|
||||||
manager = multiprocessing.Manager()
|
manager = multiprocessing.Manager()
|
||||||
queue = manager.Queue()
|
queue = manager.Queue()
|
||||||
thread = MultibarThread(n_tasks, n_show, queue, n_pertask)
|
thread = MultibarThread(n_tasks, n_show, queue, n_pertask, desc)
|
||||||
|
|
||||||
mbars = [Multibar(i, queue, n_pertask) for i in range(n_tasks)]
|
mbars = [Multibar(i, queue, n_pertask) for i in range(n_tasks)]
|
||||||
return MultibarManager(thread, queue, mbars)
|
return MultibarManager(thread, queue, mbars)
|
||||||
@@ -321,12 +336,13 @@ def apply_with_progress(
|
|||||||
n_cpu: int | None = None,
|
n_cpu: int | None = None,
|
||||||
n_pertask: int | None = None,
|
n_pertask: int | None = None,
|
||||||
unpack: bool = True,
|
unpack: bool = True,
|
||||||
|
desc: str | None = None,
|
||||||
) -> list[T]:
|
) -> list[T]:
|
||||||
args = list(args)
|
args = list(args)
|
||||||
n_cpu = n_cpu or multiprocessing.cpu_count()
|
n_cpu = n_cpu or multiprocessing.cpu_count()
|
||||||
with (
|
with (
|
||||||
multiprocessing.Pool(n_cpu) as pool,
|
multiprocessing.Pool(n_cpu) as pool,
|
||||||
multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask) as bars,
|
multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask, desc=desc) as bars,
|
||||||
):
|
):
|
||||||
if unpack:
|
if unpack:
|
||||||
all_args = ((*arg, bars[i]) for i, arg in enumerate(args))
|
all_args = ((*arg, bars[i]) for i, arg in enumerate(args))
|
||||||
|
|||||||
@@ -152,6 +152,11 @@ def test_zip_bundle(tmp_path: Path):
|
|||||||
assert prop3.parameters.effective_area_file.path == new_aeff_path.name
|
assert prop3.parameters.effective_area_file.path == new_aeff_path.name
|
||||||
assert prop3.parameters.effective_area_file.prefix == "zip"
|
assert prop3.parameters.effective_area_file.prefix == "zip"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
DataFile(None, new_disp_path, None).load_bytes()
|
||||||
|
== prop3.parameters.dispersion_file.load_bytes()
|
||||||
|
)
|
||||||
|
|
||||||
with ZipFile(tmp_path / "file3.zip", "r") as zfile:
|
with ZipFile(tmp_path / "file3.zip", "r") as zfile:
|
||||||
with zfile.open(new_aeff_path.name) as file:
|
with zfile.open(new_aeff_path.name) as file:
|
||||||
assert file.read() == new_aeff_path.read_bytes()
|
assert file.read() == new_aeff_path.read_bytes()
|
||||||
|
|||||||
Reference in New Issue
Block a user