file name still wrong in converter

This commit is contained in:
Benoît Sierro
2021-10-05 15:55:20 +02:00
parent fcaf872a26
commit fa72a6e136
8 changed files with 285 additions and 97 deletions

View File

@@ -2,9 +2,24 @@ from . import math
from .math import abs2, argclosest, span from .math import abs2, argclosest, span
from .physics import fiber, materials, pulse, simulate, units from .physics import fiber, materials, pulse, simulate, units
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot from .plotting import (
mean_values_plot,
plot_spectrogram,
propagation_plot,
single_position_plot,
transform_2D_propagation,
transform_1D_values,
transform_mean_values,
get_extent,
)
from .spectra import Pulse, Spectrum, SimulationSeries from .spectra import Pulse, Spectrum, SimulationSeries
from ._utils import Paths, open_config, parameter from ._utils import Paths, open_config, parameter
from ._utils.parameter import Configuration, Parameters from ._utils.parameter import Configuration, Parameters
from ._utils.utils import PlotRange from ._utils.utils import PlotRange
from ._utils.variationer import Variationer, VariationDescriptor, VariationSpecsError from ._utils.legacy import convert_sim_folder
from ._utils.variationer import (
Variationer,
VariationDescriptor,
VariationSpecsError,
DescriptorDict,
)

View File

@@ -8,8 +8,9 @@ import numpy as np
import toml import toml
from ..const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1, Z_FN from ..const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1, Z_FN
from .parameter import Parameters from .parameter import Configuration, Parameters
from .utils import fiber_folder, update_path, save_parameters from .utils import fiber_folder, save_parameters
from .pbar import PBars
from .variationer import VariationDescriptor, Variationer from .variationer import VariationDescriptor, Variationer
@@ -29,21 +30,32 @@ def convert_sim_folder(path: os.PathLike):
path = Path(path) path = Path(path)
config_paths, configs = load_config_sequence(path) config_paths, configs = load_config_sequence(path)
master_config = dict(name=path.name, Fiber=configs) master_config = dict(name=path.name, Fiber=configs)
with open(path / "initial_config.toml", "w") as f:
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
configuration = Configuration(path / "initial_config.toml")
new_fiber_paths: list[Path] = [ new_fiber_paths: list[Path] = [
path / fiber_folder(i, path.name, cfg["name"]) for i, cfg in enumerate(configs) path / fiber_folder(i, path.name, cfg["name"]) for i, cfg in enumerate(configs)
] ]
for p in new_fiber_paths: for p in new_fiber_paths:
p.mkdir(exist_ok=True) p.mkdir(exist_ok=True)
var = Variationer(c["variable"] for c in configs) repeat = configs[0].get("repeat", 1)
paths: dict[Path, VariationDescriptor] = { pbar = PBars(configuration.total_num_steps, "Converting")
path / descr.branch.formatted_descriptor(): descr for descr in var.iterate()
old_paths: dict[Path, VariationDescriptor] = {
path / descr.branch.formatted_descriptor(): (descr, param.final_path)
for descr, param in configuration
} }
for p in paths:
# create map from old to new path
pprint(old_paths)
quit()
for p in old_paths:
if not p.is_dir(): if not p.is_dir():
raise FileNotFoundError(f"missing {p} from {path}") raise FileNotFoundError(f"missing {p} from {path}")
processed_paths: Set[Path] = set() processed_paths: Set[Path] = set()
for old_variation_path, descriptor in paths.items(): # fiberA=0, fiber B=0 for old_variation_path, descriptor in old_paths.items(): # fiberA=0, fiber B=0
vary_parts = old_variation_path.name.split("fiber")[1:] vary_parts = old_variation_path.name.split("fiber")[1:]
identifiers = [ identifiers = [
"".join("fiber" + el for el in vary_parts[: i + 1]).strip() "".join("fiber" + el for el in vary_parts[: i + 1]).strip()
@@ -72,6 +84,9 @@ def convert_sim_folder(path: os.PathLike):
new_variation_path / SPEC1_FN_N.format(spec_num - cum_z_num, j), new_variation_path / SPEC1_FN_N.format(spec_num - cum_z_num, j),
spec1, spec1,
) )
pbar.update()
else:
pbar.update(value=repeat)
old_spec.unlink() old_spec.unlink()
if move: if move:
if i > 0: if i > 0:
@@ -88,8 +103,6 @@ def convert_sim_folder(path: os.PathLike):
for cp in config_paths: for cp in config_paths:
cp.unlink() cp.unlink()
with open(path / "initial_config.toml", "w") as f:
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
def main(): def main():

View File

