partway trough big data structure revamp

This commit is contained in:
Benoît Sierro
2021-09-28 13:28:00 +02:00
parent 695ac3bd73
commit 262a5b9701
26 changed files with 349 additions and 578 deletions

19
play.py
View File

@@ -4,22 +4,3 @@ import scgenerator as sc
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
from pprint import pprint from pprint import pprint
def _main():
print(os.getcwd())
for v_list, params in sc.Configuration("PM1550+PM2000D+PM1550/Pos30000.toml"):
print(params.fiber_map)
def main():
drr = os.getcwd()
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations")
try:
_main()
finally:
os.chdir(drr)
if __name__ == "__main__":
main()

View File

@@ -1,9 +1,10 @@
from . import math, utils from . import math
from .math import abs2, argclosest, span from .math import abs2, argclosest, span
from .physics import fiber, materials, pulse, simulate, units from .physics import fiber, materials, pulse, simulate, units
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
from .spectra import Pulse, Spectrum from .spectra import Pulse, Spectrum
from .utils import Paths, open_config, parameter from ._utils import Paths, open_config, parameter
from .utils.parameter import Configuration, Parameters from ._utils.parameter import Configuration, Parameters
from .utils.utils import PlotRange from ._utils.utils import PlotRange
from ._utils.variationer import Variationer, VariationDescriptor, VariationSpecsError

View File

