diff --git a/deploy b/deploy old mode 100755 new mode 100644 diff --git a/deploy.nu b/deploy.nu new file mode 100644 index 0000000..57bd666 --- /dev/null +++ b/deploy.nu @@ -0,0 +1,25 @@ +#!/usr/bin/env nu + +let version = (open pyproject.toml).project.version +let zipfile = $"scgenerator-($version).zip" +let help_page = "scgenerator.html" + +let filelist = ( + git ls-files | + lines | + where {not ($in | str ends-with ".afphoto")} | + where {$in != "deploy.nu"} +) + +^zip $zipfile ...$filelist + +open README.md | + str replace -a __VERSION__ $version | + pandoc --standalone --toc --template build/template.html | + save -f $'build/($help_page)' + +scp -O $zipfile $"fibnas:/volume1/web/($zipfile)" +scp -O $zipfile $"fibnas:/volume1/web/scgenerator-latest.zip" +scp -O $'build/($help_page)' $"fibnas:/volume1/web/($help_page)" + +rm $zipfile diff --git a/examples/Optica_PM2000D/Optica_PM2000D.toml b/examples/Optica_PM2000D/Optica_PM2000D.toml old mode 100755 new mode 100644 diff --git a/examples/Travers/Travers.toml b/examples/Travers/Travers.toml old mode 100755 new mode 100644 diff --git a/examples/compute_coherence.py b/examples/compute_coherence.py index 3eaef86..ad3d804 100644 --- a/examples/compute_coherence.py +++ b/examples/compute_coherence.py @@ -31,6 +31,22 @@ def propagate_all(n): return spec, props +def quick_test(): + t = sc.tspace(dt=1e-15, t_num=2048) + spec_0 = np.fft.fft(10e3 * np.exp(-((t / 70e-15) ** 2))) + spec = sc.Spectrum( + [np.exp(2j * np.pi * np.random.rand()) + spec_0 for _ in range(20)], + sc.wspace(t) + sc.units.nm_rads(800), + t, + ) + _, (top, bot) = plt.subplots(2, 1, constrained_layout=True, height_ratios=[1, 5], sharex=True) + bot.plot(spec.wl_disp * 1e9, spec[0].wl_int) + top.plot(spec.wl_disp * 1e9, spec.coherence()) + top.set_xlim(750, 850) + bot.set_yscale("log") + plt.show() + + def main(): n = 1 spec, props = propagate_all(n) @@ -52,4 +68,5 @@ def main(): if __name__ == "__main__": - main() + quick_test() + # main() diff --git a/examples/show_multi_bar.py b/examples/show_multi_bar.py new file mode 100644 index 0000000..4ed2d91 --- /dev/null +++ b/examples/show_multi_bar.py @@ -0,0 +1,22 @@ +import random +import time + +import scgenerator as sc + +SIZE = 100 + + +def compute_stuff(num: int, pbar: sc.threading.Multibar): + speed = random.random() * 5 + for i in pbar(range(SIZE), desc=f"num {num}"): + time.sleep(0.05 * speed * random.random()) + if random.random() > 0.98: + print(f"some text {i}") + + +def main(): + sc.threading.apply_with_progress(compute_stuff, range(12), n_cpu=4, unpack=False) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index f860e8a..8a100b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.4.1" +version = "0.4.2" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 7eac434..abe2950 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -1,5 +1,5 @@ # ruff: noqa -from scgenerator import math +from scgenerator import math, threading from scgenerator.physics import units from scgenerator import io, noise, operators, plotting from scgenerator.helpers import * diff --git a/src/scgenerator/threading.py b/src/scgenerator/threading.py index 16e049c..f6e295e 100644 --- a/src/scgenerator/threading.py +++ b/src/scgenerator/threading.py @@ -1,19 +1,336 @@ +from __future__ import annotations + +import multiprocessing +from collections.abc import Iterable, Iterator +from dataclasses import dataclass +from enum import Enum, auto from multiprocessing import Queue -from threading import Thread +from threading import Lock, Thread +from typing import NamedTuple, TypeVar, Callable, Any from tqdm import tqdm +import sys -class Multibar(Thread): +T = TypeVar("T") + + +class Command(Enum): + INCREASE = auto() + NEW_TOTAL = auto() + FINISHED = auto() + STOP_ALL = auto() + NEW_NAME = auto() + PRINT = auto() + + +class PBarMessage(NamedTuple): + bar_id: int + command: Command + data: int | float | str = 0 + + +class QueueStdOut: + id: int queue: Queue - bars: list[tqdm] - def __init__(self, bars: list[tqdm], queue: Queue): + def __init__(self, id: int, queue: Queue): + self.id = id self.queue = queue - self.bars = bars + + def write(self, s: str): + self.queue.put(PBarMessage(self.id, Command.PRINT, s)) + + +class DelayedTqdm: + current: int | float + total: int | float | None + position_getter: PositionGetter + position: int | None + bar: tqdm | None + label: str | None + + def __init__(self, position_getter: PositionGetter, total: int | float | None = None): + self.current = 0 + self.total = total + self.position_getter = position_getter + self.position = None + self.bar = None + self.label = None + + @property + def finished(self) -> bool: + return self.bar is not None and self.bar.disable + + def show(self): + if self.bar is not None: + return + self.position = self.position_getter.book() + if self.position is None: + return + self.bar = tqdm( + total=self.total, + position=self.position, + desc=self.label, + initial=self.current, + leave=False, + ) + + def close(self): + if self.bar is None or self.bar.disable: + return + self.bar.close() + self.position_getter.free(self.position) + + def update(self, amount: int | float = 1): + self.show() + self.current += amount + if self.bar is not None: + self.bar.update(amount) + + def new_total(self, total: int | float): + self.show() + self.total = total + if self.bar is not None: + self.bar.total = total + + def new_name(self, name: str): + self.show() + self.label = name + if self.bar is not None: + self.bar.set_description(self.label) + + +class PositionGetter: + def __init__(self, n_tot: int, n: int, offset: int = 0): + self.remaining = n_tot + self.offset = offset + self.busy_at = [False for _ in range(n_tot)] + + @property + def max_pos(self) -> int: + return max(self.busy_at) + + def book(self) -> int: + for pos, busy in enumerate(self.busy_at[: self.remaining]): + if not busy: + self.busy_at[pos] = True + self.remaining -= 1 + return pos + self.offset + + def free(self, pos: int): + self.busy_at[pos - self.offset] = False + + +class MultibarThread(Thread): + queue: Queue + main_pbar: tqdm + sub_pbars: list[DelayedTqdm] + + def __init__( + self, n_bars: int, n_show: int, queue: Queue, default_total: int | float | None = None + ): + super().__init__() + self.queue = queue + + pos_getter = PositionGetter(n_bars, n_show, offset=1) + + self.sub_pbars = [DelayedTqdm(pos_getter, total=default_total) for _ in range(n_bars)] + self.main_pbar = tqdm( + total=sum((bar.total or 0) for bar in self.sub_pbars), position=0, colour="green" + ) + self.started = False + self.finished = False + self.lock = Lock() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + try: + self.close() + except AttributeError: + # maybe eager thread cleanup upon external error + if (exc_type, exc_value, traceback) == (None, None, None): + raise + + def start(self): + if not self.started: + self.started = True + super().start() + return self def run(self): while True: - bar_id, amount = self.queue.get(True, None) - self.bars[bar_id].update(amount) - self.bars[0].update(amount) + msg = self._get_next() + if msg.command is Command.INCREASE: + self.sub_pbars[msg.bar_id].update(msg.data) + self.main_pbar.update(msg.data) + elif msg.command is Command.NEW_TOTAL: + pbar = self.sub_pbars[msg.bar_id].new_total(msg.data) + self.main_pbar.total = sum((bar.total or 0) for bar in self.sub_pbars) + elif msg.command is Command.NEW_NAME: + self.sub_pbars[msg.bar_id].new_name(msg.data) + elif msg.command is Command.PRINT: + self.main_pbar.write(f"{msg.bar_id}: {msg.data}") + self.main_pbar.refresh() + elif msg.command is Command.FINISHED: + pbar = self.sub_pbars[msg.bar_id] + pbar.close() + if all(bar.finished for bar in self.sub_pbars): + self.main_pbar.close() + elif msg.command is Command.STOP_ALL: + self.close_all() + with self.lock: + self.finished = True + return + + def _get_next(self) -> PBarMessage: + return self.queue.get(True, None) + + def close_all(self): + for bar in self.sub_pbars: + bar.close() + self.main_pbar.close() + + def close(self): + with self.lock: + if not self.finished: + self.queue.put(PBarMessage(0, Command.STOP_ALL)) + self.join() + + +class Multibar: + """lives in a subprocess""" + + id: int + total: float | int + current: float | int + queue: Queue | None = None + + def __init__( + self, id: int, queue: Queue, total: int | float | None = None, start: int | float = 0 + ): + self.id = id + self.total = total + self.current = start + self.queue = queue + self.finished = False + self.stdout = QueueStdOut(self.id, self.queue) + + sys.stdout = self.stdout + sys.stderr = self.stdout + + print(sys.stdout, file=sys.__stdout__) + + def update(self, amount: float | int = 1): + self.check_start_finish() + self.current += amount + self.queue.put(PBarMessage(self.id, Command.INCREASE, amount)) + if self.total is not None and self.current >= self.total: + self.queue.put(PBarMessage(self.id, Command.FINISHED)) + self.finished = True + + def close(self): + self.queue.put(PBarMessage(self.id, Command.FINISHED)) + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __call__( + self, iterable: Iterable[T], total: int | float | None = None, desc: str | None = None + ) -> Iterator[T]: + self.check_start_finish() + try: + new_len = len(iterable) + except TypeError: + if total is None: + raise + new_len = total + + if desc is not None: + self.queue.put(PBarMessage(self.id, Command.NEW_NAME, desc)) + self.queue.put(PBarMessage(self.id, Command.NEW_TOTAL, new_len)) + self.total = total + + iterable = iter(iterable) + while True: + try: + yield next(iterable) + except StopIteration: + break + except Exception: + self.queue.put(PBarMessage(self.id, Command.FINISHED)) + self.finished = True + raise + self.update() + self.queue.put(PBarMessage(self.id, Command.FINISHED)) + self.finished = True + + def check_start_finish(self): + if self.finished: + raise ValueError(f"{self.__class__.__name__} {self.id} has already finished") + + +@dataclass +class MultibarManager: + thread: MultibarThread + queue: Queue + mbars: list[Multibar] + + def __enter__(self): + self.thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + try: + self.thread.close() + # self.queue.close() + # self.queue.join_thread() + except AttributeError: + # maybe eager thread cleanup upon external error + if (exc_type, exc_value, traceback) == (None, None, None): + raise + + def __getitem__(self, key: int) -> Multibar: + if not self.thread.is_alive(): + raise RuntimeError("You must start the multibar thread before using the multibars") + return self.mbars[key] + + def __iter__(self) -> Iterator[Multibar]: + yield from self.mbars + + +def multibar(n_tasks: int, n_show: int, n_pertask: int | float | None = None) -> MultibarThread: + manager = multiprocessing.Manager() + queue = manager.Queue() + thread = MultibarThread(n_tasks, n_show, queue, n_pertask) + + mbars = [Multibar(i, queue, n_pertask) for i in range(n_tasks)] + return MultibarManager(thread, queue, mbars) + + +def apply_with_progress( + func: Callable[..., T], + args: Iterable[Iterable[Any]], + n_cpu: int | None = None, + n_pertask: int | None = None, + unpack: bool = True, +) -> list[T]: + args = list(args) + n_cpu = n_cpu or multiprocessing.cpu_count() + with ( + multiprocessing.Pool(n_cpu) as pool, + multibar(n_tasks=len(args), n_show=n_cpu, n_pertask=n_pertask) as bars, + ): + if unpack: + all_args = ((*arg, bars[i]) for i, arg in enumerate(args)) + else: + all_args = ((arg, bars[i]) for i, arg in enumerate(args)) + return pool.starmap(func, all_args)