working converter
This commit is contained in:
@@ -27,84 +27,67 @@ def load_config_sequence(path: os.PathLike) -> tuple[list[Path], list[dict[str,
|
||||
|
||||
|
||||
def convert_sim_folder(path: os.PathLike):
|
||||
path = Path(path)
|
||||
path = Path(path).resolve()
|
||||
config_paths, configs = load_config_sequence(path)
|
||||
master_config = dict(name=path.name, Fiber=configs)
|
||||
with open(path / "initial_config.toml", "w") as f:
|
||||
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
|
||||
configuration = Configuration(path / "initial_config.toml")
|
||||
new_fiber_paths: list[Path] = [
|
||||
path / fiber_folder(i, path.name, cfg["name"]) for i, cfg in enumerate(configs)
|
||||
]
|
||||
for p in new_fiber_paths:
|
||||
p.mkdir(exist_ok=True)
|
||||
repeat = configs[0].get("repeat", 1)
|
||||
|
||||
configuration = Configuration(path / "initial_config.toml", final_output_path=path)
|
||||
pbar = PBars(configuration.total_num_steps, "Converting")
|
||||
|
||||
old_paths: dict[Path, VariationDescriptor] = {
|
||||
path / descr.branch.formatted_descriptor(): (descr, param.final_path)
|
||||
for descr, param in configuration
|
||||
}
|
||||
new_paths: dict[VariationDescriptor, Parameters] = dict(configuration)
|
||||
old_paths: Set[Path] = set()
|
||||
old2new: list[tuple[Path, VariationDescriptor, Parameters, tuple[int, int]]] = []
|
||||
for descriptor, params in configuration.iterate_single_fiber(-1):
|
||||
old_path = path / descriptor.branch.formatted_descriptor()
|
||||
if not Path(old_path).is_dir():
|
||||
raise FileNotFoundError(f"missing {old_path} from {path}. Aborting.")
|
||||
old_paths.add(old_path)
|
||||
for d in descriptor.iter_parents():
|
||||
z_num_start = sum(c["z_num"] for c in configs[: d.num_fibers - 1])
|
||||
z_limits = (z_num_start, z_num_start + params.z_num)
|
||||
old2new.append((old_path, d, new_paths[d], z_limits))
|
||||
|
||||
# create map from old to new path
|
||||
|
||||
pprint(old_paths)
|
||||
quit()
|
||||
for p in old_paths:
|
||||
if not p.is_dir():
|
||||
raise FileNotFoundError(f"missing {p} from {path}")
|
||||
processed_paths: Set[Path] = set()
|
||||
for old_variation_path, descriptor in old_paths.items(): # fiberA=0, fiber B=0
|
||||
vary_parts = old_variation_path.name.split("fiber")[1:]
|
||||
identifiers = [
|
||||
"".join("fiber" + el for el in vary_parts[: i + 1]).strip()
|
||||
for i in range(len(vary_parts))
|
||||
]
|
||||
cum_z_num = 0
|
||||
for i, (fiber_path, new_identifier) in enumerate(zip(new_fiber_paths, identifiers)):
|
||||
config = descriptor.update_config(configs[i], i)
|
||||
new_variation_path = fiber_path / new_identifier
|
||||
z_num = config["z_num"]
|
||||
move = new_variation_path not in processed_paths
|
||||
os.makedirs(new_variation_path, exist_ok=True)
|
||||
processed_paths.add(new_variation_path)
|
||||
processed_specs: Set[VariationDescriptor] = set()
|
||||
|
||||
for spec_num in range(cum_z_num, cum_z_num + z_num):
|
||||
old_spec = old_variation_path / SPECN_FN1.format(spec_num)
|
||||
if move:
|
||||
spec_data = np.load(old_spec)
|
||||
for j, spec1 in enumerate(spec_data):
|
||||
if j == 0:
|
||||
np.save(
|
||||
new_variation_path / SPEC1_FN.format(spec_num - cum_z_num), spec1
|
||||
)
|
||||
else:
|
||||
np.save(
|
||||
new_variation_path / SPEC1_FN_N.format(spec_num - cum_z_num, j),
|
||||
spec1,
|
||||
)
|
||||
pbar.update()
|
||||
else:
|
||||
pbar.update(value=repeat)
|
||||
old_spec.unlink()
|
||||
if move:
|
||||
if i > 0:
|
||||
config["prev_data_dir"] = str(
|
||||
(new_fiber_paths[i - 1] / identifiers[i - 1]).resolve()
|
||||
)
|
||||
params = Parameters(**config)
|
||||
params.compute()
|
||||
save_parameters(params.prepare_for_dump(), new_variation_path)
|
||||
cum_z_num += z_num
|
||||
(old_variation_path / PARAM_FN).unlink()
|
||||
(old_variation_path / Z_FN).unlink()
|
||||
old_variation_path.rmdir()
|
||||
for old_path, descr, new_params, (start_z, end_z) in old2new:
|
||||
move_specs = descr not in processed_specs
|
||||
processed_specs.add(descr)
|
||||
if (parent := descr.parent) is not None:
|
||||
new_params.prev_data_dir = str(new_paths[parent].final_path)
|
||||
save_parameters(new_params.prepare_for_dump(), new_params.final_path)
|
||||
for spec_num in range(start_z, end_z):
|
||||
old_spec = old_path / SPECN_FN1.format(spec_num)
|
||||
if move_specs:
|
||||
_mv_specs(pbar, new_params, start_z, spec_num, old_spec)
|
||||
old_spec.unlink()
|
||||
if old_path not in processed_paths:
|
||||
(old_path / PARAM_FN).unlink()
|
||||
(old_path / Z_FN).unlink()
|
||||
processed_paths.add(old_path)
|
||||
|
||||
for old_path in processed_paths:
|
||||
old_path.rmdir()
|
||||
|
||||
for cp in config_paths:
|
||||
cp.unlink()
|
||||
|
||||
|
||||
def _mv_specs(pbar: PBars, new_params: Parameters, start_z: int, spec_num: int, old_spec: Path):
|
||||
os.makedirs(new_params.final_path, exist_ok=True)
|
||||
spec_data = np.load(old_spec)
|
||||
for j, spec1 in enumerate(spec_data):
|
||||
if j == 0:
|
||||
np.save(new_params.final_path / SPEC1_FN.format(spec_num - start_z), spec1)
|
||||
else:
|
||||
np.save(
|
||||
new_params.final_path / SPEC1_FN_N.format(spec_num - start_z, j),
|
||||
spec1,
|
||||
)
|
||||
pbar.update()
|
||||
|
||||
|
||||
def main():
|
||||
convert_sim_folder(sys.argv[1])
|
||||
|
||||
|
||||
@@ -12,7 +12,18 @@ from copy import copy, deepcopy
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from functools import cache, lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib import isin
|
||||
@@ -523,7 +534,7 @@ class Parameters(_AbstractParameters):
|
||||
@property
|
||||
def final_path(self) -> Path:
|
||||
if self.output_path is not None:
|
||||
return update_path(self.output_path)
|
||||
return Path(update_path(self.output_path))
|
||||
return None
|
||||
|
||||
|
||||
@@ -820,22 +831,26 @@ class Configuration:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
final_config_path: os.PathLike,
|
||||
config_path: os.PathLike,
|
||||
overwrite: bool = True,
|
||||
wait: bool = False,
|
||||
skip_callback: Callable[[int], None] = None,
|
||||
final_output_path: os.PathLike = None,
|
||||
):
|
||||
self.logger = get_logger(__name__)
|
||||
self.wait = wait
|
||||
|
||||
self.overwrite = overwrite
|
||||
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
|
||||
self.final_path, self.fiber_configs = utils.load_config_sequence(config_path)
|
||||
self.final_path = env.get(env.OUTPUT_PATH, self.final_path)
|
||||
if final_output_path is not None:
|
||||
self.final_path = final_output_path
|
||||
self.final_path = utils.ensure_folder(
|
||||
Path(env.get(env.OUTPUT_PATH, self.final_path)),
|
||||
Path(self.final_path),
|
||||
mkdir=False,
|
||||
prevent_overwrite=not self.overwrite,
|
||||
)
|
||||
self.master_config = self.fiber_configs[0]
|
||||
self.master_config = self.fiber_configs[0].copy()
|
||||
self.name = self.final_path.name
|
||||
self.z_num = 0
|
||||
self.total_num_steps = 0
|
||||
@@ -874,7 +889,7 @@ class Configuration:
|
||||
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
|
||||
|
||||
def __build_base_config(self):
|
||||
cfg = self.fiber_configs[0].copy()
|
||||
cfg = self.master_config.copy()
|
||||
vary = cfg.pop("variable", {})
|
||||
return cfg | {k: v[0] for k, v in vary.items()}
|
||||
|
||||
@@ -887,15 +902,11 @@ class Configuration:
|
||||
if len(v) == 0:
|
||||
raise ValueError(f"variable parameter {k!r} must not be empty")
|
||||
|
||||
def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]:
|
||||
def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]:
|
||||
for i in range(self.num_fibers):
|
||||
for sim_config in self.iterate_single_fiber(i):
|
||||
params = Parameters(**sim_config.config)
|
||||
yield sim_config.descriptor, params
|
||||
yield from self.iterate_single_fiber(i)
|
||||
|
||||
def iterate_single_fiber(
|
||||
self, index: int
|
||||
) -> Generator["Configuration.__SimConfig", None, None]:
|
||||
def iterate_single_fiber(self, index: int) -> Iterator[tuple[VariationDescriptor, Parameters]]:
|
||||
"""iterates through the parameters of only one fiber. It takes care of recovering partially
|
||||
completed simulations, skipping complete ones and waiting for the previous fiber to finish
|
||||
|
||||
@@ -909,6 +920,8 @@ class Configuration:
|
||||
__SimConfig
|
||||
configuration obj
|
||||
"""
|
||||
if index < 0:
|
||||
index = self.num_fibers + index
|
||||
sim_dict: dict[Path, Configuration.__SimConfig] = {}
|
||||
for descriptor in self.variationer.iterate(index):
|
||||
cfg = descriptor.update_config(self.fiber_configs[index])
|
||||
@@ -929,7 +942,7 @@ class Configuration:
|
||||
task, config_dict = self.__decide(sim_config)
|
||||
if task == self.Action.RUN:
|
||||
sim_dict.pop(data_dir)
|
||||
yield sim_config
|
||||
yield sim_config.descriptor, Parameters(**sim_config.config)
|
||||
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
|
||||
self.skip_callback(config_dict["recovery_last_stored"])
|
||||
break
|
||||
|
||||
@@ -121,6 +121,19 @@ class VariationDescriptor(BaseModel):
|
||||
_format_registry: dict[str, Callable[..., str]] = {}
|
||||
__ids: dict[int, int] = {}
|
||||
|
||||
@classmethod
|
||||
def register_formatter(cls, p_name: str, func: Callable[..., str]):
|
||||
"""register a function that formats a particular parameter
|
||||
|
||||
Parameters
|
||||
----------
|
||||
p_name : str
|
||||
name of the parameter
|
||||
func : Callable[..., str]
|
||||
function that takes as single argument the value of the parameter and returns a string
|
||||
"""
|
||||
cls._format_registry[p_name] = func
|
||||
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
|
||||
@@ -152,19 +165,6 @@ class VariationDescriptor(BaseModel):
|
||||
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_formatter(cls, p_name: str, func: Callable[..., str]):
|
||||
"""register a function that formats a particular parameter
|
||||
|
||||
Parameters
|
||||
----------
|
||||
p_name : str
|
||||
name of the parameter
|
||||
func : Callable[..., str]
|
||||
function that takes as single argument the value of the parameter and returns a string
|
||||
"""
|
||||
cls._format_registry[p_name] = func
|
||||
|
||||
def format_value(self, name: str, value) -> str:
|
||||
if value is True or value is False:
|
||||
return str(value)
|
||||
@@ -201,9 +201,15 @@ class VariationDescriptor(BaseModel):
|
||||
def __ge__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr >= other.raw_descr
|
||||
|
||||
def __eq__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr == other.raw_descr
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.raw_descr)
|
||||
|
||||
def __contains__(self, other: "VariationDescriptor") -> bool:
|
||||
return all(el in self.raw_descr for el in other.raw_descr)
|
||||
|
||||
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
|
||||
"""updates a dictionary with the value of the descriptor
|
||||
|
||||
@@ -223,6 +229,11 @@ class VariationDescriptor(BaseModel):
|
||||
out_cfg.pop("variable", None)
|
||||
return out_cfg | {k: v for k, v in self.raw_descr[index]}
|
||||
|
||||
def iter_parents(self) -> Iterator["VariationDescriptor"]:
|
||||
if (p := self.parent) is not None:
|
||||
yield from p.iter_parents()
|
||||
yield self
|
||||
|
||||
@property
|
||||
def flat(self) -> list[tuple[str, Any]]:
|
||||
out = []
|
||||
@@ -260,6 +271,10 @@ class VariationDescriptor(BaseModel):
|
||||
raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator
|
||||
)
|
||||
|
||||
@property
|
||||
def num_fibers(self) -> int:
|
||||
return len(self.raw_descr)
|
||||
|
||||
|
||||
class BranchDescriptor(VariationDescriptor):
|
||||
__ids: dict[int, int] = {}
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union
|
||||
from typing import Any, Callable, Iterator, Optional, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, DirectoryPath, root_validator
|
||||
|
||||
from . import math
|
||||
from ._utils import load_spectrum
|
||||
from ._utils.parameter import Parameters
|
||||
from ._utils.utils import PlotRange, iter_simulations
|
||||
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1
|
||||
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N
|
||||
from .logger import get_logger
|
||||
from .physics import pulse, units
|
||||
from .plotting import (
|
||||
@@ -111,7 +108,7 @@ class Spectrum(np.ndarray):
|
||||
return self.params.l[np.argmax(self.wl_int, axis=-1)]
|
||||
return np.array([s.wl_max for s in self])
|
||||
|
||||
def mask_wl(self, pos: float, width: float) -> "Spectrum":
|
||||
def mask_wl(self, pos: float, width: float) -> Spectrum:
|
||||
return self * np.exp(
|
||||
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
|
||||
)
|
||||
@@ -353,300 +350,3 @@ class SimulationSeries:
|
||||
return self.spectra(*key)
|
||||
else:
|
||||
return self.spectra(key, None)
|
||||
|
||||
|
||||
class Pulse(Sequence):
|
||||
def __new__(cls, path: os.PathLike):
|
||||
warnings.warn(
|
||||
"You are using the legacy version of the pulse loader. "
|
||||
"Please consider updating your data with scgenerator.convert_sim_folder "
|
||||
"and loading data with the SimulationSeries class"
|
||||
)
|
||||
if (Path(path) / SPECN_FN1.format(0)).exists():
|
||||
return LegacyPulse(path)
|
||||
return SimulationSeries(path)
|
||||
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LegacyPulse(Sequence):
|
||||
def __init__(self, path: os.PathLike):
|
||||
"""load a data folder as a pulse
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
path to the data (folder containing .npy files)
|
||||
default_ind : int | Iterable[int], optional
|
||||
default indices to be loaded, by default None
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
path does not contain proper data
|
||||
"""
|
||||
self.logger = get_logger(__name__)
|
||||
self.path = Path(path)
|
||||
|
||||
if not self.path.is_dir():
|
||||
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
||||
|
||||
self.params = Parameters.load(self.path / "params.toml")
|
||||
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
||||
if self.params.fiber_map is None:
|
||||
self.params.fiber_map = [(0.0, self.params.name)]
|
||||
|
||||
try:
|
||||
self.z = np.load(os.path.join(path, "z.npy"))
|
||||
except FileNotFoundError:
|
||||
if self.params is not None:
|
||||
self.z = self.params.z_targets
|
||||
else:
|
||||
raise
|
||||
self.nmax = len(list(self.path.glob("spectra_*.npy")))
|
||||
if self.nmax <= 0:
|
||||
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
||||
|
||||
self.t = self.params.t
|
||||
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
||||
self.w_order = np.argsort(w)
|
||||
self.w = w
|
||||
self.wl = units.m.inv(self.w)
|
||||
self.params.w = self.w
|
||||
self.params.z_targets = self.z
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
similar to all_spectra but works as an iterator
|
||||
"""
|
||||
|
||||
self.logger.debug(f"iterating through {self.path}")
|
||||
for i in range(self.nmax):
|
||||
yield self._load1(i)
|
||||
|
||||
def __len__(self):
|
||||
return self.nmax
|
||||
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
return self.all_spectra(key)
|
||||
|
||||
def intensity(self, unit):
|
||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||
x_axis = unit.inv(self.w)
|
||||
else:
|
||||
x_axis = unit.inv(self.t)
|
||||
|
||||
order = np.argsort(x_axis)
|
||||
func = dict(
|
||||
WL=self._to_wl_int,
|
||||
FREQ=self._to_freq_int,
|
||||
AFREQ=self._to_afreq_int,
|
||||
TIME=self._to_time_int,
|
||||
)[unit.type]
|
||||
|
||||
for spec in self:
|
||||
yield x_axis[order], func(spec)[:, order]
|
||||
|
||||
def _to_wl_int(self, spectrum):
|
||||
return units.to_WL(math.abs2(spectrum), spectrum.wl)
|
||||
|
||||
def _to_freq_int(self, spectrum):
|
||||
return math.abs2(spectrum)
|
||||
|
||||
def _to_afreq_int(self, spectrum):
|
||||
return math.abs2(spectrum)
|
||||
|
||||
def _to_time_int(self, spectrum):
|
||||
return math.abs2(np.fft.ifft(spectrum))
|
||||
|
||||
def amplitude(self, unit):
|
||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||
x_axis = unit.inv(self.w)
|
||||
else:
|
||||
x_axis = unit.inv(self.t)
|
||||
|
||||
order = np.argsort(x_axis)
|
||||
func = dict(
|
||||
WL=self._to_wl_amp,
|
||||
FREQ=self._to_freq_amp,
|
||||
AFREQ=self._to_afreq_amp,
|
||||
TIME=self._to_time_amp,
|
||||
)[unit.type]
|
||||
|
||||
for spec in self:
|
||||
yield x_axis[order], func(spec)[:, order]
|
||||
|
||||
def _to_wl_amp(self, spectrum):
|
||||
return (
|
||||
np.sqrt(
|
||||
units.to_WL(
|
||||
math.abs2(spectrum),
|
||||
spectrum.wl,
|
||||
)
|
||||
)
|
||||
* spectrum
|
||||
/ np.abs(spectrum)
|
||||
)
|
||||
|
||||
def _to_freq_amp(self, spectrum):
|
||||
return spectrum
|
||||
|
||||
def _to_afreq_amp(self, spectrum):
|
||||
return spectrum
|
||||
|
||||
def _to_time_amp(self, spectrum):
|
||||
return np.fft.ifft(spectrum)
|
||||
|
||||
def all_spectra(self, ind=None) -> Spectrum:
|
||||
"""
|
||||
loads the data already simulated.
|
||||
defauft shape is (z_targets, n, nt)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ind : int or list of int
|
||||
if only certain spectra are desired
|
||||
Returns
|
||||
----------
|
||||
spectra : array of shape (nz, m, nt)
|
||||
array of complex spectra (pulse at nz positions consisting
|
||||
of nm simulation on a nt size grid)
|
||||
"""
|
||||
|
||||
self.logger.debug(f"opening {self.path}")
|
||||
|
||||
# Check if file exists and assert how many z positions there are
|
||||
|
||||
if ind is None:
|
||||
ind = range(self.nmax)
|
||||
if isinstance(ind, (int, np.integer)):
|
||||
ind = [ind]
|
||||
elif isinstance(ind, (float, np.floating)):
|
||||
ind = [self.z_ind(ind)]
|
||||
elif isinstance(ind[0], (float, np.floating)):
|
||||
ind = [self.z_ind(ii) for ii in ind]
|
||||
|
||||
# Load the spectra
|
||||
spectra = []
|
||||
for i in ind:
|
||||
spectra.append(self._load1(i))
|
||||
spectra = Spectrum(spectra, self.params)
|
||||
|
||||
self.logger.debug(f"all spectra from {self.path} successfully loaded")
|
||||
if len(ind) == 1:
|
||||
return spectra[0]
|
||||
else:
|
||||
return spectra
|
||||
|
||||
def all_fields(self, ind=None):
|
||||
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
|
||||
|
||||
def _load1(self, i: int):
|
||||
if i < 0:
|
||||
i = self.nmax + i
|
||||
spec = load_spectrum(self.path / SPECN_FN1.format(i))
|
||||
spec = np.atleast_2d(spec)
|
||||
spec = Spectrum(spec, self.params)
|
||||
return spec
|
||||
|
||||
def plot_2D(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
ax: plt.Axes,
|
||||
z_pos: Union[int, Iterable[int]] = None,
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
plot_range = PlotRange(left, right, unit)
|
||||
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
|
||||
return propagation_plot(vals, plot_range, self.params, ax, **kwargs)
|
||||
|
||||
def plot_1D(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
ax: plt.Axes,
|
||||
z_pos: int,
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
plot_range = PlotRange(left, right, unit)
|
||||
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
|
||||
return single_position_plot(vals, plot_range, self.params, ax, **kwargs)
|
||||
|
||||
def plot_mean(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
ax: plt.Axes,
|
||||
z_pos: int,
|
||||
**kwargs,
|
||||
):
|
||||
plot_range = PlotRange(left, right, unit)
|
||||
vals = self.retrieve_plot_values(plot_range, z_pos, slice(None))
|
||||
return mean_values_plot(vals, plot_range, self.params, ax, **kwargs)
|
||||
|
||||
def retrieve_plot_values(
|
||||
self, plot_range: PlotRange, z_pos: Optional[Union[int, float]], sim_ind: Optional[int]
|
||||
):
|
||||
|
||||
if plot_range.unit.type == "TIME":
|
||||
vals = self.all_fields(ind=z_pos)
|
||||
else:
|
||||
vals = self.all_spectra(ind=z_pos)
|
||||
|
||||
if sim_ind is None:
|
||||
return vals
|
||||
elif z_pos is None:
|
||||
return vals[:, sim_ind]
|
||||
else:
|
||||
return vals[sim_ind]
|
||||
|
||||
def rin_propagation(
|
||||
self, left: float, right: float, unit: str
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""returns the RIN as function of unit and z
|
||||
|
||||
Parameters
|
||||
----------
|
||||
left : float
|
||||
left limit in unit
|
||||
right : float
|
||||
right limit in unit
|
||||
unit : str
|
||||
unit descriptor
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : np.ndarray, shape (nt,)
|
||||
x axis
|
||||
y : np.ndarray, shape (z_num, )
|
||||
y axis
|
||||
rin_prop : np.ndarray, shape (z_num, nt)
|
||||
RIN
|
||||
"""
|
||||
spectra = []
|
||||
for spec in np.moveaxis(self.all_spectra(), 1, 0):
|
||||
x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False)
|
||||
spectra.append(tmp)
|
||||
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
|
||||
|
||||
def z_ind(self, z: float) -> int:
|
||||
"""return the closest z index to the given target
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : float
|
||||
target
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
index
|
||||
"""
|
||||
return math.argclosest(self.z, z)
|
||||
|
||||
Reference in New Issue
Block a user