switch to actor pool
This commit is contained in:
@@ -3,6 +3,7 @@ import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Type
|
||||
from typing_extensions import runtime
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -25,53 +26,29 @@ class RK4IP:
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
task_id=0,
|
||||
n_percent=10,
|
||||
):
|
||||
"""A 1D solver using 4th order Runge-Kutta in the interaction picture
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sim_params : dict
|
||||
a flattened parameter dictionary containing :
|
||||
w_c : numpy.ndarray
|
||||
angular frequencies centered around 0 generated with scgenerator.initialize.wspace
|
||||
w0 : float
|
||||
central angular frequency of the pulse
|
||||
w_power_fact : numpy.ndarray
|
||||
precomputed factorial/peak_power operations on w_c (scgenerator.math.power_fact)
|
||||
spec_0 : numpy.ndarray
|
||||
initial spectral envelope as function of w_c
|
||||
z_targets : list
|
||||
target distances
|
||||
length : float
|
||||
length of the fiber
|
||||
beta : numpy.ndarray or Callable[[float], numpy.ndarray]
|
||||
beta coeficients (Taylor expansion of beta(w))
|
||||
gamma : float or Callable[[float], float]
|
||||
non-linear parameter
|
||||
t : numpy.ndarray
|
||||
time
|
||||
dt : float
|
||||
time resolution
|
||||
behaviors : list(str {'ss', 'raman', 'spm'})
|
||||
behaviors to include in the simulation given as a list of strings
|
||||
raman_type : str, optional
|
||||
type of raman modelisation if raman effect is present
|
||||
f_r, hr_w : (opt) arguments of delayed_raman_t (see there for infos)
|
||||
adapt_step_size : bool, optional
|
||||
if True (default), adapts the step size with conserved quantity methode
|
||||
error_ok : float
|
||||
tolerated relative error for the adaptive step size if adaptive
|
||||
step size is turned on, otherwise length of fixed steps in m
|
||||
params : Params
|
||||
parameters of the simulation
|
||||
save_data : bool, optional
|
||||
save calculated spectra to disk, by default False
|
||||
job_identifier : str, optional
|
||||
string identifying the parameter set, by default ""
|
||||
task_id : int, optional
|
||||
unique identifier of the session, by default 0
|
||||
n_percent : int, optional
|
||||
print/log progress update every n_percent, by default 10
|
||||
"""
|
||||
self.set(params, save_data, job_identifier, task_id)
|
||||
|
||||
def set(
|
||||
self,
|
||||
params: initialize.Params,
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
task_id=0,
|
||||
):
|
||||
|
||||
self.job_identifier = job_identifier
|
||||
self.id = task_id
|
||||
@@ -80,7 +57,6 @@ class RK4IP:
|
||||
self.sim_dir.mkdir(exist_ok=True)
|
||||
self.data_dir = self.sim_dir / self.job_identifier
|
||||
|
||||
self.n_percent = n_percent
|
||||
self.logger = get_logger(self.job_identifier)
|
||||
self.resuming = False
|
||||
self.save_data = save_data
|
||||
@@ -355,7 +331,10 @@ class MutliProcRK4IP(RK4IP):
|
||||
|
||||
|
||||
class RayRK4IP(RK4IP):
|
||||
def __init__(
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def set(
|
||||
self,
|
||||
params: initialize.Params,
|
||||
p_actor,
|
||||
@@ -363,18 +342,21 @@ class RayRK4IP(RK4IP):
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
task_id=0,
|
||||
n_percent=10,
|
||||
):
|
||||
self.worker_id = worker_id
|
||||
self.p_actor = p_actor
|
||||
super().__init__(
|
||||
super().set(
|
||||
params,
|
||||
save_data=save_data,
|
||||
job_identifier=job_identifier,
|
||||
task_id=task_id,
|
||||
n_percent=n_percent,
|
||||
)
|
||||
|
||||
def set_and_run(self, v):
|
||||
params, p_actor, worker_id, save_data, job_identifier, task_id = v
|
||||
self.set(params, p_actor, worker_id, save_data, job_identifier, task_id)
|
||||
self.run()
|
||||
|
||||
def step_saved(self):
|
||||
self.p_actor.update.remote(self.worker_id, self.z / self.z_final)
|
||||
self.p_actor.update.remote(0)
|
||||
@@ -622,57 +604,53 @@ class RaySimulations(Simulations, priority=2):
|
||||
)
|
||||
)
|
||||
|
||||
self.propagator = ray.remote(RayRK4IP).options(
|
||||
override_environment_variables=env.all_environ()
|
||||
)
|
||||
self.propagator = ray.remote(RayRK4IP).options(runtime_env=dict(env_vars=env.all_environ()))
|
||||
|
||||
self.update_cluster_frequency = 3
|
||||
self.jobs = []
|
||||
self.actors = {}
|
||||
self.pool = ray.util.ActorPool(self.propagator.remote() for _ in range(self.sim_jobs_total))
|
||||
self.num_submitted = 0
|
||||
self.rolling_id = 0
|
||||
self.p_actor = (
|
||||
ray.remote(utils.ProgressBarActor)
|
||||
.options(override_environment_variables=env.all_environ())
|
||||
.options(runtime_env=dict(env_vars=env.all_environ()))
|
||||
.remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
|
||||
)
|
||||
|
||||
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||
while len(self.jobs) >= self.sim_jobs_total:
|
||||
self._collect_1_job()
|
||||
while self.num_submitted >= self.sim_jobs_total:
|
||||
self.collect_1_job()
|
||||
|
||||
self.rolling_id = (self.rolling_id + 1) % self.sim_jobs_total
|
||||
|
||||
new_actor = self.propagator.remote(
|
||||
self.pool.submit(
|
||||
lambda a, v: a.set_and_run.remote(v),
|
||||
(
|
||||
params,
|
||||
self.p_actor,
|
||||
self.rolling_id + 1,
|
||||
save_data=True,
|
||||
job_identifier=v_list_str,
|
||||
task_id=self.id,
|
||||
True,
|
||||
v_list_str,
|
||||
self.id,
|
||||
),
|
||||
)
|
||||
new_job = new_actor.run.remote()
|
||||
self.num_submitted += 1
|
||||
|
||||
self.actors[new_job.task_id()] = new_actor
|
||||
self.jobs.append(new_job)
|
||||
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
|
||||
|
||||
self.logger.info(
|
||||
f"{self.param_seq.name} : launching simulation with {v_list_str}, job : {self.jobs[-1].hex()}"
|
||||
)
|
||||
def collect_1_job(self):
|
||||
ray.get(self.p_actor.update_pbars.remote())
|
||||
try:
|
||||
self.pool.get_next_unordered(self.update_cluster_frequency)
|
||||
ray.get(self.p_actor.update_pbars.remote())
|
||||
self.num_submitted -= 1
|
||||
except TimeoutError:
|
||||
return
|
||||
|
||||
def finish(self):
|
||||
while len(self.jobs) > 0:
|
||||
self._collect_1_job()
|
||||
while self.num_submitted > 0:
|
||||
self.collect_1_job()
|
||||
ray.get(self.p_actor.close.remote())
|
||||
|
||||
def _collect_1_job(self):
|
||||
ready, self.jobs = ray.wait(self.jobs, timeout=self.update_cluster_frequency)
|
||||
ray.get(self.p_actor.update_pbars.remote())
|
||||
if len(ready) == 0:
|
||||
return
|
||||
ray.get(ready)
|
||||
|
||||
del self.actors[ready[0].task_id()]
|
||||
|
||||
def stop(self):
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user