New SimulationSeries structure
This commit is contained in:
@@ -887,6 +887,7 @@ def uniform_axis(
|
|||||||
"""
|
"""
|
||||||
if new_axis_spec is None:
|
if new_axis_spec is None:
|
||||||
new_axis_spec = "unity"
|
new_axis_spec = "unity"
|
||||||
|
|
||||||
if isinstance(new_axis_spec, str) or callable(new_axis_spec):
|
if isinstance(new_axis_spec, str) or callable(new_axis_spec):
|
||||||
unit = units.get_unit(new_axis_spec)
|
unit = units.get_unit(new_axis_spec)
|
||||||
plt_range = PlotRange(unit.inv(axis.min()), unit.inv(axis.max()), new_axis_spec)
|
plt_range = PlotRange(unit.inv(axis.min()), unit.inv(axis.max()), new_axis_spec)
|
||||||
@@ -896,6 +897,7 @@ def uniform_axis(
|
|||||||
plt_range = new_axis_spec
|
plt_range = new_axis_spec
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Don't know how to interpret {new_axis_spec}")
|
raise TypeError(f"Don't know how to interpret {new_axis_spec}")
|
||||||
|
|
||||||
tmp_axis, ind, ext = sort_axis(axis, plt_range)
|
tmp_axis, ind, ext = sort_axis(axis, plt_range)
|
||||||
values = np.atleast_2d(values)
|
values = np.atleast_2d(values)
|
||||||
if np.allclose((diff := np.diff(tmp_axis))[0], diff):
|
if np.allclose((diff := np.diff(tmp_axis))[0], diff):
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Iterator, Optional, Union
|
from typing import Callable, Iterator, Optional, Union
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -120,105 +120,84 @@ class Spectrum(np.ndarray):
|
|||||||
|
|
||||||
|
|
||||||
class SimulationSeries:
|
class SimulationSeries:
|
||||||
|
"""
|
||||||
|
SimulationsSeries are the interface the user should use to load and
|
||||||
|
interact with simulation data. The object loads each fiber of the simulation
|
||||||
|
into a separate object and exposes convenience methods to make the series behave
|
||||||
|
as a single fiber.
|
||||||
|
|
||||||
|
It should be noted that the last spectrum of a fiber and the first one of the next
|
||||||
|
fibers are identical. Therefore, SimulationSeries object will return fewer datapoints
|
||||||
|
than when manually mergin the corresponding data.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
path: Path
|
path: Path
|
||||||
|
fibers: list[SimulatedFiber]
|
||||||
params: Parameters
|
params: Parameters
|
||||||
total_length: float
|
z_indices: list[tuple[int, int]]
|
||||||
total_num_steps: int
|
|
||||||
previous: SimulationSeries = None
|
|
||||||
fiber_lengths: list[tuple[str, float]]
|
|
||||||
fiber_positions: list[tuple[str, float]]
|
fiber_positions: list[tuple[str, float]]
|
||||||
z_inds: np.ndarray
|
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike):
|
def __init__(self, path: os.PathLike):
|
||||||
|
"""Create a SimulationSeries
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : os.PathLike
|
||||||
|
path to the last fiber of the series
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
No simulation found in specified directory
|
||||||
|
"""
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
for self.path in simulations_list(path):
|
for self.path in simulations_list(path):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"No simulation in {path}")
|
raise FileNotFoundError(f"No simulation in {path}")
|
||||||
self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN)))
|
self.fibers = [SimulatedFiber(self.path)]
|
||||||
self.t = self.params.t
|
while (p := self.fibers[-1].params.prev_data_dir) is not None:
|
||||||
self.w = self.params.w
|
self.fibers.append(SimulatedFiber(p))
|
||||||
if self.params.prev_data_dir is not None:
|
self.fibers = self.fibers[::-1]
|
||||||
self.previous = SimulationSeries(self.params.prev_data_dir)
|
|
||||||
self.total_length = self.accumulate_params("length")
|
|
||||||
self.total_num_steps = self.accumulate_params("z_num")
|
|
||||||
self.z_inds = np.arange(len(self.params.z_targets))
|
|
||||||
self.z = self.params.z_targets
|
|
||||||
if self.previous is not None:
|
|
||||||
self.z += self.previous.params.z_targets[-1]
|
|
||||||
self.params.z_targets = np.concatenate((self.previous.z, self.params.z_targets))
|
|
||||||
self.z_inds += self.previous.z_inds[-1] + 1
|
|
||||||
self.fiber_lengths = self.all_params("length")
|
|
||||||
self.fiber_positions = [
|
|
||||||
(this[0], following[1])
|
|
||||||
for this, following in zip(self.fiber_lengths, [(None, 0.0)] + self.fiber_lengths)
|
|
||||||
]
|
|
||||||
|
|
||||||
def all_params(self, key: str) -> list[tuple[str, Any]]:
|
self.fiber_positions = [(self.fibers[0].params.name, 0.0)]
|
||||||
"""returns the value of a parameter for each fiber
|
self.params = Parameters(**self.fibers[0].params.dump_dict(False, False))
|
||||||
|
z_targets = list(self.params.z_targets)
|
||||||
Parameters
|
self.z_indices = [(0, j) for j in range(self.params.z_num)]
|
||||||
----------
|
for i, fiber in enumerate(self.fibers[1:]):
|
||||||
key : str
|
self.fiber_positions.append((fiber.params.name, z_targets[-1]))
|
||||||
name of the parameter
|
z_targets += list(fiber.params.z_targets[1:] + z_targets[-1])
|
||||||
|
self.z_indices += [(i + 1, j) for j in range(1, fiber.params.z_num)]
|
||||||
Returns
|
self.params.z_targets = np.array(z_targets)
|
||||||
-------
|
self.params.length = self.params.z_targets[-1]
|
||||||
list[tuple[str, Any]]
|
self.params.z_num = len(self.params.z_targets)
|
||||||
list of (fiber_name, param_value) tuples
|
|
||||||
"""
|
|
||||||
return list(reversed(self._all_params(key, [])))
|
|
||||||
|
|
||||||
def accumulate_params(self, key: str) -> Any:
|
|
||||||
"""returns the sum of all the values a parameter takes. Useful to
|
|
||||||
get the total length of the fiber, the total number of steps, etc.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
key : str
|
|
||||||
name of the parameter
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Any
|
|
||||||
final sum
|
|
||||||
"""
|
|
||||||
return sum(el[1] for el in self.all_params(key))
|
|
||||||
|
|
||||||
def spectra(
|
def spectra(
|
||||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||||
) -> Spectrum:
|
) -> Spectrum:
|
||||||
|
...
|
||||||
if z_descr is None:
|
if z_descr is None:
|
||||||
out = [self.spectra(i, sim_ind) for i in range(self.total_num_steps)]
|
out = [self.fibers[i].spectra(j, sim_ind) for i, j in self.z_indices]
|
||||||
else:
|
else:
|
||||||
if isinstance(z_descr, (float, np.floating)):
|
if isinstance(z_descr, (float, np.floating)):
|
||||||
return self.spectra(self.z_ind(z_descr), sim_ind)
|
fib_ind, z_ind = self.z_ind(z_descr)
|
||||||
else:
|
else:
|
||||||
z_ind = z_descr
|
fib_ind, z_ind = self.z_indices[z_descr]
|
||||||
if 0 <= z_ind < self.z_inds[0]:
|
out = self.fibers[fib_ind].spectra(z_ind, sim_ind)
|
||||||
return self.previous.spectra(z_ind, sim_ind)
|
|
||||||
elif z_ind < 0:
|
|
||||||
z_ind = self.total_num_steps + z_ind
|
|
||||||
if sim_ind is None:
|
|
||||||
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
|
|
||||||
else:
|
|
||||||
out = self._load_1(z_ind)
|
|
||||||
return Spectrum(out, self.params)
|
return Spectrum(out, self.params)
|
||||||
|
|
||||||
def z_ind(self, pos: float) -> int:
|
|
||||||
if self.z[0] <= pos <= self.z[-1]:
|
|
||||||
return self.z_inds[np.argmin(np.abs(self.z - pos))]
|
|
||||||
elif 0 <= pos < self.z[0]:
|
|
||||||
return self.previous.z_ind(pos)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"cannot match z={pos} with max length of {self.total_length}")
|
|
||||||
|
|
||||||
def fields(
|
def fields(
|
||||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||||
) -> Spectrum:
|
) -> Spectrum:
|
||||||
return self.params.ifft(self.spectra(z_descr, sim_ind))
|
return self.params.ifft(self.spectra(z_descr, sim_ind))
|
||||||
|
|
||||||
# Plotting
|
def z_ind(self, pos: float) -> tuple[int, int]:
|
||||||
|
if 0 <= pos <= self.params.length:
|
||||||
|
ind = np.argmin(np.abs(self.params.z_targets - pos))
|
||||||
|
return self.z_indices[ind]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
|
||||||
|
|
||||||
def plot_2D(
|
def plot_2D(
|
||||||
self,
|
self,
|
||||||
@@ -344,7 +323,65 @@ class SimulationSeries:
|
|||||||
spectra.append(tmp)
|
spectra.append(tmp)
|
||||||
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
|
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
|
||||||
|
|
||||||
# Private
|
# Magic methods
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Spectrum]:
|
||||||
|
for i, j in self.z_indices:
|
||||||
|
yield self.fibers[i].spectra(j, None)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(path={self.path})"
|
||||||
|
|
||||||
|
def __eq__(self, other: SimulationSeries) -> bool:
|
||||||
|
return self.path == other.path and self.params == other.params
|
||||||
|
|
||||||
|
def __contains__(self, fiber: SimulatedFiber) -> bool:
|
||||||
|
return fiber in self.fibers
|
||||||
|
|
||||||
|
def __getitem__(self, key) -> Spectrum:
|
||||||
|
if isinstance(key, tuple):
|
||||||
|
return self.spectra(*key)
|
||||||
|
else:
|
||||||
|
return self.spectra(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
class SimulatedFiber:
|
||||||
|
params: Parameters
|
||||||
|
t: np.ndarray
|
||||||
|
w: np.ndarray
|
||||||
|
|
||||||
|
def __init__(self, path: os.PathLike):
|
||||||
|
self.path = Path(path)
|
||||||
|
self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN)))
|
||||||
|
self.t = self.params.t
|
||||||
|
self.w = self.params.w
|
||||||
|
self.z = self.params.z_targets
|
||||||
|
|
||||||
|
def spectra(
|
||||||
|
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||||
|
) -> np.ndarray:
|
||||||
|
if z_descr is None:
|
||||||
|
out = [self.spectra(i, sim_ind) for i in range(self.params.z_num)]
|
||||||
|
else:
|
||||||
|
if isinstance(z_descr, (float, np.floating)):
|
||||||
|
return self.spectra(self.z_ind(z_descr), sim_ind)
|
||||||
|
else:
|
||||||
|
z_ind = z_descr
|
||||||
|
|
||||||
|
if z_ind < 0:
|
||||||
|
z_ind = self.params.z_num + z_ind
|
||||||
|
|
||||||
|
if sim_ind is None:
|
||||||
|
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
|
||||||
|
else:
|
||||||
|
out = self._load_1(z_ind)
|
||||||
|
return Spectrum(out, self.params)
|
||||||
|
|
||||||
|
def z_ind(self, pos: float) -> int:
|
||||||
|
if 0 <= pos <= self.z[-1]:
|
||||||
|
return np.argmin(np.abs(self.z - pos))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
|
||||||
|
|
||||||
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
|
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
|
||||||
"""loads a spectrum file
|
"""loads a spectrum file
|
||||||
@@ -362,40 +399,9 @@ class SimulationSeries:
|
|||||||
loaded spectrum file
|
loaded spectrum file
|
||||||
"""
|
"""
|
||||||
if sim_ind > 0:
|
if sim_ind > 0:
|
||||||
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind - self.z_inds[0], sim_ind))
|
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
|
||||||
else:
|
else:
|
||||||
return load_spectrum(self.path / SPEC1_FN.format(z_ind - self.z_inds[0]))
|
return load_spectrum(self.path / SPEC1_FN.format(z_ind))
|
||||||
|
|
||||||
def _all_params(self, key: str, l: list) -> list:
|
|
||||||
l.append((self.params.name, getattr(self.params, key)))
|
|
||||||
if self.previous is not None:
|
|
||||||
return self.previous._all_params(key, l)
|
|
||||||
return l
|
|
||||||
|
|
||||||
# Magic methods
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Spectrum]:
|
|
||||||
for i in range(self.total_num_steps):
|
|
||||||
yield self.spectra(i, None)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})"
|
return f"{self.__class__.__name__}(path={self.path})"
|
||||||
|
|
||||||
def __eq__(self, other: SimulationSeries) -> bool:
|
|
||||||
return (
|
|
||||||
self.path == other.path
|
|
||||||
and self.params == other.params
|
|
||||||
and self.previous == other.previous
|
|
||||||
)
|
|
||||||
|
|
||||||
def __contains__(self, other: SimulationSeries) -> bool:
|
|
||||||
if other is self or other == self:
|
|
||||||
return True
|
|
||||||
if self.previous is not None:
|
|
||||||
return other in self.previous
|
|
||||||
|
|
||||||
def __getitem__(self, key) -> Spectrum:
|
|
||||||
if isinstance(key, tuple):
|
|
||||||
return self.spectra(*key)
|
|
||||||
else:
|
|
||||||
return self.spectra(key, None)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user