file name still wrong in converter
This commit is contained in:
@@ -2,9 +2,24 @@ from . import math
|
||||
from .math import abs2, argclosest, span
|
||||
from .physics import fiber, materials, pulse, simulate, units
|
||||
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
|
||||
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
||||
from .plotting import (
|
||||
mean_values_plot,
|
||||
plot_spectrogram,
|
||||
propagation_plot,
|
||||
single_position_plot,
|
||||
transform_2D_propagation,
|
||||
transform_1D_values,
|
||||
transform_mean_values,
|
||||
get_extent,
|
||||
)
|
||||
from .spectra import Pulse, Spectrum, SimulationSeries
|
||||
from ._utils import Paths, open_config, parameter
|
||||
from ._utils.parameter import Configuration, Parameters
|
||||
from ._utils.utils import PlotRange
|
||||
from ._utils.variationer import Variationer, VariationDescriptor, VariationSpecsError
|
||||
from ._utils.legacy import convert_sim_folder
|
||||
from ._utils.variationer import (
|
||||
Variationer,
|
||||
VariationDescriptor,
|
||||
VariationSpecsError,
|
||||
DescriptorDict,
|
||||
)
|
||||
|
||||
@@ -8,8 +8,9 @@ import numpy as np
|
||||
import toml
|
||||
|
||||
from ..const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1, Z_FN
|
||||
from .parameter import Parameters
|
||||
from .utils import fiber_folder, update_path, save_parameters
|
||||
from .parameter import Configuration, Parameters
|
||||
from .utils import fiber_folder, save_parameters
|
||||
from .pbar import PBars
|
||||
from .variationer import VariationDescriptor, Variationer
|
||||
|
||||
|
||||
@@ -29,21 +30,32 @@ def convert_sim_folder(path: os.PathLike):
|
||||
path = Path(path)
|
||||
config_paths, configs = load_config_sequence(path)
|
||||
master_config = dict(name=path.name, Fiber=configs)
|
||||
with open(path / "initial_config.toml", "w") as f:
|
||||
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
|
||||
configuration = Configuration(path / "initial_config.toml")
|
||||
new_fiber_paths: list[Path] = [
|
||||
path / fiber_folder(i, path.name, cfg["name"]) for i, cfg in enumerate(configs)
|
||||
]
|
||||
for p in new_fiber_paths:
|
||||
p.mkdir(exist_ok=True)
|
||||
var = Variationer(c["variable"] for c in configs)
|
||||
repeat = configs[0].get("repeat", 1)
|
||||
|
||||
paths: dict[Path, VariationDescriptor] = {
|
||||
path / descr.branch.formatted_descriptor(): descr for descr in var.iterate()
|
||||
pbar = PBars(configuration.total_num_steps, "Converting")
|
||||
|
||||
old_paths: dict[Path, VariationDescriptor] = {
|
||||
path / descr.branch.formatted_descriptor(): (descr, param.final_path)
|
||||
for descr, param in configuration
|
||||
}
|
||||
for p in paths:
|
||||
|
||||
# create map from old to new path
|
||||
|
||||
pprint(old_paths)
|
||||
quit()
|
||||
for p in old_paths:
|
||||
if not p.is_dir():
|
||||
raise FileNotFoundError(f"missing {p} from {path}")
|
||||
processed_paths: Set[Path] = set()
|
||||
for old_variation_path, descriptor in paths.items(): # fiberA=0, fiber B=0
|
||||
for old_variation_path, descriptor in old_paths.items(): # fiberA=0, fiber B=0
|
||||
vary_parts = old_variation_path.name.split("fiber")[1:]
|
||||
identifiers = [
|
||||
"".join("fiber" + el for el in vary_parts[: i + 1]).strip()
|
||||
@@ -72,6 +84,9 @@ def convert_sim_folder(path: os.PathLike):
|
||||
new_variation_path / SPEC1_FN_N.format(spec_num - cum_z_num, j),
|
||||
spec1,
|
||||
)
|
||||
pbar.update()
|
||||
else:
|
||||
pbar.update(value=repeat)
|
||||
old_spec.unlink()
|
||||
if move:
|
||||
if i > 0:
|
||||
@@ -88,8 +103,6 @@ def convert_sim_folder(path: os.PathLike):
|
||||
|
||||
for cp in config_paths:
|
||||
cp.unlink()
|
||||
with open(path / "initial_config.toml", "w") as f:
|
||||
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -24,7 +24,7 @@ from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
|
||||
from ..errors import EvaluatorError, NoDefaultError
|
||||
from ..logger import get_logger
|
||||
from ..physics import fiber, materials, pulse, units
|
||||
from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names
|
||||
from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -75,6 +75,7 @@ VALID_VARIABLE = {
|
||||
"interpolation_degree",
|
||||
"ideal_gas",
|
||||
"length",
|
||||
"num",
|
||||
}
|
||||
|
||||
MANDATORY_PARAMETERS = [
|
||||
@@ -519,6 +520,12 @@ class Parameters(_AbstractParameters):
|
||||
|
||||
return out
|
||||
|
||||
@property
|
||||
def final_path(self) -> Path:
|
||||
if self.output_path is not None:
|
||||
return update_path(self.output_path)
|
||||
return None
|
||||
|
||||
|
||||
class Rule:
|
||||
def __init__(
|
||||
@@ -777,6 +784,7 @@ class Configuration:
|
||||
"""
|
||||
|
||||
fiber_configs: list[dict[str, Any]]
|
||||
vary_dicts: list[dict[str, list]]
|
||||
master_config: dict[str, Any]
|
||||
fiber_paths: list[Path]
|
||||
num_sim: int
|
||||
@@ -814,9 +822,11 @@ class Configuration:
|
||||
self,
|
||||
final_config_path: os.PathLike,
|
||||
overwrite: bool = True,
|
||||
wait: bool = False,
|
||||
skip_callback: Callable[[int], None] = None,
|
||||
):
|
||||
self.logger = get_logger(__name__)
|
||||
self.wait = wait
|
||||
|
||||
self.overwrite = overwrite
|
||||
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
|
||||
@@ -842,7 +852,8 @@ class Configuration:
|
||||
config.setdefault("name", Parameters.name.default)
|
||||
self.z_num += config["z_num"]
|
||||
fiber_names.add(config["name"])
|
||||
self.variationer.append(config.pop("variable"))
|
||||
vary_dict = config.pop("variable")
|
||||
self.variationer.append(vary_dict)
|
||||
self.fiber_paths.append(
|
||||
utils.ensure_folder(
|
||||
self.final_path / fiber_folder(i, self.name, config["name"]),
|
||||
@@ -850,9 +861,11 @@ class Configuration:
|
||||
prevent_overwrite=not self.overwrite,
|
||||
)
|
||||
)
|
||||
self.__validate_variable(config)
|
||||
self.__validate_variable(vary_dict)
|
||||
self.num_fibers += 1
|
||||
Evaluator.evaluate_default(config, True)
|
||||
Evaluator.evaluate_default(
|
||||
self.__build_base_config() | config | {k: v[0] for k, v in vary_dict.items()}, True
|
||||
)
|
||||
self.num_sim = self.variationer.var_num()
|
||||
self.total_num_steps = sum(
|
||||
config["z_num"] * self.variationer.var_num(i)
|
||||
@@ -860,8 +873,13 @@ class Configuration:
|
||||
)
|
||||
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
|
||||
|
||||
def __validate_variable(self, config: dict[str, Any]):
|
||||
for k, v in config.get("variable", {}).items():
|
||||
def __build_base_config(self):
|
||||
cfg = self.fiber_configs[0].copy()
|
||||
vary = cfg.pop("variable", {})
|
||||
return cfg | {k: v[0] for k, v in vary.items()}
|
||||
|
||||
def __validate_variable(self, vary_dict: dict[str, list]):
|
||||
for k, v in vary_dict.items():
|
||||
p = getattr(Parameters, k)
|
||||
validator_list(p.validator)("variable " + k, v)
|
||||
if k not in VALID_VARIABLE:
|
||||
@@ -873,7 +891,6 @@ class Configuration:
|
||||
for i in range(self.num_fibers):
|
||||
for sim_config in self.iterate_single_fiber(i):
|
||||
params = Parameters(**sim_config.config)
|
||||
params.compute()
|
||||
yield sim_config.descriptor, params
|
||||
|
||||
def iterate_single_fiber(
|
||||
@@ -943,6 +960,8 @@ class Configuration:
|
||||
config dictionary. The only key possibly modified is 'prev_data_dir', which
|
||||
gets set if the simulation is partially completed
|
||||
"""
|
||||
if not self.wait:
|
||||
return self.Action.RUN, sim_config.config
|
||||
out_status, num = self.sim_status(sim_config.output_path, sim_config.config)
|
||||
if out_status == self.State.COMPLETE:
|
||||
return self.Action.SKIP, sim_config.config
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import abc
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import multiprocessing
|
||||
import threading
|
||||
import typing
|
||||
from collections import abc
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Union
|
||||
@@ -24,7 +24,19 @@ class PBars:
|
||||
head_kwargs=None,
|
||||
worker_kwargs=None,
|
||||
) -> "PBars":
|
||||
"""creates a PBars obj
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : int | Iterable
|
||||
if int : total length of the main task
|
||||
if Iterable : behaves like tqdm
|
||||
desc : str
|
||||
description of the main task
|
||||
num_sub_bars : int
|
||||
number of sub-tasks
|
||||
|
||||
"""
|
||||
self.id = random.randint(100000, 999999)
|
||||
try:
|
||||
self.width = os.get_terminal_size().columns
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from string import printable as str_printable
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Iterator, Set
|
||||
|
||||
import numpy as np
|
||||
import toml
|
||||
@@ -236,3 +236,24 @@ def update_path(p: str) -> str:
|
||||
|
||||
def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str:
|
||||
return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name])
|
||||
|
||||
|
||||
def iter_simulations(path: os.PathLike) -> list[Path]:
|
||||
"""finds simulations folders contained in a parent directory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
parent path
|
||||
|
||||
Yields
|
||||
-------
|
||||
Path
|
||||
Absolute Path to the simulation folder
|
||||
"""
|
||||
paths: list[Path] = []
|
||||
for pwd, _, files in os.walk(path):
|
||||
if PARAM_FN in files:
|
||||
paths.append(Path(pwd))
|
||||
paths.sort(key=lambda el: el.parent.name)
|
||||
return [p for p in paths if p.parent.name == paths[-1].parent.name]
|
||||
|
||||
@@ -2,14 +2,17 @@ from math import prod
|
||||
import itertools
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generator, Iterable, Optional, Union
|
||||
from typing import Any, Callable, Generator, Generic, Iterable, Iterator, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import validator
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from ..const import PARAM_SEPARATOR
|
||||
from . import utils
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class VariationSpecsError(ValueError):
|
||||
pass
|
||||
@@ -111,15 +114,15 @@ class Variationer:
|
||||
return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
|
||||
|
||||
|
||||
class VariationDescriptor(utils.HashableBaseModel):
|
||||
class VariationDescriptor(BaseModel):
|
||||
raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
|
||||
index: tuple[tuple[int, ...], ...]
|
||||
separator: str = "fiber"
|
||||
_format_registry: dict[str, Callable[..., str]] = {}
|
||||
__ids: dict[int, int] = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.formatted_descriptor(add_identifier=False)
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
|
||||
def formatted_descriptor(self, add_identifier=False) -> str:
|
||||
"""formats a variable list into a str such that each simulation has a unique
|
||||
@@ -183,6 +186,24 @@ class VariationDescriptor(utils.HashableBaseModel):
|
||||
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.formatted_descriptor(add_identifier=False)
|
||||
|
||||
def __lt__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr < other.raw_descr
|
||||
|
||||
def __le__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr <= other.raw_descr
|
||||
|
||||
def __gt__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr > other.raw_descr
|
||||
|
||||
def __ge__(self, other: "VariationDescriptor") -> bool:
|
||||
return self.raw_descr >= other.raw_descr
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.raw_descr)
|
||||
|
||||
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
|
||||
"""updates a dictionary with the value of the descriptor
|
||||
|
||||
@@ -252,3 +273,34 @@ class BranchDescriptor(VariationDescriptor):
|
||||
@validator("raw_descr")
|
||||
def validate_raw_descr(cls, v):
|
||||
return tuple(tuple(el for el in variable if el[0] != "num") for variable in v)
|
||||
|
||||
|
||||
class DescriptorDict(Generic[T]):
|
||||
def __init__(self, dico: dict[VariationDescriptor, T] = None):
|
||||
self.dico: dict[tuple[tuple[tuple[str, Any], ...], ...], tuple[VariationDescriptor, T]] = {}
|
||||
if dico is not None:
|
||||
for k, v in dico.items():
|
||||
self[k] = v
|
||||
|
||||
def __setitem__(self, key: VariationDescriptor, value: T):
|
||||
if not isinstance(key, VariationDescriptor):
|
||||
raise TypeError("key must be a VariationDescriptor instance")
|
||||
self.dico[key.raw_descr] = (key, value)
|
||||
|
||||
def __getitem__(
|
||||
self, key: Union[VariationDescriptor, tuple[tuple[tuple[str, Any], ...], ...]]
|
||||
) -> T:
|
||||
if isinstance(key, VariationDescriptor):
|
||||
return self.dico[key.raw_descr][1]
|
||||
else:
|
||||
return self.dico[key][1]
|
||||
|
||||
def items(self) -> Iterator[tuple[VariationDescriptor, T]]:
|
||||
for k, v in self.dico.items():
|
||||
yield k, v[1]
|
||||
|
||||
def keys(self) -> list[VariationDescriptor]:
|
||||
return [v[0] for v in self.dico.values()]
|
||||
|
||||
def values(self) -> list[T]:
|
||||
return [v[1] for v in self.dico.values()]
|
||||
|
||||
@@ -491,6 +491,7 @@ class Simulations:
|
||||
|
||||
def _run_available(self):
|
||||
for variable, params in self.configuration:
|
||||
params.compute()
|
||||
v_list_str = variable.formatted_descriptor(True)
|
||||
save_parameters(params.prepare_for_dump(), Path(params.output_path))
|
||||
|
||||
@@ -718,7 +719,7 @@ def run_simulation(
|
||||
config_file: os.PathLike,
|
||||
method: Union[str, Type[Simulations]] = None,
|
||||
):
|
||||
config = Configuration(config_file)
|
||||
config = Configuration(config_file, wait=True)
|
||||
|
||||
sim = new_simulation(config, method)
|
||||
sim.run()
|
||||
@@ -760,6 +761,8 @@ def parallel_RK4IP(
|
||||
]:
|
||||
logger = get_logger(__name__)
|
||||
params = list(Configuration(config))
|
||||
for _, param in params:
|
||||
param.compute()
|
||||
n = len(params)
|
||||
z_num = params[0][1].z_num
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union
|
||||
@@ -12,8 +13,8 @@ from pydantic import BaseModel, DirectoryPath, root_validator
|
||||
from . import math
|
||||
from ._utils import load_spectrum
|
||||
from ._utils.parameter import Parameters
|
||||
from ._utils.utils import PlotRange
|
||||
from .const import SPECN_FN1, PARAM_FN, SPEC1_FN_N, SPEC1_FN
|
||||
from ._utils.utils import PlotRange, iter_simulations
|
||||
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1
|
||||
from .logger import get_logger
|
||||
from .physics import pulse, units
|
||||
from .plotting import (
|
||||
@@ -131,11 +132,10 @@ class SimulationSeries:
|
||||
|
||||
def __init__(self, path: os.PathLike):
|
||||
self.logger = get_logger()
|
||||
path = Path(path)
|
||||
subdirs = [el for el in path.glob("*") if (el / PARAM_FN).exists()]
|
||||
while not (path / PARAM_FN).exists() and len(subdirs) == 1:
|
||||
path = subdirs[0]
|
||||
self.path = path
|
||||
for self.path in iter_simulations(path):
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(f"No simulation in {path}")
|
||||
self.params = Parameters.load(self.path / PARAM_FN)
|
||||
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
||||
self.t = self.params.t
|
||||
@@ -356,24 +356,22 @@ class SimulationSeries:
|
||||
|
||||
|
||||
class Pulse(Sequence):
|
||||
path: Path
|
||||
default_ind: Optional[int]
|
||||
params: Parameters
|
||||
z: np.ndarray
|
||||
namx: int
|
||||
t: np.ndarray
|
||||
w: np.ndarray
|
||||
w_order: np.ndarray
|
||||
def __new__(cls, path: os.PathLike):
|
||||
warnings.warn(
|
||||
"You are using the legacy version of the pulse loader. "
|
||||
"Please consider updating your data with scgenerator.convert_sim_folder "
|
||||
"and loading data with the SimulationSeries class"
|
||||
)
|
||||
if (Path(path) / SPECN_FN1.format(0)).exists():
|
||||
return LegacyPulse(path)
|
||||
return SimulationSeries(path)
|
||||
|
||||
def __new__(cls, path: os.PathLike, *args, **kwargs) -> "Pulse":
|
||||
try:
|
||||
if load_spectrum(Path(path) / SPECN_FN1.format(0)).ndim == 2:
|
||||
return super().__new__(LegacyPulse)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return super().__new__(cls)
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
|
||||
|
||||
class LegacyPulse(Sequence):
|
||||
def __init__(self, path: os.PathLike):
|
||||
"""load a data folder as a pulse
|
||||
|
||||
Parameters
|
||||
@@ -388,6 +386,35 @@ class Pulse(Sequence):
|
||||
FileNotFoundError
|
||||
path does not contain proper data
|
||||
"""
|
||||
self.logger = get_logger(__name__)
|
||||
self.path = Path(path)
|
||||
|
||||
if not self.path.is_dir():
|
||||
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
||||
|
||||
self.params = Parameters.load(self.path / "params.toml")
|
||||
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
||||
if self.params.fiber_map is None:
|
||||
self.params.fiber_map = [(0.0, self.params.name)]
|
||||
|
||||
try:
|
||||
self.z = np.load(os.path.join(path, "z.npy"))
|
||||
except FileNotFoundError:
|
||||
if self.params is not None:
|
||||
self.z = self.params.z_targets
|
||||
else:
|
||||
raise
|
||||
self.nmax = len(list(self.path.glob("spectra_*.npy")))
|
||||
if self.nmax <= 0:
|
||||
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
||||
|
||||
self.t = self.params.t
|
||||
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
||||
self.w_order = np.argsort(w)
|
||||
self.w = w
|
||||
self.wl = units.m.inv(self.w)
|
||||
self.params.w = self.w
|
||||
self.params.z_targets = self.z
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
@@ -404,6 +431,73 @@ class Pulse(Sequence):
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
return self.all_spectra(key)
|
||||
|
||||
def intensity(self, unit):
|
||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||
x_axis = unit.inv(self.w)
|
||||
else:
|
||||
x_axis = unit.inv(self.t)
|
||||
|
||||
order = np.argsort(x_axis)
|
||||
func = dict(
|
||||
WL=self._to_wl_int,
|
||||
FREQ=self._to_freq_int,
|
||||
AFREQ=self._to_afreq_int,
|
||||
TIME=self._to_time_int,
|
||||
)[unit.type]
|
||||
|
||||
for spec in self:
|
||||
yield x_axis[order], func(spec)[:, order]
|
||||
|
||||
def _to_wl_int(self, spectrum):
|
||||
return units.to_WL(math.abs2(spectrum), spectrum.wl)
|
||||
|
||||
def _to_freq_int(self, spectrum):
|
||||
return math.abs2(spectrum)
|
||||
|
||||
def _to_afreq_int(self, spectrum):
|
||||
return math.abs2(spectrum)
|
||||
|
||||
def _to_time_int(self, spectrum):
|
||||
return math.abs2(np.fft.ifft(spectrum))
|
||||
|
||||
def amplitude(self, unit):
|
||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||
x_axis = unit.inv(self.w)
|
||||
else:
|
||||
x_axis = unit.inv(self.t)
|
||||
|
||||
order = np.argsort(x_axis)
|
||||
func = dict(
|
||||
WL=self._to_wl_amp,
|
||||
FREQ=self._to_freq_amp,
|
||||
AFREQ=self._to_afreq_amp,
|
||||
TIME=self._to_time_amp,
|
||||
)[unit.type]
|
||||
|
||||
for spec in self:
|
||||
yield x_axis[order], func(spec)[:, order]
|
||||
|
||||
def _to_wl_amp(self, spectrum):
|
||||
return (
|
||||
np.sqrt(
|
||||
units.to_WL(
|
||||
math.abs2(spectrum),
|
||||
spectrum.wl,
|
||||
)
|
||||
)
|
||||
* spectrum
|
||||
/ np.abs(spectrum)
|
||||
)
|
||||
|
||||
def _to_freq_amp(self, spectrum):
|
||||
return spectrum
|
||||
|
||||
def _to_afreq_amp(self, spectrum):
|
||||
return spectrum
|
||||
|
||||
def _to_time_amp(self, spectrum):
|
||||
return np.fft.ifft(spectrum)
|
||||
|
||||
def all_spectra(self, ind=None) -> Spectrum:
|
||||
"""
|
||||
loads the data already simulated.
|
||||
@@ -425,10 +519,7 @@ class Pulse(Sequence):
|
||||
# Check if file exists and assert how many z positions there are
|
||||
|
||||
if ind is None:
|
||||
if self.default_ind is None:
|
||||
ind = range(self.nmax)
|
||||
else:
|
||||
ind = self.default_ind
|
||||
if isinstance(ind, (int, np.integer)):
|
||||
ind = [ind]
|
||||
elif isinstance(ind, (float, np.floating)):
|
||||
@@ -452,7 +543,12 @@ class Pulse(Sequence):
|
||||
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
|
||||
|
||||
def _load1(self, i: int):
|
||||
pass
|
||||
if i < 0:
|
||||
i = self.nmax + i
|
||||
spec = load_spectrum(self.path / SPECN_FN1.format(i))
|
||||
spec = np.atleast_2d(spec)
|
||||
spec = Spectrum(spec, self.params)
|
||||
return spec
|
||||
|
||||
def plot_2D(
|
||||
self,
|
||||
@@ -554,46 +650,3 @@ class Pulse(Sequence):
|
||||
index
|
||||
"""
|
||||
return math.argclosest(self.z, z)
|
||||
|
||||
|
||||
class LegacyPulse(Pulse):
|
||||
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
|
||||
print("old init called", path, default_ind)
|
||||
self.logger = get_logger(__name__)
|
||||
self.path = Path(path)
|
||||
self.default_ind = default_ind
|
||||
|
||||
if not self.path.is_dir():
|
||||
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
||||
|
||||
self.params = Parameters.load(self.path / "params.toml")
|
||||
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
||||
if self.params.fiber_map is None:
|
||||
self.params.fiber_map = [(0.0, self.params.name)]
|
||||
|
||||
try:
|
||||
self.z = np.load(os.path.join(path, "z.npy"))
|
||||
except FileNotFoundError:
|
||||
if self.params is not None:
|
||||
self.z = self.params.z_targets
|
||||
else:
|
||||
raise
|
||||
self.nmax = len(list(self.path.glob("spectra_*.npy")))
|
||||
if self.nmax <= 0:
|
||||
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
||||
|
||||
self.t = self.params.t
|
||||
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
||||
self.w_order = np.argsort(w)
|
||||
self.w = w
|
||||
self.wl = units.m.inv(self.w)
|
||||
self.params.w = self.w
|
||||
self.params.z_targets = self.z
|
||||
|
||||
def _load1(self, i: int):
|
||||
if i < 0:
|
||||
i = self.nmax + i
|
||||
spec = load_spectrum(self.path / SPECN_FN1.format(i))
|
||||
spec = np.atleast_2d(spec)
|
||||
spec = Spectrum(spec, self.params)
|
||||
return spec
|
||||
|
||||
Reference in New Issue
Block a user