new param system

This commit is contained in:
Benoît Sierro
2021-06-10 12:49:09 +02:00
parent 4a401a5771
commit 74cb057dbe
23 changed files with 1257 additions and 1488 deletions

View File

@@ -0,0 +1,4 @@
from .cli import main
if __name__ == "__main__":
main()

View File

@@ -105,7 +105,7 @@ def prep_ray(args):
def resume_sim(args):
method = prep_ray(args)
sim = resume_simulations(args.sim_dir, method=method)
sim = resume_simulations(Path(args.sim_dir), method=method)
sim.run()
run_simulation_sequence(
*args.configs, method=method, prev_sim_dir=sim.sim_dir, final_name=args.output_name

View File

@@ -1,78 +0,0 @@
from .. import const
import toml
valid_commands = ["finish", "next"]
class Configurator:
def __init__(self, name):
self.config = dict(name=name, fiber=dict(), gas=dict(), pulse=dict(), simulation=dict())
def list_input(self):
answer = ""
while answer == "":
answer = input("Please enter a list of values (one per line)\n")
out = [self.process_input(answer)]
while answer != "":
answer = input()
out.append(self.process_input(answer))
return out[:-1]
def process_input(self, s):
try:
return int(s)
except ValueError:
pass
try:
return float(s)
except ValueError:
pass
return s
def accept(self, question, default=True):
question += " ([y]/n)" if default else " (y/[n])"
question += "\n"
inp = input(question)
yes_str = ["y", "yes"]
if default:
yes_str.append("")
return inp.lower() in yes_str
def print_current(self, config: dict):
print(toml.dumps(config))
def get(self, section, param_name):
question = f"Please enter a value for the parameter '{param_name}'\n"
valid = const.valid_param_types[section][param_name]
is_valid = False
value = None
while not is_valid:
answer = input(question)
if answer == "variable" and param_name in const.valid_variable[section]:
value = self.list_input()
print(value)
is_valid = all(valid(v) for v in value)
else:
value = self.process_input(answer)
is_valid = valid(value)
return value
def ask_next_command(self):
s = ""
raw_input = input(s).split(" ")
return raw_input[0], raw_input[1:]
def main(self):
editing = True
while editing:
command, args = self.ask_next_command()

View File

@@ -1,6 +1,3 @@
import numpy as np
from collections import namedtuple
__version__ = "0.1.0"
@@ -19,243 +16,6 @@ def pbar_format(worker_id: int):
)
#####
def in_range_excl(func, r):
def _in_range(n):
if not func(n):
return False
return n > r[0] and n < r[1]
_in_range.__doc__ = func.__doc__ + f" between {r[0]} and {r[1]} (exclusive) "
return _in_range
def in_range_incl(func, r):
def _in_range(n):
if not func(n):
return False
return n >= r[0] and n <= r[1]
_in_range.__doc__ = func.__doc__ + f" between {r[0]} and {r[1]} (inclusive)"
return _in_range
def num(n):
"""must be a single, real, non-negative number"""
return isinstance(n, (float, int)) and n >= 0
def integer(n):
"""must be a strictly positive integer"""
return isinstance(n, int) and n > 0
def boolean(b):
"""must be a boolean"""
return type(b) == bool
def behaviors(l):
"""must be a valid list of behaviors"""
for s in l:
if s.lower() not in ["spm", "raman", "ss"]:
return False
return True
def beta(l):
"""must be a valid beta array"""
for n in l:
if not isinstance(n, (float, int)):
return False
return True
def field_0(f):
return isinstance(f, (str, tuple, list, np.ndarray))
def he_mode(mode):
"""must be a valide HE mode"""
if not isinstance(mode, (list, tuple)):
return False
if not len(mode) == 2:
return False
for m in mode:
if not integer(m):
return False
return True
def fit_parameters(param):
"""must be a valide fitting parameter tuple of the mercatili_adjusted model"""
if not isinstance(param, (list, tuple)):
return False
if not len(param) == 2:
return False
for n in param:
if not integer(n):
return False
return True
def string(l=None):
if l is None:
def _string(s):
return isinstance(s, str)
_string.__doc__ = f"must be a str"
else:
def _string(s):
return isinstance(s, str) and s.lower() in l
_string.__doc__ = f"must be a str matching one of {l}"
return _string
def capillary_resonance_strengths(l):
"""must be a list of non-zero, real number"""
if not isinstance(l, (list, tuple)):
return False
for m in l:
if not num(m):
return False
return True
def capillary_nested(n):
"""must be a non negative integer"""
return isinstance(n, int) and n >= 0
valid_param_types = dict(
root=dict(
name=string(),
prev_data_dir=string(),
),
fiber=dict(
input_transmission=in_range_incl(num, (0, 1)),
gamma=num,
n2=num,
effective_mode_diameter=num,
A_eff=num,
pitch=in_range_excl(num, (0, 1e-3)),
pitch_ratio=in_range_excl(num, (0, 1)),
core_radius=in_range_excl(num, (0, 1e-3)),
he_mode=he_mode,
fit_parameters=fit_parameters,
beta=beta,
dispersion_file=string(),
model=string(["pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"]),
length=in_range_excl(num, (0, 1e9)),
capillary_num=integer,
capillary_outer_d=in_range_excl(num, (0, 1e-3)),
capillary_thickness=in_range_excl(num, (0, 1e-3)),
capillary_spacing=in_range_excl(num, (0, 1e-3)),
capillary_resonance_strengths=capillary_resonance_strengths,
capillary_nested=capillary_nested,
),
gas=dict(
gas_name=string(["vacuum", "helium", "air"]),
pressure=num,
temperature=num,
plasma_density=num,
),
pulse=dict(
field_0=field_0,
field_file=string(),
repetition_rate=num,
peak_power=num,
mean_power=num,
energy=num,
soliton_num=num,
quantum_noise=boolean,
shape=string(["gaussian", "sech"]),
wavelength=in_range_excl(num, (100e-9, 3000e-9)),
intensity_noise=in_range_incl(num, (0, 1)),
width=in_range_excl(num, (0, 1e-9)),
t0=in_range_excl(num, (0, 1e-9)),
),
simulation=dict(
behaviors=behaviors,
parallel=boolean,
raman_type=string(["measured", "agrawal", "stolen"]),
ideal_gas=boolean,
repeat=integer,
t_num=integer,
z_num=integer,
time_window=num,
dt=in_range_excl(num, (0, 5e-15)),
tolerated_error=in_range_excl(num, (1e-15, 1e-5)),
step_size=num,
lower_wavelength_interp_limit=in_range_excl(num, (100e-9, 3000e-9)),
upper_wavelength_interp_limit=in_range_excl(num, (100e-9, 5000e-9)),
frep=num,
prev_sim_dir=string(),
readjust_wavelength=boolean,
),
)
hc_model_specific_parameters = dict(
marcatili=["core_radius", "he_mode"],
marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"],
hasan=[
"core_radius",
"capillary_num",
"capillary_thickness",
"capillary_resonance_strengths",
"capillary_nested",
"capillary_spacing",
"capillary_outer_d",
],
)
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
valid_variable = dict(
fiber=[
"beta",
"gamma",
"pitch",
"pitch_ratio",
"core_radius",
"capillary_num",
"capillary_outer_d",
"capillary_thickness",
"capillary_spacing",
"capillary_resonance_strengths",
"capillary_nested",
"he_mode",
"fit_parameters",
"input_transmission",
"n2",
],
gas=["pressure", "temperature", "gas_name", "plasma_density"],
pulse=[
"peak_power",
"mean_power",
"energy",
"quantum_noise",
"shape",
"wavelength",
"intensity_noise",
"width",
"soliton_num",
],
simulation=[
"behaviors",
"raman_type",
"tolerated_error",
"step_size",
"ideal_gas",
"readjust_wavelength",
],
)
ENVIRON_KEY_BASE = "SCGENERATOR_"
PBAR_POLICY = ENVIRON_KEY_BASE + "PBAR_POLICY"
LOG_POLICY = ENVIRON_KEY_BASE + "LOG_POLICY"

View File

