New SimulationSeries structure

This commit is contained in:
Benoît Sierro
2022-07-19 14:51:35 +02:00
parent 3ab20c219c
commit 1961ff9bfd
2 changed files with 117 additions and 109 deletions

View File

@@ -887,6 +887,7 @@ def uniform_axis(
"""
if new_axis_spec is None:
new_axis_spec = "unity"
if isinstance(new_axis_spec, str) or callable(new_axis_spec):
unit = units.get_unit(new_axis_spec)
plt_range = PlotRange(unit.inv(axis.min()), unit.inv(axis.max()), new_axis_spec)
@@ -896,6 +897,7 @@ def uniform_axis(
plt_range = new_axis_spec
else:
raise TypeError(f"Don't know how to interpret {new_axis_spec}")
tmp_axis, ind, ext = sort_axis(axis, plt_range)
values = np.atleast_2d(values)
if np.allclose((diff := np.diff(tmp_axis))[0], diff):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Callable, Iterator, Optional, Union
from typing import Callable, Iterator, Optional, Union
import matplotlib.pyplot as plt
import numpy as np
@@ -120,105 +120,84 @@ class Spectrum(np.ndarray):
class SimulationSeries:
"""
SimulationsSeries are the interface the user should use to load and
interact with simulation data. The object loads each fiber of the simulation
into a separate object and exposes convenience methods to make the series behave
as a single fiber.
It should be noted that the last spectrum of a fiber and the first one of the next
fibers are identical. Therefore, SimulationSeries object will return fewer datapoints
than when manually mergin the corresponding data.
"""
path: Path
fibers: list[SimulatedFiber]
params: Parameters
total_length: float
total_num_steps: int
previous: SimulationSeries = None
fiber_lengths: list[tuple[str, float]]
z_indices: list[tuple[int, int]]
fiber_positions: list[tuple[str, float]]
z_inds: np.ndarray
def __init__(self, path: os.PathLike):
"""Create a SimulationSeries
Parameters
----------
path : os.PathLike
path to the last fiber of the series
Raises
------
FileNotFoundError
No simulation found in specified directory
"""
self.logger = get_logger()
for self.path in simulations_list(path):
break
else:
raise FileNotFoundError(f"No simulation in {path}")
self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN)))
self.t = self.params.t
self.w = self.params.w
if self.params.prev_data_dir is not None:
self.previous = SimulationSeries(self.params.prev_data_dir)
self.total_length = self.accumulate_params("length")
self.total_num_steps = self.accumulate_params("z_num")
self.z_inds = np.arange(len(self.params.z_targets))
self.z = self.params.z_targets
if self.previous is not None:
self.z += self.previous.params.z_targets[-1]
self.params.z_targets = np.concatenate((self.previous.z, self.params.z_targets))
self.z_inds += self.previous.z_inds[-1] + 1
self.fiber_lengths = self.all_params("length")
self.fiber_positions = [
(this[0], following[1])
for this, following in zip(self.fiber_lengths, [(None, 0.0)] + self.fiber_lengths)
]
self.fibers = [SimulatedFiber(self.path)]
while (p := self.fibers[-1].params.prev_data_dir) is not None:
self.fibers.append(SimulatedFiber(p))
self.fibers = self.fibers[::-1]
def all_params(self, key: str) -> list[tuple[str, Any]]:
"""returns the value of a parameter for each fiber
Parameters
----------
key : str
name of the parameter
Returns
-------
list[tuple[str, Any]]
list of (fiber_name, param_value) tuples
"""
return list(reversed(self._all_params(key, [])))
def accumulate_params(self, key: str) -> Any:
"""returns the sum of all the values a parameter takes. Useful to
get the total length of the fiber, the total number of steps, etc.
Parameters
----------
key : str
name of the parameter
Returns
-------
Any
final sum
"""
return sum(el[1] for el in self.all_params(key))
self.fiber_positions = [(self.fibers[0].params.name, 0.0)]
self.params = Parameters(**self.fibers[0].params.dump_dict(False, False))
z_targets = list(self.params.z_targets)
self.z_indices = [(0, j) for j in range(self.params.z_num)]
for i, fiber in enumerate(self.fibers[1:]):
self.fiber_positions.append((fiber.params.name, z_targets[-1]))
z_targets += list(fiber.params.z_targets[1:] + z_targets[-1])
self.z_indices += [(i + 1, j) for j in range(1, fiber.params.z_num)]
self.params.z_targets = np.array(z_targets)
self.params.length = self.params.z_targets[-1]
self.params.z_num = len(self.params.z_targets)
def spectra(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> Spectrum:
...
if z_descr is None:
out = [self.spectra(i, sim_ind) for i in range(self.total_num_steps)]
out = [self.fibers[i].spectra(j, sim_ind) for i, j in self.z_indices]
else:
if isinstance(z_descr, (float, np.floating)):
return self.spectra(self.z_ind(z_descr), sim_ind)
fib_ind, z_ind = self.z_ind(z_descr)
else:
z_ind = z_descr
if 0 <= z_ind < self.z_inds[0]:
return self.previous.spectra(z_ind, sim_ind)
elif z_ind < 0:
z_ind = self.total_num_steps + z_ind
if sim_ind is None:
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
else:
out = self._load_1(z_ind)
fib_ind, z_ind = self.z_indices[z_descr]
out = self.fibers[fib_ind].spectra(z_ind, sim_ind)
return Spectrum(out, self.params)
def z_ind(self, pos: float) -> int:
if self.z[0] <= pos <= self.z[-1]:
return self.z_inds[np.argmin(np.abs(self.z - pos))]
elif 0 <= pos < self.z[0]:
return self.previous.z_ind(pos)
else:
raise ValueError(f"cannot match z={pos} with max length of {self.total_length}")
def fields(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> Spectrum:
return self.params.ifft(self.spectra(z_descr, sim_ind))
# Plotting
def z_ind(self, pos: float) -> tuple[int, int]:
if 0 <= pos <= self.params.length:
ind = np.argmin(np.abs(self.params.z_targets - pos))
return self.z_indices[ind]
else:
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
def plot_2D(
self,
@@ -344,7 +323,65 @@ class SimulationSeries:
spectra.append(tmp)
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
# Private
# Magic methods
def __iter__(self) -> Iterator[Spectrum]:
for i, j in self.z_indices:
yield self.fibers[i].spectra(j, None)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path})"
def __eq__(self, other: SimulationSeries) -> bool:
return self.path == other.path and self.params == other.params
def __contains__(self, fiber: SimulatedFiber) -> bool:
return fiber in self.fibers
def __getitem__(self, key) -> Spectrum:
if isinstance(key, tuple):
return self.spectra(*key)
else:
return self.spectra(key, None)
class SimulatedFiber:
params: Parameters
t: np.ndarray
w: np.ndarray
def __init__(self, path: os.PathLike):
self.path = Path(path)
self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN)))
self.t = self.params.t
self.w = self.params.w
self.z = self.params.z_targets
def spectra(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> np.ndarray:
if z_descr is None:
out = [self.spectra(i, sim_ind) for i in range(self.params.z_num)]
else:
if isinstance(z_descr, (float, np.floating)):
return self.spectra(self.z_ind(z_descr), sim_ind)
else:
z_ind = z_descr
if z_ind < 0:
z_ind = self.params.z_num + z_ind
if sim_ind is None:
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
else:
out = self._load_1(z_ind)
return Spectrum(out, self.params)
def z_ind(self, pos: float) -> int:
if 0 <= pos <= self.z[-1]:
return np.argmin(np.abs(self.z - pos))
else:
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
"""loads a spectrum file
@@ -362,40 +399,9 @@ class SimulationSeries:
loaded spectrum file
"""
if sim_ind > 0:
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind - self.z_inds[0], sim_ind))
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
else:
return load_spectrum(self.path / SPEC1_FN.format(z_ind - self.z_inds[0]))
def _all_params(self, key: str, l: list) -> list:
l.append((self.params.name, getattr(self.params, key)))
if self.previous is not None:
return self.previous._all_params(key, l)
return l
# Magic methods
def __iter__(self) -> Iterator[Spectrum]:
for i in range(self.total_num_steps):
yield self.spectra(i, None)
return load_spectrum(self.path / SPEC1_FN.format(z_ind))
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})"
def __eq__(self, other: SimulationSeries) -> bool:
return (
self.path == other.path
and self.params == other.params
and self.previous == other.previous
)
def __contains__(self, other: SimulationSeries) -> bool:
if other is self or other == self:
return True
if self.previous is not None:
return other in self.previous
def __getitem__(self, key) -> Spectrum:
if isinstance(key, tuple):
return self.spectra(*key)
else:
return self.spectra(key, None)
return f"{self.__class__.__name__}(path={self.path})"