made ray optional

This commit is contained in:
Benoît Sierro
2021-06-07 16:31:53 +02:00
parent f89d082bca
commit a124ccba40
5 changed files with 28 additions and 33 deletions

View File

@@ -1,6 +1,5 @@
numpy numpy
matplotlib matplotlib
scipy scipy
ray
toml toml
tqdm tqdm

View File

@@ -23,8 +23,6 @@ install_requires =
numba numba
matplotlib matplotlib
scipy scipy
ray
send2trash
toml toml
tqdm tqdm

View File

@@ -3,7 +3,10 @@ import os
from pathlib import Path from pathlib import Path
import random import random
try:
import ray import ray
except ImportError:
ray = None
from scgenerator.physics.simulate import ( from scgenerator.physics.simulate import (
run_simulation_sequence, run_simulation_sequence,
@@ -84,6 +87,7 @@ def merge(args):
def prep_ray(args): def prep_ray(args):
if ray:
if args.start_ray: if args.start_ray:
init_str = ray.init() init_str = ray.init()
elif not args.no_ray: elif not args.no_ray:

View File

@@ -35,8 +35,8 @@ def _set_debug():
def get_logger(name=None): def get_logger(name=None):
"""returns a logging.Logger instance. This function is there because if scgenerator """returns a logging.Logger instance. This function is there because if scgenerator
is used with ray, workers are not aware of any configuration done with the logging is used with some multiprocessing library, workers are not aware of any configuration done
and so it must be reconfigured. with the logging and so it must be reconfigured.
Parameters Parameters
---------- ----------

View File

@@ -1,25 +1,20 @@
import multiprocessing import multiprocessing
import os import os
import sys
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Tuple, Type from typing import Any, Dict, List, Tuple, Type
import numpy as np import numpy as np
from tqdm import tqdm
from .. import initialize, io, utils, const, env from .. import env, initialize, io, utils
from ..errors import IncompleteDataFolderError from ..errors import IncompleteDataFolderError
from ..logger import get_logger from ..logger import get_logger
from . import pulse from . import pulse
from .fiber import create_non_linear_op, fast_dispersion_op from .fiber import create_non_linear_op, fast_dispersion_op
using_ray = False
try: try:
import ray import ray
using_ray = True
except ModuleNotFoundError: except ModuleNotFoundError:
pass ray = None
class RK4IP: class RK4IP:
@@ -419,7 +414,7 @@ class Simulations:
if isinstance(method, str): if isinstance(method, str):
method = Simulations.simulation_methods_dict[method] method = Simulations.simulation_methods_dict[method]
return method(param_seq, task_id) return method(param_seq, task_id)
elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"] and using_ray: elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"]:
return Simulations.get_best_method()(param_seq, task_id) return Simulations.get_best_method()(param_seq, task_id)
else: else:
return SequencialSimulations(param_seq, task_id) return SequencialSimulations(param_seq, task_id)
@@ -435,6 +430,8 @@ class Simulations:
data_folder : str, optional data_folder : str, optional
path to the folder where data is saved, by default "scgenerator/" path to the folder where data is saved, by default "scgenerator/"
""" """
if not self.is_available():
raise RuntimeError(f"{self.__class__} is currently not available")
self.logger = io.get_logger(__name__) self.logger = io.get_logger(__name__)
self.id = int(task_id) self.id = int(task_id)
@@ -600,7 +597,9 @@ class RaySimulations(Simulations, priority=2):
@classmethod @classmethod
def is_available(cls): def is_available(cls):
return using_ray and ray.is_initialized() if ray:
return ray.is_initialized()
return False
def __init__( def __init__(
self, self,
@@ -737,9 +736,4 @@ def resume_simulations(sim_dir: str, method: Type[Simulations] = None) -> Simula
if __name__ == "__main__": if __name__ == "__main__":
try:
ray.init()
except NameError:
pass pass
config_file, *opts = sys.argv[1:]
new_simulation(config_file, *opts)