Working on better multi-fiber sim
This commit is contained in:
@@ -186,6 +186,9 @@ class PlotRange:
|
|||||||
right: float = Parameter(type_checker(int, float))
|
right: float = Parameter(type_checker(int, float))
|
||||||
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit)
|
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||||
|
|
||||||
|
|
||||||
def beta2_coef(beta):
|
def beta2_coef(beta):
|
||||||
fac = 1e27
|
fac = 1e27
|
||||||
|
|||||||
@@ -580,10 +580,9 @@ def plot_results_1D(
|
|||||||
file_type: str = "pdf",
|
file_type: str = "pdf",
|
||||||
file_name: str = None,
|
file_name: str = None,
|
||||||
ax: plt.Axes = None,
|
ax: plt.Axes = None,
|
||||||
line_label: str = None,
|
|
||||||
transpose: bool = False,
|
transpose: bool = False,
|
||||||
**line_kwargs,
|
**line_kwargs,
|
||||||
) -> tuple[plt.Figure, plt.Axes, np.ndarray, np.ndarray]:
|
) -> tuple[plt.Figure, plt.Axes, plt.Line2D, np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -620,13 +619,15 @@ def plot_results_1D(
|
|||||||
special name to give to the plot. A name is automatically assigned anyway
|
special name to give to the plot. A name is automatically assigned anyway
|
||||||
ax : matplotlib.axes._subplots.AxesSubplot object, optional
|
ax : matplotlib.axes._subplots.AxesSubplot object, optional
|
||||||
axis on which to draw the plot
|
axis on which to draw the plot
|
||||||
line_label : str, optional
|
|
||||||
label of the line
|
|
||||||
transpose : bool, optional
|
transpose : bool, optional
|
||||||
transpose the plot
|
transpose the plot
|
||||||
line_kwargs : to be passed to plt.plot
|
line_kwargs : to be passed to plt.plot
|
||||||
returns
|
returns
|
||||||
fig, ax : matplotlib objects containing the plots
|
fig, ax, line : matplotlib objects containing the plots
|
||||||
|
x_axis : np.ndarray
|
||||||
|
new x axis array
|
||||||
|
ind : np.ndarray
|
||||||
|
corresponding indices on the old axis
|
||||||
example:
|
example:
|
||||||
if spectra is a (m, n, nt) array, one can plot a spectrum evolution as such:
|
if spectra is a (m, n, nt) array, one can plot a spectrum evolution as such:
|
||||||
>>> fig, ax = plot_results_2D(spectra[:, -1], (600, 1600, nm), log=True, "Heidt2017")
|
>>> fig, ax = plot_results_2D(spectra[:, -1], (600, 1600, nm), log=True, "Heidt2017")
|
||||||
@@ -674,19 +675,21 @@ def plot_results_1D(
|
|||||||
is_new_plot = ax is None
|
is_new_plot = ax is None
|
||||||
|
|
||||||
folder_name = ""
|
folder_name = ""
|
||||||
|
if file_name is None:
|
||||||
|
file_name = params.name + str(plt_range)
|
||||||
if is_new_plot:
|
if is_new_plot:
|
||||||
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
|
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
|
||||||
else:
|
else:
|
||||||
fig = ax.get_figure()
|
fig = ax.get_figure()
|
||||||
if transpose:
|
if transpose:
|
||||||
ax.plot(values, x_axis, label=line_label, **line_kwargs)
|
(line,) = ax.plot(values, x_axis, **line_kwargs)
|
||||||
ax.yaxis.tick_right()
|
ax.yaxis.tick_right()
|
||||||
ax.yaxis.set_label_position("right")
|
ax.yaxis.set_label_position("right")
|
||||||
ax.set_xlim(vmax, vmin)
|
ax.set_xlim(vmax, vmin)
|
||||||
ax.set_xlabel(ylabel)
|
ax.set_xlabel(ylabel)
|
||||||
ax.set_ylabel(plt_range.unit.label)
|
ax.set_ylabel(plt_range.unit.label)
|
||||||
else:
|
else:
|
||||||
ax.plot(x_axis, values, label=line_label, **line_kwargs)
|
(line,) = ax.plot(x_axis, values, **line_kwargs)
|
||||||
ax.set_ylim(vmin, vmax)
|
ax.set_ylim(vmin, vmax)
|
||||||
ax.set_ylabel(ylabel)
|
ax.set_ylabel(ylabel)
|
||||||
ax.set_xlabel(plt_range.unit.label)
|
ax.set_xlabel(plt_range.unit.label)
|
||||||
@@ -694,7 +697,7 @@ def plot_results_1D(
|
|||||||
if is_new_plot:
|
if is_new_plot:
|
||||||
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
fig.savefig(out_path, bbox_inches="tight", dpi=200)
|
||||||
print(f"plot saved in {out_path}")
|
print(f"plot saved in {out_path}")
|
||||||
return fig, ax, x_axis, values
|
return fig, ax, line, x_axis, values
|
||||||
|
|
||||||
|
|
||||||
def _prep_plot(
|
def _prep_plot(
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ from matplotlib.pyplot import subplot
|
|||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy.lib import utils
|
||||||
|
from numpy.lib.arraysetops import isin
|
||||||
from tqdm.std import Bar
|
from tqdm.std import Bar
|
||||||
|
|
||||||
from . import initialize, io, math
|
from . import initialize, io, math
|
||||||
@@ -14,7 +16,7 @@ from .physics import units, pulse
|
|||||||
from .const import SPECN_FN
|
from .const import SPECN_FN
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .plotting import plot_avg, plot_results_1D, plot_results_2D
|
from .plotting import plot_avg, plot_results_1D, plot_results_2D
|
||||||
from .utils.parameter import BareParams
|
from .utils.parameter import BareParams, validator_and
|
||||||
|
|
||||||
|
|
||||||
class Spectrum(np.ndarray):
|
class Spectrum(np.ndarray):
|
||||||
@@ -184,7 +186,7 @@ class Pulse(Sequence):
|
|||||||
return self.nmax
|
return self.nmax
|
||||||
|
|
||||||
def __getitem__(self, key) -> Spectrum:
|
def __getitem__(self, key) -> Spectrum:
|
||||||
return self.all_spectra(ind=range(self.nmax)[key]).squeeze()
|
return self.all_spectra(key)
|
||||||
|
|
||||||
def intensity(self, unit):
|
def intensity(self, unit):
|
||||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||||
@@ -253,7 +255,7 @@ class Pulse(Sequence):
|
|||||||
def _to_time_amp(self, spectrum):
|
def _to_time_amp(self, spectrum):
|
||||||
return np.fft.ifft(spectrum)
|
return np.fft.ifft(spectrum)
|
||||||
|
|
||||||
def all_spectra(self, ind=None) -> Spectrum:
|
def all_spectra(self, ind) -> Spectrum:
|
||||||
"""
|
"""
|
||||||
loads the data already simulated.
|
loads the data already simulated.
|
||||||
defauft shape is (z_targets, n, nt)
|
defauft shape is (z_targets, n, nt)
|
||||||
@@ -280,6 +282,10 @@ class Pulse(Sequence):
|
|||||||
ind = self.default_ind
|
ind = self.default_ind
|
||||||
if isinstance(ind, (int, np.integer)):
|
if isinstance(ind, (int, np.integer)):
|
||||||
ind = [ind]
|
ind = [ind]
|
||||||
|
elif isinstance(ind, (float, np.floating)):
|
||||||
|
ind = [self.z_ind(ind)]
|
||||||
|
elif isinstance(ind[0], (float, np.floating)):
|
||||||
|
ind = [self.z_ind(ii) for ii in ind]
|
||||||
|
|
||||||
# Load the spectra
|
# Load the spectra
|
||||||
spectra = []
|
spectra = []
|
||||||
@@ -312,11 +318,11 @@ class Pulse(Sequence):
|
|||||||
left: float,
|
left: float,
|
||||||
right: float,
|
right: float,
|
||||||
unit: Union[Callable[[float], float], str],
|
unit: Union[Callable[[float], float], str],
|
||||||
z_ind: Union[int, Iterable[int]] = None,
|
z_pos: Union[int, Iterable[int]] = None,
|
||||||
sim_ind: int = 0,
|
sim_ind: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
|
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind)
|
||||||
return plot_results_2D(vals, plt_range, self.params, **kwargs)
|
return plot_results_2D(vals, plt_range, self.params, **kwargs)
|
||||||
|
|
||||||
def plot_1D(
|
def plot_1D(
|
||||||
@@ -324,28 +330,47 @@ class Pulse(Sequence):
|
|||||||
left: float,
|
left: float,
|
||||||
right: float,
|
right: float,
|
||||||
unit: Union[Callable[[float], float], str],
|
unit: Union[Callable[[float], float], str],
|
||||||
z_ind: int,
|
z_pos: int,
|
||||||
sim_ind: int = 0,
|
sim_ind: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, sim_ind)
|
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind)
|
||||||
return plot_results_1D(vals[0], plt_range, self.params, **kwargs)
|
return plot_results_1D(vals, plt_range, self.params, **kwargs)
|
||||||
|
|
||||||
def plot_avg(
|
def plot_avg(
|
||||||
self,
|
self,
|
||||||
left: float,
|
left: float,
|
||||||
right: float,
|
right: float,
|
||||||
unit: Union[Callable[[float], float], str],
|
unit: Union[Callable[[float], float], str],
|
||||||
z_ind: int,
|
z_pos: int,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_ind, slice(None))
|
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, slice(None))
|
||||||
return plot_avg(vals, plt_range, self.params, **kwargs)
|
return plot_avg(vals, plt_range, self.params, **kwargs)
|
||||||
|
|
||||||
def retrieve_plot_values(self, left, right, unit, z_ind, sim_ind):
|
def retrieve_plot_values(self, left, right, unit, z_pos, sim_ind):
|
||||||
plt_range = units.PlotRange(left, right, unit)
|
plt_range = units.PlotRange(left, right, unit)
|
||||||
if plt_range.unit.type == "TIME":
|
if plt_range.unit.type == "TIME":
|
||||||
vals = self.all_fields(ind=z_ind)[:, sim_ind]
|
vals = self.all_fields(ind=z_pos)
|
||||||
else:
|
else:
|
||||||
vals = self.all_spectra(ind=z_ind)[:, sim_ind]
|
vals = self.all_spectra(ind=z_pos)
|
||||||
|
if vals.ndim == 3:
|
||||||
|
vals = vals[:, sim_ind]
|
||||||
|
else:
|
||||||
|
vals = vals[sim_ind]
|
||||||
return plt_range, vals
|
return plt_range, vals
|
||||||
|
|
||||||
|
def z_ind(self, z: float) -> int:
|
||||||
|
"""return the closest z index to the given target
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
z : float
|
||||||
|
target
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
index
|
||||||
|
"""
|
||||||
|
return math.argclosest(self.z, z)
|
||||||
|
|||||||
Reference in New Issue
Block a user