partway trough big data structure revamp
This commit is contained in:
19
play.py
19
play.py
@@ -4,22 +4,3 @@ import scgenerator as sc
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
def _main():
|
|
||||||
print(os.getcwd())
|
|
||||||
for v_list, params in sc.Configuration("PM1550+PM2000D+PM1550/Pos30000.toml"):
|
|
||||||
print(params.fiber_map)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
drr = os.getcwd()
|
|
||||||
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations")
|
|
||||||
try:
|
|
||||||
_main()
|
|
||||||
finally:
|
|
||||||
os.chdir(drr)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from . import math, utils
|
from . import math
|
||||||
from .math import abs2, argclosest, span
|
from .math import abs2, argclosest, span
|
||||||
from .physics import fiber, materials, pulse, simulate, units
|
from .physics import fiber, materials, pulse, simulate, units
|
||||||
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
|
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
|
||||||
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
||||||
from .spectra import Pulse, Spectrum
|
from .spectra import Pulse, Spectrum
|
||||||
from .utils import Paths, open_config, parameter
|
from ._utils import Paths, open_config, parameter
|
||||||
from .utils.parameter import Configuration, Parameters
|
from ._utils.parameter import Configuration, Parameters
|
||||||
from .utils.utils import PlotRange
|
from ._utils.utils import PlotRange
|
||||||
|
from ._utils.variationer import Variationer, VariationDescriptor, VariationSpecsError
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import pkg_resources as pkg
|
|||||||
import toml
|
import toml
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__
|
from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN1, Z_FN, __version__
|
||||||
from ..env import pbar_policy
|
from ..env import pbar_policy
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
|
|
||||||
@@ -143,7 +143,8 @@ def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return dico
|
return dico
|
||||||
|
|
||||||
|
|
||||||
def load_toml(descr: str) -> dict[str, Any]:
|
def load_toml(descr: os.PathLike) -> dict[str, Any]:
|
||||||
|
descr = str(descr)
|
||||||
if ":" in descr:
|
if ":" in descr:
|
||||||
path, entry = descr.split(":", 1)
|
path, entry = descr.split(":", 1)
|
||||||
with open(path) as file:
|
with open(path) as file:
|
||||||
@@ -188,6 +189,7 @@ def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]
|
|||||||
params.setdefault("variable", {})
|
params.setdefault("variable", {})
|
||||||
configs.append(loaded_config | params)
|
configs.append(loaded_config | params)
|
||||||
configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"]
|
configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"]
|
||||||
|
configs[0]["variable"]["num"] = list(range(configs[0].get("repeat", 1)))
|
||||||
|
|
||||||
return Path(final_path), configs
|
return Path(final_path), configs
|
||||||
|
|
||||||
@@ -341,7 +343,7 @@ def merge_path_tree(
|
|||||||
|
|
||||||
for i, (z, merged_spectra) in enumerate(merge_spectra(path_tree)):
|
for i, (z, merged_spectra) in enumerate(merge_spectra(path_tree)):
|
||||||
z_arr.append(z)
|
z_arr.append(z)
|
||||||
spec_out_name = SPECN_FN.format(i)
|
spec_out_name = SPECN_FN1.format(i)
|
||||||
np.save(destination / spec_out_name, merged_spectra)
|
np.save(destination / spec_out_name, merged_spectra)
|
||||||
if z_callback is not None:
|
if z_callback is not None:
|
||||||
z_callback(i)
|
z_callback(i)
|
||||||
@@ -14,14 +14,15 @@ from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequen
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.lib import isin
|
from numpy.lib import isin
|
||||||
from scgenerator.utils import ensure_folder, variationer
|
|
||||||
|
|
||||||
from .. import math, utils
|
from .. import math
|
||||||
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
|
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
|
||||||
from ..errors import EvaluatorError, NoDefaultError
|
from ..errors import EvaluatorError, NoDefaultError
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..physics import fiber, materials, pulse, units
|
from ..physics import fiber, materials, pulse, units
|
||||||
from ..utils.variationer import VariationDescriptor, Variationer
|
from .._utils.variationer import VariationDescriptor, Variationer
|
||||||
|
from .. import _utils as utils
|
||||||
|
from .. import env
|
||||||
from .utils import func_rewrite, _mock_function, get_arg_names
|
from .utils import func_rewrite, _mock_function, get_arg_names
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -312,13 +313,6 @@ class Parameter:
|
|||||||
return f"{num_str} {unit}"
|
return f"{num_str} {unit}"
|
||||||
|
|
||||||
|
|
||||||
def fiber_map_converter(d: dict[str, str]) -> list[tuple[float, str]]:
|
|
||||||
if isinstance(d, dict):
|
|
||||||
return [(float(k), v) for k, v in d.items()]
|
|
||||||
else:
|
|
||||||
return [(float(k), v) for k, v in d]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Parameters:
|
class Parameters:
|
||||||
"""
|
"""
|
||||||
@@ -432,15 +426,13 @@ class Parameters:
|
|||||||
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
beta_func: Callable[[float], list[float]] = Parameter(func_validator)
|
beta_func: Callable[[float], list[float]] = Parameter(func_validator)
|
||||||
gamma_func: Callable[[float], float] = Parameter(func_validator)
|
gamma_func: Callable[[float], float] = Parameter(func_validator)
|
||||||
fiber_map: list[tuple[float, str]] = Parameter(
|
|
||||||
validator_list(type_checker(tuple)), converter=fiber_map_converter
|
num: int = Parameter(non_negative(int))
|
||||||
)
|
|
||||||
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
||||||
version: str = Parameter(string)
|
version: str = Parameter(string)
|
||||||
|
|
||||||
def prepare_for_dump(self) -> dict[str, Any]:
|
def prepare_for_dump(self) -> dict[str, Any]:
|
||||||
param = asdict(self)
|
param = asdict(self)
|
||||||
param["fiber_map"] = [(str(z), n) for z, n in param.get("fiber_map", [])]
|
|
||||||
param = Parameters.strip_params_dict(param)
|
param = Parameters.strip_params_dict(param)
|
||||||
param["datetime"] = datetime_module.datetime.now()
|
param["datetime"] = datetime_module.datetime.now()
|
||||||
param["version"] = __version__
|
param["version"] = __version__
|
||||||
@@ -816,7 +808,9 @@ class Configuration:
|
|||||||
self.overwrite = overwrite
|
self.overwrite = overwrite
|
||||||
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
|
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
|
||||||
self.final_path = utils.ensure_folder(
|
self.final_path = utils.ensure_folder(
|
||||||
self.final_path, mkdir=False, prevent_overwrite=not self.overwrite
|
Path(env.get(env.OUTPUT_PATH, self.final_path)),
|
||||||
|
mkdir=False,
|
||||||
|
prevent_overwrite=not self.overwrite,
|
||||||
)
|
)
|
||||||
self.master_config = self.fiber_configs[0]
|
self.master_config = self.fiber_configs[0]
|
||||||
self.name = self.final_path.name
|
self.name = self.final_path.name
|
||||||
@@ -868,23 +862,8 @@ class Configuration:
|
|||||||
def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]:
|
def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]:
|
||||||
for i in range(self.num_fibers):
|
for i in range(self.num_fibers):
|
||||||
for sim_config in self.iterate_single_fiber(i):
|
for sim_config in self.iterate_single_fiber(i):
|
||||||
|
|
||||||
if i > 0:
|
|
||||||
sim_config.config["prev_data_dir"] = str(
|
|
||||||
self.fiber_paths[i - 1] / sim_config.descriptor[:i].formatted_descriptor()
|
|
||||||
)
|
|
||||||
params = Parameters(**sim_config.config)
|
params = Parameters(**sim_config.config)
|
||||||
params.compute()
|
params.compute()
|
||||||
fiber_map = []
|
|
||||||
for j in range(i + 1):
|
|
||||||
this_conf = self.all_configs[sim_config.descriptor.index[: j + 1]].config
|
|
||||||
if j > 0:
|
|
||||||
prev_conf = self.all_configs[sim_config.descriptor.index[:j]].config
|
|
||||||
length = prev_conf["length"] + fiber_map[j - 1][0]
|
|
||||||
else:
|
|
||||||
length = 0.0
|
|
||||||
fiber_map.append((length, this_conf["name"]))
|
|
||||||
params.fiber_map = fiber_map
|
|
||||||
yield sim_config.descriptor, params
|
yield sim_config.descriptor, params
|
||||||
|
|
||||||
def iterate_single_fiber(
|
def iterate_single_fiber(
|
||||||
@@ -903,18 +882,21 @@ class Configuration:
|
|||||||
__SimConfig
|
__SimConfig
|
||||||
configuration obj
|
configuration obj
|
||||||
"""
|
"""
|
||||||
sim_dict: dict[Path, self.__SimConfig] = {}
|
sim_dict: dict[Path, Configuration.__SimConfig] = {}
|
||||||
for descr in self.variationer.iterate(index):
|
for descriptor in self.variationer.iterate(index):
|
||||||
cfg = descr.update_config(self.fiber_configs[index])
|
cfg = descriptor.update_config(self.fiber_configs[index])
|
||||||
p = ensure_folder(
|
if index > 0:
|
||||||
self.fiber_paths[index] / descr.formatted_descriptor(),
|
cfg["prev_data_dir"] = str(
|
||||||
|
self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True)
|
||||||
|
)
|
||||||
|
p = utils.ensure_folder(
|
||||||
|
self.fiber_paths[index] / descriptor.formatted_descriptor(True),
|
||||||
not self.overwrite,
|
not self.overwrite,
|
||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
cfg["output_path"] = str(p)
|
cfg["output_path"] = str(p)
|
||||||
sim_config = self.__SimConfig(descr, cfg, p)
|
sim_config = self.__SimConfig(descriptor, cfg, p)
|
||||||
sim_dict[p] = sim_config
|
sim_dict[p] = self.all_configs[sim_config.descriptor.index] = sim_config
|
||||||
self.all_configs[sim_config.descriptor.index] = sim_config
|
|
||||||
while len(sim_dict) > 0:
|
while len(sim_dict) > 0:
|
||||||
for data_dir, sim_config in sim_dict.items():
|
for data_dir, sim_config in sim_dict.items():
|
||||||
task, config_dict = self.__decide(sim_config)
|
task, config_dict = self.__decide(sim_config)
|
||||||
@@ -1001,9 +983,12 @@ class Configuration:
|
|||||||
raise ValueError(f"Too many spectra in {data_dir}")
|
raise ValueError(f"Too many spectra in {data_dir}")
|
||||||
|
|
||||||
def save_parameters(self):
|
def save_parameters(self):
|
||||||
for config, sim_dir in zip(self.fiber_configs, self.fiber_paths):
|
os.makedirs(self.final_path, exist_ok=True)
|
||||||
os.makedirs(sim_dir, exist_ok=True)
|
cfgs = [
|
||||||
utils.save_toml(sim_dir / f"initial_config.toml", config)
|
cfg | dict(variable=self.variationer.all_dicts[i])
|
||||||
|
for i, cfg in enumerate(self.fiber_configs)
|
||||||
|
]
|
||||||
|
utils.save_toml(self.final_path / f"initial_config.toml", dict(name=self.name, Fiber=cfgs))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def first(self) -> Parameters:
|
def first(self) -> Parameters:
|
||||||
@@ -1,12 +1,17 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
from pathlib import Path
|
||||||
from string import printable as str_printable
|
from string import printable as str_printable
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .._utils import load_toml, save_toml
|
||||||
|
from ..const import PARAM_FN, Z_FN
|
||||||
from ..physics.units import get_unit
|
from ..physics.units import get_unit
|
||||||
|
|
||||||
|
|
||||||
@@ -144,3 +149,55 @@ def _mock_function(num_args: int, num_returns: int) -> Callable:
|
|||||||
out_func = scope[func_name]
|
out_func = scope[func_name]
|
||||||
out_func.__module__ = "evaluator"
|
out_func.__module__ = "evaluator"
|
||||||
return out_func
|
return out_func
|
||||||
|
|
||||||
|
|
||||||
|
def combine_simulations(path: Path, dest: Path = None):
|
||||||
|
"""combines raw simulations into one folder per branch
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : Path
|
||||||
|
source of the simulations (must contain u_xx directories)
|
||||||
|
dest : Path, optional
|
||||||
|
if given, moves the simulations to dest, by default None
|
||||||
|
"""
|
||||||
|
paths: dict[str, list[Path]] = defaultdict(list)
|
||||||
|
if dest is None:
|
||||||
|
dest = path
|
||||||
|
|
||||||
|
for p in path.glob("u_*b_*"):
|
||||||
|
if p.is_dir():
|
||||||
|
paths[p.name.split()[1]].append(p)
|
||||||
|
for l in paths.values():
|
||||||
|
l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0])
|
||||||
|
for pulses in paths.values():
|
||||||
|
new_path = dest / update_path(pulses[0].name)
|
||||||
|
os.makedirs(new_path, exist_ok=True)
|
||||||
|
for num, pulse in enumerate(pulses):
|
||||||
|
params_ok = False
|
||||||
|
for file in pulse.glob("*"):
|
||||||
|
if file.name == PARAM_FN:
|
||||||
|
if not params_ok:
|
||||||
|
update_params(new_path, file)
|
||||||
|
params_ok = True
|
||||||
|
else:
|
||||||
|
file.unlink()
|
||||||
|
elif file.name == Z_FN:
|
||||||
|
file.rename(new_path / file.name)
|
||||||
|
else:
|
||||||
|
file.rename(new_path / (file.stem + f"_{num}" + file.suffix))
|
||||||
|
pulse.rmdir()
|
||||||
|
|
||||||
|
|
||||||
|
def update_params(new_path: Path, file: Path):
|
||||||
|
params = load_toml(file)
|
||||||
|
if (p := params.get("prev_data_dir")) is not None:
|
||||||
|
p = Path(p)
|
||||||
|
params["prev_data_dir"] = str(p.parent / update_path(p.name))
|
||||||
|
params["output_path"] = str(new_path)
|
||||||
|
save_toml(new_path / PARAM_FN, params)
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def update_path(p: str) -> str:
|
||||||
|
return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p)
|
||||||
@@ -116,6 +116,7 @@ class VariationDescriptor(utils.HashableBaseModel):
|
|||||||
index: tuple[tuple[int, ...], ...]
|
index: tuple[tuple[int, ...], ...]
|
||||||
separator: str = "fiber"
|
separator: str = "fiber"
|
||||||
_format_registry: dict[str, Callable[..., str]] = {}
|
_format_registry: dict[str, Callable[..., str]] = {}
|
||||||
|
__ids: dict[int, int] = {}
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.formatted_descriptor(add_identifier=False)
|
return self.formatted_descriptor(add_identifier=False)
|
||||||
@@ -173,7 +174,19 @@ class VariationDescriptor(utils.HashableBaseModel):
|
|||||||
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
|
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_config(self, cfg: dict[str, Any]):
|
def update_config(self, cfg: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""updates a dictionary with the value of the descriptor
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cfg : dict[str, Any]
|
||||||
|
dict to be updated
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict[str, Any]
|
||||||
|
same as cfg but with key from the descriptor added/updated.
|
||||||
|
"""
|
||||||
return cfg | {k: v for k, v in self.raw_descr[-1]}
|
return cfg | {k: v for k, v in self.raw_descr[-1]}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -188,17 +201,22 @@ class VariationDescriptor(utils.HashableBaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def branch(self) -> "BranchDescriptor":
|
def branch(self) -> "BranchDescriptor":
|
||||||
for i in reversed(range(len(self.raw_descr))):
|
descr = []
|
||||||
for j in reversed(range(len(self.raw_descr[i]))):
|
ind = []
|
||||||
if self.raw_descr[i][j][0] == "num":
|
for i, l in enumerate(self.raw_descr):
|
||||||
del self.raw_descr[i][j]
|
descr.append([])
|
||||||
return VariationDescriptor(
|
ind.append([])
|
||||||
raw_descr=self.raw_descr, index=self.index, separator=self.separator
|
for j, (k, v) in enumerate(l):
|
||||||
)
|
if k != "num":
|
||||||
|
descr[-1].append((k, v))
|
||||||
|
ind[-1].append(self.index[i][j])
|
||||||
|
return BranchDescriptor(raw_descr=descr, index=ind, separator=self.separator)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identifier(self) -> str:
|
def identifier(self) -> str:
|
||||||
return "u_" + utils.to_62(hash(str(self.flat)))
|
unique_id = hash(str(self.flat))
|
||||||
|
self.__ids.setdefault(unique_id, len(self.__ids))
|
||||||
|
return "u_" + str(self.__ids[unique_id])
|
||||||
|
|
||||||
|
|
||||||
class BranchDescriptor(VariationDescriptor):
|
class BranchDescriptor(VariationDescriptor):
|
||||||
@@ -208,7 +226,7 @@ class BranchDescriptor(VariationDescriptor):
|
|||||||
def identifier(self) -> str:
|
def identifier(self) -> str:
|
||||||
branch_id = hash(str(self.flat))
|
branch_id = hash(str(self.flat))
|
||||||
self.__ids.setdefault(branch_id, len(self.__ids))
|
self.__ids.setdefault(branch_id, len(self.__ids))
|
||||||
return str(self.__ids[branch_id])
|
return "b_" + str(self.__ids[branch_id])
|
||||||
|
|
||||||
@validator("raw_descr")
|
@validator("raw_descr")
|
||||||
def validate_raw_descr(cls, v):
|
def validate_raw_descr(cls, v):
|
||||||
@@ -1,14 +1,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from collections import ChainMap
|
from collections import ChainMap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .. import const, env, scripts, utils
|
from .. import const, env, scripts
|
||||||
|
from .. import _utils as utils
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..physics.fiber import dispersion_coefficients
|
from ..physics.fiber import dispersion_coefficients
|
||||||
from ..physics.simulate import SequencialSimulations, run_simulation
|
from ..physics.simulate import SequencialSimulations, run_simulation
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ def pbar_format(worker_id: int):
|
|||||||
|
|
||||||
|
|
||||||
SPEC1_FN = "spectrum_{}.npy"
|
SPEC1_FN = "spectrum_{}.npy"
|
||||||
SPECN_FN = "spectra_{}.npy"
|
SPECN_FN1 = "spectra_{}.npy"
|
||||||
|
SPEC1_FN_N = "spectrum_{}_{}.npy"
|
||||||
Z_FN = "z.npy"
|
Z_FN = "z.npy"
|
||||||
PARAM_FN = "params.toml"
|
PARAM_FN = "params.toml"
|
||||||
PARAM_SEPARATOR = " "
|
PARAM_SEPARATOR = " "
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ def data_folder(task_id: int) -> Optional[str]:
|
|||||||
return tmp
|
return tmp
|
||||||
|
|
||||||
|
|
||||||
def get(key: str) -> Any:
|
def get(key: str, default=None) -> Any:
|
||||||
str_value = os.environ.get(key)
|
str_value = os.environ.get(key)
|
||||||
if isinstance(str_value, str):
|
if isinstance(str_value, str):
|
||||||
try:
|
try:
|
||||||
@@ -58,7 +58,7 @@ def get(key: str) -> Any:
|
|||||||
return t(str_value)
|
return t(str_value)
|
||||||
except (ValueError, KeyError):
|
except (ValueError, KeyError):
|
||||||
pass
|
pass
|
||||||
return None
|
return default
|
||||||
|
|
||||||
|
|
||||||
def all_environ() -> Dict[str, str]:
|
def all_environ() -> Dict[str, str]:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.interpolate import griddata, interp1d
|
from scipy.interpolate import griddata, interp1d
|
||||||
from scipy.special import jn_zeros
|
from scipy.special import jn_zeros
|
||||||
from .utils.cache import np_cache
|
from ._utils.cache import np_cache
|
||||||
|
|
||||||
pi = np.pi
|
pi = np.pi
|
||||||
c = 299792458.0
|
c = 299792458.0
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from scipy.optimize import minimize_scalar
|
|||||||
|
|
||||||
from .. import math
|
from .. import math
|
||||||
from . import fiber, materials, units, pulse
|
from . import fiber, materials, units, pulse
|
||||||
from .. import utils
|
from .. import _utils
|
||||||
from ..utils import cache
|
from .._utils import cache
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ from scipy.interpolate import interp1d
|
|||||||
|
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
|
|
||||||
from .. import utils
|
from .. import _utils
|
||||||
from ..math import abs2, argclosest, power_fact, u_nm
|
from ..math import abs2, argclosest, power_fact, u_nm
|
||||||
from ..utils.cache import np_cache
|
from .._utils.cache import np_cache
|
||||||
from . import materials as mat
|
from . import materials as mat
|
||||||
from . import units
|
from . import units
|
||||||
from .units import c, pi
|
from .units import c, pi
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from scipy.integrate import cumulative_trapezoid
|
|||||||
|
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from . import units
|
from . import units
|
||||||
from .. import utils
|
from .. import _utils
|
||||||
from .units import NA, c, kB, me, e, hbar
|
from .units import NA, c, kB, me, e, hbar
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,6 @@ from scipy.interpolate import UnivariateSpline
|
|||||||
from scipy.optimize import minimize_scalar
|
from scipy.optimize import minimize_scalar
|
||||||
from scipy.optimize.optimize import OptimizeResult
|
from scipy.optimize.optimize import OptimizeResult
|
||||||
|
|
||||||
from scgenerator import utils
|
|
||||||
|
|
||||||
from ..defaults import default_plotting
|
from ..defaults import default_plotting
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..math import *
|
from ..math import *
|
||||||
|
|||||||
@@ -9,9 +9,11 @@ from typing import Any, Generator, Type, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
from .. import env, utils
|
from .. import env
|
||||||
|
from .. import _utils as utils
|
||||||
|
from .._utils.utils import combine_simulations
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..utils.parameter import Configuration, Parameters
|
from .._utils.parameter import Configuration, Parameters
|
||||||
from . import pulse
|
from . import pulse
|
||||||
from .fiber import create_non_linear_op, fast_dispersion_op
|
from .fiber import create_non_linear_op, fast_dispersion_op
|
||||||
|
|
||||||
@@ -718,17 +720,9 @@ def run_simulation(
|
|||||||
|
|
||||||
sim = new_simulation(config, method)
|
sim = new_simulation(config, method)
|
||||||
sim.run()
|
sim.run()
|
||||||
path_trees = utils.build_path_trees(config.fiber_paths[-1])
|
|
||||||
|
|
||||||
final_name = env.get(env.OUTPUT_PATH)
|
for path in config.fiber_paths:
|
||||||
if final_name is None:
|
combine_simulations(path)
|
||||||
final_name = config.final_path
|
|
||||||
|
|
||||||
utils.merge(final_name, path_trees)
|
|
||||||
try:
|
|
||||||
send2trash(config.fiber_paths)
|
|
||||||
except (PermissionError, OSError):
|
|
||||||
get_logger(__name__).error("Could not send temporary directories to trash")
|
|
||||||
|
|
||||||
|
|
||||||
def new_simulation(
|
def new_simulation(
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# For example, nm(X) means "I give the number X in nm, figure out the ang. freq."
|
# For example, nm(X) means "I give the number X in nm, figure out the ang. freq."
|
||||||
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
|
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, TypeVar, Union
|
from typing import Callable, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ from .const import PARAM_SEPARATOR
|
|||||||
from .defaults import default_plotting as defaults
|
from .defaults import default_plotting as defaults
|
||||||
from .math import abs2, span
|
from .math import abs2, span
|
||||||
from .physics import pulse, units
|
from .physics import pulse, units
|
||||||
from .utils.parameter import Parameters
|
from ._utils.parameter import Parameters
|
||||||
from .utils.utils import PlotRange, sort_axis
|
from ._utils.utils import PlotRange, sort_axis
|
||||||
|
|
||||||
RangeType = tuple[float, float, Union[str, Callable]]
|
RangeType = tuple[float, float, Union[str, Callable]]
|
||||||
NO_LIM = object()
|
NO_LIM = object()
|
||||||
|
|||||||
@@ -12,12 +12,11 @@ from ..const import PARAM_FN, PARAM_SEPARATOR
|
|||||||
from ..physics import fiber, units
|
from ..physics import fiber, units
|
||||||
from ..plotting import plot_setup
|
from ..plotting import plot_setup
|
||||||
from ..spectra import Pulse
|
from ..spectra import Pulse
|
||||||
from ..utils import auto_crop, open_config, save_toml, translate_parameters
|
from .._utils import auto_crop, open_config, save_toml, translate_parameters
|
||||||
from ..utils.parameter import (
|
from .._utils.parameter import (
|
||||||
Configuration,
|
Configuration,
|
||||||
Parameters,
|
Parameters,
|
||||||
)
|
)
|
||||||
from ..utils.variationer import VariationDescriptor
|
|
||||||
|
|
||||||
|
|
||||||
def fingerprint(params: Parameters):
|
def fingerprint(params: Parameters):
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ from typing import Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..utils import Paths
|
from .._utils import Paths
|
||||||
from ..utils.parameter import Configuration
|
from .._utils.parameter import Configuration
|
||||||
|
|
||||||
|
|
||||||
def primes(n):
|
def primes(n):
|
||||||
|
|||||||
@@ -1,13 +1,19 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Iterable, Optional, Union
|
from typing import Any, Callable, Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, DirectoryPath, root_validator
|
||||||
|
|
||||||
from . import math
|
from . import math
|
||||||
from .const import SPECN_FN
|
from ._utils import load_spectrum
|
||||||
|
from ._utils.parameter import Parameters
|
||||||
|
from ._utils.utils import PlotRange
|
||||||
|
from .const import SPECN_FN1, PARAM_FN, SPEC1_FN_N
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .physics import pulse, units
|
from .physics import pulse, units
|
||||||
from .plotting import (
|
from .plotting import (
|
||||||
@@ -16,9 +22,87 @@ from .plotting import (
|
|||||||
single_position_plot,
|
single_position_plot,
|
||||||
transform_2D_propagation,
|
transform_2D_propagation,
|
||||||
)
|
)
|
||||||
from .utils.parameter import Parameters
|
|
||||||
from .utils.utils import PlotRange
|
|
||||||
from .utils import load_spectrum
|
class SimulationSeries:
|
||||||
|
path: Path
|
||||||
|
params: Parameters
|
||||||
|
total_length: float
|
||||||
|
total_num_steps: int
|
||||||
|
previous: SimulationSeries = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def __init__(self, path: os.PathLike):
|
||||||
|
self.path = Path(path)
|
||||||
|
self.params = Parameters.load(self.path / PARAM_FN)
|
||||||
|
if self.params.prev_data_dir is not None:
|
||||||
|
self.previous = SimulationSeries(self.params.prev_data_dir)
|
||||||
|
self.total_length = self.accumulate_params("length")
|
||||||
|
self.total_num_steps = self.accumulate_params("z_num")
|
||||||
|
|
||||||
|
def fiber_map(self):
|
||||||
|
lengths = self.all_params("length")
|
||||||
|
return [
|
||||||
|
(this[0], following[1]) for this, following in zip(lengths, [(None, 0.0)] + lengths)
|
||||||
|
]
|
||||||
|
|
||||||
|
def all_params(self, key: str) -> list[tuple[str, Any]]:
|
||||||
|
"""returns the value of a parameter for each fiber
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
name of the parameter
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[tuple[str, Any]]
|
||||||
|
list of (fiber_name, param_value) tuples
|
||||||
|
"""
|
||||||
|
return list(reversed(self._all_params(key, [])))
|
||||||
|
|
||||||
|
def accumulate_params(self, key: str) -> Any:
|
||||||
|
"""returns the sum of all the values a parameter takes. Useful to
|
||||||
|
get the total length of the fiber, the total number of steps, etc.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
key : str
|
||||||
|
name of the parameter
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Any
|
||||||
|
final sum
|
||||||
|
"""
|
||||||
|
return sum(el[1] for el in self.all_params(key))
|
||||||
|
|
||||||
|
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
|
||||||
|
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
|
||||||
|
|
||||||
|
def _all_params(self, key: str, l: list) -> list:
|
||||||
|
l.append((self.params.name, getattr(self.params, key)))
|
||||||
|
if self.previous is not None:
|
||||||
|
return self.previous._all_params(key, l)
|
||||||
|
return l
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})"
|
||||||
|
|
||||||
|
def __eq__(self, other: SimulationSeries) -> bool:
|
||||||
|
return (
|
||||||
|
self.path == other.path
|
||||||
|
and self.params == other.params
|
||||||
|
and self.previous == other.previous
|
||||||
|
)
|
||||||
|
|
||||||
|
def __contains__(self, other: SimulationSeries) -> bool:
|
||||||
|
if other is self or other == self:
|
||||||
|
return True
|
||||||
|
if self.previous is not None:
|
||||||
|
return other in self.previous
|
||||||
|
|
||||||
|
|
||||||
class Spectrum(np.ndarray):
|
class Spectrum(np.ndarray):
|
||||||
@@ -129,6 +213,23 @@ class Spectrum(np.ndarray):
|
|||||||
|
|
||||||
|
|
||||||
class Pulse(Sequence):
|
class Pulse(Sequence):
|
||||||
|
path: Path
|
||||||
|
default_ind: Optional[int]
|
||||||
|
params: Parameters
|
||||||
|
z: np.ndarray
|
||||||
|
namx: int
|
||||||
|
t: np.ndarray
|
||||||
|
w: np.ndarray
|
||||||
|
w_order: np.ndarray
|
||||||
|
|
||||||
|
def __new__(cls, path: os.PathLike, *args, **kwargs) -> "Pulse":
|
||||||
|
try:
|
||||||
|
if load_spectrum(Path(path) / SPECN_FN1.format(0)).ndim == 2:
|
||||||
|
return super().__new__(LegacyPulse)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
|
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
|
||||||
"""load a data folder as a pulse
|
"""load a data folder as a pulse
|
||||||
|
|
||||||
@@ -144,36 +245,6 @@ class Pulse(Sequence):
|
|||||||
FileNotFoundError
|
FileNotFoundError
|
||||||
path does not contain proper data
|
path does not contain proper data
|
||||||
"""
|
"""
|
||||||
self.logger = get_logger(__name__)
|
|
||||||
self.path = Path(path)
|
|
||||||
self.default_ind = default_ind
|
|
||||||
|
|
||||||
if not self.path.is_dir():
|
|
||||||
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
|
||||||
|
|
||||||
self.params = Parameters.load(self.path / "params.toml")
|
|
||||||
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
|
||||||
if self.params.fiber_map is None:
|
|
||||||
self.params.fiber_map = [(0.0, self.params.name)]
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.z = np.load(os.path.join(path, "z.npy"))
|
|
||||||
except FileNotFoundError:
|
|
||||||
if self.params is not None:
|
|
||||||
self.z = self.params.z_targets
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
self.nmax = len(list(self.path.glob("spectra_*.npy")))
|
|
||||||
if self.nmax <= 0:
|
|
||||||
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
|
||||||
|
|
||||||
self.t = self.params.t
|
|
||||||
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
|
||||||
self.w_order = np.argsort(w)
|
|
||||||
self.w = w
|
|
||||||
self.wl = units.m.inv(self.w)
|
|
||||||
self.params.w = self.w
|
|
||||||
self.params.z_targets = self.z
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""
|
"""
|
||||||
@@ -190,73 +261,6 @@ class Pulse(Sequence):
|
|||||||
def __getitem__(self, key) -> Spectrum:
|
def __getitem__(self, key) -> Spectrum:
|
||||||
return self.all_spectra(key)
|
return self.all_spectra(key)
|
||||||
|
|
||||||
def intensity(self, unit):
|
|
||||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
|
||||||
x_axis = unit.inv(self.w)
|
|
||||||
else:
|
|
||||||
x_axis = unit.inv(self.t)
|
|
||||||
|
|
||||||
order = np.argsort(x_axis)
|
|
||||||
func = dict(
|
|
||||||
WL=self._to_wl_int,
|
|
||||||
FREQ=self._to_freq_int,
|
|
||||||
AFREQ=self._to_afreq_int,
|
|
||||||
TIME=self._to_time_int,
|
|
||||||
)[unit.type]
|
|
||||||
|
|
||||||
for spec in self:
|
|
||||||
yield x_axis[order], func(spec)[:, order]
|
|
||||||
|
|
||||||
def _to_wl_int(self, spectrum):
|
|
||||||
return units.to_WL(math.abs2(spectrum), spectrum.wl)
|
|
||||||
|
|
||||||
def _to_freq_int(self, spectrum):
|
|
||||||
return math.abs2(spectrum)
|
|
||||||
|
|
||||||
def _to_afreq_int(self, spectrum):
|
|
||||||
return math.abs2(spectrum)
|
|
||||||
|
|
||||||
def _to_time_int(self, spectrum):
|
|
||||||
return math.abs2(np.fft.ifft(spectrum))
|
|
||||||
|
|
||||||
def amplitude(self, unit):
|
|
||||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
|
||||||
x_axis = unit.inv(self.w)
|
|
||||||
else:
|
|
||||||
x_axis = unit.inv(self.t)
|
|
||||||
|
|
||||||
order = np.argsort(x_axis)
|
|
||||||
func = dict(
|
|
||||||
WL=self._to_wl_amp,
|
|
||||||
FREQ=self._to_freq_amp,
|
|
||||||
AFREQ=self._to_afreq_amp,
|
|
||||||
TIME=self._to_time_amp,
|
|
||||||
)[unit.type]
|
|
||||||
|
|
||||||
for spec in self:
|
|
||||||
yield x_axis[order], func(spec)[:, order]
|
|
||||||
|
|
||||||
def _to_wl_amp(self, spectrum):
|
|
||||||
return (
|
|
||||||
np.sqrt(
|
|
||||||
units.to_WL(
|
|
||||||
math.abs2(spectrum),
|
|
||||||
spectrum.wl,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
* spectrum
|
|
||||||
/ np.abs(spectrum)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _to_freq_amp(self, spectrum):
|
|
||||||
return spectrum
|
|
||||||
|
|
||||||
def _to_afreq_amp(self, spectrum):
|
|
||||||
return spectrum
|
|
||||||
|
|
||||||
def _to_time_amp(self, spectrum):
|
|
||||||
return np.fft.ifft(spectrum)
|
|
||||||
|
|
||||||
def all_spectra(self, ind=None) -> Spectrum:
|
def all_spectra(self, ind=None) -> Spectrum:
|
||||||
"""
|
"""
|
||||||
loads the data already simulated.
|
loads the data already simulated.
|
||||||
@@ -305,12 +309,7 @@ class Pulse(Sequence):
|
|||||||
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
|
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
|
||||||
|
|
||||||
def _load1(self, i: int):
|
def _load1(self, i: int):
|
||||||
if i < 0:
|
pass
|
||||||
i = self.nmax + i
|
|
||||||
spec = load_spectrum(self.path / SPECN_FN.format(i))
|
|
||||||
spec = np.atleast_2d(spec)
|
|
||||||
spec = Spectrum(spec, self.params)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
def plot_2D(
|
def plot_2D(
|
||||||
self,
|
self,
|
||||||
@@ -412,3 +411,46 @@ class Pulse(Sequence):
|
|||||||
index
|
index
|
||||||
"""
|
"""
|
||||||
return math.argclosest(self.z, z)
|
return math.argclosest(self.z, z)
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyPulse(Pulse):
|
||||||
|
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
|
||||||
|
print("old init called", path, default_ind)
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
self.path = Path(path)
|
||||||
|
self.default_ind = default_ind
|
||||||
|
|
||||||
|
if not self.path.is_dir():
|
||||||
|
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
||||||
|
|
||||||
|
self.params = Parameters.load(self.path / "params.toml")
|
||||||
|
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
||||||
|
if self.params.fiber_map is None:
|
||||||
|
self.params.fiber_map = [(0.0, self.params.name)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.z = np.load(os.path.join(path, "z.npy"))
|
||||||
|
except FileNotFoundError:
|
||||||
|
if self.params is not None:
|
||||||
|
self.z = self.params.z_targets
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
self.nmax = len(list(self.path.glob("spectra_*.npy")))
|
||||||
|
if self.nmax <= 0:
|
||||||
|
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
||||||
|
|
||||||
|
self.t = self.params.t
|
||||||
|
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
||||||
|
self.w_order = np.argsort(w)
|
||||||
|
self.w = w
|
||||||
|
self.wl = units.m.inv(self.w)
|
||||||
|
self.params.w = self.w
|
||||||
|
self.params.z_targets = self.z
|
||||||
|
|
||||||
|
def _load1(self, i: int):
|
||||||
|
if i < 0:
|
||||||
|
i = self.nmax + i
|
||||||
|
spec = load_spectrum(self.path / SPECN_FN1.format(i))
|
||||||
|
spec = np.atleast_2d(spec)
|
||||||
|
spec = Spectrum(spec, self.params)
|
||||||
|
return spec
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
import shutil
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import toml
|
|
||||||
from scgenerator import logger
|
|
||||||
from send2trash import send2trash
|
|
||||||
|
|
||||||
TMP = "testing/.tmp"
|
|
||||||
|
|
||||||
|
|
||||||
class TestRecoveryParamSequence(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP)
|
|
||||||
self.conf = toml.load(TMP + "/initial_config.toml")
|
|
||||||
io.set_data_folder(55, TMP)
|
|
||||||
|
|
||||||
def test_remaining_simulations_count(self):
|
|
||||||
param_seq = initialize.RecoveryParamSequence(self.conf, 55)
|
|
||||||
self.assertEqual(5, len(param_seq))
|
|
||||||
|
|
||||||
def test_only_one_to_complete(self):
|
|
||||||
param_seq = initialize.RecoveryParamSequence(self.conf, 55)
|
|
||||||
i = 0
|
|
||||||
for expected, (vary_list, params) in zip([True, False, False, False, False], param_seq):
|
|
||||||
i += 1
|
|
||||||
self.assertEqual(expected, "recovery_last_stored" in params)
|
|
||||||
|
|
||||||
self.assertEqual(5, i)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
send2trash(TMP)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,216 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import toml
|
|
||||||
from scgenerator import defaults, utils, math
|
|
||||||
from scgenerator.errors import *
|
|
||||||
from scgenerator.physics import pulse, units
|
|
||||||
from scgenerator.utils.parameter import Config, Parameters
|
|
||||||
|
|
||||||
|
|
||||||
def load_conf(name):
|
|
||||||
with open("testing/configs/" + name + ".toml") as file:
|
|
||||||
conf = toml.load(file)
|
|
||||||
return conf
|
|
||||||
|
|
||||||
|
|
||||||
def conf_maker(folder):
|
|
||||||
def conf(name):
|
|
||||||
return load_conf(folder + "/" + name)
|
|
||||||
|
|
||||||
return conf
|
|
||||||
|
|
||||||
|
|
||||||
class TestParamSequence(unittest.TestCase):
|
|
||||||
def iterconf(self, files):
|
|
||||||
conf = conf_maker("param_sequence")
|
|
||||||
for path in files:
|
|
||||||
yield init.ParamSequence(conf(path))
|
|
||||||
|
|
||||||
def test_no_repeat_in_sub_folder_names(self):
|
|
||||||
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
|
|
||||||
l = []
|
|
||||||
s = []
|
|
||||||
for vary_list, _ in utils.required_simulations(param_seq.config):
|
|
||||||
self.assertNotIn(vary_list, l)
|
|
||||||
self.assertNotIn(utils.format_variable_list(vary_list), s)
|
|
||||||
l.append(vary_list)
|
|
||||||
s.append(utils.format_variable_list(vary_list))
|
|
||||||
|
|
||||||
def test_no_variations_yields_only_num_and_id(self):
|
|
||||||
for param_seq in self.iterconf(["no_variations"]):
|
|
||||||
for vary_list, _ in utils.required_simulations(param_seq.config):
|
|
||||||
self.assertEqual(vary_list[1][0], "num")
|
|
||||||
self.assertEqual(vary_list[0][0], "id")
|
|
||||||
self.assertEqual(2, len(vary_list))
|
|
||||||
|
|
||||||
|
|
||||||
class TestInitializeMethods(unittest.TestCase):
|
|
||||||
def test_validate_types(self):
|
|
||||||
conf = lambda s: load_conf("validate_types/" + s)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, r"'behaviors\[3\]' must be a str in"):
|
|
||||||
init.Config(**conf("bad2"))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(TypeError, "value must be of type <class 'float'>"):
|
|
||||||
init.Config(**conf("bad3"))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"):
|
|
||||||
init.Config(**conf("bad4"))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
TypeError, "'variable intensity_noise' value must be of type <class 'list'>"
|
|
||||||
):
|
|
||||||
init.Config(**conf("bad5"))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, "'repeat' must be positive"):
|
|
||||||
init.Config(**conf("bad6"))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
ValueError, "variable parameter 'intensity_noise' must not be empty"
|
|
||||||
):
|
|
||||||
init.Config(**conf("bad7"))
|
|
||||||
|
|
||||||
self.assertIsNone(init.Config(**conf("good")).hr_w)
|
|
||||||
|
|
||||||
def test_ensure_consistency(self):
|
|
||||||
conf = lambda s: load_conf("ensure_consistency/" + s)
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
MissingParameterError,
|
|
||||||
r"1 of '\['t0', 'width'\]' is required and no defaults have been set",
|
|
||||||
):
|
|
||||||
init.Config(**conf("bad1"))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
MissingParameterError,
|
|
||||||
r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set",
|
|
||||||
):
|
|
||||||
init.Config(**conf("bad2"))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
MissingParameterError,
|
|
||||||
r"2 of '\['dt', 't_num', 'time_window'\]' are required and no defaults have been set",
|
|
||||||
):
|
|
||||||
init.Config(**conf("bad3"))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
DuplicateParameterError,
|
|
||||||
r"got multiple values for parameter 'width'",
|
|
||||||
):
|
|
||||||
init.Config(**conf("bad4"))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
MissingParameterError,
|
|
||||||
r"'capillary_thickness' is a required parameter for fiber model 'hasan' and no defaults have been set",
|
|
||||||
):
|
|
||||||
init.Config(**conf("bad5"))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
MissingParameterError,
|
|
||||||
r"1 of '\['capillary_spacing', 'capillary_outer_d'\]' is required for fiber model 'hasan' and no defaults have been set",
|
|
||||||
):
|
|
||||||
init.Config(**conf("bad6"))
|
|
||||||
|
|
||||||
self.assertLessEqual(
|
|
||||||
{"model": "pcf"}.items(), init.Config(**conf("good1")).__dict__.items()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNone(init.Config(**conf("good4")).gamma)
|
|
||||||
|
|
||||||
self.assertLessEqual(
|
|
||||||
{"raman_type": "agrawal"}.items(),
|
|
||||||
init.Config(**conf("good2")).__dict__.items(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertLessEqual(
|
|
||||||
{"name": "no name"}.items(), init.Config(**conf("good3")).__dict__.items()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertLessEqual(
|
|
||||||
{"capillary_nested": 0, "capillary_resonance_strengths": []}.items(),
|
|
||||||
init.Config(**conf("good4")).__dict__.items(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertLessEqual(
|
|
||||||
dict(he_mode=(1, 1)).items(),
|
|
||||||
init.Config(**conf("good5")).__dict__.items(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertLessEqual(
|
|
||||||
dict(temperature=300, pressure=1e5, gas_name="vacuum", plasma_density=0).items(),
|
|
||||||
init.Config(**conf("good5")).__dict__.items(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def setup_conf_custom_field(self, path) -> Parameters:
|
|
||||||
|
|
||||||
conf = load_conf(path)
|
|
||||||
conf = Parameters(**conf)
|
|
||||||
init.build_sim_grid_in_place(conf)
|
|
||||||
return conf
|
|
||||||
|
|
||||||
def test_setup_custom_field(self):
|
|
||||||
d = np.load("testing/configs/custom_field/init_field.npz")
|
|
||||||
t = d["time"]
|
|
||||||
field = d["field"]
|
|
||||||
conf = self.setup_conf_custom_field("custom_field/no_change")
|
|
||||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
|
||||||
conf
|
|
||||||
)
|
|
||||||
self.assertAlmostEqual(conf.field_0.real.max(), field.real.max(), 4)
|
|
||||||
self.assertTrue(result)
|
|
||||||
|
|
||||||
conf = self.setup_conf_custom_field("custom_field/peak_power")
|
|
||||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
|
||||||
conf
|
|
||||||
)
|
|
||||||
conf.wavelength = pulse.correct_wavelength(conf.wavelength, conf.w_c, conf.field_0)
|
|
||||||
self.assertAlmostEqual(math.abs2(conf.field_0).max(), 20000, 4)
|
|
||||||
self.assertTrue(result)
|
|
||||||
self.assertNotAlmostEqual(conf.wavelength, 1593e-9)
|
|
||||||
|
|
||||||
conf = self.setup_conf_custom_field("custom_field/mean_power")
|
|
||||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
|
||||||
conf
|
|
||||||
)
|
|
||||||
self.assertAlmostEqual(np.trapz(math.abs2(conf.field_0), conf.t), 0.22 / 40e6, 4)
|
|
||||||
self.assertTrue(result)
|
|
||||||
|
|
||||||
conf = self.setup_conf_custom_field("custom_field/recover1")
|
|
||||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
|
||||||
conf
|
|
||||||
)
|
|
||||||
self.assertAlmostEqual(math.abs2(conf.field_0 - field).sum(), 0)
|
|
||||||
self.assertTrue(result)
|
|
||||||
|
|
||||||
conf = self.setup_conf_custom_field("custom_field/recover2")
|
|
||||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
|
||||||
conf
|
|
||||||
)
|
|
||||||
self.assertAlmostEqual((math.abs2(conf.field_0) / 0.9 - math.abs2(field)).sum(), 0)
|
|
||||||
self.assertTrue(result)
|
|
||||||
|
|
||||||
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
|
|
||||||
result = Parameters(**conf)
|
|
||||||
self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9)
|
|
||||||
|
|
||||||
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
|
|
||||||
conf.wavelength = 1593e-9
|
|
||||||
result = Parameters(**conf)
|
|
||||||
|
|
||||||
conf = load_conf("custom_field/wavelength_shift2")
|
|
||||||
conf = init.Config(**conf)
|
|
||||||
for target, (variable, config) in zip(
|
|
||||||
[1050e-9, 1321e-9, 1593e-9], init.ParamSequence(conf)
|
|
||||||
):
|
|
||||||
init.build_sim_grid_in_place(conf)
|
|
||||||
self.assertAlmostEqual(
|
|
||||||
units.m.inv(config.w)[np.argmax(math.abs2(config.spec_0))], target
|
|
||||||
)
|
|
||||||
print(config.wavelength, target)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
conf = conf_maker("validate_types")
|
|
||||||
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from scgenerator.physics.pulse import conform_pulse_params
|
|
||||||
|
|
||||||
|
|
||||||
class TestPulseMethods(unittest.TestCase):
|
|
||||||
def test_conform_pulse_params(self):
|
|
||||||
self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, energy=6))
|
|
||||||
self.assertNotIn(None, conform_pulse_params("gaussian", width=5, energy=6))
|
|
||||||
self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, peak_power=6))
|
|
||||||
self.assertNotIn(None, conform_pulse_params("gaussian", width=5, peak_power=6))
|
|
||||||
|
|
||||||
self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, energy=6)))
|
|
||||||
self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, energy=6)))
|
|
||||||
self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, peak_power=6)))
|
|
||||||
self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, peak_power=6)))
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
TypeError, "when soliton number is desired, both gamma and beta2 must be specified"
|
|
||||||
):
|
|
||||||
conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01)
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
TypeError, "when soliton number is desired, both gamma and beta2 must be specified"
|
|
||||||
):
|
|
||||||
conform_pulse_params("gaussian", t0=5, energy=6, beta2=0.01)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
5, len(conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01, beta2=2e-6))
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
5, len(conform_pulse_params("gaussian", width=5, energy=6, gamma=0.01, beta2=2e-6))
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
5, len(conform_pulse_params("gaussian", t0=5, peak_power=6, gamma=0.01, beta2=2e-6))
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
5, len(conform_pulse_params("gaussian", width=5, peak_power=6, gamma=0.01, beta2=2e-6))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import toml
|
|
||||||
from scgenerator import utils
|
|
||||||
|
|
||||||
|
|
||||||
def load_conf(name):
|
|
||||||
with open("testing/configs/" + name + ".toml") as file:
|
|
||||||
conf = toml.load(file)
|
|
||||||
return conf
|
|
||||||
|
|
||||||
|
|
||||||
def conf_maker(folder, val=True):
|
|
||||||
def conf(name):
|
|
||||||
if val:
|
|
||||||
return initialize.Config(**load_conf(folder + "/" + name))
|
|
||||||
else:
|
|
||||||
return initialize.Config(**load_conf(folder + "/" + name))
|
|
||||||
|
|
||||||
return conf
|
|
||||||
|
|
||||||
|
|
||||||
class TestUtilsMethods(unittest.TestCase):
|
|
||||||
def test_count_variations(self):
|
|
||||||
conf = conf_maker("count_variations")
|
|
||||||
|
|
||||||
for sim, vary in [(1, 0), (1, 1), (2, 1), (2, 0), (120, 3)]:
|
|
||||||
self.assertEqual((sim, vary), utils.count_variations(conf(f"{sim}sim_{vary}vary")))
|
|
||||||
|
|
||||||
def test_format_value(self):
|
|
||||||
values = [
|
|
||||||
122e-6,
|
|
||||||
True,
|
|
||||||
["raman", "ss"],
|
|
||||||
np.arange(5),
|
|
||||||
1.123,
|
|
||||||
1.1230001,
|
|
||||||
0.002e122,
|
|
||||||
12.3456e-9,
|
|
||||||
]
|
|
||||||
s = [
|
|
||||||
"0.000122",
|
|
||||||
"True",
|
|
||||||
"raman-ss",
|
|
||||||
"0-1-2-3-4",
|
|
||||||
"1.123",
|
|
||||||
"1.1230001",
|
|
||||||
"2e+119",
|
|
||||||
"1.23456e-08",
|
|
||||||
]
|
|
||||||
|
|
||||||
for value, target in zip(values, s):
|
|
||||||
self.assertEqual(target, utils.format_value(value))
|
|
||||||
|
|
||||||
def test_override_config(self):
|
|
||||||
conf = conf_maker("override", False)
|
|
||||||
old = conf("initial_config")
|
|
||||||
new = conf("fiber2")
|
|
||||||
|
|
||||||
over = utils.override_config(vars(new), old)
|
|
||||||
self.assertNotIn("input_transmission", over.variable)
|
|
||||||
self.assertIsNone(over.input_transmission)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
54
testing/test_variationer.py
Normal file
54
testing/test_variationer.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from pydantic import main
|
||||||
|
import scgenerator as sc
|
||||||
|
|
||||||
|
|
||||||
|
def test_descriptor():
|
||||||
|
# Same branch
|
||||||
|
var1 = sc.VariationDescriptor(
|
||||||
|
raw_descr=[[("num", 1), ("a", False)], [("b", 0)]], index=[[1, 0], [0]]
|
||||||
|
)
|
||||||
|
var2 = sc.VariationDescriptor(
|
||||||
|
raw_descr=[[("num", 2), ("a", False)], [("b", 0)]], index=[[1, 0], [0]]
|
||||||
|
)
|
||||||
|
assert var1.branch.identifier == "b_0"
|
||||||
|
assert var1.identifier != var1.branch.identifier
|
||||||
|
assert var1.identifier != var2.identifier
|
||||||
|
assert var1.branch.identifier == var2.branch.identifier
|
||||||
|
|
||||||
|
# different branch
|
||||||
|
var3 = sc.VariationDescriptor(
|
||||||
|
raw_descr=[[("num", 2), ("a", True)], [("b", 0)]], index=[[1, 0], [0]]
|
||||||
|
)
|
||||||
|
assert var1.branch.identifier != var3.branch.identifier
|
||||||
|
assert var1.formatted_descriptor() != var2.formatted_descriptor()
|
||||||
|
assert var1.formatted_descriptor() != var3.formatted_descriptor()
|
||||||
|
|
||||||
|
|
||||||
|
def test_variationer():
|
||||||
|
var = sc.Variationer(
|
||||||
|
[
|
||||||
|
dict(a=[1, 2], num=[0, 1, 2]),
|
||||||
|
[dict(b=["000", "111"], c=["a", "-1"])],
|
||||||
|
dict(),
|
||||||
|
dict(),
|
||||||
|
[dict(aaa=[True, False], bb=[1, 3])],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert var.var_num(0) == 6
|
||||||
|
assert var.var_num(1) == 12
|
||||||
|
assert var.var_num() == 24
|
||||||
|
|
||||||
|
cfg = dict(bb=None)
|
||||||
|
branches = set()
|
||||||
|
for descr in var.iterate():
|
||||||
|
assert descr.update_config(cfg).items() >= set(descr.raw_descr[-1])
|
||||||
|
branches.add(descr.branch.identifier)
|
||||||
|
assert len(branches) == 8
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_descriptor()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user