@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt
from .errors import MissingParameterError
from pathlib import Path
default_parameters = dict(
input_transmission=1.0,
@@ -28,6 +27,7 @@ default_parameters = dict(
upper_wavelength_interp_limit=1900e-9,
ideal_gas=False,
readjust_wavelength=False,
recovery_last_stored=0,
)
default_plotting = dict(
@@ -36,7 +36,7 @@ default_plotting = dict(
vmin=-40,
vmax=0,
vmax_with_headroom=2,
name="plot",
out_path=Path("plot"),
avg_main_to_coherence_ratio=4,
avg_line_labels=["individual values", "mean"],
muted_style=dict(linewidth=0.5, c=(0.8, 0.8, 0.8, 0.4)),
@@ -57,76 +57,3 @@ default_plotting = dict(
text_topright_style=dict(verticalalignment="top", horizontalalignment="right"),
text_topleft_style=dict(verticalalignment="top", horizontalalignment="left"),
)
def get(section_dict, param, **kwargs):
"""checks if param is in the parameter section dict and attempts to fill in a default value
Parameters
----------
section_dict : dict
the parameters section {fiber, pulse, simulation, root} sub-dictionary
param : str
the name of the parameter (dict key)
kwargs : any
key word arguments passed to the MissingParameterError constructor
Returns
-------
dict
the updated section_dict dictionary
Raises
------
MissingFiberParameterError
raised when a parameter is missing and no default exists
"""
# whether the parameter is in the right place and valid is checked elsewhere,
# here, we just make sure it is present.
if param not in section_dict and param not in section_dict.get("variable", {}):
try:
section_dict[param] = default_parameters[param]
# LOG
except KeyError:
raise MissingParameterError(param, **kwargs)
return section_dict
def get_fiber(section_dict, param, **kwargs):
"""wrapper for fiber parameters that depend on fiber model"""
return get(section_dict, param, fiber_model=section_dict["model"], **kwargs)
def get_multiple(section_dict, params, num, **kwargs):
"""similar to th get method but works with several parameters
Parameters
----------
section_dict : dict
the parameters section {fiber, pulse, simulation, root}, sub-dictionary
params : list of str
names of the required parameters
num : int
how many of the parameters in params are required
Returns
-------
dict
the updated section_dict
Raises
------
MissingParameterError
raised when not enough parameters are provided and no defaults exist
"""
gotten = 0
for param in params:
try:
section_dict = get(section_dict, param, **kwargs)
gotten += 1
except MissingParameterError:
pass
if gotten >= num:
return section_dict
raise MissingParameterError(params, num_required=num, **kwargs)

View File

@@ -1,11 +1,10 @@
import os
from pathlib import Path
from typing import Dict, Literal, Optional, Set
from .const import ENVIRON_KEY_BASE, PBAR_POLICY, LOG_POLICY, TMP_FOLDER_KEY_BASE
from .const import ENVIRON_KEY_BASE, LOG_POLICY, PBAR_POLICY, TMP_FOLDER_KEY_BASE
def data_folder(task_id: int) -> Optional[Path]:
def data_folder(task_id: int) -> Optional[str]:
idstr = str(int(task_id))
tmp = os.getenv(TMP_FOLDER_KEY_BASE + idstr)
return tmp

View File

@@ -34,18 +34,3 @@ class DuplicateParameterError(Exception):
class IncompleteDataFolderError(FileNotFoundError):
pass
# class MissingFiberParameterError(MissingParameterError):
# def __init__(self, param, model):
# self.param = param
# self.model = model
# super().__init__(
# f"'{self.param}' is a required parameter for fiber model '{self.model}' and no default value is set"
# )
# class MissingPulseParameterError(MissingParameterError):
# def __init__(self, param):
# self.param = param
# super().__init__(f"'{self.param}' is a required pulse parameter and no default value is set")

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
from dataclasses import asdict
import itertools
import os
import shutil
@@ -11,18 +12,17 @@ import toml
from . import env, utils
from .const import (
__version__,
ENVIRON_KEY_BASE,
PARAM_FN,
PARAM_SEPARATOR,
PBAR_POLICY,
SPEC1_FN,
SPECN_FN,
TMP_FOLDER_KEY_BASE,
Z_FN,
__version__,
)
from .errors import IncompleteDataFolderError
from .logger import get_logger
from .utils.parameter import BareConfig, BareParams
PathTree = List[Tuple[Path, ...]]
@@ -88,6 +88,10 @@ def load_toml(path: os.PathLike):
path = conform_toml_path(path)
with open(path, mode="r") as file:
dico = toml.load(file)
for section in ["simulation", "fiber", "pulse", "gas"]:
dico.update(dico.pop(section, {}))
return dico
@@ -99,52 +103,15 @@ def save_toml(path: os.PathLike, dico):
return dico
def serializable(val):
"""returns True if val is serializable into a Json file"""
types = (np.ndarray, float, int, str, list, tuple)
out = isinstance(val, types)
if isinstance(val, np.ndarray):
out &= val.dtype != "complex"
return out
def prepare_for_serialization(dico: Dict[str, Any]) -> Dict[str, Any]:
"""prepares a dictionary for serialization. Some keys may not be preserved
(dropped due to no conversion available)
Parameters
----------
dico : dict
dictionary
"""
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"]
types = (np.ndarray, float, int, str, list, tuple, dict)
out = {}
for key, value in dico.items():
if key in forbiden_keys:
continue
if not isinstance(value, types):
continue
if isinstance(value, dict):
out[key] = prepare_for_serialization(value)
elif isinstance(value, np.ndarray) and value.dtype == complex:
continue
else:
out[key] = value
return out
def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path:
def save_parameters(params: BareParams, destination_dir: Path, file_name="params.toml") -> Path:
"""saves a parameter dictionary. Note that is does remove some entries, particularly
those that take a lot of space ("t", "w", ...)
Parameters
----------
param_dict : Dict[str, Any]
params : Dict[str, Any]
dictionary to save
data_dir : Path
destination_dir : Path
destination directory
Returns
@@ -152,12 +119,8 @@ def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path:
Path
path to newly created the paramter file
"""
param = param_dict.copy()
file_path = destination_dir / "params.toml"
param = prepare_for_serialization(param)
param["datetime"] = datetime.now()
param["version"] = __version__
param = params.prepare_for_dump()
file_path = destination_dir / file_name
file_path.parent.mkdir(exist_ok=True)
@@ -168,7 +131,7 @@ def save_parameters(param_dict: Dict[str, Any], destination_dir: Path) -> Path:
return file_path
def load_previous_parameters(path: os.PathLike):
def load_params(path: os.PathLike) -> BareParams:
"""loads a parameters toml files and converts data to appropriate type
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
@@ -179,15 +142,29 @@ def load_previous_parameters(path: os.PathLike):
Returns
----------
dict
flattened parameters dictionary
BareParams
params obj
"""
params = load_toml(path)
return BareParams(**params)
for k, v in params.items():
if isinstance(v, list) and isinstance(v[0], (float, int)):
params[k] = np.array(v)
return params
def load_config(path: os.PathLike) -> BareConfig:
"""loads a parameters toml files and converts data to appropriate type
It is advised to run initialize.build_sim_grid to recover some parameters that are not saved.
Parameters
----------
path : PathLike
path to the toml
Returns
----------
BareParams
config obj
"""
config = load_toml(path)
return BareConfig(**config)
def load_material_dico(name):

View File

@@ -1,7 +1,6 @@
import logging
from typing import Optional
from .env import log_policy
from .env import log_policy
# class DebugOnlyFileHandler(logging.FileHandler):
# def __init__(

View File

@@ -1,8 +1,8 @@
from typing import Type, Union
from typing import Union
import numpy as np
from scipy.interpolate import griddata, interp1d
from scipy.special import jn_zeros
from scipy.interpolate import interp1d, griddata
from numba import jit
def span(*vec):
@@ -54,7 +54,6 @@ def power_fact(x, n):
raise TypeError(f"type {type(x)} of x not supported.")
@jit(nopython=True)
def _power_fact_single(x, n):
result = 1.0
for k in range(n):
@@ -62,7 +61,6 @@ def _power_fact_single(x, n):
return result
@jit(nopython=True)
def _power_fact_array(x, n):
result = np.ones(len(x), dtype=np.float64)
for k in range(n):
@@ -70,7 +68,6 @@ def _power_fact_array(x, n):
return result
@jit(nopython=True)
def abs2(z: np.ndarray) -> np.ndarray:
return z.real ** 2 + z.imag ** 2

View File

@@ -1,36 +0,0 @@
class Parameter:
"""base class for parameters"""
all = dict(fiber=dict(), pulse=dict(), gas=dict(), simulation=dict())
help_message = "no help message lol"
def __init_subclass__(cls, section):
Parameter.all[section][cls.__name__.lower()] = cls
def __init__(self, s):
self.s = s
valid = True
try:
self.value = self._convert()
valid = self.valid()
except ValueError:
valid = False
if not valid:
raise ValueError(
f"{self.__class__.__name__} {self.__class__.help_message}. input : {self.s}"
)
def _convert(self):
value = self.conversion_func(self.s)
return value
class Wavelength(Parameter, section="pulse"):
help_message = "must be a strictly positive real number"
def valid(self):
return self.value > 0
def conversion_func(self, s: str) -> float:
return float(s)

View File

@@ -1,15 +1,14 @@
from typing import Any, Dict, List, Tuple
import numpy as np
from numpy.lib import disp
from numpy.lib.arraysetops import isin
import toml
from numba import jit
from numpy.fft import fft, ifft
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
from scipy.interpolate import interp1d
from .. import io
from ..const import hc_model_specific_parameters
from ..math import abs2, argclosest, power_fact, u_nm
from ..utils.parameter import BareParams, hc_model_specific_parameters
from . import materials as mat
from . import units
from .units import c, pi
@@ -25,7 +24,7 @@ def lambda_for_dispersion():
return np.linspace(190e-9, 3000e-9, 4000)
def is_dynamic_dispersion(params):
def is_dynamic_dispersion(pressure=None):
"""tests if the parameter dictionary implies that the dispersion profile of the fiber changes with z
Parameters
@@ -38,8 +37,8 @@ def is_dynamic_dispersion(params):
bool : True if dispersion is supposed to change with z
"""
out = False
if "pressure" in params:
out |= isinstance(params["pressure"], (tuple, list)) and len(params["pressure"]) == 2
if pressure is not None:
out |= isinstance(pressure, (tuple, list)) and len(pressure) == 2
return out
@@ -483,7 +482,19 @@ def HCPCF_dispersion(
return beta2(w, n_eff)
def dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg):
def dynamic_HCPCF_dispersion(
lambda_: np.ndarray,
pressure_values: List[float],
core_radius: float,
fiber_model: str,
model_params: Dict[str, Any],
temperature: float,
ideal_gas: bool,
w0: float,
interp_range: Tuple[float, float],
material_dico: Dict[str, Any],
deg,
):
"""returns functions for beta2 coefficients and gamma instead of static values
Parameters
@@ -504,25 +515,22 @@ def dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg):
in the fiber
"""
# store values because storing functions acts weird with dict
pressure_values = params["pressure"]
a = params["core_radius"]
fiber_model = params["fiber_model"]
model_params = {k: params[k] for k in hc_model_specific_parameters[fiber_model]}
temp = params["temperature"]
ideal_gas = params["ideal_gas"]
w0 = params["w0"]
interp_range = params["interp_range"]
A_eff = 1.5 * a ** 2
A_eff = 1.5 * core_radius ** 2
# defining function instead of storing every possilble value
pressure = lambda r: mat.pressure_from_gradient(r, *pressure_values)
beta2 = lambda r: HCPCF_dispersion(
lambda_, a, material_dico, fiber_model, model_params, pressure(r), temp, ideal_gas
lambda_,
core_radius,
material_dico,
fiber_model,
model_params,
pressure(r),
temperature,
ideal_gas,
)
n2 = lambda r: mat.non_linear_refractive_index(material_dico, pressure(r), temp)
n2 = lambda r: mat.non_linear_refractive_index(material_dico, pressure(r), temperature)
ratio_range = np.linspace(0, 1, 256)
gamma_grid = np.array([gamma_parameter(n2(r), w0, A_eff) for r in ratio_range])
@@ -640,7 +648,7 @@ def PCF_dispersion(lambda_, pitch, ratio_d, w0=None, n2=None, A_eff=None):
return beta2, gamma
def dispersion_central(fiber_model, params, deg=8):
def compute_dispersion(params: BareParams, deg=8):
"""dispatch function depending on what type of fiber is used
Parameters
@@ -660,8 +668,8 @@ def dispersion_central(fiber_model, params, deg=8):
nonlinear parameter
"""
if "dispersion_file" in params:
disp_file = np.load(params["dispersion_file"])
if params.dispersion_file is not None:
disp_file = np.load(params.dispersion_file)
lambda_ = disp_file["wavelength"]
D = disp_file["dispersion"]
beta2 = D_to_beta2(D, lambda_)
@@ -669,21 +677,20 @@ def dispersion_central(fiber_model, params, deg=8):
else:
lambda_ = lambda_for_dispersion()
beta2 = np.zeros_like(lambda_)
fiber_model = fiber_model.lower()
if fiber_model == "pcf":
if params.model == "pcf":
beta2, gamma = PCF_dispersion(
lambda_,
params["pitch"],
params["pitch_ratio"],
w0=params["w0"],
n2=params.get("n2"),
A_eff=params.get("A_eff"),
params.pitch,
params.pitch_ratio,
w0=params.w0,
n2=params.n2,
A_eff=params.A_eff,
)
else:
# Load material info
gas_name = params["gas_name"]
gas_name = params.gas_name
if gas_name == "vacuum":
material_dico = None
@@ -691,8 +698,20 @@ def dispersion_central(fiber_model, params, deg=8):
material_dico = toml.loads(io.Paths.gets("gas"))[gas_name]
# compute dispersion
if params.get("dynamic_dispersion", False):
return dynamic_HCPCF_dispersion(lambda_, params, material_dico, deg)
if params.dynamic_dispersion:
return dynamic_HCPCF_dispersion(
lambda_,
params.pressure,
params.core_radius,
params.model,
{k: getattr(params, k) for k in hc_model_specific_parameters[params.model]},
params.temperature,
params.ideal_gas,
params.w0,
params.interp_range,
material_dico,
deg,
)
else:
# actually compute the dispersion
@@ -700,31 +719,31 @@ def dispersion_central(fiber_model, params, deg=8):
beta2 = HCPCF_dispersion(
lambda_,
material_dico,
fiber_model,
{k: params[k] for k in hc_model_specific_parameters[fiber_model]},
params["pressure"],
params["temperature"],
params["ideal_gas"],
params.model,
{k: getattr(params, k) for k in hc_model_specific_parameters[params.model]},
params.pressure,
params.temperature,
params.ideal_gas,
)
if material_dico is not None:
A_eff = 1.5 * params["core_radius"] ** 2
A_eff = 1.5 * params.core_radius ** 2
n2 = mat.non_linear_refractive_index(
material_dico, params["pressure"], params["temperature"]
material_dico, params.pressure, params.temperature
)
gamma = gamma_parameter(n2, params["w0"], A_eff)
gamma = gamma_parameter(n2, params.w0, A_eff)
else:
gamma = None
# add plasma if wanted
if params["plasma_density"] > 0:
beta2 += plasma_dispersion(lambda_, params["plasma_density"])
if params.plasma_density > 0:
beta2 += plasma_dispersion(lambda_, params.plasma_density)
beta2_coef = dispersion_coefficients(lambda_, beta2, params["w0"], params["interp_range"], deg)
beta2_coef = dispersion_coefficients(lambda_, beta2, params.w0, params.interp_range, deg)
if gamma is None:
if "A_eff" in params:
gamma = gamma_parameter(params.get("n2", 2.6e-20), params["w0"], params["A_eff"])
if params.A_eff is not None:
gamma = gamma_parameter(params.n2, params.w0, params.A_eff)
else:
gamma = 0
@@ -855,15 +874,13 @@ def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=Non
"""
# Compute raman response function if necessary
f_r = 0.18
if "raman" in behaviors:
if "hr_w" == None:
raise TypeError("freq-dependent Raman response must be give")
else:
if f_r is None:
if raman_type in ["stolen", "measured"]:
f_r = 0.18
elif raman_type == "agrawal":
f_r = 0.245
if hr_w is None:
raise ValueError("freq-dependent Raman response must be give")
if f_r is None:
if raman_type == "agrawal":
f_r = 0.245
if "spm" in behaviors:
spm_part = lambda fi: (1 - f_r) * abs2(fi)
@@ -875,7 +892,6 @@ def create_non_linear_op(behaviors, w_c, w0, gamma, raman_type="stolen", f_r=Non
else:
raman_part = lambda fi: 0
spm_part = jit(spm_part, nopython=True)
ss_part = w_c / w0 if "ss" in behaviors else 0
if isinstance(gamma, (float, int)):
@@ -924,7 +940,6 @@ def fast_dispersion_op(w_c, beta_arr, power_fact_arr, where=slice(None)):
return -1j * out
@jit(nopython=True)
def _fast_disp_loop(dispersion: np.ndarray, beta_arr, power_fact_arr):
for k in range(len(beta_arr) - 1, -1, -1):
dispersion = dispersion + beta_arr[k] * power_fact_arr[k]

View File

@@ -1,7 +1,6 @@
import numpy as np
from ..logger import get_logger
from . import units
from .units import NA, c, kB

View File

@@ -11,6 +11,7 @@ n is the number of spectra at the same z position and nt is the size of the time
import itertools
import os
from pathlib import Path
from typing import Literal, Tuple
import matplotlib.pyplot as plt
@@ -18,13 +19,13 @@ import numpy as np
from numpy import pi
from numpy.fft import fft, fftshift, ifft
from scipy.interpolate import UnivariateSpline
from numba import jit
from .. import io
from ..defaults import default_plotting
from ..logger import get_logger
from ..plotting import plot_setup
from ..math import *
from ..plotting import plot_setup
from ..utils.parameter import BareParams
c = 299792458.0
hbar = 1.05457148e-34
@@ -205,6 +206,48 @@ def conform_pulse_params(
return width, t0, peak_power, energy, soliton_num
def setup_custom_field(params: BareParams) -> bool:
"""sets up a custom field function if necessary and returns
True if it did so, False otherwise
Parameters
----------
params : Dict[str, Any]
params dictionary
Returns
-------
bool
True if the field has been modified
"""
field_0 = width = peak_power = energy = None
did_set = True
if params.prev_data_dir is not None:
spec = io.load_last_spectrum(Path(params.prev_data_dir))[1]
field_0 = np.fft.ifft(spec) * np.sqrt(params.input_transmission)
elif params.field_file is not None:
field_data = np.load(params.field_file)
field_interp = interp1d(
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
)
field_0 = field_interp(params.t)
field_0 = field_0 * modify_field_ratio(
params.t,
field_0,
params.peak_power,
params.energy,
params.intensity_noise,
)
width, peak_power, energy = measure_field(params.t, field_0)
else:
did_set = False
return did_set, width, peak_power, energy, field_0
def E0_to_P0(E0, t0, shape="gaussian"):
"""convert an initial total pulse energy to a pulse peak peak_power"""
return E0 / (t0 * P0T0_to_E0_fac[shape])
@@ -223,12 +266,10 @@ def gauss_pulse(t, t0, P0, offset=0):
return np.sqrt(P0) * np.exp(-(((t - offset) / t0) ** 2))
@jit(nopython=True)
def photon_number(spectrum, w, dw, gamma):
return np.sum(1 / gamma * abs2(spectrum) / w * dw)
@jit(nopython=True)
def pulse_energy(spectrum, w, dw, _):
return np.sum(abs2(spectrum) * dw)

View File

@@ -1,7 +1,8 @@
import multiprocessing
import os
from datetime import datetime
from typing import Any, Dict, List, Tuple, Type
from pathlib import Path
from typing import Dict, List, Tuple, Type
import numpy as np
@@ -18,7 +19,14 @@ except ModuleNotFoundError:
class RK4IP:
def __init__(self, sim_params, save_data=False, job_identifier="", task_id=0, n_percent=10):
def __init__(
self,
params: initialize.Params,
save_data=False,
job_identifier="",
task_id=0,
n_percent=10,
):
"""A 1D solver using 4th order Runge-Kutta in the interaction picture
Parameters
@@ -76,31 +84,29 @@ class RK4IP:
self.logger = get_logger(self.job_identifier)
self.resuming = False
self.save_data = save_data
self._extract_params(sim_params)
self._setup_functions()
self.starting_num = sim_params.get("recovery_last_stored", 0)
self._setup_sim_parameters()
def _extract_params(self, params):
self.w_c = params.pop("w_c")
self.w0 = params.pop("w0")
self.w_power_fact = params.pop("w_power_fact")
self.spec_0 = params.pop("spec_0")
self.z_targets = params.pop("z_targets")
self.z_final = params.pop("length")
self.beta = params.pop("beta_func", params.pop("beta"))
self.gamma = params.pop("gamma_func", params.pop("gamma"))
self.behaviors = params.pop("behaviors")
self.raman_type = params.pop("raman_type", "stolen")
self.f_r = params.pop("f_r", 0)
self.hr_w = params.pop("hr_w", None)
self.adapt_step_size = params.pop("adapt_step_size", True)
self.error_ok = params.pop("error_ok")
self.dynamic_dispersion = params.pop("dynamic_dispersion", False)
self.w_c = params.w_c
self.w0 = params.w0
self.w_power_fact = params.w_power_fact
self.spec_0 = params.spec_0
self.z_targets = params.z_targets
self.z_final = params.length
self.beta = params.beta_func if params.beta_func is not None else params.beta
self.gamma = params.gamma_func if params.gamma_func is not None else params.gamma
self.behaviors = params.behaviors
self.raman_type = params.raman_type
self.hr_w = params.hr_w
self.adapt_step_size = params.adapt_step_size
self.error_ok = params.error_ok
self.dynamic_dispersion = params.dynamic_dispersion
self.starting_num = params.recovery_last_stored
self._setup_functions()
self._setup_sim_parameters()
def _setup_functions(self):
self.N_func = create_non_linear_op(
self.behaviors, self.w_c, self.w0, self.gamma, self.raman_type, self.f_r, self.hr_w
self.behaviors, self.w_c, self.w0, self.gamma, self.raman_type, hr_w=self.hr_w
)
if self.dynamic_dispersion:
@@ -303,7 +309,7 @@ class RK4IP:
class SequentialRK4IP(RK4IP):
def __init__(
self,
sim_params,
params: initialize.Params,
pbars: utils.PBars,
save_data=False,
job_identifier="",
@@ -312,7 +318,7 @@ class SequentialRK4IP(RK4IP):
):
self.pbars = pbars
super().__init__(
sim_params,
params,
save_data=save_data,
job_identifier=job_identifier,
task_id=task_id,
@@ -326,7 +332,7 @@ class SequentialRK4IP(RK4IP):
class MutliProcRK4IP(RK4IP):
def __init__(
self,
sim_params,
params: initialize.Params,
p_queue: multiprocessing.Queue,
worker_id: int,
save_data=False,
@@ -337,7 +343,7 @@ class MutliProcRK4IP(RK4IP):
self.worker_id = worker_id
self.p_queue = p_queue
super().__init__(
sim_params,
params,
save_data=save_data,
job_identifier=job_identifier,
task_id=task_id,
@@ -351,7 +357,7 @@ class MutliProcRK4IP(RK4IP):
class RayRK4IP(RK4IP):
def __init__(
self,
sim_params,
params: initialize.Params,
p_actor,
worker_id: int,
save_data=False,
@@ -362,7 +368,7 @@ class RayRK4IP(RK4IP):
self.worker_id = worker_id
self.p_actor = p_actor
super().__init__(
sim_params,
params,
save_data=save_data,
job_identifier=job_identifier,
task_id=task_id,
@@ -414,7 +420,7 @@ class Simulations:
if isinstance(method, str):
method = Simulations.simulation_methods_dict[method]
return method(param_seq, task_id)
elif param_seq.num_sim > 1 and param_seq["simulation", "parallel"]:
elif param_seq.num_sim > 1 and param_seq.config.parallel:
return Simulations.get_best_method()(param_seq, task_id)
else:
return SequencialSimulations(param_seq, task_id)
@@ -439,7 +445,7 @@ class Simulations:
self.name = self.param_seq.name
self.sim_dir = io.get_sim_dir(self.id, name_if_new=self.name)
io.save_toml(os.path.join(self.sim_dir, "initial_config.toml"), self.param_seq.config)
io.save_parameters(self.param_seq.config, self.sim_dir, file_name="initial_config.toml")
self.sim_jobs_per_node = 1
self.max_concurrent_jobs = np.inf
@@ -447,9 +453,7 @@ class Simulations:
@property
def finished_and_complete(self):
try:
io.check_data_integrity(
io.get_data_dirs(self.sim_dir), self.param_seq["simulation", "z_num"]
)
io.check_data_integrity(io.get_data_dirs(self.sim_dir), self.param_seq.config.z_num)
return True
except IncompleteDataFolderError:
return False
@@ -472,15 +476,15 @@ class Simulations:
self.new_sim(v_list_str, params)
self.finish()
def new_sim(self, v_list_str: str, params: dict):
def new_sim(self, v_list_str: str, params: initialize.Params):
"""responsible to launch a new simulation
Parameters
----------
v_list_str : str
string that uniquely identifies the simulation as returned by utils.format_variable_list
params : dict
a flattened parameter dictionary, as returned by initialize.compute_init_parameters
params : initialize.Params
computed parameters
"""
raise NotImplementedError()
@@ -507,7 +511,7 @@ class SequencialSimulations(Simulations, priority=0):
super().__init__(param_seq, task_id=task_id)
self.pbars = utils.PBars(self.param_seq.num_steps, "Simulating " + self.param_seq.name, 1)
def new_sim(self, v_list_str: str, params: Dict[str, Any]):
def new_sim(self, v_list_str: str, params: initialize.Params):
self.logger.info(f"{self.param_seq.name} : launching simulation with {v_list_str}")
SequentialRK4IP(
params, self.pbars, save_data=True, job_identifier=v_list_str, task_id=self.id
@@ -517,7 +521,7 @@ class SequencialSimulations(Simulations, priority=0):
pass
def finish(self):
pass
self.pbars.close()
class MultiProcSimulations(Simulations, priority=1):
@@ -553,7 +557,7 @@ class MultiProcSimulations(Simulations, priority=1):
worker.start()
super().run()
def new_sim(self, v_list_str: str, params: dict):
def new_sim(self, v_list_str: str, params: initialize.Params):
self.queue.put((v_list_str, params), block=True, timeout=None)
def finish(self):
@@ -576,7 +580,7 @@ class MultiProcSimulations(Simulations, priority=1):
p_queue: multiprocessing.Queue,
):
while True:
raw_data: Tuple[List[tuple], Dict[str, Any]] = queue.get()
raw_data: Tuple[List[tuple], initialize.Params] = queue.get()
if raw_data == 0:
queue.task_done()
return
@@ -635,7 +639,7 @@ class RaySimulations(Simulations, priority=2):
.remote(self.param_seq.name, self.sim_jobs_total, self.param_seq.num_steps)
)
def new_sim(self, v_list_str: str, params: dict):
def new_sim(self, v_list_str: str, params: initialize.Params):
while len(self.jobs) >= self.sim_jobs_total:
self._collect_1_job()
@@ -707,28 +711,27 @@ def new_simulation(
method: Type[Simulations] = None,
) -> Simulations:
config = io.load_toml(config_file)
config_dict = io.load_toml(config_file)
if prev_sim_dir is not None:
config.setdefault("simulation", {})
config["simulation"]["prev_sim_dir"] = str(prev_sim_dir)
config_dict["prev_sim_dir"] = str(prev_sim_dir)
task_id = np.random.randint(1e9, 1e12)
if prev_sim_dir is None:
param_seq = initialize.ParamSequence(config)
param_seq = initialize.ParamSequence(config_dict)
else:
param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config)
param_seq = initialize.ContinuationParamSequence(prev_sim_dir, config_dict)
print(f"{param_seq.name=}")
return Simulations.new(param_seq, task_id, method)
def resume_simulations(sim_dir: str, method: Type[Simulations] = None) -> Simulations:
def resume_simulations(sim_dir: Path, method: Type[Simulations] = None) -> Simulations:
task_id = np.random.randint(1e9, 1e12)
config = io.load_toml(os.path.join(sim_dir, "initial_config.toml"))
config = io.load_toml(sim_dir / "initial_config.toml")
io.set_data_folder(task_id, sim_dir)
param_seq = initialize.RecoveryParamSequence(config, task_id)

View File

@@ -2,9 +2,10 @@
# For example, nm(X) means "I give the number X in nm, figure out the ang. freq."
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
from numba.core.types.misc import Phantom
from typing import Callable, Union
import numpy as np
from numpy import isin, pi
from numpy import pi
c = 299792458.0
hbar = 1.05457148e-34
@@ -217,7 +218,7 @@ units_map = dict(
)
def get_unit(unit):
def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]:
if isinstance(unit, str):
return units_map[unit]
return unit

View File

@@ -1,55 +1,47 @@
import os
from pathlib import Path
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
import matplotlib.gridspec as gs
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from scgenerator.utils import variable_iterator
from scipy.interpolate import UnivariateSpline
from . import io, math
from .math import abs2, length, make_uniform_1D, span
from .physics import pulse, units
from .defaults import default_plotting as defaults
from .math import abs2, make_uniform_1D, span
from .physics import pulse, units
from .utils.parameter import BareParams
RangeType = Tuple[float, float, Union[str, Callable]]
def plot_setup(
folder_name=None,
file_name=None,
file_type="png",
figsize=defaults["figsize"],
params=None,
mode="default",
):
out_path: Path,
file_type: str = "png",
figsize: Tuple[float, float] = defaults["figsize"],
mode: Literal["default", "coherence", "coherence_T"] = "default",
) -> Tuple[Path, plt.Figure, Union[plt.Axes, Tuple[plt.Axes]]]:
"""It should return :
- a folder_name
- a file name
- a fig
- an axis
"""
file_name = defaults["name"] if file_name is None else file_name
out_path = defaults["name"] if out_path is None else out_path
plot_name = out_path.stem
out_dir = out_path.resolve().parent
if params is not None:
folder_name = params.get("plot.folder_name", folder_name)
file_name = params.get("plot.file_name", file_name)
file_type = params.get("plot.file_type", file_type)
figsize = params.get("plot.figsize", figsize)
file_name = plot_name + "." + file_type
out_path = out_dir / file_name
# ensure output folder_name exists
folder_name, file_name = (
os.path.split(file_name)
if folder_name is None
else (folder_name, os.path.split(file_name)[1])
)
folder_name = os.path.join(io.Paths.get("plots"), folder_name)
if not os.path.exists(os.path.abspath(folder_name)):
os.makedirs(os.path.abspath(folder_name))
os.makedirs(out_dir, exist_ok=True)
# ensure no overwrite
ind = 0
while os.path.exists(os.path.join(folder_name, file_name + "_" + str(ind) + "." + file_type)):
while (full_path := (out_dir / (plot_name + f"_{ind}." + file_type))).exists():
ind += 1
file_name = file_name + "_" + str(ind) + "." + file_type
if mode == "default":
fig, ax = plt.subplots(figsize=figsize)
@@ -78,7 +70,7 @@ def plot_setup(
else:
raise ValueError(f"mode {mode} not understood")
return folder_name, file_name, fig, ax
return full_path, fig, ax
def draw_across(ax1, xy1, ax2, xy2, clip_on=False, **kwargs):
@@ -297,9 +289,7 @@ def _finish_plot_2D(
folder_name = ""
if is_new_plot:
folder_name, file_name, fig, ax = plot_setup(
file_name=file_name, file_type=file_type, params=params
)
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
else:
fig = ax.get_figure()
@@ -345,8 +335,8 @@ def _finish_plot_2D(
cbar.ax.set_ylabel(cbar_label)
if is_new_plot:
fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200)
print(f"plot saved in {os.path.join(folder_name, file_name)}")
fig.savefig(out_path, bbox_inches="tight", dpi=200)
print(f"plot saved in {out_path}")
if cbar_label is not None:
return fig, ax, cbar.ax
else:
@@ -354,20 +344,20 @@ def _finish_plot_2D(
def plot_spectrogram(
values,
x_range,
y_range,
params,
t_res=None,
gate_width=None,
log=True,
vmin=None,
vmax=None,
cbar_label="normalized intensity (dB)",
file_type="png",
file_name=None,
cmap=None,
ax=None,
values: np.ndarray,
x_range: RangeType,
y_range: RangeType,
params: BareParams,
t_res: int = None,
gate_width: float = None,
log: bool = True,
vmin: float = None,
vmax: float = None,
cbar_label: str = "normalized intensity (dB)",
file_type: str = "png",
file_name: str = None,
cmap: str = None,
ax: plt.Axes = None,
):
"""Plots a spectrogram given a complex field in the time domain
Parameters
@@ -382,7 +372,7 @@ def plot_spectrogram(
units : function to convert from the desired units to rad/s or to time.
common functions are already defined in scgenerator.physics.units
look there for more details
params : dict
params : BareParams
parameters of the simulations
log : bool, optional
whether to compute the logarithm of the spectrogram
@@ -424,16 +414,16 @@ def plot_spectrogram(
t_win = 2 * np.max(t_range[2](np.abs(t_range[:2])))
spec_kwargs = dict(t_res=t_res, t_win=t_win, gate_width=gate_width, shift=False)
spec, new_t = pulse.spectrogram(
params["t"].copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None}
params.t.copy(), values, **{k: v for k, v in spec_kwargs.items() if v is not None}
)
# Crop and reoder axis
new_t, ind_t, _ = units.sort_axis(new_t, t_range)
new_f, ind_f, _ = units.sort_axis(params["w"], f_range)
new_f, ind_f, _ = units.sort_axis(params.w, f_range)
values = spec[ind_t][:, ind_f]
if f_range[2].type == "WL":
values = np.apply_along_axis(
units.to_WL, 1, values, params["frep"], units.m(f_range[2].inv(new_f))
units.to_WL, 1, values, params.frep, units.m(f_range[2].inv(new_f))
)
values = np.apply_along_axis(make_uniform_1D, 1, values, new_f)
@@ -463,19 +453,19 @@ def plot_spectrogram(
def plot_results_2D(
values,
plt_range,
params,
log="1D",
skip=16,
vmin=None,
vmax=None,
transpose=False,
cbar_label="normalized intensity (dB)",
file_type="png",
file_name=None,
cmap=None,
ax=None,
values: np.ndarray,
plt_range: RangeType,
params: BareParams,
log: Union[int, float, bool, str] = "1D",
skip: int = 16,
vmin: float = None,
vmax: float = None,
transpose: bool = False,
cbar_label: Optional[str] = "normalized intensity (dB)",
file_type: str = "png",
file_name: str = None,
cmap: str = None,
ax: plt.Axes = None,
):
"""
plots 2D arrays and automatically saves the plots, as well as returns it
@@ -540,27 +530,32 @@ def plot_results_2D(
# make uniform if converting to wavelength
if plt_range[2].type == "WL":
if is_spectrum:
values = np.apply_along_axis(units.to_WL, 1, values, params.get("frep", 1), x_axis)
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
values = np.array(
[make_uniform_1D(v, x_axis, n=len(x_axis), method="linear") for v in values]
)
z = params["z_targets"]
lim_diff = 1e-5 * np.max(z)
dz_s = np.diff(z)
lim_diff = 1e-5 * np.max(params.z_targets)
dz_s = np.diff(params.z_targets)
if not np.all(np.diff(dz_s) < lim_diff):
new_z = np.linspace(
*span(z), int(np.floor((np.max(z) - np.min(z)) / np.min(dz_s[dz_s > lim_diff])))
*span(params.z_targets),
int(
np.floor(
(np.max(params.z_targets) - np.min(params.z_targets))
/ np.min(dz_s[dz_s > lim_diff])
)
),
)
values = np.array(
[make_uniform_1D(v, z, n=len(new_z), method="linear") for v in values.T]
[make_uniform_1D(v, params.z_targets, n=len(new_z), method="linear") for v in values.T]
).T
z = new_z
params.z_targets = new_z
return _finish_plot_2D(
values,
x_axis,
plt_range[2].label,
z,
params.z_targets,
"propagation distance (m)",
log,
vmin,
@@ -576,20 +571,20 @@ def plot_results_2D(
def plot_results_1D(
values,
plt_range,
params,
log=False,
spacing=1,
vmin=None,
vmax=None,
ylabel=None,
yscaling=1,
file_type="pdf",
file_name=None,
ax=None,
line_label=None,
transpose=False,
values: np.ndarray,
plt_range: RangeType,
params: BareParams,
log: Union[str, int, float, bool] = False,
spacing: Union[int, float] = 1,
vmin: float = None,
vmax: float = None,
ylabel: str = None,
yscaling: float = 1,
file_type: str = "pdf",
file_name: str = None,
ax: plt.Axes = None,
line_label: str = None,
transpose: bool = False,
**line_kwargs,
):
"""
@@ -656,7 +651,7 @@ def plot_results_1D(
# make uniform if converting to wavelength
if plt_range[2].type == "WL":
if is_spectrum:
values = units.to_WL(values, params["frep"], units.m.inv(params["w"][ind]))
values = units.to_WL(values, params.frep, units.m.inv(params.w[ind]))
# change the resolution
if isinstance(spacing, float):
@@ -683,9 +678,7 @@ def plot_results_1D(
folder_name = ""
if is_new_plot:
folder_name, file_name, fig, ax = plot_setup(
file_name=file_name, file_type=file_type, params=params
)
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
else:
fig = ax.get_figure()
if transpose:
@@ -702,40 +695,40 @@ def plot_results_1D(
ax.set_xlabel(plt_range[2].label)
if is_new_plot:
fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200)
print(f"plot saved in {os.path.join(folder_name, file_name)}")
fig.savefig(out_path, bbox_inches="tight", dpi=200)
print(f"plot saved in {out_path}")
return fig, ax, x_axis, values
def _prep_plot(values, plt_range, params):
def _prep_plot(values: np.ndarray, plt_range: RangeType, params: BareParams):
is_spectrum = values.dtype == "complex"
plt_range = (*plt_range[:2], units.get_unit(plt_range[2]))
if plt_range[2].type in ["WL", "FREQ", "AFREQ"]:
x_axis = params["w"].copy()
x_axis = params.w.copy()
else:
x_axis = params["t"].copy()
x_axis = params.t.copy()
return is_spectrum, x_axis, plt_range
def plot_avg(
values,
plt_range,
params,
log=False,
spacing=1,
vmin=None,
vmax=None,
ylabel=None,
yscaling=1,
renormalize=True,
add_coherence=False,
file_type="png",
file_name=None,
ax=None,
line_labels=None,
legend=True,
legend_kwargs={},
transpose=False,
values: np.ndarray,
plt_range: RangeType,
params: BareParams,
log: Union[float, int, str, bool] = False,
spacing: Union[float, int] = 1,
vmin: float = None,
vmax: float = None,
ylabel: str = None,
yscaling: float = 1,
renormalize: bool = True,
add_coherence: bool = False,
file_type: str = "png",
file_name: str = None,
ax: plt.Axes = None,
line_labels: Tuple[str, str] = None,
legend: bool = True,
legend_kwargs: Dict[str, Any] = {},
transpose: bool = False,
):
"""
plots 1D arrays and there mean and automatically saves the plots, as well as returns it
@@ -817,8 +810,8 @@ def plot_avg(
values *= yscaling
mean_values = np.mean(values, axis=0)
if plt_range[2].type == "WL" and renormalize:
values = np.apply_along_axis(units.to_WL, 1, values, params["frep"], x_axis)
mean_values = units.to_WL(mean_values, params["frep"], x_axis)
values = np.apply_along_axis(units.to_WL, 1, values, params.frep, x_axis)
mean_values = units.to_WL(mean_values, params.frep, x_axis)
# change the resolution
if isinstance(spacing, float):
@@ -852,12 +845,12 @@ def plot_avg(
if is_new_plot:
if add_coherence:
mode = "coherence_T" if transpose else "coherence"
folder_name, file_name, fig, (top, bot) = plot_setup(
file_name=file_name, file_type=file_type, params=params, mode=mode
out_path, fig, (top, bot) = plot_setup(
out_path=Path(folder_name) / file_name, file_type=file_type, mode=mode
)
else:
folder_name, file_name, fig, top = plot_setup(
file_name=file_name, file_type=file_type, params=params
out_path, fig, top = plot_setup(
out_path=Path(folder_name) / file_name, file_type=file_type
)
bot = top
else:
@@ -923,8 +916,8 @@ def plot_avg(
top.legend(custom_lines, line_labels, **legend_kwargs)
if is_new_plot:
fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200)
print(f"plot saved in {os.path.join(folder_name, file_name)}")
fig.savefig(out_path, bbox_inches="tight", dpi=200)
print(f"plot saved in {out_path}")
if top is bot:
return fig, top
@@ -984,46 +977,6 @@ def prepare_plot_1D(values, plt_range, x_axis, yscaling=1, spacing=1, frep=80e6)
return x_axis, np.squeeze(values)
def plot_dispersion_parameter(params, plt_range):
"""
Plots the dispersion parameter D as well as the beta2 parameter over the given range
"""
# TODO allow several curves, with legends, to be plotted
x_axis = np.linspace(*plt_range[:2], 1000)
w_axis = plt_range[2](x_axis)
if "disp_obj" in params:
D = params["disp_obj"].D_w(w_axis)
beta2 = params["disp_obj"].beta2_w(w_axis)
else:
print("no dispersion information given")
return
fig, (ax_D, ax_beta2) = plt.subplots(1, 2)
ax_D.plot(x_axis, 1e6 * D)
ax_D.plot(
x_axis,
0 * x_axis,
":",
c="k",
)
ax_D.set_xlabel(plt_range[2].label)
ax_D.set_ylabel(r"Dispersion parameter $D$ ($\frac{\mathrm{ps}}{\mathrm{nm\ km}}$)")
ax_beta2.plot(x_axis, 1e27 * beta2)
ax_beta2.plot(
x_axis,
0 * x_axis,
":",
c="k",
)
ax_beta2.set_xlabel(plt_range[2].label)
ax_beta2.set_ylabel(r"$\beta_2$ parameter ($\frac{\mathrm{ps}^2}{\mathrm{km}}$)")
plt.show()
def white_bottom_cmap(name, start=0, end=1, new_name="white_background", c_back=(1, 1, 1, 1)):
"""returns a new colormap based on "name" but that has a solid bacground (default=white)"""
top = plt.get_cmap(name, 1024)

View File

@@ -128,11 +128,11 @@ def main():
args.nodes, args.cpus_per_node = distribute(sim_num, args.nodes, args.cpus_per_node)
submit_path = Path(
"submit " + final_config["name"] + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh"
"submit " + final_config.name + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh"
)
tmp_path = Path("submit tmp.sh")
job_name = f"supercontinuum {final_config['name']}"
job_name = f"supercontinuum {final_config.name}"
submit_sh = template.format(
job_name=job_name, configs_list=" ".join(f'"{c}"' for c in args.configs), **vars(args)
)

View File

@@ -1,16 +1,14 @@
import os
from collections.abc import Mapping, Sequence
from glob import glob
from typing import Any, Dict, List, Tuple
from collections.abc import Sequence
from pathlib import Path
from typing import Dict
import numpy as np
from scgenerator.const import SPECN_FN
from . import io, initialize, math
from .plotting import units
from . import initialize, io, math
from .const import SPECN_FN
from .logger import get_logger
from .plotting import units
class Spectrum(np.ndarray):
@@ -43,7 +41,7 @@ class Pulse(Sequence):
self.params = None
try:
self.params = io.load_previous_parameters(self.path / "params.toml")
self.params = io.load_params(self.path / "params.toml")
except FileNotFoundError:
self.logger.info(f"parameters corresponding to {self.path} not found")

View File

@@ -4,24 +4,23 @@ scgenerator module but some function may be used in any python program
"""
import collections
import itertools
import multiprocessing
import threading
import time
from collections import abc
from copy import deepcopy
from dataclasses import asdict, replace
from io import StringIO
from pathlib import Path
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Tuple, TypeVar, Union
from typing import Any, Dict, Iterable, Iterator, List, Tuple, TypeVar, Union
import numpy as np
from tqdm import tqdm
from . import env
from .const import PARAM_SEPARATOR, valid_variable
from .math import *
from .. import env
from ..const import PARAM_SEPARATOR
from ..math import *
from .parameter import BareConfig, BareParams
T_ = TypeVar("T_")
@@ -177,18 +176,11 @@ def progress_worker(
pbars[0].update()
def count_variations(config: dict) -> Tuple[int, int]:
def count_variations(config: BareConfig) -> Tuple[int, int]:
"""returns (sim_num, variable_params_num) where sim_num is the total number of simulations required and
variable_params_num is the number of distinct parameters that will vary."""
sim_num = 1
variable_params_num = 0
for section_name in valid_variable:
for array in config.get(section_name, {}).get("variable", {}).values():
sim_num *= len(array)
variable_params_num += 1
sim_num *= config["simulation"].get("repeat", 1)
variable_params_num = len(config.variable)
sim_num = np.prod([len(l) for l in config.variable.values()]) * config.repeat
return sim_num, variable_params_num
@@ -217,49 +209,45 @@ def format_value(value):
return str(value)
def variable_iterator(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]:
"""given a config with "variable" parameters, iterates through every possible combination,
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
Parameters
----------
config : dict
initial config dictionary
config : BareConfig
initial config obj
Yields
-------
Iterator[Tuple[List[Tuple[str, Any]], dict]]
Iterator[Tuple[List[Tuple[str, Any]], BareParams]]
variable_list : a list of (name, value) tuple of parameter name and value that are variable.
dict : a config dictionary for one simulation
params : a BareParams obj for one simulation
"""
indiv_config = deepcopy(config)
variable_dict = {
section_name: indiv_config.get(section_name, {}).pop("variable", {})
for section_name in valid_variable
}
possible_keys = []
possible_ranges = []
for section_name, section in variable_dict.items():
for key in section:
arr = variable_dict[section_name][key]
possible_keys.append((section_name, key))
possible_ranges.append(range(len(arr)))
for key, values in config.variable.items():
possible_keys.append(key)
possible_ranges.append(range(len(values)))
combinations = itertools.product(*possible_ranges)
for combination in combinations:
indiv_config = {}
variable_list = []
for i, key in enumerate(possible_keys):
parameter_value = variable_dict[key[0]][key[1]][combination[i]]
indiv_config[key[0]][key[1]] = parameter_value
variable_list.append((key[1], parameter_value))
yield variable_list, indiv_config
parameter_value = config.variable[key][combination[i]]
indiv_config[key] = parameter_value
variable_list.append((key, parameter_value))
param_dict = asdict(config)
param_dict.pop("variable")
param_dict.update(indiv_config)
yield variable_list, BareParams(**param_dict)
def required_simulations(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
def required_simulations(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]:
"""takes the output of `scgenerator.utils.variable_iterator` which is a new dict per different
parameter set and iterates through every single necessary simulation
@@ -273,48 +261,19 @@ def required_simulations(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]
dict : a config dictionary for one simulation
"""
i = 0 # unique sim id
for variable_only, full_config in variable_iterator(config):
for j in range(config["simulation"]["repeat"]):
for variable_only, bare_params in variable_iterator(config):
for j in range(config.repeat):
variable_ind = [("id", i)] + variable_only + [("num", j)]
i += 1
yield variable_ind, full_config
yield variable_ind, bare_params
def deep_update(d: Mapping, u: Mapping) -> dict:
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = deep_update(d.get(k, {}), v)
else:
d[k] = v
return d
def override_config(new: Dict[str, Any], old: Dict[str, Any] = None) -> Dict[str, Any]:
def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
"""makes sure all the parameters set in new are there, leaves untouched parameters in old"""
if old is None:
return new
out = deepcopy(old)
for section_name, section in new.items():
if isinstance(section, Mapping):
for param_name, value in section.items():
if param_name == "variable" and isinstance(value, Mapping):
out[section_name].setdefault("variable", {})
for p, v in value.items():
# override previously unvariable param
if p in old[section_name]:
del out[section_name][p]
out[section_name]["variable"][p] = v
else:
# override previously variable param
if (
"variable" in old[section_name]
and isinstance(old[section_name]["variable"], Mapping)
and param_name in old[section_name]["variable"]
):
del out[section_name]["variable"][param_name]
if len(out[section_name]["variable"]) == 0:
del out[section_name["variable"]]
out[section_name][param_name] = value
else:
out[section_name] = section
return out
return BareConfig(**new)
variable = deepcopy(old.variable)
variable.update(new.pop("variable", {})) # add new variable
for k in new:
variable.pop(k) # remove old ones
return replace(old, variable=variable, **{k: None for k in variable}, **new)

View File

@@ -0,0 +1,442 @@
import datetime
from copy import copy
from dataclasses import asdict, dataclass
from functools import lru_cache
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import numpy as np
from ..const import __version__
@lru_cache
def type_checker(*types):
def _type_checker_wrapper(validator, n=None):
if isinstance(validator, str) and n is not None:
_name = validator
validator = lambda *args: None
def _type_checker_wrapped(name, n):
if not isinstance(n, types):
raise TypeError(
f"{name!r} value must be of type {' or '.join(format(t) for t in types)} "
f"instead of {type(n)}"
)
validator(name, n)
if n is None:
return _type_checker_wrapped
else:
_type_checker_wrapped(_name, n)
return _type_checker_wrapper
@type_checker(str)
def string(name, n):
if len(n) == 0:
raise ValueError(f"{name!r} must not be empty")
def in_range_excl(_min, _max):
@type_checker(float, int)
def _in_range(name, n):
if n <= _min or n >= _max:
raise ValueError(f"{name!r} must be between {_min} and {_max} (exclusive)")
return _in_range
def in_range_incl(_min, _max):
@type_checker(float, int)
def _in_range(name, n):
if n < _min or n > _max:
raise ValueError(f"{name!r} must be between {_min} and {_max} (inclusive)")
return _in_range
def boolean(name, n):
if not n is True and not n is False:
raise ValueError(f"{name!r} must be True or False")
@lru_cache
def non_negative(*types):
@type_checker(*types)
def _non_negative(name, n):
if n < 0:
raise ValueError(f"{name!r} must be non negative")
return _non_negative
@lru_cache
def positive(*types):
@type_checker(*types)
def _positive(name, n):
if n <= 0:
raise ValueError(f"{name!r} must be positive")
return _positive
@type_checker(tuple, list)
def int_pair(name, t):
invalid = len(t) != 2
for m in t:
if invalid or not isinstance(m, int):
raise ValueError(f"{name!r} must be a list or a tuple of 2 int")
def literal(*l):
l = set(l)
@type_checker(str)
def _string(name, s):
if not s in l:
raise ValueError(f"{name!r} must be a str in {l}")
return _string
def validator_list(validator):
"""returns a new validator that applies validator to each el of an iterable"""
@type_checker(list, tuple)
def _list_validator(name, l):
for i, el in enumerate(l):
validator(name + f"[{i}]", el)
return _list_validator
def validator_or(*validators):
"""combines many validators and raises an exception only if all of them raise an exception"""
n = len(validators)
def _or_validator(name, value):
errors = []
for validator in validators:
try:
validator(name, value)
except (ValueError, TypeError) as e:
errors.append(e)
errors.sort(key=lambda el: isinstance(el, ValueError))
if len(errors) == n:
raise errors[-1]
return _or_validator
def validator_and(*validators):
def _and_validator(name, n):
for v in validators:
v(name, n)
return _and_validator
@type_checker(list, tuple, np.ndarray)
def num_list(name, l):
for i, el in enumerate(l):
type_checker(int, float)(name + f"[{i}]", el)
def func_validator(name, n):
if not callable(n):
raise TypeError(f"{name!r} must be callable")
class Parameter:
def __init__(self, validator, converter=None, default=None):
"""Single parameter
Parameters
----------
tpe : type
type of the paramter
validators : Callable[[str, Any], None]
signature : validator(name, value)
must raise a ValueError when value doesn't fit the criteria checked by
validator. name is passed to validator to be included in the error message
converter : Callable, optional
converts a valid value (for example, str.lower), by default None
default : callable, optional
factory function for a default value (for example, list), by default None
"""
self.validator = validator
self.converter = converter
self.default = default
def __set_name__(self, owner, name):
self.name = name
def __get__(self, instance, owner):
if not instance:
return self
return instance.__dict__[self.name]
def __delete__(self, instance):
del instance.__dict__[self.name]
def __set__(self, instance, value):
if isinstance(value, Parameter):
defaut = None if self.default is None else copy(self.default)
instance.__dict__[self.name] = defaut
else:
if value is not None:
self.validator(self.name, value)
if self.converter is not None:
value = self.converter(value)
instance.__dict__[self.name] = value
class VariableParameter:
def __init__(self, parameterBase):
self.pbase = parameterBase
def __set_name__(self, owner, name):
self.name = name
def __get__(self, instance, owner):
if not instance:
return self
return instance.__dict__[self.name]
def __delete__(self, instance):
del instance.__dict__[self.name]
def __set__(self, instance, value: dict):
if isinstance(value, VariableParameter):
value = {}
else:
for k, v in value.items():
if k not in valid_variable:
raise TypeError(f"{k!r} is not a valide variable parameter")
if len(v) == 0:
raise ValueError(f"variable parameter {k!r} must not be empty")
p = getattr(self.pbase, k)
for el in v:
p.validator(k, el)
instance.__dict__[self.name] = value
valid_variable = {
"beta",
"gamma",
"pitch",
"pitch_ratio",
"core_radius",
"capillary_num",
"capillary_outer_d",
"capillary_thickness",
"capillary_spacing",
"capillary_resonance_strengths",
"capillary_nested",
"he_mode",
"fit_parameters",
"input_transmission",
"n2",
"pressure",
"temperature",
"gas_name",
"plasma_density" "peak_power",
"mean_power",
"peak_power",
"energy",
"quantum_noise",
"shape",
"wavelength",
"intensity_noise",
"width",
"soliton_num",
"behaviors",
"raman_type",
"tolerated_error",
"step_size",
"ideal_gas",
"readjust_wavelength",
}
hc_model_specific_parameters = dict(
marcatili=["core_radius", "he_mode"],
marcatili_adjusted=["core_radius", "he_mode", "fit_parameters"],
hasan=[
"core_radius",
"capillary_num",
"capillary_thickness",
"capillary_resonance_strengths",
"capillary_nested",
"capillary_spacing",
"capillary_outer_d",
],
)
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
@dataclass
class BareParams:
"""
This class defines each valid parameter's name, type and valid value but doesn't provide
any method to act on those. For that, use initialize.Params
"""
# root
name: str = Parameter(string)
prev_data_dir: str = Parameter(string)
# # fiber
input_transmission: float = Parameter(in_range_incl(0, 1))
gamma: float = Parameter(non_negative(float, int))
n2: float = Parameter(non_negative(float, int))
effective_mode_diameter: float = Parameter(positive(float, int))
A_eff: float = Parameter(non_negative(float, int))
pitch: float = Parameter(in_range_excl(0, 1e-3))
pitch_ratio: float = Parameter(in_range_excl(0, 1))
core_radius: float = Parameter(in_range_excl(0, 1e-3))
he_mode: Tuple[int, int] = Parameter(int_pair)
fit_parameters: Tuple[int, int] = Parameter(int_pair)
beta: Iterable[float] = Parameter(num_list)
dispersion_file: str = Parameter(string)
model: str = Parameter(literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"))
length: float = Parameter(non_negative(float, int))
capillary_num: int = Parameter(positive(int))
capillary_outer_d: float = Parameter(in_range_excl(0, 1e-3))
capillary_thickness: float = Parameter(in_range_excl(0, 1e-3))
capillary_spacing: float = Parameter(in_range_excl(0, 1e-3))
capillary_resonance_strengths: Iterable[float] = Parameter(num_list)
capillary_nested: int = Parameter(non_negative(int))
# gas
gas_name: str = Parameter(literal("vacuum", "helium", "air"), converter=str.lower)
pressure: Union[float, Iterable[float]] = Parameter(
validator_or(non_negative(float, int), num_list)
)
temperature: float = Parameter(positive(float, int))
plasma_density: float = Parameter(non_negative(float, int))
# pulse
field_file: str = Parameter(string)
repetition_rate: float = Parameter(non_negative(float, int))
peak_power: float = Parameter(positive(float, int))
mean_power: float = Parameter(positive(float, int))
energy: float = Parameter(positive(float, int))
soliton_num: float = Parameter(positive(float, int))
quantum_noise: bool = Parameter(boolean)
shape: str = Parameter(literal("gaussian", "sech"))
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9))
intensity_noise: float = Parameter(in_range_incl(0, 1))
width: float = Parameter(in_range_excl(0, 1e-9))
t0: float = Parameter(in_range_excl(0, 1e-9))
# simulation
behaviors: str = Parameter(validator_list(literal("spm", "raman", "ss")))
parallel: bool = Parameter(boolean)
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
ideal_gas: bool = Parameter(boolean)
repeat: int = Parameter(positive(int))
t_num: int = Parameter(positive(int))
z_num: int = Parameter(positive(int))
time_window: float = Parameter(positive(float, int))
dt: float = Parameter(in_range_excl(0, 5e-15))
tolerated_error: float = Parameter(in_range_excl(1e-15, 1e-5))
step_size: float = Parameter(positive(float, int))
lower_wavelength_interp_limit: float = Parameter(in_range_incl(100e-9, 3000e-9))
upper_wavelength_interp_limit: float = Parameter(in_range_incl(200e-9, 5000e-9))
frep: float = Parameter(positive(float, int))
prev_sim_dir: str = Parameter(string)
readjust_wavelength: bool = Parameter(boolean)
recovery_last_stored: int = Parameter(non_negative(int))
# computed
field_0: np.ndarray = Parameter(type_checker(np.ndarray))
spec_0: np.ndarray = Parameter(type_checker(np.ndarray))
w: np.ndarray = Parameter(type_checker(np.ndarray))
w_c: np.ndarray = Parameter(type_checker(np.ndarray))
t: np.ndarray = Parameter(type_checker(np.ndarray))
L_D: float = Parameter(non_negative(float, int))
L_NL: float = Parameter(non_negative(float, int))
L_sol: float = Parameter(non_negative(float, int))
dynamic_dispersion: bool = Parameter(boolean)
adapt_step_size: bool = Parameter(boolean)
error_ok: float = Parameter(positive(float))
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
beta_func: Callable[[float], List[float]] = Parameter(func_validator)
gamma_func: Callable[[float], float] = Parameter(func_validator)
def prepare_for_dump(self) -> Dict[str, Any]:
param = asdict(self)
param = BareParams.strip_params_dict(param)
param["datetime"] = datetime.datetime.now()
param["version"] = __version__
return param
@staticmethod
def strip_params_dict(dico: Dict[str, Any]) -> Dict[str, Any]:
"""prepares a dictionary for serialization. Some keys may not be preserved
(dropped because they take a lot of space and can be exactly reconstructed)
Parameters
----------
dico : dict
dictionary
"""
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"]
types = (np.ndarray, float, int, str, list, tuple, dict)
out = {}
for key, value in dico.items():
if key in forbiden_keys:
continue
if not isinstance(value, types):
continue
if isinstance(value, dict):
out[key] = BareParams.strip_params_dict(value)
elif isinstance(value, np.ndarray) and value.dtype == complex:
continue
else:
out[key] = value
if "variable" in out and len(out["variable"]) == 0:
del out["variable"]
return out
@dataclass
class BareConfig(BareParams):
variable: dict = VariableParameter(BareParams)
if __name__ == "__main__":
numero = type_checker(int)
@numero
def natural_number(name, n):
if n < 0:
raise ValueError(f"{name!r} must be positive")
try:
numero("a", np.arange(45))
except Exception as e:
print(e)
try:
natural_number("b", -1)
except Exception as e:
print(e)
try:
natural_number("c", 1.0)
except Exception as e:
print(e)
try:
natural_number("d", 1)
print("success !")
except Exception as e:
print(e)

View File

@@ -0,0 +1,37 @@
from numba.core import config
from scgenerator.initialize import Config, Params, BareParams
from scgenerator.utils import variable_iterator, override_config
from scgenerator.io import load_toml
from pprint import pprint
from dataclasses import asdict
dico = load_toml("testing/configs/ensure_consistency/good2.toml")
out = dict(variable=dict())
for k, v in dico.items():
if isinstance(v, dict):
for kk, vv in v.items():
if kk == "variable":
for kkk, vvv in vv.items():
out["variable"][kkk] = vvv
else:
out[kk] = vv
pprint(out)
p = Config(**out)
print(p)
for l, c in variable_iterator(p):
print(l, c.width, c.intensity_noise)
print()
config2 = override_config(dict(width=1.2e-13, variable=dict(peak_power=[1e5, 2e5])), p)
print(
f"{config2.variable=}",
f"{config2.intensity_noise=}",
f"{config2.width=}",
f"{config2.peak_power=}",
)
par = BareParams()
print(all(v is None for v in vars(par).values()))