Files
scgenerator/src/scgenerator/spectra.py
2021-08-30 13:30:16 +02:00

373 lines
10 KiB
Python

import os
from collections.abc import Sequence
from pathlib import Path
from typing import Callable, Dict, Iterable, Union
import matplotlib.pyplot as plt
import numpy as np
from . import math
from .const import SPECN_FN
from .logger import get_logger
from .physics import pulse, units
from .plotting import mean_values_plot, propagation_plot, single_position_plot
from .utils.parameter import Parameters, PlotRange
class Spectrum(np.ndarray):
params: Parameters
def __new__(cls, input_array, params: Parameters):
# Input array is an already formed ndarray instance
# We first cast to be our class type
obj = np.asarray(input_array).view(cls)
# add the new attribute to the created instance
obj.params = params
# Finally, we must return the newly created object:
return obj
def __array_finalize__(self, obj):
# see InfoArray.__array_finalize__ for comments
if obj is None:
return
self.params = getattr(obj, "params", None)
def __getitem__(self, key) -> "Spectrum":
return super().__getitem__(key)
def energy(self) -> Union[np.ndarray, float]:
if self.ndim == 1:
m = np.argwhere(self.params.l > 0)[:, 0]
m = np.array(sorted(m, key=lambda el: self.params.l[el]))
return np.trapz(self.wl_int[m], self.params.l[m])
else:
return np.array([s.energy() for s in self])
def crop_wl(self, left: float, right: float) -> np.ndarray:
cond = (self.params.l >= left) & (self.params.l <= right)
return cond
@property
def wl_int(self):
return units.to_WL(math.abs2(self), self.params.l)
@property
def freq_int(self):
return math.abs2(self)
@property
def afreq_int(self):
return math.abs2(self)
@property
def time_int(self):
return math.abs2(np.fft.ifft(self))
def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.params.w)
else:
x_axis = unit.inv(self.params.t)
order = np.argsort(x_axis)
func = dict(
WL=self.wl_amp,
FREQ=self.freq_amp,
AFREQ=self.afreq_amp,
TIME=self.time_amp,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
@property
def wl_amp(self):
return (
np.sqrt(
units.to_WL(
math.abs2(self),
self.params.l,
)
)
* self
/ np.abs(self)
)
@property
def freq_amp(self):
return self
@property
def afreq_amp(self):
return self
@property
def time_amp(self):
return np.fft.ifft(self)
@property
def wl_max(self):
if self.ndim == 1:
return self.params.l[np.argmax(self.wl_int, axis=-1)]
return np.array([s.wl_max for s in self])
def mask_wl(self, pos: float, width: float) -> "Spectrum":
return self * np.exp(
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
)
def measure(self) -> tuple[float, float, float]:
return pulse.measure_field(self.params.t, self.time_amp)
class Pulse(Sequence):
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
"""load a data folder as a pulse
Parameters
----------
path : os.PathLike
path to the data (folder containing .npy files)
default_ind : int | Iterable[int], optional
default indices to be loaded, by default None
Raises
------
FileNotFoundError
path does not contain proper data
"""
self.logger = get_logger(__name__)
self.path = Path(path)
self.default_ind = default_ind
if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = Parameters.load(self.path / "params.toml")
try:
self.z = np.load(os.path.join(path, "z.npy"))
except FileNotFoundError:
if self.params is not None:
self.z = self.params.z_targets
else:
raise
self.cache: Dict[int, Spectrum] = {}
self.nmax = len(list(self.path.glob("spectra_*.npy")))
if self.nmax <= 0:
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
self.t = self.params.t
w = math.wspace(self.t) + units.m(self.params.wavelength)
self.w_order = np.argsort(w)
self.w = w
self.wl = units.m.inv(self.w)
self.params.w = self.w
self.params.z_targets = self.z
def __iter__(self):
"""
similar to all_spectra but works as an iterator
"""
self.logger.debug(f"iterating through {self.path}")
for i in range(self.nmax):
yield self._load1(i)
def __len__(self):
return self.nmax
def __getitem__(self, key) -> Spectrum:
return self.all_spectra(key)
def intensity(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self._to_wl_int,
FREQ=self._to_freq_int,
AFREQ=self._to_afreq_int,
TIME=self._to_time_int,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
def _to_wl_int(self, spectrum):
return units.to_WL(math.abs2(spectrum), spectrum.wl)
def _to_freq_int(self, spectrum):
return math.abs2(spectrum)
def _to_afreq_int(self, spectrum):
return math.abs2(spectrum)
def _to_time_int(self, spectrum):
return math.abs2(np.fft.ifft(spectrum))
def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self._to_wl_amp,
FREQ=self._to_freq_amp,
AFREQ=self._to_afreq_amp,
TIME=self._to_time_amp,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
def _to_wl_amp(self, spectrum):
return (
np.sqrt(
units.to_WL(
math.abs2(spectrum),
spectrum.wl,
)
)
* spectrum
/ np.abs(spectrum)
)
def _to_freq_amp(self, spectrum):
return spectrum
def _to_afreq_amp(self, spectrum):
return spectrum
def _to_time_amp(self, spectrum):
return np.fft.ifft(spectrum)
def all_spectra(self, ind=None) -> Spectrum:
"""
loads the data already simulated.
defauft shape is (z_targets, n, nt)
Parameters
----------
ind : int or list of int
if only certain spectra are desired
Returns
----------
spectra : array of shape (nz, m, nt)
array of complex spectra (pulse at nz positions consisting
of nm simulation on a nt size grid)
"""
self.logger.debug(f"opening {self.path}")
# Check if file exists and assert how many z positions there are
if ind is None:
if self.default_ind is None:
ind = range(self.nmax)
else:
ind = self.default_ind
if isinstance(ind, (int, np.integer)):
ind = [ind]
elif isinstance(ind, (float, np.floating)):
ind = [self.z_ind(ind)]
elif isinstance(ind[0], (float, np.floating)):
ind = [self.z_ind(ii) for ii in ind]
# Load the spectra
spectra = []
for i in ind:
spectra.append(self._load1(i))
spectra = Spectrum(spectra, self.params)
self.logger.debug(f"all spectra from {self.path} successfully loaded")
if len(ind) == 1:
return spectra[0]
else:
return spectra
def all_fields(self, ind=None):
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
def _load1(self, i: int):
if i < 0:
i = self.nmax + i
if i in self.cache:
return self.cache[i]
spec = np.load(self.path / SPECN_FN.format(i))
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.params)
self.cache[i] = spec
return spec
def plot_2D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: Union[int, Iterable[int]] = None,
sim_ind: int = 0,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind)
return propagation_plot(vals, plt_range, self.params, ax, **kwargs)
def plot_1D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: int,
sim_ind: int = 0,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, sim_ind)
return single_position_plot(vals, plt_range, self.params, ax, **kwargs)
def plot_mean(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: int,
**kwargs,
):
plt_range, vals = self.retrieve_plot_values(left, right, unit, z_pos, slice(None))
return mean_values_plot(vals, plt_range, self.params, ax, **kwargs)
def retrieve_plot_values(self, left, right, unit, z_pos, sim_ind):
plt_range = PlotRange(left, right, unit)
if plt_range.unit.type == "TIME":
vals = self.all_fields(ind=z_pos)
else:
vals = self.all_spectra(ind=z_pos)
if vals.ndim == 3:
vals = vals[:, sim_ind]
else:
vals = vals[sim_ind]
return plt_range, vals
def z_ind(self, z: float) -> int:
"""return the closest z index to the given target
Parameters
----------
z : float
target
Returns
-------
int
index
"""
return math.argclosest(self.z, z)