@@ -25,7 +25,7 @@ import pkg_resources as pkg
import toml import toml
from tqdm import tqdm from tqdm import tqdm
from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__ from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN1, Z_FN, __version__
from ..env import pbar_policy from ..env import pbar_policy
from ..logger import get_logger from ..logger import get_logger
@@ -143,7 +143,8 @@ def resolve_loadfile_arg(dico: dict[str, Any]) -> dict[str, Any]:
return dico return dico
def load_toml(descr: str) -> dict[str, Any]: def load_toml(descr: os.PathLike) -> dict[str, Any]:
descr = str(descr)
if ":" in descr: if ":" in descr:
path, entry = descr.split(":", 1) path, entry = descr.split(":", 1)
with open(path) as file: with open(path) as file:
@@ -188,6 +189,7 @@ def load_config_sequence(path: os.PathLike) -> tuple[Path, list[dict[str, Any]]]
params.setdefault("variable", {}) params.setdefault("variable", {})
configs.append(loaded_config | params) configs.append(loaded_config | params)
configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"] configs[0]["variable"] = loaded_config.get("variable", {}) | configs[0]["variable"]
configs[0]["variable"]["num"] = list(range(configs[0].get("repeat", 1)))
return Path(final_path), configs return Path(final_path), configs
@@ -341,7 +343,7 @@ def merge_path_tree(
for i, (z, merged_spectra) in enumerate(merge_spectra(path_tree)): for i, (z, merged_spectra) in enumerate(merge_spectra(path_tree)):
z_arr.append(z) z_arr.append(z)
spec_out_name = SPECN_FN.format(i) spec_out_name = SPECN_FN1.format(i)
np.save(destination / spec_out_name, merged_spectra) np.save(destination / spec_out_name, merged_spectra)
if z_callback is not None: if z_callback is not None:
z_callback(i) z_callback(i)

View File

@@ -14,14 +14,15 @@ from typing import Any, Callable, Generator, Iterable, Literal, Optional, Sequen
import numpy as np import numpy as np
from numpy.lib import isin from numpy.lib import isin
from scgenerator.utils import ensure_folder, variationer
from .. import math, utils from .. import math
from ..const import PARAM_FN, PARAM_SEPARATOR, __version__ from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
from ..errors import EvaluatorError, NoDefaultError from ..errors import EvaluatorError, NoDefaultError
from ..logger import get_logger from ..logger import get_logger
from ..physics import fiber, materials, pulse, units from ..physics import fiber, materials, pulse, units
from ..utils.variationer import VariationDescriptor, Variationer from .._utils.variationer import VariationDescriptor, Variationer
from .. import _utils as utils
from .. import env
from .utils import func_rewrite, _mock_function, get_arg_names from .utils import func_rewrite, _mock_function, get_arg_names
T = TypeVar("T") T = TypeVar("T")
@@ -312,13 +313,6 @@ class Parameter:
return f"{num_str} {unit}" return f"{num_str} {unit}"
def fiber_map_converter(d: dict[str, str]) -> list[tuple[float, str]]:
if isinstance(d, dict):
return [(float(k), v) for k, v in d.items()]
else:
return [(float(k), v) for k, v in d]
@dataclass @dataclass
class Parameters: class Parameters:
""" """
@@ -432,15 +426,13 @@ class Parameters:
const_qty: np.ndarray = Parameter(type_checker(np.ndarray)) const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
beta_func: Callable[[float], list[float]] = Parameter(func_validator) beta_func: Callable[[float], list[float]] = Parameter(func_validator)
gamma_func: Callable[[float], float] = Parameter(func_validator) gamma_func: Callable[[float], float] = Parameter(func_validator)
fiber_map: list[tuple[float, str]] = Parameter(
validator_list(type_checker(tuple)), converter=fiber_map_converter num: int = Parameter(non_negative(int))
)
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
version: str = Parameter(string) version: str = Parameter(string)
def prepare_for_dump(self) -> dict[str, Any]: def prepare_for_dump(self) -> dict[str, Any]:
param = asdict(self) param = asdict(self)
param["fiber_map"] = [(str(z), n) for z, n in param.get("fiber_map", [])]
param = Parameters.strip_params_dict(param) param = Parameters.strip_params_dict(param)
param["datetime"] = datetime_module.datetime.now() param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__ param["version"] = __version__
@@ -816,7 +808,9 @@ class Configuration:
self.overwrite = overwrite self.overwrite = overwrite
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path) self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
self.final_path = utils.ensure_folder( self.final_path = utils.ensure_folder(
self.final_path, mkdir=False, prevent_overwrite=not self.overwrite Path(env.get(env.OUTPUT_PATH, self.final_path)),
mkdir=False,
prevent_overwrite=not self.overwrite,
) )
self.master_config = self.fiber_configs[0] self.master_config = self.fiber_configs[0]
self.name = self.final_path.name self.name = self.final_path.name
@@ -868,23 +862,8 @@ class Configuration:
def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]: def __iter__(self) -> Generator[tuple[VariationDescriptor, Parameters], None, None]:
for i in range(self.num_fibers): for i in range(self.num_fibers):
for sim_config in self.iterate_single_fiber(i): for sim_config in self.iterate_single_fiber(i):
if i > 0:
sim_config.config["prev_data_dir"] = str(
self.fiber_paths[i - 1] / sim_config.descriptor[:i].formatted_descriptor()
)
params = Parameters(**sim_config.config) params = Parameters(**sim_config.config)
params.compute() params.compute()
fiber_map = []
for j in range(i + 1):
this_conf = self.all_configs[sim_config.descriptor.index[: j + 1]].config
if j > 0:
prev_conf = self.all_configs[sim_config.descriptor.index[:j]].config
length = prev_conf["length"] + fiber_map[j - 1][0]
else:
length = 0.0
fiber_map.append((length, this_conf["name"]))
params.fiber_map = fiber_map
yield sim_config.descriptor, params yield sim_config.descriptor, params
def iterate_single_fiber( def iterate_single_fiber(
@@ -903,18 +882,21 @@ class Configuration:
__SimConfig __SimConfig
configuration obj configuration obj
""" """
sim_dict: dict[Path, self.__SimConfig] = {} sim_dict: dict[Path, Configuration.__SimConfig] = {}
for descr in self.variationer.iterate(index): for descriptor in self.variationer.iterate(index):
cfg = descr.update_config(self.fiber_configs[index]) cfg = descriptor.update_config(self.fiber_configs[index])
p = ensure_folder( if index > 0:
self.fiber_paths[index] / descr.formatted_descriptor(), cfg["prev_data_dir"] = str(
self.fiber_paths[index - 1] / descriptor[:index].formatted_descriptor(True)
)
p = utils.ensure_folder(
self.fiber_paths[index] / descriptor.formatted_descriptor(True),
not self.overwrite, not self.overwrite,
False, False,
) )
cfg["output_path"] = str(p) cfg["output_path"] = str(p)
sim_config = self.__SimConfig(descr, cfg, p) sim_config = self.__SimConfig(descriptor, cfg, p)
sim_dict[p] = sim_config sim_dict[p] = self.all_configs[sim_config.descriptor.index] = sim_config
self.all_configs[sim_config.descriptor.index] = sim_config
while len(sim_dict) > 0: while len(sim_dict) > 0:
for data_dir, sim_config in sim_dict.items(): for data_dir, sim_config in sim_dict.items():
task, config_dict = self.__decide(sim_config) task, config_dict = self.__decide(sim_config)
@@ -1001,9 +983,12 @@ class Configuration:
raise ValueError(f"Too many spectra in {data_dir}") raise ValueError(f"Too many spectra in {data_dir}")
def save_parameters(self): def save_parameters(self):
for config, sim_dir in zip(self.fiber_configs, self.fiber_paths): os.makedirs(self.final_path, exist_ok=True)
os.makedirs(sim_dir, exist_ok=True) cfgs = [
utils.save_toml(sim_dir / f"initial_config.toml", config) cfg | dict(variable=self.variationer.all_dicts[i])
for i, cfg in enumerate(self.fiber_configs)
]
utils.save_toml(self.final_path / f"initial_config.toml", dict(name=self.name, Fiber=cfgs))
@property @property
def first(self) -> Parameters: def first(self) -> Parameters:

View File

@@ -1,12 +1,17 @@
import inspect import inspect
import os
import re import re
from collections import defaultdict
from functools import cache from functools import cache
from pathlib import Path
from string import printable as str_printable from string import printable as str_printable
from typing import Callable from typing import Callable
import numpy as np import numpy as np
from pydantic import BaseModel from pydantic import BaseModel
from .._utils import load_toml, save_toml
from ..const import PARAM_FN, Z_FN
from ..physics.units import get_unit from ..physics.units import get_unit
@@ -144,3 +149,55 @@ def _mock_function(num_args: int, num_returns: int) -> Callable:
out_func = scope[func_name] out_func = scope[func_name]
out_func.__module__ = "evaluator" out_func.__module__ = "evaluator"
return out_func return out_func
def combine_simulations(path: Path, dest: Path = None):
"""combines raw simulations into one folder per branch
Parameters
----------
path : Path
source of the simulations (must contain u_xx directories)
dest : Path, optional
if given, moves the simulations to dest, by default None
"""
paths: dict[str, list[Path]] = defaultdict(list)
if dest is None:
dest = path
for p in path.glob("u_*b_*"):
if p.is_dir():
paths[p.name.split()[1]].append(p)
for l in paths.values():
l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0])
for pulses in paths.values():
new_path = dest / update_path(pulses[0].name)
os.makedirs(new_path, exist_ok=True)
for num, pulse in enumerate(pulses):
params_ok = False
for file in pulse.glob("*"):
if file.name == PARAM_FN:
if not params_ok:
update_params(new_path, file)
params_ok = True
else:
file.unlink()
elif file.name == Z_FN:
file.rename(new_path / file.name)
else:
file.rename(new_path / (file.stem + f"_{num}" + file.suffix))
pulse.rmdir()
def update_params(new_path: Path, file: Path):
params = load_toml(file)
if (p := params.get("prev_data_dir")) is not None:
p = Path(p)
params["prev_data_dir"] = str(p.parent / update_path(p.name))
params["output_path"] = str(new_path)
save_toml(new_path / PARAM_FN, params)
file.unlink()
def update_path(p: str) -> str:
return re.sub(r"( ?num [0-9]+)|(u_[0-9]+ )", "", p)

