Working on better multi-fiber sim
This commit is contained in:
@@ -186,6 +186,9 @@ class PlotRange:
|
||||
right: float = Parameter(type_checker(int, float))
|
||||
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||
|
||||
|
||||
def beta2_coef(beta):
|
||||
fac = 1e27
|
||||
|
||||
@@ -580,10 +580,9 @@ def plot_results_1D(
|
||||
file_type: str = "pdf",
|
||||
file_name: str = None,
|
||||
ax: plt.Axes = None,
|
||||
line_label: str = None,
|
||||
transpose: bool = False,
|
||||
**line_kwargs,
|
||||
) -> tuple[plt.Figure, plt.Axes, np.ndarray, np.ndarray]:
|
||||
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -620,13 +619,15 @@ def plot_results_1D(
|
||||
special name to give to the plot. A name is automatically assigned anyway
|
||||
ax : matplotlib.axes._subplots.AxesSubplot object, optional
|
||||
axis on which to draw the plot
|
||||
line_label : str, optional
|
||||
label of the line
|
||||
transpose : bool, optional
|
||||
transpose the plot
|
||||
line_kwargs : to be passed to plt.plot
|
||||
returns
|
||||
fig, ax : matplotlib objects containing the plots
|
||||
fig, ax, line : matplotlib objects containing the plots
|
||||
x_axis : np.ndarray
|
||||
new x axis array
|
||||
ind : np.ndarray
|
||||
corresponding indices on the old axis
|
||||
example:
|
||||
if spectra is a (m, n, nt) array, one can plot a spectrum evolution as such:
|
||||
>>> fig, ax = plot_results_2D(spectra[:, -1], (600, 1600, nm), log=True, "Heidt2017")
|
||||
@@ -674,19 +675,21 @@ def plot_results_1D(
|
||||
is_new_plot = ax is None
|
||||
|
||||
folder_name = ""
|
||||
if file_name is None:
|
||||
file_name = params.name + str(plt_range)
|
||||
if is_new_plot:
|
||||
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
if transpose:
|
||||
ax.plot(values, x_axis, label=line_label, **line_kwargs)
|
||||
(line,) = ax.plot(values, x_axis, **line_kwargs)
|
||||
ax.yaxis.tick_right()
|
||||
ax.yaxis.set_label_position("right")
|
||||
ax.set_xlim(vmax, vmin)
|
||||
ax.set_xlabel(ylabel)
|
||||
ax.set_ylabel(plt_range.unit.label)
|
||||
else:
|
||||
ax.plot(x_axis, values, label=line_label, **line_kwargs)
|
||||
(line,) = ax.plot(x_axis, values, **line_kwargs)
|
||||
ax.set_ylim(vmin, vmax)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_xlabel(plt_range.unit.label)
|
||||
@@ -694,7 +697,7 @@ def plot_results_1D(
|
||||
if is_new_plot:
|
||||
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||
print(f"plot saved in {out_path}")
|
||||
return fig, ax, x_axis, values
|
||||
return fig, ax, line, x_axis, values
|
||||
|
||||
|
||||
def _prep_plot(
|
||||
|
||||
@@ -7,6 +7,8 @@ from matplotlib.pyplot import subplot
|
||||
from dataclasses import replace
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib import utils
|
||||
from numpy.lib.arraysetops import isin
|
||||
from tqdm.std import Bar
|
||||
|
||||
from . import initialize, io, math
|
||||
@@ -14,7 +16,7 @@ from .physics import units, pulse
|
||||
from .const import SPECN_FN
|
||||
from .logger import get_logger
|
||||
from .plotting import plot_avg, plot_results_1D, plot_results_2D
|
||||
from .utils.parameter import BareParams
|
||||
from .utils.parameter import BareParams, validator_and
|
||||
|
||||
|
||||
class Spectrum(np.ndarray):
|
||||
@@ -184,7 +186,7 @@ class Pulse(Sequence):
|
||||
return self.nmax
|
||||
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
return self.all_spectra(ind=range(self.nmax)[key]).squeeze()
|
||||
return self.all_spectra(key)
|
||||
|
||||
def intensity(self, unit):
|
||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||
@@ -253,7 +255,7 @@ class Pulse(Sequence):
|
||||
def _to_time_amp(self, spectrum):
|
||||
return np.fft.ifft(spectrum)
|
||||
|
||||
def all_spectra(self, ind=None) -> Spectrum:
|
||||
def all_spectra(self, ind) -> Spectrum:
|
||||
"""
|
||||
loads the data already simulated.
|
||||
defauft shape is (z_targets, n, nt)
|
||||
@@ -280,6 +282,10 @@ class Pulse(Sequence):
|
||||
ind = self.default_ind
|
||||
if isinstance(ind, (int, np.integer)):
|
||||
ind = [ind]
|
||||
elif isinstance(ind, (float, np.floating)):
|
||||
ind = [self.z_ind(ind)]
|
||||
elif isinstance(ind[0], (float, np.floating)):
|
||||
ind = [self.z_ind(ii) for ii in ind]
|
||||
|
||||
# Load the spectra
|
||||
spectra = []
|
||||
@@ -312,11 +318,11 @@ class Pulse(Sequence):
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
z_ind: Union[int, Iterable[int]] = None,
|
||||
z_pos: Union[int, Iterable[int]] = None,
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
|
||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind)
|
||||
return plot_results_2D(vals, plt_range, self.params, **kwargs)
|
||||
|
||||
def plot_1D(
|
||||
@@ -324,28 +330,47 @@ class Pulse(Sequence):
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
z_ind: int,
|
||||
z_pos: int,
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
|
||||
return plot_results_1D(vals[0], plt_range, self.params, **kwargs)
|
||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind)
|
||||
return plot_results_1D(vals, plt_range, self.params, **kwargs)
|
||||
|
||||
def plot_avg(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
z_ind: int,
|
||||
z_pos: int,
|
||||
**kwargs,
|
||||
):
|
||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, slice(None))
|
||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, slice(None))
|
||||
return plot_avg(vals, plt_range, self.params, **kwargs)
|
||||
|
||||
def retrieve_plot_values(self, left, right, unit, z_ind, sim_ind):
|
||||
def retrieve_plot_values(self, left, right, unit, z_pos, sim_ind):
|
||||
plt_range = units.PlotRange(left, right, unit)
|
||||
if plt_range.unit.type == "TIME":
|
||||
vals = self.all_fields(ind=z_ind)[:, sim_ind]
|
||||
vals = self.all_fields(ind=z_pos)
|
||||
else:
|
||||
vals = self.all_spectra(ind=z_ind)[:, sim_ind]
|
||||
vals = self.all_spectra(ind=z_pos)
|
||||
if vals.ndim == 3:
|
||||
vals = vals[:, sim_ind]
|
||||
else:
|
||||
vals = vals[sim_ind]
|
||||
return plt_range, vals
|
||||
|
||||
def z_ind(self, z: float) -> int:
|
||||
"""return the closest z index to the given target
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : float
|
||||
target
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
index
|
||||
"""
|
||||
return math.argclosest(self.z, z)
|
||||
|
||||
Reference in New Issue
Block a user