This commit is contained in:
Benoît Sierro
2021-06-21 16:26:45 +02:00
parent 991b6b81b0
commit 1b02b3bb4a
5 changed files with 39 additions and 22 deletions

View File

@@ -147,9 +147,11 @@ def run_sim(args):
def merge(args): def merge(args):
path_trees = io.build_path_trees(Path(args.path)) path_trees = io.build_path_trees(Path(args.path))
if args.output_name is None: output = env.output_path()
args.output_name = path_trees[0][-1][0].parent.name + " merged" if output is None:
io.merge(args.output_name, path_trees) output = path_trees[0][-1][0].parent.name + " merged"
io.merge(output, path_trees)
def prep_ray(): def prep_ray():

View File

@@ -519,15 +519,15 @@ def validate_config_sequence(*configs: os.PathLike) -> Tuple[Config, int]:
""" """
previous = None previous = None
variables = set() variables = set()
num = 1 repeat = 1
for config in configs: for config in configs:
if (p := Path(config)).is_dir(): if (p := Path(config)).is_dir():
config = p / "initial_config.toml" config = p / "initial_config.toml"
dico = io.load_toml(config) dico = io.load_toml(config)
previous = Config.from_bare(override_config(dico, previous)) previous = Config.from_bare(override_config(dico, previous))
num *= previous.repeat repeat = previous.repeat
variables |= {(k, tuple(v)) for k, v in previous.variable.items()} variables |= {(k, tuple(v)) for k, v in previous.variable.items()}
return previous, num * int(np.product([len(v) for k, v in variables if len(v) > 0])) return previous, repeat * int(np.product([len(v) for k, v in variables if len(v) > 0]))
def wspace(t, t_num=0): def wspace(t, t_num=0):

View File

@@ -6,7 +6,7 @@ from typing import Dict, List, Tuple, Type
import numpy as np import numpy as np
from .. import const, env, initialize, io, utils from .. import env, initialize, io, utils
from ..errors import IncompleteDataFolderError from ..errors import IncompleteDataFolderError
from ..logger import get_logger from ..logger import get_logger
from . import pulse from . import pulse
@@ -668,7 +668,7 @@ def run_simulation_sequence(
prev = sim.sim_dir prev = sim.sim_dir
path_trees = io.build_path_trees(sim.sim_dir) path_trees = io.build_path_trees(sim.sim_dir)
final_name = env.get(const.OUTPUT_PATH) final_name = env.get(env.OUTPUT_PATH)
if final_name is None: if final_name is None:
final_name = path_trees[0][-1][0].parent.name + " merged" final_name = path_trees[0][-1][0].parent.name + " merged"

View File

@@ -33,19 +33,29 @@ class Spectrum(np.ndarray):
class Pulse(Sequence): class Pulse(Sequence):
def __init__(self, path: os.PathLike, ensure_2d=True): def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
"""load a data folder as a pulse
Parameters
----------
path : os.PathLike
path to the data (folder containing .npy files)
default_ind : int | Iterable[int], optional
default indices to be loaded, by default None
Raises
------
FileNotFoundError
path does not contain proper data
"""
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
self.path = Path(path) self.path = Path(path)
self.__ensure_2d = ensure_2d self.default_ind = default_ind
if not self.path.is_dir(): if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist") raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = None self.params = io.load_params(self.path / "params.toml")
try:
self.params = io.load_params(self.path / "params.toml")
except FileNotFoundError:
self.logger.info(f"parameters corresponding to {self.path} not found")
initialize.build_sim_grid_in_place(self.params) initialize.build_sim_grid_in_place(self.params)
@@ -173,8 +183,11 @@ class Pulse(Sequence):
# Check if file exists and assert how many z positions there are # Check if file exists and assert how many z positions there are
if ind is None: if ind is None:
ind = range(self.nmax) if self.default_ind is None:
elif isinstance(ind, int): ind = range(self.nmax)
else:
ind = self.default_ind
if isinstance(ind, int):
ind = [ind] ind = [ind]
# Load the spectra # Load the spectra
@@ -184,8 +197,10 @@ class Pulse(Sequence):
spectra = np.array(spectra) spectra = np.array(spectra)
self.logger.debug(f"all spectra from {self.path} successfully loaded") self.logger.debug(f"all spectra from {self.path} successfully loaded")
if len(ind) == 1:
return spectra return spectra[0]
else:
return spectra
def all_fields(self, ind=None): def all_fields(self, ind=None):
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1) return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
@@ -196,8 +211,7 @@ class Pulse(Sequence):
if i in self.cache: if i in self.cache:
return self.cache[i] return self.cache[i]
spec = np.load(self.path / SPECN_FN.format(i)) spec = np.load(self.path / SPECN_FN.format(i))
if self.__ensure_2d: spec = np.atleast_2d(spec)
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.wl, self.params.repetition_rate) spec = Spectrum(spec, self.wl, self.params.repetition_rate)
self.cache[i] = spec self.cache[i] = spec
return spec return spec

View File

@@ -6,6 +6,7 @@ scgenerator module but some function may be used in any python program
import itertools import itertools
import multiprocessing import multiprocessing
import re
import threading import threading
from collections import abc from collections import abc
from copy import deepcopy from copy import deepcopy
@@ -195,7 +196,7 @@ def format_variable_list(l: List[Tuple[str, Any]]):
def branch_id(branch: Tuple[Path, ...]) -> str: def branch_id(branch: Tuple[Path, ...]) -> str:
return "".join("".join(b.name.split()[2:-2]) for b in branch) return "".join("".join(re.sub(r"id\d\S*num\d", "", b.name).split()[2:-2]) for b in branch)
def format_value(value) -> str: def format_value(value) -> str: