partway trough big data structure revamp

This commit is contained in:
Benoît Sierro
2021-09-28 13:28:00 +02:00
parent 695ac3bd73
commit 262a5b9701
26 changed files with 349 additions and 578 deletions

19
play.py
View File

@@ -4,22 +4,3 @@ import scgenerator as sc
import matplotlib.pyplot as plt
from pathlib import Path
from pprint import pprint
def _main():
print(os.getcwd())
for v_list, params in sc.Configuration("PM1550+PM2000D+PM1550/Pos30000.toml"):
print(params.fiber_map)
def main():
drr = os.getcwd()
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations")
try:
_main()
finally:
os.chdir(drr)
if __name__ == "__main__":
main()

View File

@@ -1,9 +1,10 @@
from . import math, utils
from . import math
from .math import abs2, argclosest, span
from .physics import fiber, materials, pulse, simulate, units
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
from .spectra import Pulse, Spectrum
from .utils import Paths, open_config, parameter
from .utils.parameter import Configuration, Parameters
from .utils.utils import PlotRange
from ._utils import Paths, open_config, parameter
from ._utils.parameter import Configuration, Parameters
from ._utils.utils import PlotRange
from ._utils.variationer import Variationer, VariationDescriptor, VariationSpecsError

View File

@@ -25,7 +25,7 @@ import pkg_resources as pkg
import toml
from tqdm import tqdm
from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__
from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN1, Z_FN, __version__
from ..env import pbar_policy
from ..logger import get_logger
@@ -143,7 +143,8 @@ def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]:
return dico
def load_toml(descr: str) -> dict[str, Any]:
def load_toml(descr: os.PathLike) -> dict[str, Any]:
descr = str(descr)
if ":" in descr:
path, entry = descr.split(":", 1)
with open(path) as file:
@@ -188,6 +189,7 @@ def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]
params.setdefault("variable", {})
configs.append(loaded_config | params)
configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"]
configs[0]["variable"]["num"] = list(range(configs[0].get("repeat", 1)))
return Path(final_path), configs
@@ -341,7 +343,7 @@ def merge_path_tree(
for i, (z, merged_spectra) in enumerate(merge_spectra(path_tree)):
z_arr.append(z)
spec_out_name = SPECN_FN.format(i)
spec_out_name = SPECN_FN1.format(i)
np.save(destination / spec_out_name, merged_spectra)
if z_callback is not None:
z_callback(i)

View File

@@ -14,14 +14,15 @@ from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequen
import numpy as np
from numpy.lib import isin
from scgenerator.utils import ensure_folder, variationer
from .. import math, utils
from .. import math
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
from ..errors import EvaluatorError, NoDefaultError
from ..logger import get_logger
from ..physics import fiber, materials, pulse, units
from ..utils.variationer import VariationDescriptor, Variationer
from .._utils.variationer import VariationDescriptor, Variationer
from .. import _utils as utils
from .. import env
from .utils import func_rewrite, _mock_function, get_arg_names
T = TypeVar("T")
@@ -312,13 +313,6 @@ class Parameter:
return f"{num_str} {unit}"
def fiber_map_converter(d: dict[str, str]) -> list[tuple[float, str]]:
if isinstance(d, dict):
return [(float(k), v) for k, v in d.items()]
else:
return [(float(k), v) for k, v in d]
@dataclass
class Parameters:
"""
@@ -432,15 +426,13 @@ class Parameters:
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
beta_func: Callable[[float], list[float]] = Parameter(func_validator)
gamma_func: Callable[[float], float] = Parameter(func_validator)
fiber_map: list[tuple[float, str]] = Parameter(
validator_list(type_checker(tuple)), converter=fiber_map_converter
)
num: int = Parameter(non_negative(int))
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
version: str = Parameter(string)
def prepare_for_dump(self) -> dict[str, Any]:
param = asdict(self)
param["fiber_map"] = [(str(z), n) for z, n in param.get("fiber_map", [])]
param = Parameters.strip_params_dict(param)
param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__
@@ -816,7 +808,9 @@ class Configuration:
self.overwrite = overwrite
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
self.final_path = utils.ensure_folder(
self.final_path, mkdir=False, prevent_overwrite=not self.overwrite
Path(env.get(env.OUTPUT_PATH, self.final_path)),
mkdir=False,
prevent_overwrite=not self.overwrite,
)
self.master_config = self.fiber_configs[0]
self.name = self.final_path.name
@@ -868,23 +862,8 @@ class Configuration:
def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]:
for i in range(self.num_fibers):
for sim_config in self.iterate_single_fiber(i):
if i > 0:
sim_config.config["prev_data_dir"] = str(
self.fiber_paths[i - 1] / sim_config.descriptor[:i].formatted_descriptor()
)
params = Parameters(**sim_config.config)
params.compute()
fiber_map = []
for j in range(i + 1):
this_conf = self.all_configs[sim_config.descriptor.index[: j + 1]].config
if j > 0:
prev_conf = self.all_configs[sim_config.descriptor.index[:j]].config
length = prev_conf["length"] + fiber_map[j - 1][0]
else:
length = 0.0
fiber_map.append((length, this_conf["name"]))
params.fiber_map = fiber_map
yield sim_config.descriptor, params
def iterate_single_fiber(
@@ -903,18 +882,21 @@ class Configuration:
__SimConfig
configuration obj
"""
sim_dict: dict[Path, self.__SimConfig] = {}
for descr in self.variationer.iterate(index):
cfg = descr.update_config(self.fiber_configs[index])
p = ensure_folder(
self.fiber_paths[index] / descr.formatted_descriptor(),
sim_dict: dict[Path, Configuration.__SimConfig] = {}
for descriptor in self.variationer.iterate(index):
cfg = descriptor.update_config(self.fiber_configs[index])
if index > 0:
cfg["prev_data_dir"] = str(
self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True)
)
p = utils.ensure_folder(
self.fiber_paths[index] / descriptor.formatted_descriptor(True),
not self.overwrite,
False,
)
cfg["output_path"] = str(p)
sim_config = self.__SimConfig(descr, cfg, p)
sim_dict[p] = sim_config
self.all_configs[sim_config.descriptor.index] = sim_config
sim_config = self.__SimConfig(descriptor, cfg, p)
sim_dict[p] = self.all_configs[sim_config.descriptor.index] = sim_config
while len(sim_dict) > 0:
for data_dir, sim_config in sim_dict.items():
task, config_dict = self.__decide(sim_config)
@@ -1001,9 +983,12 @@ class Configuration:
raise ValueError(f"Too many spectra in {data_dir}")
def save_parameters(self):
for config, sim_dir in zip(self.fiber_configs, self.fiber_paths):
os.makedirs(sim_dir, exist_ok=True)
utils.save_toml(sim_dir / f"initial_config.toml", config)
os.makedirs(self.final_path, exist_ok=True)
cfgs = [
cfg | dict(variable=self.variationer.all_dicts[i])
for i, cfg in enumerate(self.fiber_configs)
]
utils.save_toml(self.final_path / f"initial_config.toml", dict(name=self.name, Fiber=cfgs))
@property
def first(self) -> Parameters:

View File

@@ -1,12 +1,17 @@
import inspect
import os
import re
from collections import defaultdict
from functools import cache
from pathlib import Path
from string import printable as str_printable
from typing import Callable
import numpy as np
from pydantic import BaseModel
from .._utils import load_toml, save_toml
from ..const import PARAM_FN, Z_FN
from ..physics.units import get_unit
@@ -144,3 +149,55 @@ def _mock_function(num_args: int, num_returns: int) -> Callable:
out_func = scope[func_name]
out_func.__module__ = "evaluator"
return out_func
def combine_simulations(path: Path, dest: Path = None):
"""combines raw simulations into one folder per branch
Parameters
----------
path : Path
source of the simulations (must contain u_xx directories)
dest : Path, optional
if given, moves the simulations to dest, by default None
"""
paths: dict[str, list[Path]] = defaultdict(list)
if dest is None:
dest = path
for p in path.glob("u_*b_*"):
if p.is_dir():
paths[p.name.split()[1]].append(p)
for l in paths.values():
l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0])
for pulses in paths.values():
new_path = dest / update_path(pulses[0].name)
os.makedirs(new_path, exist_ok=True)
for num, pulse in enumerate(pulses):
params_ok = False
for file in pulse.glob("*"):
if file.name == PARAM_FN:
if not params_ok:
update_params(new_path, file)
params_ok = True
else:
file.unlink()
elif file.name == Z_FN:
file.rename(new_path / file.name)
else:
file.rename(new_path / (file.stem + f"_{num}" + file.suffix))
pulse.rmdir()
def update_params(new_path: Path, file: Path):
params = load_toml(file)
if (p := params.get("prev_data_dir")) is not None:
p = Path(p)
params["prev_data_dir"] = str(p.parent / update_path(p.name))
params["output_path"] = str(new_path)
save_toml(new_path / PARAM_FN, params)
file.unlink()
def update_path(p: str) -> str:
return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p)

