made ray optional
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
numpy
|
numpy
|
||||||
matplotlib
|
matplotlib
|
||||||
scipy
|
scipy
|
||||||
ray
|
|
||||||
toml
|
toml
|
||||||
tqdm
|
tqdm
|
||||||
@@ -23,8 +23,6 @@ install_requires =
|
|||||||
numba
|
numba
|
||||||
matplotlib
|
matplotlib
|
||||||
scipy
|
scipy
|
||||||
ray
|
|
||||||
send2trash
|
|
||||||
toml
|
toml
|
||||||
tqdm
|
tqdm
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
try:
|
||||||
import ray
|
import ray
|
||||||
|
except ImportError:
|
||||||
|
ray = None
|
||||||
|
|
||||||
from scgenerator.physics.simulate import (
|
from scgenerator.physics.simulate import (
|
||||||
run_simulation_sequence,
|
run_simulation_sequence,
|
||||||
@@ -84,6 +87,7 @@ def merge(args):
|
|||||||
|
|
||||||
|
|
||||||
def prep_ray(args):
|
def prep_ray(args):
|
||||||
|
if ray:
|
||||||
if args.start_ray:
|
if args.start_ray:
|
||||||
init_str = ray.init()
|
init_str = ray.init()
|
||||||
elif not args.no_ray:
|
elif not args.no_ray:
|
||||||
|
|||||||
@@ -35,8 +35,8 @@ def _set_debug():
|
|||||||
|
|
||||||
def get_logger(name=None):
|
def get_logger(name=None):
|
||||||
"""returns a logging.Logger instance. This function is there because if scgenerator
|
"""returns a logging.Logger instance. This function is there because if scgenerator
|
||||||
is used with ray, workers are not aware of any configuration done with the logging
|
is used with some multiprocessing library, workers are not aware of any configuration done
|
||||||
and so it must be reconfigured.
|
with the logging and so it must be reconfigured.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|||||||
@@ -1,25 +1,20 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Tuple, Type
|
from typing import Any, Dict, List, Tuple, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .. import initialize, io, utils, const, env
|
from .. import env, initialize, io, utils
|
||||||
from ..errors import IncompleteDataFolderError
|
from ..errors import IncompleteDataFolderError
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from . import pulse
|
from . import pulse
|
||||||
from .fiber import create_non_linear_op, fast_dispersion_op
|
from .fiber import create_non_linear_op, fast_dispersion_op
|
||||||
|
|
||||||
using_ray = False
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
using_ray = True
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
pass
|
ray = None
|
||||||
|
|
||||||
|
|
||||||
class RK4IP:
|
class RK4IP:
|
||||||
@@ -419,7 +414,7 @@ class Simulations:
|
|||||||
if isinstance(method, str):
|
if isinstance(method, str):
|
||||||
method = Simulations.simulation_methods_dict[method]
|
method = Simulations.simulation_methods_dict[method]
|
||||||
return method(param_seq, task_id)
|
return method(param_seq, task_id)
|
||||||
elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"] and using_ray:
|
elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"]:
|
||||||
return Simulations.get_best_method()(param_seq, task_id)
|
return Simulations.get_best_method()(param_seq, task_id)
|
||||||
else:
|
else:
|
||||||
return SequencialSimulations(param_seq, task_id)
|
return SequencialSimulations(param_seq, task_id)
|
||||||
@@ -435,6 +430,8 @@ class Simulations:
|
|||||||
data_folder : str, optional
|
data_folder : str, optional
|
||||||
path to the folder where data is saved, by default "scgenerator/"
|
path to the folder where data is saved, by default "scgenerator/"
|
||||||
"""
|
"""
|
||||||
|
if not self.is_available():
|
||||||
|
raise RuntimeError(f"{self.__class__} is currently not available")
|
||||||
self.logger = io.get_logger(__name__)
|
self.logger = io.get_logger(__name__)
|
||||||
self.id = int(task_id)
|
self.id = int(task_id)
|
||||||
|
|
||||||
@@ -600,7 +597,9 @@ class RaySimulations(Simulations, priority=2):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_available(cls):
|
def is_available(cls):
|
||||||
return using_ray and ray.is_initialized()
|
if ray:
|
||||||
|
return ray.is_initialized()
|
||||||
|
return False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -737,9 +736,4 @@ def resume_simulations(sim_dir: str, method: Type[Simulations] = None) -> Simula
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
|
||||||
ray.init()
|
|
||||||
except NameError:
|
|
||||||
pass
|
pass
|
||||||
config_file, *opts = sys.argv[1:]
|
|
||||||
new_simulation(config_file, *opts)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user