diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index 90c1c20..7620af0 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -147,9 +147,11 @@ def run_sim(args): def merge(args): path_trees = io.build_path_trees(Path(args.path)) - if args.output_name is None: - args.output_name = path_trees[0][-1][0].parent.name + " merged" - io.merge(args.output_name, path_trees) + output = env.output_path() + if output is None: + output = path_trees[0][-1][0].parent.name + " merged" + + io.merge(output, path_trees) def prep_ray(): diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index bfe4d6d..255aaee 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -519,15 +519,15 @@ def validate_config_sequence(*configs: os.PathLike) -> Tuple[Config, int]: """ previous = None variables = set() - num = 1 + repeat = 1 for config in configs: if (p := Path(config)).is_dir(): config = p / "initial_config.toml" dico = io.load_toml(config) previous = Config.from_bare(override_config(dico, previous)) - num *= previous.repeat + repeat = previous.repeat variables |= {(k, tuple(v)) for k, v in previous.variable.items()} - return previous, num * int(np.product([len(v) for k, v in variables if len(v) > 0])) + return previous, repeat * int(np.product([len(v) for k, v in variables if len(v) > 0])) def wspace(t, t_num=0): diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 3e1778e..89060c6 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -6,7 +6,7 @@ from typing import Dict, List, Tuple, Type import numpy as np -from .. import const, env, initialize, io, utils +from .. import env, initialize, io, utils from ..errors import IncompleteDataFolderError from ..logger import get_logger from . import pulse @@ -668,7 +668,7 @@ def run_simulation_sequence( prev = sim.sim_dir path_trees = io.build_path_trees(sim.sim_dir) - final_name = env.get(const.OUTPUT_PATH) + final_name = env.get(env.OUTPUT_PATH) if final_name is None: final_name = path_trees[0][-1][0].parent.name + " merged" diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 9eb81ab..30ebdc2 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -33,19 +33,29 @@ class Spectrum(np.ndarray): class Pulse(Sequence): - def __init__(self, path: os.PathLike, ensure_2d=True): + def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None): + """load a data folder as a pulse + + Parameters + ---------- + path : os.PathLike + path to the data (folder containing .npy files) + default_ind : int | Iterable[int], optional + default indices to be loaded, by default None + + Raises + ------ + FileNotFoundError + path does not contain proper data + """ self.logger = get_logger(__name__) self.path = Path(path) - self.__ensure_2d = ensure_2d + self.default_ind = default_ind if not self.path.is_dir(): raise FileNotFoundError(f"Folder {self.path} does not exist") - self.params = None - try: - self.params = io.load_params(self.path / "params.toml") - except FileNotFoundError: - self.logger.info(f"parameters corresponding to {self.path} not found") + self.params = io.load_params(self.path / "params.toml") initialize.build_sim_grid_in_place(self.params) @@ -173,8 +183,11 @@ class Pulse(Sequence): # Check if file exists and assert how many z positions there are if ind is None: - ind = range(self.nmax) - elif isinstance(ind, int): + if self.default_ind is None: + ind = range(self.nmax) + else: + ind = self.default_ind + if isinstance(ind, int): ind = [ind] # Load the spectra @@ -184,8 +197,10 @@ class Pulse(Sequence): spectra = np.array(spectra) self.logger.debug(f"all spectra from {self.path} successfully loaded") - - return spectra + if len(ind) == 1: + return spectra[0] + else: + return spectra def all_fields(self, ind=None): return np.fft.ifft(self.all_spectra(ind=ind), axis=-1) @@ -196,8 +211,7 @@ class Pulse(Sequence): if i in self.cache: return self.cache[i] spec = np.load(self.path / SPECN_FN.format(i)) - if self.__ensure_2d: - spec = np.atleast_2d(spec) + spec = np.atleast_2d(spec) spec = Spectrum(spec, self.wl, self.params.repetition_rate) self.cache[i] = spec return spec diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index 5ed4857..9a2b1e0 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -6,6 +6,7 @@ scgenerator module but some function may be used in any python program import itertools import multiprocessing +import re import threading from collections import abc from copy import deepcopy @@ -195,7 +196,7 @@ def format_variable_list(l: List[Tuple[str, Any]]): def branch_id(branch: Tuple[Path, ...]) -> str: - return "".join("".join(b.name.split()[2:-2]) for b in branch) + return "".join("".join(re.sub(r"id\d\S*num\d", "", b.name).split()[2:-2]) for b in branch) def format_value(value) -> str: