made ray optional
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
numpy
|
||||
matplotlib
|
||||
scipy
|
||||
ray
|
||||
toml
|
||||
tqdm
|
||||
@@ -23,8 +23,6 @@ install_requires =
|
||||
numba
|
||||
matplotlib
|
||||
scipy
|
||||
ray
|
||||
send2trash
|
||||
toml
|
||||
tqdm
|
||||
|
||||
|
||||
@@ -3,7 +3,10 @@ import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
import ray
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
ray = None
|
||||
|
||||
from scgenerator.physics.simulate import (
|
||||
run_simulation_sequence,
|
||||
@@ -84,6 +87,7 @@ def merge(args):
|
||||
|
||||
|
||||
def prep_ray(args):
|
||||
if ray:
|
||||
if args.start_ray:
|
||||
init_str = ray.init()
|
||||
elif not args.no_ray:
|
||||
|
||||
@@ -35,8 +35,8 @@ def _set_debug():
|
||||
|
||||
def get_logger(name=None):
|
||||
"""returns a logging.Logger instance. This function is there because if scgenerator
|
||||
is used with ray, workers are not aware of any configuration done with the logging
|
||||
and so it must be reconfigured.
|
||||
is used with some multiprocessing library, workers are not aware of any configuration done
|
||||
with the logging and so it must be reconfigured.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
@@ -1,25 +1,20 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from .. import initialize, io, utils, const, env
|
||||
from .. import env, initialize, io, utils
|
||||
from ..errors import IncompleteDataFolderError
|
||||
from ..logger import get_logger
|
||||
from . import pulse
|
||||
from .fiber import create_non_linear_op, fast_dispersion_op
|
||||
|
||||
using_ray = False
|
||||
try:
|
||||
import ray
|
||||
|
||||
using_ray = True
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
ray = None
|
||||
|
||||
|
||||
class RK4IP:
|
||||
@@ -419,7 +414,7 @@ class Simulations:
|
||||
if isinstance(method, str):
|
||||
method = Simulations.simulation_methods_dict[method]
|
||||
return method(param_seq, task_id)
|
||||
elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"] and using_ray:
|
||||
elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"]:
|
||||
return Simulations.get_best_method()(param_seq, task_id)
|
||||
else:
|
||||
return SequencialSimulations(param_seq, task_id)
|
||||
@@ -435,6 +430,8 @@ class Simulations:
|
||||
data_folder : str, optional
|
||||
path to the folder where data is saved, by default "scgenerator/"
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise RuntimeError(f"{self.__class__} is currently not available")
|
||||
self.logger = io.get_logger(__name__)
|
||||
self.id = int(task_id)
|
||||
|
||||
@@ -600,7 +597,9 @@ class RaySimulations(Simulations, priority=2):
|
||||
|
||||
@classmethod
|
||||
def is_available(cls):
|
||||
return using_ray and ray.is_initialized()
|
||||
if ray:
|
||||
return ray.is_initialized()
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -737,9 +736,4 @@ def resume_simulations(sim_dir: str, method: Type[Simulations] = None) -> Simula
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
ray.init()
|
||||
except NameError:
|
||||
pass
|
||||
config_file, *opts = sys.argv[1:]
|
||||
new_simulation(config_file, *opts)
|
||||
|
||||
Reference in New Issue
Block a user