many changes

This commit is contained in:
Benoît Sierro
2022-12-12 16:08:03 +01:00
parent 256ff1d36e
commit aa225a0820
7 changed files with 163 additions and 74 deletions

View File

@@ -1,13 +1,13 @@
# flake8: noqa # flake8: noqa
from . import math, operators from scgenerator import math, operators
from .evaluator import Evaluator from scgenerator.evaluator import Evaluator
from .legacy import convert_sim_folder from scgenerator.legacy import convert_sim_folder
from .math import abs2, argclosest, normalized, span, tspace, wspace from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace
from .parameter import FileConfiguration, Parameters from scgenerator.parameter import FileConfiguration, Parameters
from .physics import fiber, materials, pulse, simulate, units, plasma from scgenerator.physics import fiber, materials, pulse, simulate, units, plasma
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation from scgenerator.physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from .physics.units import PlotRange from scgenerator.physics.units import PlotRange
from .plotting import ( from scgenerator.plotting import (
get_extent, get_extent,
mean_values_plot, mean_values_plot,
plot_spectrogram, plot_spectrogram,
@@ -17,6 +17,12 @@ from .plotting import (
transform_2D_propagation, transform_2D_propagation,
transform_mean_values, transform_mean_values,
) )
from .spectra import SimulationSeries, Spectrum from scgenerator.spectra import SimulationSeries, Spectrum
from .utils import Paths, _open_config, open_single_config, simulations_list from scgenerator.utils import Paths, _open_config, open_single_config, simulations_list
from .variationer import DescriptorDict, VariationDescriptor, Variationer, VariationSpecsError from scgenerator.variationer import (
DescriptorDict,
VariationDescriptor,
Variationer,
VariationSpecsError,
)
from scgenerator.helpers import *

View File

@@ -310,11 +310,12 @@ default_rules: list[Rule] = [
Rule(["fft", "ifft"], utils.fft_functions, priorities=1), Rule(["fft", "ifft"], utils.fft_functions, priorities=1),
# Pulse # Pulse
Rule("field_0", pulse.finalize_pulse), Rule("field_0", pulse.finalize_pulse),
Rule(["input_time", "input_field"], pulse.load_custom_field),
Rule("spec_0", utils.load_previous_spectrum, ["recovery_data_dir"], priorities=4), Rule("spec_0", utils.load_previous_spectrum, ["recovery_data_dir"], priorities=4),
Rule("spec_0", utils.load_previous_spectrum, priorities=3), Rule("spec_0", utils.load_previous_spectrum, priorities=3),
*Rule.deduce( *Rule.deduce(
["pre_field_0", "peak_power", "energy", "width"], ["pre_field_0", "peak_power", "energy", "width"],
pulse.load_and_adjust_field_file, pulse.adjust_custom_field,
["energy", "peak_power"], ["energy", "peak_power"],
1, 1,
priorities=[2, 1, 1, 1], priorities=[2, 1, 1, 1],

View File

@@ -0,0 +1,53 @@
"""
series of helper functions
"""
from scgenerator.physics.materials import n_gas_2
from scgenerator.physics.fiber import n_eff_marcatili, beta2, beta2_to_D
from scgenerator.physics.units import c
import numpy as np
__all__ = ["capillary_dispersion"]
def capillary_dispersion(
wl: np.ndarray, radius: float, gas_name: str, pressure=None, temperature=None
) -> np.ndarray:
"""computes the dispersion (beta2) of a capillary tube
Parameters
----------
wl : np.ndarray
wavelength in m
radius : float
core radius in m
gas_name : str
gas name (case insensitive)
pressure : float, optional
pressure in Pa (multiply mbar by 100 to get Pa), by default atm pressure
temperature : float, optional
temperature in K, by default 20°C
Returns
-------
np.ndarray
D parameter
"""
wl = extend_axis(wl)
if pressure is None:
pressure = 101325
if temperature is None:
temperature = 293.15
n = n_eff_marcatili(wl, n_gas_2(wl, gas_name.lower(), pressure, temperature, False), radius)
w = 2 * np.pi * c / wl
return beta2(w, n)[2:-2]
def extend_axis(wl):
dwl_left = wl[1] - wl[0]
dwl_right = wl[-1] - wl[-2]
wl = np.concatenate(
([wl[0] - 2 * dwl_left, wl[0] - dwl_left], wl, [wl[-1] + dwl_right, wl[-1] + 2 * dwl_right])
)
return wl

View File

@@ -342,6 +342,8 @@ class Parameters:
# pulse # pulse
field_file: str = Parameter(string) field_file: str = Parameter(string)
input_time: np.ndarray = Parameter(type_checker(np.ndarray))
input_field: np.ndarray = Parameter(type_checker(np.ndarray))
repetition_rate: float = Parameter( repetition_rate: float = Parameter(
non_negative(float, int), display_info=(1e-3, "kHz"), default=40e6 non_negative(float, int), display_info=(1e-3, "kHz"), default=40e6
) )

View File

@@ -160,13 +160,13 @@ def initial_field_envelope(t: np.ndarray, shape: str, t0: float, peak_power: flo
def modify_field_ratio( def modify_field_ratio(
t: np.ndarray, t: np.ndarray,
field: np.ndarray, pre_field_0: np.ndarray,
target_power: float = None, peak_power: float = None,
target_energy: float = None, energy: float = None,
intensity_noise: float = None, intensity_noise: float = None,
noise_correlation: float = 0, noise_correlation: float = 0,
) -> float: ) -> float:
"""multiply a field by this number to get the desired effects """multiply a field by this number to get the desired specifications
Parameters Parameters
---------- ----------
@@ -185,10 +185,10 @@ def modify_field_ratio(
ratio (multiply field by this number) ratio (multiply field by this number)
""" """
ratio = 1 ratio = 1
if target_energy is not None: if energy is not None:
ratio *= np.sqrt(target_energy / np.trapz(math.abs2(field), t)) ratio *= np.sqrt(energy / np.trapz(math.abs2(pre_field_0), t))
elif target_power is not None: elif peak_power is not None:
ratio *= np.sqrt(target_power / math.abs2(field).max()) ratio *= np.sqrt(peak_power / math.abs2(pre_field_0).max())
if intensity_noise is not None: if intensity_noise is not None:
d_int, _ = technical_noise(intensity_noise, noise_correlation) d_int, _ = technical_noise(intensity_noise, noise_correlation)
@@ -351,15 +351,16 @@ def L_sol(L_D):
return pi / 2 * L_D return pi / 2 * L_D
def load_and_adjust_field_file( def adjust_custom_field(
field_file: str, input_time: np.ndarray,
input_field: np.ndarray,
t: np.ndarray, t: np.ndarray,
intensity_noise: float, intensity_noise: float,
noise_correlation: float, noise_correlation: float,
energy: float = None, energy: float = None,
peak_power: float = None, peak_power: float = None,
) -> np.ndarray: ) -> np.ndarray:
field_0 = load_field_file(field_file, t) field_0 = interp_custom_field(input_time, input_field, t)
if energy is not None: if energy is not None:
curr_energy = np.trapz(math.abs2(field_0), t) curr_energy = np.trapz(math.abs2(field_0), t)
field_0 = field_0 * np.sqrt(energy / curr_energy) field_0 = field_0 * np.sqrt(energy / curr_energy)
@@ -367,7 +368,7 @@ def load_and_adjust_field_file(
ratio = np.sqrt(peak_power / math.abs2(field_0).max()) ratio = np.sqrt(peak_power / math.abs2(field_0).max())
field_0 = field_0 * ratio field_0 = field_0 * ratio
else: else:
raise ValueError(f"Not enough parameters specified to load {field_file} correctly") raise ValueError("Not enough parameters specified to load custom field correctly")
field_0 = field_0 * modify_field_ratio( field_0 = field_0 * modify_field_ratio(
t, field_0, peak_power, energy, intensity_noise, noise_correlation t, field_0, peak_power, energy, intensity_noise, noise_correlation
@@ -376,15 +377,19 @@ def load_and_adjust_field_file(
return field_0, peak_power, energy, width return field_0, peak_power, energy, width
def load_field_file(field_file: str, t: np.ndarray) -> np.ndarray: def interp_custom_field(
field_data = np.load(field_file) input_time: np.ndarray, input_field: np.ndarray, t: np.ndarray
field_interp = interp1d( ) -> np.ndarray:
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0) field_interp = interp1d(input_time, input_field, bounds_error=False, fill_value=(0, 0))
)
field_0 = field_interp(t) field_0 = field_interp(t)
return field_0 return field_0
def load_custom_field(field_file: str) -> tuple[np.ndarray, np.ndarray]:
field_data = np.load(field_file)
return field_data["time"], field_data["field"]
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float: def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:
""" """
finds a new wavelength parameter such that the maximum of the spectrum corresponding finds a new wavelength parameter such that the maximum of the spectrum corresponding

View File

@@ -4,16 +4,17 @@ from typing import Any, Callable, Literal, Optional, Union
import matplotlib.gridspec as gs import matplotlib.gridspec as gs
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numba
import numpy as np import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.colors import ListedColormap from matplotlib.colors import ListedColormap
from scipy.interpolate import UnivariateSpline from matplotlib.transforms import offset_copy
from scipy.interpolate.interpolate import interp1d from scipy.interpolate import UnivariateSpline, interp1d
from . import math from . import math
from .const import PARAM_SEPARATOR from .const import PARAM_SEPARATOR
from .defaults import default_plotting as defaults from .defaults import default_plotting as defaults
from .math import abs2, span, linear_interp_2d from .math import abs2, linear_interp_2d, span
from .parameter import Parameters from .parameter import Parameters
from .physics import pulse, units from .physics import pulse, units
from .physics.units import PlotRange, sort_axis from .physics.units import PlotRange, sort_axis
@@ -49,7 +50,7 @@ def plot_setup(
file_type: str = "png", file_type: str = "png",
figsize: tuple[float, float] = defaults["figsize"], figsize: tuple[float, float] = defaults["figsize"],
mode: Literal["default", "coherence", "coherence_T"] = "default", mode: Literal["default", "coherence", "coherence_T"] = "default",
) -> tuple[Path, plt.Figure, Union[plt.Axes, tuple[plt.Axes]]]: ) -> tuple[Path, Figure, Union[Axes, tuple[Axes]]]:
out_path = defaults["name"] if out_path is None else out_path out_path = defaults["name"] if out_path is None else out_path
out_path = Path(out_path) out_path = Path(out_path)
plot_name = out_path.name.replace(f".{file_type}", "") plot_name = out_path.name.replace(f".{file_type}", "")
@@ -260,7 +261,7 @@ def propagation_plot(
x_axis: np.ndarray = None, x_axis: np.ndarray = None,
y_axis: np.ndarray = None, y_axis: np.ndarray = None,
params: Parameters = None, params: Parameters = None,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None, ax: Union[Axes, tuple[Axes, Axes]] = None,
log: Union[int, float, bool, str] = "1D", log: Union[int, float, bool, str] = "1D",
renormalize: bool = False, renormalize: bool = False,
vmin: float = None, vmin: float = None,
@@ -269,7 +270,7 @@ def propagation_plot(
skip: int = 1, skip: int = 1,
cbar_label: Optional[str] = "normalized intensity (dB)", cbar_label: Optional[str] = "normalized intensity (dB)",
cmap: str = None, cmap: str = None,
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]: ) -> tuple[Figure, Axes, plt.Line2D, np.ndarray, np.ndarray]:
"""transforms and plots a 2D propagation """transforms and plots a 2D propagation
Parameters Parameters
@@ -294,7 +295,7 @@ def propagation_plot(
label of the colorbar. No colorbar is drawn if this is set to None, by default "normalized intensity (dB)" label of the colorbar. No colorbar is drawn if this is set to None, by default "normalized intensity (dB)"
cmap : str, optional cmap : str, optional
colormap, by default None colormap, by default None
ax : plt.Axes, optional ax : Axes, optional
Axes obj on which to draw, by default None Axes obj on which to draw, by default None
""" """
@@ -325,7 +326,7 @@ def plot_2D(
values: np.ndarray, values: np.ndarray,
x_axis: np.ndarray, x_axis: np.ndarray,
y_axis: np.ndarray, y_axis: np.ndarray,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None, ax: Union[Axes, tuple[Axes, Axes]] = None,
x_label: str = None, x_label: str = None,
y_label: str = None, y_label: str = None,
vmin: float = None, vmin: float = None,
@@ -333,7 +334,7 @@ def plot_2D(
transpose: bool = False, transpose: bool = False,
cmap: str = None, cmap: str = None,
cbar_label: str = "", cbar_label: str = "",
) -> Union[tuple[plt.Axes, plt.Axes], plt.Axes]: ) -> Union[tuple[Axes, Axes], Axes]:
"""plots given 2D values in a standard """plots given 2D values in a standard
Parameters Parameters
@@ -344,7 +345,7 @@ def plot_2D(
x axis x axis
y_axis : np.ndarray, shape (m,) y_axis : np.ndarray, shape (m,)
y axis y axis
ax : Union[plt.Axes, tuple[plt.Axes, plt.Axes]] ax : Union[Axes, tuple[Axes, Axes]]
the ax on which to draw, or a tuple (ax, cbar_ax) where cbar_ax is the ax for the color bar the ax on which to draw, or a tuple (ax, cbar_ax) where cbar_ax is the ax for the color bar
x_label : str, optional x_label : str, optional
x label x label
@@ -363,7 +364,7 @@ def plot_2D(
Returns Returns
------- -------
Union[tuple[plt.Axes, plt.Axes], plt.Axes] Union[tuple[Axes, Axes], Axes]
ax if no color bar is drawn, a tuple (ax, cbar_ax) otherwise ax if no color bar is drawn, a tuple (ax, cbar_ax) otherwise
""" """
# apply log transform if required # apply log transform if required
@@ -493,7 +494,7 @@ def mean_values_plot(
values: np.ndarray, values: np.ndarray,
plt_range: Union[PlotRange, RangeType], plt_range: Union[PlotRange, RangeType],
params: Parameters, params: Parameters,
ax: plt.Axes, ax: Axes,
log: Union[float, int, str, bool] = False, log: Union[float, int, str, bool] = False,
vmin: float = None, vmin: float = None,
vmax: float = None, vmax: float = None,
@@ -598,7 +599,7 @@ def plot_mean(
values: np.ndarray, values: np.ndarray,
mean_values: np.ndarray, mean_values: np.ndarray,
x_axis: np.ndarray, x_axis: np.ndarray,
ax: plt.Axes, ax: Axes,
x_label: str = None, x_label: str = None,
y_label: str = None, y_label: str = None,
line_labels: tuple[str, str] = None, line_labels: tuple[str, str] = None,
@@ -618,7 +619,7 @@ def plot_mean(
values to plot values to plot
x_axis : np.ndarray, shape (n,) x_axis : np.ndarray, shape (n,)
corresponding x axis corresponding x axis
ax : plt.Axes ax : Axes
ax on which to plot ax on which to plot
x_label : str, optional x_label : str, optional
x label, by default None x label, by default None
@@ -662,7 +663,7 @@ def single_position_plot(
values: np.ndarray, values: np.ndarray,
plt_range: Union[PlotRange, RangeType], plt_range: Union[PlotRange, RangeType],
x_axis: np.ndarray = None, x_axis: np.ndarray = None,
ax: plt.Axes = None, ax: Axes = None,
params: Parameters = None, params: Parameters = None,
log: Union[str, int, float, bool] = False, log: Union[str, int, float, bool] = False,
vmin: float = None, vmin: float = None,
@@ -672,7 +673,7 @@ def single_position_plot(
renormalize: bool = False, renormalize: bool = False,
y_label: str = None, y_label: str = None,
**line_kwargs, **line_kwargs,
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]: ) -> tuple[Figure, Axes, plt.Line2D, np.ndarray, np.ndarray]:
x_axis, values = transform_1D_values(values, plt_range, x_axis, params, log, spacing) x_axis, values = transform_1D_values(values, plt_range, x_axis, params, log, spacing)
if renormalize: if renormalize:
values = values / values.max() values = values / values.max()
@@ -689,7 +690,7 @@ def single_position_plot(
def plot_1D( def plot_1D(
values: np.ndarray, values: np.ndarray,
x_axis: np.ndarray, x_axis: np.ndarray,
ax: Optional[plt.Axes], ax: Optional[Axes],
x_label: str = None, x_label: str = None,
y_label: str = None, y_label: str = None,
vmin: float = None, vmin: float = None,
@@ -705,7 +706,7 @@ def plot_1D(
values to plot values to plot
x_axis : np.ndarray, shape (n,) x_axis : np.ndarray, shape (n,)
corresponding x axis corresponding x axis
ax : plt.Axes, ax : Axes,
ax on which to plot ax on which to plot
x_label : str, optional x_label : str, optional
x label x label
@@ -810,7 +811,7 @@ def plot_spectrogram(
vmax: float = 0, vmax: float = 0,
cbar_label: str = "normalized intensity (dB)", cbar_label: str = "normalized intensity (dB)",
cmap: str = None, cmap: str = None,
ax: Union[plt.Axes, tuple[plt.Axes, plt.Axes]] = None, ax: Union[Axes, tuple[Axes, Axes]] = None,
): ):
"""Plots a spectrogram given a complex field in the time domain """Plots a spectrogram given a complex field in the time domain
Parameters Parameters
@@ -1050,7 +1051,7 @@ def arrowstyle(direction=1, color="white"):
def measure_and_annotate_fwhm( def measure_and_annotate_fwhm(
ax: plt.Axes, ax: Axes,
t: np.ndarray, t: np.ndarray,
field: np.ndarray, field: np.ndarray,
side: Literal["left", "right"] = "right", side: Literal["left", "right"] = "right",
@@ -1062,7 +1063,7 @@ def measure_and_annotate_fwhm(
Parameters Parameters
---------- ----------
ax : plt.Axes ax : Axes
ax on which to plot ax on which to plot
t : np.ndarray, shape (n,) t : np.ndarray, shape (n,)
time in s time in s
@@ -1093,25 +1094,46 @@ def measure_and_annotate_fwhm(
def annotate_fwhm( def annotate_fwhm(
ax, left, right, arrow_label, v_max=1, side="right", arrow_length_pts=20.0, arrow_props=None ax: Axes,
left,
right,
arrow_label,
v_max=1,
side="right",
arrow_length_pts=20.0,
arrow_props=None,
color=None,
**annotate_kwargs,
): ):
arrow_dict = dict(arrowstyle="->") arrow_dict = dict(arrowstyle="->")
if color:
arrow_dict |= dict(color=color)
annotate_kwargs |= dict(color=color)
text_kwargs = dict(ha="right" if side == "left" else "left", va="center") | annotate_kwargs
if arrow_props is not None: if arrow_props is not None:
arrow_dict |= arrow_props arrow_dict |= arrow_props
ax.annotate( txt = {}
"" if side == "right" else arrow_label, txt["left"] = ax.annotate(
"",
(left, v_max / 2), (left, v_max / 2),
xytext=(-arrow_length_pts, 0), xytext=(-arrow_length_pts, 0),
ha="right",
va="center",
textcoords="offset points", textcoords="offset points",
arrowprops=arrow_dict, arrowprops=arrow_dict,
) )
ax.annotate( txt["right"] = ax.annotate(
"" if side == "left" else arrow_label, "",
(right, v_max / 2), (right, v_max / 2),
xytext=(arrow_length_pts, 0), xytext=(arrow_length_pts, 0),
textcoords="offset points", textcoords="offset points",
arrowprops=arrow_dict, arrowprops=arrow_dict,
va="center",
) )
if side == "right":
offset = arrow_length_pts
x, y = (right, v_max / 2)
else:
offset = -arrow_length_pts
x, y = (left, v_max / 2)
trans = offset_copy(
ax.transData, ax.get_figure(), offset, 0, "points"
)
ax.text(x, y, arrow_label, transform=trans, **text_kwargs)