many changes

This commit is contained in:
Benoît Sierro
2022-12-12 16:08:03 +01:00
parent 256ff1d36e
commit aa225a0820
7 changed files with 163 additions and 74 deletions

View File

@@ -1,13 +1,13 @@
# flake8: noqa
from . import math, operators
from .evaluator import Evaluator
from .legacy import convert_sim_folder
from .math import abs2, argclosest, normalized, span, tspace, wspace
from .parameter import FileConfiguration, Parameters
from .physics import fiber, materials, pulse, simulate, units, plasma
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from .physics.units import PlotRange
from .plotting import (
from scgenerator import math, operators
from scgenerator.evaluator import Evaluator
from scgenerator.legacy import convert_sim_folder
from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace
from scgenerator.parameter import FileConfiguration, Parameters
from scgenerator.physics import fiber, materials, pulse, simulate, units, plasma
from scgenerator.physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from scgenerator.physics.units import PlotRange
from scgenerator.plotting import (
get_extent,
mean_values_plot,
plot_spectrogram,
@@ -17,6 +17,12 @@ from .plotting import (
transform_2D_propagation,
transform_mean_values,
)
from .spectra import SimulationSeries, Spectrum
from .utils import Paths, _open_config, open_single_config, simulations_list
from .variationer import DescriptorDict, VariationDescriptor, Variationer, VariationSpecsError
from scgenerator.spectra import SimulationSeries, Spectrum
from scgenerator.utils import Paths, _open_config, open_single_config, simulations_list
from scgenerator.variationer import (
DescriptorDict,
VariationDescriptor,
Variationer,
VariationSpecsError,
)
from scgenerator.helpers import *

View File

@@ -310,11 +310,12 @@ default_rules: list[Rule] = [
Rule(["fft", "ifft"], utils.fft_functions, priorities=1),
# Pulse
Rule("field_0", pulse.finalize_pulse),
Rule(["input_time", "input_field"], pulse.load_custom_field),
Rule("spec_0", utils.load_previous_spectrum, ["recovery_data_dir"], priorities=4),
Rule("spec_0", utils.load_previous_spectrum, priorities=3),
*Rule.deduce(
["pre_field_0", "peak_power", "energy", "width"],
pulse.load_and_adjust_field_file,
pulse.adjust_custom_field,
["energy", "peak_power"],
1,
priorities=[2, 1, 1, 1],

View File

@@ -0,0 +1,53 @@
"""
series of helper functions
"""
from scgenerator.physics.materials import n_gas_2
from scgenerator.physics.fiber import n_eff_marcatili, beta2, beta2_to_D
from scgenerator.physics.units import c
import numpy as np
__all__ = ["capillary_dispersion"]
def capillary_dispersion(
wl: np.ndarray, radius: float, gas_name: str, pressure=None, temperature=None
) -> np.ndarray:
"""computes the dispersion (beta2) of a capillary tube
Parameters
----------
wl : np.ndarray
wavelength in m
radius : float
core radius in m
gas_name : str
gas name (case insensitive)
pressure : float, optional
pressure in Pa (multiply mbar by 100 to get Pa), by default atm pressure
temperature : float, optional
temperature in K, by default 20°C
Returns
-------
np.ndarray
D parameter
"""
wl = extend_axis(wl)
if pressure is None:
pressure = 101325
if temperature is None:
temperature = 293.15
n = n_eff_marcatili(wl, n_gas_2(wl, gas_name.lower(), pressure, temperature, False), radius)
w = 2 * np.pi * c / wl
return beta2(w, n)[2:-2]
def extend_axis(wl):
dwl_left = wl[1] - wl[0]
dwl_right = wl[-1] - wl[-2]
wl = np.concatenate(
([wl[0] - 2 * dwl_left, wl[0] - dwl_left], wl, [wl[-1] + dwl_right, wl[-1] + 2 * dwl_right])
)
return wl

View File

@@ -342,6 +342,8 @@ class Parameters:
# pulse
field_file: str = Parameter(string)
input_time: np.ndarray = Parameter(type_checker(np.ndarray))
input_field: np.ndarray = Parameter(type_checker(np.ndarray))
repetition_rate: float = Parameter(
non_negative(float, int), display_info=(1e-3, "kHz"), default=40e6
)

View File

@@ -160,13 +160,13 @@ def initial_field_envelope(t: np.ndarray, shape: str, t0: float, peak_power: flo
def modify_field_ratio(
t: np.ndarray,
field: np.ndarray,
target_power: float = None,
target_energy: float = None,
pre_field_0: np.ndarray,
peak_power: float = None,
energy: float = None,
intensity_noise: float = None,
noise_correlation: float = 0,
) -> float:
"""multiply a field by this number to get the desired effects
"""multiply a field by this number to get the desired specifications
Parameters
----------
@@ -185,10 +185,10 @@ def modify_field_ratio(
ratio (multiply field by this number)
"""
ratio = 1
if target_energy is not None:
ratio *= np.sqrt(target_energy / np.trapz(math.abs2(field), t))
elif target_power is not None:
ratio *= np.sqrt(target_power / math.abs2(field).max())
if energy is not None:
ratio *= np.sqrt(energy / np.trapz(math.abs2(pre_field_0), t))
elif peak_power is not None:
ratio *= np.sqrt(peak_power / math.abs2(pre_field_0).max())
if intensity_noise is not None:
d_int, _ = technical_noise(intensity_noise, noise_correlation)
@@ -351,15 +351,16 @@ def L_sol(L_D):
return pi / 2 * L_D
def load_and_adjust_field_file(
field_file: str,
def adjust_custom_field(
input_time: np.ndarray,
input_field: np.ndarray,
t: np.ndarray,
intensity_noise: float,
noise_correlation: float,
energy: float = None,
peak_power: float = None,
) -> np.ndarray:
field_0 = load_field_file(field_file, t)
field_0 = interp_custom_field(input_time, input_field, t)
if energy is not None:
curr_energy = np.trapz(math.abs2(field_0), t)
field_0 = field_0 * np.sqrt(energy / curr_energy)
@@ -367,7 +368,7 @@ def load_and_adjust_field_file(
ratio = np.sqrt(peak_power / math.abs2(field_0).max())
field_0 = field_0 * ratio
else:
raise ValueError(f"Not enough parameters specified to load {field_file} correctly")
raise ValueError("Not enough parameters specified to load custom field correctly")
field_0 = field_0 * modify_field_ratio(
t, field_0, peak_power, energy, intensity_noise, noise_correlation
@@ -376,15 +377,19 @@ def load_and_adjust_field_file(
return field_0, peak_power, energy, width
def load_field_file(field_file: str, t: np.ndarray) -> np.ndarray:
field_data = np.load(field_file)
field_interp = interp1d(
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
)
def interp_custom_field(
input_time: np.ndarray, input_field: np.ndarray, t: np.ndarray
) -> np.ndarray:
field_interp = interp1d(input_time, input_field, bounds_error=False, fill_value=(0, 0))
field_0 = field_interp(t)
return field_0
def load_custom_field(field_file: str) -> tuple[np.ndarray, np.ndarray]:
field_data = np.load(field_file)
return field_data["time"], field_data["field"]
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:
"""
finds a new wavelength parameter such that the maximum of the spectrum corresponding

View File

@@ -4,16 +4,17 @@ from typing import Any, Callable, Literal, Optional, Union
import matplotlib.gridspec as gs
import matplotlib.pyplot as plt
import numba
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.colors import ListedColormap
from scipy.interpolate import UnivariateSpline
from scipy.interpolate.interpolate import interp1d
from matplotlib.transforms import offset_copy
from scipy.interpolate import UnivariateSpline, interp1d
from . import math
from .const import PARAM_SEPARATOR
from .defaults import default_plotting as defaults
from .math import abs2, span, linear_interp_2d
from .math import abs2, linear_interp_2d, span
from .parameter import Parameters
from .physics import pulse, units
from .physics.units import PlotRange, sort_axis
@@ -49,7 +50,7 @@ def plot_setup(
file_type: str = "png",
figsize: tuple[float, float] = defaults["figsize"],
mode: Literal["default", "coherence", "coherence_T"] = "default",
) -> tuple[Path, plt.Figure, Union[plt.Axes, tuple[plt.Axes]]]:
) -> tuple[Path, Figure, Union[Axes, tuple[Axes]]]:
out_path = defaults["name"] if out_path is None else out_path
out_path = Path(out_path)
plot_name = out_path.name.replace(f".{file_type}", "")
@@ -260,7 +261,7 @@ def propagation_plot(
x_axis: np.ndarray = None,
y_axis: np.ndarray = None,
params: Parameters = None,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
ax: Union[Axes, tuple[Axes, Axes]] = None,
log: Union[int, float, bool, str] = "1D",
renormalize: bool = False,
vmin: float = None,
@@ -269,7 +270,7 @@ def propagation_plot(
skip: int = 1,
cbar_label: Optional[str] = "normalized intensity (dB)",
cmap: str = None,
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]:
) -> tuple[Figure, Axes, plt.Line2D, np.ndarray, np.ndarray]:
"""transforms and plots a 2D propagation
Parameters
@@ -294,7 +295,7 @@ def propagation_plot(
label of the colorbar. No colorbar is drawn if this is set to None, by default "normalized intensity (dB)"
cmap : str, optional
colormap, by default None
ax : plt.Axes, optional
ax : Axes, optional
Axes obj on which to draw, by default None
"""
@@ -325,7 +326,7 @@ def plot_2D(
values: np.ndarray,
x_axis: np.ndarray,
y_axis: np.ndarray,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
ax: Union[Axes, tuple[Axes, Axes]] = None,
x_label: str = None,
y_label: str = None,
vmin: float = None,
@@ -333,7 +334,7 @@ def plot_2D(
transpose: bool = False,
cmap: str = None,
cbar_label: str = "",
) -> Union[tuple[plt.Axes, plt.Axes], plt.Axes]:
) -> Union[tuple[Axes, Axes], Axes]:
"""plots given 2D values in a standard
Parameters
@@ -344,7 +345,7 @@ def plot_2D(
x axis
y_axis : np.ndarray, shape (m,)
y axis
ax : Union[plt.Axes, tuple[plt.Axes, plt.Axes]]
ax : Union[Axes, tuple[Axes, Axes]]
the ax on which to draw, or a tuple (ax, cbar_ax) where cbar_ax is the ax for the color bar
x_label : str, optional
x label
@@ -363,7 +364,7 @@ def plot_2D(
Returns
-------
Union[tuple[plt.Axes, plt.Axes], plt.Axes]
Union[tuple[Axes, Axes], Axes]
ax if no color bar is drawn, a tuple (ax, cbar_ax) otherwise
"""
# apply log transform if required
@@ -493,7 +494,7 @@ def mean_values_plot(
values: np.ndarray,
plt_range: Union[PlotRange, RangeType],
params: Parameters,
ax: plt.Axes,
ax: Axes,
log: Union[float, int, str, bool] = False,
vmin: float = None,
vmax: float = None,
@@ -598,7 +599,7 @@ def plot_mean(
values: np.ndarray,
mean_values: np.ndarray,
x_axis: np.ndarray,
ax: plt.Axes,
ax: Axes,
x_label: str = None,
y_label: str = None,
line_labels: tuple[str, str] = None,
@@ -618,7 +619,7 @@ def plot_mean(
values to plot
x_axis : np.ndarray, shape (n,)
corresponding x axis
ax : plt.Axes
ax : Axes
ax on which to plot
x_label : str, optional
x label, by default None
@@ -662,7 +663,7 @@ def single_position_plot(
values: np.ndarray,
plt_range: Union[PlotRange, RangeType],
x_axis: np.ndarray = None,
ax: plt.Axes = None,
ax: Axes = None,
params: Parameters = None,
log: Union[str, int, float, bool] = False,
vmin: float = None,
@@ -672,7 +673,7 @@ def single_position_plot(
renormalize: bool = False,
y_label: str = None,
**line_kwargs,
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]:
) -> tuple[Figure, Axes, plt.Line2D, np.ndarray, np.ndarray]:
x_axis, values = transform_1D_values(values, plt_range, x_axis, params, log, spacing)
if renormalize:
values = values / values.max()
@@ -689,7 +690,7 @@ def single_position_plot(
def plot_1D(
values: np.ndarray,
x_axis: np.ndarray,
ax: Optional[plt.Axes],
ax: Optional[Axes],
x_label: str = None,
y_label: str = None,
vmin: float = None,
@@ -705,7 +706,7 @@ def plot_1D(
values to plot
x_axis : np.ndarray, shape (n,)
corresponding x axis
ax : plt.Axes,
ax : Axes,
ax on which to plot
x_label : str, optional
x label
@@ -810,7 +811,7 @@ def plot_spectrogram(
vmax: float = 0,
cbar_label: str = "normalized intensity (dB)",
cmap: str = None,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None,
ax: Union[Axes, tuple[Axes, Axes]] = None,
):
"""Plots a spectrogram given a complex field in the time domain
Parameters
@@ -1050,7 +1051,7 @@ def arrowstyle(direction=1, color="white"):
def measure_and_annotate_fwhm(
ax: plt.Axes,
ax: Axes,
t: np.ndarray,
field: np.ndarray,
side: Literal["left", "right"] = "right",
@@ -1062,7 +1063,7 @@ def measure_and_annotate_fwhm(
Parameters
----------
ax : plt.Axes
ax : Axes
ax on which to plot
t : np.ndarray, shape (n,)
time in s
@@ -1093,25 +1094,46 @@ def measure_and_annotate_fwhm(
def annotate_fwhm(
ax, left, right, arrow_label, v_max=1, side="right", arrow_length_pts=20.0, arrow_props=None
ax: Axes,
left,
right,
arrow_label,
v_max=1,
side="right",
arrow_length_pts=20.0,
arrow_props=None,
color=None,
**annotate_kwargs,
):
arrow_dict = dict(arrowstyle="->")
if color:
arrow_dict |= dict(color=color)
annotate_kwargs |= dict(color=color)
text_kwargs = dict(ha="right" if side == "left" else "left", va="center") | annotate_kwargs
if arrow_props is not None:
arrow_dict |= arrow_props
ax.annotate(
"" if side == "right" else arrow_label,
txt = {}
txt["left"] = ax.annotate(
"",
(left, v_max / 2),
xytext=(-arrow_length_pts, 0),
ha="right",
va="center",
textcoords="offset points",
arrowprops=arrow_dict,
)
ax.annotate(
"" if side == "left" else arrow_label,
txt["right"] = ax.annotate(
"",
(right, v_max / 2),
xytext=(arrow_length_pts, 0),
textcoords="offset points",
arrowprops=arrow_dict,
va="center",
)
if side == "right":
offset = arrow_length_pts
x, y = (right, v_max / 2)
else:
offset = -arrow_length_pts
x, y = (left, v_max / 2)
trans = offset_copy(
ax.transData, ax.get_figure(), offset, 0, "points"
)
ax.text(x, y, arrow_label, transform=trans, **text_kwargs)