@@ -24,7 +24,7 @@ from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
from ..errors import EvaluatorError, NoDefaultError from ..errors import EvaluatorError, NoDefaultError
from ..logger import get_logger from ..logger import get_logger
from ..physics import fiber, materials, pulse, units from ..physics import fiber, materials, pulse, units
from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path
T = TypeVar("T") T = TypeVar("T")
@@ -75,6 +75,7 @@ VALID_VARIABLE = {
"interpolation_degree", "interpolation_degree",
"ideal_gas", "ideal_gas",
"length", "length",
"num",
} }
MANDATORY_PARAMETERS = [ MANDATORY_PARAMETERS = [
@@ -519,6 +520,12 @@ class Parameters(_AbstractParameters):
return out return out
@property
def final_path(self) -> Path:
if self.output_path is not None:
return update_path(self.output_path)
return None
class Rule: class Rule:
def __init__( def __init__(
@@ -777,6 +784,7 @@ class Configuration:
""" """
fiber_configs: list[dict[str, Any]] fiber_configs: list[dict[str, Any]]
vary_dicts: list[dict[str, list]]
master_config: dict[str, Any] master_config: dict[str, Any]
fiber_paths: list[Path] fiber_paths: list[Path]
num_sim: int num_sim: int
@@ -814,9 +822,11 @@ class Configuration:
self, self,
final_config_path: os.PathLike, final_config_path: os.PathLike,
overwrite: bool = True, overwrite: bool = True,
wait: bool = False,
skip_callback: Callable[[int], None] = None, skip_callback: Callable[[int], None] = None,
): ):
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
self.wait = wait
self.overwrite = overwrite self.overwrite = overwrite
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path) self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
@@ -842,7 +852,8 @@ class Configuration:
config.setdefault("name", Parameters.name.default) config.setdefault("name", Parameters.name.default)
self.z_num += config["z_num"] self.z_num += config["z_num"]
fiber_names.add(config["name"]) fiber_names.add(config["name"])
self.variationer.append(config.pop("variable")) vary_dict = config.pop("variable")
self.variationer.append(vary_dict)
self.fiber_paths.append( self.fiber_paths.append(
utils.ensure_folder( utils.ensure_folder(
self.final_path / fiber_folder(i, self.name, config["name"]), self.final_path / fiber_folder(i, self.name, config["name"]),
@@ -850,9 +861,11 @@ class Configuration:
prevent_overwrite=not self.overwrite, prevent_overwrite=not self.overwrite,
) )
) )
self.__validate_variable(config) self.__validate_variable(vary_dict)
self.num_fibers += 1 self.num_fibers += 1
Evaluator.evaluate_default(config, True) Evaluator.evaluate_default(
self.__build_base_config() | config | {k: v[0] for k, v in vary_dict.items()}, True
)
self.num_sim = self.variationer.var_num() self.num_sim = self.variationer.var_num()
self.total_num_steps = sum( self.total_num_steps = sum(
config["z_num"] * self.variationer.var_num(i) config["z_num"] * self.variationer.var_num(i)
@@ -860,8 +873,13 @@ class Configuration:
) )
self.parallel = self.master_config.get("parallel", Parameters.parallel.default) self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
def __validate_variable(self, config: dict[str, Any]): def __build_base_config(self):
for k, v in config.get("variable", {}).items(): cfg = self.fiber_configs[0].copy()
vary = cfg.pop("variable", {})
return cfg | {k: v[0] for k, v in vary.items()}
def __validate_variable(self, vary_dict: dict[str, list]):
for k, v in vary_dict.items():
p = getattr(Parameters, k) p = getattr(Parameters, k)
validator_list(p.validator)("variable " + k, v) validator_list(p.validator)("variable " + k, v)
if k not in VALID_VARIABLE: if k not in VALID_VARIABLE:
@@ -873,7 +891,6 @@ class Configuration:
for i in range(self.num_fibers): for i in range(self.num_fibers):
for sim_config in self.iterate_single_fiber(i): for sim_config in self.iterate_single_fiber(i):
params = Parameters(**sim_config.config) params = Parameters(**sim_config.config)
params.compute()
yield sim_config.descriptor, params yield sim_config.descriptor, params
def iterate_single_fiber( def iterate_single_fiber(
@@ -943,6 +960,8 @@ class Configuration:
config dictionary. The only key possibly modified is 'prev_data_dir', which config dictionary. The only key possibly modified is 'prev_data_dir', which
gets set if the simulation is partially completed gets set if the simulation is partially completed
""" """
if not self.wait:
return self.Action.RUN, sim_config.config
out_status, num = self.sim_status(sim_config.output_path, sim_config.config) out_status, num = self.sim_status(sim_config.output_path, sim_config.config)
if out_status == self.State.COMPLETE: if out_status == self.State.COMPLETE:
return self.Action.SKIP, sim_config.config return self.Action.SKIP, sim_config.config

View File

@@ -1,9 +1,9 @@
import abc import multiprocessing
import os import os
import random import random
import multiprocessing
import threading import threading
import typing import typing
from collections import abc
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Iterable, Union from typing import Iterable, Union
@@ -24,7 +24,19 @@ class PBars:
head_kwargs=None, head_kwargs=None,
worker_kwargs=None, worker_kwargs=None,
) -> "PBars": ) -> "PBars":
"""creates a PBars obj
Parameters
----------
task : int | Iterable
if int : total length of the main task
if Iterable : behaves like tqdm
desc : str
description of the main task
num_sub_bars : int
number of sub-tasks
"""
self.id = random.randint(100000, 999999) self.id = random.randint(100000, 999999)
try: try:
self.width = os.get_terminal_size().columns self.width = os.get_terminal_size().columns

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from functools import cache from functools import cache
from pathlib import Path from pathlib import Path
from string import printable as str_printable from string import printable as str_printable
from typing import Any, Callable from typing import Any, Callable, Iterator, Set
import numpy as np import numpy as np
import toml import toml
@@ -236,3 +236,24 @@ def update_path(p: str) -> str:
def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str: def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str:
return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name]) return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name])
def iter_simulations(path: os.PathLike) -> list[Path]:
"""finds simulations folders contained in a parent directory
Parameters
----------
path : os.PathLike
parent path
Yields
-------
Path
Absolute Path to the simulation folder
"""
paths: list[Path] = []
for pwd, _, files in os.walk(path):
if PARAM_FN in files:
paths.append(Path(pwd))
paths.sort(key=lambda el: el.parent.name)
return [p for p in paths if p.parent.name == paths[-1].parent.name]

View File

@@ -2,14 +2,17 @@ from math import prod
import itertools import itertools
from collections.abc import MutableMapping, Sequence from collections.abc import MutableMapping, Sequence
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, Iterable, Optional, Union from typing import Any, Callable, Generator, Generic, Iterable, Iterator, Optional, TypeVar, Union
import numpy as np import numpy as np
from pydantic import validator from pydantic import validator
from pydantic.main import BaseModel
from ..const import PARAM_SEPARATOR from ..const import PARAM_SEPARATOR
from . import utils from . import utils
T = TypeVar("T")
class VariationSpecsError(ValueError): class VariationSpecsError(ValueError):
pass pass
@@ -111,15 +114,15 @@ class Variationer:
return max(1, prod(prod(el) for el in self.all_indices[: index + 1])) return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
class VariationDescriptor(utils.HashableBaseModel): class VariationDescriptor(BaseModel):
raw_descr: tuple[tuple[tuple[str, Any], ...], ...] raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
index: tuple[tuple[int, ...], ...] index: tuple[tuple[int, ...], ...]
separator: str = "fiber" separator: str = "fiber"
_format_registry: dict[str, Callable[..., str]] = {} _format_registry: dict[str, Callable[..., str]] = {}
__ids: dict[int, int] = {} __ids: dict[int, int] = {}
def __str__(self) -> str: class Config:
return self.formatted_descriptor(add_identifier=False) allow_mutation = False
def formatted_descriptor(self, add_identifier=False) -> str: def formatted_descriptor(self, add_identifier=False) -> str:
"""formats a variable list into a str such that each simulation has a unique """formats a variable list into a str such that each simulation has a unique
@@ -183,6 +186,24 @@ class VariationDescriptor(utils.HashableBaseModel):
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
) )
def __str__(self) -> str:
return self.formatted_descriptor(add_identifier=False)
def __lt__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr < other.raw_descr
def __le__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr <= other.raw_descr
def __gt__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr > other.raw_descr
def __ge__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr >= other.raw_descr
def __hash__(self) -> int:
return hash(self.raw_descr)
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]: def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
"""updates a dictionary with the value of the descriptor """updates a dictionary with the value of the descriptor
@@ -252,3 +273,34 @@ class BranchDescriptor(VariationDescriptor):
@validator("raw_descr") @validator("raw_descr")
def validate_raw_descr(cls, v): def validate_raw_descr(cls, v):
return tuple(tuple(el for el in variable if el[0] != "num") for variable in v) return tuple(tuple(el for el in variable if el[0] != "num") for variable in v)
class DescriptorDict(Generic[T]):
def __init__(self, dico: dict[VariationDescriptor, T] = None):
self.dico: dict[tuple[tuple[tuple[str, Any], ...], ...], tuple[VariationDescriptor, T]] = {}
if dico is not None:
for k, v in dico.items():
self[k] = v
def __setitem__(self, key: VariationDescriptor, value: T):
if not isinstance(key, VariationDescriptor):
raise TypeError("key must be a VariationDescriptor instance")
self.dico[key.raw_descr] = (key, value)
def __getitem__(
self, key: Union[VariationDescriptor, tuple[tuple[tuple[str, Any], ...], ...]]
) -> T:
if isinstance(key, VariationDescriptor):
return self.dico[key.raw_descr][1]
else:
return self.dico[key][1]
def items(self) -> Iterator[tuple[VariationDescriptor, T]]:
for k, v in self.dico.items():
yield k, v[1]
def keys(self) -> list[VariationDescriptor]:
return [v[0] for v in self.dico.values()]
def values(self) -> list[T]:
return [v[1] for v in self.dico.values()]

