From ce9a11e16e7fa429845c5e5ce0caaf04d4b34b24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Fri, 28 May 2021 13:58:53 +0200 Subject: [PATCH] misc --- src/scgenerator/initialize.py | 48 ++++++++++++++++++----------------- src/scgenerator/io.py | 26 ++++++++++--------- src/scgenerator/spectra.py | 13 +++++++--- 3 files changed, 49 insertions(+), 38 deletions(-) diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 80c714b..be9ba4d 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -607,7 +607,7 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]: params = _comform_custom_field(params) # Initial field elif "field_0" in params: - params = _validate_custom_init_field(params) + params = _evalutate_custom_field_equation(params) params = _comform_custom_field(params) else: params = _update_pulse_parameters(params) @@ -636,7 +636,6 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]: def compute_subsequent_paramters(sim_folder: str, config: Dict[str, Any]) -> Dict[str, Any]: - params = compute_init_parameters(config) spec = io.load_last_spectrum(sim_folder)[1] params["field_0"] = np.fft.ifft(spec) * params["input_transmission"] @@ -656,6 +655,7 @@ def _comform_custom_field(params): params["width"], params["peak_power"], params["energy"] = pulse.measure_field( params["t"], params["field_0"] ) + wl = params["wavelength"] return params @@ -678,10 +678,22 @@ def _update_pulse_parameters(params): return params -def _validate_custom_init_field(params): +def _evalutate_custom_field_equation(params): field_info = params["field_0"] if isinstance(field_info, str): - field_0 = evaluate_field_equation(field_info, **params) + field_0 = eval( + field_info, + dict( + sin=np.sin, + cos=np.cos, + tan=np.tan, + exp=np.exp, + pi=np.pi, + sqrt=np.sqrt, + **params, + ), + ) + params["field_0"] = field_0 elif len(field_info) != params["t_num"]: raise ValueError( @@ -741,15 +753,20 @@ def _generate_sim_grid(params): params["dt"] = t[1] - t[0] params["t_num"] = len(t) - w_c = wspace(t) + params = _update_frequency_domain(params) + + params["z_targets"] = np.linspace(0, params["length"], params["z_num"]) + + return params + + +def _update_frequency_domain(params): + w_c = wspace(params["t"]) w0 = units.m(params["wavelength"]) params["w0"] = w0 params["w_c"] = w_c params["w"] = w_c + w0 params["w_power_fact"] = np.array([power_fact(w_c, k) for k in range(2, 11)]) - - params["z_targets"] = np.linspace(0, params["length"], params["z_num"]) - return params @@ -784,18 +801,3 @@ def sanitize_z_targets(z_targets): z_targets = [0] + z_targets return z_targets - - -def evaluate_field_equation(eq, **kwargs): - return eval( - eq, - dict( - sin=np.sin, - cos=np.cos, - tan=np.tan, - exp=np.exp, - pi=np.pi, - sqrt=np.sqrt, - **kwargs, - ), - ) diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index d44ab8f..b2df02a 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -10,6 +10,7 @@ import toml from send2trash import TrashPermissionError, send2trash from tqdm import tqdm from pathlib import Path +import itertools from . import utils from .const import ENVIRON_KEY_BASE, PARAM_SEPARATOR, PREFIX_KEY_BASE, TMP_FOLDER_KEY_BASE @@ -313,7 +314,7 @@ def check_data_integrity(sub_folders: List[str], init_z_num: int): def propagation_initiated(sub_folder) -> bool: if os.path.isdir(sub_folder): - return find_last_spectrum_file(sub_folder) > 0 + return find_last_spectrum_num(sub_folder) > 0 return False @@ -339,7 +340,7 @@ def num_left_to_propagate(sub_folder: str, init_z_num: int) -> int: """ params = load_toml(os.path.join(sub_folder, "params.toml")) z_num = params["z_num"] - num_spectra = find_last_spectrum_file(sub_folder) + 1 # because of zero-indexing + num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing if z_num != init_z_num: raise IncompleteDataFolderError( @@ -350,20 +351,21 @@ def num_left_to_propagate(sub_folder: str, init_z_num: int) -> int: return z_num - num_spectra -def find_last_spectrum_file(path: str): - num = 0 - while True: - if os.path.isfile(os.path.join(path, f"spectrum_{num}.npy")): - num += 1 - pass - else: +def find_last_spectrum_num(data_dir: Path): + for num in itertools.count(): + if not (data_dir / f"spectrum_{num}.npy").is_file(): return num - 1 -def load_last_spectrum(path: str) -> Tuple[int, np.ndarray]: +def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]: """return the last spectrum stored in path as well as its id""" - num = find_last_spectrum_file(path) - return num, np.load(os.path.join(path, f"spectrum_{num}.npy")) + num = find_last_spectrum_num(data_dir) + return num, np.load(data_dir / f"spectrum_{num}.npy") + + +def last_spectrum_path(path: Path) -> Path: + num = find_last_spectrum_num(path) + return path / f"spectrum_{num}.npy" def merge(paths: Union[str, List[str]], delete=False): diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index a45596b..52b7dce 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -1,7 +1,7 @@ import os from collections.abc import Mapping, Sequence from glob import glob -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple import numpy as np @@ -51,7 +51,7 @@ class Pulse(Sequence): self.z = self.params["z_targets"] else: raise - + self.cache: Dict[int, Spectrum] = {} self.nmax = len(glob(os.path.join(self.path, "spectra_*.npy"))) if self.nmax <= 0: raise FileNotFoundError(f"No appropriate file in specified folder {self.path}") @@ -183,11 +183,18 @@ class Pulse(Sequence): return spectra + def all_fields(self, ind=None): + return np.fft.ifft(self.all_spectra(ind=ind), axis=-1) + def _load1(self, i: int): + if i in self.cache: + return self.cache[i] spec = io.load_single_spectrum(self.path, i) if self.__ensure_2d: spec = np.atleast_2d(spec) - return Spectrum(spec, self.wl, self.params["frep"]) + spec = Spectrum(spec, self.wl, self.params["frep"]) + self.cache[i] = spec + return spec class SpectraCollection(Mapping, Sequence):