removed special Ray RK4IP

This commit is contained in:
Benoît Sierro
2021-05-07 10:50:32 +02:00
parent 9a68ebbac6
commit c737dea213
3 changed files with 30 additions and 26 deletions

View File

@@ -1,9 +1,10 @@
import argparse import argparse
import os import os
import random import random
import sys
import ray import ray
from scgenerator.physics.simulate import new_simulations, resume_simulations from scgenerator.physics.simulate import new_simulations, resume_simulations, SequencialSimulations
def create_parser(): def create_parser():
@@ -23,6 +24,12 @@ def create_parser():
help="assume no ray instance has been started beforehand", help="assume no ray instance has been started beforehand",
) )
parser.add_argument(
"--no-ray",
action="store_true",
help="force not to use ray",
)
run_parser = subparsers.add_parser("run", help="run a simulation from a config file") run_parser = subparsers.add_parser("run", help="run a simulation from a config file")
run_parser.add_argument("config", help="path to the toml configuration file") run_parser.add_argument("config", help="path to the toml configuration file")
@@ -48,15 +55,18 @@ def run_sim(args):
if args.start_ray: if args.start_ray:
init_str = ray.init() init_str = ray.init()
else: elif not args.no_ray:
init_str = ray.init( init_str = ray.init(
address="auto", address="auto",
_node_ip_address=os.environ.get("ip_head", "127.0.0.1").split(":")[0], _node_ip_address=os.environ.get("ip_head", "127.0.0.1").split(":")[0],
_redis_password=os.environ.get("redis_password", "caco1234"), _redis_password=os.environ.get("redis_password", "caco1234"),
) )
print(init_str) print(init_str)
sim = new_simulations(args.config, args.id) if args.no_ray:
sim = new_simulations(args.config, args.id, Method=SequencialSimulations)
else:
sim = new_simulations(args.config, args.id)
sim.run() sim.run()
@@ -71,7 +81,7 @@ def resume_sim(args):
_redis_password=os.environ.get("redis_password", "caco1234"), _redis_password=os.environ.get("redis_password", "caco1234"),
) )
print(init_str) print(init_str)
sim = resume_simulations(args.data_dir, args.id) sim = resume_simulations(args.data_dir, args.id)
sim.run() sim.run()

View File

@@ -381,7 +381,8 @@ def merge_same_simulations(path: str):
base_folders.add(base_folder) base_folders.add(base_folder)
sim_num, param_num = utils.count_variations(config) sim_num, param_num = utils.count_variations(config)
pt = utils.ProgressTracker(sim_num, logger=logger, prefix="merging data : ") pt = utils.ProgressTracker(sim_num * z_num, logger=logger, prefix="merging data : ")
print(f"{pt.max=}")
spectra = [] spectra = []
for z_id in range(z_num): for z_id in range(z_num):

View File

@@ -1,13 +1,14 @@
import os import os
import sys
from datetime import datetime from datetime import datetime
from typing import List, Tuple, Type from typing import List, Tuple, Type
import numpy as np import numpy as np
from .. import initialize, io, utils from .. import initialize, io, utils
from ..errors import IncompleteDataFolderError
from ..logger import get_logger from ..logger import get_logger
from . import pulse from . import pulse
from ..errors import IncompleteDataFolderError
from .fiber import create_non_linear_op, fast_dispersion_op from .fiber import create_non_linear_op, fast_dispersion_op
using_ray = False using_ray = False
@@ -301,23 +302,6 @@ class RK4IP:
return h, h_next_step, new_spectrum return h, h_next_step, new_spectrum
class RayRK4IP(RK4IP):
def __init__(
self, sim_params, data_queue, save_data=False, job_identifier="", task_id=0, n_percent=10
):
self.queue = data_queue
super().__init__(
sim_params,
save_data=save_data,
job_identifier=job_identifier,
task_id=task_id,
n_percent=n_percent,
)
def _save_data(self, data: np.ndarray, name: str):
self.queue.put((name, self.job_identifier, data))
class Simulations: class Simulations:
"""The recommended way to run simulations. """The recommended way to run simulations.
New Simulations child classes can be written and must implement the following New Simulations child classes can be written and must implement the following
@@ -473,7 +457,7 @@ class RaySimulations(Simulations, available=using_ray, priority=1):
) )
) )
self.propagator = ray.remote(RayRK4IP).options( self.propagator = ray.remote(RK4IP).options(
override_environment_variables=io.get_all_environ() override_environment_variables=io.get_all_environ()
) )
self.sim_jobs_per_node = min( self.sim_jobs_per_node = min(
@@ -490,7 +474,7 @@ class RaySimulations(Simulations, available=using_ray, priority=1):
v_list_str = utils.format_variable_list(variable_list) v_list_str = utils.format_variable_list(variable_list)
new_actor = self.propagator.remote( new_actor = self.propagator.remote(
params, self.buffer.queue, save_data=True, job_identifier=v_list_str, task_id=self.id params, save_data=True, job_identifier=v_list_str, task_id=self.id
) )
new_job = new_actor.run.remote() new_job = new_actor.run.remote()
@@ -560,3 +544,12 @@ def _new_simulations(
return Simulations.get_best_method()(param_seq, task_id, data_folder=data_folder) return Simulations.get_best_method()(param_seq, task_id, data_folder=data_folder)
else: else:
return SequencialSimulations(param_seq, task_id, data_folder=data_folder) return SequencialSimulations(param_seq, task_id, data_folder=data_folder)
if __name__ == "__main__":
try:
ray.init()
except NameError:
pass
config_file, *opts = sys.argv[1:]
new_simulations(config_file, *opts)