diff --git a/examples/show_multi_bar.py b/examples/show_multi_bar.py index 3399cb9..142411d 100644 --- a/examples/show_multi_bar.py +++ b/examples/show_multi_bar.py @@ -1,6 +1,8 @@ import random import time +import numpy as np + import scgenerator as sc SIZE = 100 @@ -8,11 +10,15 @@ SIZE = 100 def compute_stuff(num: int, pbar: sc.threading.Multibar): speed = random.random() * 5 + out = 0 for i in pbar(range(SIZE), desc=f"num {num}"): - time.sleep(0.05 * speed * random.random()) + # time.sleep(0.01 * speed * random.random()) + out += np.abs(np.subtract.outer(np.random.rand(1 << 13), np.random.rand(1 << 13))).min() + if i == 32: + pbar.print(f"reached 32 in {num}") # if random.random() > 0.98: # print(f"some text {i}") - return num + return out def main(): diff --git a/pyproject.toml b/pyproject.toml index 2a4cdc0..b092b2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.4.4" +version = "0.4.5" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] diff --git a/src/scgenerator/threading.py b/src/scgenerator/threading.py index eb11f77..4c17398 100644 --- a/src/scgenerator/threading.py +++ b/src/scgenerator/threading.py @@ -98,6 +98,7 @@ class DelayedTqdm: self.label = name if self.bar is not None: self.bar.set_description(self.label) + self.bar.refresh() class PositionGetter: @@ -127,7 +128,12 @@ class MultibarThread(Thread): sub_pbars: list[DelayedTqdm] def __init__( - self, n_bars: int, n_show: int, queue: Queue, default_total: int | float | None = None + self, + n_bars: int, + n_show: int, + queue: Queue, + default_total: int | float | None = None, + desc: str | None = None, ): super().__init__() self.queue = queue @@ -136,7 +142,10 @@ class MultibarThread(Thread): self.sub_pbars = [DelayedTqdm(pos_getter, total=default_total) for _ in range(n_bars)] self.main_pbar = tqdm( - total=sum((bar.total or 0) for bar in self.sub_pbars), position=0, colour="green" + total=sum((bar.total or 0) for bar in self.sub_pbars), + position=0, + colour="green", + desc=desc, ) self.started = False self.finished = False @@ -233,8 +242,6 @@ class Multibar: def close(self): self.queue.put(PBarMessage(self.id, Command.FINISHED)) - # sys.stdout = sys.__stdout__ - # sys.stderr = sys.__stderr__ def __enter__(self): return self @@ -272,6 +279,12 @@ class Multibar: self.queue.put(PBarMessage(self.id, Command.FINISHED)) self.finished = True + def set_description(self, desc: str): + self.queue.put(PBarMessage(self.id, Command.NEW_NAME, desc)) + + def print(self, msg: str): + self.queue.put(PBarMessage(self.id, Command.PRINT, msg)) + def check_start_finish(self): if self.finished: raise ValueError(f"{self.__class__.__name__} {self.id} has already finished") @@ -306,10 +319,12 @@ class MultibarManager: yield from self.mbars -def multibar(n_tasks: int, n_show: int, n_pertask: int | float | None = None) -> MultibarThread: +def multibar( + n_tasks: int, n_show: int, n_pertask: int | float | None = None, desc: str | None = None +) -> MultibarThread: manager = multiprocessing.Manager() queue = manager.Queue() - thread = MultibarThread(n_tasks, n_show, queue, n_pertask) + thread = MultibarThread(n_tasks, n_show, queue, n_pertask, desc) mbars = [Multibar(i, queue, n_pertask) for i in range(n_tasks)] return MultibarManager(thread, queue, mbars) @@ -321,12 +336,13 @@ def apply_with_progress( n_cpu: int | None = None, n_pertask: int | None = None, unpack: bool = True, + desc: str | None = None, ) -> list[T]: args = list(args) n_cpu = n_cpu or multiprocessing.cpu_count() with ( multiprocessing.Pool(n_cpu) as pool, - multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask) as bars, + multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask, desc=desc) as bars, ): if unpack: all_args = ((*arg, bars[i]) for i, arg in enumerate(args)) diff --git a/tests/test_io_handlers.py b/tests/test_io_handlers.py index b84e65b..dad107f 100644 --- a/tests/test_io_handlers.py +++ b/tests/test_io_handlers.py @@ -152,6 +152,11 @@ def test_zip_bundle(tmp_path: Path): assert prop3.parameters.effective_area_file.path == new_aeff_path.name assert prop3.parameters.effective_area_file.prefix == "zip" + assert ( + DataFile(None, new_disp_path, None).load_bytes() + == prop3.parameters.dispersion_file.load_bytes() + ) + with ZipFile(tmp_path / "file3.zip", "r") as zfile: with zfile.open(new_aeff_path.name) as file: assert file.read() == new_aeff_path.read_bytes()