View File

@@ -116,6 +116,7 @@ class VariationDescriptor(utils.HashableBaseModel):
index: tuple[tuple[int, ...], ...]
separator: str = "fiber"
_format_registry: dict[str, Callable[..., str]] = {}
__ids: dict[int, int] = {}
def __str__(self) -> str:
return self.formatted_descriptor(add_identifier=False)
@@ -173,7 +174,19 @@ class VariationDescriptor(utils.HashableBaseModel):
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
)
def update_config(self, cfg: dict[str, Any]):
def update_config(self, cfg: dict[str, Any]) -> dict[str, Any]:
"""updates a dictionary with the value of the descriptor
Parameters
----------
cfg : dict[str, Any]
dict to be updated
Returns
-------
dict[str, Any]
same as cfg but with key from the descriptor added/updated.
"""
return cfg | {k: v for k, v in self.raw_descr[-1]}
@property
@@ -188,17 +201,22 @@ class VariationDescriptor(utils.HashableBaseModel):
@property
def branch(self) -> "BranchDescriptor":
for i in reversed(range(len(self.raw_descr))):
for j in reversed(range(len(self.raw_descr[i]))):
if self.raw_descr[i][j][0] == "num":
del self.raw_descr[i][j]
return VariationDescriptor(
raw_descr=self.raw_descr, index=self.index, separator=self.separator
)
descr = []
ind = []
for i, l in enumerate(self.raw_descr):
descr.append([])
ind.append([])
for j, (k, v) in enumerate(l):
if k != "num":
descr[-1].append((k, v))
ind[-1].append(self.index[i][j])
return BranchDescriptor(raw_descr=descr, index=ind, separator=self.separator)
@property
def identifier(self) -> str:
return "u_" + utils.to_62(hash(str(self.flat)))
unique_id = hash(str(self.flat))
self.__ids.setdefault(unique_id, len(self.__ids))
return "u_" + str(self.__ids[unique_id])
class BranchDescriptor(VariationDescriptor):
@@ -208,7 +226,7 @@ class BranchDescriptor(VariationDescriptor):
def identifier(self) -> str:
branch_id = hash(str(self.flat))
self.__ids.setdefault(branch_id, len(self.__ids))
return str(self.__ids[branch_id])
return "b_" + str(self.__ids[branch_id])
@validator("raw_descr")
def validate_raw_descr(cls, v):

