compute function written

This commit is contained in:
Benoît Sierro
2021-08-27 09:28:45 +02:00
parent 4bcf5add60
commit 6c869d5c6c
6 changed files with 78 additions and 71 deletions

View File

@@ -6,31 +6,32 @@ from typing import Any, Dict, Iterator, List, Tuple, Union
from collections import defaultdict
import numpy as np
from numpy import pi
from . import io, utils
from .defaults import default_parameters
from .errors import *
from .logger import get_logger
from .math import power_fact
from .physics import fiber, pulse, units
from .utils import override_config, required_simulations, evaluator
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
global_evaluator = evaluator.Evaluator()
from .utils import override_config, required_simulations
from .utils.evaluator import Evaluator
from .utils.parameter import (
BareConfig,
BareParams,
hc_model_specific_parameters,
mandatory_parameters,
)
@dataclass
class Params(BareParams):
@classmethod
def from_bare(cls, bare: BareParams):
return cls(**asdict(bare))
def __post_init__(self):
self.compute()
def compute(self):
logger = get_logger(__name__)
param_dict = {k: v for k, v in asdict(bare).items() if v is not None}
evaluator = Evaluator.default()
evaluator.set(**param_dict)
for p_name in mandatory_parameters:
evaluator.compute(p_name)
new_param_dict = {k: v for k, v in evaluator.params.items() if k in param_dict}
return cls(**new_param_dict)
@dataclass

View File

@@ -12,7 +12,6 @@ from ..logger import get_logger
from .. import io
from ..math import abs2, argclosest, power_fact, u_nm
from ..utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
from ..utils.cache import np_cache
from . import materials as mat
from . import units

View File

@@ -30,7 +30,6 @@ from ..defaults import default_plotting
from ..logger import get_logger
from ..math import *
from ..plotting import plot_setup
from ..utils.parameter import BareParams
from . import units
c = 299792458.0
@@ -343,54 +342,6 @@ def load_field_file(
return field_0, peak_power, energy, width
def setup_custom_field(params: BareParams) -> bool:
"""sets up a custom field function if necessary and returns
True if it did so, False otherwise
Parameters
----------
params : Dict[str, Any]
params dictionary
Returns
-------
bool
True if the field has been modified
"""
field_0 = params.field_0
width = params.width
peak_power = params.peak_power
energy = params.energy
did_set = True
if params.prev_data_dir is not None:
spec = io.load_last_spectrum(Path(params.prev_data_dir))[1]
field_0 = np.fft.ifft(spec)
elif params.field_file is not None:
field_data = np.load(params.field_file)
field_interp = interp1d(
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
)
field_0 = field_interp(params.t)
field_0 = field_0 * modify_field_ratio(
params.t,
field_0,
params.peak_power,
params.energy,
params.intensity_noise,
)
width, peak_power, energy = measure_field(params.t, field_0)
else:
did_set = False
if did_set:
field_0 = field_0 * np.sqrt(params.input_transmission)
return did_set, width, peak_power, energy, field_0
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:
"""
finds a new wavelength parameter such that the maximum of the spectrum corresponding
@@ -481,6 +432,14 @@ def shot_noise(w_c, w0, T, dt):
return out
def add_shot_noise(
field_0: np.ndarray, quantum_noise: bool, w_c: bool, w0: float, time_window: float, dt: float
) -> np.ndarray:
if quantum_noise:
field_0 = field_0 + shot_noise(w_c, w0, time_window, dt)
return field_0
def mean_phase(spectra):
"""computes the mean phase of spectra
Parameter

View File

@@ -80,7 +80,7 @@ class RK4IP:
self.raman_type = params.raman_type
self.hr_w = params.hr_w
self.adapt_step_size = params.adapt_step_size
self.error_ok = params.error_ok
self.error_ok = params.tolerated_error
self.dynamic_dispersion = params.dynamic_dispersion
self.starting_num = params.recovery_last_stored

View File

@@ -97,6 +97,12 @@ class EvalStat:
class Evaluator:
@classmethod
def default(cls) -> "Evaluator":
evaluator = cls()
evaluator.append(*default_rules)
return evaluator
def __init__(self):
self.rules: dict[str, list[Rule]] = defaultdict(list)
self.params = {}
@@ -111,7 +117,7 @@ class Evaluator:
self.rules[t].append(r)
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
def update(self, **params: Any):
def set(self, **params: Any):
self.params.update(params)
for k in params:
self.eval_stats[k].priority = np.inf
@@ -256,9 +262,29 @@ default_rules: list[Rule] = [
Rule("field_0", np.fft.ifft, ["spec_0"]),
Rule("spec_0", pulse.load_previous_spectrum, priorities=3),
Rule(
["field_0", "peak_power", "energy", "width"], pulse.load_field_file, priorities=[2, 1, 1, 1]
["pre_field_0", "peak_power", "energy", "width"],
pulse.load_field_file,
[
"field_file",
"t",
"peak_power",
"energy",
"intensity_noise",
"noise_correlation",
"quantum_noise",
"w_c",
"w0",
"time_window",
"dt",
],
priorities=[2, 1, 1, 1],
),
Rule("pre_field_0", pulse.initial_field, priorities=1),
Rule(
"field_0",
pulse.add_shot_noise,
["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"],
),
Rule("field_0", pulse.initial_field, priorities=1),
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
Rule("peak_power", pulse.soliton_num_to_peak_power),
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
@@ -329,7 +355,7 @@ def main():
evalor = Evaluator()
evalor.append(*default_rules)
evalor.update(
evalor.set(
**{
"length": 1,
"z_num": 128,
@@ -343,8 +369,9 @@ def main():
"width": 30e-15,
"mean_power": 100e-3,
"n2": 2.4e-20,
"A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_max.npz",
"A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_marcuse.npz",
"model": "pcf",
"quantum_noise": True,
"pitch": 1.2e-6,
"pitch_ratio": 0.5,
}
@@ -354,6 +381,7 @@ def main():
print(evalor.params["l"][evalor.params["l"] > 0].min())
evalor.compute("spec_0")
plt.plot(evalor.params["l"], abs(evalor.params["spec_0"]) ** 2)
plt.yscale("log")
plt.show()
print(evalor.compute("gamma"))
print(evalor.compute("beta2"))

View File

@@ -342,6 +342,27 @@ hc_model_specific_parameters = dict(
)
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
mandatory_parameters = [
"name",
"w_c",
"w",
"w0",
"w_power_fact",
"alpha",
"spec_0",
"z_targets",
"length",
"beta2_coefficients",
"gamma_arr",
"behaviors",
"raman_type",
"hr_w",
"adapt_step_size",
"tollerated_error",
"dynamic_dispersion",
"recovery_last_stored",
]
@dataclass
class BareParams:
@@ -445,7 +466,6 @@ class BareParams:
L_sol: float = Parameter(non_negative(float, int))
dynamic_dispersion: bool = Parameter(boolean)
adapt_step_size: bool = Parameter(boolean)
error_ok: float = Parameter(positive(float))
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))