compute function written
This commit is contained in:
@@ -6,31 +6,32 @@ from typing import Any, Dict, Iterator, List, Tuple, Union
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from numpy import pi
|
||||
|
||||
from . import io, utils
|
||||
from .defaults import default_parameters
|
||||
from .errors import *
|
||||
from .logger import get_logger
|
||||
from .math import power_fact
|
||||
from .physics import fiber, pulse, units
|
||||
from .utils import override_config, required_simulations, evaluator
|
||||
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
|
||||
|
||||
global_evaluator = evaluator.Evaluator()
|
||||
from .utils import override_config, required_simulations
|
||||
from .utils.evaluator import Evaluator
|
||||
from .utils.parameter import (
|
||||
BareConfig,
|
||||
BareParams,
|
||||
hc_model_specific_parameters,
|
||||
mandatory_parameters,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Params(BareParams):
|
||||
@classmethod
|
||||
def from_bare(cls, bare: BareParams):
|
||||
return cls(**asdict(bare))
|
||||
|
||||
def __post_init__(self):
|
||||
self.compute()
|
||||
|
||||
def compute(self):
|
||||
logger = get_logger(__name__)
|
||||
param_dict = {k: v for k, v in asdict(bare).items() if v is not None}
|
||||
evaluator = Evaluator.default()
|
||||
evaluator.set(**param_dict)
|
||||
for p_name in mandatory_parameters:
|
||||
evaluator.compute(p_name)
|
||||
new_param_dict = {k: v for k, v in evaluator.params.items() if k in param_dict}
|
||||
return cls(**new_param_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -12,7 +12,6 @@ from ..logger import get_logger
|
||||
|
||||
from .. import io
|
||||
from ..math import abs2, argclosest, power_fact, u_nm
|
||||
from ..utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
|
||||
from ..utils.cache import np_cache
|
||||
from . import materials as mat
|
||||
from . import units
|
||||
|
||||
@@ -30,7 +30,6 @@ from ..defaults import default_plotting
|
||||
from ..logger import get_logger
|
||||
from ..math import *
|
||||
from ..plotting import plot_setup
|
||||
from ..utils.parameter import BareParams
|
||||
from . import units
|
||||
|
||||
c = 299792458.0
|
||||
@@ -343,54 +342,6 @@ def load_field_file(
|
||||
return field_0, peak_power, energy, width
|
||||
|
||||
|
||||
def setup_custom_field(params: BareParams) -> bool:
|
||||
"""sets up a custom field function if necessary and returns
|
||||
True if it did so, False otherwise
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : Dict[str, Any]
|
||||
params dictionary
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the field has been modified
|
||||
"""
|
||||
field_0 = params.field_0
|
||||
width = params.width
|
||||
peak_power = params.peak_power
|
||||
energy = params.energy
|
||||
|
||||
did_set = True
|
||||
|
||||
if params.prev_data_dir is not None:
|
||||
spec = io.load_last_spectrum(Path(params.prev_data_dir))[1]
|
||||
field_0 = np.fft.ifft(spec)
|
||||
elif params.field_file is not None:
|
||||
field_data = np.load(params.field_file)
|
||||
field_interp = interp1d(
|
||||
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
|
||||
)
|
||||
field_0 = field_interp(params.t)
|
||||
|
||||
field_0 = field_0 * modify_field_ratio(
|
||||
params.t,
|
||||
field_0,
|
||||
params.peak_power,
|
||||
params.energy,
|
||||
params.intensity_noise,
|
||||
)
|
||||
width, peak_power, energy = measure_field(params.t, field_0)
|
||||
else:
|
||||
did_set = False
|
||||
|
||||
if did_set:
|
||||
field_0 = field_0 * np.sqrt(params.input_transmission)
|
||||
|
||||
return did_set, width, peak_power, energy, field_0
|
||||
|
||||
|
||||
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:
|
||||
"""
|
||||
finds a new wavelength parameter such that the maximum of the spectrum corresponding
|
||||
@@ -481,6 +432,14 @@ def shot_noise(w_c, w0, T, dt):
|
||||
return out
|
||||
|
||||
|
||||
def add_shot_noise(
|
||||
field_0: np.ndarray, quantum_noise: bool, w_c: bool, w0: float, time_window: float, dt: float
|
||||
) -> np.ndarray:
|
||||
if quantum_noise:
|
||||
field_0 = field_0 + shot_noise(w_c, w0, time_window, dt)
|
||||
return field_0
|
||||
|
||||
|
||||
def mean_phase(spectra):
|
||||
"""computes the mean phase of spectra
|
||||
Parameter
|
||||
|
||||
@@ -80,7 +80,7 @@ class RK4IP:
|
||||
self.raman_type = params.raman_type
|
||||
self.hr_w = params.hr_w
|
||||
self.adapt_step_size = params.adapt_step_size
|
||||
self.error_ok = params.error_ok
|
||||
self.error_ok = params.tolerated_error
|
||||
self.dynamic_dispersion = params.dynamic_dispersion
|
||||
self.starting_num = params.recovery_last_stored
|
||||
|
||||
|
||||
@@ -97,6 +97,12 @@ class EvalStat:
|
||||
|
||||
|
||||
class Evaluator:
|
||||
@classmethod
|
||||
def default(cls) -> "Evaluator":
|
||||
evaluator = cls()
|
||||
evaluator.append(*default_rules)
|
||||
return evaluator
|
||||
|
||||
def __init__(self):
|
||||
self.rules: dict[str, list[Rule]] = defaultdict(list)
|
||||
self.params = {}
|
||||
@@ -111,7 +117,7 @@ class Evaluator:
|
||||
self.rules[t].append(r)
|
||||
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
||||
|
||||
def update(self, **params: Any):
|
||||
def set(self, **params: Any):
|
||||
self.params.update(params)
|
||||
for k in params:
|
||||
self.eval_stats[k].priority = np.inf
|
||||
@@ -256,9 +262,29 @@ default_rules: list[Rule] = [
|
||||
Rule("field_0", np.fft.ifft, ["spec_0"]),
|
||||
Rule("spec_0", pulse.load_previous_spectrum, priorities=3),
|
||||
Rule(
|
||||
["field_0", "peak_power", "energy", "width"], pulse.load_field_file, priorities=[2, 1, 1, 1]
|
||||
["pre_field_0", "peak_power", "energy", "width"],
|
||||
pulse.load_field_file,
|
||||
[
|
||||
"field_file",
|
||||
"t",
|
||||
"peak_power",
|
||||
"energy",
|
||||
"intensity_noise",
|
||||
"noise_correlation",
|
||||
"quantum_noise",
|
||||
"w_c",
|
||||
"w0",
|
||||
"time_window",
|
||||
"dt",
|
||||
],
|
||||
priorities=[2, 1, 1, 1],
|
||||
),
|
||||
Rule("pre_field_0", pulse.initial_field, priorities=1),
|
||||
Rule(
|
||||
"field_0",
|
||||
pulse.add_shot_noise,
|
||||
["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"],
|
||||
),
|
||||
Rule("field_0", pulse.initial_field, priorities=1),
|
||||
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
||||
Rule("peak_power", pulse.soliton_num_to_peak_power),
|
||||
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
|
||||
@@ -329,7 +355,7 @@ def main():
|
||||
|
||||
evalor = Evaluator()
|
||||
evalor.append(*default_rules)
|
||||
evalor.update(
|
||||
evalor.set(
|
||||
**{
|
||||
"length": 1,
|
||||
"z_num": 128,
|
||||
@@ -343,8 +369,9 @@ def main():
|
||||
"width": 30e-15,
|
||||
"mean_power": 100e-3,
|
||||
"n2": 2.4e-20,
|
||||
"A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_max.npz",
|
||||
"A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_marcuse.npz",
|
||||
"model": "pcf",
|
||||
"quantum_noise": True,
|
||||
"pitch": 1.2e-6,
|
||||
"pitch_ratio": 0.5,
|
||||
}
|
||||
@@ -354,6 +381,7 @@ def main():
|
||||
print(evalor.params["l"][evalor.params["l"] > 0].min())
|
||||
evalor.compute("spec_0")
|
||||
plt.plot(evalor.params["l"], abs(evalor.params["spec_0"]) ** 2)
|
||||
plt.yscale("log")
|
||||
plt.show()
|
||||
print(evalor.compute("gamma"))
|
||||
print(evalor.compute("beta2"))
|
||||
|
||||
@@ -342,6 +342,27 @@ hc_model_specific_parameters = dict(
|
||||
)
|
||||
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
|
||||
|
||||
mandatory_parameters = [
|
||||
"name",
|
||||
"w_c",
|
||||
"w",
|
||||
"w0",
|
||||
"w_power_fact",
|
||||
"alpha",
|
||||
"spec_0",
|
||||
"z_targets",
|
||||
"length",
|
||||
"beta2_coefficients",
|
||||
"gamma_arr",
|
||||
"behaviors",
|
||||
"raman_type",
|
||||
"hr_w",
|
||||
"adapt_step_size",
|
||||
"tollerated_error",
|
||||
"dynamic_dispersion",
|
||||
"recovery_last_stored",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BareParams:
|
||||
@@ -445,7 +466,6 @@ class BareParams:
|
||||
L_sol: float = Parameter(non_negative(float, int))
|
||||
dynamic_dispersion: bool = Parameter(boolean)
|
||||
adapt_step_size: bool = Parameter(boolean)
|
||||
error_ok: float = Parameter(positive(float))
|
||||
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
|
||||
Reference in New Issue
Block a user