View File

@@ -491,6 +491,7 @@ class Simulations:
def _run_available(self): def _run_available(self):
for variable, params in self.configuration: for variable, params in self.configuration:
params.compute()
v_list_str = variable.formatted_descriptor(True) v_list_str = variable.formatted_descriptor(True)
save_parameters(params.prepare_for_dump(), Path(params.output_path)) save_parameters(params.prepare_for_dump(), Path(params.output_path))
@@ -718,7 +719,7 @@ def run_simulation(
config_file: os.PathLike, config_file: os.PathLike,
method: Union[str, Type[Simulations]] = None, method: Union[str, Type[Simulations]] = None,
): ):
config = Configuration(config_file) config = Configuration(config_file, wait=True)
sim = new_simulation(config, method) sim = new_simulation(config, method)
sim.run() sim.run()
@@ -760,6 +761,8 @@ def parallel_RK4IP(
]: ]:
logger = get_logger(__name__) logger = get_logger(__name__)
params = list(Configuration(config)) params = list(Configuration(config))
for _, param in params:
param.compute()
n = len(params) n = len(params)
z_num = params[0][1].z_num z_num = params[0][1].z_num

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
import warnings
from collections.abc import Sequence from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union
@@ -12,8 +13,8 @@ from pydantic import BaseModel, DirectoryPath, root_validator
from . import math from . import math
from ._utils import load_spectrum from ._utils import load_spectrum
from ._utils.parameter import Parameters from ._utils.parameter import Parameters
from ._utils.utils import PlotRange from ._utils.utils import PlotRange, iter_simulations
from .const import SPECN_FN1, PARAM_FN, SPEC1_FN_N, SPEC1_FN from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1
from .logger import get_logger from .logger import get_logger
from .physics import pulse, units from .physics import pulse, units
from .plotting import ( from .plotting import (
@@ -131,11 +132,10 @@ class SimulationSeries:
def __init__(self, path: os.PathLike): def __init__(self, path: os.PathLike):
self.logger = get_logger() self.logger = get_logger()
path = Path(path) for self.path in iter_simulations(path):
subdirs = [el for el in path.glob("*") if (el / PARAM_FN).exists()] break
while not (path / PARAM_FN).exists() and len(subdirs) == 1: else:
path = subdirs[0] raise FileNotFoundError(f"No simulation in {path}")
self.path = path
self.params = Parameters.load(self.path / PARAM_FN) self.params = Parameters.load(self.path / PARAM_FN)
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"]) self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
self.t = self.params.t self.t = self.params.t
@@ -356,24 +356,22 @@ class SimulationSeries:
class Pulse(Sequence): class Pulse(Sequence):
path: Path def __new__(cls, path: os.PathLike):
default_ind: Optional[int] warnings.warn(
params: Parameters "You are using the legacy version of the pulse loader. "
z: np.ndarray "Please consider updating your data with scgenerator.convert_sim_folder "
namx: int "and loading data with the SimulationSeries class"
t: np.ndarray )
w: np.ndarray if (Path(path) / SPECN_FN1.format(0)).exists():
w_order: np.ndarray return LegacyPulse(path)
return SimulationSeries(path)
def __new__(cls, path: os.PathLike, *args, **kwargs) -> "Pulse": def __getitem__(self, key) -> Spectrum:
try: raise NotImplementedError()
if load_spectrum(Path(path) / SPECN_FN1.format(0)).ndim == 2:
return super().__new__(LegacyPulse)
except FileNotFoundError:
pass
return super().__new__(cls)
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
class LegacyPulse(Sequence):
def __init__(self, path: os.PathLike):
"""load a data folder as a pulse """load a data folder as a pulse
Parameters Parameters
@@ -388,6 +386,35 @@ class Pulse(Sequence):
FileNotFoundError FileNotFoundError
path does not contain proper data path does not contain proper data
""" """
self.logger = get_logger(__name__)
self.path = Path(path)
if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = Parameters.load(self.path / "params.toml")
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
if self.params.fiber_map is None:
self.params.fiber_map = [(0.0, self.params.name)]
try:
self.z = np.load(os.path.join(path, "z.npy"))
except FileNotFoundError:
if self.params is not None:
self.z = self.params.z_targets
else:
raise
self.nmax = len(list(self.path.glob("spectra_*.npy")))
if self.nmax <= 0:
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
self.t = self.params.t
w = math.wspace(self.t) + units.m(self.params.wavelength)
self.w_order = np.argsort(w)
self.w = w
self.wl = units.m.inv(self.w)
self.params.w = self.w
self.params.z_targets = self.z
def __iter__(self): def __iter__(self):
""" """
@@ -404,6 +431,73 @@ class Pulse(Sequence):
def __getitem__(self, key) -> Spectrum: def __getitem__(self, key) -> Spectrum:
return self.all_spectra(key) return self.all_spectra(key)
def intensity(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self._to_wl_int,
FREQ=self._to_freq_int,
AFREQ=self._to_afreq_int,
TIME=self._to_time_int,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
def _to_wl_int(self, spectrum):
return units.to_WL(math.abs2(spectrum), spectrum.wl)
def _to_freq_int(self, spectrum):
return math.abs2(spectrum)
def _to_afreq_int(self, spectrum):
return math.abs2(spectrum)
def _to_time_int(self, spectrum):
return math.abs2(np.fft.ifft(spectrum))
def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self._to_wl_amp,
FREQ=self._to_freq_amp,
AFREQ=self._to_afreq_amp,
TIME=self._to_time_amp,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
def _to_wl_amp(self, spectrum):
return (
np.sqrt(
units.to_WL(
math.abs2(spectrum),
spectrum.wl,
)
)
* spectrum
/ np.abs(spectrum)
)
def _to_freq_amp(self, spectrum):
return spectrum
def _to_afreq_amp(self, spectrum):
return spectrum
def _to_time_amp(self, spectrum):
return np.fft.ifft(spectrum)
def all_spectra(self, ind=None) -> Spectrum: def all_spectra(self, ind=None) -> Spectrum:
""" """
loads the data already simulated. loads the data already simulated.
@@ -425,10 +519,7 @@ class Pulse(Sequence):
# Check if file exists and assert how many z positions there are # Check if file exists and assert how many z positions there are
if ind is None: if ind is None:
if self.default_ind is None:
ind = range(self.nmax) ind = range(self.nmax)
else:
ind = self.default_ind
if isinstance(ind, (int, np.integer)): if isinstance(ind, (int, np.integer)):
ind = [ind] ind = [ind]
elif isinstance(ind, (float, np.floating)): elif isinstance(ind, (float, np.floating)):
@@ -452,7 +543,12 @@ class Pulse(Sequence):
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1) return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
def _load1(self, i: int): def _load1(self, i: int):
pass if i < 0:
i = self.nmax + i
spec = load_spectrum(self.path / SPECN_FN1.format(i))
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.params)
return spec
def plot_2D( def plot_2D(
self, self,
@@ -554,46 +650,3 @@ class Pulse(Sequence):
index index
""" """
return math.argclosest(self.z, z) return math.argclosest(self.z, z)
class LegacyPulse(Pulse):
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
print("old init called", path, default_ind)
self.logger = get_logger(__name__)
self.path = Path(path)
self.default_ind = default_ind
if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = Parameters.load(self.path / "params.toml")
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
if self.params.fiber_map is None:
self.params.fiber_map = [(0.0, self.params.name)]
try:
self.z = np.load(os.path.join(path, "z.npy"))
except FileNotFoundError:
if self.params is not None:
self.z = self.params.z_targets
else:
raise
self.nmax = len(list(self.path.glob("spectra_*.npy")))
if self.nmax <= 0:
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
self.t = self.params.t
w = math.wspace(self.t) + units.m(self.params.wavelength)
self.w_order = np.argsort(w)
self.w = w
self.wl = units.m.inv(self.w)
self.params.w = self.w
self.params.z_targets = self.z
def _load1(self, i: int):
if i < 0:
i = self.nmax + i
spec = load_spectrum(self.path / SPECN_FN1.format(i))
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.params)
return spec