New SimulationSeries structure

This commit is contained in:
Benoît Sierro
2022-07-19 14:51:35 +02:00
parent 3ab20c219c
commit 1961ff9bfd
2 changed files with 117 additions and 109 deletions

View File

@@ -887,6 +887,7 @@ def uniform_axis(
""" """
if new_axis_spec is None: if new_axis_spec is None:
new_axis_spec = "unity" new_axis_spec = "unity"
if isinstance(new_axis_spec, str) or callable(new_axis_spec): if isinstance(new_axis_spec, str) or callable(new_axis_spec):
unit = units.get_unit(new_axis_spec) unit = units.get_unit(new_axis_spec)
plt_range = PlotRange(unit.inv(axis.min()), unit.inv(axis.max()), new_axis_spec) plt_range = PlotRange(unit.inv(axis.min()), unit.inv(axis.max()), new_axis_spec)
@@ -896,6 +897,7 @@ def uniform_axis(
plt_range = new_axis_spec plt_range = new_axis_spec
else: else:
raise TypeError(f"Don't know how to interpret {new_axis_spec}") raise TypeError(f"Don't know how to interpret {new_axis_spec}")
tmp_axis, ind, ext = sort_axis(axis, plt_range) tmp_axis, ind, ext = sort_axis(axis, plt_range)
values = np.atleast_2d(values) values = np.atleast_2d(values)
if np.allclose((diff := np.diff(tmp_axis))[0], diff): if np.allclose((diff := np.diff(tmp_axis))[0], diff):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Iterator, Optional, Union from typing import Callable, Iterator, Optional, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@@ -120,105 +120,84 @@ class Spectrum(np.ndarray):
class SimulationSeries: class SimulationSeries:
"""
SimulationsSeries are the interface the user should use to load and
interact with simulation data. The object loads each fiber of the simulation
into a separate object and exposes convenience methods to make the series behave
as a single fiber.
It should be noted that the last spectrum of a fiber and the first one of the next
fibers are identical. Therefore, SimulationSeries object will return fewer datapoints
than when manually mergin the corresponding data.
"""
path: Path path: Path
fibers: list[SimulatedFiber]
params: Parameters params: Parameters
total_length: float z_indices: list[tuple[int, int]]
total_num_steps: int
previous: SimulationSeries = None
fiber_lengths: list[tuple[str, float]]
fiber_positions: list[tuple[str, float]] fiber_positions: list[tuple[str, float]]
z_inds: np.ndarray
def __init__(self, path: os.PathLike): def __init__(self, path: os.PathLike):
"""Create a SimulationSeries
Parameters
----------
path : os.PathLike
path to the last fiber of the series
Raises
------
FileNotFoundError
No simulation found in specified directory
"""
self.logger = get_logger() self.logger = get_logger()
for self.path in simulations_list(path): for self.path in simulations_list(path):
break break
else: else:
raise FileNotFoundError(f"No simulation in {path}") raise FileNotFoundError(f"No simulation in {path}")
self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN))) self.fibers = [SimulatedFiber(self.path)]
self.t = self.params.t while (p := self.fibers[-1].params.prev_data_dir) is not None:
self.w = self.params.w self.fibers.append(SimulatedFiber(p))
if self.params.prev_data_dir is not None: self.fibers = self.fibers[::-1]
self.previous = SimulationSeries(self.params.prev_data_dir)
self.total_length = self.accumulate_params("length")
self.total_num_steps = self.accumulate_params("z_num")
self.z_inds = np.arange(len(self.params.z_targets))
self.z = self.params.z_targets
if self.previous is not None:
self.z += self.previous.params.z_targets[-1]
self.params.z_targets = np.concatenate((self.previous.z, self.params.z_targets))
self.z_inds += self.previous.z_inds[-1] + 1
self.fiber_lengths = self.all_params("length")
self.fiber_positions = [
(this[0], following[1])
for this, following in zip(self.fiber_lengths, [(None, 0.0)] + self.fiber_lengths)
]
def all_params(self, key: str) -> list[tuple[str, Any]]: self.fiber_positions = [(self.fibers[0].params.name, 0.0)]
"""returns the value of a parameter for each fiber self.params = Parameters(**self.fibers[0].params.dump_dict(False, False))
z_targets = list(self.params.z_targets)
Parameters self.z_indices = [(0, j) for j in range(self.params.z_num)]
---------- for i, fiber in enumerate(self.fibers[1:]):
key : str self.fiber_positions.append((fiber.params.name, z_targets[-1]))
name of the parameter z_targets += list(fiber.params.z_targets[1:] + z_targets[-1])
self.z_indices += [(i + 1, j) for j in range(1, fiber.params.z_num)]
Returns self.params.z_targets = np.array(z_targets)
------- self.params.length = self.params.z_targets[-1]
list[tuple[str, Any]] self.params.z_num = len(self.params.z_targets)
list of (fiber_name, param_value) tuples
"""
return list(reversed(self._all_params(key, [])))
def accumulate_params(self, key: str) -> Any:
"""returns the sum of all the values a parameter takes. Useful to
get the total length of the fiber, the total number of steps, etc.
Parameters
----------
key : str
name of the parameter
Returns
-------
Any
final sum
"""
return sum(el[1] for el in self.all_params(key))
def spectra( def spectra(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0 self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> Spectrum: ) -> Spectrum:
...
if z_descr is None: if z_descr is None:
out = [self.spectra(i, sim_ind) for i in range(self.total_num_steps)] out = [self.fibers[i].spectra(j, sim_ind) for i, j in self.z_indices]
else: else:
if isinstance(z_descr, (float, np.floating)): if isinstance(z_descr, (float, np.floating)):
return self.spectra(self.z_ind(z_descr), sim_ind) fib_ind, z_ind = self.z_ind(z_descr)
else: else:
z_ind = z_descr fib_ind, z_ind = self.z_indices[z_descr]
if 0 <= z_ind < self.z_inds[0]: out = self.fibers[fib_ind].spectra(z_ind, sim_ind)
return self.previous.spectra(z_ind, sim_ind)
elif z_ind < 0:
z_ind = self.total_num_steps + z_ind
if sim_ind is None:
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
else:
out = self._load_1(z_ind)
return Spectrum(out, self.params) return Spectrum(out, self.params)
def z_ind(self, pos: float) -> int:
if self.z[0] <= pos <= self.z[-1]:
return self.z_inds[np.argmin(np.abs(self.z - pos))]
elif 0 <= pos < self.z[0]:
return self.previous.z_ind(pos)
else:
raise ValueError(f"cannot match z={pos} with max length of {self.total_length}")
def fields( def fields(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0 self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> Spectrum: ) -> Spectrum:
return self.params.ifft(self.spectra(z_descr, sim_ind)) return self.params.ifft(self.spectra(z_descr, sim_ind))
# Plotting def z_ind(self, pos: float) -> tuple[int, int]:
if 0 <= pos <= self.params.length:
ind = np.argmin(np.abs(self.params.z_targets - pos))
return self.z_indices[ind]
else:
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
def plot_2D( def plot_2D(
self, self,
@@ -344,7 +323,65 @@ class SimulationSeries:
spectra.append(tmp) spectra.append(tmp)
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1)) return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
# Private # Magic methods
def __iter__(self) -> Iterator[Spectrum]:
for i, j in self.z_indices:
yield self.fibers[i].spectra(j, None)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path})"
def __eq__(self, other: SimulationSeries) -> bool:
return self.path == other.path and self.params == other.params
def __contains__(self, fiber: SimulatedFiber) -> bool:
return fiber in self.fibers
def __getitem__(self, key) -> Spectrum:
if isinstance(key, tuple):
return self.spectra(*key)
else:
return self.spectra(key, None)
class SimulatedFiber:
params: Parameters
t: np.ndarray
w: np.ndarray
def __init__(self, path: os.PathLike):
self.path = Path(path)
self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN)))
self.t = self.params.t
self.w = self.params.w
self.z = self.params.z_targets
def spectra(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> np.ndarray:
if z_descr is None:
out = [self.spectra(i, sim_ind) for i in range(self.params.z_num)]
else:
if isinstance(z_descr, (float, np.floating)):
return self.spectra(self.z_ind(z_descr), sim_ind)
else:
z_ind = z_descr
if z_ind < 0:
z_ind = self.params.z_num + z_ind
if sim_ind is None:
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
else:
out = self._load_1(z_ind)
return Spectrum(out, self.params)
def z_ind(self, pos: float) -> int:
if 0 <= pos <= self.z[-1]:
return np.argmin(np.abs(self.z - pos))
else:
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray: def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
"""loads a spectrum file """loads a spectrum file
@@ -362,40 +399,9 @@ class SimulationSeries:
loaded spectrum file loaded spectrum file
""" """
if sim_ind > 0: if sim_ind > 0:
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind - self.z_inds[0], sim_ind)) return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
else: else:
return load_spectrum(self.path / SPEC1_FN.format(z_ind - self.z_inds[0])) return load_spectrum(self.path / SPEC1_FN.format(z_ind))
def _all_params(self, key: str, l: list) -> list:
l.append((self.params.name, getattr(self.params, key)))
if self.previous is not None:
return self.previous._all_params(key, l)
return l
# Magic methods
def __iter__(self) -> Iterator[Spectrum]:
for i in range(self.total_num_steps):
yield self.spectra(i, None)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})" return f"{self.__class__.__name__}(path={self.path})"
def __eq__(self, other: SimulationSeries) -> bool:
return (
self.path == other.path
and self.params == other.params
and self.previous == other.previous
)
def __contains__(self, other: SimulationSeries) -> bool:
if other is self or other == self:
return True
if self.previous is not None:
return other in self.previous
def __getitem__(self, key) -> Spectrum:
if isinstance(key, tuple):
return self.spectra(*key)
else:
return self.spectra(key, None)