diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index d9c9ae5..d33e0eb 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -22,7 +22,7 @@ from .const import PARAM_FN, __version__ from .errors import EvaluatorError, NoDefaultError from .logger import get_logger from .physics import fiber, materials, pulse, units -from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path +from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path_name from .variationer import VariationDescriptor, Variationer T = TypeVar("T") @@ -341,7 +341,7 @@ class Parameters(_AbstractParameters): prev_data_dir: str = Parameter(string) recovery_data_dir: str = Parameter(string) previous_config_file: str = Parameter(string) - output_path: str = Parameter(string, default="sc_data") + output_path: Path = Parameter(type_checker(Path), default=Path("sc_data"), converter=Path) # # fiber input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0) @@ -535,7 +535,7 @@ class Parameters(_AbstractParameters): @property def final_path(self) -> Path: if self.output_path is not None: - return Path(update_path(self.output_path)) + return self.output_path.parent / update_path_name(self.output_path.name) return None @@ -938,7 +938,7 @@ class Configuration: not self.overwrite, False, ) - cfg["output_path"] = str(p) + cfg["output_path"] = p sim_config = self.__SimConfig(descriptor, cfg, p) sim_dict[p] = self.all_configs[sim_config.descriptor.index] = sim_config while len(sim_dict) > 0: diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 83def58..9da68bf 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -3,6 +3,7 @@ import multiprocessing.connection import os import random from datetime import datetime +from dataclasses import dataclass from pathlib import Path from typing import Any, Generator, Type, Union @@ -21,12 +22,23 @@ except ModuleNotFoundError: ray = None +@dataclass +class CurrentState: + length: float + spectrum: np.ndarray + z: float + h: float + + @property + def z_ratio(self) -> float: + return self.z / self.length + + class RK4IP: def __init__( self, params: Parameters, save_data=False, - job_identifier="", task_id=0, ): """A 1D solver using 4th order Runge-Kutta in the interaction picture @@ -37,12 +49,10 @@ class RK4IP: parameters of the simulation save_data : bool, optional save calculated spectra to disk, by default False - job_identifier : str, optional - string identifying the parameter set, by default "" task_id : int, optional unique identifier of the session, by default 0 """ - self.set(params, save_data, job_identifier, task_id) + self.set(params, save_data, task_id) def __iter__(self) -> Generator[tuple[int, int, np.ndarray], None, None]: yield from self.irun() @@ -54,21 +64,19 @@ class RK4IP: self, params: Parameters, save_data=False, - job_identifier="", task_id=0, ): - - self.job_identifier = job_identifier + self.params = params self.id = task_id self.save_data = save_data if self.save_data: - self.data_dir = Path(params.output_path) + self.data_dir = params.output_path os.makedirs(self.data_dir, exist_ok=True) else: self.data_dir = None - self.logger = get_logger(self.job_identifier) + self.logger = get_logger(self.params.output_path) self.resuming = False self.w_c = params.w_c @@ -346,14 +354,12 @@ class SequentialRK4IP(RK4IP): params: Parameters, pbars: PBars, save_data=False, - job_identifier="", task_id=0, ): self.pbars = pbars super().__init__( params, save_data=save_data, - job_identifier=job_identifier, task_id=task_id, ) @@ -368,7 +374,6 @@ class MutliProcRK4IP(RK4IP): p_queue: multiprocessing.Queue, worker_id: int, save_data=False, - job_identifier="", task_id=0, ): self.worker_id = worker_id @@ -376,7 +381,6 @@ class MutliProcRK4IP(RK4IP): super().__init__( params, save_data=save_data, - job_identifier=job_identifier, task_id=task_id, ) @@ -394,7 +398,6 @@ class RayRK4IP(RK4IP): p_actor, worker_id: int, save_data=False, - job_identifier="", task_id=0, ): self.worker_id = worker_id @@ -402,13 +405,12 @@ class RayRK4IP(RK4IP): super().set( params, save_data=save_data, - job_identifier=job_identifier, task_id=task_id, ) def set_and_run(self, v): - params, p_actor, worker_id, save_data, job_identifier, task_id = v - self.set(params, p_actor, worker_id, save_data, job_identifier, task_id) + params, p_actor, worker_id, save_data, task_id = v + self.set(params, p_actor, worker_id, save_data, task_id) self.run() def step_saved(self): @@ -500,19 +502,16 @@ class Simulations: def _run_available(self): for variable, params in self.configuration: params.compute() - v_list_str = variable.formatted_descriptor(True) - utils.save_parameters(params.prepare_for_dump(), Path(params.output_path)) + utils.save_parameters(params.prepare_for_dump(), params.output_path) - self.new_sim(v_list_str, params) + self.new_sim(params) self.finish() - def new_sim(self, v_list_str: str, params: Parameters): + def new_sim(self, params: Parameters): """responsible to launch a new simulation Parameters ---------- - v_list_str : str - string that uniquely identifies the simulation as returned by utils.format_variable_list params : Parameters computed parameters """ @@ -545,13 +544,9 @@ class SequencialSimulations(Simulations, priority=0): ) self.configuration.skip_callback = lambda num: self.pbars.update(0, num) - def new_sim(self, v_list_str: str, params: Parameters): - self.logger.info( - f"{self.configuration.final_path} : launching simulation with {v_list_str}" - ) - SequentialRK4IP( - params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id - ).run() + def new_sim(self, params: Parameters): + self.logger.info(f"{self.configuration.final_path} : launching simulation") + SequentialRK4IP(params, self.pbars, save_data=True, task_id=self.id).run() def stop(self): pass @@ -597,8 +592,8 @@ class MultiProcSimulations(Simulations, priority=1): worker.start() super().run() - def new_sim(self, v_list_str: str, params: Parameters): - self.queue.put((v_list_str, params), block=True, timeout=None) + def new_sim(self, params: Parameters): + self.queue.put((params,), block=True, timeout=None) def finish(self): """0 means finished""" @@ -624,13 +619,12 @@ class MultiProcSimulations(Simulations, priority=1): if raw_data == 0: queue.task_done() return - v_list_str, params = raw_data + (params,) = raw_data MutliProcRK4IP( params, p_queue, worker_id, save_data=True, - job_identifier=v_list_str, task_id=task_id, ).run() queue.task_done() @@ -676,7 +670,7 @@ class RaySimulations(Simulations, priority=2): ) self.configuration.skip_callback = lambda num: ray.get(self.p_actor.update.remote(0, num)) - def new_sim(self, v_list_str: str, params: Parameters): + def new_sim(self, params: Parameters): while self.num_submitted >= self.sim_jobs_total: self.collect_1_job() @@ -688,15 +682,12 @@ class RaySimulations(Simulations, priority=2): self.p_actor, self.rolling_id + 1, True, - v_list_str, self.id, ), ) self.num_submitted += 1 - self.logger.info( - f"{self.configuration.final_path} : launching simulation with {v_list_str}" - ) + self.logger.info(f"{self.configuration.final_path} : launching simulation") def collect_1_job(self): ray.get(self.p_actor.update_pbars.remote()) diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index c81a68b..1e95d85 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -437,7 +437,7 @@ def combine_simulations(path: Path, dest: Path = None): for l in paths.values(): l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0]) for pulses in paths.values(): - new_path = dest / update_path(pulses[0].name) + new_path = dest / update_path_name(pulses[0].name) os.makedirs(new_path, exist_ok=True) for num, pulse in enumerate(pulses): params_ok = False @@ -461,8 +461,8 @@ def update_params(new_path: Path, file: Path): params = load_toml(file) if (p := params.get("prev_data_dir")) is not None: p = Path(p) - params["prev_data_dir"] = str(p.parent / update_path(p.name)) - params["output_path"] = str(new_path) + params["prev_data_dir"] = str(p.parent / update_path_name(p.name)) + params["output_path"] = new_path save_toml(new_path / PARAM_FN, params) file.unlink() @@ -495,7 +495,7 @@ def save_parameters( return file_path -def update_path(p: str) -> str: +def update_path_name(p: str) -> str: return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p)