From bf868c46684cca8060e89f4d00bfd9166ba21245 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Wed, 18 Aug 2021 09:56:02 +0200 Subject: [PATCH] changed default data directory path --- src/scgenerator/io.py | 8 ++++---- src/scgenerator/physics/pulse.py | 2 +- src/scgenerator/physics/simulate.py | 7 +++++-- src/scgenerator/utils/parameter.py | 1 + 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 33561cc..10fc8c9 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -454,12 +454,12 @@ def sim_dirs(path_trees: List[PathTree]) -> Generator[Path, None, None]: yield p[0].parent -def get_sim_dir(task_id: int, name_if_new: str = "data") -> Path: - if name_if_new == "": - name_if_new = "data" +def get_sim_dir(task_id: int, path_if_new: Path = None) -> Path: + if path_if_new is None: + path_if_new = Path("scgenerator data") tmp = env.data_folder(task_id) if tmp is None: - tmp = ensure_folder(Path("scgenerator" + PARAM_SEPARATOR + name_if_new)) + tmp = ensure_folder(path_if_new) os.environ[TMP_FOLDER_KEY_BASE + str(task_id)] = str(tmp) tmp = Path(tmp).resolve() if not tmp.exists(): diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 2d6195a..8106282 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -739,7 +739,7 @@ def find_lobe_limits(x_axis, values, debug="", already_sorted=True): ) ax.legend() fig.savefig(out_path, bbox_inches="tight") - plt.close(fig) + plt.close() else: good_roots, left_lim, right_lim = _select_roots(d_spline, d_roots, dd_roots, fwhm_pos) diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 04f7eff..5c2837d 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -8,6 +8,7 @@ from typing import Dict, List, Tuple, Type, Union import numpy as np from .. import env, initialize, io, utils +from ..const import PARAM_SEPARATOR from ..errors import IncompleteDataFolderError from ..logger import get_logger from . import pulse @@ -438,7 +439,9 @@ class Simulations: self.update(param_seq) self.name = self.param_seq.name - self.sim_dir = io.get_sim_dir(self.id, name_if_new=self.name) + self.sim_dir = io.get_sim_dir( + self.id, path_if_new=Path(self.name + PARAM_SEPARATOR + "tmp") + ) io.save_parameters(self.param_seq.config, self.sim_dir, file_name="initial_config.toml") self.sim_jobs_per_node = 1 @@ -690,7 +693,7 @@ def run_simulation_sequence( final_name = env.get(env.OUTPUT_PATH) if final_name is None: - final_name = path_trees[0][-1][0].parent.name + " merged" + final_name = config.name io.merge(final_name, path_trees) diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index 13e8983..b852a78 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -251,6 +251,7 @@ class VariableParameter: valid_variable = { "dispersion_file", "field_file", + "loss_file", "beta", "gamma", "pitch",