This commit is contained in:
Benoît Sierro
2021-05-28 13:58:53 +02:00
parent ce648b12ff
commit ce9a11e16e
3 changed files with 49 additions and 38 deletions

View File

@@ -607,7 +607,7 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
params = _comform_custom_field(params)
# Initial field
elif "field_0" in params:
params = _validate_custom_init_field(params)
params = _evalutate_custom_field_equation(params)
params = _comform_custom_field(params)
else:
params = _update_pulse_parameters(params)
@@ -636,7 +636,6 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
def compute_subsequent_paramters(sim_folder: str, config: Dict[str, Any]) -> Dict[str, Any]:
params = compute_init_parameters(config)
spec = io.load_last_spectrum(sim_folder)[1]
params["field_0"] = np.fft.ifft(spec) * params["input_transmission"]
@@ -656,6 +655,7 @@ def _comform_custom_field(params):
params["width"], params["peak_power"], params["energy"] = pulse.measure_field(
params["t"], params["field_0"]
)
wl = params["wavelength"]
return params
@@ -678,10 +678,22 @@ def _update_pulse_parameters(params):
return params
def _validate_custom_init_field(params):
def _evalutate_custom_field_equation(params):
field_info = params["field_0"]
if isinstance(field_info, str):
field_0 = evaluate_field_equation(field_info, **params)
field_0 = eval(
field_info,
dict(
sin=np.sin,
cos=np.cos,
tan=np.tan,
exp=np.exp,
pi=np.pi,
sqrt=np.sqrt,
**params,
),
)
params["field_0"] = field_0
elif len(field_info) != params["t_num"]:
raise ValueError(
@@ -741,15 +753,20 @@ def _generate_sim_grid(params):
params["dt"] = t[1] - t[0]
params["t_num"] = len(t)
w_c = wspace(t)
params = _update_frequency_domain(params)
params["z_targets"] = np.linspace(0, params["length"], params["z_num"])
return params
def _update_frequency_domain(params):
w_c = wspace(params["t"])
w0 = units.m(params["wavelength"])
params["w0"] = w0
params["w_c"] = w_c
params["w"] = w_c + w0
params["w_power_fact"] = np.array([power_fact(w_c, k) for k in range(2, 11)])
params["z_targets"] = np.linspace(0, params["length"], params["z_num"])
return params
@@ -784,18 +801,3 @@ def sanitize_z_targets(z_targets):
z_targets = [0] + z_targets
return z_targets
def evaluate_field_equation(eq, **kwargs):
return eval(
eq,
dict(
sin=np.sin,
cos=np.cos,
tan=np.tan,
exp=np.exp,
pi=np.pi,
sqrt=np.sqrt,
**kwargs,
),
)

View File

@@ -10,6 +10,7 @@ import toml
from send2trash import TrashPermissionError, send2trash
from tqdm import tqdm
from pathlib import Path
import itertools
from . import utils
from .const import ENVIRON_KEY_BASE, PARAM_SEPARATOR, PREFIX_KEY_BASE, TMP_FOLDER_KEY_BASE
@@ -313,7 +314,7 @@ def check_data_integrity(sub_folders: List[str], init_z_num: int):
def propagation_initiated(sub_folder) -> bool:
if os.path.isdir(sub_folder):
return find_last_spectrum_file(sub_folder) > 0
return find_last_spectrum_num(sub_folder) > 0
return False
@@ -339,7 +340,7 @@ def num_left_to_propagate(sub_folder: str, init_z_num: int) -> int:
"""
params = load_toml(os.path.join(sub_folder, "params.toml"))
z_num = params["z_num"]
num_spectra = find_last_spectrum_file(sub_folder) + 1 # because of zero-indexing
num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing
if z_num != init_z_num:
raise IncompleteDataFolderError(
@@ -350,20 +351,21 @@ def num_left_to_propagate(sub_folder: str, init_z_num: int) -> int:
return z_num - num_spectra
def find_last_spectrum_file(path: str):
num = 0
while True:
if os.path.isfile(os.path.join(path, f"spectrum_{num}.npy")):
num += 1
pass
else:
def find_last_spectrum_num(data_dir: Path):
for num in itertools.count():
if not (data_dir / f"spectrum_{num}.npy").is_file():
return num - 1
def load_last_spectrum(path: str) -> Tuple[int, np.ndarray]:
def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]:
"""return the last spectrum stored in path as well as its id"""
num = find_last_spectrum_file(path)
return num, np.load(os.path.join(path, f"spectrum_{num}.npy"))
num = find_last_spectrum_num(data_dir)
return num, np.load(data_dir / f"spectrum_{num}.npy")
def last_spectrum_path(path: Path) -> Path:
num = find_last_spectrum_num(path)
return path / f"spectrum_{num}.npy"
def merge(paths: Union[str, List[str]], delete=False):

View File

@@ -1,7 +1,7 @@
import os
from collections.abc import Mapping, Sequence
from glob import glob
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple
import numpy as np
@@ -51,7 +51,7 @@ class Pulse(Sequence):
self.z = self.params["z_targets"]
else:
raise
self.cache: Dict[int, Spectrum] = {}
self.nmax = len(glob(os.path.join(self.path, "spectra_*.npy")))
if self.nmax <= 0:
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
@@ -183,11 +183,18 @@ class Pulse(Sequence):
return spectra
def all_fields(self, ind=None):
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
def _load1(self, i: int):
if i in self.cache:
return self.cache[i]
spec = io.load_single_spectrum(self.path, i)
if self.__ensure_2d:
spec = np.atleast_2d(spec)
return Spectrum(spec, self.wl, self.params["frep"])
spec = Spectrum(spec, self.wl, self.params["frep"])
self.cache[i] = spec
return spec
class SpectraCollection(Mapping, Sequence):