This commit is contained in:
Benoît Sierro
2021-05-28 13:58:53 +02:00
parent ce648b12ff
commit ce9a11e16e
3 changed files with 49 additions and 38 deletions

View File

@@ -607,7 +607,7 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
params = _comform_custom_field(params) params = _comform_custom_field(params)
# Initial field # Initial field
elif "field_0" in params: elif "field_0" in params:
params = _validate_custom_init_field(params) params = _evalutate_custom_field_equation(params)
params = _comform_custom_field(params) params = _comform_custom_field(params)
else: else:
params = _update_pulse_parameters(params) params = _update_pulse_parameters(params)
@@ -636,7 +636,6 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
def compute_subsequent_paramters(sim_folder: str, config: Dict[str, Any]) -> Dict[str, Any]: def compute_subsequent_paramters(sim_folder: str, config: Dict[str, Any]) -> Dict[str, Any]:
params = compute_init_parameters(config) params = compute_init_parameters(config)
spec = io.load_last_spectrum(sim_folder)[1] spec = io.load_last_spectrum(sim_folder)[1]
params["field_0"] = np.fft.ifft(spec) * params["input_transmission"] params["field_0"] = np.fft.ifft(spec) * params["input_transmission"]
@@ -656,6 +655,7 @@ def _comform_custom_field(params):
params["width"], params["peak_power"], params["energy"] = pulse.measure_field( params["width"], params["peak_power"], params["energy"] = pulse.measure_field(
params["t"], params["field_0"] params["t"], params["field_0"]
) )
wl = params["wavelength"]
return params return params
@@ -678,10 +678,22 @@ def _update_pulse_parameters(params):
return params return params
def _validate_custom_init_field(params): def _evalutate_custom_field_equation(params):
field_info = params["field_0"] field_info = params["field_0"]
if isinstance(field_info, str): if isinstance(field_info, str):
field_0 = evaluate_field_equation(field_info, **params) field_0 = eval(
field_info,
dict(
sin=np.sin,
cos=np.cos,
tan=np.tan,
exp=np.exp,
pi=np.pi,
sqrt=np.sqrt,
**params,
),
)
params["field_0"] = field_0 params["field_0"] = field_0
elif len(field_info) != params["t_num"]: elif len(field_info) != params["t_num"]:
raise ValueError( raise ValueError(
@@ -741,15 +753,20 @@ def _generate_sim_grid(params):
params["dt"] = t[1] - t[0] params["dt"] = t[1] - t[0]
params["t_num"] = len(t) params["t_num"] = len(t)
w_c = wspace(t) params = _update_frequency_domain(params)
params["z_targets"] = np.linspace(0, params["length"], params["z_num"])
return params
def _update_frequency_domain(params):
w_c = wspace(params["t"])
w0 = units.m(params["wavelength"]) w0 = units.m(params["wavelength"])
params["w0"] = w0 params["w0"] = w0
params["w_c"] = w_c params["w_c"] = w_c
params["w"] = w_c + w0 params["w"] = w_c + w0
params["w_power_fact"] = np.array([power_fact(w_c, k) for k in range(2, 11)]) params["w_power_fact"] = np.array([power_fact(w_c, k) for k in range(2, 11)])
params["z_targets"] = np.linspace(0, params["length"], params["z_num"])
return params return params
@@ -784,18 +801,3 @@ def sanitize_z_targets(z_targets):
z_targets = [0] + z_targets z_targets = [0] + z_targets
return z_targets return z_targets
def evaluate_field_equation(eq, **kwargs):
return eval(
eq,
dict(
sin=np.sin,
cos=np.cos,
tan=np.tan,
exp=np.exp,
pi=np.pi,
sqrt=np.sqrt,
**kwargs,
),
)

View File

@@ -10,6 +10,7 @@ import toml
from send2trash import TrashPermissionError, send2trash from send2trash import TrashPermissionError, send2trash
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
import itertools
from . import utils from . import utils
from .const import ENVIRON_KEY_BASE, PARAM_SEPARATOR, PREFIX_KEY_BASE, TMP_FOLDER_KEY_BASE from .const import ENVIRON_KEY_BASE, PARAM_SEPARATOR, PREFIX_KEY_BASE, TMP_FOLDER_KEY_BASE
@@ -313,7 +314,7 @@ def check_data_integrity(sub_folders: List[str], init_z_num: int):
def propagation_initiated(sub_folder) -> bool: def propagation_initiated(sub_folder) -> bool:
if os.path.isdir(sub_folder): if os.path.isdir(sub_folder):
return find_last_spectrum_file(sub_folder) > 0 return find_last_spectrum_num(sub_folder) > 0
return False return False
@@ -339,7 +340,7 @@ def num_left_to_propagate(sub_folder: str, init_z_num: int) -> int:
""" """
params = load_toml(os.path.join(sub_folder, "params.toml")) params = load_toml(os.path.join(sub_folder, "params.toml"))
z_num = params["z_num"] z_num = params["z_num"]
num_spectra = find_last_spectrum_file(sub_folder) + 1 # because of zero-indexing num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing
if z_num != init_z_num: if z_num != init_z_num:
raise IncompleteDataFolderError( raise IncompleteDataFolderError(
@@ -350,20 +351,21 @@ def num_left_to_propagate(sub_folder: str, init_z_num: int) -> int:
return z_num - num_spectra return z_num - num_spectra
def find_last_spectrum_file(path: str): def find_last_spectrum_num(data_dir: Path):
num = 0 for num in itertools.count():
while True: if not (data_dir / f"spectrum_{num}.npy").is_file():
if os.path.isfile(os.path.join(path, f"spectrum_{num}.npy")):
num += 1
pass
else:
return num - 1 return num - 1
def load_last_spectrum(path: str) -> Tuple[int, np.ndarray]: def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]:
"""return the last spectrum stored in path as well as its id""" """return the last spectrum stored in path as well as its id"""
num = find_last_spectrum_file(path) num = find_last_spectrum_num(data_dir)
return num, np.load(os.path.join(path, f"spectrum_{num}.npy")) return num, np.load(data_dir / f"spectrum_{num}.npy")
def last_spectrum_path(path: Path) -> Path:
num = find_last_spectrum_num(path)
return path / f"spectrum_{num}.npy"
def merge(paths: Union[str, List[str]], delete=False): def merge(paths: Union[str, List[str]], delete=False):

View File

@@ -1,7 +1,7 @@
import os import os
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from glob import glob from glob import glob
from typing import Any, List, Tuple from typing import Any, Dict, List, Tuple
import numpy as np import numpy as np
@@ -51,7 +51,7 @@ class Pulse(Sequence):
self.z = self.params["z_targets"] self.z = self.params["z_targets"]
else: else:
raise raise
self.cache: Dict[int, Spectrum] = {}
self.nmax = len(glob(os.path.join(self.path, "spectra_*.npy"))) self.nmax = len(glob(os.path.join(self.path, "spectra_*.npy")))
if self.nmax <= 0: if self.nmax <= 0:
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}") raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
@@ -183,11 +183,18 @@ class Pulse(Sequence):
return spectra return spectra
def all_fields(self, ind=None):
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
def _load1(self, i: int): def _load1(self, i: int):
if i in self.cache:
return self.cache[i]
spec = io.load_single_spectrum(self.path, i) spec = io.load_single_spectrum(self.path, i)
if self.__ensure_2d: if self.__ensure_2d:
spec = np.atleast_2d(spec) spec = np.atleast_2d(spec)
return Spectrum(spec, self.wl, self.params["frep"]) spec = Spectrum(spec, self.wl, self.params["frep"])
self.cache[i] = spec
return spec
class SpectraCollection(Mapping, Sequence): class SpectraCollection(Mapping, Sequence):