multi-pbar threading
This commit is contained in:
25
deploy.nu
Normal file
25
deploy.nu
Normal file
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env nu
|
||||
|
||||
let version = (open pyproject.toml).project.version
|
||||
let zipfile = $"scgenerator-($version).zip"
|
||||
let help_page = "scgenerator.html"
|
||||
|
||||
let filelist = (
|
||||
git ls-files |
|
||||
lines |
|
||||
where {not ($in | str ends-with ".afphoto")} |
|
||||
where {$in != "deploy.nu"}
|
||||
)
|
||||
|
||||
^zip $zipfile ...$filelist
|
||||
|
||||
open README.md |
|
||||
str replace -a __VERSION__ $version |
|
||||
pandoc --standalone --toc --template build/template.html |
|
||||
save -f $'build/($help_page)'
|
||||
|
||||
scp -O $zipfile $"fibnas:/volume1/web/($zipfile)"
|
||||
scp -O $zipfile $"fibnas:/volume1/web/scgenerator-latest.zip"
|
||||
scp -O $'build/($help_page)' $"fibnas:/volume1/web/($help_page)"
|
||||
|
||||
rm $zipfile
|
||||
0
examples/Optica_PM2000D/Optica_PM2000D.toml
Executable file → Normal file
0
examples/Optica_PM2000D/Optica_PM2000D.toml
Executable file → Normal file
0
examples/Travers/Travers.toml
Executable file → Normal file
0
examples/Travers/Travers.toml
Executable file → Normal file
@@ -31,6 +31,22 @@ def propagate_all(n):
|
||||
return spec, props
|
||||
|
||||
|
||||
def quick_test():
|
||||
t = sc.tspace(dt=1e-15, t_num=2048)
|
||||
spec_0 = np.fft.fft(10e3 * np.exp(-((t / 70e-15) ** 2)))
|
||||
spec = sc.Spectrum(
|
||||
[np.exp(2j * np.pi * np.random.rand()) + spec_0 for _ in range(20)],
|
||||
sc.wspace(t) + sc.units.nm_rads(800),
|
||||
t,
|
||||
)
|
||||
_, (top, bot) = plt.subplots(2, 1, constrained_layout=True, height_ratios=[1, 5], sharex=True)
|
||||
bot.plot(spec.wl_disp * 1e9, spec[0].wl_int)
|
||||
top.plot(spec.wl_disp * 1e9, spec.coherence())
|
||||
top.set_xlim(750, 850)
|
||||
bot.set_yscale("log")
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
n = 1
|
||||
spec, props = propagate_all(n)
|
||||
@@ -52,4 +68,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
quick_test()
|
||||
# main()
|
||||
|
||||
22
examples/show_multi_bar.py
Normal file
22
examples/show_multi_bar.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import scgenerator as sc
|
||||
|
||||
SIZE = 100
|
||||
|
||||
|
||||
def compute_stuff(num: int, pbar: sc.threading.Multibar):
|
||||
speed = random.random() * 5
|
||||
for i in pbar(range(SIZE), desc=f"num {num}"):
|
||||
time.sleep(0.05 * speed * random.random())
|
||||
if random.random() > 0.98:
|
||||
print(f"some text {i}")
|
||||
|
||||
|
||||
def main():
|
||||
sc.threading.apply_with_progress(compute_stuff, range(12), n_cpu=4, unpack=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "scgenerator"
|
||||
version = "0.4.1"
|
||||
version = "0.4.2"
|
||||
description = "Simulate nonlinear pulse propagation in optical fibers"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# ruff: noqa
|
||||
from scgenerator import math
|
||||
from scgenerator import math, threading
|
||||
from scgenerator.physics import units
|
||||
from scgenerator import io, noise, operators, plotting
|
||||
from scgenerator.helpers import *
|
||||
|
||||
@@ -1,19 +1,336 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import Queue
|
||||
from threading import Thread
|
||||
from threading import Lock, Thread
|
||||
from typing import NamedTuple, TypeVar, Callable, Any
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import sys
|
||||
|
||||
class Multibar(Thread):
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Command(Enum):
|
||||
INCREASE = auto()
|
||||
NEW_TOTAL = auto()
|
||||
FINISHED = auto()
|
||||
STOP_ALL = auto()
|
||||
NEW_NAME = auto()
|
||||
PRINT = auto()
|
||||
|
||||
|
||||
class PBarMessage(NamedTuple):
|
||||
bar_id: int
|
||||
command: Command
|
||||
data: int | float | str = 0
|
||||
|
||||
|
||||
class QueueStdOut:
|
||||
id: int
|
||||
queue: Queue
|
||||
bars: list[tqdm]
|
||||
|
||||
def __init__(self, bars: list[tqdm], queue: Queue):
|
||||
def __init__(self, id: int, queue: Queue):
|
||||
self.id = id
|
||||
self.queue = queue
|
||||
self.bars = bars
|
||||
|
||||
def write(self, s: str):
|
||||
self.queue.put(PBarMessage(self.id, Command.PRINT, s))
|
||||
|
||||
|
||||
class DelayedTqdm:
|
||||
current: int | float
|
||||
total: int | float | None
|
||||
position_getter: PositionGetter
|
||||
position: int | None
|
||||
bar: tqdm | None
|
||||
label: str | None
|
||||
|
||||
def __init__(self, position_getter: PositionGetter, total: int | float | None = None):
|
||||
self.current = 0
|
||||
self.total = total
|
||||
self.position_getter = position_getter
|
||||
self.position = None
|
||||
self.bar = None
|
||||
self.label = None
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self.bar is not None and self.bar.disable
|
||||
|
||||
def show(self):
|
||||
if self.bar is not None:
|
||||
return
|
||||
self.position = self.position_getter.book()
|
||||
if self.position is None:
|
||||
return
|
||||
self.bar = tqdm(
|
||||
total=self.total,
|
||||
position=self.position,
|
||||
desc=self.label,
|
||||
initial=self.current,
|
||||
leave=False,
|
||||
)
|
||||
|
||||
def close(self):
|
||||
if self.bar is None or self.bar.disable:
|
||||
return
|
||||
self.bar.close()
|
||||
self.position_getter.free(self.position)
|
||||
|
||||
def update(self, amount: int | float = 1):
|
||||
self.show()
|
||||
self.current += amount
|
||||
if self.bar is not None:
|
||||
self.bar.update(amount)
|
||||
|
||||
def new_total(self, total: int | float):
|
||||
self.show()
|
||||
self.total = total
|
||||
if self.bar is not None:
|
||||
self.bar.total = total
|
||||
|
||||
def new_name(self, name: str):
|
||||
self.show()
|
||||
self.label = name
|
||||
if self.bar is not None:
|
||||
self.bar.set_description(self.label)
|
||||
|
||||
|
||||
class PositionGetter:
|
||||
def __init__(self, n_tot: int, n: int, offset: int = 0):
|
||||
self.remaining = n_tot
|
||||
self.offset = offset
|
||||
self.busy_at = [False for _ in range(n_tot)]
|
||||
|
||||
@property
|
||||
def max_pos(self) -> int:
|
||||
return max(self.busy_at)
|
||||
|
||||
def book(self) -> int:
|
||||
for pos, busy in enumerate(self.busy_at[: self.remaining]):
|
||||
if not busy:
|
||||
self.busy_at[pos] = True
|
||||
self.remaining -= 1
|
||||
return pos + self.offset
|
||||
|
||||
def free(self, pos: int):
|
||||
self.busy_at[pos - self.offset] = False
|
||||
|
||||
|
||||
class MultibarThread(Thread):
|
||||
queue: Queue
|
||||
main_pbar: tqdm
|
||||
sub_pbars: list[DelayedTqdm]
|
||||
|
||||
def __init__(
|
||||
self, n_bars: int, n_show: int, queue: Queue, default_total: int | float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
|
||||
pos_getter = PositionGetter(n_bars, n_show, offset=1)
|
||||
|
||||
self.sub_pbars = [DelayedTqdm(pos_getter, total=default_total) for _ in range(n_bars)]
|
||||
self.main_pbar = tqdm(
|
||||
total=sum((bar.total or 0) for bar in self.sub_pbars), position=0, colour="green"
|
||||
)
|
||||
self.started = False
|
||||
self.finished = False
|
||||
self.lock = Lock()
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
try:
|
||||
self.close()
|
||||
except AttributeError:
|
||||
# maybe eager thread cleanup upon external error
|
||||
if (exc_type, exc_value, traceback) == (None, None, None):
|
||||
raise
|
||||
|
||||
def start(self):
|
||||
if not self.started:
|
||||
self.started = True
|
||||
super().start()
|
||||
return self
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
bar_id, amount = self.queue.get(True, None)
|
||||
self.bars[bar_id].update(amount)
|
||||
self.bars[0].update(amount)
|
||||
msg = self._get_next()
|
||||
if msg.command is Command.INCREASE:
|
||||
self.sub_pbars[msg.bar_id].update(msg.data)
|
||||
self.main_pbar.update(msg.data)
|
||||
elif msg.command is Command.NEW_TOTAL:
|
||||
pbar = self.sub_pbars[msg.bar_id].new_total(msg.data)
|
||||
self.main_pbar.total = sum((bar.total or 0) for bar in self.sub_pbars)
|
||||
elif msg.command is Command.NEW_NAME:
|
||||
self.sub_pbars[msg.bar_id].new_name(msg.data)
|
||||
elif msg.command is Command.PRINT:
|
||||
self.main_pbar.write(f"{msg.bar_id}: {msg.data}")
|
||||
self.main_pbar.refresh()
|
||||
elif msg.command is Command.FINISHED:
|
||||
pbar = self.sub_pbars[msg.bar_id]
|
||||
pbar.close()
|
||||
if all(bar.finished for bar in self.sub_pbars):
|
||||
self.main_pbar.close()
|
||||
elif msg.command is Command.STOP_ALL:
|
||||
self.close_all()
|
||||
with self.lock:
|
||||
self.finished = True
|
||||
return
|
||||
|
||||
def _get_next(self) -> PBarMessage:
|
||||
return self.queue.get(True, None)
|
||||
|
||||
def close_all(self):
|
||||
for bar in self.sub_pbars:
|
||||
bar.close()
|
||||
self.main_pbar.close()
|
||||
|
||||
def close(self):
|
||||
with self.lock:
|
||||
if not self.finished:
|
||||
self.queue.put(PBarMessage(0, Command.STOP_ALL))
|
||||
self.join()
|
||||
|
||||
|
||||
class Multibar:
|
||||
"""lives in a subprocess"""
|
||||
|
||||
id: int
|
||||
total: float | int
|
||||
current: float | int
|
||||
queue: Queue | None = None
|
||||
|
||||
def __init__(
|
||||
self, id: int, queue: Queue, total: int | float | None = None, start: int | float = 0
|
||||
):
|
||||
self.id = id
|
||||
self.total = total
|
||||
self.current = start
|
||||
self.queue = queue
|
||||
self.finished = False
|
||||
self.stdout = QueueStdOut(self.id, self.queue)
|
||||
|
||||
sys.stdout = self.stdout
|
||||
sys.stderr = self.stdout
|
||||
|
||||
print(sys.stdout, file=sys.__stdout__)
|
||||
|
||||
def update(self, amount: float | int = 1):
|
||||
self.check_start_finish()
|
||||
self.current += amount
|
||||
self.queue.put(PBarMessage(self.id, Command.INCREASE, amount))
|
||||
if self.total is not None and self.current >= self.total:
|
||||
self.queue.put(PBarMessage(self.id, Command.FINISHED))
|
||||
self.finished = True
|
||||
|
||||
def close(self):
|
||||
self.queue.put(PBarMessage(self.id, Command.FINISHED))
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
def __call__(
|
||||
self, iterable: Iterable[T], total: int | float | None = None, desc: str | None = None
|
||||
) -> Iterator[T]:
|
||||
self.check_start_finish()
|
||||
try:
|
||||
new_len = len(iterable)
|
||||
except TypeError:
|
||||
if total is None:
|
||||
raise
|
||||
new_len = total
|
||||
|
||||
if desc is not None:
|
||||
self.queue.put(PBarMessage(self.id, Command.NEW_NAME, desc))
|
||||
self.queue.put(PBarMessage(self.id, Command.NEW_TOTAL, new_len))
|
||||
self.total = total
|
||||
|
||||
iterable = iter(iterable)
|
||||
while True:
|
||||
try:
|
||||
yield next(iterable)
|
||||
except StopIteration:
|
||||
break
|
||||
except Exception:
|
||||
self.queue.put(PBarMessage(self.id, Command.FINISHED))
|
||||
self.finished = True
|
||||
raise
|
||||
self.update()
|
||||
self.queue.put(PBarMessage(self.id, Command.FINISHED))
|
||||
self.finished = True
|
||||
|
||||
def check_start_finish(self):
|
||||
if self.finished:
|
||||
raise ValueError(f"{self.__class__.__name__} {self.id} has already finished")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultibarManager:
|
||||
thread: MultibarThread
|
||||
queue: Queue
|
||||
mbars: list[Multibar]
|
||||
|
||||
def __enter__(self):
|
||||
self.thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
try:
|
||||
self.thread.close()
|
||||
# self.queue.close()
|
||||
# self.queue.join_thread()
|
||||
except AttributeError:
|
||||
# maybe eager thread cleanup upon external error
|
||||
if (exc_type, exc_value, traceback) == (None, None, None):
|
||||
raise
|
||||
|
||||
def __getitem__(self, key: int) -> Multibar:
|
||||
if not self.thread.is_alive():
|
||||
raise RuntimeError("You must start the multibar thread before using the multibars")
|
||||
return self.mbars[key]
|
||||
|
||||
def __iter__(self) -> Iterator[Multibar]:
|
||||
yield from self.mbars
|
||||
|
||||
|
||||
def multibar(n_tasks: int, n_show: int, n_pertask: int | float | None = None) -> MultibarThread:
|
||||
manager = multiprocessing.Manager()
|
||||
queue = manager.Queue()
|
||||
thread = MultibarThread(n_tasks, n_show, queue, n_pertask)
|
||||
|
||||
mbars = [Multibar(i, queue, n_pertask) for i in range(n_tasks)]
|
||||
return MultibarManager(thread, queue, mbars)
|
||||
|
||||
|
||||
def apply_with_progress(
|
||||
func: Callable[..., T],
|
||||
args: Iterable[Iterable[Any]],
|
||||
n_cpu: int | None = None,
|
||||
n_pertask: int | None = None,
|
||||
unpack: bool = True,
|
||||
) -> list[T]:
|
||||
args = list(args)
|
||||
n_cpu = n_cpu or multiprocessing.cpu_count()
|
||||
with (
|
||||
multiprocessing.Pool(n_cpu) as pool,
|
||||
multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask) as bars,
|
||||
):
|
||||
if unpack:
|
||||
all_args = ((*arg, bars[i]) for i, arg in enumerate(args))
|
||||
else:
|
||||
all_args = ((arg, bars[i]) for i, arg in enumerate(args))
|
||||
return pool.starmap(func, all_args)
|
||||
|
||||
Reference in New Issue
Block a user