misc
This commit is contained in:
@@ -607,7 +607,7 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
params = _comform_custom_field(params)
|
||||
# Initial field
|
||||
elif "field_0" in params:
|
||||
params = _validate_custom_init_field(params)
|
||||
params = _evalutate_custom_field_equation(params)
|
||||
params = _comform_custom_field(params)
|
||||
else:
|
||||
params = _update_pulse_parameters(params)
|
||||
@@ -636,7 +636,6 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def compute_subsequent_paramters(sim_folder: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
params = compute_init_parameters(config)
|
||||
spec = io.load_last_spectrum(sim_folder)[1]
|
||||
params["field_0"] = np.fft.ifft(spec) * params["input_transmission"]
|
||||
@@ -656,6 +655,7 @@ def _comform_custom_field(params):
|
||||
params["width"], params["peak_power"], params["energy"] = pulse.measure_field(
|
||||
params["t"], params["field_0"]
|
||||
)
|
||||
wl = params["wavelength"]
|
||||
return params
|
||||
|
||||
|
||||
@@ -678,10 +678,22 @@ def _update_pulse_parameters(params):
|
||||
return params
|
||||
|
||||
|
||||
def _validate_custom_init_field(params):
|
||||
def _evalutate_custom_field_equation(params):
|
||||
field_info = params["field_0"]
|
||||
if isinstance(field_info, str):
|
||||
field_0 = evaluate_field_equation(field_info, **params)
|
||||
field_0 = eval(
|
||||
field_info,
|
||||
dict(
|
||||
sin=np.sin,
|
||||
cos=np.cos,
|
||||
tan=np.tan,
|
||||
exp=np.exp,
|
||||
pi=np.pi,
|
||||
sqrt=np.sqrt,
|
||||
**params,
|
||||
),
|
||||
)
|
||||
|
||||
params["field_0"] = field_0
|
||||
elif len(field_info) != params["t_num"]:
|
||||
raise ValueError(
|
||||
@@ -741,15 +753,20 @@ def _generate_sim_grid(params):
|
||||
params["dt"] = t[1] - t[0]
|
||||
params["t_num"] = len(t)
|
||||
|
||||
w_c = wspace(t)
|
||||
params = _update_frequency_domain(params)
|
||||
|
||||
params["z_targets"] = np.linspace(0, params["length"], params["z_num"])
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _update_frequency_domain(params):
|
||||
w_c = wspace(params["t"])
|
||||
w0 = units.m(params["wavelength"])
|
||||
params["w0"] = w0
|
||||
params["w_c"] = w_c
|
||||
params["w"] = w_c + w0
|
||||
params["w_power_fact"] = np.array([power_fact(w_c, k) for k in range(2, 11)])
|
||||
|
||||
params["z_targets"] = np.linspace(0, params["length"], params["z_num"])
|
||||
|
||||
return params
|
||||
|
||||
|
||||
@@ -784,18 +801,3 @@ def sanitize_z_targets(z_targets):
|
||||
z_targets = [0] + z_targets
|
||||
|
||||
return z_targets
|
||||
|
||||
|
||||
def evaluate_field_equation(eq, **kwargs):
|
||||
return eval(
|
||||
eq,
|
||||
dict(
|
||||
sin=np.sin,
|
||||
cos=np.cos,
|
||||
tan=np.tan,
|
||||
exp=np.exp,
|
||||
pi=np.pi,
|
||||
sqrt=np.sqrt,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ import toml
|
||||
from send2trash import TrashPermissionError, send2trash
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import itertools
|
||||
|
||||
from . import utils
|
||||
from .const import ENVIRON_KEY_BASE, PARAM_SEPARATOR, PREFIX_KEY_BASE, TMP_FOLDER_KEY_BASE
|
||||
@@ -313,7 +314,7 @@ def check_data_integrity(sub_folders: List[str], init_z_num: int):
|
||||
|
||||
def propagation_initiated(sub_folder) -> bool:
|
||||
if os.path.isdir(sub_folder):
|
||||
return find_last_spectrum_file(sub_folder) > 0
|
||||
return find_last_spectrum_num(sub_folder) > 0
|
||||
return False
|
||||
|
||||
|
||||
@@ -339,7 +340,7 @@ def num_left_to_propagate(sub_folder: str, init_z_num: int) -> int:
|
||||
"""
|
||||
params = load_toml(os.path.join(sub_folder, "params.toml"))
|
||||
z_num = params["z_num"]
|
||||
num_spectra = find_last_spectrum_file(sub_folder) + 1 # because of zero-indexing
|
||||
num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing
|
||||
|
||||
if z_num != init_z_num:
|
||||
raise IncompleteDataFolderError(
|
||||
@@ -350,20 +351,21 @@ def num_left_to_propagate(sub_folder: str, init_z_num: int) -> int:
|
||||
return z_num - num_spectra
|
||||
|
||||
|
||||
def find_last_spectrum_file(path: str):
|
||||
num = 0
|
||||
while True:
|
||||
if os.path.isfile(os.path.join(path, f"spectrum_{num}.npy")):
|
||||
num += 1
|
||||
pass
|
||||
else:
|
||||
def find_last_spectrum_num(data_dir: Path):
|
||||
for num in itertools.count():
|
||||
if not (data_dir / f"spectrum_{num}.npy").is_file():
|
||||
return num - 1
|
||||
|
||||
|
||||
def load_last_spectrum(path: str) -> Tuple[int, np.ndarray]:
|
||||
def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]:
|
||||
"""return the last spectrum stored in path as well as its id"""
|
||||
num = find_last_spectrum_file(path)
|
||||
return num, np.load(os.path.join(path, f"spectrum_{num}.npy"))
|
||||
num = find_last_spectrum_num(data_dir)
|
||||
return num, np.load(data_dir / f"spectrum_{num}.npy")
|
||||
|
||||
|
||||
def last_spectrum_path(path: Path) -> Path:
|
||||
num = find_last_spectrum_num(path)
|
||||
return path / f"spectrum_{num}.npy"
|
||||
|
||||
|
||||
def merge(paths: Union[str, List[str]], delete=False):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from glob import glob
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -51,7 +51,7 @@ class Pulse(Sequence):
|
||||
self.z = self.params["z_targets"]
|
||||
else:
|
||||
raise
|
||||
|
||||
self.cache: Dict[int, Spectrum] = {}
|
||||
self.nmax = len(glob(os.path.join(self.path, "spectra_*.npy")))
|
||||
if self.nmax <= 0:
|
||||
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
||||
@@ -183,11 +183,18 @@ class Pulse(Sequence):
|
||||
|
||||
return spectra
|
||||
|
||||
def all_fields(self, ind=None):
|
||||
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
|
||||
|
||||
def _load1(self, i: int):
|
||||
if i in self.cache:
|
||||
return self.cache[i]
|
||||
spec = io.load_single_spectrum(self.path, i)
|
||||
if self.__ensure_2d:
|
||||
spec = np.atleast_2d(spec)
|
||||
return Spectrum(spec, self.wl, self.params["frep"])
|
||||
spec = Spectrum(spec, self.wl, self.params["frep"])
|
||||
self.cache[i] = spec
|
||||
return spec
|
||||
|
||||
|
||||
class SpectraCollection(Mapping, Sequence):
|
||||
|
||||
Reference in New Issue
Block a user