From 739db77dafe41b946b41bc1837fe5592e11758b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Mon, 7 Jun 2021 12:05:31 +0200 Subject: [PATCH] PBars improvements --- src/scgenerator/initialize.py | 16 ++---- src/scgenerator/io.py | 20 +++---- src/scgenerator/physics/simulate.py | 5 +- src/scgenerator/utils.py | 81 +++++++++++++++++------------ 4 files changed, 60 insertions(+), 62 deletions(-) diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 9bf761e..b29cd85 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -118,24 +118,14 @@ class RecoveryParamSequence(ParamSequence): started = self.num_sim sub_folders = io.get_data_dirs(io.get_sim_dir(self.id)) - pbar_store = utils.PBars( - tqdm( - total=len(sub_folders), - desc="Initial recovery process", - unit="sim", - ncols=100, - ) - ) - - for sub_folder in sub_folders: + for sub_folder in utils.PBars( + sub_folders, "Initial recovery", head_kwargs=dict(unit="sim") + ): num_left = io.num_left_to_propagate(sub_folder, z_num) if num_left == 0: self.num_sim -= 1 self.num_steps += num_left started -= 1 - pbar_store.update() - - pbar_store.close() self.num_steps += started * z_num self.single_sim = self.num_sim == 1 diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 43e1b66..4f0d36a 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -238,7 +238,8 @@ def check_data_integrity(sub_folders: List[Path], init_z_num: int): IncompleteDataFolderError raised if not all spectra are present in any folder """ - for sub_folder in sub_folders: + + for sub_folder in utils.PBars(sub_folders, "Checking integrity"): if num_left_to_propagate(sub_folder, init_z_num) != 0: raise IncompleteDataFolderError( f"not enough spectra of the specified {init_z_num} found in {sub_folder}" @@ -306,12 +307,11 @@ def build_path_trees(sim_dir: Path) -> List[PathTree]: sim_dir = sim_dir.resolve() path_branches: List[Tuple[Path, ...]] = [] to_check = list(sim_dir.glob("id*num*")) - pbar = utils.PBars.auto(len(to_check), desc="Building path trees") - for branch in map(build_path_branch, to_check): - if branch is not None: - path_branches.append(branch) - pbar.update() - pbar.close() + with utils.PBars(len(to_check), desc="Building path trees") as pbar: + for branch in map(build_path_branch, to_check): + if branch is not None: + path_branches.append(branch) + pbar.update() path_trees = group_path_branches(path_branches) return path_trees @@ -428,13 +428,9 @@ def merge(destination: os.PathLike, path_trees: List[PathTree] = None): destination / f"initial_config_{i}.toml", ) - pbar = utils.PBars.auto(len(path_trees), desc="Merging") - for path_tree in path_trees: + for path_tree in utils.PBars(path_trees, desc="Merging"): iden = PARAM_SEPARATOR.join(path_tree[-1][0].name.split()[2:-2]) merge_path_tree(path_tree, destination / iden) - pbar.update() - - pbar.close() def sim_dirs(path_trees: List[PathTree]) -> Generator[Path, None, None]: diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 3f6cf87..f2c2e16 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -325,7 +325,6 @@ class SequentialRK4IP(RK4IP): ) def step_saved(self): - self.pbars.update(0) self.pbars.update(1, self.z / self.z_final - self.pbars[1].n) @@ -509,9 +508,7 @@ class SequencialSimulations(Simulations, priority=0): def __init__(self, param_seq: initialize.ParamSequence, task_id): super().__init__(param_seq, task_id=task_id) - self.pbars = utils.PBars.auto( - self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1 - ) + self.pbars = utils.PBars(self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1) def new_sim(self, v_list_str: str, params: Dict[str, Any]): self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}") diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index 94ab6d9..36a3668 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -6,32 +6,44 @@ scgenerator module but some function may be used in any python program import collections -import datetime as dt import itertools -import logging import multiprocessing +import threading +import time +from collections import abc from copy import deepcopy from io import StringIO from pathlib import Path -import threading -from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union -import time +from typing import Any, Dict, Iterable, Iterator, List, Mapping, Tuple, TypeVar, Union import numpy as np -import ray from tqdm import tqdm from . import env from .const import PARAM_SEPARATOR, valid_variable -from .logger import get_logger from .math import * +T_ = TypeVar("T_") + class PBars: - @classmethod - def auto( - cls, num_tot: int, desc: str, num_sub_bars: int = 0, head_kwargs=None, worker_kwargs=None + def __init__( + self, + task: Union[int, Iterable[T_]], + desc: str, + num_sub_bars: int = 0, + head_kwargs=None, + worker_kwargs=None, ) -> "PBars": + + if isinstance(task, abc.Iterable): + self.iterator: Iterable[T_] = iter(task) + self.num_tot: int = len(task) + else: + self.num_tot: int = task + self.iterator = None + + self.policy = env.pbar_policy() if head_kwargs is None: head_kwargs = dict() if worker_kwargs is None: @@ -43,21 +55,13 @@ class PBars: if "print" not in env.pbar_policy(): head_kwargs["file"] = worker_kwargs["file"] = StringIO() head_kwargs["desc"] = desc - p = cls([tqdm(total=num_tot, ncols=100, ascii=False, **head_kwargs)]) + self.pbars = [tqdm(total=self.num_tot, ncols=100, ascii=False, **head_kwargs)] for i in range(1, num_sub_bars + 1): kwargs = {k: v for k, v in worker_kwargs.items()} if "desc" in kwargs: kwargs["desc"] = kwargs["desc"].format(worker_id=i) - p.append(tqdm(position=i, ncols=100, ascii=False, **kwargs)) - return p - - def __init__(self, pbars: Union[tqdm, List[tqdm]]) -> None: - self.policy = env.pbar_policy() - self.print_path = Path("progress " + pbars[0].desc).resolve() - if isinstance(pbars, tqdm): - self.pbars = [pbars] - else: - self.pbars = pbars + self.append(tqdm(position=i, ncols=100, ascii=False, **kwargs)) + self.print_path = Path("progress " + self.pbars[0].desc).resolve() self.open = True if "file" in self.policy: self.thread = threading.Thread(target=self.print_worker, daemon=True) @@ -80,7 +84,16 @@ class PBars: self.print() def __iter__(self): - yield from self.pbars + with self as pb: + for thing in self.iterator: + yield thing + pb.update() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() def __getitem__(self, key): return self.pbars[key] @@ -89,7 +102,7 @@ class PBars: if i is None: for pbar in self.pbars[1:]: pbar.update(value) - else: + elif i > 0: self.pbars[i].update(value) self.pbars[0].update() @@ -112,7 +125,7 @@ class PBars: class ProgressBarActor: def __init__(self, name: str, num_workers: int, num_steps: int) -> None: self.counters = [0 for _ in range(num_workers + 1)] - self.p_bars = PBars.auto( + self.p_bars = PBars( num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step") ) @@ -155,15 +168,17 @@ def progress_worker( Literal[0] : stop the worker and close the progress bars Tuple[int, float] : worker id and relative progress between 0 and 1 """ - pbars = PBars.auto(num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step")) - while True: - raw = progress_queue.get() - if raw == 0: - pbars.close() - return - i, rel_pos = raw - pbars[i].update(rel_pos - pbars[i].n) - pbars[0].update() + with PBars( + num_steps, "Simulating " + name, num_workers, head_kwargs=dict(unit="step") + ) as pbars: + while True: + raw = progress_queue.get() + if raw == 0: + return + i, rel_pos = raw + print(i) + pbars[i].update(rel_pos - pbars[i].n) + pbars[0].update() def count_variations(config: dict) -> Tuple[int, int]: