New SimulationSeries structure
This commit is contained in:
@@ -887,6 +887,7 @@ def uniform_axis(
|
||||
"""
|
||||
if new_axis_spec is None:
|
||||
new_axis_spec = "unity"
|
||||
|
||||
if isinstance(new_axis_spec, str) or callable(new_axis_spec):
|
||||
unit = units.get_unit(new_axis_spec)
|
||||
plt_range = PlotRange(unit.inv(axis.min()), unit.inv(axis.max()), new_axis_spec)
|
||||
@@ -896,6 +897,7 @@ def uniform_axis(
|
||||
plt_range = new_axis_spec
|
||||
else:
|
||||
raise TypeError(f"Don't know how to interpret {new_axis_spec}")
|
||||
|
||||
tmp_axis, ind, ext = sort_axis(axis, plt_range)
|
||||
values = np.atleast_2d(values)
|
||||
if np.allclose((diff := np.diff(tmp_axis))[0], diff):
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterator, Optional, Union
|
||||
from typing import Callable, Iterator, Optional, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@@ -120,105 +120,84 @@ class Spectrum(np.ndarray):
|
||||
|
||||
|
||||
class SimulationSeries:
|
||||
"""
|
||||
SimulationsSeries are the interface the user should use to load and
|
||||
interact with simulation data. The object loads each fiber of the simulation
|
||||
into a separate object and exposes convenience methods to make the series behave
|
||||
as a single fiber.
|
||||
|
||||
It should be noted that the last spectrum of a fiber and the first one of the next
|
||||
fibers are identical. Therefore, SimulationSeries object will return fewer datapoints
|
||||
than when manually mergin the corresponding data.
|
||||
|
||||
"""
|
||||
|
||||
path: Path
|
||||
fibers: list[SimulatedFiber]
|
||||
params: Parameters
|
||||
total_length: float
|
||||
total_num_steps: int
|
||||
previous: SimulationSeries = None
|
||||
fiber_lengths: list[tuple[str, float]]
|
||||
z_indices: list[tuple[int, int]]
|
||||
fiber_positions: list[tuple[str, float]]
|
||||
z_inds: np.ndarray
|
||||
|
||||
def __init__(self, path: os.PathLike):
|
||||
"""Create a SimulationSeries
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
path to the last fiber of the series
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
No simulation found in specified directory
|
||||
"""
|
||||
self.logger = get_logger()
|
||||
for self.path in simulations_list(path):
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(f"No simulation in {path}")
|
||||
self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN)))
|
||||
self.t = self.params.t
|
||||
self.w = self.params.w
|
||||
if self.params.prev_data_dir is not None:
|
||||
self.previous = SimulationSeries(self.params.prev_data_dir)
|
||||
self.total_length = self.accumulate_params("length")
|
||||
self.total_num_steps = self.accumulate_params("z_num")
|
||||
self.z_inds = np.arange(len(self.params.z_targets))
|
||||
self.z = self.params.z_targets
|
||||
if self.previous is not None:
|
||||
self.z += self.previous.params.z_targets[-1]
|
||||
self.params.z_targets = np.concatenate((self.previous.z, self.params.z_targets))
|
||||
self.z_inds += self.previous.z_inds[-1] + 1
|
||||
self.fiber_lengths = self.all_params("length")
|
||||
self.fiber_positions = [
|
||||
(this[0], following[1])
|
||||
for this, following in zip(self.fiber_lengths, [(None, 0.0)] + self.fiber_lengths)
|
||||
]
|
||||
self.fibers = [SimulatedFiber(self.path)]
|
||||
while (p := self.fibers[-1].params.prev_data_dir) is not None:
|
||||
self.fibers.append(SimulatedFiber(p))
|
||||
self.fibers = self.fibers[::-1]
|
||||
|
||||
def all_params(self, key: str) -> list[tuple[str, Any]]:
|
||||
"""returns the value of a parameter for each fiber
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
name of the parameter
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[tuple[str, Any]]
|
||||
list of (fiber_name, param_value) tuples
|
||||
"""
|
||||
return list(reversed(self._all_params(key, [])))
|
||||
|
||||
def accumulate_params(self, key: str) -> Any:
|
||||
"""returns the sum of all the values a parameter takes. Useful to
|
||||
get the total length of the fiber, the total number of steps, etc.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
name of the parameter
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
final sum
|
||||
"""
|
||||
return sum(el[1] for el in self.all_params(key))
|
||||
self.fiber_positions = [(self.fibers[0].params.name, 0.0)]
|
||||
self.params = Parameters(**self.fibers[0].params.dump_dict(False, False))
|
||||
z_targets = list(self.params.z_targets)
|
||||
self.z_indices = [(0, j) for j in range(self.params.z_num)]
|
||||
for i, fiber in enumerate(self.fibers[1:]):
|
||||
self.fiber_positions.append((fiber.params.name, z_targets[-1]))
|
||||
z_targets += list(fiber.params.z_targets[1:] + z_targets[-1])
|
||||
self.z_indices += [(i + 1, j) for j in range(1, fiber.params.z_num)]
|
||||
self.params.z_targets = np.array(z_targets)
|
||||
self.params.length = self.params.z_targets[-1]
|
||||
self.params.z_num = len(self.params.z_targets)
|
||||
|
||||
def spectra(
|
||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||
) -> Spectrum:
|
||||
...
|
||||
if z_descr is None:
|
||||
out = [self.spectra(i, sim_ind) for i in range(self.total_num_steps)]
|
||||
out = [self.fibers[i].spectra(j, sim_ind) for i, j in self.z_indices]
|
||||
else:
|
||||
if isinstance(z_descr, (float, np.floating)):
|
||||
return self.spectra(self.z_ind(z_descr), sim_ind)
|
||||
fib_ind, z_ind = self.z_ind(z_descr)
|
||||
else:
|
||||
z_ind = z_descr
|
||||
if 0 <= z_ind < self.z_inds[0]:
|
||||
return self.previous.spectra(z_ind, sim_ind)
|
||||
elif z_ind < 0:
|
||||
z_ind = self.total_num_steps + z_ind
|
||||
if sim_ind is None:
|
||||
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
|
||||
else:
|
||||
out = self._load_1(z_ind)
|
||||
fib_ind, z_ind = self.z_indices[z_descr]
|
||||
out = self.fibers[fib_ind].spectra(z_ind, sim_ind)
|
||||
return Spectrum(out, self.params)
|
||||
|
||||
def z_ind(self, pos: float) -> int:
|
||||
if self.z[0] <= pos <= self.z[-1]:
|
||||
return self.z_inds[np.argmin(np.abs(self.z - pos))]
|
||||
elif 0 <= pos < self.z[0]:
|
||||
return self.previous.z_ind(pos)
|
||||
else:
|
||||
raise ValueError(f"cannot match z={pos} with max length of {self.total_length}")
|
||||
|
||||
def fields(
|
||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||
) -> Spectrum:
|
||||
return self.params.ifft(self.spectra(z_descr, sim_ind))
|
||||
|
||||
# Plotting
|
||||
def z_ind(self, pos: float) -> tuple[int, int]:
|
||||
if 0 <= pos <= self.params.length:
|
||||
ind = np.argmin(np.abs(self.params.z_targets - pos))
|
||||
return self.z_indices[ind]
|
||||
else:
|
||||
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
|
||||
|
||||
def plot_2D(
|
||||
self,
|
||||
@@ -344,7 +323,65 @@ class SimulationSeries:
|
||||
spectra.append(tmp)
|
||||
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
|
||||
|
||||
# Private
|
||||
# Magic methods
|
||||
|
||||
def __iter__(self) -> Iterator[Spectrum]:
|
||||
for i, j in self.z_indices:
|
||||
yield self.fibers[i].spectra(j, None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(path={self.path})"
|
||||
|
||||
def __eq__(self, other: SimulationSeries) -> bool:
|
||||
return self.path == other.path and self.params == other.params
|
||||
|
||||
def __contains__(self, fiber: SimulatedFiber) -> bool:
|
||||
return fiber in self.fibers
|
||||
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
if isinstance(key, tuple):
|
||||
return self.spectra(*key)
|
||||
else:
|
||||
return self.spectra(key, None)
|
||||
|
||||
|
||||
class SimulatedFiber:
|
||||
params: Parameters
|
||||
t: np.ndarray
|
||||
w: np.ndarray
|
||||
|
||||
def __init__(self, path: os.PathLike):
|
||||
self.path = Path(path)
|
||||
self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN)))
|
||||
self.t = self.params.t
|
||||
self.w = self.params.w
|
||||
self.z = self.params.z_targets
|
||||
|
||||
def spectra(
|
||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||
) -> np.ndarray:
|
||||
if z_descr is None:
|
||||
out = [self.spectra(i, sim_ind) for i in range(self.params.z_num)]
|
||||
else:
|
||||
if isinstance(z_descr, (float, np.floating)):
|
||||
return self.spectra(self.z_ind(z_descr), sim_ind)
|
||||
else:
|
||||
z_ind = z_descr
|
||||
|
||||
if z_ind < 0:
|
||||
z_ind = self.params.z_num + z_ind
|
||||
|
||||
if sim_ind is None:
|
||||
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
|
||||
else:
|
||||
out = self._load_1(z_ind)
|
||||
return Spectrum(out, self.params)
|
||||
|
||||
def z_ind(self, pos: float) -> int:
|
||||
if 0 <= pos <= self.z[-1]:
|
||||
return np.argmin(np.abs(self.z - pos))
|
||||
else:
|
||||
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
|
||||
|
||||
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
|
||||
"""loads a spectrum file
|
||||
@@ -362,40 +399,9 @@ class SimulationSeries:
|
||||
loaded spectrum file
|
||||
"""
|
||||
if sim_ind > 0:
|
||||
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind - self.z_inds[0], sim_ind))
|
||||
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
|
||||
else:
|
||||
return load_spectrum(self.path / SPEC1_FN.format(z_ind - self.z_inds[0]))
|
||||
|
||||
def _all_params(self, key: str, l: list) -> list:
|
||||
l.append((self.params.name, getattr(self.params, key)))
|
||||
if self.previous is not None:
|
||||
return self.previous._all_params(key, l)
|
||||
return l
|
||||
|
||||
# Magic methods
|
||||
|
||||
def __iter__(self) -> Iterator[Spectrum]:
|
||||
for i in range(self.total_num_steps):
|
||||
yield self.spectra(i, None)
|
||||
return load_spectrum(self.path / SPEC1_FN.format(z_ind))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})"
|
||||
|
||||
def __eq__(self, other: SimulationSeries) -> bool:
|
||||
return (
|
||||
self.path == other.path
|
||||
and self.params == other.params
|
||||
and self.previous == other.previous
|
||||
)
|
||||
|
||||
def __contains__(self, other: SimulationSeries) -> bool:
|
||||
if other is self or other == self:
|
||||
return True
|
||||
if self.previous is not None:
|
||||
return other in self.previous
|
||||
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
if isinstance(key, tuple):
|
||||
return self.spectra(*key)
|
||||
else:
|
||||
return self.spectra(key, None)
|
||||
return f"{self.__class__.__name__}(path={self.path})"
|
||||
|
||||
Reference in New Issue
Block a user