working converter
This commit is contained in:
@@ -27,84 +27,67 @@ def load_config_sequence(path: os.PathLike) -> tuple[list[Path], list[dict[str,
|
|||||||
|
|
||||||
|
|
||||||
def convert_sim_folder(path: os.PathLike):
|
def convert_sim_folder(path: os.PathLike):
|
||||||
path = Path(path)
|
path = Path(path).resolve()
|
||||||
config_paths, configs = load_config_sequence(path)
|
config_paths, configs = load_config_sequence(path)
|
||||||
master_config = dict(name=path.name, Fiber=configs)
|
master_config = dict(name=path.name, Fiber=configs)
|
||||||
with open(path / "initial_config.toml", "w") as f:
|
with open(path / "initial_config.toml", "w") as f:
|
||||||
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
|
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
|
||||||
configuration = Configuration(path / "initial_config.toml")
|
configuration = Configuration(path / "initial_config.toml", final_output_path=path)
|
||||||
new_fiber_paths: list[Path] = [
|
|
||||||
path / fiber_folder(i, path.name, cfg["name"]) for i, cfg in enumerate(configs)
|
|
||||||
]
|
|
||||||
for p in new_fiber_paths:
|
|
||||||
p.mkdir(exist_ok=True)
|
|
||||||
repeat = configs[0].get("repeat", 1)
|
|
||||||
|
|
||||||
pbar = PBars(configuration.total_num_steps, "Converting")
|
pbar = PBars(configuration.total_num_steps, "Converting")
|
||||||
|
|
||||||
old_paths: dict[Path, VariationDescriptor] = {
|
new_paths: dict[VariationDescriptor, Parameters] = dict(configuration)
|
||||||
path / descr.branch.formatted_descriptor(): (descr, param.final_path)
|
old_paths: Set[Path] = set()
|
||||||
for descr, param in configuration
|
old2new: list[tuple[Path, VariationDescriptor, Parameters, tuple[int, int]]] = []
|
||||||
}
|
for descriptor, params in configuration.iterate_single_fiber(-1):
|
||||||
|
old_path = path / descriptor.branch.formatted_descriptor()
|
||||||
|
if not Path(old_path).is_dir():
|
||||||
|
raise FileNotFoundError(f"missing {old_path} from {path}. Aborting.")
|
||||||
|
old_paths.add(old_path)
|
||||||
|
for d in descriptor.iter_parents():
|
||||||
|
z_num_start = sum(c["z_num"] for c in configs[: d.num_fibers - 1])
|
||||||
|
z_limits = (z_num_start, z_num_start + params.z_num)
|
||||||
|
old2new.append((old_path, d, new_paths[d], z_limits))
|
||||||
|
|
||||||
# create map from old to new path
|
|
||||||
|
|
||||||
pprint(old_paths)
|
|
||||||
quit()
|
|
||||||
for p in old_paths:
|
|
||||||
if not p.is_dir():
|
|
||||||
raise FileNotFoundError(f"missing {p} from {path}")
|
|
||||||
processed_paths: Set[Path] = set()
|
processed_paths: Set[Path] = set()
|
||||||
for old_variation_path, descriptor in old_paths.items(): # fiberA=0, fiber B=0
|
processed_specs: Set[VariationDescriptor] = set()
|
||||||
vary_parts = old_variation_path.name.split("fiber")[1:]
|
|
||||||
identifiers = [
|
|
||||||
"".join("fiber" + el for el in vary_parts[: i + 1]).strip()
|
|
||||||
for i in range(len(vary_parts))
|
|
||||||
]
|
|
||||||
cum_z_num = 0
|
|
||||||
for i, (fiber_path, new_identifier) in enumerate(zip(new_fiber_paths, identifiers)):
|
|
||||||
config = descriptor.update_config(configs[i], i)
|
|
||||||
new_variation_path = fiber_path / new_identifier
|
|
||||||
z_num = config["z_num"]
|
|
||||||
move = new_variation_path not in processed_paths
|
|
||||||
os.makedirs(new_variation_path, exist_ok=True)
|
|
||||||
processed_paths.add(new_variation_path)
|
|
||||||
|
|
||||||
for spec_num in range(cum_z_num, cum_z_num + z_num):
|
for old_path, descr, new_params, (start_z, end_z) in old2new:
|
||||||
old_spec = old_variation_path / SPECN_FN1.format(spec_num)
|
move_specs = descr not in processed_specs
|
||||||
if move:
|
processed_specs.add(descr)
|
||||||
spec_data = np.load(old_spec)
|
if (parent := descr.parent) is not None:
|
||||||
for j, spec1 in enumerate(spec_data):
|
new_params.prev_data_dir = str(new_paths[parent].final_path)
|
||||||
if j == 0:
|
save_parameters(new_params.prepare_for_dump(), new_params.final_path)
|
||||||
np.save(
|
for spec_num in range(start_z, end_z):
|
||||||
new_variation_path / SPEC1_FN.format(spec_num - cum_z_num), spec1
|
old_spec = old_path / SPECN_FN1.format(spec_num)
|
||||||
)
|
if move_specs:
|
||||||
else:
|
_mv_specs(pbar, new_params, start_z, spec_num, old_spec)
|
||||||
np.save(
|
|
||||||
new_variation_path / SPEC1_FN_N.format(spec_num - cum_z_num, j),
|
|
||||||
spec1,
|
|
||||||
)
|
|
||||||
pbar.update()
|
|
||||||
else:
|
|
||||||
pbar.update(value=repeat)
|
|
||||||
old_spec.unlink()
|
old_spec.unlink()
|
||||||
if move:
|
if old_path not in processed_paths:
|
||||||
if i > 0:
|
(old_path / PARAM_FN).unlink()
|
||||||
config["prev_data_dir"] = str(
|
(old_path / Z_FN).unlink()
|
||||||
(new_fiber_paths[i - 1] / identifiers[i - 1]).resolve()
|
processed_paths.add(old_path)
|
||||||
)
|
|
||||||
params = Parameters(**config)
|
for old_path in processed_paths:
|
||||||
params.compute()
|
old_path.rmdir()
|
||||||
save_parameters(params.prepare_for_dump(), new_variation_path)
|
|
||||||
cum_z_num += z_num
|
|
||||||
(old_variation_path / PARAM_FN).unlink()
|
|
||||||
(old_variation_path / Z_FN).unlink()
|
|
||||||
old_variation_path.rmdir()
|
|
||||||
|
|
||||||
for cp in config_paths:
|
for cp in config_paths:
|
||||||
cp.unlink()
|
cp.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def _mv_specs(pbar: PBars, new_params: Parameters, start_z: int, spec_num: int, old_spec: Path):
|
||||||
|
os.makedirs(new_params.final_path, exist_ok=True)
|
||||||
|
spec_data = np.load(old_spec)
|
||||||
|
for j, spec1 in enumerate(spec_data):
|
||||||
|
if j == 0:
|
||||||
|
np.save(new_params.final_path / SPEC1_FN.format(spec_num - start_z), spec1)
|
||||||
|
else:
|
||||||
|
np.save(
|
||||||
|
new_params.final_path / SPEC1_FN_N.format(spec_num - start_z, j),
|
||||||
|
spec1,
|
||||||
|
)
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
convert_sim_folder(sys.argv[1])
|
convert_sim_folder(sys.argv[1])
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,18 @@ from copy import copy, deepcopy
|
|||||||
from dataclasses import asdict, dataclass, fields
|
from dataclasses import asdict, dataclass, fields
|
||||||
from functools import cache, lru_cache
|
from functools import cache, lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.lib import isin
|
from numpy.lib import isin
|
||||||
@@ -523,7 +534,7 @@ class Parameters(_AbstractParameters):
|
|||||||
@property
|
@property
|
||||||
def final_path(self) -> Path:
|
def final_path(self) -> Path:
|
||||||
if self.output_path is not None:
|
if self.output_path is not None:
|
||||||
return update_path(self.output_path)
|
return Path(update_path(self.output_path))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -820,22 +831,26 @@ class Configuration:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
final_config_path: os.PathLike,
|
config_path: os.PathLike,
|
||||||
overwrite: bool = True,
|
overwrite: bool = True,
|
||||||
wait: bool = False,
|
wait: bool = False,
|
||||||
skip_callback: Callable[[int], None] = None,
|
skip_callback: Callable[[int], None] = None,
|
||||||
|
final_output_path: os.PathLike = None,
|
||||||
):
|
):
|
||||||
self.logger = get_logger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
self.wait = wait
|
self.wait = wait
|
||||||
|
|
||||||
self.overwrite = overwrite
|
self.overwrite = overwrite
|
||||||
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
|
self.final_path, self.fiber_configs = utils.load_config_sequence(config_path)
|
||||||
|
self.final_path = env.get(env.OUTPUT_PATH, self.final_path)
|
||||||
|
if final_output_path is not None:
|
||||||
|
self.final_path = final_output_path
|
||||||
self.final_path = utils.ensure_folder(
|
self.final_path = utils.ensure_folder(
|
||||||
Path(env.get(env.OUTPUT_PATH, self.final_path)),
|
Path(self.final_path),
|
||||||
mkdir=False,
|
mkdir=False,
|
||||||
prevent_overwrite=not self.overwrite,
|
prevent_overwrite=not self.overwrite,
|
||||||
)
|
)
|
||||||
self.master_config = self.fiber_configs[0]
|
self.master_config = self.fiber_configs[0].copy()
|
||||||
self.name = self.final_path.name
|
self.name = self.final_path.name
|
||||||
self.z_num = 0
|
self.z_num = 0
|
||||||
self.total_num_steps = 0
|
self.total_num_steps = 0
|
||||||
@@ -874,7 +889,7 @@ class Configuration:
|
|||||||
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
|
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
|
||||||
|
|
||||||
def __build_base_config(self):
|
def __build_base_config(self):
|
||||||
cfg = self.fiber_configs[0].copy()
|
cfg = self.master_config.copy()
|
||||||
vary = cfg.pop("variable", {})
|
vary = cfg.pop("variable", {})
|
||||||
return cfg | {k: v[0] for k, v in vary.items()}
|
return cfg | {k: v[0] for k, v in vary.items()}
|
||||||
|
|
||||||
@@ -887,15 +902,11 @@ class Configuration:
|
|||||||
if len(v) == 0:
|
if len(v) == 0:
|
||||||
raise ValueError(f"variable parameter {k!r} must not be empty")
|
raise ValueError(f"variable parameter {k!r} must not be empty")
|
||||||
|
|
||||||
def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]:
|
def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]:
|
||||||
for i in range(self.num_fibers):
|
for i in range(self.num_fibers):
|
||||||
for sim_config in self.iterate_single_fiber(i):
|
yield from self.iterate_single_fiber(i)
|
||||||
params = Parameters(**sim_config.config)
|
|
||||||
yield sim_config.descriptor, params
|
|
||||||
|
|
||||||
def iterate_single_fiber(
|
def iterate_single_fiber(self, index: int) -> Iterator[tuple[VariationDescriptor, Parameters]]:
|
||||||
self, index: int
|
|
||||||
) -> Generator["Configuration.__SimConfig", None, None]:
|
|
||||||
"""iterates through the parameters of only one fiber. It takes care of recovering partially
|
"""iterates through the parameters of only one fiber. It takes care of recovering partially
|
||||||
completed simulations, skipping complete ones and waiting for the previous fiber to finish
|
completed simulations, skipping complete ones and waiting for the previous fiber to finish
|
||||||
|
|
||||||
@@ -909,6 +920,8 @@ class Configuration:
|
|||||||
__SimConfig
|
__SimConfig
|
||||||
configuration obj
|
configuration obj
|
||||||
"""
|
"""
|
||||||
|
if index < 0:
|
||||||
|
index = self.num_fibers + index
|
||||||
sim_dict: dict[Path, Configuration.__SimConfig] = {}
|
sim_dict: dict[Path, Configuration.__SimConfig] = {}
|
||||||
for descriptor in self.variationer.iterate(index):
|
for descriptor in self.variationer.iterate(index):
|
||||||
cfg = descriptor.update_config(self.fiber_configs[index])
|
cfg = descriptor.update_config(self.fiber_configs[index])
|
||||||
@@ -929,7 +942,7 @@ class Configuration:
|
|||||||
task, config_dict = self.__decide(sim_config)
|
task, config_dict = self.__decide(sim_config)
|
||||||
if task == self.Action.RUN:
|
if task == self.Action.RUN:
|
||||||
sim_dict.pop(data_dir)
|
sim_dict.pop(data_dir)
|
||||||
yield sim_config
|
yield sim_config.descriptor, Parameters(**sim_config.config)
|
||||||
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
|
if "recovery_last_stored" in config_dict and self.skip_callback is not None:
|
||||||
self.skip_callback(config_dict["recovery_last_stored"])
|
self.skip_callback(config_dict["recovery_last_stored"])
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -121,6 +121,19 @@ class VariationDescriptor(BaseModel):
|
|||||||
_format_registry: dict[str, Callable[..., str]] = {}
|
_format_registry: dict[str, Callable[..., str]] = {}
|
||||||
__ids: dict[int, int] = {}
|
__ids: dict[int, int] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_formatter(cls, p_name: str, func: Callable[..., str]):
|
||||||
|
"""register a function that formats a particular parameter
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
p_name : str
|
||||||
|
name of the parameter
|
||||||
|
func : Callable[..., str]
|
||||||
|
function that takes as single argument the value of the parameter and returns a string
|
||||||
|
"""
|
||||||
|
cls._format_registry[p_name] = func
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
allow_mutation = False
|
allow_mutation = False
|
||||||
|
|
||||||
@@ -152,19 +165,6 @@ class VariationDescriptor(BaseModel):
|
|||||||
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
|
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_formatter(cls, p_name: str, func: Callable[..., str]):
|
|
||||||
"""register a function that formats a particular parameter
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
p_name : str
|
|
||||||
name of the parameter
|
|
||||||
func : Callable[..., str]
|
|
||||||
function that takes as single argument the value of the parameter and returns a string
|
|
||||||
"""
|
|
||||||
cls._format_registry[p_name] = func
|
|
||||||
|
|
||||||
def format_value(self, name: str, value) -> str:
|
def format_value(self, name: str, value) -> str:
|
||||||
if value is True or value is False:
|
if value is True or value is False:
|
||||||
return str(value)
|
return str(value)
|
||||||
@@ -201,9 +201,15 @@ class VariationDescriptor(BaseModel):
|
|||||||
def __ge__(self, other: "VariationDescriptor") -> bool:
|
def __ge__(self, other: "VariationDescriptor") -> bool:
|
||||||
return self.raw_descr >= other.raw_descr
|
return self.raw_descr >= other.raw_descr
|
||||||
|
|
||||||
|
def __eq__(self, other: "VariationDescriptor") -> bool:
|
||||||
|
return self.raw_descr == other.raw_descr
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash(self.raw_descr)
|
return hash(self.raw_descr)
|
||||||
|
|
||||||
|
def __contains__(self, other: "VariationDescriptor") -> bool:
|
||||||
|
return all(el in self.raw_descr for el in other.raw_descr)
|
||||||
|
|
||||||
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
|
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
|
||||||
"""updates a dictionary with the value of the descriptor
|
"""updates a dictionary with the value of the descriptor
|
||||||
|
|
||||||
@@ -223,6 +229,11 @@ class VariationDescriptor(BaseModel):
|
|||||||
out_cfg.pop("variable", None)
|
out_cfg.pop("variable", None)
|
||||||
return out_cfg | {k: v for k, v in self.raw_descr[index]}
|
return out_cfg | {k: v for k, v in self.raw_descr[index]}
|
||||||
|
|
||||||
|
def iter_parents(self) -> Iterator["VariationDescriptor"]:
|
||||||
|
if (p := self.parent) is not None:
|
||||||
|
yield from p.iter_parents()
|
||||||
|
yield self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def flat(self) -> list[tuple[str, Any]]:
|
def flat(self) -> list[tuple[str, Any]]:
|
||||||
out = []
|
out = []
|
||||||
@@ -260,6 +271,10 @@ class VariationDescriptor(BaseModel):
|
|||||||
raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator
|
raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_fibers(self) -> int:
|
||||||
|
return len(self.raw_descr)
|
||||||
|
|
||||||
|
|
||||||
class BranchDescriptor(VariationDescriptor):
|
class BranchDescriptor(VariationDescriptor):
|
||||||
__ids: dict[int, int] = {}
|
__ids: dict[int, int] = {}
|
||||||
|
|||||||
@@ -1,20 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union
|
from typing import Any, Callable, Iterator, Optional, Union
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, DirectoryPath, root_validator
|
|
||||||
|
|
||||||
from . import math
|
from . import math
|
||||||
from ._utils import load_spectrum
|
from ._utils import load_spectrum
|
||||||
from ._utils.parameter import Parameters
|
from ._utils.parameter import Parameters
|
||||||
from ._utils.utils import PlotRange, iter_simulations
|
from ._utils.utils import PlotRange, iter_simulations
|
||||||
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1
|
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .physics import pulse, units
|
from .physics import pulse, units
|
||||||
from .plotting import (
|
from .plotting import (
|
||||||
@@ -111,7 +108,7 @@ class Spectrum(np.ndarray):
|
|||||||
return self.params.l[np.argmax(self.wl_int, axis=-1)]
|
return self.params.l[np.argmax(self.wl_int, axis=-1)]
|
||||||
return np.array([s.wl_max for s in self])
|
return np.array([s.wl_max for s in self])
|
||||||
|
|
||||||
def mask_wl(self, pos: float, width: float) -> "Spectrum":
|
def mask_wl(self, pos: float, width: float) -> Spectrum:
|
||||||
return self * np.exp(
|
return self * np.exp(
|
||||||
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
|
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
|
||||||
)
|
)
|
||||||
@@ -353,300 +350,3 @@ class SimulationSeries:
|
|||||||
return self.spectra(*key)
|
return self.spectra(*key)
|
||||||
else:
|
else:
|
||||||
return self.spectra(key, None)
|
return self.spectra(key, None)
|
||||||
|
|
||||||
|
|
||||||
class Pulse(Sequence):
|
|
||||||
def __new__(cls, path: os.PathLike):
|
|
||||||
warnings.warn(
|
|
||||||
"You are using the legacy version of the pulse loader. "
|
|
||||||
"Please consider updating your data with scgenerator.convert_sim_folder "
|
|
||||||
"and loading data with the SimulationSeries class"
|
|
||||||
)
|
|
||||||
if (Path(path) / SPECN_FN1.format(0)).exists():
|
|
||||||
return LegacyPulse(path)
|
|
||||||
return SimulationSeries(path)
|
|
||||||
|
|
||||||
def __getitem__(self, key) -> Spectrum:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class LegacyPulse(Sequence):
|
|
||||||
def __init__(self, path: os.PathLike):
|
|
||||||
"""load a data folder as a pulse
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : os.PathLike
|
|
||||||
path to the data (folder containing .npy files)
|
|
||||||
default_ind : int | Iterable[int], optional
|
|
||||||
default indices to be loaded, by default None
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
path does not contain proper data
|
|
||||||
"""
|
|
||||||
self.logger = get_logger(__name__)
|
|
||||||
self.path = Path(path)
|
|
||||||
|
|
||||||
if not self.path.is_dir():
|
|
||||||
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
|
||||||
|
|
||||||
self.params = Parameters.load(self.path / "params.toml")
|
|
||||||
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
|
||||||
if self.params.fiber_map is None:
|
|
||||||
self.params.fiber_map = [(0.0, self.params.name)]
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.z = np.load(os.path.join(path, "z.npy"))
|
|
||||||
except FileNotFoundError:
|
|
||||||
if self.params is not None:
|
|
||||||
self.z = self.params.z_targets
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
self.nmax = len(list(self.path.glob("spectra_*.npy")))
|
|
||||||
if self.nmax <= 0:
|
|
||||||
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
|
||||||
|
|
||||||
self.t = self.params.t
|
|
||||||
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
|
||||||
self.w_order = np.argsort(w)
|
|
||||||
self.w = w
|
|
||||||
self.wl = units.m.inv(self.w)
|
|
||||||
self.params.w = self.w
|
|
||||||
self.params.z_targets = self.z
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
"""
|
|
||||||
similar to all_spectra but works as an iterator
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.logger.debug(f"iterating through {self.path}")
|
|
||||||
for i in range(self.nmax):
|
|
||||||
yield self._load1(i)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.nmax
|
|
||||||
|
|
||||||
def __getitem__(self, key) -> Spectrum:
|
|
||||||
return self.all_spectra(key)
|
|
||||||
|
|
||||||
def intensity(self, unit):
|
|
||||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
|
||||||
x_axis = unit.inv(self.w)
|
|
||||||
else:
|
|
||||||
x_axis = unit.inv(self.t)
|
|
||||||
|
|
||||||
order = np.argsort(x_axis)
|
|
||||||
func = dict(
|
|
||||||
WL=self._to_wl_int,
|
|
||||||
FREQ=self._to_freq_int,
|
|
||||||
AFREQ=self._to_afreq_int,
|
|
||||||
TIME=self._to_time_int,
|
|
||||||
)[unit.type]
|
|
||||||
|
|
||||||
for spec in self:
|
|
||||||
yield x_axis[order], func(spec)[:, order]
|
|
||||||
|
|
||||||
def _to_wl_int(self, spectrum):
|
|
||||||
return units.to_WL(math.abs2(spectrum), spectrum.wl)
|
|
||||||
|
|
||||||
def _to_freq_int(self, spectrum):
|
|
||||||
return math.abs2(spectrum)
|
|
||||||
|
|
||||||
def _to_afreq_int(self, spectrum):
|
|
||||||
return math.abs2(spectrum)
|
|
||||||
|
|
||||||
def _to_time_int(self, spectrum):
|
|
||||||
return math.abs2(np.fft.ifft(spectrum))
|
|
||||||
|
|
||||||
def amplitude(self, unit):
|
|
||||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
|
||||||
x_axis = unit.inv(self.w)
|
|
||||||
else:
|
|
||||||
x_axis = unit.inv(self.t)
|
|
||||||
|
|
||||||
order = np.argsort(x_axis)
|
|
||||||
func = dict(
|
|
||||||
WL=self._to_wl_amp,
|
|
||||||
FREQ=self._to_freq_amp,
|
|
||||||
AFREQ=self._to_afreq_amp,
|
|
||||||
TIME=self._to_time_amp,
|
|
||||||
)[unit.type]
|
|
||||||
|
|
||||||
for spec in self:
|
|
||||||
yield x_axis[order], func(spec)[:, order]
|
|
||||||
|
|
||||||
def _to_wl_amp(self, spectrum):
|
|
||||||
return (
|
|
||||||
np.sqrt(
|
|
||||||
units.to_WL(
|
|
||||||
math.abs2(spectrum),
|
|
||||||
spectrum.wl,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
* spectrum
|
|
||||||
/ np.abs(spectrum)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _to_freq_amp(self, spectrum):
|
|
||||||
return spectrum
|
|
||||||
|
|
||||||
def _to_afreq_amp(self, spectrum):
|
|
||||||
return spectrum
|
|
||||||
|
|
||||||
def _to_time_amp(self, spectrum):
|
|
||||||
return np.fft.ifft(spectrum)
|
|
||||||
|
|
||||||
def all_spectra(self, ind=None) -> Spectrum:
|
|
||||||
"""
|
|
||||||
loads the data already simulated.
|
|
||||||
defauft shape is (z_targets, n, nt)
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
ind : int or list of int
|
|
||||||
if only certain spectra are desired
|
|
||||||
Returns
|
|
||||||
----------
|
|
||||||
spectra : array of shape (nz, m, nt)
|
|
||||||
array of complex spectra (pulse at nz positions consisting
|
|
||||||
of nm simulation on a nt size grid)
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.logger.debug(f"opening {self.path}")
|
|
||||||
|
|
||||||
# Check if file exists and assert how many z positions there are
|
|
||||||
|
|
||||||
if ind is None:
|
|
||||||
ind = range(self.nmax)
|
|
||||||
if isinstance(ind, (int, np.integer)):
|
|
||||||
ind = [ind]
|
|
||||||
elif isinstance(ind, (float, np.floating)):
|
|
||||||
ind = [self.z_ind(ind)]
|
|
||||||
elif isinstance(ind[0], (float, np.floating)):
|
|
||||||
ind = [self.z_ind(ii) for ii in ind]
|
|
||||||
|
|
||||||
# Load the spectra
|
|
||||||
spectra = []
|
|
||||||
for i in ind:
|
|
||||||
spectra.append(self._load1(i))
|
|
||||||
spectra = Spectrum(spectra, self.params)
|
|
||||||
|
|
||||||
self.logger.debug(f"all spectra from {self.path} successfully loaded")
|
|
||||||
if len(ind) == 1:
|
|
||||||
return spectra[0]
|
|
||||||
else:
|
|
||||||
return spectra
|
|
||||||
|
|
||||||
def all_fields(self, ind=None):
|
|
||||||
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
|
|
||||||
|
|
||||||
def _load1(self, i: int):
|
|
||||||
if i < 0:
|
|
||||||
i = self.nmax + i
|
|
||||||
spec = load_spectrum(self.path / SPECN_FN1.format(i))
|
|
||||||
spec = np.atleast_2d(spec)
|
|
||||||
spec = Spectrum(spec, self.params)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
def plot_2D(
|
|
||||||
self,
|
|
||||||
left: float,
|
|
||||||
right: float,
|
|
||||||
unit: Union[Callable[[float], float], str],
|
|
||||||
ax: plt.Axes,
|
|
||||||
z_pos: Union[int, Iterable[int]] = None,
|
|
||||||
sim_ind: int = 0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
plot_range = PlotRange(left, right, unit)
|
|
||||||
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
|
|
||||||
return propagation_plot(vals, plot_range, self.params, ax, **kwargs)
|
|
||||||
|
|
||||||
def plot_1D(
|
|
||||||
self,
|
|
||||||
left: float,
|
|
||||||
right: float,
|
|
||||||
unit: Union[Callable[[float], float], str],
|
|
||||||
ax: plt.Axes,
|
|
||||||
z_pos: int,
|
|
||||||
sim_ind: int = 0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
plot_range = PlotRange(left, right, unit)
|
|
||||||
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
|
|
||||||
return single_position_plot(vals, plot_range, self.params, ax, **kwargs)
|
|
||||||
|
|
||||||
def plot_mean(
|
|
||||||
self,
|
|
||||||
left: float,
|
|
||||||
right: float,
|
|
||||||
unit: Union[Callable[[float], float], str],
|
|
||||||
ax: plt.Axes,
|
|
||||||
z_pos: int,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
plot_range = PlotRange(left, right, unit)
|
|
||||||
vals = self.retrieve_plot_values(plot_range, z_pos, slice(None))
|
|
||||||
return mean_values_plot(vals, plot_range, self.params, ax, **kwargs)
|
|
||||||
|
|
||||||
def retrieve_plot_values(
|
|
||||||
self, plot_range: PlotRange, z_pos: Optional[Union[int, float]], sim_ind: Optional[int]
|
|
||||||
):
|
|
||||||
|
|
||||||
if plot_range.unit.type == "TIME":
|
|
||||||
vals = self.all_fields(ind=z_pos)
|
|
||||||
else:
|
|
||||||
vals = self.all_spectra(ind=z_pos)
|
|
||||||
|
|
||||||
if sim_ind is None:
|
|
||||||
return vals
|
|
||||||
elif z_pos is None:
|
|
||||||
return vals[:, sim_ind]
|
|
||||||
else:
|
|
||||||
return vals[sim_ind]
|
|
||||||
|
|
||||||
def rin_propagation(
|
|
||||||
self, left: float, right: float, unit: str
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
"""returns the RIN as function of unit and z
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
left : float
|
|
||||||
left limit in unit
|
|
||||||
right : float
|
|
||||||
right limit in unit
|
|
||||||
unit : str
|
|
||||||
unit descriptor
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
x : np.ndarray, shape (nt,)
|
|
||||||
x axis
|
|
||||||
y : np.ndarray, shape (z_num, )
|
|
||||||
y axis
|
|
||||||
rin_prop : np.ndarray, shape (z_num, nt)
|
|
||||||
RIN
|
|
||||||
"""
|
|
||||||
spectra = []
|
|
||||||
for spec in np.moveaxis(self.all_spectra(), 1, 0):
|
|
||||||
x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False)
|
|
||||||
spectra.append(tmp)
|
|
||||||
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
|
|
||||||
|
|
||||||
def z_ind(self, z: float) -> int:
|
|
||||||
"""return the closest z index to the given target
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
z : float
|
|
||||||
target
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
int
|
|
||||||
index
|
|
||||||
"""
|
|
||||||
return math.argclosest(self.z, z)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user