This commit is contained in:
Benoît Sierro
2021-06-21 16:26:45 +02:00
parent 991b6b81b0
commit 1b02b3bb4a
5 changed files with 39 additions and 22 deletions

View File

@@ -147,9 +147,11 @@ def run_sim(args):
def merge(args):
path_trees = io.build_path_trees(Path(args.path))
if args.output_name is None:
args.output_name = path_trees[0][-1][0].parent.name + " merged"
io.merge(args.output_name, path_trees)
output = env.output_path()
if output is None:
output = path_trees[0][-1][0].parent.name + " merged"
io.merge(output, path_trees)
def prep_ray():

View File

@@ -519,15 +519,15 @@ def validate_config_sequence(*configs: os.PathLike) -> Tuple[Config, int]:
"""
previous = None
variables = set()
num = 1
repeat = 1
for config in configs:
if (p := Path(config)).is_dir():
config = p / "initial_config.toml"
dico = io.load_toml(config)
previous = Config.from_bare(override_config(dico, previous))
num *= previous.repeat
repeat = previous.repeat
variables |= {(k, tuple(v)) for k, v in previous.variable.items()}
return previous, num * int(np.product([len(v) for k, v in variables if len(v) > 0]))
return previous, repeat * int(np.product([len(v) for k, v in variables if len(v) > 0]))
def wspace(t, t_num=0):

View File

@@ -6,7 +6,7 @@ from typing import Dict, List, Tuple, Type
import numpy as np
from .. import const, env, initialize, io, utils
from .. import env, initialize, io, utils
from ..errors import IncompleteDataFolderError
from ..logger import get_logger
from . import pulse
@@ -668,7 +668,7 @@ def run_simulation_sequence(
prev = sim.sim_dir
path_trees = io.build_path_trees(sim.sim_dir)
final_name = env.get(const.OUTPUT_PATH)
final_name = env.get(env.OUTPUT_PATH)
if final_name is None:
final_name = path_trees[0][-1][0].parent.name + " merged"

View File

@@ -33,19 +33,29 @@ class Spectrum(np.ndarray):
class Pulse(Sequence):
def __init__(self, path: os.PathLike, ensure_2d=True):
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
"""load a data folder as a pulse
Parameters
----------
path : os.PathLike
path to the data (folder containing .npy files)
default_ind : int | Iterable[int], optional
default indices to be loaded, by default None
Raises
------
FileNotFoundError
path does not contain proper data
"""
self.logger = get_logger(__name__)
self.path = Path(path)
self.__ensure_2d = ensure_2d
self.default_ind = default_ind
if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = None
try:
self.params = io.load_params(self.path / "params.toml")
except FileNotFoundError:
self.logger.info(f"parameters corresponding to {self.path} not found")
self.params = io.load_params(self.path / "params.toml")
initialize.build_sim_grid_in_place(self.params)
@@ -173,8 +183,11 @@ class Pulse(Sequence):
# Check if file exists and assert how many z positions there are
if ind is None:
ind = range(self.nmax)
elif isinstance(ind, int):
if self.default_ind is None:
ind = range(self.nmax)
else:
ind = self.default_ind
if isinstance(ind, int):
ind = [ind]
# Load the spectra
@@ -184,8 +197,10 @@ class Pulse(Sequence):
spectra = np.array(spectra)
self.logger.debug(f"all spectra from {self.path} successfully loaded")
return spectra
if len(ind) == 1:
return spectra[0]
else:
return spectra
def all_fields(self, ind=None):
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
@@ -196,8 +211,7 @@ class Pulse(Sequence):
if i in self.cache:
return self.cache[i]
spec = np.load(self.path / SPECN_FN.format(i))
if self.__ensure_2d:
spec = np.atleast_2d(spec)
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.wl, self.params.repetition_rate)
self.cache[i] = spec
return spec

View File

@@ -6,6 +6,7 @@ scgenerator module but some function may be used in any python program
import itertools
import multiprocessing
import re
import threading
from collections import abc
from copy import deepcopy
@@ -195,7 +196,7 @@ def format_variable_list(l: List[Tuple[str, Any]]):
def branch_id(branch: Tuple[Path, ...]) -> str:
return "".join("".join(b.name.split()[2:-2]) for b in branch)
return "".join("".join(re.sub(r"id\d\S*num\d", "", b.name).split()[2:-2]) for b in branch)
def format_value(value) -> str: