diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index d906e3c..691ca7c 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -887,6 +887,7 @@ def uniform_axis( """ if new_axis_spec is None: new_axis_spec = "unity" + if isinstance(new_axis_spec, str) or callable(new_axis_spec): unit = units.get_unit(new_axis_spec) plt_range = PlotRange(unit.inv(axis.min()), unit.inv(axis.max()), new_axis_spec) @@ -896,6 +897,7 @@ def uniform_axis( plt_range = new_axis_spec else: raise TypeError(f"Don't know how to interpret {new_axis_spec}") + tmp_axis, ind, ext = sort_axis(axis, plt_range) values = np.atleast_2d(values) if np.allclose((diff := np.diff(tmp_axis))[0], diff): diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index dd8acc2..2ddfe5f 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -2,7 +2,7 @@ from __future__ import annotations import os from pathlib import Path -from typing import Any, Callable, Iterator, Optional, Union +from typing import Callable, Iterator, Optional, Union import matplotlib.pyplot as plt import numpy as np @@ -120,105 +120,84 @@ class Spectrum(np.ndarray): class SimulationSeries: + """ + SimulationsSeries are the interface the user should use to load and + interact with simulation data. The object loads each fiber of the simulation + into a separate object and exposes convenience methods to make the series behave + as a single fiber. + + It should be noted that the last spectrum of a fiber and the first one of the next + fibers are identical. Therefore, SimulationSeries object will return fewer datapoints + than when manually mergin the corresponding data. + + """ + path: Path + fibers: list[SimulatedFiber] params: Parameters - total_length: float - total_num_steps: int - previous: SimulationSeries = None - fiber_lengths: list[tuple[str, float]] + z_indices: list[tuple[int, int]] fiber_positions: list[tuple[str, float]] - z_inds: np.ndarray def __init__(self, path: os.PathLike): + """Create a SimulationSeries + + Parameters + ---------- + path : os.PathLike + path to the last fiber of the series + + Raises + ------ + FileNotFoundError + No simulation found in specified directory + """ self.logger = get_logger() for self.path in simulations_list(path): break else: raise FileNotFoundError(f"No simulation in {path}") - self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN))) - self.t = self.params.t - self.w = self.params.w - if self.params.prev_data_dir is not None: - self.previous = SimulationSeries(self.params.prev_data_dir) - self.total_length = self.accumulate_params("length") - self.total_num_steps = self.accumulate_params("z_num") - self.z_inds = np.arange(len(self.params.z_targets)) - self.z = self.params.z_targets - if self.previous is not None: - self.z += self.previous.params.z_targets[-1] - self.params.z_targets = np.concatenate((self.previous.z, self.params.z_targets)) - self.z_inds += self.previous.z_inds[-1] + 1 - self.fiber_lengths = self.all_params("length") - self.fiber_positions = [ - (this[0], following[1]) - for this, following in zip(self.fiber_lengths, [(None, 0.0)] + self.fiber_lengths) - ] + self.fibers = [SimulatedFiber(self.path)] + while (p := self.fibers[-1].params.prev_data_dir) is not None: + self.fibers.append(SimulatedFiber(p)) + self.fibers = self.fibers[::-1] - def all_params(self, key: str) -> list[tuple[str, Any]]: - """returns the value of a parameter for each fiber - - Parameters - ---------- - key : str - name of the parameter - - Returns - ------- - list[tuple[str, Any]] - list of (fiber_name, param_value) tuples - """ - return list(reversed(self._all_params(key, []))) - - def accumulate_params(self, key: str) -> Any: - """returns the sum of all the values a parameter takes. Useful to - get the total length of the fiber, the total number of steps, etc. - - Parameters - ---------- - key : str - name of the parameter - - Returns - ------- - Any - final sum - """ - return sum(el[1] for el in self.all_params(key)) + self.fiber_positions = [(self.fibers[0].params.name, 0.0)] + self.params = Parameters(**self.fibers[0].params.dump_dict(False, False)) + z_targets = list(self.params.z_targets) + self.z_indices = [(0, j) for j in range(self.params.z_num)] + for i, fiber in enumerate(self.fibers[1:]): + self.fiber_positions.append((fiber.params.name, z_targets[-1])) + z_targets += list(fiber.params.z_targets[1:] + z_targets[-1]) + self.z_indices += [(i + 1, j) for j in range(1, fiber.params.z_num)] + self.params.z_targets = np.array(z_targets) + self.params.length = self.params.z_targets[-1] + self.params.z_num = len(self.params.z_targets) def spectra( self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0 ) -> Spectrum: + ... if z_descr is None: - out = [self.spectra(i, sim_ind) for i in range(self.total_num_steps)] + out = [self.fibers[i].spectra(j, sim_ind) for i, j in self.z_indices] else: if isinstance(z_descr, (float, np.floating)): - return self.spectra(self.z_ind(z_descr), sim_ind) + fib_ind, z_ind = self.z_ind(z_descr) else: - z_ind = z_descr - if 0 <= z_ind < self.z_inds[0]: - return self.previous.spectra(z_ind, sim_ind) - elif z_ind < 0: - z_ind = self.total_num_steps + z_ind - if sim_ind is None: - out = [self._load_1(z_ind, i) for i in range(self.params.repeat)] - else: - out = self._load_1(z_ind) + fib_ind, z_ind = self.z_indices[z_descr] + out = self.fibers[fib_ind].spectra(z_ind, sim_ind) return Spectrum(out, self.params) - def z_ind(self, pos: float) -> int: - if self.z[0] <= pos <= self.z[-1]: - return self.z_inds[np.argmin(np.abs(self.z - pos))] - elif 0 <= pos < self.z[0]: - return self.previous.z_ind(pos) - else: - raise ValueError(f"cannot match z={pos} with max length of {self.total_length}") - def fields( self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0 ) -> Spectrum: return self.params.ifft(self.spectra(z_descr, sim_ind)) - # Plotting + def z_ind(self, pos: float) -> tuple[int, int]: + if 0 <= pos <= self.params.length: + ind = np.argmin(np.abs(self.params.z_targets - pos)) + return self.z_indices[ind] + else: + raise ValueError(f"cannot match z={pos} with max length of {self.params.length}") def plot_2D( self, @@ -344,7 +323,65 @@ class SimulationSeries: spectra.append(tmp) return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1)) - # Private + # Magic methods + + def __iter__(self) -> Iterator[Spectrum]: + for i, j in self.z_indices: + yield self.fibers[i].spectra(j, None) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(path={self.path})" + + def __eq__(self, other: SimulationSeries) -> bool: + return self.path == other.path and self.params == other.params + + def __contains__(self, fiber: SimulatedFiber) -> bool: + return fiber in self.fibers + + def __getitem__(self, key) -> Spectrum: + if isinstance(key, tuple): + return self.spectra(*key) + else: + return self.spectra(key, None) + + +class SimulatedFiber: + params: Parameters + t: np.ndarray + w: np.ndarray + + def __init__(self, path: os.PathLike): + self.path = Path(path) + self.params = Parameters(**translate_parameters(load_toml(self.path / PARAM_FN))) + self.t = self.params.t + self.w = self.params.w + self.z = self.params.z_targets + + def spectra( + self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0 + ) -> np.ndarray: + if z_descr is None: + out = [self.spectra(i, sim_ind) for i in range(self.params.z_num)] + else: + if isinstance(z_descr, (float, np.floating)): + return self.spectra(self.z_ind(z_descr), sim_ind) + else: + z_ind = z_descr + + if z_ind < 0: + z_ind = self.params.z_num + z_ind + + if sim_ind is None: + out = [self._load_1(z_ind, i) for i in range(self.params.repeat)] + else: + out = self._load_1(z_ind) + return Spectrum(out, self.params) + + def z_ind(self, pos: float) -> int: + if 0 <= pos <= self.z[-1]: + return np.argmin(np.abs(self.z - pos)) + else: + raise ValueError(f"cannot match z={pos} with max length of {self.params.length}") def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray: """loads a spectrum file @@ -362,40 +399,9 @@ class SimulationSeries: loaded spectrum file """ if sim_ind > 0: - return load_spectrum(self.path / SPEC1_FN_N.format(z_ind - self.z_inds[0], sim_ind)) + return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind)) else: - return load_spectrum(self.path / SPEC1_FN.format(z_ind - self.z_inds[0])) - - def _all_params(self, key: str, l: list) -> list: - l.append((self.params.name, getattr(self.params, key))) - if self.previous is not None: - return self.previous._all_params(key, l) - return l - - # Magic methods - - def __iter__(self) -> Iterator[Spectrum]: - for i in range(self.total_num_steps): - yield self.spectra(i, None) + return load_spectrum(self.path / SPEC1_FN.format(z_ind)) def __repr__(self) -> str: - return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})" - - def __eq__(self, other: SimulationSeries) -> bool: - return ( - self.path == other.path - and self.params == other.params - and self.previous == other.previous - ) - - def __contains__(self, other: SimulationSeries) -> bool: - if other is self or other == self: - return True - if self.previous is not None: - return other in self.previous - - def __getitem__(self, key) -> Spectrum: - if isinstance(key, tuple): - return self.spectra(*key) - else: - return self.spectra(key, None) + return f"{self.__class__.__name__}(path={self.path})"