View File

@@ -116,6 +116,7 @@ class VariationDescriptor(utils.HashableBaseModel):
index: tuple[tuple[int, ...], ...] index: tuple[tuple[int, ...], ...]
separator: str = "fiber" separator: str = "fiber"
_format_registry: dict[str, Callable[..., str]] = {} _format_registry: dict[str, Callable[..., str]] = {}
__ids: dict[int, int] = {}
def __str__(self) -> str: def __str__(self) -> str:
return self.formatted_descriptor(add_identifier=False) return self.formatted_descriptor(add_identifier=False)
@@ -173,7 +174,19 @@ class VariationDescriptor(utils.HashableBaseModel):
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
) )
def update_config(self, cfg: dict[str, Any]): def update_config(self, cfg: dict[str, Any]) -> dict[str, Any]:
"""updates a dictionary with the value of the descriptor
Parameters
----------
cfg : dict[str, Any]
dict to be updated
Returns
-------
dict[str, Any]
same as cfg but with key from the descriptor added/updated.
"""
return cfg | {k: v for k, v in self.raw_descr[-1]} return cfg | {k: v for k, v in self.raw_descr[-1]}
@property @property
@@ -188,17 +201,22 @@ class VariationDescriptor(utils.HashableBaseModel):
@property @property
def branch(self) -> "BranchDescriptor": def branch(self) -> "BranchDescriptor":
for i in reversed(range(len(self.raw_descr))): descr = []
for j in reversed(range(len(self.raw_descr[i]))): ind = []
if self.raw_descr[i][j][0] == "num": for i, l in enumerate(self.raw_descr):
del self.raw_descr[i][j] descr.append([])
return VariationDescriptor( ind.append([])
raw_descr=self.raw_descr, index=self.index, separator=self.separator for j, (k, v) in enumerate(l):
) if k != "num":
descr[-1].append((k, v))
ind[-1].append(self.index[i][j])
return BranchDescriptor(raw_descr=descr, index=ind, separator=self.separator)
@property @property
def identifier(self) -> str: def identifier(self) -> str:
return "u_" + utils.to_62(hash(str(self.flat))) unique_id = hash(str(self.flat))
self.__ids.setdefault(unique_id, len(self.__ids))
return "u_" + str(self.__ids[unique_id])
class BranchDescriptor(VariationDescriptor): class BranchDescriptor(VariationDescriptor):
@@ -208,7 +226,7 @@ class BranchDescriptor(VariationDescriptor):
def identifier(self) -> str: def identifier(self) -> str:
branch_id = hash(str(self.flat)) branch_id = hash(str(self.flat))
self.__ids.setdefault(branch_id, len(self.__ids)) self.__ids.setdefault(branch_id, len(self.__ids))
return str(self.__ids[branch_id]) return "b_" + str(self.__ids[branch_id])
@validator("raw_descr") @validator("raw_descr")
def validate_raw_descr(cls, v): def validate_raw_descr(cls, v):

View File

@@ -1,14 +1,13 @@
import argparse import argparse
import os import os
import re import re
import subprocess
import sys
from collections import ChainMap from collections import ChainMap
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from .. import const, env, scripts, utils from .. import const, env, scripts
from .. import _utils as utils
from ..logger import get_logger from ..logger import get_logger
from ..physics.fiber import dispersion_coefficients from ..physics.fiber import dispersion_coefficients
from ..physics.simulate import SequencialSimulations, run_simulation from ..physics.simulate import SequencialSimulations, run_simulation

