iterable RK4IP solver, also in parallel
This commit is contained in:
@@ -361,6 +361,11 @@ class ParamSequence:
|
|||||||
def count_variations(self) -> int:
|
def count_variations(self) -> int:
|
||||||
return count_variations(self.config)
|
return count_variations(self.config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def first(self) -> Params:
|
||||||
|
for _, params in self:
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
class ContinuationParamSequence(ParamSequence):
|
class ContinuationParamSequence(ParamSequence):
|
||||||
def __init__(self, prev_sim_dir: os.PathLike, new_config: BareConfig):
|
def __init__(self, prev_sim_dir: os.PathLike, new_config: BareConfig):
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import multiprocessing.connection
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple, Type, Union
|
from typing import Dict, Generator, List, Tuple, Type, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -53,14 +54,18 @@ class RK4IP:
|
|||||||
|
|
||||||
self.job_identifier = job_identifier
|
self.job_identifier = job_identifier
|
||||||
self.id = task_id
|
self.id = task_id
|
||||||
|
self.save_data = save_data
|
||||||
|
|
||||||
|
if self.save_data:
|
||||||
self.sim_dir = io.get_sim_dir(self.id)
|
self.sim_dir = io.get_sim_dir(self.id)
|
||||||
self.sim_dir.mkdir(exist_ok=True)
|
self.sim_dir.mkdir(exist_ok=True)
|
||||||
self.data_dir = self.sim_dir / self.job_identifier
|
self.data_dir = self.sim_dir / self.job_identifier
|
||||||
|
else:
|
||||||
|
self.sim_dir = None
|
||||||
|
self.data_dir = None
|
||||||
|
|
||||||
self.logger = get_logger(self.job_identifier)
|
self.logger = get_logger(self.job_identifier)
|
||||||
self.resuming = False
|
self.resuming = False
|
||||||
self.save_data = save_data
|
|
||||||
|
|
||||||
self.w_c = params.w_c
|
self.w_c = params.w_c
|
||||||
self.w = params.w
|
self.w = params.w
|
||||||
@@ -144,9 +149,6 @@ class RK4IP:
|
|||||||
]
|
]
|
||||||
self.size_fac = 2 ** (1 / 5)
|
self.size_fac = 2 ** (1 / 5)
|
||||||
|
|
||||||
if self.save_data:
|
|
||||||
self._save_current_spectrum(0)
|
|
||||||
|
|
||||||
# Initial step size
|
# Initial step size
|
||||||
if self.adapt_step_size:
|
if self.adapt_step_size:
|
||||||
self.initial_h = (self.z_targets[0] - self.z) / 2
|
self.initial_h = (self.z_targets[0] - self.z) / 2
|
||||||
@@ -165,6 +167,16 @@ class RK4IP:
|
|||||||
self._save_data(self.cons_qty, f"cons_qty")
|
self._save_data(self.cons_qty, f"cons_qty")
|
||||||
self.step_saved()
|
self.step_saved()
|
||||||
|
|
||||||
|
def get_current_spectrum(self) -> tuple[int, np.ndarray]:
|
||||||
|
"""returns the current spectrum
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
spectrum
|
||||||
|
"""
|
||||||
|
return self.C_to_A_factor * self.current_spectrum
|
||||||
|
|
||||||
def _save_data(self, data: np.ndarray, name: str):
|
def _save_data(self, data: np.ndarray, name: str):
|
||||||
"""calls the appropriate method to save data
|
"""calls the appropriate method to save data
|
||||||
|
|
||||||
@@ -178,6 +190,24 @@ class RK4IP:
|
|||||||
io.save_data(data, self.data_dir, name)
|
io.save_data(data, self.data_dir, name)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
time_start = datetime.today()
|
||||||
|
|
||||||
|
for step, num, _ in self.irun():
|
||||||
|
if self.save_data:
|
||||||
|
self._save_current_spectrum(num)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"propagation finished in {} steps ({} seconds)".format(
|
||||||
|
step, (datetime.today() - time_start).total_seconds()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.save_data:
|
||||||
|
self._save_data(self.z_stored, "z.npy")
|
||||||
|
|
||||||
|
return self.stored_spectra
|
||||||
|
|
||||||
|
def irun(self):
|
||||||
|
|
||||||
# Print introduction
|
# Print introduction
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
@@ -189,7 +219,8 @@ class RK4IP:
|
|||||||
h_taken = self.initial_h
|
h_taken = self.initial_h
|
||||||
h_next_step = self.initial_h
|
h_next_step = self.initial_h
|
||||||
store = False # store a spectrum
|
store = False # store a spectrum
|
||||||
time_start = datetime.today()
|
|
||||||
|
yield step, len(self.stored_spectra) - 1, self.get_current_spectrum()
|
||||||
|
|
||||||
while self.z < self.z_final:
|
while self.z < self.z_final:
|
||||||
h_taken, h_next_step, self.current_spectrum = self.take_step(
|
h_taken, h_next_step, self.current_spectrum = self.take_step(
|
||||||
@@ -205,8 +236,8 @@ class RK4IP:
|
|||||||
self.logger.debug("{} steps, z = {:.4f}, h = {:.5g}".format(step, self.z, h_taken))
|
self.logger.debug("{} steps, z = {:.4f}, h = {:.5g}".format(step, self.z, h_taken))
|
||||||
|
|
||||||
self.stored_spectra.append(self.current_spectrum)
|
self.stored_spectra.append(self.current_spectrum)
|
||||||
if self.save_data:
|
|
||||||
self._save_current_spectrum(len(self.stored_spectra) - 1)
|
yield step, len(self.stored_spectra) - 1, self.get_current_spectrum()
|
||||||
|
|
||||||
self.z_stored.append(self.z)
|
self.z_stored.append(self.z)
|
||||||
del self.z_targets[0]
|
del self.z_targets[0]
|
||||||
@@ -225,17 +256,6 @@ class RK4IP:
|
|||||||
store = True
|
store = True
|
||||||
h_next_step = self.z_targets[0] - self.z
|
h_next_step = self.z_targets[0] - self.z
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
"propagation finished in {} steps ({} seconds)".format(
|
|
||||||
step, (datetime.today() - time_start).total_seconds()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.save_data:
|
|
||||||
self._save_data(self.z_stored, "z.npy")
|
|
||||||
|
|
||||||
return self.stored_spectra
|
|
||||||
|
|
||||||
def take_step(
|
def take_step(
|
||||||
self, step: int, h_next_step: float, current_spectrum: np.ndarray
|
self, step: int, h_next_step: float, current_spectrum: np.ndarray
|
||||||
) -> Tuple[float, float, np.ndarray]:
|
) -> Tuple[float, float, np.ndarray]:
|
||||||
@@ -731,5 +751,65 @@ def resume_simulations(sim_dir: Path, method: Type[Simulations] = None) -> Simul
|
|||||||
return Simulations.new(param_seq, task_id, method)
|
return Simulations.new(param_seq, task_id, method)
|
||||||
|
|
||||||
|
|
||||||
|
def __parallel_RK4IP_worker(
|
||||||
|
worker_id: int,
|
||||||
|
msq_queue: multiprocessing.connection.Connection,
|
||||||
|
data_queue: multiprocessing.Queue,
|
||||||
|
params: utils.BareParams,
|
||||||
|
):
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
logger.debug(f"workder {worker_id} started")
|
||||||
|
for out in RK4IP(params).irun():
|
||||||
|
logger.debug(f"worker {worker_id} waiting for msg")
|
||||||
|
msq_queue.recv()
|
||||||
|
logger.debug(f"worker {worker_id} got msg")
|
||||||
|
data_queue.put((worker_id, out))
|
||||||
|
logger.debug(f"worker {worker_id} sent data")
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_RK4IP(
|
||||||
|
config,
|
||||||
|
) -> Generator[
|
||||||
|
tuple[tuple[list[tuple[str, Any]], initialize.Params, int, int, np.ndarray], ...], None, None
|
||||||
|
]:
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
params = list(initialize.ParamSequence(config))
|
||||||
|
n = len(params)
|
||||||
|
z_num = params[0][1].z_num
|
||||||
|
|
||||||
|
cpu_no = multiprocessing.cpu_count()
|
||||||
|
if len(params) < cpu_no:
|
||||||
|
cpu_no = len(params)
|
||||||
|
|
||||||
|
pipes = [multiprocessing.Pipe(duplex=False) for i in range(n)]
|
||||||
|
data_queue = multiprocessing.Queue()
|
||||||
|
workers = [
|
||||||
|
multiprocessing.Process(target=__parallel_RK4IP_worker, args=(i, pipe[0], data_queue, p[1]))
|
||||||
|
for i, (pipe, p) in enumerate(zip(pipes, params))
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
[w.start() for w in workers]
|
||||||
|
logger.debug("pool started")
|
||||||
|
for i in range(z_num):
|
||||||
|
for q in pipes:
|
||||||
|
q[1].send(0)
|
||||||
|
logger.debug("msg sent")
|
||||||
|
computed_dict = {}
|
||||||
|
for j in range(n):
|
||||||
|
w_id, computed = data_queue.get()
|
||||||
|
computed_dict[w_id] = computed
|
||||||
|
computed_dict = list(computed_dict.items())
|
||||||
|
computed_dict.sort()
|
||||||
|
yield tuple((*p, *c) for p, c in zip(params, [el[1] for el in computed_dict]))
|
||||||
|
print("finished")
|
||||||
|
finally:
|
||||||
|
for w, cs in zip(workers, pipes):
|
||||||
|
w.join()
|
||||||
|
w.close()
|
||||||
|
cs[0].close()
|
||||||
|
cs[1].close()
|
||||||
|
data_queue.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user