switch to actor pool
This commit is contained in:
@@ -3,6 +3,7 @@ import os
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple, Type
|
from typing import Dict, List, Tuple, Type
|
||||||
|
from typing_extensions import runtime
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -25,53 +26,29 @@ class RK4IP:
|
|||||||
save_data=False,
|
save_data=False,
|
||||||
job_identifier="",
|
job_identifier="",
|
||||||
task_id=0,
|
task_id=0,
|
||||||
n_percent=10,
|
|
||||||
):
|
):
|
||||||
"""A 1D solver using 4th order Runge-Kutta in the interaction picture
|
"""A 1D solver using 4th order Runge-Kutta in the interaction picture
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
sim_params : dict
|
params : Params
|
||||||
a flattened parameter dictionary containing :
|
parameters of the simulation
|
||||||
w_c : numpy.ndarray
|
|
||||||
angular frequencies centered around 0 generated with scgenerator.initialize.wspace
|
|
||||||
w0 : float
|
|
||||||
central angular frequency of the pulse
|
|
||||||
w_power_fact : numpy.ndarray
|
|
||||||
precomputed factorial/peak_power operations on w_c (scgenerator.math.power_fact)
|
|
||||||
spec_0 : numpy.ndarray
|
|
||||||
initial spectral envelope as function of w_c
|
|
||||||
z_targets : list
|
|
||||||
target distances
|
|
||||||
length : float
|
|
||||||
length of the fiber
|
|
||||||
beta : numpy.ndarray or Callable[[float], numpy.ndarray]
|
|
||||||
beta coeficients (Taylor expansion of beta(w))
|
|
||||||
gamma : float or Callable[[float], float]
|
|
||||||
non-linear parameter
|
|
||||||
t : numpy.ndarray
|
|
||||||
time
|
|
||||||
dt : float
|
|
||||||
time resolution
|
|
||||||
behaviors : list(str {'ss', 'raman', 'spm'})
|
|
||||||
behaviors to include in the simulation given as a list of strings
|
|
||||||
raman_type : str, optional
|
|
||||||
type of raman modelisation if raman effect is present
|
|
||||||
f_r, hr_w : (opt) arguments of delayed_raman_t (see there for infos)
|
|
||||||
adapt_step_size : bool, optional
|
|
||||||
if True (default), adapts the step size with conserved quantity methode
|
|
||||||
error_ok : float
|
|
||||||
tolerated relative error for the adaptive step size if adaptive
|
|
||||||
step size is turned on, otherwise length of fixed steps in m
|
|
||||||
save_data : bool, optional
|
save_data : bool, optional
|
||||||
save calculated spectra to disk, by default False
|
save calculated spectra to disk, by default False
|
||||||
job_identifier : str, optional
|
job_identifier : str, optional
|
||||||
string identifying the parameter set, by default ""
|
string identifying the parameter set, by default ""
|
||||||
task_id : int, optional
|
task_id : int, optional
|
||||||
unique identifier of the session, by default 0
|
unique identifier of the session, by default 0
|
||||||
n_percent : int, optional
|
|
||||||
print/log progress update every n_percent, by default 10
|
|
||||||
"""
|
"""
|
||||||
|
self.set(params, save_data, job_identifier, task_id)
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
params: initialize.Params,
|
||||||
|
save_data=False,
|
||||||
|
job_identifier="",
|
||||||
|
task_id=0,
|
||||||
|
):
|
||||||
|
|
||||||
self.job_identifier = job_identifier
|
self.job_identifier = job_identifier
|
||||||
self.id = task_id
|
self.id = task_id
|
||||||
@@ -80,7 +57,6 @@ class RK4IP:
|
|||||||
self.sim_dir.mkdir(exist_ok=True)
|
self.sim_dir.mkdir(exist_ok=True)
|
||||||
self.data_dir = self.sim_dir / self.job_identifier
|
self.data_dir = self.sim_dir / self.job_identifier
|
||||||
|
|
||||||
self.n_percent = n_percent
|
|
||||||
self.logger = get_logger(self.job_identifier)
|
self.logger = get_logger(self.job_identifier)
|
||||||
self.resuming = False
|
self.resuming = False
|
||||||
self.save_data = save_data
|
self.save_data = save_data
|
||||||
@@ -355,7 +331,10 @@ class MutliProcRK4IP(RK4IP):
|
|||||||
|
|
||||||
|
|
||||||
class RayRK4IP(RK4IP):
|
class RayRK4IP(RK4IP):
|
||||||
def __init__(
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set(
|
||||||
self,
|
self,
|
||||||
params: initialize.Params,
|
params: initialize.Params,
|
||||||
p_actor,
|
p_actor,
|
||||||
@@ -363,18 +342,21 @@ class RayRK4IP(RK4IP):
|
|||||||
save_data=False,
|
save_data=False,
|
||||||
job_identifier="",
|
job_identifier="",
|
||||||
task_id=0,
|
task_id=0,
|
||||||
n_percent=10,
|
|
||||||
):
|
):
|
||||||
self.worker_id = worker_id
|
self.worker_id = worker_id
|
||||||
self.p_actor = p_actor
|
self.p_actor = p_actor
|
||||||
super().__init__(
|
super().set(
|
||||||
params,
|
params,
|
||||||
save_data=save_data,
|
save_data=save_data,
|
||||||
job_identifier=job_identifier,
|
job_identifier=job_identifier,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
n_percent=n_percent,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_and_run(self, v):
|
||||||
|
params, p_actor, worker_id, save_data, job_identifier, task_id = v
|
||||||
|
self.set(params, p_actor, worker_id, save_data, job_identifier, task_id)
|
||||||
|
self.run()
|
||||||
|
|
||||||
def step_saved(self):
|
def step_saved(self):
|
||||||
self.p_actor.update.remote(self.worker_id, self.z / self.z_final)
|
self.p_actor.update.remote(self.worker_id, self.z / self.z_final)
|
||||||
self.p_actor.update.remote(0)
|
self.p_actor.update.remote(0)
|
||||||
@@ -622,57 +604,53 @@ class RaySimulations(Simulations, priority=2):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.propagator = ray.remote(RayRK4IP).options(
|
self.propagator = ray.remote(RayRK4IP).options(runtime_env=dict(env_vars=env.all_environ()))
|
||||||
override_environment_variables=env.all_environ()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.update_cluster_frequency = 3
|
self.update_cluster_frequency = 3
|
||||||
self.jobs = []
|
self.jobs = []
|
||||||
self.actors = {}
|
self.pool = ray.util.ActorPool(self.propagator.remote() for _ in range(self.sim_jobs_total))
|
||||||
|
self.num_submitted = 0
|
||||||
self.rolling_id = 0
|
self.rolling_id = 0
|
||||||
self.p_actor = (
|
self.p_actor = (
|
||||||
ray.remote(utils.ProgressBarActor)
|
ray.remote(utils.ProgressBarActor)
|
||||||
.options(override_environment_variables=env.all_environ())
|
.options(runtime_env=dict(env_vars=env.all_environ()))
|
||||||
.remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
|
.remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
|
||||||
)
|
)
|
||||||
|
|
||||||
def new_sim(self, v_list_str: str, params: initialize.Params):
|
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||||
while len(self.jobs) >= self.sim_jobs_total:
|
while self.num_submitted >= self.sim_jobs_total:
|
||||||
self._collect_1_job()
|
self.collect_1_job()
|
||||||
|
|
||||||
self.rolling_id = (self.rolling_id + 1) % self.sim_jobs_total
|
self.rolling_id = (self.rolling_id + 1) % self.sim_jobs_total
|
||||||
|
self.pool.submit(
|
||||||
new_actor = self.propagator.remote(
|
lambda a, v: a.set_and_run.remote(v),
|
||||||
params,
|
(
|
||||||
self.p_actor,
|
params,
|
||||||
self.rolling_id + 1,
|
self.p_actor,
|
||||||
save_data=True,
|
self.rolling_id + 1,
|
||||||
job_identifier=v_list_str,
|
True,
|
||||||
task_id=self.id,
|
v_list_str,
|
||||||
|
self.id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
new_job = new_actor.run.remote()
|
self.num_submitted += 1
|
||||||
|
|
||||||
self.actors[new_job.task_id()] = new_actor
|
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
|
||||||
self.jobs.append(new_job)
|
|
||||||
|
|
||||||
self.logger.info(
|
def collect_1_job(self):
|
||||||
f"{self.param_seq.name} : launching simulation with {v_list_str}, job : {self.jobs[-1].hex()}"
|
ray.get(self.p_actor.update_pbars.remote())
|
||||||
)
|
try:
|
||||||
|
self.pool.get_next_unordered(self.update_cluster_frequency)
|
||||||
|
ray.get(self.p_actor.update_pbars.remote())
|
||||||
|
self.num_submitted -= 1
|
||||||
|
except TimeoutError:
|
||||||
|
return
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
while len(self.jobs) > 0:
|
while self.num_submitted > 0:
|
||||||
self._collect_1_job()
|
self.collect_1_job()
|
||||||
ray.get(self.p_actor.close.remote())
|
ray.get(self.p_actor.close.remote())
|
||||||
|
|
||||||
def _collect_1_job(self):
|
|
||||||
ready, self.jobs = ray.wait(self.jobs, timeout=self.update_cluster_frequency)
|
|
||||||
ray.get(self.p_actor.update_pbars.remote())
|
|
||||||
if len(ready) == 0:
|
|
||||||
return
|
|
||||||
ray.get(ready)
|
|
||||||
|
|
||||||
del self.actors[ready[0].task_id()]
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user