View File

@@ -20,7 +20,8 @@ def pbar_format(worker_id: int):
SPEC1_FN = "spectrum_{}.npy" SPEC1_FN = "spectrum_{}.npy"
SPECN_FN = "spectra_{}.npy" SPECN_FN1 = "spectra_{}.npy"
SPEC1_FN_N = "spectrum_{}_{}.npy"
Z_FN = "z.npy" Z_FN = "z.npy"
PARAM_FN = "params.toml" PARAM_FN = "params.toml"
PARAM_SEPARATOR = " " PARAM_SEPARATOR = " "

View File

@@ -48,7 +48,7 @@ def data_folder(task_id: int) -> Optional[str]:
return tmp return tmp
def get(key: str) -> Any: def get(key: str, default=None) -> Any:
str_value = os.environ.get(key) str_value = os.environ.get(key)
if isinstance(str_value, str): if isinstance(str_value, str):
try: try:
@@ -58,7 +58,7 @@ def get(key: str) -> Any:
return t(str_value) return t(str_value)
except (ValueError, KeyError): except (ValueError, KeyError):
pass pass
return None return default
def all_environ() -> Dict[str, str]: def all_environ() -> Dict[str, str]:

View File

@@ -3,7 +3,7 @@ from typing import Union
import numpy as np import numpy as np
from scipy.interpolate import griddata, interp1d from scipy.interpolate import griddata, interp1d
from scipy.special import jn_zeros from scipy.special import jn_zeros
from .utils.cache import np_cache from ._utils.cache import np_cache
pi = np.pi pi = np.pi
c = 299792458.0 c = 299792458.0

View File

@@ -10,8 +10,8 @@ from scipy.optimize import minimize_scalar
from .. import math from .. import math
from . import fiber, materials, units, pulse from . import fiber, materials, units, pulse
from .. import utils from .. import _utils
from ..utils import cache from .._utils import cache
T = TypeVar("T") T = TypeVar("T")

View File

@@ -7,9 +7,9 @@ from scipy.interpolate import interp1d
from ..logger import get_logger from ..logger import get_logger
from .. import utils from .. import _utils
from ..math import abs2, argclosest, power_fact, u_nm from ..math import abs2, argclosest, power_fact, u_nm
from ..utils.cache import np_cache from .._utils.cache import np_cache
from . import materials as mat from . import materials as mat
from . import units from . import units
from .units import c, pi from .units import c, pi

View File

@@ -5,7 +5,7 @@ from scipy.integrate import cumulative_trapezoid
from ..logger import get_logger from ..logger import get_logger
from . import units from . import units
from .. import utils from .. import _utils
from .units import NA, c, kB, me, e, hbar from .units import NA, c, kB, me, e, hbar

View File

@@ -23,8 +23,6 @@ from scipy.interpolate import UnivariateSpline
from scipy.optimize import minimize_scalar from scipy.optimize import minimize_scalar
from scipy.optimize.optimize import OptimizeResult from scipy.optimize.optimize import OptimizeResult
from scgenerator import utils
from ..defaults import default_plotting from ..defaults import default_plotting
from ..logger import get_logger from ..logger import get_logger
from ..math import * from ..math import *

View File

