Working on better multi-fiber sim

This commit is contained in:
Benoît Sierro
2021-07-22 10:06:13 +02:00
parent 1a681d8df8
commit 873d02c60f
3 changed files with 52 additions and 21 deletions

View File

@@ -186,6 +186,9 @@ class PlotRange:
right: float = Parameter(type_checker(int, float))
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit)
def __str__(self):
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
def beta2_coef(beta):
fac = 1e27

View File

@@ -580,10 +580,9 @@ def plot_results_1D(
file_type: str = "pdf",
file_name: str = None,
ax: plt.Axes = None,
line_label: str = None,
transpose: bool = False,
**line_kwargs,
) -> tuple[plt.Figure, plt.Axes, np.ndarray, np.ndarray]:
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]:
"""
Parameters
@@ -620,13 +619,15 @@ def plot_results_1D(
special name to give to the plot. A name is automatically assigned anyway
ax : matplotlib.axes._subplots.AxesSubplot object, optional
axis on which to draw the plot
line_label : str, optional
label of the line
transpose : bool, optional
transpose the plot
line_kwargs : to be passed to plt.plot
returns
fig, ax : matplotlib objects containing the plots
fig, ax, line : matplotlib objects containing the plots
x_axis : np.ndarray
new x axis array
ind : np.ndarray
corresponding indices on the old axis
example:
if spectra is a (m, n, nt) array, one can plot a spectrum evolution as such:
>>> fig, ax = plot_results_2D(spectra[:, -1], (600, 1600, nm), log=True, "Heidt2017")
@@ -674,19 +675,21 @@ def plot_results_1D(
is_new_plot = ax is None
folder_name = ""
if file_name is None:
file_name = params.name + str(plt_range)
if is_new_plot:
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
else:
fig = ax.get_figure()
if transpose:
ax.plot(values, x_axis, label=line_label, **line_kwargs)
(line,) = ax.plot(values, x_axis, **line_kwargs)
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
ax.set_xlim(vmax, vmin)
ax.set_xlabel(ylabel)
ax.set_ylabel(plt_range.unit.label)
else:
ax.plot(x_axis, values, label=line_label, **line_kwargs)
(line,) = ax.plot(x_axis, values, **line_kwargs)
ax.set_ylim(vmin, vmax)
ax.set_ylabel(ylabel)
ax.set_xlabel(plt_range.unit.label)
@@ -694,7 +697,7 @@ def plot_results_1D(
if is_new_plot:
fig.savefig(out_path, bbox_inches="tight", dpi=200)
print(f"plot saved in {out_path}")
return fig, ax, x_axis, values
return fig, ax, line, x_axis, values
def _prep_plot(

View File

@@ -7,6 +7,8 @@ from matplotlib.pyplot import subplot
from dataclasses import replace
import numpy as np
from numpy.lib import utils
from numpy.lib.arraysetops import isin
from tqdm.std import Bar
from . import initialize, io, math
@@ -14,7 +16,7 @@ from .physics import units, pulse
from .const import SPECN_FN
from .logger import get_logger
from .plotting import plot_avg, plot_results_1D, plot_results_2D
from .utils.parameter import BareParams
from .utils.parameter import BareParams, validator_and
class Spectrum(np.ndarray):
@@ -184,7 +186,7 @@ class Pulse(Sequence):
return self.nmax
def __getitem__(self, key) -> Spectrum:
return self.all_spectra(ind=range(self.nmax)[key]).squeeze()
return self.all_spectra(key)
def intensity(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
@@ -253,7 +255,7 @@ class Pulse(Sequence):
def _to_time_amp(self, spectrum):
return np.fft.ifft(spectrum)
def all_spectra(self, ind=None) -> Spectrum:
def all_spectra(self, ind) -> Spectrum:
"""
loads the data already simulated.
defauft shape is (z_targets, n, nt)
@@ -280,6 +282,10 @@ class Pulse(Sequence):
ind = self.default_ind
if isinstance(ind, (int, np.integer)):
ind = [ind]
elif isinstance(ind, (float, np.floating)):
ind = [self.z_ind(ind)]
elif isinstance(ind[0], (float, np.floating)):
ind = [self.z_ind(ii) for ii in ind]
# Load the spectra
spectra = []
@@ -312,11 +318,11 @@ class Pulse(Sequence):
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_ind: Union[int, Iterable[int]] = None,
z_pos: Union[int, Iterable[int]] = None,
sim_ind: int = 0,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind)
return plot_results_2D(vals, plt_range, self.params, **kwargs)
def plot_1D(
@@ -324,28 +330,47 @@ class Pulse(Sequence):
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_ind: int,
z_pos: int,
sim_ind: int = 0,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
return plot_results_1D(vals[0], plt_range, self.params, **kwargs)
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind)
return plot_results_1D(vals, plt_range, self.params, **kwargs)
def plot_avg(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_ind: int,
z_pos: int,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, slice(None))
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, slice(None))
return plot_avg(vals, plt_range, self.params, **kwargs)
def retrieve_plot_values(self, left, right, unit, z_ind, sim_ind):
def retrieve_plot_values(self, left, right, unit, z_pos, sim_ind):
plt_range = units.PlotRange(left, right, unit)
if plt_range.unit.type == "TIME":
vals = self.all_fields(ind=z_ind)[:, sim_ind]
vals = self.all_fields(ind=z_pos)
else:
vals = self.all_spectra(ind=z_ind)[:, sim_ind]
vals = self.all_spectra(ind=z_pos)
if vals.ndim == 3:
vals = vals[:, sim_ind]
else:
vals = vals[sim_ind]
return plt_range, vals
def z_ind(self, z: float) -> int:
"""return the closest z index to the given target
Parameters
----------
z : float
target
Returns
-------
int
index
"""
return math.argclosest(self.z, z)