iterable RK4IP solver, also in parallel

This commit is contained in:
Benoît Sierro
2021-08-27 08:37:58 +02:00
parent f8d2f53083
commit 5236beedd6
2 changed files with 107 additions and 22 deletions

View File

@@ -361,6 +361,11 @@ class ParamSequence:
def count_variations(self) -> int:
return count_variations(self.config)
@property
def first(self) -> Params:
for _, params in self:
return params
class ContinuationParamSequence(ParamSequence):
def __init__(self, prev_sim_dir: os.PathLike, new_config: BareConfig):

View File

@@ -1,9 +1,10 @@
import multiprocessing
import multiprocessing.connection
import os
import random
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Type, Union
from typing import Dict, Generator, List, Tuple, Type, Any
import numpy as np
@@ -53,14 +54,18 @@ class RK4IP:
self.job_identifier = job_identifier
self.id = task_id
self.save_data = save_data
self.sim_dir = io.get_sim_dir(self.id)
self.sim_dir.mkdir(exist_ok=True)
self.data_dir = self.sim_dir / self.job_identifier
if self.save_data:
self.sim_dir = io.get_sim_dir(self.id)
self.sim_dir.mkdir(exist_ok=True)
self.data_dir = self.sim_dir / self.job_identifier
else:
self.sim_dir = None
self.data_dir = None
self.logger = get_logger(self.job_identifier)
self.resuming = False
self.save_data = save_data
self.w_c = params.w_c
self.w = params.w
@@ -144,9 +149,6 @@ class RK4IP:
]
self.size_fac = 2 ** (1 / 5)
if self.save_data:
self._save_current_spectrum(0)
# Initial step size
if self.adapt_step_size:
self.initial_h = (self.z_targets[0] - self.z) / 2
@@ -165,6 +167,16 @@ class RK4IP:
self._save_data(self.cons_qty, f"cons_qty")
self.step_saved()
def get_current_spectrum(self) -> tuple[int, np.ndarray]:
"""returns the current spectrum
Returns
-------
np.ndarray
spectrum
"""
return self.C_to_A_factor * self.current_spectrum
def _save_data(self, data: np.ndarray, name: str):
"""calls the appropriate method to save data
@@ -178,6 +190,24 @@ class RK4IP:
io.save_data(data, self.data_dir, name)
def run(self):
time_start = datetime.today()
for step, num, _ in self.irun():
if self.save_data:
self._save_current_spectrum(num)
self.logger.info(
"propagation finished in {} steps ({} seconds)".format(
step, (datetime.today() - time_start).total_seconds()
)
)
if self.save_data:
self._save_data(self.z_stored, "z.npy")
return self.stored_spectra
def irun(self):
# Print introduction
self.logger.debug(
@@ -189,7 +219,8 @@ class RK4IP:
h_taken = self.initial_h
h_next_step = self.initial_h
store = False # store a spectrum
time_start = datetime.today()
yield step, len(self.stored_spectra) - 1, self.get_current_spectrum()
while self.z < self.z_final:
h_taken, h_next_step, self.current_spectrum = self.take_step(
@@ -205,8 +236,8 @@ class RK4IP:
self.logger.debug("{} steps, z = {:.4f}, h = {:.5g}".format(step, self.z, h_taken))
self.stored_spectra.append(self.current_spectrum)
if self.save_data:
self._save_current_spectrum(len(self.stored_spectra) - 1)
yield step, len(self.stored_spectra) - 1, self.get_current_spectrum()
self.z_stored.append(self.z)
del self.z_targets[0]
@@ -225,17 +256,6 @@ class RK4IP:
store = True
h_next_step = self.z_targets[0] - self.z
self.logger.info(
"propagation finished in {} steps ({} seconds)".format(
step, (datetime.today() - time_start).total_seconds()
)
)
if self.save_data:
self._save_data(self.z_stored, "z.npy")
return self.stored_spectra
def take_step(
self, step: int, h_next_step: float, current_spectrum: np.ndarray
) -> Tuple[float, float, np.ndarray]:
@@ -731,5 +751,65 @@ def resume_simulations(sim_dir: Path, method: Type[Simulations] = None) -> Simul
return Simulations.new(param_seq, task_id, method)
def __parallel_RK4IP_worker(
worker_id: int,
msq_queue: multiprocessing.connection.Connection,
data_queue: multiprocessing.Queue,
params: utils.BareParams,
):
logger = get_logger(__name__)
logger.debug(f"workder {worker_id} started")
for out in RK4IP(params).irun():
logger.debug(f"worker {worker_id} waiting for msg")
msq_queue.recv()
logger.debug(f"worker {worker_id} got msg")
data_queue.put((worker_id, out))
logger.debug(f"worker {worker_id} sent data")
def parallel_RK4IP(
config,
) -> Generator[
tuple[tuple[list[tuple[str, Any]], initialize.Params, int, int, np.ndarray], ...], None, None
]:
logger = get_logger(__name__)
params = list(initialize.ParamSequence(config))
n = len(params)
z_num = params[0][1].z_num
cpu_no = multiprocessing.cpu_count()
if len(params) < cpu_no:
cpu_no = len(params)
pipes = [multiprocessing.Pipe(duplex=False) for i in range(n)]
data_queue = multiprocessing.Queue()
workers = [
multiprocessing.Process(target=__parallel_RK4IP_worker, args=(i, pipe[0], data_queue, p[1]))
for i, (pipe, p) in enumerate(zip(pipes, params))
]
try:
[w.start() for w in workers]
logger.debug("pool started")
for i in range(z_num):
for q in pipes:
q[1].send(0)
logger.debug("msg sent")
computed_dict = {}
for j in range(n):
w_id, computed = data_queue.get()
computed_dict[w_id] = computed
computed_dict = list(computed_dict.items())
computed_dict.sort()
yield tuple((*p, *c) for p, c in zip(params, [el[1] for el in computed_dict]))
print("finished")
finally:
for w, cs in zip(workers, pipes):
w.join()
w.close()
cs[0].close()
cs[1].close()
data_queue.close()
if __name__ == "__main__":
pass