diff --git a/examples/playground_threading.py b/examples/playground_threading.py new file mode 100644 index 0000000..501ddff --- /dev/null +++ b/examples/playground_threading.py @@ -0,0 +1,33 @@ +from tqdm import tqdm +import scgenerator as sc +import numpy as np +import time + +s = "The file to read. File-like objects must support the".split() +size = 5 + + +def do_stuff(name: str, stuff: int, pbar=tqdm): + speed = np.random.rand() + for i in pbar(range(size)): + time.sleep(speed) + return np.arange(size * 4).reshape(size, 4) * len(name) + stuff + + +def main(): + shape = (len(s), size, 4) + out = np.zeros(shape) + out_control = np.zeros(shape) + + args = [(el, i) for i, el in enumerate(s)] + + for i, result in sc.threading.apply_with_progress_iter(do_stuff, args, 2): + print(i, result) + out[i] = result + for i, arg in enumerate(args): + out_control[i] = do_stuff(*arg) + assert np.all(out == out_control) + + +if __name__ == "__main__": + main() diff --git a/src/scgenerator/threading.py b/src/scgenerator/threading.py index 6412298..19c4755 100644 --- a/src/scgenerator/threading.py +++ b/src/scgenerator/threading.py @@ -345,3 +345,48 @@ def apply_with_progress( else: all_args = ((arg, bars[i]) for i, arg in enumerate(args)) return pool.starmap(func, all_args) + + +def _unpack_fn(args): + func, ind, *args = args + return ind, func(*args) + + +def apply_with_progress_iter( + func: Callable[..., T], + args: Iterable[Iterable[Any]], + n_cpu: int | None = None, + n_pertask: int | None = None, + unpack: bool = True, + desc: str | None = None, +) -> Iterator[tuple[int, T]]: + args = list(args) + n_cpu = n_cpu or multiprocessing.cpu_count() + with ( + multiprocessing.Pool(n_cpu) as pool, + multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask, desc=desc) as bars, + ): + if unpack: + all_args = ((func, i, *arg, bars[i]) for i, arg in enumerate(args)) + else: + all_args = ((func, i, arg, bars[i]) for i, arg in enumerate(args)) + yield from pool.imap_unordered(_unpack_fn, all_args) + + +def apply_with_progress_fast( + pool: multiprocessing.Pool, + func: Callable[..., T], + args: Iterable[Iterable[Any]], + n_cpu: int | None = None, + n_pertask: int | None = None, + unpack: bool = True, + desc: str | None = None, +) -> list[T]: + args = list(args) + n_cpu = n_cpu or multiprocessing.cpu_count() + with multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask, desc=desc) as bars: + if unpack: + all_args = ((*arg, bars[i]) for i, arg in enumerate(args)) + else: + all_args = ((arg, bars[i]) for i, arg in enumerate(args)) + return pool.starmap(func, all_args)