in the middle of sorting circular import

This commit is contained in:
Benoît Sierro
2021-08-27 10:59:11 +02:00
parent 6c869d5c6c
commit 5751a86e79
19 changed files with 603 additions and 732 deletions

View File

@@ -1,47 +0,0 @@
from typing import Callable
import inspect
import re
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
return scope[tmp_name]
def lol(a, b=None, c=None):
print(f"{a=}, {b=}, {c=}")
def main():
lol1 = func_rewrite(lol, ["c"])
print(inspect.getfullargspec(lol1))
lol2 = func_rewrite(lol, ["b"])
print(inspect.getfullargspec(lol2))
if __name__ == "__main__":
main()

22
play.py
View File

@@ -1,6 +1,18 @@
from tqdm import tqdm from scgenerator import Parameters
import time import os
import random
for i in tqdm(range(100), smoothing=0):
time.sleep(random.random()) def main():
cwd = os.getcwd()
try:
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations")
pa = Parameters.load("PM1550+PM2000D/PM1550_PM2000D raman_test/initial_config_0.toml")
print(pa)
finally:
os.chdir(cwd)
if __name__ == "__main__":
main()

View File

@@ -2,15 +2,15 @@ from . import initialize, io, math, utils
from .initialize import ( from .initialize import (
Config, Config,
ContinuationParamSequence, ContinuationParamSequence,
Params, Parameters,
ParamSequence, ParamSequence,
RecoveryParamSequence, RecoveryParamSequence,
) )
from .io import Paths, load_params, load_toml from .io import Paths, load_toml
from .math import abs2, argclosest, span from .math import abs2, argclosest, span
from .physics import fiber, materials, pulse, simulate, units from .physics import fiber, materials, pulse, simulate, units
from .physics.simulate import RK4IP, new_simulation, resume_simulations from .physics.simulate import RK4IP, new_simulation, resume_simulations
from .physics.units import PlotRange from .physics.units import PlotRange
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
from .spectra import Pulse, Spectrum from .spectra import Pulse, Spectrum
from .utils.parameter import BareConfig, BareParams from .utils.parameter import BareConfig, Parameters

View File

