compute function written
This commit is contained in:
@@ -6,31 +6,32 @@ from typing import Any, Dict, Iterator, List, Tuple, Union
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy import pi
|
|
||||||
|
|
||||||
from . import io, utils
|
from . import io, utils
|
||||||
from .defaults import default_parameters
|
from .defaults import default_parameters
|
||||||
from .errors import *
|
from .errors import *
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .math import power_fact
|
from .utils import override_config, required_simulations
|
||||||
from .physics import fiber, pulse, units
|
from .utils.evaluator import Evaluator
|
||||||
from .utils import override_config, required_simulations, evaluator
|
from .utils.parameter import (
|
||||||
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
|
BareConfig,
|
||||||
|
BareParams,
|
||||||
global_evaluator = evaluator.Evaluator()
|
hc_model_specific_parameters,
|
||||||
|
mandatory_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Params(BareParams):
|
class Params(BareParams):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bare(cls, bare: BareParams):
|
def from_bare(cls, bare: BareParams):
|
||||||
return cls(**asdict(bare))
|
param_dict = {k: v for k, v in asdict(bare).items() if v is not None}
|
||||||
|
evaluator = Evaluator.default()
|
||||||
def __post_init__(self):
|
evaluator.set(**param_dict)
|
||||||
self.compute()
|
for p_name in mandatory_parameters:
|
||||||
|
evaluator.compute(p_name)
|
||||||
def compute(self):
|
new_param_dict = {k: v for k, v in evaluator.params.items() if k in param_dict}
|
||||||
logger = get_logger(__name__)
|
return cls(**new_param_dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from ..logger import get_logger
|
|||||||
|
|
||||||
from .. import io
|
from .. import io
|
||||||
from ..math import abs2, argclosest, power_fact, u_nm
|
from ..math import abs2, argclosest, power_fact, u_nm
|
||||||
from ..utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
|
|
||||||
from ..utils.cache import np_cache
|
from ..utils.cache import np_cache
|
||||||
from . import materials as mat
|
from . import materials as mat
|
||||||
from . import units
|
from . import units
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from ..defaults import default_plotting
|
|||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..math import *
|
from ..math import *
|
||||||
from ..plotting import plot_setup
|
from ..plotting import plot_setup
|
||||||
from ..utils.parameter import BareParams
|
|
||||||
from . import units
|
from . import units
|
||||||
|
|
||||||
c = 299792458.0
|
c = 299792458.0
|
||||||
@@ -343,54 +342,6 @@ def load_field_file(
|
|||||||
return field_0, peak_power, energy, width
|
return field_0, peak_power, energy, width
|
||||||
|
|
||||||
|
|
||||||
def setup_custom_field(params: BareParams) -> bool:
|
|
||||||
"""sets up a custom field function if necessary and returns
|
|
||||||
True if it did so, False otherwise
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
params : Dict[str, Any]
|
|
||||||
params dictionary
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True if the field has been modified
|
|
||||||
"""
|
|
||||||
field_0 = params.field_0
|
|
||||||
width = params.width
|
|
||||||
peak_power = params.peak_power
|
|
||||||
energy = params.energy
|
|
||||||
|
|
||||||
did_set = True
|
|
||||||
|
|
||||||
if params.prev_data_dir is not None:
|
|
||||||
spec = io.load_last_spectrum(Path(params.prev_data_dir))[1]
|
|
||||||
field_0 = np.fft.ifft(spec)
|
|
||||||
elif params.field_file is not None:
|
|
||||||
field_data = np.load(params.field_file)
|
|
||||||
field_interp = interp1d(
|
|
||||||
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
|
|
||||||
)
|
|
||||||
field_0 = field_interp(params.t)
|
|
||||||
|
|
||||||
field_0 = field_0 * modify_field_ratio(
|
|
||||||
params.t,
|
|
||||||
field_0,
|
|
||||||
params.peak_power,
|
|
||||||
params.energy,
|
|
||||||
params.intensity_noise,
|
|
||||||
)
|
|
||||||
width, peak_power, energy = measure_field(params.t, field_0)
|
|
||||||
else:
|
|
||||||
did_set = False
|
|
||||||
|
|
||||||
if did_set:
|
|
||||||
field_0 = field_0 * np.sqrt(params.input_transmission)
|
|
||||||
|
|
||||||
return did_set, width, peak_power, energy, field_0
|
|
||||||
|
|
||||||
|
|
||||||
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:
|
def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndarray) -> float:
|
||||||
"""
|
"""
|
||||||
finds a new wavelength parameter such that the maximum of the spectrum corresponding
|
finds a new wavelength parameter such that the maximum of the spectrum corresponding
|
||||||
@@ -481,6 +432,14 @@ def shot_noise(w_c, w0, T, dt):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def add_shot_noise(
|
||||||
|
field_0: np.ndarray, quantum_noise: bool, w_c: bool, w0: float, time_window: float, dt: float
|
||||||
|
) -> np.ndarray:
|
||||||
|
if quantum_noise:
|
||||||
|
field_0 = field_0 + shot_noise(w_c, w0, time_window, dt)
|
||||||
|
return field_0
|
||||||
|
|
||||||
|
|
||||||
def mean_phase(spectra):
|
def mean_phase(spectra):
|
||||||
"""computes the mean phase of spectra
|
"""computes the mean phase of spectra
|
||||||
Parameter
|
Parameter
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class RK4IP:
|
|||||||
self.raman_type = params.raman_type
|
self.raman_type = params.raman_type
|
||||||
self.hr_w = params.hr_w
|
self.hr_w = params.hr_w
|
||||||
self.adapt_step_size = params.adapt_step_size
|
self.adapt_step_size = params.adapt_step_size
|
||||||
self.error_ok = params.error_ok
|
self.error_ok = params.tolerated_error
|
||||||
self.dynamic_dispersion = params.dynamic_dispersion
|
self.dynamic_dispersion = params.dynamic_dispersion
|
||||||
self.starting_num = params.recovery_last_stored
|
self.starting_num = params.recovery_last_stored
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,12 @@ class EvalStat:
|
|||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
|
@classmethod
|
||||||
|
def default(cls) -> "Evaluator":
|
||||||
|
evaluator = cls()
|
||||||
|
evaluator.append(*default_rules)
|
||||||
|
return evaluator
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.rules: dict[str, list[Rule]] = defaultdict(list)
|
self.rules: dict[str, list[Rule]] = defaultdict(list)
|
||||||
self.params = {}
|
self.params = {}
|
||||||
@@ -111,7 +117,7 @@ class Evaluator:
|
|||||||
self.rules[t].append(r)
|
self.rules[t].append(r)
|
||||||
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
self.rules[t].sort(key=lambda el: el.targets[t], reverse=True)
|
||||||
|
|
||||||
def update(self, **params: Any):
|
def set(self, **params: Any):
|
||||||
self.params.update(params)
|
self.params.update(params)
|
||||||
for k in params:
|
for k in params:
|
||||||
self.eval_stats[k].priority = np.inf
|
self.eval_stats[k].priority = np.inf
|
||||||
@@ -256,9 +262,29 @@ default_rules: list[Rule] = [
|
|||||||
Rule("field_0", np.fft.ifft, ["spec_0"]),
|
Rule("field_0", np.fft.ifft, ["spec_0"]),
|
||||||
Rule("spec_0", pulse.load_previous_spectrum, priorities=3),
|
Rule("spec_0", pulse.load_previous_spectrum, priorities=3),
|
||||||
Rule(
|
Rule(
|
||||||
["field_0", "peak_power", "energy", "width"], pulse.load_field_file, priorities=[2, 1, 1, 1]
|
["pre_field_0", "peak_power", "energy", "width"],
|
||||||
|
pulse.load_field_file,
|
||||||
|
[
|
||||||
|
"field_file",
|
||||||
|
"t",
|
||||||
|
"peak_power",
|
||||||
|
"energy",
|
||||||
|
"intensity_noise",
|
||||||
|
"noise_correlation",
|
||||||
|
"quantum_noise",
|
||||||
|
"w_c",
|
||||||
|
"w0",
|
||||||
|
"time_window",
|
||||||
|
"dt",
|
||||||
|
],
|
||||||
|
priorities=[2, 1, 1, 1],
|
||||||
|
),
|
||||||
|
Rule("pre_field_0", pulse.initial_field, priorities=1),
|
||||||
|
Rule(
|
||||||
|
"field_0",
|
||||||
|
pulse.add_shot_noise,
|
||||||
|
["pre_field_0", "quantum_noise", "w_c", "w0", "time_window", "dt"],
|
||||||
),
|
),
|
||||||
Rule("field_0", pulse.initial_field, priorities=1),
|
|
||||||
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]),
|
||||||
Rule("peak_power", pulse.soliton_num_to_peak_power),
|
Rule("peak_power", pulse.soliton_num_to_peak_power),
|
||||||
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
|
Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]),
|
||||||
@@ -329,7 +355,7 @@ def main():
|
|||||||
|
|
||||||
evalor = Evaluator()
|
evalor = Evaluator()
|
||||||
evalor.append(*default_rules)
|
evalor.append(*default_rules)
|
||||||
evalor.update(
|
evalor.set(
|
||||||
**{
|
**{
|
||||||
"length": 1,
|
"length": 1,
|
||||||
"z_num": 128,
|
"z_num": 128,
|
||||||
@@ -343,8 +369,9 @@ def main():
|
|||||||
"width": 30e-15,
|
"width": 30e-15,
|
||||||
"mean_power": 100e-3,
|
"mean_power": 100e-3,
|
||||||
"n2": 2.4e-20,
|
"n2": 2.4e-20,
|
||||||
"A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_max.npz",
|
"A_eff_file": "/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PM2000D/PM2000D_A_eff_marcuse.npz",
|
||||||
"model": "pcf",
|
"model": "pcf",
|
||||||
|
"quantum_noise": True,
|
||||||
"pitch": 1.2e-6,
|
"pitch": 1.2e-6,
|
||||||
"pitch_ratio": 0.5,
|
"pitch_ratio": 0.5,
|
||||||
}
|
}
|
||||||
@@ -354,6 +381,7 @@ def main():
|
|||||||
print(evalor.params["l"][evalor.params["l"] > 0].min())
|
print(evalor.params["l"][evalor.params["l"] > 0].min())
|
||||||
evalor.compute("spec_0")
|
evalor.compute("spec_0")
|
||||||
plt.plot(evalor.params["l"], abs(evalor.params["spec_0"]) ** 2)
|
plt.plot(evalor.params["l"], abs(evalor.params["spec_0"]) ** 2)
|
||||||
|
plt.yscale("log")
|
||||||
plt.show()
|
plt.show()
|
||||||
print(evalor.compute("gamma"))
|
print(evalor.compute("gamma"))
|
||||||
print(evalor.compute("beta2"))
|
print(evalor.compute("beta2"))
|
||||||
|
|||||||
@@ -342,6 +342,27 @@ hc_model_specific_parameters = dict(
|
|||||||
)
|
)
|
||||||
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
|
"""dependecy map only includes actual fiber parameters and exclude gas parameters"""
|
||||||
|
|
||||||
|
mandatory_parameters = [
|
||||||
|
"name",
|
||||||
|
"w_c",
|
||||||
|
"w",
|
||||||
|
"w0",
|
||||||
|
"w_power_fact",
|
||||||
|
"alpha",
|
||||||
|
"spec_0",
|
||||||
|
"z_targets",
|
||||||
|
"length",
|
||||||
|
"beta2_coefficients",
|
||||||
|
"gamma_arr",
|
||||||
|
"behaviors",
|
||||||
|
"raman_type",
|
||||||
|
"hr_w",
|
||||||
|
"adapt_step_size",
|
||||||
|
"tollerated_error",
|
||||||
|
"dynamic_dispersion",
|
||||||
|
"recovery_last_stored",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BareParams:
|
class BareParams:
|
||||||
@@ -445,7 +466,6 @@ class BareParams:
|
|||||||
L_sol: float = Parameter(non_negative(float, int))
|
L_sol: float = Parameter(non_negative(float, int))
|
||||||
dynamic_dispersion: bool = Parameter(boolean)
|
dynamic_dispersion: bool = Parameter(boolean)
|
||||||
adapt_step_size: bool = Parameter(boolean)
|
adapt_step_size: bool = Parameter(boolean)
|
||||||
error_ok: float = Parameter(positive(float))
|
|
||||||
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
|
hr_w: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
|
z_targets: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
const_qty: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
|
|||||||
Reference in New Issue
Block a user