misc
This commit is contained in:
@@ -147,9 +147,11 @@ def run_sim(args):
|
|||||||
def merge(args):
|
def merge(args):
|
||||||
path_trees = io.build_path_trees(Path(args.path))
|
path_trees = io.build_path_trees(Path(args.path))
|
||||||
|
|
||||||
if args.output_name is None:
|
output = env.output_path()
|
||||||
args.output_name = path_trees[0][-1][0].parent.name + " merged"
|
if output is None:
|
||||||
io.merge(args.output_name, path_trees)
|
output = path_trees[0][-1][0].parent.name + " merged"
|
||||||
|
|
||||||
|
io.merge(output, path_trees)
|
||||||
|
|
||||||
|
|
||||||
def prep_ray():
|
def prep_ray():
|
||||||
|
|||||||
@@ -519,15 +519,15 @@ def validate_config_sequence(*configs: os.PathLike) -> Tuple[Config, int]:
|
|||||||
"""
|
"""
|
||||||
previous = None
|
previous = None
|
||||||
variables = set()
|
variables = set()
|
||||||
num = 1
|
repeat = 1
|
||||||
for config in configs:
|
for config in configs:
|
||||||
if (p := Path(config)).is_dir():
|
if (p := Path(config)).is_dir():
|
||||||
config = p / "initial_config.toml"
|
config = p / "initial_config.toml"
|
||||||
dico = io.load_toml(config)
|
dico = io.load_toml(config)
|
||||||
previous = Config.from_bare(override_config(dico, previous))
|
previous = Config.from_bare(override_config(dico, previous))
|
||||||
num *= previous.repeat
|
repeat = previous.repeat
|
||||||
variables |= {(k, tuple(v)) for k, v in previous.variable.items()}
|
variables |= {(k, tuple(v)) for k, v in previous.variable.items()}
|
||||||
return previous, num * int(np.product([len(v) for k, v in variables if len(v) > 0]))
|
return previous, repeat * int(np.product([len(v) for k, v in variables if len(v) > 0]))
|
||||||
|
|
||||||
|
|
||||||
def wspace(t, t_num=0):
|
def wspace(t, t_num=0):
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Dict, List, Tuple, Type
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .. import const, env, initialize, io, utils
|
from .. import env, initialize, io, utils
|
||||||
from ..errors import IncompleteDataFolderError
|
from ..errors import IncompleteDataFolderError
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from . import pulse
|
from . import pulse
|
||||||
@@ -668,7 +668,7 @@ def run_simulation_sequence(
|
|||||||
prev = sim.sim_dir
|
prev = sim.sim_dir
|
||||||
path_trees = io.build_path_trees(sim.sim_dir)
|
path_trees = io.build_path_trees(sim.sim_dir)
|
||||||
|
|
||||||
final_name = env.get(const.OUTPUT_PATH)
|
final_name = env.get(env.OUTPUT_PATH)
|
||||||
if final_name is None:
|
if final_name is None:
|
||||||
final_name = path_trees[0][-1][0].parent.name + " merged"
|
final_name = path_trees[0][-1][0].parent.name + " merged"
|
||||||
|
|
||||||
|
|||||||
@@ -33,19 +33,29 @@ class Spectrum(np.ndarray):
|
|||||||
|
|
||||||
|
|
||||||
class Pulse(Sequence):
|
class Pulse(Sequence):
|
||||||
def __init__(self, path: os.PathLike, ensure_2d=True):
|
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
|
||||||
|
"""load a data folder as a pulse
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : os.PathLike
|
||||||
|
path to the data (folder containing .npy files)
|
||||||
|
default_ind : int | Iterable[int], optional
|
||||||
|
default indices to be loaded, by default None
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
path does not contain proper data
|
||||||
|
"""
|
||||||
self.logger = get_logger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
self.path = Path(path)
|
self.path = Path(path)
|
||||||
self.__ensure_2d = ensure_2d
|
self.default_ind = default_ind
|
||||||
|
|
||||||
if not self.path.is_dir():
|
if not self.path.is_dir():
|
||||||
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
||||||
|
|
||||||
self.params = None
|
|
||||||
try:
|
|
||||||
self.params = io.load_params(self.path / "params.toml")
|
self.params = io.load_params(self.path / "params.toml")
|
||||||
except FileNotFoundError:
|
|
||||||
self.logger.info(f"parameters corresponding to {self.path} not found")
|
|
||||||
|
|
||||||
initialize.build_sim_grid_in_place(self.params)
|
initialize.build_sim_grid_in_place(self.params)
|
||||||
|
|
||||||
@@ -173,8 +183,11 @@ class Pulse(Sequence):
|
|||||||
# Check if file exists and assert how many z positions there are
|
# Check if file exists and assert how many z positions there are
|
||||||
|
|
||||||
if ind is None:
|
if ind is None:
|
||||||
|
if self.default_ind is None:
|
||||||
ind = range(self.nmax)
|
ind = range(self.nmax)
|
||||||
elif isinstance(ind, int):
|
else:
|
||||||
|
ind = self.default_ind
|
||||||
|
if isinstance(ind, int):
|
||||||
ind = [ind]
|
ind = [ind]
|
||||||
|
|
||||||
# Load the spectra
|
# Load the spectra
|
||||||
@@ -184,7 +197,9 @@ class Pulse(Sequence):
|
|||||||
spectra = np.array(spectra)
|
spectra = np.array(spectra)
|
||||||
|
|
||||||
self.logger.debug(f"all spectra from {self.path} successfully loaded")
|
self.logger.debug(f"all spectra from {self.path} successfully loaded")
|
||||||
|
if len(ind) == 1:
|
||||||
|
return spectra[0]
|
||||||
|
else:
|
||||||
return spectra
|
return spectra
|
||||||
|
|
||||||
def all_fields(self, ind=None):
|
def all_fields(self, ind=None):
|
||||||
@@ -196,7 +211,6 @@ class Pulse(Sequence):
|
|||||||
if i in self.cache:
|
if i in self.cache:
|
||||||
return self.cache[i]
|
return self.cache[i]
|
||||||
spec = np.load(self.path / SPECN_FN.format(i))
|
spec = np.load(self.path / SPECN_FN.format(i))
|
||||||
if self.__ensure_2d:
|
|
||||||
spec = np.atleast_2d(spec)
|
spec = np.atleast_2d(spec)
|
||||||
spec = Spectrum(spec, self.wl, self.params.repetition_rate)
|
spec = Spectrum(spec, self.wl, self.params.repetition_rate)
|
||||||
self.cache[i] = spec
|
self.cache[i] = spec
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ scgenerator module but some function may be used in any python program
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import re
|
||||||
import threading
|
import threading
|
||||||
from collections import abc
|
from collections import abc
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@@ -195,7 +196,7 @@ def format_variable_list(l: List[Tuple[str, Any]]):
|
|||||||
|
|
||||||
|
|
||||||
def branch_id(branch: Tuple[Path, ...]) -> str:
|
def branch_id(branch: Tuple[Path, ...]) -> str:
|
||||||
return "".join("".join(b.name.split()[2:-2]) for b in branch)
|
return "".join("".join(re.sub(r"id\d\S*num\d", "", b.name).split()[2:-2]) for b in branch)
|
||||||
|
|
||||||
|
|
||||||
def format_value(value) -> str:
|
def format_value(value) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user