diff --git a/src/scgenerator/defaults.py b/src/scgenerator/defaults.py deleted file mode 100644 index b55add0..0000000 --- a/src/scgenerator/defaults.py +++ /dev/null @@ -1,31 +0,0 @@ -from pathlib import Path - -import matplotlib.pyplot as plt - -default_plotting = dict( - figsize=(10, 7), - interpolation_2D="bicubic", - vmin=-40, - vmax=0, - vmax_with_headroom=2, - out_path=Path("plot"), - avg_main_to_coherence_ratio=4, - avg_line_labels=["individual values", "mean"], - muted_style=dict(linewidth=0.5, c=(0.8, 0.8, 0.8, 0.4)), - highlighted_style=dict(c="red"), - color_cycle=plt.rcParams["axes.prop_cycle"].by_key()["color"], - light_color=(1, 1, 1, 0.7), - markers=["*", "+", ".", "D", "x", "d", "v", "s", "1", "^"], - cmap="viridis", - label_quality_factor=r"$F_\mathrm{Q}$", - label_mean_g12=r"$\langle | g_{12} |\rangle$", - label_g12=r"|$g_{12}$|", - label_z="propagation distance z (m)", - label_fwhm=r"$T_\mathrm{FWHM}$ (fs)", - label_wb_distance=r"$L_\mathrm{WB}$", - label_t_jitter="timing jitter (fs)", - label_fwhm_noise="FWHM noise (%)", - label_int_noise="RIN (%)", - text_topright_style=dict(verticalalignment="top", horizontalalignment="right"), - text_topleft_style=dict(verticalalignment="top", horizontalalignment="left"), -) diff --git a/src/scgenerator/env.py b/src/scgenerator/env.py index da08da8..bcaca9a 100644 --- a/src/scgenerator/env.py +++ b/src/scgenerator/env.py @@ -36,4 +36,32 @@ class Config(BaseSettings): log_file: Path | None = None +default_plotting = dict( + figsize=(10, 7), + interpolation_2D="bicubic", + vmin=-40, + vmax=0, + vmax_with_headroom=2, + out_path=Path("plot"), + avg_main_to_coherence_ratio=4, + avg_line_labels=["individual values", "mean"], + muted_style=dict(linewidth=0.5, c=(0.8, 0.8, 0.8, 0.4)), + highlighted_style=dict(c="red"), + color_cycle=plt.rcParams["axes.prop_cycle"].by_key()["color"], + light_color=(1, 1, 1, 0.7), + markers=["*", "+", ".", "D", "x", "d", "v", "s", "1", "^"], + cmap="viridis", + label_quality_factor=r"$F_\mathrm{Q}$", + label_mean_g12=r"$\langle | g_{12} |\rangle$", + label_g12=r"|$g_{12}$|", + label_z="propagation distance z (m)", + label_fwhm=r"$T_\mathrm{FWHM}$ (fs)", + label_wb_distance=r"$L_\mathrm{WB}$", + label_t_jitter="timing jitter (fs)", + label_fwhm_noise="FWHM noise (%)", + label_int_noise="RIN (%)", + text_topright_style=dict(verticalalignment="top", horizontalalignment="right"), + text_topleft_style=dict(verticalalignment="top", horizontalalignment="left"), +) + CONFIG = Config() diff --git a/src/scgenerator/evaluator.py b/src/scgenerator/evaluator.py index 50c5086..c440efd 100644 --- a/src/scgenerator/evaluator.py +++ b/src/scgenerator/evaluator.py @@ -9,10 +9,10 @@ from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Union import numpy as np -from scgenerator import io, math, operators, utils +from scgenerator import io, math, operators from scgenerator.const import INF, MANDATORY_PARAMETERS +from scgenerator.logger import get_logger from scgenerator.physics import fiber, materials, plasma, pulse, units -from scgenerator.utils import get_logger class ErrorRecord(NamedTuple): @@ -464,7 +464,7 @@ default_rules: list[Rule] = [ Rule("w0_ind", math.argclosest, ["w", "w0"]), Rule("w_num", len, ["w"]), Rule("dw", lambda w: w[1] - w[0]), - Rule(["fft", "ifft"], utils.fft_functions, priorities=1), + Rule(["fft", "ifft"], math.fft_functions, priorities=1), Rule("wavelength_window", lambda dt, wavelength: (math.min_wl_from_dt(dt, wavelength), 8e-6)), Rule("wavelength_window", fiber.valid_wavelength_window), Rule("dispersion_ind", fiber.dispersion_indices), diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 3b6a9d3..cf690be 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -6,9 +6,10 @@ import json import os import re from dataclasses import dataclass +from functools import cache from io import BytesIO from pathlib import Path -from typing import BinaryIO, Protocol, Sequence +from typing import Any, BinaryIO, Protocol, Sequence from zipfile import ZipFile import numpy as np @@ -67,6 +68,23 @@ def custom_decode_hook(obj): return obj +@cache +def load_material_dico(name: str) -> dict[str, Any]: + """ + loads a material dictionary + + Parameters + ---------- + name : str + name of the material + + Returns + ------- + material_dico : dict + """ + return json.loads(data_file("materials.json").read_text())[name] + + class PropagationIOHandler(Protocol): def __len__(self) -> int: ... diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index fd24f90..f91338c 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -6,7 +6,7 @@ import math import warnings from dataclasses import dataclass from functools import cache -from typing import Sequence +from typing import Callable, Sequence import numba import numpy as np @@ -17,6 +17,15 @@ pi = np.pi c = 299792458.0 +def fft_functions( + full_field: bool, +) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]: + if full_field: + return np.fft.rfft, np.fft.irfft + else: + return np.fft.fft, np.fft.ifft + + def expm1_int(y: np.ndarray, dx: float) -> np.ndarray: """evaluates 1 - exp( -∫func(y(x))dx ) from x=-inf to x""" return -np.expm1(-cumulative_simpson(y) * dx) diff --git a/src/scgenerator/physics/fiber.py b/src/scgenerator/physics/fiber.py index f077e7e..b560159 100644 --- a/src/scgenerator/physics/fiber.py +++ b/src/scgenerator/physics/fiber.py @@ -377,7 +377,7 @@ def V_parameter_koshiba(l: np.ndarray, pcf_pitch: float, pcf_pitch_ratio: float) n_co = 1.45 r_eff = pcf_pitch / np.sqrt(3) pi2a = pipi * r_eff - A, B = saitoh_paramters(pcf_pitch_ratio) + A, _ = saitoh_paramters(pcf_pitch_ratio) V = A[0] + A[1] / (1 + A[2] * np.exp(A[3] * ratio_l)) diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 6b4136e..5d4f2b9 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -24,7 +24,7 @@ from scipy.optimize import minimize_scalar from scipy.optimize._optimize import OptimizeResult from scgenerator import math -from scgenerator.defaults import default_plotting +from scgenerator.env import default_plotting from scgenerator.io import DataFile from scgenerator.physics import units @@ -186,8 +186,6 @@ def modify_field_ratio( pre_field_0: np.ndarray, peak_power: float = None, energy: float = None, - intensity_noise: float = None, - noise_correlation: float = 0, ) -> float: """ multiply a field by this number to get the desired specifications @@ -735,7 +733,7 @@ def avg_g12(values: np.ndarray): avg_values = np.mean(math.abs2(values), axis=0) coherence = g12(values) - return np.sum(coherence * avg_values) / np.sum(avg_values) + return np.sum(coherence * avg_values, axis=-1) / np.sum(avg_values, axis=-1) def fwhm_ind(values: np.ndarray, mam=None): diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index 71e2d95..ff3e350 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -349,11 +349,6 @@ def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]: return unit -def is_unit(name, value): - if not hasattr(get_unit(value), "inv"): - raise TypeError("invalid unit specified") - - def beta2_coef(beta2_coefficients): fac = 1e27 out = np.zeros_like(beta2_coefficients) diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index e204347..c502157 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -14,12 +14,12 @@ from scipy.interpolate import UnivariateSpline, interp1d from scgenerator import math from scgenerator.const import PARAM_SEPARATOR -from scgenerator.defaults import default_plotting as defaults +from scgenerator.env import default_plotting as defaults from scgenerator.math import abs2, linear_interp_2d, span from scgenerator.parameter import Parameters from scgenerator.physics import pulse, units from scgenerator.physics.units import PlotRange, sort_axis -from scgenerator.spectra import Propagation, Spectrum +from scgenerator.spectra import Spectrum RangeType = tuple[float, float, Union[str, Callable]] NO_LIM = object() diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 4abce9d..9c95ab8 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -150,6 +150,13 @@ class Spectrum(np.ndarray): def measure(self) -> tuple[float, float, float]: return pulse.measure_field(self.t, self.time_amp) + def coherence(self, axis: int = 0) -> np.ndarray: + """ + returns the coherence of the spectrum, computed by collapsing axis `axis` and aligned + on the wavelength grid + """ + return pulse.g12(self, axis)[..., self.wl_order] + freq_int = afreq_int freq_amp = afreq_amp diff --git a/src/scgenerator/threading.py b/src/scgenerator/threading.py index e505eb0..16e049c 100644 --- a/src/scgenerator/threading.py +++ b/src/scgenerator/threading.py @@ -1,11 +1,14 @@ -from multiprocessing import Process, Queue +from multiprocessing import Queue from threading import Thread + from tqdm import tqdm + class Multibar(Thread): queue: Queue - bars:list[tqdm] - def __init__(self, bars:list[tqdm], queue:Queue): + bars: list[tqdm] + + def __init__(self, bars: list[tqdm], queue: Queue): self.queue = queue self.bars = bars @@ -14,5 +17,3 @@ class Multibar(Thread): bar_id, amount = self.queue.get(True, None) self.bars[bar_id].update(amount) self.bars[0].update(amount) - - diff --git a/src/scgenerator/transform.py b/src/scgenerator/transform.py deleted file mode 100644 index 1ca8c45..0000000 --- a/src/scgenerator/transform.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np -import scgenerator.math as math -import scgenerator.physics.units as units - - -def normalize_range( - axis: np.ndarray, _range: tuple | units.PlotRange | None, num: int -) -> tuple[units.PlotRange, np.ndarray]: - if _range is None: - _range = units.PlotRange(axis.min(), axis.max(), units.no_unit) - elif not isinstance(_range, units.PlotRange): - _range = units.PlotRange(*_range) - new_axis = np.linspace(_range[0], _range[1], num) - return _range, new_axis - - -def prop_2d( - values: np.ndarray, - h_axis: np.ndarray, - v_axis: np.ndarray, - h_range: tuple | units.PlotRange | None = None, - v_range: tuple | units.PlotRange | None = None, - h_num: int = 1024, - v_num: int = 1024, - z_lim: tuple[float, float] | None = None, -): - if values.ndim != 2: - raise TypeError("prop_2d can only transform 2d data") - if np.iscomplexobj(values): - values = math.abs2(values) - - horizontal_range, horizontal = normalize_range(h_axis, h_range, h_num) - vertical_range, vertical = normalize_range(v_axis, v_range, v_num) - - values = math.interp_2d( - h_axis, v_axis, values, horizontal_range.unit(horizontal), vertical_range.unit(vertical) - ) - - if horizontal_range.must_correct_wl: - values = np.apply_along_axis( - lambda x: units.to_WL(x, horizontal_range.unit.to.m(horizontal)), 1, values - ) - elif vertical_range.must_correct_wl: - values = np.apply_along_axis( - lambda x: units.to_WL(x, vertical_range.unit.to.m(vertical)), 0, values - ) - - return horizontal, vertical, values diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py deleted file mode 100644 index 03151a9..0000000 --- a/src/scgenerator/utils.py +++ /dev/null @@ -1,297 +0,0 @@ -""" -This files includes utility functions designed more or less to be used specifically with the -scgenerator module but some function may be used in any python program - -""" -from __future__ import annotations - -import itertools -import json -import os -import re -import tomllib -from functools import cache, lru_cache -from pathlib import Path -from string import printable as str_printable -from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Union - -import numpy as np - -from scgenerator import io -from scgenerator.const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN -from scgenerator.logger import get_logger - -T_ = TypeVar("T_") - - -def conform_variable_entry(d) -> list[dict[str, list]]: - if isinstance(d, MutableMapping): - d = [{k: v} for k, v in d.items()] - return d - - -def load_previous_spectrum(prev_data_dir: str) -> np.ndarray: - prev_data_dir = Path(prev_data_dir) - num = find_last_spectrum_num(prev_data_dir) - return load_spectrum(prev_data_dir / SPEC1_FN.format(num)) - - -@lru_cache(20000) -def load_spectrum(file: os.PathLike) -> np.ndarray: - return np.load(file) - - -def conform_toml_path(path: os.PathLike) -> Path: - path: str = str(path) - if not path.lower().endswith(".toml"): - path = path + ".toml" - return Path(path) - - -def open_single_config(path: os.PathLike) -> dict[str, Any]: - d = _open_config(path) - f = d.pop("Fiber", [{}])[0] - return d | f - - -def _open_config(path: os.PathLike): - """ - returns a dictionary parsed from the specified toml file - This also handle having a 'INCLUDE' argument that will fill - otherwise unspecified keys with what's in the INCLUDE file(s) - """ - - path = conform_toml_path(path) - dico = resolve_loadfile_arg(load_toml(path)) - - if "Fiber" not in dico: - dico = dict(name=path.name, Fiber=[dico]) - - resolve_relative_paths(dico, path.parent) - - return dico - - -def resolve_relative_paths(d: dict[str, Any], root: os.PathLike | None = None): - root = Path(root) if root is not None else Path.cwd() - for k, v in d.items(): - if isinstance(v, MutableMapping): - resolve_relative_paths(v, root) - elif not isinstance(v, str) and isinstance(v, Sequence): - for el in v: - if isinstance(el, MutableMapping): - resolve_relative_paths(el, root) - elif "file" in k: - d[k] = str(root / v) - - -def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]: - if (f_list := dico.pop("INCLUDE", None)) is not None: - if isinstance(f_list, str): - f_list = [f_list] - for to_load in f_list: - loaded = load_toml(to_load) - for k, v in loaded.items(): - if k not in dico and k not in dico.get("variable", {}): - dico[k] = v - for k, v in dico.items(): - if isinstance(v, MutableMapping): - dico[k] = resolve_loadfile_arg(v) - elif isinstance(v, Sequence): - for i, vv in enumerate(v): - if isinstance(vv, MutableMapping): - dico[k][i] = resolve_loadfile_arg(vv) - return dico - - -def load_toml(descr: os.PathLike) -> dict[str, Any]: - descr = str(descr) - if ":" in descr: - path, entry = descr.split(":", 1) - with open(path, "rb") as file: - return tomllib.load(file)[entry] - else: - with open(descr, "rb") as file: - return tomllib.load(file) - - -def load_flat(descr: os.PathLike) -> dict[str, Any]: - with open(descr, "rb") as file: - d = tomllib.load(file) - if "Fiber" in d: - for fib in d["Fiber"]: - for k, v in fib.items(): - d[k] = v - break - return d - - -@cache -def load_material_dico(name: str) -> dict[str, Any]: - """ - loads a material dictionary - - Parameters - ---------- - name : str - name of the material - - Returns - ------- - material_dico : dict - """ - return json.loads(io.data_file("materials.json").read_text())[name] - - -def save_data(data: Union[np.ndarray, MutableMapping], data_dir: Path, file_name: str): - """ - saves numpy array to disk - - Parameters - ---------- - data : Union[np.ndarray, MutableMapping] - data to save - file_name : str - file name - task_id : int - id that uniquely identifies the process - identifier : str, optional - identifier in the main data folder of the task, by default "" - """ - path = data_dir / file_name - if isinstance(data, np.ndarray): - np.save(path, data) - elif isinstance(data, MutableMapping): - np.savez(path, **data) - get_logger(__name__).debug(f"saved data in {path}") - return - - -def ensure_folder(path: Path, prevent_overwrite: bool = True, mkdir=True) -> Path: - """ - ensure a folder exists and doesn't overwrite anything if required - - Parameters - ---------- - path : Path - desired path - prevent_overwrite : bool, optional - whether to create a new directory when one already exists, by default True - - Returns - ------- - Path - final path - """ - - path = path.resolve() - - # is path root ? - if len(path.parts) < 2: - return path - - # is a part of path an existing *file* ? - parts = path.parts - path = Path(path.root) - for part in parts: - if path.is_file(): - path = ensure_folder(path, mkdir=mkdir, prevent_overwrite=False) - path /= part - - folder_name = path.name - - for i in itertools.count(): - if not path.is_file() and (not prevent_overwrite or not path.is_dir()): - if mkdir: - path.mkdir(exist_ok=True) - return path - path = path.parent / (folder_name + f"_{i}") - - -def branch_id(branch: Path) -> tuple[int, int]: - sim_match = branch.resolve().parent.name.split()[0] - if sim_match.isdigit(): - s_int = int(sim_match) - else: - s_int = 0 - branch_match = re.search(r"(?<=b_)[0-9]+", branch.name) - if branch_match is None: - b_int = 0 - else: - b_int = int(branch_match[0]) - return s_int, b_int - - -def find_last_spectrum_num(data_dir: Path): - for num in itertools.count(1): - p_to_test = data_dir / SPEC1_FN.format(num) - if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0: - return num - 1 - - -def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray: - threshold = y.min() + rel_thr * (y.max() - y.min()) - above_threshold = y > threshold - ind = np.argsort(x) - valid_ind = [ - np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k - ] - ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0] - width = len(ind_above) - return np.concatenate( - ( - np.arange(max(ind_above[0] - width, 0), ind_above[0]), - ind_above, - np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)), - ) - ) - - -def to_62(i: int) -> str: - arr = [] - if i == 0: - return "0" - i = abs(i) - while i: - i, value = divmod(i, 62) - arr.append(str_printable[value]) - return "".join(reversed(arr)) - - -def fft_functions( - full_field: bool, -) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]: - if full_field: - return np.fft.rfft, np.fft.irfft - else: - return np.fft.fft, np.fft.ifft - - -def update_path_name(p: str) -> str: - return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p) - - -def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str: - return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name]) - - -def simulations_list(path: os.PathLike) -> list[Path]: - """ - finds simulations folders contained in a parent directory - - Parameters - ---------- - path : os.PathLike - parent path - - Returns - ------- - list[Path] - Absolute Path to the simulation folder - """ - paths: list[Path] = [] - for pwd, _, files in os.walk(path): - if PARAM_FN in files and SPEC1_FN.format(0) in files: - paths.append(Path(pwd)) - paths.sort(key=branch_id) - return [p for p in paths if p.parent.name == paths[-1].parent.name]