working converter

This commit is contained in:
Benoît Sierro
2021-10-08 16:36:53 +02:00
parent fa72a6e136
commit fb880077ed
4 changed files with 106 additions and 395 deletions

View File

@@ -27,84 +27,67 @@ def load_config_sequence(path: os.PathLike) -> tuple[list[Path], list[dict[str,
def convert_sim_folder(path: os.PathLike): def convert_sim_folder(path: os.PathLike):
path = Path(path) path = Path(path).resolve()
config_paths, configs = load_config_sequence(path) config_paths, configs = load_config_sequence(path)
master_config = dict(name=path.name, Fiber=configs) master_config = dict(name=path.name, Fiber=configs)
with open(path / "initial_config.toml", "w") as f: with open(path / "initial_config.toml", "w") as f:
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder()) toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
configuration = Configuration(path / "initial_config.toml") configuration = Configuration(path / "initial_config.toml", final_output_path=path)
new_fiber_paths: list[Path] = [
path / fiber_folder(i, path.name, cfg["name"]) for i, cfg in enumerate(configs)
]
for p in new_fiber_paths:
p.mkdir(exist_ok=True)
repeat = configs[0].get("repeat", 1)
pbar = PBars(configuration.total_num_steps, "Converting") pbar = PBars(configuration.total_num_steps, "Converting")
old_paths: dict[Path, VariationDescriptor] = { new_paths: dict[VariationDescriptor, Parameters] = dict(configuration)
path / descr.branch.formatted_descriptor(): (descr, param.final_path) old_paths: Set[Path] = set()
for descr, param in configuration old2new: list[tuple[Path, VariationDescriptor, Parameters, tuple[int, int]]] = []
} for descriptor, params in configuration.iterate_single_fiber(-1):
old_path = path / descriptor.branch.formatted_descriptor()
if not Path(old_path).is_dir():
raise FileNotFoundError(f"missing {old_path} from {path}. Aborting.")
old_paths.add(old_path)
for d in descriptor.iter_parents():
z_num_start = sum(c["z_num"] for c in configs[: d.num_fibers - 1])
z_limits = (z_num_start, z_num_start + params.z_num)
old2new.append((old_path, d, new_paths[d], z_limits))
# create map from old to new path
pprint(old_paths)
quit()
for p in old_paths:
if not p.is_dir():
raise FileNotFoundError(f"missing {p} from {path}")
processed_paths: Set[Path] = set() processed_paths: Set[Path] = set()
for old_variation_path, descriptor in old_paths.items(): # fiberA=0, fiber B=0 processed_specs: Set[VariationDescriptor] = set()
vary_parts = old_variation_path.name.split("fiber")[1:]
identifiers = [
"".join("fiber" + el for el in vary_parts[: i + 1]).strip()
for i in range(len(vary_parts))
]
cum_z_num = 0
for i, (fiber_path, new_identifier) in enumerate(zip(new_fiber_paths, identifiers)):
config = descriptor.update_config(configs[i], i)
new_variation_path = fiber_path / new_identifier
z_num = config["z_num"]
move = new_variation_path not in processed_paths
os.makedirs(new_variation_path, exist_ok=True)
processed_paths.add(new_variation_path)
for spec_num in range(cum_z_num, cum_z_num + z_num): for old_path, descr, new_params, (start_z, end_z) in old2new:
old_spec = old_variation_path / SPECN_FN1.format(spec_num) move_specs = descr not in processed_specs
if move: processed_specs.add(descr)
spec_data = np.load(old_spec) if (parent := descr.parent) is not None:
for j, spec1 in enumerate(spec_data): new_params.prev_data_dir = str(new_paths[parent].final_path)
if j == 0: save_parameters(new_params.prepare_for_dump(), new_params.final_path)
np.save( for spec_num in range(start_z, end_z):
new_variation_path / SPEC1_FN.format(spec_num - cum_z_num), spec1 old_spec = old_path / SPECN_FN1.format(spec_num)
) if move_specs:
else: _mv_specs(pbar, new_params, start_z, spec_num, old_spec)
np.save( old_spec.unlink()
new_variation_path / SPEC1_FN_N.format(spec_num - cum_z_num, j), if old_path not in processed_paths:
spec1, (old_path / PARAM_FN).unlink()
) (old_path / Z_FN).unlink()
pbar.update() processed_paths.add(old_path)
else:
pbar.update(value=repeat) for old_path in processed_paths:
old_spec.unlink() old_path.rmdir()
if move:
if i > 0:
config["prev_data_dir"] = str(
(new_fiber_paths[i - 1] / identifiers[i - 1]).resolve()
)
params = Parameters(**config)
params.compute()
save_parameters(params.prepare_for_dump(), new_variation_path)
cum_z_num += z_num
(old_variation_path / PARAM_FN).unlink()
(old_variation_path / Z_FN).unlink()
old_variation_path.rmdir()
for cp in config_paths: for cp in config_paths:
cp.unlink() cp.unlink()
def _mv_specs(pbar: PBars, new_params: Parameters, start_z: int, spec_num: int, old_spec: Path):
os.makedirs(new_params.final_path, exist_ok=True)
spec_data = np.load(old_spec)
for j, spec1 in enumerate(spec_data):
if j == 0:
np.save(new_params.final_path / SPEC1_FN.format(spec_num - start_z), spec1)
else:
np.save(
new_params.final_path / SPEC1_FN_N.format(spec_num - start_z, j),
spec1,
)
pbar.update()
def main(): def main():
convert_sim_folder(sys.argv[1]) convert_sim_folder(sys.argv[1])

View File

@@ -12,7 +12,18 @@ from copy import copy, deepcopy
from dataclasses import asdict, dataclass, fields from dataclasses import asdict, dataclass, fields
from functools import cache, lru_cache from functools import cache, lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequence, TypeVar, Union from typing import (
Any,
Callable,
Generator,
Iterable,
Iterator,
Literal,
Optional,
Sequence,
TypeVar,
Union,
)
import numpy as np import numpy as np
from numpy.lib import isin from numpy.lib import isin
@@ -523,7 +534,7 @@ class Parameters(_AbstractParameters):
@property @property
def final_path(self) -> Path: def final_path(self) -> Path:
if self.output_path is not None: if self.output_path is not None:
return update_path(self.output_path) return Path(update_path(self.output_path))
return None return None
@@ -820,22 +831,26 @@ class Configuration:
def __init__( def __init__(
self, self,
final_config_path: os.PathLike, config_path: os.PathLike,
overwrite: bool = True, overwrite: bool = True,
wait: bool = False, wait: bool = False,
skip_callback: Callable[[int], None] = None, skip_callback: Callable[[int], None] = None,
final_output_path: os.PathLike = None,
): ):
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
self.wait = wait self.wait = wait
self.overwrite = overwrite self.overwrite = overwrite
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path) self.final_path, self.fiber_configs = utils.load_config_sequence(config_path)
self.final_path = env.get(env.OUTPUT_PATH, self.final_path)
if final_output_path is not None:
self.final_path = final_output_path
self.final_path = utils.ensure_folder( self.final_path = utils.ensure_folder(
Path(env.get(env.OUTPUT_PATH, self.final_path)), Path(self.final_path),
mkdir=False, mkdir=False,
prevent_overwrite=not self.overwrite, prevent_overwrite=not self.overwrite,
) )
self.master_config = self.fiber_configs[0] self.master_config = self.fiber_configs[0].copy()
self.name = self.final_path.name self.name = self.final_path.name
self.z_num = 0 self.z_num = 0
self.total_num_steps = 0 self.total_num_steps = 0
@@ -874,7 +889,7 @@ class Configuration:
self.parallel = self.master_config.get("parallel", Parameters.parallel.default) self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
def __build_base_config(self): def __build_base_config(self):
cfg = self.fiber_configs[0].copy() cfg = self.master_config.copy()
vary = cfg.pop("variable", {}) vary = cfg.pop("variable", {})
return cfg | {k: v[0] for k, v in vary.items()} return cfg | {k: v[0] for k, v in vary.items()}
@@ -887,15 +902,11 @@ class Configuration:
if len(v) == 0: if len(v) == 0:
raise ValueError(f"variable parameter {k!r} must not be empty") raise ValueError(f"variable parameter {k!r} must not be empty")
def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]: def __iter__(self) -> Iterator[tuple[VariationDescriptor, Parameters]]:
for i in range(self.num_fibers): for i in range(self.num_fibers):
for sim_config in self.iterate_single_fiber(i): yield from self.iterate_single_fiber(i)
params = Parameters(**sim_config.config)
yield sim_config.descriptor, params
def iterate_single_fiber( def iterate_single_fiber(self, index: int) -> Iterator[tuple[VariationDescriptor, Parameters]]:
self, index: int
) -> Generator["Configuration.__SimConfig", None, None]:
"""iterates through the parameters of only one fiber. It takes care of recovering partially """iterates through the parameters of only one fiber. It takes care of recovering partially
completed simulations, skipping complete ones and waiting for the previous fiber to finish completed simulations, skipping complete ones and waiting for the previous fiber to finish
@@ -909,6 +920,8 @@ class Configuration:
__SimConfig __SimConfig
configuration obj configuration obj
""" """
if index < 0:
index = self.num_fibers + index
sim_dict: dict[Path, Configuration.__SimConfig] = {} sim_dict: dict[Path, Configuration.__SimConfig] = {}
for descriptor in self.variationer.iterate(index): for descriptor in self.variationer.iterate(index):
cfg = descriptor.update_config(self.fiber_configs[index]) cfg = descriptor.update_config(self.fiber_configs[index])
@@ -929,7 +942,7 @@ class Configuration:
task, config_dict = self.__decide(sim_config) task, config_dict = self.__decide(sim_config)
if task == self.Action.RUN: if task == self.Action.RUN:
sim_dict.pop(data_dir) sim_dict.pop(data_dir)
yield sim_config yield sim_config.descriptor, Parameters(**sim_config.config)
if "recovery_last_stored" in config_dict and self.skip_callback is not None: if "recovery_last_stored" in config_dict and self.skip_callback is not None:
self.skip_callback(config_dict["recovery_last_stored"]) self.skip_callback(config_dict["recovery_last_stored"])
break break