View File

@@ -1,14 +1,13 @@
import argparse
import os
import re
import subprocess
import sys
from collections import ChainMap
from pathlib import Path
import numpy as np
from .. import const, env, scripts, utils
from .. import const, env, scripts
from .. import _utils as utils
from ..logger import get_logger
from ..physics.fiber import dispersion_coefficients
from ..physics.simulate import SequencialSimulations, run_simulation

View File

@@ -20,7 +20,8 @@ def pbar_format(worker_id: int):
SPEC1_FN = "spectrum_{}.npy"
SPECN_FN = "spectra_{}.npy"
SPECN_FN1 = "spectra_{}.npy"
SPEC1_FN_N = "spectrum_{}_{}.npy"
Z_FN = "z.npy"
PARAM_FN = "params.toml"
PARAM_SEPARATOR = " "

View File

@@ -48,7 +48,7 @@ def data_folder(task_id: int) -> Optional[str]:
return tmp
def get(key: str) -> Any:
def get(key: str, default=None) -> Any:
str_value = os.environ.get(key)
if isinstance(str_value, str):
try:
@@ -58,7 +58,7 @@ def get(key: str) -> Any:
return t(str_value)
except (ValueError, KeyError):
pass
return None
return default
def all_environ() -> Dict[str, str]:

View File

@@ -3,7 +3,7 @@ from typing import Union
import numpy as np
from scipy.interpolate import griddata, interp1d
from scipy.special import jn_zeros
from .utils.cache import np_cache
from ._utils.cache import np_cache
pi = np.pi
c = 299792458.0

View File

@@ -10,8 +10,8 @@ from scipy.optimize import minimize_scalar
from .. import math
from . import fiber, materials, units, pulse
from .. import utils
from ..utils import cache
from .. import _utils
from .._utils import cache
T = TypeVar("T")

