Working on better multi-fiber sim

This commit is contained in:
Benoît Sierro
2021-07-22 10:06:13 +02:00
parent 1a681d8df8
commit 873d02c60f
3 changed files with 52 additions and 21 deletions

View File

@@ -186,6 +186,9 @@ class PlotRange:
right: float = Parameter(type_checker(int, float)) right: float = Parameter(type_checker(int, float))
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit) unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit)
def __str__(self):
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
def beta2_coef(beta): def beta2_coef(beta):
fac = 1e27 fac = 1e27

View File

@@ -580,10 +580,9 @@ def plot_results_1D(
file_type: str = "pdf", file_type: str = "pdf",
file_name: str = None, file_name: str = None,
ax: plt.Axes = None, ax: plt.Axes = None,
line_label: str = None,
transpose: bool = False, transpose: bool = False,
**line_kwargs, **line_kwargs,
) -> tuple[plt.Figure, plt.Axes, np.ndarray, np.ndarray]: ) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]:
""" """
Parameters Parameters
@@ -620,13 +619,15 @@ def plot_results_1D(
special name to give to the plot. A name is automatically assigned anyway special name to give to the plot. A name is automatically assigned anyway
ax : matplotlib.axes._subplots.AxesSubplot object, optional ax : matplotlib.axes._subplots.AxesSubplot object, optional
axis on which to draw the plot axis on which to draw the plot
line_label : str, optional
label of the line
transpose : bool, optional transpose : bool, optional
transpose the plot transpose the plot
line_kwargs : to be passed to plt.plot line_kwargs : to be passed to plt.plot
returns returns
fig, ax : matplotlib objects containing the plots fig, ax, line : matplotlib objects containing the plots
x_axis : np.ndarray
new x axis array
ind : np.ndarray
corresponding indices on the old axis
example: example:
if spectra is a (m, n, nt) array, one can plot a spectrum evolution as such: if spectra is a (m, n, nt) array, one can plot a spectrum evolution as such:
>>> fig, ax = plot_results_2D(spectra[:, -1], (600, 1600, nm), log=True, "Heidt2017") >>> fig, ax = plot_results_2D(spectra[:, -1], (600, 1600, nm), log=True, "Heidt2017")
@@ -674,19 +675,21 @@ def plot_results_1D(
is_new_plot = ax is None is_new_plot = ax is None
folder_name = "" folder_name = ""
if file_name is None:
file_name = params.name + str(plt_range)
if is_new_plot: if is_new_plot:
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type) out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
else: else:
fig = ax.get_figure() fig = ax.get_figure()
if transpose: if transpose:
ax.plot(values, x_axis, label=line_label, **line_kwargs) (line,) = ax.plot(values, x_axis, **line_kwargs)
ax.yaxis.tick_right() ax.yaxis.tick_right()
ax.yaxis.set_label_position("right") ax.yaxis.set_label_position("right")
ax.set_xlim(vmax, vmin) ax.set_xlim(vmax, vmin)
ax.set_xlabel(ylabel) ax.set_xlabel(ylabel)
ax.set_ylabel(plt_range.unit.label) ax.set_ylabel(plt_range.unit.label)
else: else:
ax.plot(x_axis, values, label=line_label, **line_kwargs) (line,) = ax.plot(x_axis, values, **line_kwargs)
ax.set_ylim(vmin, vmax) ax.set_ylim(vmin, vmax)
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.set_xlabel(plt_range.unit.label) ax.set_xlabel(plt_range.unit.label)
@@ -694,7 +697,7 @@ def plot_results_1D(
if is_new_plot: if is_new_plot:
fig.savefig(out_path, bbox_inches="tight", dpi=200) fig.savefig(out_path, bbox_inches="tight", dpi=200)
print(f"plot saved in {out_path}") print(f"plot saved in {out_path}")
return fig, ax, x_axis, values return fig, ax, line, x_axis, values
def _prep_plot( def _prep_plot(

View File

@@ -7,6 +7,8 @@ from matplotlib.pyplot import subplot
from dataclasses import replace from dataclasses import replace
import numpy as np import numpy as np
from numpy.lib import utils
from numpy.lib.arraysetops import isin
from tqdm.std import Bar from tqdm.std import Bar
from . import initialize, io, math from . import initialize, io, math
@@ -14,7 +16,7 @@ from .physics import units, pulse
from .const import SPECN_FN from .const import SPECN_FN
from .logger import get_logger from .logger import get_logger
from .plotting import plot_avg, plot_results_1D, plot_results_2D from .plotting import plot_avg, plot_results_1D, plot_results_2D
from .utils.parameter import BareParams from .utils.parameter import BareParams, validator_and
class Spectrum(np.ndarray): class Spectrum(np.ndarray):
@@ -184,7 +186,7 @@ class Pulse(Sequence):
return self.nmax return self.nmax
def __getitem__(self, key) -> Spectrum: def __getitem__(self, key) -> Spectrum:
return self.all_spectra(ind=range(self.nmax)[key]).squeeze() return self.all_spectra(key)
def intensity(self, unit): def intensity(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]: if unit.type in ["WL", "FREQ", "AFREQ"]:
@@ -253,7 +255,7 @@ class Pulse(Sequence):
def _to_time_amp(self, spectrum): def _to_time_amp(self, spectrum):
return np.fft.ifft(spectrum) return np.fft.ifft(spectrum)
def all_spectra(self, ind=None) -> Spectrum: def all_spectra(self, ind) -> Spectrum:
""" """
loads the data already simulated. loads the data already simulated.
defauft shape is (z_targets, n, nt) defauft shape is (z_targets, n, nt)
@@ -280,6 +282,10 @@ class Pulse(Sequence):
ind = self.default_ind ind = self.default_ind
if isinstance(ind, (int, np.integer)): if isinstance(ind, (int, np.integer)):
ind = [ind] ind = [ind]
elif isinstance(ind, (float, np.floating)):
ind = [self.z_ind(ind)]
elif isinstance(ind[0], (float, np.floating)):
ind = [self.z_ind(ii) for ii in ind]
# Load the spectra # Load the spectra
spectra = [] spectra = []
@@ -312,11 +318,11 @@ class Pulse(Sequence):
left: float, left: float,
right: float, right: float,
unit: Union[Callable[[float], float], str], unit: Union[Callable[[float], float], str],
z_ind: Union[int, Iterable[int]] = None, z_pos: Union[int, Iterable[int]] = None,
sim_ind: int = 0, sim_ind: int = 0,
**kwargs, **kwargs,
): ):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind) plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind)
return plot_results_2D(vals, plt_range, self.params, **kwargs) return plot_results_2D(vals, plt_range, self.params, **kwargs)
def plot_1D( def plot_1D(
@@ -324,28 +330,47 @@ class Pulse(Sequence):
left: float, left: float,
right: float, right: float,
unit: Union[Callable[[float], float], str], unit: Union[Callable[[float], float], str],
z_ind: int, z_pos: int,
sim_ind: int = 0, sim_ind: int = 0,
**kwargs, **kwargs,
): ):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind) plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind)
return plot_results_1D(vals[0], plt_range, self.params, **kwargs) return plot_results_1D(vals, plt_range, self.params, **kwargs)
def plot_avg( def plot_avg(
self, self,
left: float, left: float,
right: float, right: float,
unit: Union[Callable[[float], float], str], unit: Union[Callable[[float], float], str],
z_ind: int, z_pos: int,
**kwargs, **kwargs,
): ):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, slice(None)) plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, slice(None))
return plot_avg(vals, plt_range, self.params, **kwargs) return plot_avg(vals, plt_range, self.params, **kwargs)
def retrieve_plot_values(self, left, right, unit, z_ind, sim_ind): def retrieve_plot_values(self, left, right, unit, z_pos, sim_ind):
plt_range = units.PlotRange(left, right, unit) plt_range = units.PlotRange(left, right, unit)
if plt_range.unit.type == "TIME": if plt_range.unit.type == "TIME":
vals = self.all_fields(ind=z_ind)[:, sim_ind] vals = self.all_fields(ind=z_pos)
else: else:
vals = self.all_spectra(ind=z_ind)[:, sim_ind] vals = self.all_spectra(ind=z_pos)
if vals.ndim == 3:
vals = vals[:, sim_ind]
else:
vals = vals[sim_ind]
return plt_range, vals return plt_range, vals
def z_ind(self, z: float) -> int:
"""return the closest z index to the given target
Parameters
----------
z : float
target
Returns
-------
int
index
"""
return math.argclosest(self.z, z)