View File

@@ -121,6 +121,19 @@ class VariationDescriptor(BaseModel):
_format_registry: dict[str, Callable[..., str]] = {} _format_registry: dict[str, Callable[..., str]] = {}
__ids: dict[int, int] = {} __ids: dict[int, int] = {}
@classmethod
def register_formatter(cls, p_name: str, func: Callable[..., str]):
"""register a function that formats a particular parameter
Parameters
----------
p_name : str
name of the parameter
func : Callable[..., str]
function that takes as single argument the value of the parameter and returns a string
"""
cls._format_registry[p_name] = func
class Config: class Config:
allow_mutation = False allow_mutation = False
@@ -152,19 +165,6 @@ class VariationDescriptor(BaseModel):
self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name
) )
@classmethod
def register_formatter(cls, p_name: str, func: Callable[..., str]):
"""register a function that formats a particular parameter
Parameters
----------
p_name : str
name of the parameter
func : Callable[..., str]
function that takes as single argument the value of the parameter and returns a string
"""
cls._format_registry[p_name] = func
def format_value(self, name: str, value) -> str: def format_value(self, name: str, value) -> str:
if value is True or value is False: if value is True or value is False:
return str(value) return str(value)
@@ -201,9 +201,15 @@ class VariationDescriptor(BaseModel):
def __ge__(self, other: "VariationDescriptor") -> bool: def __ge__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr >= other.raw_descr return self.raw_descr >= other.raw_descr
def __eq__(self, other: "VariationDescriptor") -> bool:
return self.raw_descr == other.raw_descr
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self.raw_descr) return hash(self.raw_descr)
def __contains__(self, other: "VariationDescriptor") -> bool:
return all(el in self.raw_descr for el in other.raw_descr)
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]: def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
"""updates a dictionary with the value of the descriptor """updates a dictionary with the value of the descriptor
@@ -223,6 +229,11 @@ class VariationDescriptor(BaseModel):
out_cfg.pop("variable", None) out_cfg.pop("variable", None)
return out_cfg | {k: v for k, v in self.raw_descr[index]} return out_cfg | {k: v for k, v in self.raw_descr[index]}
def iter_parents(self) -> Iterator["VariationDescriptor"]:
if (p := self.parent) is not None:
yield from p.iter_parents()
yield self
@property @property
def flat(self) -> list[tuple[str, Any]]: def flat(self) -> list[tuple[str, Any]]:
out = [] out = []
@@ -260,6 +271,10 @@ class VariationDescriptor(BaseModel):
raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator raw_descr=self.raw_descr[:-1], index=self.index[:-1], separator=self.separator
) )
@property
def num_fibers(self) -> int:
return len(self.raw_descr)
class BranchDescriptor(VariationDescriptor): class BranchDescriptor(VariationDescriptor):
__ids: dict[int, int] = {} __ids: dict[int, int] = {}