@@ -15,25 +15,12 @@ from .utils import override_config, required_simulations
from .utils.evaluator import Evaluator from .utils.evaluator import Evaluator
from .utils.parameter import ( from .utils.parameter import (
BareConfig, BareConfig,
BareParams, Parameters,
hc_model_specific_parameters, hc_model_specific_parameters,
mandatory_parameters, mandatory_parameters,
) )
@dataclass
class Params(BareParams):
@classmethod
def from_bare(cls, bare: BareParams):
param_dict = {k: v for k, v in asdict(bare).items() if v is not None}
evaluator = Evaluator.default()
evaluator.set(**param_dict)
for p_name in mandatory_parameters:
evaluator.compute(p_name)
new_param_dict = {k: v for k, v in evaluator.params.items() if k in param_dict}
return cls(**new_param_dict)
@dataclass @dataclass
class Config(BareConfig): class Config(BareConfig):
@classmethod @classmethod
@@ -222,11 +209,11 @@ class ParamSequence:
self.update_num_sim() self.update_num_sim()
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Parameters]]:
"""iterates through all possible parameters, yielding a config as well as a flattened """iterates through all possible parameters, yielding a config as well as a flattened
computed parameters set each time""" computed parameters set each time"""
for variable_list, bare_params in required_simulations(self.config): for variable_list, params in required_simulations(self.config):
yield variable_list, Params.from_bare(bare_params) yield variable_list, params
def __len__(self): def __len__(self):
return self.num_sim return self.num_sim
@@ -259,19 +246,19 @@ class ContinuationParamSequence(ParamSequence):
new config new config
""" """
self.prev_sim_dir = Path(prev_sim_dir) self.prev_sim_dir = Path(prev_sim_dir)
self.bare_configs = io.load_config_sequence(new_config.previous_config_file) self.bare_configs = BareConfig.load_sequence(new_config.previous_config_file)
self.bare_configs.append(new_config) self.bare_configs.append(new_config)
self.bare_configs[0] = Config.from_bare(self.bare_configs[0]) self.bare_configs[0] = Config.from_bare(self.bare_configs[0])
final_config = utils.final_config_from_sequence(*self.bare_configs) final_config = utils.final_config_from_sequence(*self.bare_configs)
super().__init__(final_config) super().__init__(final_config)
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Parameters]]:
"""iterates through all possible parameters, yielding a config as well as a flattened """iterates through all possible parameters, yielding a config as well as a flattened
computed parameters set each time""" computed parameters set each time"""
for variable_list, bare_params in required_simulations(*self.bare_configs): for variable_list, params in required_simulations(*self.bare_configs):
prev_data_dir = self.find_prev_data_dirs(variable_list)[0] prev_data_dir = self.find_prev_data_dirs(variable_list)[0]
bare_params.prev_data_dir = str(prev_data_dir.resolve()) params.prev_data_dir = str(prev_data_dir.resolve())
yield variable_list, Params.from_bare(bare_params) yield variable_list, params
def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]: def find_prev_data_dirs(self, new_variable_list: List[Tuple[str, Any]]) -> List[Path]:
"""finds the previous simulation data that this new config should start from """finds the previous simulation data that this new config should start from
@@ -324,7 +311,7 @@ class RecoveryParamSequence(ParamSequence):
self.prev_sim_dir = None self.prev_sim_dir = None
if self.config.prev_sim_dir is not None: if self.config.prev_sim_dir is not None:
self.prev_sim_dir = Path(self.config.prev_sim_dir) self.prev_sim_dir = Path(self.config.prev_sim_dir)
init_config = io.load_config(self.prev_sim_dir / "initial_config.toml") init_config = BareConfig.load(self.prev_sim_dir / "initial_config.toml")
self.prev_variable_lists = [ self.prev_variable_lists = [
( (
set(variable_list[1:]), set(variable_list[1:]),
@@ -357,17 +344,17 @@ class RecoveryParamSequence(ParamSequence):
self.num_steps += not_started * self.config.z_num self.num_steps += not_started * self.config.z_num
self.single_sim = self.num_sim == 1 self.single_sim = self.num_sim == 1
def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Params]]: def __iter__(self) -> Iterator[Tuple[List[Tuple[str, Any]], Parameters]]:
for variable_list, bare_params in required_simulations(self.config): for variable_list, params in required_simulations(self.config):
data_dir = io.get_sim_dir(self.id) / utils.format_variable_list(variable_list) data_dir = io.get_sim_dir(self.id) / utils.format_variable_list(variable_list)
if not data_dir.is_dir() or io.find_last_spectrum_num(data_dir) == 0: if not data_dir.is_dir() or io.find_last_spectrum_num(data_dir) == 0:
if (prev_data_dir := self.find_prev_data_dir(variable_list)) is not None: if (prev_data_dir := self.find_prev_data_dir(variable_list)) is not None:
bare_params.prev_data_dir = str(prev_data_dir) params.prev_data_dir = str(prev_data_dir)
yield variable_list, Params.from_bare(bare_params) yield variable_list, params
elif io.num_left_to_propagate(data_dir, self.config.z_num) != 0: elif io.num_left_to_propagate(data_dir, self.config.z_num) != 0:
yield variable_list, recover_params(bare_params, data_dir) yield variable_list, params + "Needs to rethink recovery procedure"
else: else:
continue continue
@@ -417,7 +404,7 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]:
""" """
previous = None previous = None
configs = io.load_config_sequence(*configs) configs = BareConfig.load_sequence(*configs)
for config in configs: for config in configs:
# if (p := Path(config)).is_dir(): # if (p := Path(config)).is_dir():
# config = p / "initial_config.toml" # config = p / "initial_config.toml"
@@ -487,22 +474,20 @@ def validate_config_sequence(*configs: os.PathLike) -> tuple[str, int]:
# raise TypeError("not enough parameter to determine time vector") # raise TypeError("not enough parameter to determine time vector")
def recover_params(params: BareParams, data_folder: Path) -> Params: # def recover_params(params: Parameters, data_folder: Path) -> Parameters:
params = Params.from_bare(params) # try:
try: # prev = Parameters.load(data_folder / "params.toml")
prev = io.load_params(data_folder / "params.toml") # except FileNotFoundError:
build_sim_grid_in_place(prev) # prev = Parameters()
except FileNotFoundError: # for k, v in filter(lambda el: el[1] is not None, vars(prev).items()):
prev = BareParams() # if getattr(params, k) is None:
for k, v in filter(lambda el: el[1] is not None, vars(prev).items()): # setattr(params, k, v)
if getattr(params, k) is None: # num, last_spectrum = io.load_last_spectrum(data_folder)
setattr(params, k, v) # params.spec_0 = last_spectrum
num, last_spectrum = io.load_last_spectrum(data_folder) # params.field_0 = np.fft.ifft(last_spectrum)
params.spec_0 = last_spectrum # params.recovery_last_stored = num
params.field_0 = np.fft.ifft(last_spectrum) # params.cons_qty = np.load(data_folder / "cons_qty.npy")
params.recovery_last_stored = num # return params
params.cons_qty = np.load(data_folder / "cons_qty.npy")
return params
# def build_sim_grid( # def build_sim_grid(

View File

@@ -1,20 +1,22 @@
from __future__ import annotations
import itertools import itertools
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Sequence, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Tuple
import numpy as np import numpy as np
import pkg_resources as pkg import pkg_resources as pkg
import toml import toml
from tqdm.std import Bar
from . import env, utils from scgenerator.utils.parameter import BareConfig
from . import env
from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__ from .const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__
from .env import TMP_FOLDER_KEY_BASE from .env import TMP_FOLDER_KEY_BASE
from .errors import IncompleteDataFolderError
from .logger import get_logger from .logger import get_logger
from .utils.parameter import BareConfig, BareParams, translate
PathTree = List[Tuple[Path, ...]] PathTree = List[Tuple[Path, ...]]
@@ -98,7 +100,9 @@ def save_toml(path: os.PathLike, dico):
return dico return dico
def save_parameters(params: BareParams, destination_dir: Path, file_name="params.toml") -> Path: def save_parameters(
params: dict[str, Any], destination_dir: Path, file_name: str = "params.toml"
) -> Path:
"""saves a parameter dictionary. Note that is does remove some entries, particularly """saves a parameter dictionary. Note that is does remove some entries, particularly
those that take a lot of space ("t", "w", ...) those that take a lot of space ("t", "w", ...)
@@ -114,87 +118,17 @@ def save_parameters(params: BareParams, destination_dir: Path, file_name="params
Path Path
path to newly created the paramter file path to newly created the paramter file
""" """
param = params.prepare_for_dump()
file_path = destination_dir / file_name file_path = destination_dir / file_name
file_path.parent.mkdir(exist_ok=True) file_path.parent.mkdir(exist_ok=True)
# save toml of the simulation # save toml of the simulation
with open(file_path, "w") as file: with open(file_path, "w") as file:
toml.dump(param, file, encoder=toml.TomlNumpyEncoder()) toml.dump(params, file, encoder=toml.TomlNumpyEncoder())
return file_path return file_path
def load_params(path: os.PathLike) -> BareParams:
"""loads a parameters toml files and converts data to appropriate type
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
Parameters
----------
path : PathLike
path to the toml
Returns
----------
BareParams
params obj
"""
params = load_toml(path)
try:
return BareParams(**params)
except TypeError:
return BareParams(**dict(translate(p, v) for p, v in params.items()))
def load_config(path: os.PathLike) -> BareConfig:
"""loads a parameters toml files and converts data to appropriate type
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
Parameters
----------
path : PathLike
path to the toml
Returns
----------
BareParams
config obj
"""
config = load_toml(path)
return BareConfig(**config)
def load_config_sequence(*config_paths: os.PathLike) -> list[BareConfig]:
"""Loads a sequence of
Parameters
----------
config_paths : os.PathLike
either one path (the last config containing previous_config_file parameter)
or a list of config path in the order they have to be simulated
Returns
-------
list[BareConfig]
all loaded configs
"""
if config_paths[0] is None:
return []
all_configs = [load_config(config_paths[0])]
if len(config_paths) == 1:
while True:
if all_configs[0].previous_config_file is not None:
all_configs.insert(0, load_config(all_configs[0].previous_config_file))
else:
break
else:
for i, path in enumerate(config_paths[1:]):
all_configs.append(load_config(path))
all_configs[i + 1].previous_config_file = config_paths[i]
return all_configs
def load_material_dico(name: str) -> dict[str, Any]: def load_material_dico(name: str) -> dict[str, Any]:
"""loads a material dictionary """loads a material dictionary
Parameters Parameters
@@ -228,74 +162,6 @@ def get_data_dirs(sim_dir: Path) -> List[Path]:
return [p.resolve() for p in sim_dir.glob("*") if p.is_dir()] return [p.resolve() for p in sim_dir.glob("*") if p.is_dir()]
def check_data_integrity(sub_folders: List[Path], init_z_num: int):
"""checks the integrity and completeness of a simulation data folder
Parameters
----------
path : str
path to the data folder
init_z_num : int
z_num as specified by the initial configuration file
Raises
------
IncompleteDataFolderError
raised if not all spectra are present in any folder
"""
for sub_folder in utils.PBars(sub_folders, "Checking integrity"):
if num_left_to_propagate(sub_folder, init_z_num) != 0:
raise IncompleteDataFolderError(
f"not enough spectra of the specified {init_z_num} found in {sub_folder}"
)
def num_left_to_propagate(sub_folder: Path, init_z_num: int) -> int:
"""checks if a propagation has completed
Parameters
----------
sub_folder : Path
path to the sub folder containing the spectra
init_z_num : int
number of z position to store as specified in the master config file
Returns
-------
bool
True if the propagation has completed
Raises
------
IncompleteDataFolderError
raised if init_z_num doesn't match that specified in the individual parameter file
"""
z_num = load_toml(sub_folder / "params.toml")["z_num"]
num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing
if z_num != init_z_num:
raise IncompleteDataFolderError(
f"initial config specifies {init_z_num} spectra per"
+ f" but the parameter file in {sub_folder} specifies {z_num}"
)
return z_num - num_spectra
def find_last_spectrum_num(data_dir: Path):
for num in itertools.count(1):
p_to_test = data_dir / SPEC1_FN.format(num)
if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0:
return num - 1
def load_last_spectrum(data_dir: Path) -> Tuple[int, np.ndarray]:
"""return the last spectrum stored in path as well as its id"""
num = find_last_spectrum_num(data_dir)
return num, np.load(data_dir / SPEC1_FN.format(num))
def update_appended_params(source: Path, destination: Path, z: Sequence): def update_appended_params(source: Path, destination: Path, z: Sequence):
z_num = len(z) z_num = len(z)
params = load_toml(source) params = load_toml(source)

View File

@@ -8,8 +8,7 @@ from typing import TypeVar
import numpy as np import numpy as np
from scipy.optimize import minimize_scalar from scipy.optimize import minimize_scalar
from ..io import load_material_dico from .. import math, io
from .. import math
from . import fiber, materials, units, pulse from . import fiber, materials, units, pulse
T = TypeVar("T") T = TypeVar("T")
@@ -54,7 +53,7 @@ def material_dispersion(
order = np.argsort(w) order = np.argsort(w)
material_dico = load_material_dico(material) material_dico = io.load_material_dico(material)
if ideal: if ideal:
n_gas_2 = materials.sellmeier(wavelengths, material_dico, pressure, temperature) + 1 n_gas_2 = materials.sellmeier(wavelengths, material_dico, pressure, temperature) + 1
else: else:

View File

@@ -2,10 +2,8 @@ from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union, T
import numpy as np import numpy as np
from numpy.ma import core from numpy.ma import core
import toml
from numpy.fft import fft, ifft from numpy.fft import fft, ifft
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
from numpy.polynomial.polynomial import Polynomial
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
from ..logger import get_logger from ..logger import get_logger

View File

@@ -318,7 +318,8 @@ def L_sol(L_D):
def load_previous_spectrum(prev_data_dir: str) -> np.ndarray: def load_previous_spectrum(prev_data_dir: str) -> np.ndarray:
return io.load_last_spectrum(Path(prev_data_dir))[1] num = utils.find_last_spectrum_num(data_dir)
return np.load(data_dir / SPEC1_FN.format(num))
def load_field_file( def load_field_file(

View File

@@ -8,6 +8,7 @@ from typing import Dict, List, Tuple, Type, Union
import numpy as np import numpy as np
from .. import env, initialize, io, utils from .. import env, initialize, io, utils
from ..utils import Parameters, BareConfig
from ..const import PARAM_SEPARATOR from ..const import PARAM_SEPARATOR
from ..errors import IncompleteDataFolderError from ..errors import IncompleteDataFolderError
from ..logger import get_logger from ..logger import get_logger
@@ -23,29 +24,29 @@ except ModuleNotFoundError:
class RK4IP: class RK4IP:
def __init__( def __init__(
self, self,
params: initialize.Params, params: Parameters,
save_data=False, save_data=False,
job_identifier="", job_identifier="",
task_id=0, task_id=0,
): ):
"""A 1D solver using 4th order Runge-Kutta in the interaction picture """A 1D solver using 4th order Runge-Kutta in the interaction picture
Parameters
----------
Parameters Parameters
---------- parameters of the simulation
params : Params save_data : bool, optional
parameters of the simulation save calculated spectra to disk, by default False
save_data : bool, optional job_identifier : str, optional
save calculated spectra to disk, by default False string identifying the parameter set, by default ""
job_identifier : str, optional task_id : int, optional
string identifying the parameter set, by default "" unique identifier of the session, by default 0
task_id : int, optional
unique identifier of the session, by default 0
""" """
self.set(params, save_data, job_identifier, task_id) self.set(params, save_data, job_identifier, task_id)
def set( def set(
self, self,
params: initialize.Params, params: Parameters,
save_data=False, save_data=False,
job_identifier="", job_identifier="",
task_id=0, task_id=0,
@@ -306,7 +307,7 @@ class RK4IP:
class SequentialRK4IP(RK4IP): class SequentialRK4IP(RK4IP):
def __init__( def __init__(
self, self,
params: initialize.Params, params: Parameters,
pbars: utils.PBars, pbars: utils.PBars,
save_data=False, save_data=False,
job_identifier="", job_identifier="",
@@ -327,7 +328,7 @@ class SequentialRK4IP(RK4IP):
class MutliProcRK4IP(RK4IP): class MutliProcRK4IP(RK4IP):
def __init__( def __init__(
self, self,
params: initialize.Params, params: Parameters,
p_queue: multiprocessing.Queue, p_queue: multiprocessing.Queue,
worker_id: int, worker_id: int,
save_data=False, save_data=False,
@@ -353,7 +354,7 @@ class RayRK4IP(RK4IP):
def set( def set(
self, self,
params: initialize.Params, params: Parameters,
p_actor, p_actor,
worker_id: int, worker_id: int,
save_data=False, save_data=False,
@@ -445,7 +446,9 @@ class Simulations:
self.sim_dir = io.get_sim_dir( self.sim_dir = io.get_sim_dir(
self.id, path_if_new=Path(self.name + PARAM_SEPARATOR + "tmp") self.id, path_if_new=Path(self.name + PARAM_SEPARATOR + "tmp")
) )
io.save_parameters(self.param_seq.config, self.sim_dir, file_name="initial_config.toml") io.save_parameters(
self.param_seq.config.prepare_for_dump(), self.sim_dir, file_name="initial_config.toml"
)
self.sim_jobs_per_node = 1 self.sim_jobs_per_node = 1
@@ -467,19 +470,19 @@ class Simulations:
def _run_available(self): def _run_available(self):
for variable, params in self.param_seq: for variable, params in self.param_seq:
v_list_str = utils.format_variable_list(variable) v_list_str = utils.format_variable_list(variable)
io.save_parameters(params, self.sim_dir / v_list_str) io.save_parameters(params.prepare_for_dump(), self.sim_dir / v_list_str)
self.new_sim(v_list_str, params) self.new_sim(v_list_str, params)
self.finish() self.finish()
def new_sim(self, v_list_str: str, params: initialize.Params): def new_sim(self, v_list_str: str, params: Parameters):
"""responsible to launch a new simulation """responsible to launch a new simulation
Parameters Parameters
---------- ----------
v_list_str : str v_list_str : str
string that uniquely identifies the simulation as returned by utils.format_variable_list string that uniquely identifies the simulation as returned by utils.format_variable_list
params : initialize.Params params : Parameters
computed parameters computed parameters
""" """
raise NotImplementedError() raise NotImplementedError()
@@ -507,7 +510,7 @@ class SequencialSimulations(Simulations, priority=0):
super().__init__(param_seq, task_id=task_id) super().__init__(param_seq, task_id=task_id)
self.pbars = utils.PBars(self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1) self.pbars = utils.PBars(self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1)
def new_sim(self, v_list_str: str, params: initialize.Params): def new_sim(self, v_list_str: str, params: Parameters):
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}") self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
SequentialRK4IP( SequentialRK4IP(
params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id
@@ -556,7 +559,7 @@ class MultiProcSimulations(Simulations, priority=1):
worker.start() worker.start()
super().run() super().run()
def new_sim(self, v_list_str: str, params: initialize.Params): def new_sim(self, v_list_str: str, params: Parameters):
self.queue.put((v_list_str, params), block=True, timeout=None) self.queue.put((v_list_str, params), block=True, timeout=None)
def finish(self): def finish(self):
@@ -579,7 +582,7 @@ class MultiProcSimulations(Simulations, priority=1):
p_queue: multiprocessing.Queue, p_queue: multiprocessing.Queue,
): ):
while True: while True:
raw_data: Tuple[List[tuple], initialize.Params] = queue.get() raw_data: Tuple[List[tuple], Parameters] = queue.get()
if raw_data == 0: if raw_data == 0:
queue.task_done() queue.task_done()
return return
@@ -635,7 +638,7 @@ class RaySimulations(Simulations, priority=2):
.remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps) .remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
) )
def new_sim(self, v_list_str: str, params: initialize.Params): def new_sim(self, v_list_str: str, params: Parameters):
while self.num_submitted >= self.sim_jobs_total: while self.num_submitted >= self.sim_jobs_total:
self.collect_1_job() self.collect_1_job()
@@ -685,7 +688,7 @@ def run_simulation_sequence(
method=None, method=None,
prev_sim_dir: os.PathLike = None, prev_sim_dir: os.PathLike = None,
): ):
configs = io.load_config_sequence(*config_files) configs = BareConfig.load_sequence(*config_files)
prev = prev_sim_dir prev = prev_sim_dir
for config in configs: for config in configs:

View File

@@ -5,7 +5,7 @@
from typing import Callable, TypeVar, Union from typing import Callable, TypeVar, Union
from dataclasses import dataclass from dataclasses import dataclass
from ..utils.parameter import Parameter, boolean, type_checker from ..utils import parameter
import numpy as np import numpy as np
from numpy import pi from numpy import pi
@@ -183,10 +183,10 @@ def is_unit(name, value):
@dataclass @dataclass
class PlotRange: class PlotRange:
left: float = Parameter(type_checker(int, float)) left: float = parameter.Parameter(parameter.type_checker(int, float))
right: float = Parameter(type_checker(int, float)) right: float = parameter.Parameter(parameter.type_checker(int, float))
unit: Callable[[float], float] = Parameter(is_unit, converter=get_unit) unit: Callable[[float], float] = parameter.Parameter(is_unit, converter=get_unit)
conserved_quantity: bool = Parameter(boolean, default=True) conserved_quantity: bool = parameter.Parameter(parameter.boolean, default=True)
def __str__(self): def __str__(self):
return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}" return f"{self.left:.1f}-{self.right:.1f} {self.unit.__name__}"

View File

@@ -21,7 +21,7 @@ from . import io, math
from .defaults import default_plotting as defaults from .defaults import default_plotting as defaults
from .math import abs2, make_uniform_1D, span from .math import abs2, make_uniform_1D, span
from .physics import pulse, units from .physics import pulse, units
from .utils.parameter import BareConfig, BareParams from .utils.parameter import BareConfig, Parameters
RangeType = Tuple[float, float, Union[str, Callable]] RangeType = Tuple[float, float, Union[str, Callable]]
NO_LIM = object() NO_LIM = object()
@@ -263,7 +263,7 @@ def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0
def propagation_plot( def propagation_plot(
values: np.ndarray, values: np.ndarray,
plt_range: Union[units.PlotRange, RangeType], plt_range: Union[units.PlotRange, RangeType],
params: BareParams, params: Parameters,
ax: plt.Axes, ax: plt.Axes,
log: Union[int, float, bool, str] = "1D", log: Union[int, float, bool, str] = "1D",
vmin: float = None, vmin: float = None,
@@ -281,7 +281,7 @@ def propagation_plot(
raw values, either complex fields or complex spectra raw values, either complex fields or complex spectra
plt_range : Union[units.PlotRange, RangeType] plt_range : Union[units.PlotRange, RangeType]
time, wavelength or frequency range time, wavelength or frequency range
params : BareParams params : Parameters
parameters of the simulation parameters of the simulation
log : Union[int, float, bool, str], optional log : Union[int, float, bool, str], optional
what kind of log to apply, see apply_log for details. by default "1D" what kind of log to apply, see apply_log for details. by default "1D"
@@ -418,7 +418,7 @@ def plot_2D(
def transform_2D_propagation( def transform_2D_propagation(
values: np.ndarray, values: np.ndarray,
plt_range: Union[units.PlotRange, RangeType], plt_range: Union[units.PlotRange, RangeType],
params: BareParams, params: Parameters,
log: Union[int, float, bool, str] = "1D", log: Union[int, float, bool, str] = "1D",
skip: int = 1, skip: int = 1,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
@@ -430,7 +430,7 @@ def transform_2D_propagation(
values to transform values to transform
plt_range : Union[units.PlotRange, RangeType] plt_range : Union[units.PlotRange, RangeType]
range range
params : BareParams params : Parameters
parameters of the simulation parameters of the simulation
log : Union[int, float, bool, str], optional log : Union[int, float, bool, str], optional
see apply_log, by default "1D" see apply_log, by default "1D"
@@ -469,7 +469,7 @@ def transform_2D_propagation(
def mean_values_plot( def mean_values_plot(
values: np.ndarray, values: np.ndarray,
plt_range: Union[units.PlotRange, RangeType], plt_range: Union[units.PlotRange, RangeType],
params: BareParams, params: Parameters,
ax: plt.Axes, ax: plt.Axes,
log: Union[float, int, str, bool] = False, log: Union[float, int, str, bool] = False,
vmin: float = None, vmin: float = None,
@@ -511,7 +511,7 @@ def mean_values_plot(
def transform_mean_values( def transform_mean_values(
values: np.ndarray, values: np.ndarray,
plt_range: Union[units.PlotRange, RangeType], plt_range: Union[units.PlotRange, RangeType],
params: BareParams, params: Parameters,
log: Union[bool, int, float] = False, log: Union[bool, int, float] = False,
spacing: Union[int, float] = 1, spacing: Union[int, float] = 1,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
@@ -523,7 +523,7 @@ def transform_mean_values(
values to transform values to transform
plt_range : Union[units.PlotRange, RangeType] plt_range : Union[units.PlotRange, RangeType]
x axis specifications x axis specifications
params : BareParams params : Parameters
parameters of the simulation parameters of the simulation
log : Union[bool, int, float], optional log : Union[bool, int, float], optional
see transform_1D_values for details, by default False see transform_1D_values for details, by default False
@@ -637,7 +637,7 @@ def plot_mean(
def single_position_plot( def single_position_plot(
values: np.ndarray, values: np.ndarray,
plt_range: Union[units.PlotRange, RangeType], plt_range: Union[units.PlotRange, RangeType],
params: BareParams, params: Parameters,
ax: plt.Axes, ax: plt.Axes,
log: Union[str, int, float, bool] = False, log: Union[str, int, float, bool] = False,
vmin: float = None, vmin: float = None,
@@ -712,7 +712,7 @@ def plot_1D(
def transform_1D_values( def transform_1D_values(
values: np.ndarray, values: np.ndarray,
plt_range: Union[units.PlotRange, RangeType], plt_range: Union[units.PlotRange, RangeType],
params: BareParams, params: Parameters,
log: Union[int, float, bool] = False, log: Union[int, float, bool] = False,
spacing: Union[int, float] = 1, spacing: Union[int, float] = 1,
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
@@ -724,7 +724,7 @@ def transform_1D_values(
values to plot, may be complex values to plot, may be complex
plt_range : Union[units.PlotRange, RangeType] plt_range : Union[units.PlotRange, RangeType]
plot range specification, either (min, max, unit) or a PlotRange obj plot range specification, either (min, max, unit) or a PlotRange obj
params : BareParams params : Parameters
parameters of the simulations parameters of the simulations
log : Union[int, float, bool], optional log : Union[int, float, bool], optional
if True, will convert to dB relative to max. If a float or int, whill if True, will convert to dB relative to max. If a float or int, whill
@@ -767,7 +767,7 @@ def plot_spectrogram(
values: np.ndarray, values: np.ndarray,
x_range: RangeType, x_range: RangeType,
y_range: RangeType, y_range: RangeType,
params: BareParams, params: Parameters,
t_res: int = None, t_res: int = None,
gate_width: float = None, gate_width: float = None,
log: bool = "2D", log: bool = "2D",
@@ -790,7 +790,7 @@ def plot_spectrogram(
units : function to convert from the desired units to rad/s or to time. units : function to convert from the desired units to rad/s or to time.
common functions are already defined in scgenerator.physics.units common functions are already defined in scgenerator.physics.units
look there for more details look there for more details
params : BareParams params : Parameters
parameters of the simulations parameters of the simulations
log : bool, optional log : bool, optional
whether to compute the logarithm of the spectrogram whether to compute the logarithm of the spectrogram
@@ -954,7 +954,7 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
def prep_plot_axis( def prep_plot_axis(
values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], params: BareParams values: np.ndarray, plt_range: Union[units.PlotRange, RangeType], params: Parameters
) -> tuple[bool, np.ndarray, units.PlotRange]: ) -> tuple[bool, np.ndarray, units.PlotRange]:
is_spectrum = values.dtype == "complex" is_spectrum = values.dtype == "complex"
if not isinstance(plt_range, units.PlotRange): if not isinstance(plt_range, units.PlotRange):

View File

@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from ..utils.parameter import BareParams from ..utils.parameter import Parameters
from ..const import PARAM_SEPARATOR from ..const import PARAM_SEPARATOR
from ..initialize import ParamSequence from ..initialize import ParamSequence
@@ -19,7 +19,7 @@ from ..plotting import plot_setup
from .. import env, math from .. import env, math
def fingerprint(params: BareParams): def fingerprint(params: Parameters):
h1 = hash(params.field_0.tobytes()) h1 = hash(params.field_0.tobytes())
h2 = tuple(params.beta2_coefficients) h2 = tuple(params.beta2_coefficients)
return h1, h2 return h1, h2
@@ -160,7 +160,7 @@ def plot_1_dispersion(
right: plt.Axes, right: plt.Axes,
style: dict[str, Any], style: dict[str, Any],
lbl: list[str], lbl: list[str],
params: BareParams, params: Parameters,
loss: plt.Axes = None, loss: plt.Axes = None,
): ):
beta_arr = fiber.dispersion_from_coefficients(params.w_c, params.beta2_coefficients) beta_arr = fiber.dispersion_from_coefficients(params.w_c, params.beta2_coefficients)
@@ -253,7 +253,7 @@ def finish_plot(fig, legend_axes, all_labels, params):
plt.show() plt.show()
def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], BareParams]]: def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters]]:
cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"]) cc = cycler(color=[f"C{i}" for i in range(10)]) * cycler(ls=["-", "--"])
pseq = ParamSequence(config_path) pseq = ParamSequence(config_path)
for style, (variables, params) in zip(cc, pseq): for style, (variables, params) in zip(cc, pseq):

View File

@@ -10,7 +10,8 @@ from typing import Tuple
import numpy as np import numpy as np
from ..initialize import validate_config_sequence from ..initialize import validate_config_sequence
from ..io import Paths, load_config from ..io import Paths
from ..utils.parameter import BareConfig
def primes(n): def primes(n):
@@ -127,7 +128,7 @@ def main():
) )
if args.command == "merge": if args.command == "merge":
final_name = load_config(Path(args.configs[0]) / "initial_config.toml").name final_name = BareConfig.load(Path(args.configs[0]) / "initial_config.toml").name
sim_num = "many" sim_num = "many"
args.nodes = 1 args.nodes = 1
args.cpus_per_node = 1 args.cpus_per_node = 1

View File

@@ -11,13 +11,13 @@ from .const import SPECN_FN
from .logger import get_logger from .logger import get_logger
from .physics import pulse, units from .physics import pulse, units
from .plotting import mean_values_plot, propagation_plot, single_position_plot from .plotting import mean_values_plot, propagation_plot, single_position_plot
from .utils.parameter import BareParams from .utils.parameter import Parameters
class Spectrum(np.ndarray): class Spectrum(np.ndarray):
params: BareParams params: Parameters
def __new__(cls, input_array, params: BareParams): def __new__(cls, input_array, params: Parameters):
# Input array is an already formed ndarray instance # Input array is an already formed ndarray instance
# We first cast to be our class type # We first cast to be our class type
obj = np.asarray(input_array).view(cls) obj = np.asarray(input_array).view(cls)
@@ -144,7 +144,7 @@ class Pulse(Sequence):
if not self.path.is_dir(): if not self.path.is_dir():
raise FileNotFoundError(f"Folder {self.path} does not exist") raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = io.load_params(self.path / "params.toml") self.params = Parameters.load(self.path / "params.toml")
initialize.build_sim_grid_in_place(self.params) initialize.build_sim_grid_in_place(self.params)

View File

@@ -16,15 +16,17 @@ from dataclasses import asdict, replace
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, Iterator, TypeVar, Union from typing import Any, Iterable, Iterator, TypeVar, Union
from ..errors import IncompleteDataFolderError
import numpy as np import numpy as np
from numpy.lib.arraysetops import isin from numpy.lib.arraysetops import isin
from tqdm import tqdm from tqdm import tqdm
from .. import env from .. import env
from ..const import PARAM_SEPARATOR from ..const import PARAM_SEPARATOR, SPEC1_FN
from ..math import * from ..math import *
from .parameter import BareConfig, BareParams from .. import io
from .parameter import BareConfig, Parameters
T_ = TypeVar("T_") T_ = TypeVar("T_")
@@ -212,7 +214,7 @@ def format_value(name: str, value) -> str:
return str(value) return str(value)
elif isinstance(value, (float, int)): elif isinstance(value, (float, int)):
try: try:
return getattr(BareParams, name).display(value) return getattr(Parameters, name).display(value)
except AttributeError: except AttributeError:
return format(value, ".9g") return format(value, ".9g")
elif isinstance(value, (list, tuple, np.ndarray)): elif isinstance(value, (list, tuple, np.ndarray)):
@@ -226,7 +228,7 @@ def format_value(name: str, value) -> str:
def pretty_format_value(name: str, value) -> str: def pretty_format_value(name: str, value) -> str:
try: try:
return getattr(BareParams, name).display(value) return getattr(Parameters, name).display(value)
except AttributeError: except AttributeError:
return name + PARAM_SEPARATOR + str(value) return name + PARAM_SEPARATOR + str(value)
@@ -248,12 +250,74 @@ def pretty_format_from_sim_name(name: str) -> str:
out = [] out = []
for key, value in zip(s[::2], s[1::2]): for key, value in zip(s[::2], s[1::2]):
try: try:
out += [key.replace("_", " "), getattr(BareParams, key).display(float(value))] out += [key.replace("_", " "), getattr(Parameters, key).display(float(value))]
except (AttributeError, ValueError): except (AttributeError, ValueError):
out.append(key + PARAM_SEPARATOR + value) out.append(key + PARAM_SEPARATOR + value)
return PARAM_SEPARATOR.join(out) return PARAM_SEPARATOR.join(out)
def check_data_integrity(sub_folders: list[Path], init_z_num: int):
"""checks the integrity and completeness of a simulation data folder
Parameters
----------
path : str
path to the data folder
init_z_num : int
z_num as specified by the initial configuration file
Raises
------
IncompleteDataFolderError
raised if not all spectra are present in any folder
"""
for sub_folder in PBars(sub_folders, "Checking integrity"):
if num_left_to_propagate(sub_folder, init_z_num) != 0:
raise IncompleteDataFolderError(
f"not enough spectra of the specified {init_z_num} found in {sub_folder}"
)
def num_left_to_propagate(sub_folder: Path, init_z_num: int) -> int:
"""checks if a propagation has completed
Parameters
----------
sub_folder : Path
path to the sub folder containing the spectra
init_z_num : int
number of z position to store as specified in the master config file
Returns
-------
bool
True if the propagation has completed
Raises
------
IncompleteDataFolderError
raised if init_z_num doesn't match that specified in the individual parameter file
"""
z_num = io.load_toml(sub_folder / "params.toml")["z_num"]
num_spectra = find_last_spectrum_num(sub_folder) + 1 # because of zero-indexing
if z_num != init_z_num:
raise IncompleteDataFolderError(
f"initial config specifies {init_z_num} spectra per"
+ f" but the parameter file in {sub_folder} specifies {z_num}"
)
return z_num - num_spectra
def find_last_spectrum_num(data_dir: Path):
for num in itertools.count(1):
p_to_test = data_dir / SPEC1_FN.format(num)
if not p_to_test.is_file() or os.path.getsize(p_to_test) == 0:
return num - 1
def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]: def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]:
"""given a config with "variable" parameters, iterates through every possible combination, """given a config with "variable" parameters, iterates through every possible combination,
yielding a a list of (parameter_name, value) tuples and a full config dictionary. yielding a a list of (parameter_name, value) tuples and a full config dictionary.
@@ -268,7 +332,7 @@ def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]
Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]] Iterator[tuple[list[tuple[str, Any]], dict[str, Any]]]
variable_list : a list of (name, value) tuple of parameter name and value that are variable. variable_list : a list of (name, value) tuple of parameter name and value that are variable.
params : a dict[str, Any] to be fed to Params params : a dict[str, Any] to be fed to Parameters
""" """
possible_keys = [] possible_keys = []
possible_ranges = [] possible_ranges = []
@@ -294,7 +358,7 @@ def variable_iterator(config: BareConfig) -> Iterator[tuple[list[tuple[str, Any]
def required_simulations( def required_simulations(
*configs: BareConfig, *configs: BareConfig,
) -> Iterator[tuple[list[tuple[str, Any]], BareParams]]: ) -> Iterator[tuple[list[tuple[str, Any]], Parameters]]:
"""takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different """takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different
parameter set and iterates through every single necessary simulation parameter set and iterates through every single necessary simulation
@@ -317,7 +381,7 @@ def required_simulations(
for j in range(configs[0].repeat or 1): for j in range(configs[0].repeat or 1):
variable_ind = [("id", i)] + variable_only + [("num", j)] variable_ind = [("id", i)] + variable_only + [("num", j)]
i += 1 i += 1
yield variable_ind, BareParams(**params_dict) yield variable_ind, Parameters(**params_dict)
def reduce_all_variable(all_variable: list[list[tuple[str, Any]]]) -> list[tuple[str, Any]]: def reduce_all_variable(all_variable: list[list[tuple[str, Any]]]) -> list[tuple[str, Any]]:

View File

@@ -1,392 +0,0 @@
import itertools
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Optional, TypeVar, Union
import numpy as np
from .. import math
from ..logger import get_logger
from ..physics import fiber, materials, pulse, units
T = TypeVar("T")
import inspect
class EvaluatorError(Exception):
pass
class Rule:
def __init__(
self,
target: Union[str, list[Optional[str]]],
func: Callable,
args: list[str] = None,
priorities: Union[int, list[int]] = None,
conditions: dict[str, str] = None,
):
targets = list(target) if isinstance(target, (list, tuple)) else [target]
self.func = func
if priorities is None:
priorities = [1] * len(targets)
elif isinstance(priorities, (int, float, np.integer, np.floating)):
priorities = [priorities]
self.targets = dict(zip(targets, priorities))
if args is None:
args = get_arg_names(func)
self.args = args
self.conditions = conditions or {}
def __repr__(self) -> str:
return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})"
@classmethod
def deduce(
cls,
target: Union[str, list[Optional[str]]],
func: Callable,
kwarg_names: list[str],
n_var: int,
args_const: list[str] = None,
) -> list["Rule"]:
"""given a function that doesn't need all its keyword arguemtn specified, will
return a list of Rule obj, one for each combination of n_var specified kwargs
Parameters
----------
target : str | list[str | None]
name of the variable(s) that func returns
func : Callable
function to work with
kwarg_names : list[str]
list of all kwargs of the function to be used
n_var : int
how many shoulf be used per rule
arg_const : list[str], optional
override the name of the positional arguments
Returns
-------
list[Rule]
list of all possible rules
Example
-------
>> def lol(a, b=None, c=None):
pass
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
[
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
]
"""
rules: list[cls] = []
for var_possibility in itertools.combinations(kwarg_names, n_var):
new_func = func_rewrite(func, list(var_possibility), args_const)
rules.append(cls(target, new_func))
return rules
@dataclass
class EvalStat:
priority: float = np.inf
class Evaluator:
@classmethod
def default(cls) -> "Evaluator":
evaluator = cls()
evaluator.append(*default_rules)
return evaluator
def __init__(self):
self.rules: dict[str, list[Rule]] = defaultdict(list)
self.params = {}
self.__curent_lookup = set()
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
self.logger = get_logger(__name__)
def append(self, *rule: Rule):
for r in rule:
for t in r.targets:
if t is not None:
self.rules[t].append(r)
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
def set(self, **params: Any):
self.params.update(params)
for k in params:
self.eval_stats[k].priority = np.inf
def reset(self):
self.params = {}
self.eval_stats = defaultdict(EvalStat)
def compute(self, target: str) -> Any:
"""computes a target
Parameters
----------
target : str
name of the target
Returns
-------
Any
return type of the target function
Raises
------
EvaluatorError
a cyclic dependence exists
KeyError
there is no saved rule for the target
"""
value = self.params.get(target)
if value is None:
if target in self.__curent_lookup:
raise EvaluatorError(
"cyclic dependency detected : "
f"{target!r} seems to depend on itself, "
f"please provide a value for at least one variable in {self.__curent_lookup}"
)
else:
self.__curent_lookup.add(target)
if len(self.rules[target]) == 0:
raise EvaluatorError(f"no rule for {target}")
error = None
for ii, rule in enumerate(
filter(lambda r: self.validate_condition(r), reversed(self.rules[target]))
):
self.logger.debug(f"attempt {ii+1} to compute {target}, this time using {rule!r}")
try:
args = [self.compute(k) for k in rule.args]
returned_values = rule.func(*args)
if len(rule.targets) == 1:
returned_values = [returned_values]
for ((param_name, param_priority), returned_value) in zip(
rule.targets.items(), returned_values
):
if (
param_name == target
or param_name not in self.params
or self.eval_stats[param_name].priority < param_priority
):
self.logger.info(
f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}"
)
self.params[param_name] = returned_value
self.eval_stats[param_name] = param_priority
if param_name == target:
value = returned_value
break
except (EvaluatorError, KeyError) as e:
error = e
continue
if value is None and error is not None:
raise error
self.__curent_lookup.remove(target)
return value
def validate_condition(self, rule: Rule) -> bool:
return all(self.compute(k) == v for k, v in rule.conditions.items())
def __call__(self, target: str, args: list[str] = None):
"""creates a wrapper that adds decorated functions to the set of rules
Parameters
----------
target : str
name of the target
args : list[str], optional
list of name of arguments. Automatically deduced from function signature if
not provided, by default None
"""
def wrapper(func):
self.append(Rule(target, func, args))
return func
return wrapper
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
out_func = scope[tmp_name]
out_func.__module__ = "evaluator"
return out_func
default_rules: list[Rule] = [
# Grid
*Rule.deduce(
["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"],
math.build_sim_grid,
["time_window", "t_num", "dt"],
2,
),
# Pulse
Rule("spec_0", np.fft.fft, ["field_0"]),
Rule("field_0", np.fft.ifft, ["spec_0"]),
Rule("spec_0", pulse.load_previous_spectrum, priorities=3),
Rule(
["pre_field_0", "peak_power", "energy", "width"],
pulse.load_field_file,
[
"field_file",
"t",
"peak_power",
"energy",
"intensity_noise",
"noise_correlation",
"quantum_noise",
"w_c",
"w0",
"time_window",
"dt",
],
priorities=[2, 1, 1, 1],
),
Rule("pre_field_0", pulse.initial_field, priorities=1),
Rule(
"field_0",
pulse.add_shot_noise,
["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"],
),
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
Rule("peak_power", pulse.soliton_num_to_peak_power),
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
Rule("energy", pulse.mean_power_to_energy),
Rule("t0", pulse.width_to_t0),
Rule("t0", pulse.soliton_num_to_t0),
Rule("width", pulse.t0_to_width),
Rule("soliton_num", pulse.soliton_num),
Rule("L_D", pulse.L_D),
Rule("L_NL", pulse.L_NL),
Rule("L_sol", pulse.L_sol),
# Fiber Dispersion
Rule("wl_for_disp", fiber.lambda_for_dispersion),
Rule("w_for_disp", units.m, ["wl_for_disp"]),
Rule(
"beta2_coefficients",
fiber.dispersion_coefficients,
["wl_for_disp", "beta2_arr", "w0", "interpolation_range", "interpolation_degree"],
),
Rule("beta2_arr", fiber.beta2),
Rule("beta2_arr", fiber.dispersion_from_coefficients),
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
Rule(
["wl_for_disp", "beta2_arr", "interpolation_range"],
fiber.load_custom_dispersion,
priorities=[2, 2, 2],
),
Rule("hr_w", fiber.delayed_raman_w),
Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")),
Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")),
Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")),
Rule(
"n_eff",
fiber.n_eff_pcf,
["wl_for_disp", "pitch", "pitch_ratio"],
conditions=dict(model="pcf"),
),
Rule("capillary_spacing", fiber.HCARF_gap),
# Fiber nonlinearity
Rule("A_eff", fiber.A_eff_from_V),
Rule("A_eff", fiber.A_eff_from_diam),
Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")),
Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1),
Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]),
Rule("A_eff_arr", fiber.load_custom_A_eff),
Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
Rule(
"V_eff",
fiber.V_parameter_koshiba,
["wavelength", "pitch", "pitch_ratio"],
conditions=dict(model="pcf"),
),
Rule("V_eff", fiber.V_eff_marcuse, ["wavelength", "core_radius", "numerical_aperture"]),
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
Rule("V_eff_arr", fiber.V_eff_marcuse),
Rule("gamma", lambda gamma_arr: gamma_arr[0]),
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]),
# Fiber loss
Rule("alpha", fiber.compute_capillary_loss),
Rule("alpha", fiber.load_custom_loss),
# gas
Rule("n_gas_2", materials.n_gas_2),
]
def main():
import matplotlib.pyplot as plt
evalor = Evaluator()
evalor.append(*default_rules)
evalor.set(
**{
"length": 1,
"z_num": 128,
"wavelength": 1500e-9,
"interpolation_degree": 8,
"interpolation_range": (500e-9, 2200e-9),
"t_num": 16384,
"dt": 1e-15,
"shape": "gaussian",
"repetition_rate": 40e6,
"width": 30e-15,
"mean_power": 100e-3,
"n2": 2.4e-20,
"A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_marcuse.npz",
"model": "pcf",
"quantum_noise": True,
"pitch": 1.2e-6,
"pitch_ratio": 0.5,
}
)
evalor.compute("z_targets")
print(evalor.params.keys())
print(evalor.params["l"][evalor.params["l"] > 0].min())
evalor.compute("spec_0")
plt.plot(evalor.params["l"], abs(evalor.params["spec_0"]) ** 2)
plt.yscale("log")
plt.show()
print(evalor.compute("gamma"))
print(evalor.compute("beta2"))
from pprint import pprint
if __name__ == "__main__":
main()

View File

@@ -1,15 +1,21 @@
import datetime as datetime_module import datetime as datetime_module
import inspect
import itertools
import re
from collections import defaultdict
from copy import copy from copy import copy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from functools import lru_cache from functools import lru_cache
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, TypeVar from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
import os
import numpy as np import numpy as np
from tqdm.std import Bar
from .. import math
from ..const import __version__ from ..const import __version__
from ..logger import get_logger
# from .evaluator import Rule, Evaluator from .. import io
# from ..physics import pulse, fiber, materials from ..physics import fiber, materials, pulse, units
T = TypeVar("T") T = TypeVar("T")
@@ -365,10 +371,10 @@ mandatory_parameters = [
@dataclass @dataclass
class BareParams: class Parameters:
""" """
This class defines each valid parameter's name, type and valid value but doesn't provide This class defines each valid parameter's name, type and valid value. Initializing
any method to act on those. For that, use initialize.Params such an obj will automatically compute all possible parameters
""" """
# root # root
@@ -476,11 +482,25 @@ class BareParams:
def prepare_for_dump(self) -> Dict[str, Any]: def prepare_for_dump(self) -> Dict[str, Any]:
param = asdict(self) param = asdict(self)
param = BareParams.strip_params_dict(param) param = Parameters.strip_params_dict(param)
param["datetime"] = datetime_module.datetime.now() param["datetime"] = datetime_module.datetime.now()
param["version"] = __version__ param["version"] = __version__
return param return param
def __post_init__(self):
param_dict = {k: v for k, v in asdict(self).items() if v is not None}
evaluator = Evaluator.default()
evaluator.set(**param_dict)
for p_name in mandatory_parameters:
evaluator.compute(p_name)
for k, v in evaluator.params.items():
if k in param_dict:
setattr(self, k, v)
@classmethod
def load(cls, path: os.PathLike) -> "Parameters":
return cls(**io.load_toml(path))
@staticmethod @staticmethod
def strip_params_dict(dico: Dict[str, Any]) -> Dict[str, Any]: def strip_params_dict(dico: Dict[str, Any]) -> Dict[str, Any]:
"""prepares a dictionary for serialization. Some keys may not be preserved """prepares a dictionary for serialization. Some keys may not be preserved
@@ -513,7 +533,7 @@ class BareParams:
if not isinstance(value, types): if not isinstance(value, types):
continue continue
if isinstance(value, dict): if isinstance(value, dict):
out[key] = BareParams.strip_params_dict(value) out[key] = Parameters.strip_params_dict(value)
elif isinstance(value, np.ndarray) and value.dtype == complex: elif isinstance(value, np.ndarray) and value.dtype == complex:
continue continue
else: else:
@@ -525,9 +545,382 @@ class BareParams:
return out return out
class EvaluatorError(Exception):
pass
class Rule:
def __init__(
self,
target: Union[str, list[Optional[str]]],
func: Callable,
args: list[str] = None,
priorities: Union[int, list[int]] = None,
conditions: dict[str, str] = None,
):
targets = list(target) if isinstance(target, (list, tuple)) else [target]
self.func = func
if priorities is None:
priorities = [1] * len(targets)
elif isinstance(priorities, (int, float, np.integer, np.floating)):
priorities = [priorities]
self.targets = dict(zip(targets, priorities))
if args is None:
args = get_arg_names(func)
self.args = args
self.conditions = conditions or {}
def __repr__(self) -> str:
return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})"
@classmethod
def deduce(
cls,
target: Union[str, list[Optional[str]]],
func: Callable,
kwarg_names: list[str],
n_var: int,
args_const: list[str] = None,
) -> list["Rule"]:
"""given a function that doesn't need all its keyword arguemtn specified, will
return a list of Rule obj, one for each combination of n_var specified kwargs
Parameters
----------
target : str | list[str | None]
name of the variable(s) that func returns
func : Callable
function to work with
kwarg_names : list[str]
list of all kwargs of the function to be used
n_var : int
how many shoulf be used per rule
arg_const : list[str], optional
override the name of the positional arguments
Returns
-------
list[Rule]
list of all possible rules
Example
-------
>> def lol(a, b=None, c=None):
pass
>> print(Rule.deduce(["d"], lol, ["b", "c"], 1))
[
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d0d0>, args=['a', 'b']),
Rule(targets={'d': 1}, func=<function lol_0 at 0x7f9bce31d160>, args=['a', 'c'])
]
"""
rules: list[cls] = []
for var_possibility in itertools.combinations(kwarg_names, n_var):
new_func = func_rewrite(func, list(var_possibility), args_const)
rules.append(cls(target, new_func))
return rules
@dataclass @dataclass
class BareConfig(BareParams): class EvalStat:
variable: dict = VariableParameter(BareParams) priority: float = np.inf
class Evaluator:
@classmethod
def default(cls) -> "Evaluator":
evaluator = cls()
evaluator.append(*default_rules)
return evaluator
def __init__(self):
self.rules: dict[str, list[Rule]] = defaultdict(list)
self.params = {}
self.__curent_lookup = set()
self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat)
self.logger = get_logger(__name__)
def append(self, *rule: Rule):
for r in rule:
for t in r.targets:
if t is not None:
self.rules[t].append(r)
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
def set(self, **params: Any):
self.params.update(params)
for k in params:
self.eval_stats[k].priority = np.inf
def reset(self):
self.params = {}
self.eval_stats = defaultdict(EvalStat)
def compute(self, target: str) -> Any:
"""computes a target
Parameters
----------
target : str
name of the target
Returns
-------
Any
return type of the target function
Raises
------
EvaluatorError
a cyclic dependence exists
KeyError
there is no saved rule for the target
"""
value = self.params.get(target)
if value is None:
if target in self.__curent_lookup:
raise EvaluatorError(
"cyclic dependency detected : "
f"{target!r} seems to depend on itself, "
f"please provide a value for at least one variable in {self.__curent_lookup}"
)
else:
self.__curent_lookup.add(target)
if len(self.rules[target]) == 0:
raise EvaluatorError(f"no rule for {target}")
error = None
for ii, rule in enumerate(
filter(lambda r: self.validate_condition(r), reversed(self.rules[target]))
):
self.logger.debug(f"attempt {ii+1} to compute {target}, this time using {rule!r}")
try:
args = [self.compute(k) for k in rule.args]
returned_values = rule.func(*args)
if len(rule.targets) == 1:
returned_values = [returned_values]
for ((param_name, param_priority), returned_value) in zip(
rule.targets.items(), returned_values
):
if (
param_name == target
or param_name not in self.params
or self.eval_stats[param_name].priority < param_priority
):
self.logger.info(
f"computed {param_name}={returned_value} using {rule.func.__name__} from {rule.func.__module__}"
)
self.params[param_name] = returned_value
self.eval_stats[param_name] = param_priority
if param_name == target:
value = returned_value
break
except (EvaluatorError, KeyError) as e:
error = e
continue
if value is None and error is not None:
raise error
self.__curent_lookup.remove(target)
return value
def validate_condition(self, rule: Rule) -> bool:
return all(self.compute(k) == v for k, v in rule.conditions.items())
def __call__(self, target: str, args: list[str] = None):
"""creates a wrapper that adds decorated functions to the set of rules
Parameters
----------
target : str
name of the target
args : list[str], optional
list of name of arguments. Automatically deduced from function signature if
not provided, by default None
"""
def wrapper(func):
self.append(Rule(target, func, args))
return func
return wrapper
def get_arg_names(func: Callable) -> list[str]:
spec = inspect.getfullargspec(func)
args = spec.args
if spec.defaults is not None and len(spec.defaults) > 0:
args = args[: -len(spec.defaults)]
return args
def validate_arg_names(names: list[str]):
for n in names:
if re.match(r"^[^\s\-'\(\)\"\d][^\(\)\-\s'\"]*$", n) is None:
raise ValueError(f"{n} is an invalid parameter name")
def func_rewrite(func: Callable, kwarg_names: list[str], arg_names: list[str] = None):
if arg_names is None:
arg_names = get_arg_names(func)
else:
validate_arg_names(arg_names)
validate_arg_names(kwarg_names)
sign_arg_str = ", ".join(arg_names + kwarg_names)
call_arg_str = ", ".join(arg_names + [f"{s}={s}" for s in kwarg_names])
tmp_name = f"{func.__name__}_0"
func_str = f"def {tmp_name}({sign_arg_str}):\n return __func__({call_arg_str})"
scope = dict(__func__=func)
exec(func_str, scope)
out_func = scope[tmp_name]
out_func.__module__ = "evaluator"
return out_func
default_rules: list[Rule] = [
# Grid
*Rule.deduce(
["z_targets", "t", "time_window", "t_num", "dt", "w_c", "w0", "w", "w_power_fact", "l"],
math.build_sim_grid,
["time_window", "t_num", "dt"],
2,
),
# Pulse
Rule("spec_0", np.fft.fft, ["field_0"]),
Rule("field_0", np.fft.ifft, ["spec_0"]),
Rule("spec_0", pulse.load_previous_spectrum, priorities=3),
Rule(
["pre_field_0", "peak_power", "energy", "width"],
pulse.load_field_file,
[
"field_file",
"t",
"peak_power",
"energy",
"intensity_noise",
"noise_correlation",
"quantum_noise",
"w_c",
"w0",
"time_window",
"dt",
],
priorities=[2, 1, 1, 1],
),
Rule("pre_field_0", pulse.initial_field, priorities=1),
Rule(
"field_0",
pulse.add_shot_noise,
["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"],
),
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
Rule("peak_power", pulse.soliton_num_to_peak_power),
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
Rule("energy", pulse.mean_power_to_energy),
Rule("t0", pulse.width_to_t0),
Rule("t0", pulse.soliton_num_to_t0),
Rule("width", pulse.t0_to_width),
Rule("soliton_num", pulse.soliton_num),
Rule("L_D", pulse.L_D),
Rule("L_NL", pulse.L_NL),
Rule("L_sol", pulse.L_sol),
# Fiber Dispersion
Rule("wl_for_disp", fiber.lambda_for_dispersion),
Rule("w_for_disp", units.m, ["wl_for_disp"]),
Rule(
"beta2_coefficients",
fiber.dispersion_coefficients,
["wl_for_disp", "beta2_arr", "w0", "interpolation_range", "interpolation_degree"],
),
Rule("beta2_arr", fiber.beta2),
Rule("beta2_arr", fiber.dispersion_from_coefficients),
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
Rule(
["wl_for_disp", "beta2_arr", "interpolation_range"],
fiber.load_custom_dispersion,
priorities=[2, 2, 2],
),
Rule("hr_w", fiber.delayed_raman_w),
Rule("n_eff", fiber.n_eff_hasan, conditions=dict(model="hasan")),
Rule("n_eff", fiber.n_eff_marcatili, conditions=dict(model="marcatili")),
Rule("n_eff", fiber.n_eff_marcatili_adjusted, conditions=dict(model="marcatili_adjusted")),
Rule(
"n_eff",
fiber.n_eff_pcf,
["wl_for_disp", "pitch", "pitch_ratio"],
conditions=dict(model="pcf"),
),
Rule("capillary_spacing", fiber.HCARF_gap),
# Fiber nonlinearity
Rule("A_eff", fiber.A_eff_from_V),
Rule("A_eff", fiber.A_eff_from_diam),
Rule("A_eff", fiber.A_eff_hasan, conditions=dict(model="hasan")),
Rule("A_eff", fiber.A_eff_from_gamma, priorities=-1),
Rule("A_eff_arr", fiber.A_eff_from_V, ["core_radius", "V_eff_arr"]),
Rule("A_eff_arr", fiber.load_custom_A_eff),
Rule("A_eff_arr", fiber.constant_A_eff_arr, priorities=-1),
Rule(
"V_eff",
fiber.V_parameter_koshiba,
["wavelength", "pitch", "pitch_ratio"],
conditions=dict(model="pcf"),
),
Rule("V_eff", fiber.V_eff_marcuse, ["wavelength", "core_radius", "numerical_aperture"]),
Rule("V_eff_arr", fiber.V_parameter_koshiba, conditions=dict(model="pcf")),
Rule("V_eff_arr", fiber.V_eff_marcuse),
Rule("gamma", lambda gamma_arr: gamma_arr[0]),
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "A_eff_arr"]),
# Fiber loss
Rule("alpha", fiber.compute_capillary_loss),
Rule("alpha", fiber.load_custom_loss),
# gas
Rule("n_gas_2", materials.n_gas_2),
]
@dataclass
class BareConfig(Parameters):
variable: dict = VariableParameter(Parameters)
def __post_init__(self):
pass
@classmethod
def load(cls, path: os.PathLike) -> "BareConfig":
return cls(**io.load_toml(path))
@classmethod
def load_sequence(cls, *config_paths: os.PathLike) -> list["BareConfig"]:
"""Loads a sequence of
Parameters
----------
config_paths : os.PathLike
either one path (the last config containing previous_config_file parameter)
or a list of config path in the order they have to be simulated
Returns
-------
list[BareConfig]
all loaded configs
"""
if config_paths[0] is None:
return []
all_configs = [cls.load(config_paths[0])]
if len(config_paths) == 1:
while True:
if all_configs[0].previous_config_file is not None:
all_configs.insert(0, cls.load(all_configs[0].previous_config_file))
else:
break
else:
for i, path in enumerate(config_paths[1:]):
all_configs.append(cls.load(path))
all_configs[i + 1].previous_config_file = config_paths[i]
return all_configs
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -7,7 +7,7 @@ import toml
from scgenerator import defaults, utils, math from scgenerator import defaults, utils, math
from scgenerator.errors import * from scgenerator.errors import *
from scgenerator.physics import pulse, units from scgenerator.physics import pulse, units
from scgenerator.utils.parameter import BareConfig, BareParams from scgenerator.utils.parameter import BareConfig, Parameters
def load_conf(name): def load_conf(name):
@@ -143,10 +143,10 @@ class TestInitializeMethods(unittest.TestCase):
init.Config(**conf("good5")).__dict__.items(), init.Config(**conf("good5")).__dict__.items(),
) )
def setup_conf_custom_field(self, path) -> BareParams: def setup_conf_custom_field(self, path) -> Parameters:
conf = load_conf(path) conf = load_conf(path)
conf = BareParams(**conf) conf = Parameters(**conf)
init.build_sim_grid_in_place(conf) init.build_sim_grid_in_place(conf)
return conf return conf
@@ -192,12 +192,12 @@ class TestInitializeMethods(unittest.TestCase):
self.assertTrue(result) self.assertTrue(result)
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1") conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
result = init.Params.from_bare(conf) result = Parameters(**conf)
self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9) self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9)
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1") conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
conf.wavelength = 1593e-9 conf.wavelength = 1593e-9
result = init.Params.from_bare(conf) result = Parameters(**conf)
conf = load_conf("custom_field/wavelength_shift2") conf = load_conf("custom_field/wavelength_shift2")
conf = init.Config(**conf) conf = init.Config(**conf)

View File

@@ -1,12 +0,0 @@
import scgenerator as sc
from pathlib import Path
import os
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/")
root = Path("PM1550+PMHNLF+PM1550+PM2000")
confs = sc.io.load_config_sequence(root / "4_PM2000.toml")
final = sc.utils.final_config_from_sequence(*confs)
print(final)