in the middle of sorting circular import
This commit is contained in:
@@ -1,47 +0,0 @@
|
||||
from typing import Callable
|
||||
import inspect
|
||||
import re
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
spec = inspect.getfullargspec(func)
|
||||
args = spec.args
|
||||
if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
args = args[: -len(spec.defaults)]
|
||||
return args
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
return scope[tmp_name]
|
||||
|
||||
|
||||
def lol(a, b=None, c=None):
|
||||
print(f"{a=}, {b=}, {c=}")
|
||||
|
||||
|
||||
def main():
|
||||
lol1 = func_rewrite(lol, ["c"])
|
||||
print(inspect.getfullargspec(lol1))
|
||||
lol2 = func_rewrite(lol, ["b"])
|
||||
print(inspect.getfullargspec(lol2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
22
play.py
22
play.py
@@ -1,6 +1,18 @@
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import random
|
||||
from scgenerator import Parameters
|
||||
import os
|
||||
|
||||
for i in tqdm(range(100), smoothing=0):
|
||||
time.sleep(random.random())
|
||||
|
||||
def main():
|
||||
cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations")
|
||||
|
||||
pa = Parameters.load("PM1550+PM2000D/PM1550_PM2000D raman_test/initial_config_0.toml")
|
||||
|
||||
print(pa)
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -2,15 +2,15 @@ from . import initialize, io, math, utils
|
||||
from .initialize import (
|
||||
Config,
|
||||
ContinuationParamSequence,
|
||||
Params,
|
||||
Parameters,
|
||||
ParamSequence,
|
||||
RecoveryParamSequence,
|
||||
)
|
||||
from .io import Paths, load_params, load_toml
|
||||
from .io import Paths, load_toml
|
||||
from .math import abs2, argclosest, span
|
||||
from .physics import fiber, materials, pulse, simulate, units
|
||||
from .physics.simulate import RK4IP, new_simulation, resume_simulations
|
||||
from .physics.units import PlotRange
|
||||
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
||||
from .spectra import Pulse, Spectrum
|
||||
from .utils.parameter import BareConfig, BareParams
|
||||
from .utils.parameter import BareConfig, Parameters
|
||||
|
||||
@@ -15,25 +15,12 @@ from .utils import override_config, required_simulations
|
||||
from .utils.evaluator import Evaluator
|
||||
from .utils.parameter import (
|
||||
BareConfig,
|
||||
BareParams,
|
||||
Parameters,
|
||||
hc_model_specific_parameters,
|
||||
mandatory_parameters,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Params(BareParams):
|
||||
@classmethod
|
||||
def from_bare(cls, bare: BareParams):
|
||||
param_dict = {k: v for k, v in asdict(bare).items() if v is not None}
|
||||
evaluator = Evaluator.default()
|
||||
evaluator.set(**param_dict)
|
||||
for p_name in mandatory_parameters:
|
||||
evaluator.compute(p_name)
|
||||
new_param_dict = {k: v for k, v in evaluator.params.items() if k in param_dict}
|
||||
return cls(**new_param_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config(BareConfig):
|
||||
@classmethod
|
||||
@@ -222,11 +209,11 @@ class ParamSequence:
|
||||
|
||||
self.update_num_sim()
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
|
||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Parameters]]:
|
||||
"""iterates through all possible parameters, yielding a config as well as a flattened
|
||||
computed parameters set each time"""
|
||||
for variable_list, bare_params in required_simulations(self.config):
|
||||
yield variable_list, Params.from_bare(bare_params)
|
||||
for variable_list, params in required_simulations(self.config):
|
||||
yield variable_list, params
|
||||
|
||||
def __len__(self):
|
||||
return self.num_sim
|
||||
@@ -259,19 +246,19 @@ class ContinuationParamSequence(ParamSequence):
|
||||
new config
|
||||
"""
|
||||
self.prev_sim_dir = Path(prev_sim_dir)
|
||||
self.bare_configs = io.load_config_sequence(new_config.previous_config_file)
|
||||
self.bare_configs = BareConfig.load_sequence(new_config.previous_config_file)
|
||||
self.bare_configs.append(new_config)
|
||||
self.bare_configs[0] = Config.from_bare(self.bare_configs[0])
|
||||
final_config = utils.final_config_from_sequence(*self.bare_configs)
|
||||
super().__init__(final_config)
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
|
||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Parameters]]:
|
||||
"""iterates through all possible parameters, yielding a config as well as a flattened
|
||||
computed parameters set each time"""
|
||||
for variable_list, bare_params in required_simulations(*self.bare_configs):
|
||||
for variable_list, params in required_simulations(*self.bare_configs):
|
||||
prev_data_dir = self.find_prev_data_dirs(variable_list)[0]
|
||||
bare_params.prev_data_dir = str(prev_data_dir.resolve())
|
||||
yield variable_list, Params.from_bare(bare_params)
|
||||
params.prev_data_dir = str(prev_data_dir.resolve())
|
||||
yield variable_list, params
|
||||
|
||||
def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]:
|
||||
"""finds the previous simulation data that this new config should start from
|
||||
@@ -324,7 +311,7 @@ class RecoveryParamSequence(ParamSequence):
|
||||
self.prev_sim_dir = None
|
||||
if self.config.prev_sim_dir is not None:
|
||||
self.prev_sim_dir = Path(self.config.prev_sim_dir)
|
||||
init_config = io.load_config(self.prev_sim_dir / "initial_config.toml")
|
||||
init_config = BareConfig.load(self.prev_sim_dir / "initial_config.toml")
|
||||
self.prev_variable_lists = [
|
||||
(
|
||||
set(variable_list[1:]),
|
||||
@@ -357,17 +344,17 @@ class RecoveryParamSequence(ParamSequence):
|
||||
self.num_steps += not_started * self.config.z_num
|
||||
self.single_sim = self.num_sim == 1
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]:
|
||||
for variable_list, bare_params in required_simulations(self.config):
|
||||
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Parameters]]:
|
||||
for variable_list, params in required_simulations(self.config):
|
||||
|
||||
data_dir = io.get_sim_dir(self.id) / utils.format_variable_list(variable_list)
|
||||
|
||||
if not data_dir.is_dir() or io.find_last_spectrum_num(data_dir) == 0:
|
||||
if (prev_data_dir := self.find_prev_data_dir(variable_list)) is not None:
|
||||
bare_params.prev_data_dir = str(prev_data_dir)
|
||||
yield variable_list, Params.from_bare(bare_params)
|
||||
params.prev_data_dir = str(prev_data_dir)
|
||||
yield variable_list, params
|
||||
elif io.num_left_to_propagate(data_dir, self.config.z_num) != 0:
|
||||
yield variable_list, recover_params(bare_params, data_dir)
|
||||
yield variable_list, params + "Needs to rethink recovery procedure"
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -417,7 +404,7 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]:
|
||||
"""
|
||||
|
||||
previous = None
|
||||
configs = io.load_config_sequence(*configs)
|
||||
configs = BareConfig.load_sequence(*configs)
|
||||
for config in configs:
|
||||
# if (p := Path(config)).is_dir():
|
||||
# config = p / "initial_config.toml"
|
||||
@@ -487,22 +474,20 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]:
|
||||
# raise TypeError("not enough parameter to determine time vector")
|
||||
|
||||
|
||||
def recover_params(params: BareParams, data_folder: Path) -> Params:
|
||||
params = Params.from_bare(params)
|
||||
try:
|
||||
prev = io.load_params(data_folder / "params.toml")
|
||||
build_sim_grid_in_place(prev)
|
||||
except FileNotFoundError:
|
||||
prev = BareParams()
|
||||
for k, v in filter(lambda el: el[1] is not None, vars(prev).items()):
|
||||
if getattr(params, k) is None:
|
||||
setattr(params, k, v)
|
||||
num, last_spectrum = io.load_last_spectrum(data_folder)
|
||||
params.spec_0 = last_spectrum
|
||||
params.field_0 = np.fft.ifft(last_spectrum)
|
||||
params.recovery_last_stored = num
|
||||
params.cons_qty = np.load(data_folder / "cons_qty.npy")
|
||||
return params
|
||||
# def recover_params(params: Parameters, data_folder: Path) -> Parameters:
|
||||
# try:
|
||||
# prev = Parameters.load(data_folder / "params.toml")
|
||||
# except FileNotFoundError:
|
||||
# prev = Parameters()
|
||||
# for k, v in filter(lambda el: el[1] is not None, vars(prev).items()):
|
||||
# if getattr(params, k) is None:
|
||||
# setattr(params, k, v)
|
||||
# num, last_spectrum = io.load_last_spectrum(data_folder)
|
||||
# params.spec_0 = last_spectrum
|
||||
# params.field_0 = np.fft.ifft(last_spectrum)
|
||||
# params.recovery_last_stored = num
|
||||
# params.cons_qty = np.load(data_folder / "cons_qty.npy")
|
||||
# return params
|
||||
|
||||
|
||||
# def build_sim_grid(
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generator, List, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pkg_resources as pkg
|
||||
import toml
|
||||
from tqdm.std import Bar
|
||||
|
||||
from . import env, utils
|
||||
from scgenerator.utils.parameter import BareConfig
|
||||
|
||||
from . import env
|
||||
from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__
|
||||
from .env import TMP_FOLDER_KEY_BASE
|
||||
|
||||
from .errors import IncompleteDataFolderError
|
||||
from .logger import get_logger
|
||||
from .utils.parameter import BareConfig, BareParams, translate
|
||||
|
||||
PathTree = List[Tuple[Path, ...]]
|
||||
|
||||
@@ -98,7 +100,9 @@ def save_toml(path: os.PathLike, dico):
|
||||
return dico
|
||||
|
||||
|
||||
def save_parameters(params: BareParams, destination_dir: Path, file_name="params.toml") -> Path:
|
||||
def save_parameters(
|
||||
params: dict[str, Any], destination_dir: Path, file_name: str = "params.toml"
|
||||
) -> Path:
|
||||
"""saves a parameter dictionary. Note that is does remove some entries, particularly
|
||||
those that take a lot of space ("t", "w", ...)
|
||||
|
||||
@@ -114,87 +118,17 @@ def save_parameters(params: BareParams, destination_dir: Path, file_name="params
|
||||
Path
|
||||
path to newly created the paramter file
|
||||
"""
|
||||
param = params.prepare_for_dump()
|
||||
file_path = destination_dir / file_name
|
||||
|
||||
file_path.parent.mkdir(exist_ok=True)
|
||||
|
||||
# save toml of the simulation
|
||||
with open(file_path, "w") as file:
|
||||
toml.dump(param, file, encoder=toml.TomlNumpyEncoder())
|
||||
toml.dump(params, file, encoder=toml.TomlNumpyEncoder())
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def load_params(path: os.PathLike) -> BareParams:
|
||||
"""loads a parameters toml files and converts data to appropriate type
|
||||
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
path to the toml
|
||||
|
||||
Returns
|
||||
----------
|
||||
BareParams
|
||||
params obj
|
||||
"""
|
||||
params = load_toml(path)
|
||||
try:
|
||||
return BareParams(**params)
|
||||
except TypeError:
|
||||
return BareParams(**dict(translate(p, v) for p, v in params.items()))
|
||||
|
||||
|
||||
def load_config(path: os.PathLike) -> BareConfig:
|
||||
"""loads a parameters toml files and converts data to appropriate type
|
||||
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
path to the toml
|
||||
|
||||
Returns
|
||||
----------
|
||||
BareParams
|
||||
config obj
|
||||
"""
|
||||
config = load_toml(path)
|
||||
return BareConfig(**config)
|
||||
|
||||
|
||||
def load_config_sequence(*config_paths: os.PathLike) -> list[BareConfig]:
|
||||
"""Loads a sequence of
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_paths : os.PathLike
|
||||
either one path (the last config containing previous_config_file parameter)
|
||||
or a list of config path in the order they have to be simulated
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[BareConfig]
|
||||
all loaded configs
|
||||
"""
|
||||
if config_paths[0] is None:
|
||||
return []
|
||||
all_configs = [load_config(config_paths[0])]
|
||||
if len(config_paths) == 1:
|
||||
while True:
|
||||
if all_configs[0].previous_config_file is not None:
|
||||
all_configs.insert(0, load_config(all_configs[0].previous_config_file))
|
||||
else:
|
||||
break
|
||||
else:
|
||||
for i, path in enumerate(config_paths[1:]):
|
||||
all_configs.append(load_config(path))
|
||||
all_configs[i + 1].previous_config_file = config_paths[i]
|
||||
return all_configs
|
||||
|
||||
|
||||
def load_material_dico(name: str) -> dict[str, Any]:
|
||||
"""loads a material dictionary
|
||||
Parameters
|
||||
@@ -228,74 +162,6 @@ def get_data_dirs(sim_dir: Path) -> List[Path]:
|
||||
return [p.resolve() for p in sim_dir.glob("*") if p.is_dir()]
|
||||
|
||||
|
||||
def check_data_integrity(sub_folders: List[Path], init_z_num: int):
|
||||
"""checks the integrity and completeness of a simulation data folder
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
path to the data folder
|
||||
init_z_num : int
|
||||
z_num as specified by the initial configuration file
|
||||
|
||||
Raises
|
||||
------
|
||||
IncompleteDataFolderError
|
||||
raised if not all spectra are present in any folder
|
||||
"""
|
||||
|
||||
for sub_folder in utils.PBars(sub_folders, "Checking integrity"):
|
||||
if num_left_to_propagate(sub_folder, init_z_num) != 0:
|
||||
raise IncompleteDataFolderError(
|
||||
f"not enough spectra of the specified {init_z_num} found in {sub_folder}"
|
||||
)
|
||||
|
||||
|
||||
def num_left_to_propagate(sub_folder: Path, init_z_num: int) -> int:
|
||||
"""checks if a propagation has completed
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sub_folder : Path
|
||||
path to the sub folder containing the spectra
|
||||
init_z_num : int
|
||||
number of z position to store as specified in the master config file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the propagation has completed
|
||||
|
||||
Raises
|
||||
------
|
||||
IncompleteDataFolderError
|
||||
raised if init_z_num doesn't match that specified in the individual parameter file
|
||||
"""
|
||||
z_num = load_toml(sub_folder / "params.toml")["z_num"]
|
||||
num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing
|
||||
|
||||
if z_num != init_z_num:
|
||||
raise IncompleteDataFolderError(
|
||||
f"initial config specifies {init_z_num} spectra per"
|
||||
+ f" but the parameter file in {sub_folder} specifies {z_num}"
|
||||
)
|
||||
|
||||
return z_num - num_spectra
|
||||
|
||||
|
||||
def find_last_spectrum_num(data_dir: Path):
|
||||
for num in itertools.count(1):
|
||||
p_to_test = data_dir / SPEC1_FN.format(num)
|
||||
if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0:
|
||||
return num - 1
|
||||
|
||||
|
||||
def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]:
|
||||
"""return the last spectrum stored in path as well as its id"""
|
||||
num = find_last_spectrum_num(data_dir)
|
||||
return num, np.load(data_dir / SPEC1_FN.format(num))
|
||||
|
||||
|
||||
def update_appended_params(source: Path, destination: Path, z: Sequence):
|
||||
z_num = len(z)
|
||||
params = load_toml(source)
|
||||
|
||||
@@ -8,8 +8,7 @@ from typing import TypeVar
|
||||
import numpy as np
|
||||
from scipy.optimize import minimize_scalar
|
||||
|
||||
from ..io import load_material_dico
|
||||
from .. import math
|
||||
from .. import math, io
|
||||
from . import fiber, materials, units, pulse
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -54,7 +53,7 @@ def material_dispersion(
|
||||
|
||||
order = np.argsort(w)
|
||||
|
||||
material_dico = load_material_dico(material)
|
||||
material_dico = io.load_material_dico(material)
|
||||
if ideal:
|
||||
n_gas_2 = materials.sellmeier(wavelengths, material_dico, pressure, temperature) + 1
|
||||
else:
|
||||
|
||||
@@ -2,10 +2,8 @@ from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union, T
|
||||
|
||||
import numpy as np
|
||||
from numpy.ma import core
|
||||
import toml
|
||||
from numpy.fft import fft, ifft
|
||||
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
|
||||
from numpy.polynomial.polynomial import Polynomial
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from ..logger import get_logger
|
||||
|
||||
@@ -318,7 +318,8 @@ def L_sol(L_D):
|
||||
|
||||
|
||||
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
|
||||
return io.load_last_spectrum(Path(prev_data_dir))[1]
|
||||
num = utils.find_last_spectrum_num(data_dir)
|
||||
return np.load(data_dir / SPEC1_FN.format(num))
|
||||
|
||||
|
||||
def load_field_file(
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Dict, List, Tuple, Type, Union
|
||||
import numpy as np
|
||||
|
||||
from .. import env, initialize, io, utils
|
||||
from ..utils import Parameters, BareConfig
|
||||
from ..const import PARAM_SEPARATOR
|
||||
from ..errors import IncompleteDataFolderError
|
||||
from ..logger import get_logger
|
||||
@@ -23,29 +24,29 @@ except ModuleNotFoundError:
|
||||
class RK4IP:
|
||||
def __init__(
|
||||
self,
|
||||
params: initialize.Params,
|
||||
params: Parameters,
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
task_id=0,
|
||||
):
|
||||
"""A 1D solver using 4th order Runge-Kutta in the interaction picture
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Parameters
|
||||
----------
|
||||
params : Params
|
||||
parameters of the simulation
|
||||
save_data : bool, optional
|
||||
save calculated spectra to disk, by default False
|
||||
job_identifier : str, optional
|
||||
string identifying the parameter set, by default ""
|
||||
task_id : int, optional
|
||||
unique identifier of the session, by default 0
|
||||
parameters of the simulation
|
||||
save_data : bool, optional
|
||||
save calculated spectra to disk, by default False
|
||||
job_identifier : str, optional
|
||||
string identifying the parameter set, by default ""
|
||||
task_id : int, optional
|
||||
unique identifier of the session, by default 0
|
||||
"""
|
||||
self.set(params, save_data, job_identifier, task_id)
|
||||
|
||||
def set(
|
||||
self,
|
||||
params: initialize.Params,
|
||||
params: Parameters,
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
task_id=0,
|
||||
@@ -306,7 +307,7 @@ class RK4IP:
|
||||
class SequentialRK4IP(RK4IP):
|
||||
def __init__(
|
||||
self,
|
||||
params: initialize.Params,
|
||||
params: Parameters,
|
||||
pbars: utils.PBars,
|
||||
save_data=False,
|
||||
job_identifier="",
|
||||
@@ -327,7 +328,7 @@ class SequentialRK4IP(RK4IP):
|
||||
class MutliProcRK4IP(RK4IP):
|
||||
def __init__(
|
||||
self,
|
||||
params: initialize.Params,
|
||||
params: Parameters,
|
||||
p_queue: multiprocessing.Queue,
|
||||
worker_id: int,
|
||||
save_data=False,
|
||||
@@ -353,7 +354,7 @@ class RayRK4IP(RK4IP):
|
||||
|
||||
def set(
|
||||
self,
|
||||
params: initialize.Params,
|
||||
params: Parameters,
|
||||
p_actor,
|
||||
worker_id: int,
|
||||
save_data=False,
|
||||
@@ -445,7 +446,9 @@ class Simulations:
|
||||
self.sim_dir = io.get_sim_dir(
|
||||
self.id, path_if_new=Path(self.name + PARAM_SEPARATOR + "tmp")
|
||||
)
|
||||
io.save_parameters(self.param_seq.config, self.sim_dir, file_name="initial_config.toml")
|
||||
io.save_parameters(
|
||||
self.param_seq.config.prepare_for_dump(), self.sim_dir, file_name="initial_config.toml"
|
||||
)
|
||||
|
||||
self.sim_jobs_per_node = 1
|
||||
|
||||
@@ -467,19 +470,19 @@ class Simulations:
|
||||
def _run_available(self):
|
||||
for variable, params in self.param_seq:
|
||||
v_list_str = utils.format_variable_list(variable)
|
||||
io.save_parameters(params, self.sim_dir / v_list_str)
|
||||
io.save_parameters(params.prepare_for_dump(), self.sim_dir / v_list_str)
|
||||
|
||||
self.new_sim(v_list_str, params)
|
||||
self.finish()
|
||||
|
||||
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||
def new_sim(self, v_list_str: str, params: Parameters):
|
||||
"""responsible to launch a new simulation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
v_list_str : str
|
||||
string that uniquely identifies the simulation as returned by utils.format_variable_list
|
||||
params : initialize.Params
|
||||
params : Parameters
|
||||
computed parameters
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
@@ -507,7 +510,7 @@ class SequencialSimulations(Simulations, priority=0):
|
||||
super().__init__(param_seq, task_id=task_id)
|
||||
self.pbars = utils.PBars(self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1)
|
||||
|
||||
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||
def new_sim(self, v_list_str: str, params: Parameters):
|
||||
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
|
||||
SequentialRK4IP(
|
||||
params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id
|
||||
@@ -556,7 +559,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
||||
worker.start()
|
||||
super().run()
|
||||
|
||||
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||
def new_sim(self, v_list_str: str, params: Parameters):
|
||||
self.queue.put((v_list_str, params), block=True, timeout=None)
|
||||
|
||||
def finish(self):
|
||||
@@ -579,7 +582,7 @@ class MultiProcSimulations(Simulations, priority=1):
|
||||
p_queue: multiprocessing.Queue,
|
||||
):
|
||||
while True:
|
||||
raw_data: Tuple[List[tuple], initialize.Params] = queue.get()
|
||||
raw_data: Tuple[List[tuple], Parameters] = queue.get()
|
||||
if raw_data == 0:
|
||||
queue.task_done()
|
||||
return
|
||||
@@ -635,7 +638,7 @@ class RaySimulations(Simulations, priority=2):
|
||||
.remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
|
||||
)
|
||||
|
||||
def new_sim(self, v_list_str: str, params: initialize.Params):
|
||||
def new_sim(self, v_list_str: str, params: Parameters):
|
||||
while self.num_submitted >= self.sim_jobs_total:
|
||||
self.collect_1_job()
|
||||
|
||||
@@ -685,7 +688,7 @@ def run_simulation_sequence(
|
||||
method=None,
|
||||
prev_sim_dir: os.PathLike = None,
|
||||
):
|
||||
configs = io.load_config_sequence(*config_files)
|
||||
configs = BareConfig.load_sequence(*config_files)
|
||||
|
||||
prev = prev_sim_dir
|
||||
for config in configs:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from typing import Callable, TypeVar, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..utils.parameter import Parameter, boolean, type_checker
|
||||
from ..utils import parameter
|
||||
import numpy as np
|
||||
from numpy import pi
|
||||
|
||||
@@ -183,10 +183,10 @@ def is_unit(name, value):
|
||||
|
||||
@dataclass
|
||||
class PlotRange:
|
||||
left: float = Parameter(type_checker(int, float))
|
||||
right: float = Parameter(type_checker(int, float))
|
||||
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit)
|
||||
conserved_quantity: bool = Parameter(boolean, default=True)
|
||||
left: float = parameter.Parameter(parameter.type_checker(int, float))
|
||||
right: float = parameter.Parameter(parameter.type_checker(int, float))
|
||||
unit: Callable[[float], float] = parameter.Parameter(is_unit, converter=get_unit)
|
||||
conserved_quantity: bool = parameter.Parameter(parameter.boolean, default=True)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"
|
||||
|
||||
@@ -21,7 +21,7 @@ from . import io, math
|
||||
from .defaults import default_plotting as defaults
|
||||
from .math import abs2, make_uniform_1D, span
|
||||
from .physics import pulse, units
|
||||
from .utils.parameter import BareConfig, BareParams
|
||||
from .utils.parameter import BareConfig, Parameters
|
||||
|
||||
RangeType = Tuple[float, float, Union[str, Callable]]
|
||||
NO_LIM = object()
|
||||
@@ -263,7 +263,7 @@ def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0
|
||||
def propagation_plot(
|
||||
values: np.ndarray,
|
||||
plt_range: Union[units.PlotRange, RangeType],
|
||||
params: BareParams,
|
||||
params: Parameters,
|
||||
ax: plt.Axes,
|
||||
log: Union[int, float, bool, str] = "1D",
|
||||
vmin: float = None,
|
||||
@@ -281,7 +281,7 @@ def propagation_plot(
|
||||
raw values, either complex fields or complex spectra
|
||||
plt_range : Union[units.PlotRange, RangeType]
|
||||
time, wavelength or frequency range
|
||||
params : BareParams
|
||||
params : Parameters
|
||||
parameters of the simulation
|
||||
log : Union[int, float, bool, str], optional
|
||||
what kind of log to apply, see apply_log for details. by default "1D"
|
||||
@@ -418,7 +418,7 @@ def plot_2D(
|
||||
def transform_2D_propagation(
|
||||
values: np.ndarray,
|
||||
plt_range: Union[units.PlotRange, RangeType],
|
||||
params: BareParams,
|
||||
params: Parameters,
|
||||
log: Union[int, float, bool, str] = "1D",
|
||||
skip: int = 1,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
@@ -430,7 +430,7 @@ def transform_2D_propagation(
|
||||
values to transform
|
||||
plt_range : Union[units.PlotRange, RangeType]
|
||||
range
|
||||
params : BareParams
|
||||
params : Parameters
|
||||
parameters of the simulation
|
||||
log : Union[int, float, bool, str], optional
|
||||
see apply_log, by default "1D"
|
||||
@@ -469,7 +469,7 @@ def transform_2D_propagation(
|
||||
def mean_values_plot(
|
||||
values: np.ndarray,
|
||||
plt_range: Union[units.PlotRange, RangeType],
|
||||
params: BareParams,
|
||||
params: Parameters,
|
||||
ax: plt.Axes,
|
||||
log: Union[float, int, str, bool] = False,
|
||||
vmin: float = None,
|
||||
@@ -511,7 +511,7 @@ def mean_values_plot(
|
||||
def transform_mean_values(
|
||||
values: np.ndarray,
|
||||
plt_range: Union[units.PlotRange, RangeType],
|
||||
params: BareParams,
|
||||
params: Parameters,
|
||||
log: Union[bool, int, float] = False,
|
||||
spacing: Union[int, float] = 1,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
@@ -523,7 +523,7 @@ def transform_mean_values(
|
||||
values to transform
|
||||
plt_range : Union[units.PlotRange, RangeType]
|
||||
x axis specifications
|
||||
params : BareParams
|
||||
params : Parameters
|
||||
parameters of the simulation
|
||||
log : Union[bool, int, float], optional
|
||||
see transform_1D_values for details, by default False
|
||||
@@ -637,7 +637,7 @@ def plot_mean(
|
||||
def single_position_plot(
|
||||
values: np.ndarray,
|
||||
plt_range: Union[units.PlotRange, RangeType],
|
||||
params: BareParams,
|
||||
params: Parameters,
|
||||
ax: plt.Axes,
|
||||
log: Union[str, int, float, bool] = False,
|
||||
vmin: float = None,
|
||||
@@ -712,7 +712,7 @@ def plot_1D(
|
||||
def transform_1D_values(
|
||||
values: np.ndarray,
|
||||
plt_range: Union[units.PlotRange, RangeType],
|
||||
params: BareParams,
|
||||
params: Parameters,
|
||||
log: Union[int, float, bool] = False,
|
||||
spacing: Union[int, float] = 1,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
@@ -724,7 +724,7 @@ def transform_1D_values(
|
||||
values to plot, may be complex
|
||||
plt_range : Union[units.PlotRange, RangeType]
|
||||
plot range specification, either (min, max, unit) or a PlotRange obj
|
||||
params : BareParams
|
||||
params : Parameters
|
||||
parameters of the simulations
|
||||
log : Union[int, float, bool], optional
|
||||
if True, will convert to dB relative to max. If a float or int, whill
|
||||
@@ -767,7 +767,7 @@ def plot_spectrogram(
|
||||
values: np.ndarray,
|
||||
x_range: RangeType,
|
||||
y_range: RangeType,
|
||||
params: BareParams,
|
||||
params: Parameters,
|
||||
t_res: int = None,
|
||||
gate_width: float = None,
|
||||
log: bool = "2D",
|
||||
@@ -790,7 +790,7 @@ def plot_spectrogram(
|
||||
units : function to convert from the desired units to rad/s or to time.
|
||||
common functions are already defined in scgenerator.physics.units
|
||||
look there for more details
|
||||
params : BareParams
|
||||
params : Parameters
|
||||
parameters of the simulations
|
||||
log : bool, optional
|
||||
whether to compute the logarithm of the spectrogram
|
||||
@@ -954,7 +954,7 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
|
||||
|
||||
|
||||
def prep_plot_axis(
|
||||
values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], params: BareParams
|
||||
values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], params: Parameters
|
||||
) -> tuple[bool, np.ndarray, units.PlotRange]:
|
||||
is_spectrum = values.dtype == "complex"
|
||||
if not isinstance(plt_range, units.PlotRange):
|
||||
|
||||
@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils.parameter import BareParams
|
||||
from ..utils.parameter import Parameters
|
||||
from ..const import PARAM_SEPARATOR
|
||||
|
||||
from ..initialize import ParamSequence
|
||||
@@ -19,7 +19,7 @@ from ..plotting import plot_setup
|
||||
from .. import env, math
|
||||
|
||||
|
||||
def fingerprint(params: BareParams):
|
||||
def fingerprint(params: Parameters):
|
||||
h1 = hash(params.field_0.tobytes())
|
||||
h2 = tuple(params.beta2_coefficients)
|
||||
return h1, h2
|
||||
@@ -160,7 +160,7 @@ def plot_1_dispersion(
|
||||
right: plt.Axes,
|
||||
style: dict[str, Any],
|
||||
lbl: list[str],
|
||||
params: BareParams,
|
||||
params: Parameters,
|
||||
loss: plt.Axes = None,
|
||||
):
|
||||
beta_arr = fiber.dispersion_from_coefficients(params.w_c, params.beta2_coefficients)
|
||||
@@ -253,7 +253,7 @@ def finish_plot(fig, legend_axes, all_labels, params):
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], BareParams]]:
|
||||
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]:
|
||||
cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"])
|
||||
pseq = ParamSequence(config_path)
|
||||
for style, (variables, params) in zip(cc, pseq):
|
||||
|
||||
@@ -10,7 +10,8 @@ from typing import Tuple
|
||||
import numpy as np
|
||||
|
||||
from ..initialize import validate_config_sequence
|
||||
from ..io import Paths, load_config
|
||||
from ..io import Paths
|
||||
from ..utils.parameter import BareConfig
|
||||
|
||||
|
||||
def primes(n):
|
||||
@@ -127,7 +128,7 @@ def main():
|
||||
)
|
||||
|
||||
if args.command == "merge":
|
||||
final_name = load_config(Path(args.configs[0]) / "initial_config.toml").name
|
||||
final_name = BareConfig.load(Path(args.configs[0]) / "initial_config.toml").name
|
||||
sim_num = "many"
|
||||
args.nodes = 1
|
||||
args.cpus_per_node = 1
|
||||
|
||||
@@ -11,13 +11,13 @@ from .const import SPECN_FN
|
||||
from .logger import get_logger
|
||||
from .physics import pulse, units
|
||||
from .plotting import mean_values_plot, propagation_plot, single_position_plot
|
||||
from .utils.parameter import BareParams
|
||||
from .utils.parameter import Parameters
|
||||
|
||||
|
||||
class Spectrum(np.ndarray):
|
||||
params: BareParams
|
||||
params: Parameters
|
||||
|
||||
def __new__(cls, input_array, params: BareParams):
|
||||
def __new__(cls, input_array, params: Parameters):
|
||||
# Input array is an already formed ndarray instance
|
||||
# We first cast to be our class type
|
||||
obj = np.asarray(input_array).view(cls)
|
||||
@@ -144,7 +144,7 @@ class Pulse(Sequence):
|
||||
if not self.path.is_dir():
|
||||
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
||||
|
||||
self.params = io.load_params(self.path / "params.toml")
|
||||
self.params = Parameters.load(self.path / "params.toml")
|
||||
|
||||
initialize.build_sim_grid_in_place(self.params)
|
||||
|
||||
|
||||
@@ -16,15 +16,17 @@ from dataclasses import asdict, replace
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Iterator, TypeVar, Union
|
||||
from ..errors import IncompleteDataFolderError
|
||||
|
||||
import numpy as np
|
||||
from numpy.lib.arraysetops import isin
|
||||
from tqdm import tqdm
|
||||
|
||||
from .. import env
|
||||
from ..const import PARAM_SEPARATOR
|
||||
from ..const import PARAM_SEPARATOR, SPEC1_FN
|
||||
from ..math import *
|
||||
from .parameter import BareConfig, BareParams
|
||||
from .. import io
|
||||
from .parameter import BareConfig, Parameters
|
||||
|
||||
T_ = TypeVar("T_")
|
||||
|
||||
@@ -212,7 +214,7 @@ def format_value(name: str, value) -> str:
|
||||
return str(value)
|
||||
elif isinstance(value, (float, int)):
|
||||
try:
|
||||
return getattr(BareParams, name).display(value)
|
||||
return getattr(Parameters, name).display(value)
|
||||
except AttributeError:
|
||||
return format(value, ".9g")
|
||||
elif isinstance(value, (list, tuple, np.ndarray)):
|
||||
@@ -226,7 +228,7 @@ def format_value(name: str, value) -> str:
|
||||
|
||||
def pretty_format_value(name: str, value) -> str:
|
||||
try:
|
||||
return getattr(BareParams, name).display(value)
|
||||
return getattr(Parameters, name).display(value)
|
||||
except AttributeError:
|
||||
return name + PARAM_SEPARATOR + str(value)
|
||||
|
||||
@@ -248,12 +250,74 @@ def pretty_format_from_sim_name(name: str) -> str:
|
||||
out = []
|
||||
for key, value in zip(s[::2], s[1::2]):
|
||||
try:
|
||||
out += [key.replace("_", " "), getattr(BareParams, key).display(float(value))]
|
||||
out += [key.replace("_", " "), getattr(Parameters, key).display(float(value))]
|
||||
except (AttributeError, ValueError):
|
||||
out.append(key + PARAM_SEPARATOR + value)
|
||||
return PARAM_SEPARATOR.join(out)
|
||||
|
||||
|
||||
def check_data_integrity(sub_folders: list[Path], init_z_num: int):
|
||||
"""checks the integrity and completeness of a simulation data folder
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
path to the data folder
|
||||
init_z_num : int
|
||||
z_num as specified by the initial configuration file
|
||||
|
||||
Raises
|
||||
------
|
||||
IncompleteDataFolderError
|
||||
raised if not all spectra are present in any folder
|
||||
"""
|
||||
|
||||
for sub_folder in PBars(sub_folders, "Checking integrity"):
|
||||
if num_left_to_propagate(sub_folder, init_z_num) != 0:
|
||||
raise IncompleteDataFolderError(
|
||||
f"not enough spectra of the specified {init_z_num} found in {sub_folder}"
|
||||
)
|
||||
|
||||
|
||||
def num_left_to_propagate(sub_folder: Path, init_z_num: int) -> int:
|
||||
"""checks if a propagation has completed
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sub_folder : Path
|
||||
path to the sub folder containing the spectra
|
||||
init_z_num : int
|
||||
number of z position to store as specified in the master config file
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the propagation has completed
|
||||
|
||||
Raises
|
||||
------
|
||||
IncompleteDataFolderError
|
||||
raised if init_z_num doesn't match that specified in the individual parameter file
|
||||
"""
|
||||
z_num = io.load_toml(sub_folder / "params.toml")["z_num"]
|
||||
num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing
|
||||
|
||||
if z_num != init_z_num:
|
||||
raise IncompleteDataFolderError(
|
||||
f"initial config specifies {init_z_num} spectra per"
|
||||
+ f" but the parameter file in {sub_folder} specifies {z_num}"
|
||||
)
|
||||
|
||||
return z_num - num_spectra
|
||||
|
||||
|
||||
def find_last_spectrum_num(data_dir: Path):
|
||||
for num in itertools.count(1):
|
||||
p_to_test = data_dir / SPEC1_FN.format(num)
|
||||
if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0:
|
||||
return num - 1
|
||||
|
||||
|
||||
def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]:
|
||||
"""given a config with "variable" parameters, iterates through every possible combination,
|
||||
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
|
||||
@@ -268,7 +332,7 @@ def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]
|
||||
Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]
|
||||
variable_list : a list of (name, value) tuple of parameter name and value that are variable.
|
||||
|
||||
params : a dict[str, Any] to be fed to Params
|
||||
params : a dict[str, Any] to be fed to Parameters
|
||||
"""
|
||||
possible_keys = []
|
||||
possible_ranges = []
|
||||
@@ -294,7 +358,7 @@ def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]
|
||||
|
||||
def required_simulations(
|
||||
*configs: BareConfig,
|
||||
) -> Iterator[tuple[list[tuple[str, Any]], BareParams]]:
|
||||
) -> Iterator[tuple[list[tuple[str, Any]], Parameters]]:
|
||||
"""takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different
|
||||
parameter set and iterates through every single necessary simulation
|
||||
|
||||
@@ -317,7 +381,7 @@ def required_simulations(
|
||||
for j in range(configs[0].repeat or 1):
|
||||
variable_ind = [("id", i)] + variable_only + [("num", j)]
|
||||
i += 1
|
||||
yield variable_ind, BareParams(**params_dict)
|
||||
yield variable_ind, Parameters(**params_dict)
|
||||
|
||||
|
||||
def reduce_all_variable(all_variable: list[list[tuple[str, Any]]]) -> list[tuple[str, Any]]:
|
||||
|
||||
@@ -1,392 +0,0 @@
|
||||
import itertools
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import math
|
||||
from ..logger import get_logger
|
||||
from ..physics import fiber, materials, pulse, units
|
||||
|
||||
T = TypeVar("T")
|
||||
import inspect
|
||||
|
||||
|
||||
class EvaluatorError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Rule:
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[str, list[Optional[str]]],
|
||||
func: Callable,
|
||||
args: list[str] = None,
|
||||
priorities: Union[int, list[int]] = None,
|
||||
conditions: dict[str, str] = None,
|
||||
):
|
||||
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
||||
self.func = func
|
||||
if priorities is None:
|
||||
priorities = [1] * len(targets)
|
||||
elif isinstance(priorities, (int, float, np.integer, np.floating)):
|
||||
priorities = [priorities]
|
||||
self.targets = dict(zip(targets, priorities))
|
||||
if args is None:
|
||||
args = get_arg_names(func)
|
||||
self.args = args
|
||||
self.conditions = conditions or {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})"
|
||||
|
||||
@classmethod
|
||||
def deduce(
|
||||
cls,
|
||||
target: Union[str, list[Optional[str]]],
|
||||
func: Callable,
|
||||
kwarg_names: list[str],
|
||||
n_var: int,
|
||||
args_const: list[str] = None,
|
||||
) -> list["Rule"]:
|
||||
"""given a function that doesn't need all its keyword arguemtn specified, will
|
||||
return a list of Rule obj, one for each combination of n_var specified kwargs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str | list[str | None]
|
||||
name of the variable(s) that func returns
|
||||
func : Callable
|
||||
function to work with
|
||||
kwarg_names : list[str]
|
||||
list of all kwargs of the function to be used
|
||||
n_var : int
|
||||
how many shoulf be used per rule
|
||||
arg_const : list[str], optional
|
||||
override the name of the positional arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Rule]
|
||||
list of all possible rules
|
||||
|
||||
Example
|
||||
-------
|
||||
>> def lol(a, b=None, c=None):
|
||||
pass
|
||||
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
|
||||
[
|
||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
|
||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
|
||||
]
|
||||
"""
|
||||
rules: list[cls] = []
|
||||
for var_possibility in itertools.combinations(kwarg_names, n_var):
|
||||
|
||||
new_func = func_rewrite(func, list(var_possibility), args_const)
|
||||
|
||||
rules.append(cls(target, new_func))
|
||||
return rules
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalStat:
|
||||
priority: float = np.inf
|
||||
|
||||
|
||||
class Evaluator:
|
||||
@classmethod
|
||||
def default(cls) -> "Evaluator":
|
||||
evaluator = cls()
|
||||
evaluator.append(*default_rules)
|
||||
return evaluator
|
||||
|
||||
def __init__(self):
|
||||
self.rules: dict[str, list[Rule]] = defaultdict(list)
|
||||
self.params = {}
|
||||
self.__curent_lookup = set()
|
||||
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def append(self, *rule: Rule):
|
||||
for r in rule:
|
||||
for t in r.targets:
|
||||
if t is not None:
|
||||
self.rules[t].append(r)
|
||||
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
||||
|
||||
def set(self, **params: Any):
|
||||
self.params.update(params)
|
||||
for k in params:
|
||||
self.eval_stats[k].priority = np.inf
|
||||
|
||||
def reset(self):
|
||||
self.params = {}
|
||||
self.eval_stats = defaultdict(EvalStat)
|
||||
|
||||
def compute(self, target: str) -> Any:
|
||||
"""computes a target
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str
|
||||
name of the target
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
return type of the target function
|
||||
|
||||
Raises
|
||||
------
|
||||
EvaluatorError
|
||||
a cyclic dependence exists
|
||||
KeyError
|
||||
there is no saved rule for the target
|
||||
"""
|
||||
value = self.params.get(target)
|
||||
if value is None:
|
||||
if target in self.__curent_lookup:
|
||||
raise EvaluatorError(
|
||||
"cyclic dependency detected : "
|
||||
f"{target!r} seems to depend on itself, "
|
||||
f"please provide a value for at least one variable in {self.__curent_lookup}"
|
||||
)
|
||||
else:
|
||||
self.__curent_lookup.add(target)
|
||||
|
||||
if len(self.rules[target]) == 0:
|
||||
raise EvaluatorError(f"no rule for {target}")
|
||||
|
||||
error = None
|
||||
for ii, rule in enumerate(
|
||||
filter(lambda r: self.validate_condition(r), reversed(self.rules[target]))
|
||||
):
|
||||
self.logger.debug(f"attempt {ii+1} to compute {target}, this time using {rule!r}")
|
||||
try:
|
||||
args = [self.compute(k) for k in rule.args]
|
||||
returned_values = rule.func(*args)
|
||||
if len(rule.targets) == 1:
|
||||
returned_values = [returned_values]
|
||||
for ((param_name, param_priority), returned_value) in zip(
|
||||
rule.targets.items(), returned_values
|
||||
):
|
||||
if (
|
||||
param_name == target
|
||||
or param_name not in self.params
|
||||
or self.eval_stats[param_name].priority < param_priority
|
||||
):
|
||||
self.logger.info(
|
||||
f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}"
|
||||
)
|
||||
self.params[param_name] = returned_value
|
||||
self.eval_stats[param_name] = param_priority
|
||||
if param_name == target:
|
||||
value = returned_value
|
||||
break
|
||||
except (EvaluatorError, KeyError) as e:
|
||||
error = e
|
||||
continue
|
||||
|
||||
if value is None and error is not None:
|
||||
raise error
|
||||
|
||||
self.__curent_lookup.remove(target)
|
||||
return value
|
||||
|
||||
def validate_condition(self, rule: Rule) -> bool:
|
||||
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
||||
|
||||
def __call__(self, target: str, args: list[str] = None):
|
||||
"""creates a wrapper that adds decorated functions to the set of rules
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str
|
||||
name of the target
|
||||
args : list[str], optional
|
||||
list of name of arguments. Automatically deduced from function signature if
|
||||
not provided, by default None
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
self.append(Rule(target, func, args))
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
spec = inspect.getfullargspec(func)
|
||||
args = spec.args
|
||||
if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
args = args[: -len(spec.defaults)]
|
||||
return args
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
out_func = scope[tmp_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
default_rules: list[Rule] = [
|
||||
# Grid
|
||||
*Rule.deduce(
|
||||
["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"],
|
||||
math.build_sim_grid,
|
||||
["time_window", "t_num", "dt"],
|
||||
2,
|
||||
),
|
||||
# Pulse
|
||||
Rule("spec_0", np.fft.fft, ["field_0"]),
|
||||
Rule("field_0", np.fft.ifft, ["spec_0"]),
|
||||
Rule("spec_0", pulse.load_previous_spectrum, priorities=3),
|
||||
Rule(
|
||||
["pre_field_0", "peak_power", "energy", "width"],
|
||||
pulse.load_field_file,
|
||||
[
|
||||
"field_file",
|
||||
"t",
|
||||
"peak_power",
|
||||
"energy",
|
||||
"intensity_noise",
|
||||
"noise_correlation",
|
||||
"quantum_noise",
|
||||
"w_c",
|
||||
"w0",
|
||||
"time_window",
|
||||
"dt",
|
||||
],
|
||||
priorities=[2, 1, 1, 1],
|
||||
),
|
||||
Rule("pre_field_0", pulse.initial_field, priorities=1),
|
||||
Rule(
|
||||
"field_0",
|
||||
pulse.add_shot_noise,
|
||||
["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"],
|
||||
),
|
||||
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
||||
Rule("peak_power", pulse.soliton_num_to_peak_power),
|
||||
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
|
||||
Rule("energy", pulse.mean_power_to_energy),
|
||||
Rule("t0", pulse.width_to_t0),
|
||||
Rule("t0", pulse.soliton_num_to_t0),
|
||||
Rule("width", pulse.t0_to_width),
|
||||
Rule("soliton_num", pulse.soliton_num),
|
||||
Rule("L_D", pulse.L_D),
|
||||
Rule("L_NL", pulse.L_NL),
|
||||
Rule("L_sol", pulse.L_sol),
|
||||
# Fiber Dispersion
|
||||
Rule("wl_for_disp", fiber.lambda_for_dispersion),
|
||||
Rule("w_for_disp", units.m, ["wl_for_disp"]),
|
||||
Rule(
|
||||
"beta2_coefficients",
|
||||
fiber.dispersion_coefficients,
|
||||
["wl_for_disp", "beta2_arr", "w0", "interpolation_range", "interpolation_degree"],
|
||||
),
|
||||
Rule("beta2_arr", fiber.beta2),
|
||||
Rule("beta2_arr", fiber.dispersion_from_coefficients),
|
||||
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
|
||||
Rule(
|
||||
["wl_for_disp", "beta2_arr", "interpolation_range"],
|
||||
fiber.load_custom_dispersion,
|
||||
priorities=[2, 2, 2],
|
||||
),
|
||||
Rule("hr_w", fiber.delayed_raman_w),
|
||||
Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")),
|
||||
Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")),
|
||||
Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")),
|
||||
Rule(
|
||||
"n_eff",
|
||||
fiber.n_eff_pcf,
|
||||
["wl_for_disp", "pitch", "pitch_ratio"],
|
||||
conditions=dict(model="pcf"),
|
||||
),
|
||||
Rule("capillary_spacing", fiber.HCARF_gap),
|
||||
# Fiber nonlinearity
|
||||
Rule("A_eff", fiber.A_eff_from_V),
|
||||
Rule("A_eff", fiber.A_eff_from_diam),
|
||||
Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")),
|
||||
Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1),
|
||||
Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]),
|
||||
Rule("A_eff_arr", fiber.load_custom_A_eff),
|
||||
Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
|
||||
Rule(
|
||||
"V_eff",
|
||||
fiber.V_parameter_koshiba,
|
||||
["wavelength", "pitch", "pitch_ratio"],
|
||||
conditions=dict(model="pcf"),
|
||||
),
|
||||
Rule("V_eff", fiber.V_eff_marcuse, ["wavelength", "core_radius", "numerical_aperture"]),
|
||||
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
|
||||
Rule("V_eff_arr", fiber.V_eff_marcuse),
|
||||
Rule("gamma", lambda gamma_arr: gamma_arr[0]),
|
||||
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]),
|
||||
# Fiber loss
|
||||
Rule("alpha", fiber.compute_capillary_loss),
|
||||
Rule("alpha", fiber.load_custom_loss),
|
||||
# gas
|
||||
Rule("n_gas_2", materials.n_gas_2),
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
evalor = Evaluator()
|
||||
evalor.append(*default_rules)
|
||||
evalor.set(
|
||||
**{
|
||||
"length": 1,
|
||||
"z_num": 128,
|
||||
"wavelength": 1500e-9,
|
||||
"interpolation_degree": 8,
|
||||
"interpolation_range": (500e-9, 2200e-9),
|
||||
"t_num": 16384,
|
||||
"dt": 1e-15,
|
||||
"shape": "gaussian",
|
||||
"repetition_rate": 40e6,
|
||||
"width": 30e-15,
|
||||
"mean_power": 100e-3,
|
||||
"n2": 2.4e-20,
|
||||
"A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_marcuse.npz",
|
||||
"model": "pcf",
|
||||
"quantum_noise": True,
|
||||
"pitch": 1.2e-6,
|
||||
"pitch_ratio": 0.5,
|
||||
}
|
||||
)
|
||||
evalor.compute("z_targets")
|
||||
print(evalor.params.keys())
|
||||
print(evalor.params["l"][evalor.params["l"] > 0].min())
|
||||
evalor.compute("spec_0")
|
||||
plt.plot(evalor.params["l"], abs(evalor.params["spec_0"]) ** 2)
|
||||
plt.yscale("log")
|
||||
plt.show()
|
||||
print(evalor.compute("gamma"))
|
||||
print(evalor.compute("beta2"))
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,15 +1,21 @@
|
||||
import datetime as datetime_module
|
||||
import inspect
|
||||
import itertools
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, TypeVar
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
import os
|
||||
import numpy as np
|
||||
from tqdm.std import Bar
|
||||
|
||||
from .. import math
|
||||
from ..const import __version__
|
||||
|
||||
# from .evaluator import Rule, Evaluator
|
||||
# from ..physics import pulse, fiber, materials
|
||||
from ..logger import get_logger
|
||||
from .. import io
|
||||
from ..physics import fiber, materials, pulse, units
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -365,10 +371,10 @@ mandatory_parameters = [
|
||||
|
||||
|
||||
@dataclass
|
||||
class BareParams:
|
||||
class Parameters:
|
||||
"""
|
||||
This class defines each valid parameter's name, type and valid value but doesn't provide
|
||||
any method to act on those. For that, use initialize.Params
|
||||
This class defines each valid parameter's name, type and valid value. Initializing
|
||||
such an obj will automatically compute all possible parameters
|
||||
"""
|
||||
|
||||
# root
|
||||
@@ -476,11 +482,25 @@ class BareParams:
|
||||
|
||||
def prepare_for_dump(self) -> Dict[str, Any]:
|
||||
param = asdict(self)
|
||||
param = BareParams.strip_params_dict(param)
|
||||
param = Parameters.strip_params_dict(param)
|
||||
param["datetime"] = datetime_module.datetime.now()
|
||||
param["version"] = __version__
|
||||
return param
|
||||
|
||||
def __post_init__(self):
|
||||
param_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
||||
evaluator = Evaluator.default()
|
||||
evaluator.set(**param_dict)
|
||||
for p_name in mandatory_parameters:
|
||||
evaluator.compute(p_name)
|
||||
for k, v in evaluator.params.items():
|
||||
if k in param_dict:
|
||||
setattr(self, k, v)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: os.PathLike) -> "Parameters":
|
||||
return cls(**io.load_toml(path))
|
||||
|
||||
@staticmethod
|
||||
def strip_params_dict(dico: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""prepares a dictionary for serialization. Some keys may not be preserved
|
||||
@@ -513,7 +533,7 @@ class BareParams:
|
||||
if not isinstance(value, types):
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
out[key] = BareParams.strip_params_dict(value)
|
||||
out[key] = Parameters.strip_params_dict(value)
|
||||
elif isinstance(value, np.ndarray) and value.dtype == complex:
|
||||
continue
|
||||
else:
|
||||
@@ -525,9 +545,382 @@ class BareParams:
|
||||
return out
|
||||
|
||||
|
||||
class EvaluatorError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Rule:
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[str, list[Optional[str]]],
|
||||
func: Callable,
|
||||
args: list[str] = None,
|
||||
priorities: Union[int, list[int]] = None,
|
||||
conditions: dict[str, str] = None,
|
||||
):
|
||||
targets = list(target) if isinstance(target, (list, tuple)) else [target]
|
||||
self.func = func
|
||||
if priorities is None:
|
||||
priorities = [1] * len(targets)
|
||||
elif isinstance(priorities, (int, float, np.integer, np.floating)):
|
||||
priorities = [priorities]
|
||||
self.targets = dict(zip(targets, priorities))
|
||||
if args is None:
|
||||
args = get_arg_names(func)
|
||||
self.args = args
|
||||
self.conditions = conditions or {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})"
|
||||
|
||||
@classmethod
|
||||
def deduce(
|
||||
cls,
|
||||
target: Union[str, list[Optional[str]]],
|
||||
func: Callable,
|
||||
kwarg_names: list[str],
|
||||
n_var: int,
|
||||
args_const: list[str] = None,
|
||||
) -> list["Rule"]:
|
||||
"""given a function that doesn't need all its keyword arguemtn specified, will
|
||||
return a list of Rule obj, one for each combination of n_var specified kwargs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str | list[str | None]
|
||||
name of the variable(s) that func returns
|
||||
func : Callable
|
||||
function to work with
|
||||
kwarg_names : list[str]
|
||||
list of all kwargs of the function to be used
|
||||
n_var : int
|
||||
how many shoulf be used per rule
|
||||
arg_const : list[str], optional
|
||||
override the name of the positional arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Rule]
|
||||
list of all possible rules
|
||||
|
||||
Example
|
||||
-------
|
||||
>> def lol(a, b=None, c=None):
|
||||
pass
|
||||
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
|
||||
[
|
||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
|
||||
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
|
||||
]
|
||||
"""
|
||||
rules: list[cls] = []
|
||||
for var_possibility in itertools.combinations(kwarg_names, n_var):
|
||||
|
||||
new_func = func_rewrite(func, list(var_possibility), args_const)
|
||||
|
||||
rules.append(cls(target, new_func))
|
||||
return rules
|
||||
|
||||
|
||||
@dataclass
|
||||
class BareConfig(BareParams):
|
||||
variable: dict = VariableParameter(BareParams)
|
||||
class EvalStat:
|
||||
priority: float = np.inf
|
||||
|
||||
|
||||
class Evaluator:
|
||||
@classmethod
|
||||
def default(cls) -> "Evaluator":
|
||||
evaluator = cls()
|
||||
evaluator.append(*default_rules)
|
||||
return evaluator
|
||||
|
||||
def __init__(self):
|
||||
self.rules: dict[str, list[Rule]] = defaultdict(list)
|
||||
self.params = {}
|
||||
self.__curent_lookup = set()
|
||||
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
def append(self, *rule: Rule):
|
||||
for r in rule:
|
||||
for t in r.targets:
|
||||
if t is not None:
|
||||
self.rules[t].append(r)
|
||||
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
||||
|
||||
def set(self, **params: Any):
|
||||
self.params.update(params)
|
||||
for k in params:
|
||||
self.eval_stats[k].priority = np.inf
|
||||
|
||||
def reset(self):
|
||||
self.params = {}
|
||||
self.eval_stats = defaultdict(EvalStat)
|
||||
|
||||
def compute(self, target: str) -> Any:
|
||||
"""computes a target
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str
|
||||
name of the target
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
return type of the target function
|
||||
|
||||
Raises
|
||||
------
|
||||
EvaluatorError
|
||||
a cyclic dependence exists
|
||||
KeyError
|
||||
there is no saved rule for the target
|
||||
"""
|
||||
value = self.params.get(target)
|
||||
if value is None:
|
||||
if target in self.__curent_lookup:
|
||||
raise EvaluatorError(
|
||||
"cyclic dependency detected : "
|
||||
f"{target!r} seems to depend on itself, "
|
||||
f"please provide a value for at least one variable in {self.__curent_lookup}"
|
||||
)
|
||||
else:
|
||||
self.__curent_lookup.add(target)
|
||||
|
||||
if len(self.rules[target]) == 0:
|
||||
raise EvaluatorError(f"no rule for {target}")
|
||||
|
||||
error = None
|
||||
for ii, rule in enumerate(
|
||||
filter(lambda r: self.validate_condition(r), reversed(self.rules[target]))
|
||||
):
|
||||
self.logger.debug(f"attempt {ii+1} to compute {target}, this time using {rule!r}")
|
||||
try:
|
||||
args = [self.compute(k) for k in rule.args]
|
||||
returned_values = rule.func(*args)
|
||||
if len(rule.targets) == 1:
|
||||
returned_values = [returned_values]
|
||||
for ((param_name, param_priority), returned_value) in zip(
|
||||
rule.targets.items(), returned_values
|
||||
):
|
||||
if (
|
||||
param_name == target
|
||||
or param_name not in self.params
|
||||
or self.eval_stats[param_name].priority < param_priority
|
||||
):
|
||||
self.logger.info(
|
||||
f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}"
|
||||
)
|
||||
self.params[param_name] = returned_value
|
||||
self.eval_stats[param_name] = param_priority
|
||||
if param_name == target:
|
||||
value = returned_value
|
||||
break
|
||||
except (EvaluatorError, KeyError) as e:
|
||||
error = e
|
||||
continue
|
||||
|
||||
if value is None and error is not None:
|
||||
raise error
|
||||
|
||||
self.__curent_lookup.remove(target)
|
||||
return value
|
||||
|
||||
def validate_condition(self, rule: Rule) -> bool:
|
||||
return all(self.compute(k) == v for k, v in rule.conditions.items())
|
||||
|
||||
def __call__(self, target: str, args: list[str] = None):
|
||||
"""creates a wrapper that adds decorated functions to the set of rules
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : str
|
||||
name of the target
|
||||
args : list[str], optional
|
||||
list of name of arguments. Automatically deduced from function signature if
|
||||
not provided, by default None
|
||||
"""
|
||||
|
||||
def wrapper(func):
|
||||
self.append(Rule(target, func, args))
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_arg_names(func: Callable) -> list[str]:
|
||||
spec = inspect.getfullargspec(func)
|
||||
args = spec.args
|
||||
if spec.defaults is not None and len(spec.defaults) > 0:
|
||||
args = args[: -len(spec.defaults)]
|
||||
return args
|
||||
|
||||
|
||||
def validate_arg_names(names: list[str]):
|
||||
for n in names:
|
||||
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
|
||||
raise ValueError(f"{n} is an invalid parameter name")
|
||||
|
||||
|
||||
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
|
||||
if arg_names is None:
|
||||
arg_names = get_arg_names(func)
|
||||
else:
|
||||
validate_arg_names(arg_names)
|
||||
validate_arg_names(kwarg_names)
|
||||
sign_arg_str = ", ".join(arg_names + kwarg_names)
|
||||
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
|
||||
tmp_name = f"{func.__name__}_0"
|
||||
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
|
||||
scope = dict(__func__=func)
|
||||
exec(func_str, scope)
|
||||
out_func = scope[tmp_name]
|
||||
out_func.__module__ = "evaluator"
|
||||
return out_func
|
||||
|
||||
|
||||
default_rules: list[Rule] = [
|
||||
# Grid
|
||||
*Rule.deduce(
|
||||
["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"],
|
||||
math.build_sim_grid,
|
||||
["time_window", "t_num", "dt"],
|
||||
2,
|
||||
),
|
||||
# Pulse
|
||||
Rule("spec_0", np.fft.fft, ["field_0"]),
|
||||
Rule("field_0", np.fft.ifft, ["spec_0"]),
|
||||
Rule("spec_0", pulse.load_previous_spectrum, priorities=3),
|
||||
Rule(
|
||||
["pre_field_0", "peak_power", "energy", "width"],
|
||||
pulse.load_field_file,
|
||||
[
|
||||
"field_file",
|
||||
"t",
|
||||
"peak_power",
|
||||
"energy",
|
||||
"intensity_noise",
|
||||
"noise_correlation",
|
||||
"quantum_noise",
|
||||
"w_c",
|
||||
"w0",
|
||||
"time_window",
|
||||
"dt",
|
||||
],
|
||||
priorities=[2, 1, 1, 1],
|
||||
),
|
||||
Rule("pre_field_0", pulse.initial_field, priorities=1),
|
||||
Rule(
|
||||
"field_0",
|
||||
pulse.add_shot_noise,
|
||||
["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"],
|
||||
),
|
||||
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
||||
Rule("peak_power", pulse.soliton_num_to_peak_power),
|
||||
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
|
||||
Rule("energy", pulse.mean_power_to_energy),
|
||||
Rule("t0", pulse.width_to_t0),
|
||||
Rule("t0", pulse.soliton_num_to_t0),
|
||||
Rule("width", pulse.t0_to_width),
|
||||
Rule("soliton_num", pulse.soliton_num),
|
||||
Rule("L_D", pulse.L_D),
|
||||
Rule("L_NL", pulse.L_NL),
|
||||
Rule("L_sol", pulse.L_sol),
|
||||
# Fiber Dispersion
|
||||
Rule("wl_for_disp", fiber.lambda_for_dispersion),
|
||||
Rule("w_for_disp", units.m, ["wl_for_disp"]),
|
||||
Rule(
|
||||
"beta2_coefficients",
|
||||
fiber.dispersion_coefficients,
|
||||
["wl_for_disp", "beta2_arr", "w0", "interpolation_range", "interpolation_degree"],
|
||||
),
|
||||
Rule("beta2_arr", fiber.beta2),
|
||||
Rule("beta2_arr", fiber.dispersion_from_coefficients),
|
||||
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
|
||||
Rule(
|
||||
["wl_for_disp", "beta2_arr", "interpolation_range"],
|
||||
fiber.load_custom_dispersion,
|
||||
priorities=[2, 2, 2],
|
||||
),
|
||||
Rule("hr_w", fiber.delayed_raman_w),
|
||||
Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")),
|
||||
Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")),
|
||||
Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")),
|
||||
Rule(
|
||||
"n_eff",
|
||||
fiber.n_eff_pcf,
|
||||
["wl_for_disp", "pitch", "pitch_ratio"],
|
||||
conditions=dict(model="pcf"),
|
||||
),
|
||||
Rule("capillary_spacing", fiber.HCARF_gap),
|
||||
# Fiber nonlinearity
|
||||
Rule("A_eff", fiber.A_eff_from_V),
|
||||
Rule("A_eff", fiber.A_eff_from_diam),
|
||||
Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")),
|
||||
Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1),
|
||||
Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]),
|
||||
Rule("A_eff_arr", fiber.load_custom_A_eff),
|
||||
Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
|
||||
Rule(
|
||||
"V_eff",
|
||||
fiber.V_parameter_koshiba,
|
||||
["wavelength", "pitch", "pitch_ratio"],
|
||||
conditions=dict(model="pcf"),
|
||||
),
|
||||
Rule("V_eff", fiber.V_eff_marcuse, ["wavelength", "core_radius", "numerical_aperture"]),
|
||||
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
|
||||
Rule("V_eff_arr", fiber.V_eff_marcuse),
|
||||
Rule("gamma", lambda gamma_arr: gamma_arr[0]),
|
||||
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]),
|
||||
# Fiber loss
|
||||
Rule("alpha", fiber.compute_capillary_loss),
|
||||
Rule("alpha", fiber.load_custom_loss),
|
||||
# gas
|
||||
Rule("n_gas_2", materials.n_gas_2),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BareConfig(Parameters):
|
||||
variable: dict = VariableParameter(Parameters)
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: os.PathLike) -> "BareConfig":
|
||||
return cls(**io.load_toml(path))
|
||||
|
||||
@classmethod
|
||||
def load_sequence(cls, *config_paths: os.PathLike) -> list["BareConfig"]:
|
||||
"""Loads a sequence of
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_paths : os.PathLike
|
||||
either one path (the last config containing previous_config_file parameter)
|
||||
or a list of config path in the order they have to be simulated
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[BareConfig]
|
||||
all loaded configs
|
||||
"""
|
||||
if config_paths[0] is None:
|
||||
return []
|
||||
all_configs = [cls.load(config_paths[0])]
|
||||
if len(config_paths) == 1:
|
||||
while True:
|
||||
if all_configs[0].previous_config_file is not None:
|
||||
all_configs.insert(0, cls.load(all_configs[0].previous_config_file))
|
||||
else:
|
||||
break
|
||||
else:
|
||||
for i, path in enumerate(config_paths[1:]):
|
||||
all_configs.append(cls.load(path))
|
||||
all_configs[i + 1].previous_config_file = config_paths[i]
|
||||
return all_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,7 +7,7 @@ import toml
|
||||
from scgenerator import defaults, utils, math
|
||||
from scgenerator.errors import *
|
||||
from scgenerator.physics import pulse, units
|
||||
from scgenerator.utils.parameter import BareConfig, BareParams
|
||||
from scgenerator.utils.parameter import BareConfig, Parameters
|
||||
|
||||
|
||||
def load_conf(name):
|
||||
@@ -143,10 +143,10 @@ class TestInitializeMethods(unittest.TestCase):
|
||||
init.Config(**conf("good5")).__dict__.items(),
|
||||
)
|
||||
|
||||
def setup_conf_custom_field(self, path) -> BareParams:
|
||||
def setup_conf_custom_field(self, path) -> Parameters:
|
||||
|
||||
conf = load_conf(path)
|
||||
conf = BareParams(**conf)
|
||||
conf = Parameters(**conf)
|
||||
init.build_sim_grid_in_place(conf)
|
||||
return conf
|
||||
|
||||
@@ -192,12 +192,12 @@ class TestInitializeMethods(unittest.TestCase):
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
|
||||
result = init.Params.from_bare(conf)
|
||||
result = Parameters(**conf)
|
||||
self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
|
||||
conf.wavelength = 1593e-9
|
||||
result = init.Params.from_bare(conf)
|
||||
result = Parameters(**conf)
|
||||
|
||||
conf = load_conf("custom_field/wavelength_shift2")
|
||||
conf = init.Config(**conf)
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import scgenerator as sc
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/")
|
||||
|
||||
root = Path("PM1550+PMHNLF+PM1550+PM2000")
|
||||
|
||||
confs = sc.io.load_config_sequence(root / "4_PM2000.toml")
|
||||
final = sc.utils.final_config_from_sequence(*confs)
|
||||
|
||||
print(final)
|
||||
Reference in New Issue
Block a user