View File

@@ -7,9 +7,9 @@ from scipy.interpolate import interp1d
from ..logger import get_logger
from .. import utils
from .. import _utils
from ..math import abs2, argclosest, power_fact, u_nm
from ..utils.cache import np_cache
from .._utils.cache import np_cache
from . import materials as mat
from . import units
from .units import c, pi

View File

@@ -5,7 +5,7 @@ from scipy.integrate import cumulative_trapezoid
from ..logger import get_logger
from . import units
from .. import utils
from .. import _utils
from .units import NA, c, kB, me, e, hbar

View File

@@ -23,8 +23,6 @@ from scipy.interpolate import UnivariateSpline
from scipy.optimize import minimize_scalar
from scipy.optimize.optimize import OptimizeResult
from scgenerator import utils
from ..defaults import default_plotting
from ..logger import get_logger
from ..math import *

View File

@@ -9,9 +9,11 @@ from typing import Any, Generator, Type, Union
import numpy as np
from send2trash import send2trash
from .. import env, utils
from .. import env
from .. import _utils as utils
from .._utils.utils import combine_simulations
from ..logger import get_logger
from ..utils.parameter import Configuration, Parameters
from .._utils.parameter import Configuration, Parameters
from . import pulse
from .fiber import create_non_linear_op, fast_dispersion_op
@@ -718,17 +720,9 @@ def run_simulation(
sim = new_simulation(config, method)
sim.run()
path_trees = utils.build_path_trees(config.fiber_paths[-1])
final_name = env.get(env.OUTPUT_PATH)
if final_name is None:
final_name = config.final_path
utils.merge(final_name, path_trees)
try:
send2trash(config.fiber_paths)
except (PermissionError, OSError):
get_logger(__name__).error("Could not send temporary directories to trash")
for path in config.fiber_paths:
combine_simulations(path)
def new_simulation(

View File

@@ -2,7 +2,6 @@
# For example, nm(X) means "I give the number X in nm, figure out the ang. freq."
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
from dataclasses import dataclass
from typing import Callable, TypeVar, Union
import numpy as np

View File

@@ -14,8 +14,8 @@ from .const import PARAM_SEPARATOR
from .defaults import default_plotting as defaults
from .math import abs2, span
from .physics import pulse, units
from .utils.parameter import Parameters
from .utils.utils import PlotRange, sort_axis
from ._utils.parameter import Parameters
from ._utils.utils import PlotRange, sort_axis
RangeType = tuple[float, float, Union[str, Callable]]
NO_LIM = object()

View File

@@ -12,12 +12,11 @@ from ..const import PARAM_FN, PARAM_SEPARATOR
from ..physics import fiber, units
from ..plotting import plot_setup
from ..spectra import Pulse
from ..utils import auto_crop, open_config, save_toml, translate_parameters
from ..utils.parameter import (
from .._utils import auto_crop, open_config, save_toml, translate_parameters
from .._utils.parameter import (
Configuration,
Parameters,
)
from ..utils.variationer import VariationDescriptor
def fingerprint(params: Parameters):

View File

@@ -9,8 +9,8 @@ from typing import Tuple
import numpy as np
from ..utils import Paths
from ..utils.parameter import Configuration
from .._utils import Paths
from .._utils.parameter import Configuration
def primes(n):

View File

@@ -1,13 +1,19 @@
from __future__ import annotations
import os
from collections.abc import Sequence
from pathlib import Path
from typing import Callable, Dict, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Union
import matplotlib.pyplot as plt
import numpy as np
from pydantic import BaseModel, DirectoryPath, root_validator
from . import math
from .const import SPECN_FN
from ._utils import load_spectrum
from ._utils.parameter import Parameters
from ._utils.utils import PlotRange
from .const import SPECN_FN1, PARAM_FN, SPEC1_FN_N
from .logger import get_logger
from .physics import pulse, units
from .plotting import (
@@ -16,9 +22,87 @@ from .plotting import (
single_position_plot,
transform_2D_propagation,
)
from .utils.parameter import Parameters
from .utils.utils import PlotRange
from .utils import load_spectrum
class SimulationSeries:
path: Path
params: Parameters
total_length: float
total_num_steps: int
previous: SimulationSeries = None
class Config:
arbitrary_types_allowed = True
def __init__(self, path: os.PathLike):
self.path = Path(path)
self.params = Parameters.load(self.path / PARAM_FN)
if self.params.prev_data_dir is not None:
self.previous = SimulationSeries(self.params.prev_data_dir)
self.total_length = self.accumulate_params("length")
self.total_num_steps = self.accumulate_params("z_num")
def fiber_map(self):
lengths = self.all_params("length")
return [
(this[0], following[1]) for this, following in zip(lengths, [(None, 0.0)] + lengths)
]
def all_params(self, key: str) -> list[tuple[str, Any]]:
"""returns the value of a parameter for each fiber
Parameters
----------
key : str
name of the parameter
Returns
-------
list[tuple[str, Any]]
list of (fiber_name, param_value) tuples
"""
return list(reversed(self._all_params(key, [])))
def accumulate_params(self, key: str) -> Any:
"""returns the sum of all the values a parameter takes. Useful to
get the total length of the fiber, the total number of steps, etc.
Parameters
----------
key : str
name of the parameter
Returns
-------
Any
final sum
"""
return sum(el[1] for el in self.all_params(key))
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
def _all_params(self, key: str, l: list) -> list:
l.append((self.params.name, getattr(self.params, key)))
if self.previous is not None:
return self.previous._all_params(key, l)
return l
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})"
def __eq__(self, other: SimulationSeries) -> bool:
return (
self.path == other.path
and self.params == other.params
and self.previous == other.previous
)
def __contains__(self, other: SimulationSeries) -> bool:
if other is self or other == self:
return True
if self.previous is not None:
return other in self.previous
class Spectrum(np.ndarray):
@@ -129,6 +213,23 @@ class Spectrum(np.ndarray):
class Pulse(Sequence):
path: Path
default_ind: Optional[int]
params: Parameters
z: np.ndarray
namx: int
t: np.ndarray
w: np.ndarray
w_order: np.ndarray
def __new__(cls, path: os.PathLike, *args, **kwargs) -> "Pulse":
try:
if load_spectrum(Path(path) / SPECN_FN1.format(0)).ndim == 2:
return super().__new__(LegacyPulse)
except FileNotFoundError:
pass
return super().__new__(cls)
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
"""load a data folder as a pulse
@@ -144,36 +245,6 @@ class Pulse(Sequence):
FileNotFoundError
path does not contain proper data
"""
self.logger = get_logger(__name__)
self.path = Path(path)
self.default_ind = default_ind
if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = Parameters.load(self.path / "params.toml")
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
if self.params.fiber_map is None:
self.params.fiber_map = [(0.0, self.params.name)]
try:
self.z = np.load(os.path.join(path, "z.npy"))
except FileNotFoundError:
if self.params is not None:
self.z = self.params.z_targets
else:
raise
self.nmax = len(list(self.path.glob("spectra_*.npy")))
if self.nmax <= 0:
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
self.t = self.params.t
w = math.wspace(self.t) + units.m(self.params.wavelength)
self.w_order = np.argsort(w)
self.w = w
self.wl = units.m.inv(self.w)
self.params.w = self.w
self.params.z_targets = self.z
def __iter__(self):
"""
@@ -190,73 +261,6 @@ class Pulse(Sequence):
def __getitem__(self, key) -> Spectrum:
return self.all_spectra(key)
def intensity(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self._to_wl_int,
FREQ=self._to_freq_int,
AFREQ=self._to_afreq_int,
TIME=self._to_time_int,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
def _to_wl_int(self, spectrum):
return units.to_WL(math.abs2(spectrum), spectrum.wl)
def _to_freq_int(self, spectrum):
return math.abs2(spectrum)
def _to_afreq_int(self, spectrum):
return math.abs2(spectrum)
def _to_time_int(self, spectrum):
return math.abs2(np.fft.ifft(spectrum))
def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self._to_wl_amp,
FREQ=self._to_freq_amp,
AFREQ=self._to_afreq_amp,
TIME=self._to_time_amp,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
def _to_wl_amp(self, spectrum):
return (
np.sqrt(
units.to_WL(
math.abs2(spectrum),
spectrum.wl,
)
)
* spectrum
/ np.abs(spectrum)
)
def _to_freq_amp(self, spectrum):
return spectrum
def _to_afreq_amp(self, spectrum):
return spectrum
def _to_time_amp(self, spectrum):
return np.fft.ifft(spectrum)
def all_spectra(self, ind=None) -> Spectrum:
"""
loads the data already simulated.
@@ -305,12 +309,7 @@ class Pulse(Sequence):
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
def _load1(self, i: int):
if i < 0:
i = self.nmax + i
spec = load_spectrum(self.path / SPECN_FN.format(i))
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.params)
return spec
pass
def plot_2D(
self,
@@ -412,3 +411,46 @@ class Pulse(Sequence):
index
"""
return math.argclosest(self.z, z)
class LegacyPulse(Pulse):
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
print("old init called", path, default_ind)
self.logger = get_logger(__name__)
self.path = Path(path)
self.default_ind = default_ind
if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = Parameters.load(self.path / "params.toml")
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
if self.params.fiber_map is None:
self.params.fiber_map = [(0.0, self.params.name)]
try:
self.z = np.load(os.path.join(path, "z.npy"))
except FileNotFoundError:
if self.params is not None:
self.z = self.params.z_targets
else:
raise
self.nmax = len(list(self.path.glob("spectra_*.npy")))
if self.nmax <= 0:
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
self.t = self.params.t
w = math.wspace(self.t) + units.m(self.params.wavelength)
self.w_order = np.argsort(w)
self.w = w
self.wl = units.m.inv(self.w)
self.params.w = self.w
self.params.z_targets = self.z
def _load1(self, i: int):
if i < 0:
i = self.nmax + i
spec = load_spectrum(self.path / SPECN_FN1.format(i))
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.params)
return spec

View File

@@ -1,35 +0,0 @@
import shutil
import unittest
import toml
from scgenerator import logger
from send2trash import send2trash
TMP = "testing/.tmp"
class TestRecoveryParamSequence(unittest.TestCase):
def setUp(self):
shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP)
self.conf = toml.load(TMP + "/initial_config.toml")
io.set_data_folder(55, TMP)
def test_remaining_simulations_count(self):
param_seq = initialize.RecoveryParamSequence(self.conf, 55)
self.assertEqual(5, len(param_seq))
def test_only_one_to_complete(self):
param_seq = initialize.RecoveryParamSequence(self.conf, 55)
i = 0
for expected, (vary_list, params) in zip([True, False, False, False, False], param_seq):
i += 1
self.assertEqual(expected, "recovery_last_stored" in params)
self.assertEqual(5, i)
def tearDown(self):
send2trash(TMP)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,216 +0,0 @@
import unittest
from copy import deepcopy
import numpy as np
import toml
from scgenerator import defaults, utils, math
from scgenerator.errors import *
from scgenerator.physics import pulse, units
from scgenerator.utils.parameter import Config, Parameters
def load_conf(name):
with open("testing/configs/" + name + ".toml") as file:
conf = toml.load(file)
return conf
def conf_maker(folder):
def conf(name):
return load_conf(folder + "/" + name)
return conf
class TestParamSequence(unittest.TestCase):
def iterconf(self, files):
conf = conf_maker("param_sequence")
for path in files:
yield init.ParamSequence(conf(path))
def test_no_repeat_in_sub_folder_names(self):
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
l = []
s = []
for vary_list, _ in utils.required_simulations(param_seq.config):
self.assertNotIn(vary_list, l)
self.assertNotIn(utils.format_variable_list(vary_list), s)
l.append(vary_list)
s.append(utils.format_variable_list(vary_list))
def test_no_variations_yields_only_num_and_id(self):
for param_seq in self.iterconf(["no_variations"]):
for vary_list, _ in utils.required_simulations(param_seq.config):
self.assertEqual(vary_list[1][0], "num")
self.assertEqual(vary_list[0][0], "id")
self.assertEqual(2, len(vary_list))
class TestInitializeMethods(unittest.TestCase):
def test_validate_types(self):
conf = lambda s: load_conf("validate_types/" + s)
with self.assertRaisesRegex(ValueError, r"'behaviors\[3\]' must be a str in"):
init.Config(**conf("bad2"))
with self.assertRaisesRegex(TypeError, "value must be of type <class 'float'>"):
init.Config(**conf("bad3"))
with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"):
init.Config(**conf("bad4"))
with self.assertRaisesRegex(
TypeError, "'variable intensity_noise' value must be of type <class 'list'>"
):
init.Config(**conf("bad5"))
with self.assertRaisesRegex(ValueError, "'repeat' must be positive"):
init.Config(**conf("bad6"))
with self.assertRaisesRegex(
ValueError, "variable parameter 'intensity_noise' must not be empty"
):
init.Config(**conf("bad7"))
self.assertIsNone(init.Config(**conf("good")).hr_w)
def test_ensure_consistency(self):
conf = lambda s: load_conf("ensure_consistency/" + s)
with self.assertRaisesRegex(
MissingParameterError,
r"1 of '\['t0', 'width'\]' is required and no defaults have been set",
):
init.Config(**conf("bad1"))
with self.assertRaisesRegex(
MissingParameterError,
r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set",
):
init.Config(**conf("bad2"))
with self.assertRaisesRegex(
MissingParameterError,
r"2 of '\['dt', 't_num', 'time_window'\]' are required and no defaults have been set",
):
init.Config(**conf("bad3"))
with self.assertRaisesRegex(
DuplicateParameterError,
r"got multiple values for parameter 'width'",
):
init.Config(**conf("bad4"))
with self.assertRaisesRegex(
MissingParameterError,
r"'capillary_thickness' is a required parameter for fiber model 'hasan' and no defaults have been set",
):
init.Config(**conf("bad5"))
with self.assertRaisesRegex(
MissingParameterError,
r"1 of '\['capillary_spacing', 'capillary_outer_d'\]' is required for fiber model 'hasan' and no defaults have been set",
):
init.Config(**conf("bad6"))
self.assertLessEqual(
{"model": "pcf"}.items(), init.Config(**conf("good1")).__dict__.items()
)
self.assertIsNone(init.Config(**conf("good4")).gamma)
self.assertLessEqual(
{"raman_type": "agrawal"}.items(),
init.Config(**conf("good2")).__dict__.items(),
)
self.assertLessEqual(
{"name": "no name"}.items(), init.Config(**conf("good3")).__dict__.items()
)
self.assertLessEqual(
{"capillary_nested": 0, "capillary_resonance_strengths": []}.items(),
init.Config(**conf("good4")).__dict__.items(),
)
self.assertLessEqual(
dict(he_mode=(1, 1)).items(),
init.Config(**conf("good5")).__dict__.items(),
)
self.assertLessEqual(
dict(temperature=300, pressure=1e5, gas_name="vacuum", plasma_density=0).items(),
init.Config(**conf("good5")).__dict__.items(),
)
def setup_conf_custom_field(self, path) -> Parameters:
conf = load_conf(path)
conf = Parameters(**conf)
init.build_sim_grid_in_place(conf)
return conf
def test_setup_custom_field(self):
d = np.load("testing/configs/custom_field/init_field.npz")
t = d["time"]
field = d["field"]
conf = self.setup_conf_custom_field("custom_field/no_change")
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
conf
)
self.assertAlmostEqual(conf.field_0.real.max(), field.real.max(), 4)
self.assertTrue(result)
conf = self.setup_conf_custom_field("custom_field/peak_power")
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
conf
)
conf.wavelength = pulse.correct_wavelength(conf.wavelength, conf.w_c, conf.field_0)
self.assertAlmostEqual(math.abs2(conf.field_0).max(), 20000, 4)
self.assertTrue(result)
self.assertNotAlmostEqual(conf.wavelength, 1593e-9)
conf = self.setup_conf_custom_field("custom_field/mean_power")
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
conf
)
self.assertAlmostEqual(np.trapz(math.abs2(conf.field_0), conf.t), 0.22 / 40e6, 4)
self.assertTrue(result)
conf = self.setup_conf_custom_field("custom_field/recover1")
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
conf
)
self.assertAlmostEqual(math.abs2(conf.field_0 - field).sum(), 0)
self.assertTrue(result)
conf = self.setup_conf_custom_field("custom_field/recover2")
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
conf
)
self.assertAlmostEqual((math.abs2(conf.field_0) / 0.9 - math.abs2(field)).sum(), 0)
self.assertTrue(result)
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
result = Parameters(**conf)
self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9)
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
conf.wavelength = 1593e-9
result = Parameters(**conf)
conf = load_conf("custom_field/wavelength_shift2")
conf = init.Config(**conf)
for target, (variable, config) in zip(
[1050e-9, 1321e-9, 1593e-9], init.ParamSequence(conf)
):
init.build_sim_grid_in_place(conf)
self.assertAlmostEqual(
units.m.inv(config.w)[np.argmax(math.abs2(config.spec_0))], target
)
print(config.wavelength, target)
if __name__ == "__main__":
conf = conf_maker("validate_types")
unittest.main()

View File

@@ -1,41 +0,0 @@
import unittest
from scgenerator.physics.pulse import conform_pulse_params
class TestPulseMethods(unittest.TestCase):
def test_conform_pulse_params(self):
self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, energy=6))
self.assertNotIn(None, conform_pulse_params("gaussian", width=5, energy=6))
self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, peak_power=6))
self.assertNotIn(None, conform_pulse_params("gaussian", width=5, peak_power=6))
self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, energy=6)))
self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, energy=6)))
self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, peak_power=6)))
self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, peak_power=6)))
with self.assertRaisesRegex(
TypeError, "when soliton number is desired, both gamma and beta2 must be specified"
):
conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01)
with self.assertRaisesRegex(
TypeError, "when soliton number is desired, both gamma and beta2 must be specified"
):
conform_pulse_params("gaussian", t0=5, energy=6, beta2=0.01)
self.assertEqual(
5, len(conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01, beta2=2e-6))
)
self.assertEqual(
5, len(conform_pulse_params("gaussian", width=5, energy=6, gamma=0.01, beta2=2e-6))
)
self.assertEqual(
5, len(conform_pulse_params("gaussian", t0=5, peak_power=6, gamma=0.01, beta2=2e-6))
)
self.assertEqual(
5, len(conform_pulse_params("gaussian", width=5, peak_power=6, gamma=0.01, beta2=2e-6))
)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,67 +0,0 @@
import unittest
import numpy as np
import toml
from scgenerator import utils
def load_conf(name):
with open("testing/configs/" + name + ".toml") as file:
conf = toml.load(file)
return conf
def conf_maker(folder, val=True):
def conf(name):
if val:
return initialize.Config(**load_conf(folder + "/" + name))
else:
return initialize.Config(**load_conf(folder + "/" + name))
return conf
class TestUtilsMethods(unittest.TestCase):
def test_count_variations(self):
conf = conf_maker("count_variations")
for sim, vary in [(1, 0), (1, 1), (2, 1), (2, 0), (120, 3)]:
self.assertEqual((sim, vary), utils.count_variations(conf(f"{sim}sim_{vary}vary")))
def test_format_value(self):
values = [
122e-6,
True,
["raman", "ss"],
np.arange(5),
1.123,
1.1230001,
0.002e122,
12.3456e-9,
]
s = [
"0.000122",
"True",
"raman-ss",
"0-1-2-3-4",
"1.123",
"1.1230001",
"2e+119",
"1.23456e-08",
]
for value, target in zip(values, s):
self.assertEqual(target, utils.format_value(value))
def test_override_config(self):
conf = conf_maker("override", False)
old = conf("initial_config")
new = conf("fiber2")
over = utils.override_config(vars(new), old)
self.assertNotIn("input_transmission", over.variable)
self.assertIsNone(over.input_transmission)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,54 @@
from pydantic import main
import scgenerator as sc
def test_descriptor():
# Same branch
var1 = sc.VariationDescriptor(
raw_descr=[[("num", 1), ("a", False)], [("b", 0)]], index=[[1, 0], [0]]
)
var2 = sc.VariationDescriptor(
raw_descr=[[("num", 2), ("a", False)], [("b", 0)]], index=[[1, 0], [0]]
)
assert var1.branch.identifier == "b_0"
assert var1.identifier != var1.branch.identifier
assert var1.identifier != var2.identifier
assert var1.branch.identifier == var2.branch.identifier
# different branch
var3 = sc.VariationDescriptor(
raw_descr=[[("num", 2), ("a", True)], [("b", 0)]], index=[[1, 0], [0]]
)
assert var1.branch.identifier != var3.branch.identifier
assert var1.formatted_descriptor() != var2.formatted_descriptor()
assert var1.formatted_descriptor() != var3.formatted_descriptor()
def test_variationer():
var = sc.Variationer(
[
dict(a=[1, 2], num=[0, 1, 2]),
[dict(b=["000", "111"], c=["a", "-1"])],
dict(),
dict(),
[dict(aaa=[True, False], bb=[1, 3])],
]
)
assert var.var_num(0) == 6
assert var.var_num(1) == 12
assert var.var_num() == 24
cfg = dict(bb=None)
branches = set()
for descr in var.iterate():
assert descr.update_config(cfg).items() >= set(descr.raw_descr[-1])
branches.add(descr.branch.identifier)
assert len(branches) == 8
def main():
test_descriptor()
if __name__ == "__main__":
main()