From a50e14f765c9c19e126e5f0a611e9d29b9917b09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Fri, 28 May 2021 16:02:20 +0200 Subject: [PATCH] adjust wl --- src/scgenerator/initialize.py | 18 ++++++++++++++---- src/scgenerator/plotting.py | 3 ++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index be9ba4d..feeb007 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -1,6 +1,6 @@ import os from collections.abc import Mapping -from typing import Any, Dict, Iterator, List, Set, Tuple +from typing import Any, Dict, Iterator, List, Set, Tuple, Union import numpy as np from numpy import pi @@ -12,13 +12,15 @@ from . import defaults, io, utils from .const import hc_model_specific_parameters, valid_param_types, valid_variable from .errors import * from .logger import get_logger -from .math import length, power_fact +from .math import abs2, length, power_fact from .physics import fiber, pulse, units from .utils import count_variations, override_config, required_simulations class ParamSequence(Mapping): - def __init__(self, config): + def __init__(self, config: Union[Dict[str, Any], os.PathLike]): + if not isinstance(config, Mapping): + config = io.load_toml(config) self.config = validate(config) self.name = self.config["name"] self.logger = get_logger(__name__) @@ -598,6 +600,7 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]: if "mean_power" in params: params["energy"] = params["mean_power"] / params["repetition_rate"] + custom_field = True if "field_file" in params: field_data = np.load(params["field_file"]) field_interp = interp1d( @@ -610,6 +613,7 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]: params = _evalutate_custom_field_equation(params) params = _comform_custom_field(params) else: + custom_field = False params = _update_pulse_parameters(params) logger.info(f"computed initial N = {params['soliton_num']:.3g}") @@ -632,6 +636,12 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]: params["spec_0"] = np.fft.fft(params["field_0"]) + # central wavelength may be off with custom fields + if custom_field: + delta_w = params["w_c"][np.argmax(abs2(params["spec_0"]))] + logger.debug(f"had to adjust w by {delta_w}") + params["wavelength"] = units.m.inv(units.m(params["wavelength"]) - delta_w) + _update_frequency_domain(params) return params @@ -655,7 +665,7 @@ def _comform_custom_field(params): params["width"], params["peak_power"], params["energy"] = pulse.measure_field( params["t"], params["field_0"] ) - wl = params["wavelength"] + return params diff --git a/src/scgenerator/plotting.py b/src/scgenerator/plotting.py index c7de0e1..b45cb78 100644 --- a/src/scgenerator/plotting.py +++ b/src/scgenerator/plotting.py @@ -4,6 +4,7 @@ import matplotlib.gridspec as gs import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import ListedColormap +from scgenerator.utils import variable_iterator from scipy.interpolate import UnivariateSpline from . import io, math @@ -703,7 +704,7 @@ def plot_results_1D( if is_new_plot: fig.savefig(os.path.join(folder_name, file_name), bbox_inches="tight", dpi=200) print(f"plot saved in {os.path.join(folder_name, file_name)}") - return fig, ax + return fig, ax, x_axis, values def _prep_plot(values, plt_range, params):