View File

@@ -1,20 +1,17 @@
from __future__ import annotations from __future__ import annotations
import os import os
import warnings
from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union from typing import Any, Callable, Iterator, Optional, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from pydantic import BaseModel, DirectoryPath, root_validator
from . import math from . import math
from ._utils import load_spectrum from ._utils import load_spectrum
from ._utils.parameter import Parameters from ._utils.parameter import Parameters
from ._utils.utils import PlotRange, iter_simulations from ._utils.utils import PlotRange, iter_simulations
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1 from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N
from .logger import get_logger from .logger import get_logger
from .physics import pulse, units from .physics import pulse, units
from .plotting import ( from .plotting import (
@@ -111,7 +108,7 @@ class Spectrum(np.ndarray):
return self.params.l[np.argmax(self.wl_int, axis=-1)] return self.params.l[np.argmax(self.wl_int, axis=-1)]
return np.array([s.wl_max for s in self]) return np.array([s.wl_max for s in self])
def mask_wl(self, pos: float, width: float) -> "Spectrum": def mask_wl(self, pos: float, width: float) -> Spectrum:
return self * np.exp( return self * np.exp(
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2) -(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
) )
@@ -353,300 +350,3 @@ class SimulationSeries:
return self.spectra(*key) return self.spectra(*key)
else: else:
return self.spectra(key, None) return self.spectra(key, None)
class Pulse(Sequence):
def __new__(cls, path: os.PathLike):
warnings.warn(
"You are using the legacy version of the pulse loader. "
"Please consider updating your data with scgenerator.convert_sim_folder "
"and loading data with the SimulationSeries class"
)
if (Path(path) / SPECN_FN1.format(0)).exists():
return LegacyPulse(path)
return SimulationSeries(path)
def __getitem__(self, key) -> Spectrum:
raise NotImplementedError()
class LegacyPulse(Sequence):
def __init__(self, path: os.PathLike):
"""load a data folder as a pulse
Parameters
----------
path : os.PathLike
path to the data (folder containing .npy files)
default_ind : int | Iterable[int], optional
default indices to be loaded, by default None
Raises
------
FileNotFoundError
path does not contain proper data
"""
self.logger = get_logger(__name__)
self.path = Path(path)
if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = Parameters.load(self.path / "params.toml")
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
if self.params.fiber_map is None:
self.params.fiber_map = [(0.0, self.params.name)]
try:
self.z = np.load(os.path.join(path, "z.npy"))
except FileNotFoundError:
if self.params is not None:
self.z = self.params.z_targets
else:
raise
self.nmax = len(list(self.path.glob("spectra_*.npy")))
if self.nmax <= 0:
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
self.t = self.params.t
w = math.wspace(self.t) + units.m(self.params.wavelength)
self.w_order = np.argsort(w)
self.w = w
self.wl = units.m.inv(self.w)
self.params.w = self.w
self.params.z_targets = self.z
def __iter__(self):
"""
similar to all_spectra but works as an iterator
"""
self.logger.debug(f"iterating through {self.path}")
for i in range(self.nmax):
yield self._load1(i)
def __len__(self):
return self.nmax
def __getitem__(self, key) -> Spectrum:
return self.all_spectra(key)
def intensity(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self._to_wl_int,
FREQ=self._to_freq_int,
AFREQ=self._to_afreq_int,
TIME=self._to_time_int,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
def _to_wl_int(self, spectrum):
return units.to_WL(math.abs2(spectrum), spectrum.wl)
def _to_freq_int(self, spectrum):
return math.abs2(spectrum)
def _to_afreq_int(self, spectrum):
return math.abs2(spectrum)
def _to_time_int(self, spectrum):
return math.abs2(np.fft.ifft(spectrum))
def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self._to_wl_amp,
FREQ=self._to_freq_amp,
AFREQ=self._to_afreq_amp,
TIME=self._to_time_amp,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
def _to_wl_amp(self, spectrum):
return (
np.sqrt(
units.to_WL(
math.abs2(spectrum),
spectrum.wl,
)
)
* spectrum
/ np.abs(spectrum)
)
def _to_freq_amp(self, spectrum):
return spectrum
def _to_afreq_amp(self, spectrum):
return spectrum
def _to_time_amp(self, spectrum):
return np.fft.ifft(spectrum)
def all_spectra(self, ind=None) -> Spectrum:
"""
loads the data already simulated.
defauft shape is (z_targets, n, nt)
Parameters
----------
ind : int or list of int
if only certain spectra are desired
Returns
----------
spectra : array of shape (nz, m, nt)
array of complex spectra (pulse at nz positions consisting
of nm simulation on a nt size grid)
"""
self.logger.debug(f"opening {self.path}")
# Check if file exists and assert how many z positions there are
if ind is None:
ind = range(self.nmax)
if isinstance(ind, (int, np.integer)):
ind = [ind]
elif isinstance(ind, (float, np.floating)):
ind = [self.z_ind(ind)]
elif isinstance(ind[0], (float, np.floating)):
ind = [self.z_ind(ii) for ii in ind]
# Load the spectra
spectra = []
for i in ind:
spectra.append(self._load1(i))
spectra = Spectrum(spectra, self.params)
self.logger.debug(f"all spectra from {self.path} successfully loaded")
if len(ind) == 1:
return spectra[0]
else:
return spectra
def all_fields(self, ind=None):
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
def _load1(self, i: int):
if i < 0:
i = self.nmax + i
spec = load_spectrum(self.path / SPECN_FN1.format(i))
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.params)
return spec
def plot_2D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: Union[int, Iterable[int]] = None,
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
return propagation_plot(vals, plot_range, self.params, ax, **kwargs)
def plot_1D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: int,
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
return single_position_plot(vals, plot_range, self.params, ax, **kwargs)
def plot_mean(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: int,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, slice(None))
return mean_values_plot(vals, plot_range, self.params, ax, **kwargs)
def retrieve_plot_values(
self, plot_range: PlotRange, z_pos: Optional[Union[int, float]], sim_ind: Optional[int]
):
if plot_range.unit.type == "TIME":
vals = self.all_fields(ind=z_pos)
else:
vals = self.all_spectra(ind=z_pos)
if sim_ind is None:
return vals
elif z_pos is None:
return vals[:, sim_ind]
else:
return vals[sim_ind]
def rin_propagation(
self, left: float, right: float, unit: str
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""returns the RIN as function of unit and z
Parameters
----------
left : float
left limit in unit
right : float
right limit in unit
unit : str
unit descriptor
Returns
-------
x : np.ndarray, shape (nt,)
x axis
y : np.ndarray, shape (z_num, )
y axis
rin_prop : np.ndarray, shape (z_num, nt)
RIN
"""
spectra = []
for spec in np.moveaxis(self.all_spectra(), 1, 0):
x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False)
spectra.append(tmp)
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
def z_ind(self, z: float) -> int:
"""return the closest z index to the given target
Parameters
----------
z : float
target
Returns
-------
int
index
"""
return math.argclosest(self.z, z)