adjust wl

This commit is contained in:
Benoît Sierro
2021-05-28 16:02:20 +02:00
parent ce9a11e16e
commit a50e14f765
2 changed files with 16 additions and 5 deletions

View File

@@ -1,6 +1,6 @@
import os import os
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Dict, Iterator, List, Set, Tuple from typing import Any, Dict, Iterator, List, Set, Tuple, Union
import numpy as np import numpy as np
from numpy import pi from numpy import pi
@@ -12,13 +12,15 @@ from . import defaults, io, utils
from .const import hc_model_specific_parameters, valid_param_types, valid_variable from .const import hc_model_specific_parameters, valid_param_types, valid_variable
from .errors import * from .errors import *
from .logger import get_logger from .logger import get_logger
from .math import length, power_fact from .math import abs2, length, power_fact
from .physics import fiber, pulse, units from .physics import fiber, pulse, units
from .utils import count_variations, override_config, required_simulations from .utils import count_variations, override_config, required_simulations
class ParamSequence(Mapping): class ParamSequence(Mapping):
def __init__(self, config): def __init__(self, config: Union[Dict[str, Any], os.PathLike]):
if not isinstance(config, Mapping):
config = io.load_toml(config)
self.config = validate(config) self.config = validate(config)
self.name = self.config["name"] self.name = self.config["name"]
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
@@ -598,6 +600,7 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
if "mean_power" in params: if "mean_power" in params:
params["energy"] = params["mean_power"] / params["repetition_rate"] params["energy"] = params["mean_power"] / params["repetition_rate"]
custom_field = True
if "field_file" in params: if "field_file" in params:
field_data = np.load(params["field_file"]) field_data = np.load(params["field_file"])
field_interp = interp1d( field_interp = interp1d(
@@ -610,6 +613,7 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
params = _evalutate_custom_field_equation(params) params = _evalutate_custom_field_equation(params)
params = _comform_custom_field(params) params = _comform_custom_field(params)
else: else:
custom_field = False
params = _update_pulse_parameters(params) params = _update_pulse_parameters(params)
logger.info(f"computed initial N = {params['soliton_num']:.3g}") logger.info(f"computed initial N = {params['soliton_num']:.3g}")
@@ -632,6 +636,12 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
params["spec_0"] = np.fft.fft(params["field_0"]) params["spec_0"] = np.fft.fft(params["field_0"])
# central wavelength may be off with custom fields
if custom_field:
delta_w = params["w_c"][np.argmax(abs2(params["spec_0"]))]
logger.debug(f"had to adjust w by {delta_w}")
params["wavelength"] = units.m.inv(units.m(params["wavelength"]) - delta_w)
_update_frequency_domain(params)
return params return params
@@ -655,7 +665,7 @@ def _comform_custom_field(params):
params["width"], params["peak_power"], params["energy"] = pulse.measure_field( params["width"], params["peak_power"], params["energy"] = pulse.measure_field(
params["t"], params["field_0"] params["t"], params["field_0"]
) )
wl = params["wavelength"]
return params return params

View File

@@ -4,6 +4,7 @@ import matplotlib.gridspec as gs
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from matplotlib.colors import ListedColormap from matplotlib.colors import ListedColormap
from scgenerator.utils import variable_iterator
from scipy.interpolate import UnivariateSpline from scipy.interpolate import UnivariateSpline
from . import io, math from . import io, math
@@ -703,7 +704,7 @@ def plot_results_1D(
if is_new_plot: if is_new_plot:
fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200) fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200)
print(f"plot saved in {os.path.join(folder_name, file_name)}") print(f"plot saved in {os.path.join(folder_name, file_name)}")
return fig, ax return fig, ax, x_axis, values
def _prep_plot(values, plt_range, params): def _prep_plot(values, plt_range, params):