@@ -9,9 +9,11 @@ from typing import Any, Generator, Type, Union
import numpy as np import numpy as np
from send2trash import send2trash from send2trash import send2trash
from .. import env, utils from .. import env
from .. import _utils as utils
from .._utils.utils import combine_simulations
from ..logger import get_logger from ..logger import get_logger
from ..utils.parameter import Configuration, Parameters from .._utils.parameter import Configuration, Parameters
from . import pulse from . import pulse
from .fiber import create_non_linear_op, fast_dispersion_op from .fiber import create_non_linear_op, fast_dispersion_op
@@ -718,17 +720,9 @@ def run_simulation(
sim = new_simulation(config, method) sim = new_simulation(config, method)
sim.run() sim.run()
path_trees = utils.build_path_trees(config.fiber_paths[-1])
final_name = env.get(env.OUTPUT_PATH) for path in config.fiber_paths:
if final_name is None: combine_simulations(path)
final_name = config.final_path
utils.merge(final_name, path_trees)
try:
send2trash(config.fiber_paths)
except (PermissionError, OSError):
get_logger(__name__).error("Could not send temporary directories to trash")
def new_simulation( def new_simulation(

View File

@@ -2,7 +2,6 @@
# For example, nm(X) means "I give the number X in nm, figure out the ang. freq." # For example, nm(X) means "I give the number X in nm, figure out the ang. freq."
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ... # to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
from dataclasses import dataclass
from typing import Callable, TypeVar, Union from typing import Callable, TypeVar, Union
import numpy as np import numpy as np

View File

@@ -14,8 +14,8 @@ from .const import PARAM_SEPARATOR
from .defaults import default_plotting as defaults from .defaults import default_plotting as defaults
from .math import abs2, span from .math import abs2, span
from .physics import pulse, units from .physics import pulse, units
from .utils.parameter import Parameters from ._utils.parameter import Parameters
from .utils.utils import PlotRange, sort_axis from ._utils.utils import PlotRange, sort_axis
RangeType = tuple[float, float, Union[str, Callable]] RangeType = tuple[float, float, Union[str, Callable]]
NO_LIM = object() NO_LIM = object()

View File

@@ -12,12 +12,11 @@ from ..const import PARAM_FN, PARAM_SEPARATOR
from ..physics import fiber, units from ..physics import fiber, units
from ..plotting import plot_setup from ..plotting import plot_setup
from ..spectra import Pulse from ..spectra import Pulse
from ..utils import auto_crop, open_config, save_toml, translate_parameters from .._utils import auto_crop, open_config, save_toml, translate_parameters
from ..utils.parameter import ( from .._utils.parameter import (
Configuration, Configuration,
Parameters, Parameters,
) )
from ..utils.variationer import VariationDescriptor
def fingerprint(params: Parameters): def fingerprint(params: Parameters):

View File

@@ -9,8 +9,8 @@ from typing import Tuple
import numpy as np import numpy as np
from ..utils import Paths from .._utils import Paths
from ..utils.parameter import Configuration from .._utils.parameter import Configuration
def primes(n): def primes(n):

View File

@@ -1,13 +1,19 @@
from __future__ import annotations
import os import os
from collections.abc import Sequence from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, Optional, Union from typing import Any, Callable, Dict, Iterable, Optional, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from pydantic import BaseModel, DirectoryPath, root_validator
from . import math from . import math
from .const import SPECN_FN from ._utils import load_spectrum
from ._utils.parameter import Parameters
from ._utils.utils import PlotRange
from .const import SPECN_FN1, PARAM_FN, SPEC1_FN_N
from .logger import get_logger from .logger import get_logger
from .physics import pulse, units from .physics import pulse, units
from .plotting import ( from .plotting import (
@@ -16,9 +22,87 @@ from .plotting import (
single_position_plot, single_position_plot,
transform_2D_propagation, transform_2D_propagation,
) )
from .utils.parameter import Parameters
from .utils.utils import PlotRange
from .utils import load_spectrum class SimulationSeries:
path: Path
params: Parameters
total_length: float
total_num_steps: int
previous: SimulationSeries = None
class Config:
arbitrary_types_allowed = True
def __init__(self, path: os.PathLike):
self.path = Path(path)
self.params = Parameters.load(self.path / PARAM_FN)
if self.params.prev_data_dir is not None:
self.previous = SimulationSeries(self.params.prev_data_dir)
self.total_length = self.accumulate_params("length")
self.total_num_steps = self.accumulate_params("z_num")
def fiber_map(self):
lengths = self.all_params("length")
return [
(this[0], following[1]) for this, following in zip(lengths, [(None, 0.0)] + lengths)
]
def all_params(self, key: str) -> list[tuple[str, Any]]:
"""returns the value of a parameter for each fiber
Parameters
----------
key : str
name of the parameter
Returns
-------
list[tuple[str, Any]]
list of (fiber_name, param_value) tuples
"""
return list(reversed(self._all_params(key, [])))
def accumulate_params(self, key: str) -> Any:
"""returns the sum of all the values a parameter takes. Useful to
get the total length of the fiber, the total number of steps, etc.
Parameters
----------
key : str
name of the parameter
Returns
-------
Any
final sum
"""
return sum(el[1] for el in self.all_params(key))
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
def _all_params(self, key: str, l: list) -> list:
l.append((self.params.name, getattr(self.params, key)))
if self.previous is not None:
return self.previous._all_params(key, l)
return l
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path}, previous={self.previous!r})"
def __eq__(self, other: SimulationSeries) -> bool:
return (
self.path == other.path
and self.params == other.params
and self.previous == other.previous
)
def __contains__(self, other: SimulationSeries) -> bool:
if other is self or other == self:
return True
if self.previous is not None:
return other in self.previous
class Spectrum(np.ndarray): class Spectrum(np.ndarray):
@@ -129,6 +213,23 @@ class Spectrum(np.ndarray):
class Pulse(Sequence): class Pulse(Sequence):
path: Path
default_ind: Optional[int]
params: Parameters
z: np.ndarray
namx: int
t: np.ndarray
w: np.ndarray
w_order: np.ndarray
def __new__(cls, path: os.PathLike, *args, **kwargs) -> "Pulse":
try:
if load_spectrum(Path(path) / SPECN_FN1.format(0)).ndim == 2:
return super().__new__(LegacyPulse)
except FileNotFoundError:
pass
return super().__new__(cls)
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None): def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
"""load a data folder as a pulse """load a data folder as a pulse
@@ -144,36 +245,6 @@ class Pulse(Sequence):
FileNotFoundError FileNotFoundError
path does not contain proper data path does not contain proper data
""" """
self.logger = get_logger(__name__)
self.path = Path(path)
self.default_ind = default_ind
if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = Parameters.load(self.path / "params.toml")
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
if self.params.fiber_map is None:
self.params.fiber_map = [(0.0, self.params.name)]
try:
self.z = np.load(os.path.join(path, "z.npy"))
except FileNotFoundError:
if self.params is not None:
self.z = self.params.z_targets
else:
raise
self.nmax = len(list(self.path.glob("spectra_*.npy")))
if self.nmax <= 0:
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
self.t = self.params.t
w = math.wspace(self.t) + units.m(self.params.wavelength)
self.w_order = np.argsort(w)
self.w = w
self.wl = units.m.inv(self.w)
self.params.w = self.w
self.params.z_targets = self.z
def __iter__(self): def __iter__(self):
""" """
@@ -190,73 +261,6 @@ class Pulse(Sequence):
def __getitem__(self, key) -> Spectrum: def __getitem__(self, key) -> Spectrum:
return self.all_spectra(key) return self.all_spectra(key)
def intensity(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self._to_wl_int,
FREQ=self._to_freq_int,
AFREQ=self._to_afreq_int,
TIME=self._to_time_int,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
def _to_wl_int(self, spectrum):
return units.to_WL(math.abs2(spectrum), spectrum.wl)
def _to_freq_int(self, spectrum):
return math.abs2(spectrum)
def _to_afreq_int(self, spectrum):
return math.abs2(spectrum)
def _to_time_int(self, spectrum):
return math.abs2(np.fft.ifft(spectrum))
def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
WL=self._to_wl_amp,
FREQ=self._to_freq_amp,
AFREQ=self._to_afreq_amp,
TIME=self._to_time_amp,
)[unit.type]
for spec in self:
yield x_axis[order], func(spec)[:, order]
def _to_wl_amp(self, spectrum):
return (
np.sqrt(
units.to_WL(
math.abs2(spectrum),
spectrum.wl,
)
)
* spectrum
/ np.abs(spectrum)
)
def _to_freq_amp(self, spectrum):
return spectrum
def _to_afreq_amp(self, spectrum):
return spectrum
def _to_time_amp(self, spectrum):
return np.fft.ifft(spectrum)
def all_spectra(self, ind=None) -> Spectrum: def all_spectra(self, ind=None) -> Spectrum:
""" """
loads the data already simulated. loads the data already simulated.
@@ -305,12 +309,7 @@ class Pulse(Sequence):
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1) return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
def _load1(self, i: int): def _load1(self, i: int):
if i < 0: pass
i = self.nmax + i
spec = load_spectrum(self.path / SPECN_FN.format(i))
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.params)
return spec
def plot_2D( def plot_2D(
self, self,
@@ -412,3 +411,46 @@ class Pulse(Sequence):
index index
""" """
return math.argclosest(self.z, z) return math.argclosest(self.z, z)
class LegacyPulse(Pulse):
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
print("old init called", path, default_ind)
self.logger = get_logger(__name__)
self.path = Path(path)
self.default_ind = default_ind
if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = Parameters.load(self.path / "params.toml")
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
if self.params.fiber_map is None:
self.params.fiber_map = [(0.0, self.params.name)]
try:
self.z = np.load(os.path.join(path, "z.npy"))
except FileNotFoundError:
if self.params is not None:
self.z = self.params.z_targets
else:
raise
self.nmax = len(list(self.path.glob("spectra_*.npy")))
if self.nmax <= 0:
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
self.t = self.params.t
w = math.wspace(self.t) + units.m(self.params.wavelength)
self.w_order = np.argsort(w)
self.w = w
self.wl = units.m.inv(self.w)
self.params.w = self.w
self.params.z_targets = self.z
def _load1(self, i: int):
if i < 0:
i = self.nmax + i
spec = load_spectrum(self.path / SPECN_FN1.format(i))
spec = np.atleast_2d(spec)
spec = Spectrum(spec, self.params)
return spec

View File

@@ -1,35 +0,0 @@
import shutil
import unittest
import toml
from scgenerator import logger
from send2trash import send2trash
TMP = "testing/.tmp"
class TestRecoveryParamSequence(unittest.TestCase):
def setUp(self):
shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP)
self.conf = toml.load(TMP + "/initial_config.toml")
io.set_data_folder(55, TMP)
def test_remaining_simulations_count(self):
param_seq = initialize.RecoveryParamSequence(self.conf, 55)
self.assertEqual(5, len(param_seq))
def test_only_one_to_complete(self):
param_seq = initialize.RecoveryParamSequence(self.conf, 55)
i = 0
for expected, (vary_list, params) in zip([True, False, False, False, False], param_seq):
i += 1
self.assertEqual(expected, "recovery_last_stored" in params)
self.assertEqual(5, i)
def tearDown(self):
send2trash(TMP)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,216 +0,0 @@
import unittest
from copy import deepcopy
import numpy as np
import toml
from scgenerator import defaults, utils, math
from scgenerator.errors import *
from scgenerator.physics import pulse, units
from scgenerator.utils.parameter import Config, Parameters
def load_conf(name):
with open("testing/configs/" + name + ".toml") as file:
conf = toml.load(file)
return conf
def conf_maker(folder):
def conf(name):
return load_conf(folder + "/" + name)
return conf
class TestParamSequence(unittest.TestCase):
def iterconf(self, files):
conf = conf_maker("param_sequence")
for path in files:
yield init.ParamSequence(conf(path))
def test_no_repeat_in_sub_folder_names(self):
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
l = []
s = []
for vary_list, _ in utils.required_simulations(param_seq.config):
self.assertNotIn(vary_list, l)
self.assertNotIn(utils.format_variable_list(vary_list), s)
l.append(vary_list)
s.append(utils.format_variable_list(vary_list))
def test_no_variations_yields_only_num_and_id(self):
for param_seq in self.iterconf(["no_variations"]):
for vary_list, _ in utils.required_simulations(param_seq.config):
self.assertEqual(vary_list[1][0], "num")
self.assertEqual(vary_list[0][0], "id")
self.assertEqual(2, len(vary_list))
class TestInitializeMethods(unittest.TestCase):
def test_validate_types(self):
conf = lambda s: load_conf("validate_types/" + s)
with self.assertRaisesRegex(ValueError, r"'behaviors\[3\]' must be a str in"):
init.Config(**conf("bad2"))
with self.assertRaisesRegex(TypeError, "value must be of type <class 'float'>"):
init.Config(**conf("bad3"))
with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"):
init.Config(**conf("bad4"))
with self.assertRaisesRegex(
TypeError, "'variable intensity_noise' value must be of type <class 'list'>"
):
init.Config(**conf("bad5"))
with self.assertRaisesRegex(ValueError, "'repeat' must be positive"):
init.Config(**conf("bad6"))
with self.assertRaisesRegex(
ValueError, "variable parameter 'intensity_noise' must not be empty"
):
init.Config(**conf("bad7"))
self.assertIsNone(init.Config(**conf("good")).hr_w)
def test_ensure_consistency(self):
conf = lambda s: load_conf("ensure_consistency/" + s)
with self.assertRaisesRegex(
MissingParameterError,
r"1 of '\['t0', 'width'\]' is required and no defaults have been set",
):
init.Config(**conf("bad1"))
with self.assertRaisesRegex(
MissingParameterError,
r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set",
):
init.Config(**conf("bad2"))
with self.assertRaisesRegex(
MissingParameterError,
r"2 of '\['dt', 't_num', 'time_window'\]' are required and no defaults have been set",
):
init.Config(**conf("bad3"))
with self.assertRaisesRegex(
DuplicateParameterError,
r"got multiple values for parameter 'width'",
):
init.Config(**conf("bad4"))
with self.assertRaisesRegex(
MissingParameterError,
r"'capillary_thickness' is a required parameter for fiber model 'hasan' and no defaults have been set",
):
init.Config(**conf("bad5"))
with self.assertRaisesRegex(
MissingParameterError,
r"1 of '\['capillary_spacing', 'capillary_outer_d'\]' is required for fiber model 'hasan' and no defaults have been set",
):
init.Config(**conf("bad6"))
self.assertLessEqual(
{"model": "pcf"}.items(), init.Config(**conf("good1")).__dict__.items()
)
self.assertIsNone(init.Config(**conf("good4")).gamma)
self.assertLessEqual(
{"raman_type": "agrawal"}.items(),
init.Config(**conf("good2")).__dict__.items(),
)
self.assertLessEqual(
{"name": "no name"}.items(), init.Config(**conf("good3")).__dict__.items()
)
self.assertLessEqual(
{"capillary_nested": 0, "capillary_resonance_strengths": []}.items(),
init.Config(**conf("good4")).__dict__.items(),
)
self.assertLessEqual(
dict(he_mode=(1, 1)).items(),
init.Config(**conf("good5")).__dict__.items(),
)
self.assertLessEqual(
dict(temperature=300, pressure=1e5, gas_name="vacuum", plasma_density=0).items(),
init.Config(**conf("good5")).__dict__.items(),
)
def setup_conf_custom_field(self, path) -> Parameters:
conf = load_conf(path)
conf = Parameters(**conf)
init.build_sim_grid_in_place(conf)
return conf
def test_setup_custom_field(self):
d = np.load("testing/configs/custom_field/init_field.npz")
t = d["time"]
field = d["field"]
conf = self.setup_conf_custom_field("custom_field/no_change")
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
conf
)
self.assertAlmostEqual(conf.field_0.real.max(), field.real.max(), 4)
self.assertTrue(result)
conf = self.setup_conf_custom_field("custom_field/peak_power")
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
conf
)
conf.wavelength = pulse.correct_wavelength(conf.wavelength, conf.w_c, conf.field_0)
self.assertAlmostEqual(math.abs2(conf.field_0).max(), 20000, 4)
self.assertTrue(result)
self.assertNotAlmostEqual(conf.wavelength, 1593e-9)
conf = self.setup_conf_custom_field("custom_field/mean_power")
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
conf
)
self.assertAlmostEqual(np.trapz(math.abs2(conf.field_0), conf.t), 0.22 / 40e6, 4)
self.assertTrue(result)
conf = self.setup_conf_custom_field("custom_field/recover1")
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
conf
)
self.assertAlmostEqual(math.abs2(conf.field_0 - field).sum(), 0)
self.assertTrue(result)
conf = self.setup_conf_custom_field("custom_field/recover2")
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
conf
)
self.assertAlmostEqual((math.abs2(conf.field_0) / 0.9 - math.abs2(field)).sum(), 0)
self.assertTrue(result)
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
result = Parameters(**conf)
self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9)
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
conf.wavelength = 1593e-9
result = Parameters(**conf)
conf = load_conf("custom_field/wavelength_shift2")
conf = init.Config(**conf)
for target, (variable, config) in zip(
[1050e-9, 1321e-9, 1593e-9], init.ParamSequence(conf)
):
init.build_sim_grid_in_place(conf)
self.assertAlmostEqual(
units.m.inv(config.w)[np.argmax(math.abs2(config.spec_0))], target
)
print(config.wavelength, target)
if __name__ == "__main__":
conf = conf_maker("validate_types")
unittest.main()

