Files
scgenerator/src/scgenerator/spectra.py
2022-01-06 10:19:40 +01:00

402 lines
12 KiB
Python

from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Callable, Iterator, Optional, Union
import matplotlib.pyplot as plt
import numpy as np
from . import math
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N
from .logger import get_logger
from .parameter import Parameters
from .physics import pulse, units
from .physics.units import PlotRange
from .plotting import (
mean_values_plot,
propagation_plot,
single_position_plot,
transform_1D_values,
transform_2D_propagation,
)
from .utils import load_spectrum, simulations_list, load_toml
from .legacy import translate_parameters
class Spectrum(np.ndarray):
params: Parameters
def __new__(cls, input_array, params: Parameters):
# Input array is an already formed ndarray instance
# We first cast to be our class type
obj = np.asarray(input_array).view(cls)
# add the new attribute to the created instance
obj.params = params
# Finally, we must return the newly created object:
return obj
def __array_finalize__(self, obj):
# see InfoArray.__array_finalize__ for comments
if obj is None:
return
self.params = getattr(obj, "params", None)
def __getitem__(self, key) -> "Spectrum":
return super().__getitem__(key)
@property
def wl_int(self):
return units.to_WL(math.abs2(self), self.params.l)
@property
def freq_int(self):
return math.abs2(self)
@property
def afreq_int(self):
return math.abs2(self)
@property
def time_int(self):
return math.abs2(self.params.ifft(self))
def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.params.w)
else:
x_axis = unit.inv(self.params.t)
order = np.argsort(x_axis)
func = dict(
WL=self.wl_amp,
FREQ=self.freq_amp,
AFREQ=self.afreq_amp,
TIME=self.time_amp,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
@property
def wl_amp(self):
return (
np.sqrt(
units.to_WL(
math.abs2(self),
self.params.l,
)
)
* self
/ np.abs(self)
)
@property
def freq_amp(self):
return self
@property
def afreq_amp(self):
return self
@property
def time_amp(self):
return np.fft.ifft(self)
@property
def wl_max(self):
if self.ndim == 1:
return self.params.l[np.argmax(self.wl_int, axis=-1)]
return np.array([s.wl_max for s in self])
def mask_wl(self, pos: float, width: float) -> Spectrum:
return self * np.exp(
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
)
def measure(self) -> tuple[float, float, float]:
return pulse.measure_field(self.params.t, self.time_amp)
class SimulationSeries:
path: Path
params: Parameters
total_length: float
total_num_steps: int
previous: SimulationSeries = None
fiber_lengths: list[tuple[str, float]]
fiber_positions: list[tuple[str, float]]
z_inds: np.ndarray
def __init__(self, path: os.PathLike):
self.logger = get_logger()
for self.path in simulations_list(path):
break
else:
raise FileNotFoundError(f"No simulation in {path}")
self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN)))
self.t = self.params.t
self.w = self.params.w
if self.params.prev_data_dir is not None:
self.previous = SimulationSeries(self.params.prev_data_dir)
self.total_length = self.accumulate_params("length")
self.total_num_steps = self.accumulate_params("z_num")
self.z_inds = np.arange(len(self.params.z_targets))
self.z = self.params.z_targets
if self.previous is not None:
self.z += self.previous.params.z_targets[-1]
self.params.z_targets = np.concatenate((self.previous.z, self.params.z_targets))
self.z_inds += self.previous.z_inds[-1] + 1
self.fiber_lengths = self.all_params("length")
self.fiber_positions = [
(this[0], following[1])
for this, following in zip(self.fiber_lengths, [(None, 0.0)] + self.fiber_lengths)
]
def all_params(self, key: str) -> list[tuple[str, Any]]:
"""returns the value of a parameter for each fiber
Parameters
----------
key : str
name of the parameter
Returns
-------
list[tuple[str, Any]]
list of (fiber_name, param_value) tuples
"""
return list(reversed(self._all_params(key, [])))
def accumulate_params(self, key: str) -> Any:
"""returns the sum of all the values a parameter takes. Useful to
get the total length of the fiber, the total number of steps, etc.
Parameters
----------
key : str
name of the parameter
Returns
-------
Any
final sum
"""
return sum(el[1] for el in self.all_params(key))
def spectra(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> Spectrum:
if z_descr is None:
out = [self.spectra(i, sim_ind) for i in range(self.total_num_steps)]
else:
if isinstance(z_descr, (float, np.floating)):
return self.spectra(self.z_ind(z_descr), sim_ind)
else:
z_ind = z_descr
if 0 <= z_ind < self.z_inds[0]:
return self.previous.spectra(z_ind, sim_ind)
elif z_ind < 0:
z_ind = self.total_num_steps + z_ind
if sim_ind is None:
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
else:
out = self._load_1(z_ind)
return Spectrum(out, self.params)
def z_ind(self, pos: float) -> int:
if self.z[0] <= pos <= self.z[-1]:
return self.z_inds[np.argmin(np.abs(self.z - pos))]
elif 0 <= pos < self.z[0]:
return self.previous.z_ind(pos)
else:
raise ValueError(f"cannot match z={pos} with max length of {self.total_length}")
def fields(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> Spectrum:
return self.params.ifft(self.spectra(z_descr, sim_ind))
# Plotting
def plot_2D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
return propagation_plot(vals, plot_range, self.params, ax, **kwargs)
def plot_values_2D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
return transform_2D_propagation(vals, plot_range, self.params, **kwargs)
def plot_1D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: int,
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
return single_position_plot(vals, plot_range, self.params, ax, **kwargs)
def plot_values_1D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_pos: int,
sim_ind: int = 0,
**kwargs,
) -> tuple[np.ndarray, np.ndarray]:
"""gives the desired values already tranformes according to the give range
Parameters
----------
left : float
leftmost limit in unit
right : float
rightmost limit in unit
unit : Union[Callable[[float], float], str]
unit
z_pos : Union[int, float]
position either as an index (int) or a real position (float)
sim_ind : Optional[int]
which simulation to take when more than one are present
Returns
-------
np.ndarray
x axis
np.ndarray
y values
"""
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
return transform_1D_values(vals, plot_range, self.params, **kwargs)
def plot_mean(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: int,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, None)
return mean_values_plot(vals, plot_range, self.params, ax, **kwargs)
def retrieve_plot_values(
self, plot_range: PlotRange, z_pos: Optional[Union[int, float]], sim_ind: Optional[int]
):
if plot_range.unit.type == "TIME":
return self.fields(z_pos, sim_ind)
else:
return self.spectra(z_pos, sim_ind)
def rin_propagation(
self, left: float, right: float, unit: str
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""returns the RIN as function of unit and z
Parameters
----------
left : float
left limit in unit
right : float
right limit in unit
unit : str
unit descriptor
Returns
-------
x : np.ndarray, shape (nt,)
x axis
y : np.ndarray, shape (z_num, )
y axis
rin_prop : np.ndarray, shape (z_num, nt)
RIN
"""
spectra = []
for spec in np.moveaxis(self.spectra(None, None), 1, 0):
x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False)
spectra.append(tmp)
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
# Private
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
"""loads a spectrum file
Parameters
----------
z_ind : int
z_index relative to the entire simulation
sim_ind : int, optional
simulation index, used when repeated simulations with same parameters are ran, by default 0
Returns
-------
np.ndarray
loaded spectrum file
"""
if sim_ind > 0:
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind - self.z_inds[0], sim_ind))
else:
return load_spectrum(self.path / SPEC1_FN.format(z_ind - self.z_inds[0]))
def _all_params(self, key: str, l: list) -> list:
l.append((self.params.name, getattr(self.params, key)))
if self.previous is not None:
return self.previous._all_params(key, l)
return l
# Magic methods
def __iter__(self) -> Iterator[Spectrum]:
for i in range(self.total_num_steps):
yield self.spectra(i, None)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})"
def __eq__(self, other: SimulationSeries) -> bool:
return (
self.path == other.path
and self.params == other.params
and self.previous == other.previous
)
def __contains__(self, other: SimulationSeries) -> bool:
if other is self or other == self:
return True
if self.previous is not None:
return other in self.previous
def __getitem__(self, key) -> Spectrum:
if isinstance(key, tuple):
return self.spectra(*key)
else:
return self.spectra(key, None)