View File

@@ -1,41 +0,0 @@
import unittest
from scgenerator.physics.pulse import conform_pulse_params
class TestPulseMethods(unittest.TestCase):
def test_conform_pulse_params(self):
self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, energy=6))
self.assertNotIn(None, conform_pulse_params("gaussian", width=5, energy=6))
self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, peak_power=6))
self.assertNotIn(None, conform_pulse_params("gaussian", width=5, peak_power=6))
self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, energy=6)))
self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, energy=6)))
self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, peak_power=6)))
self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, peak_power=6)))
with self.assertRaisesRegex(
TypeError, "when soliton number is desired, both gamma and beta2 must be specified"
):
conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01)
with self.assertRaisesRegex(
TypeError, "when soliton number is desired, both gamma and beta2 must be specified"
):
conform_pulse_params("gaussian", t0=5, energy=6, beta2=0.01)
self.assertEqual(
5, len(conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01, beta2=2e-6))
)
self.assertEqual(
5, len(conform_pulse_params("gaussian", width=5, energy=6, gamma=0.01, beta2=2e-6))
)
self.assertEqual(
5, len(conform_pulse_params("gaussian", t0=5, peak_power=6, gamma=0.01, beta2=2e-6))
)
self.assertEqual(
5, len(conform_pulse_params("gaussian", width=5, peak_power=6, gamma=0.01, beta2=2e-6))
)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,67 +0,0 @@
import unittest
import numpy as np
import toml
from scgenerator import utils
def load_conf(name):
with open("testing/configs/" + name + ".toml") as file:
conf = toml.load(file)
return conf
def conf_maker(folder, val=True):
def conf(name):
if val:
return initialize.Config(**load_conf(folder + "/" + name))
else:
return initialize.Config(**load_conf(folder + "/" + name))
return conf
class TestUtilsMethods(unittest.TestCase):
def test_count_variations(self):
conf = conf_maker("count_variations")
for sim, vary in [(1, 0), (1, 1), (2, 1), (2, 0), (120, 3)]:
self.assertEqual((sim, vary), utils.count_variations(conf(f"{sim}sim_{vary}vary")))
def test_format_value(self):
values = [
122e-6,
True,
["raman", "ss"],
np.arange(5),
1.123,
1.1230001,
0.002e122,
12.3456e-9,
]
s = [
"0.000122",
"True",
"raman-ss",
"0-1-2-3-4",
"1.123",
"1.1230001",
"2e+119",
"1.23456e-08",
]
for value, target in zip(values, s):
self.assertEqual(target, utils.format_value(value))
def test_override_config(self):
conf = conf_maker("override", False)
old = conf("initial_config")
new = conf("fiber2")
over = utils.override_config(vars(new), old)
self.assertNotIn("input_transmission", over.variable)
self.assertIsNone(over.input_transmission)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,54 @@
from pydantic import main
import scgenerator as sc
def test_descriptor():
# Same branch
var1 = sc.VariationDescriptor(
raw_descr=[[("num", 1), ("a", False)], [("b", 0)]], index=[[1, 0], [0]]
)
var2 = sc.VariationDescriptor(
raw_descr=[[("num", 2), ("a", False)], [("b", 0)]], index=[[1, 0], [0]]
)
assert var1.branch.identifier == "b_0"
assert var1.identifier != var1.branch.identifier
assert var1.identifier != var2.identifier
assert var1.branch.identifier == var2.branch.identifier
# different branch
var3 = sc.VariationDescriptor(
raw_descr=[[("num", 2), ("a", True)], [("b", 0)]], index=[[1, 0], [0]]
)
assert var1.branch.identifier != var3.branch.identifier
assert var1.formatted_descriptor() != var2.formatted_descriptor()
assert var1.formatted_descriptor() != var3.formatted_descriptor()
def test_variationer():
var = sc.Variationer(
[
dict(a=[1, 2], num=[0, 1, 2]),
[dict(b=["000", "111"], c=["a", "-1"])],
dict(),
dict(),
[dict(aaa=[True, False], bb=[1, 3])],
]
)
assert var.var_num(0) == 6
assert var.var_num(1) == 12
assert var.var_num() == 24
cfg = dict(bb=None)
branches = set()
for descr in var.iterate():
assert descr.update_config(cfg).items() >= set(descr.raw_descr[-1])
branches.add(descr.branch.identifier)
assert len(branches) == 8
def main():
test_descriptor()
if __name__ == "__main__":
main()