working on better dispersion interpolation

This commit is contained in:
Benoît Sierro
2021-06-24 10:13:12 +02:00
parent 97a19d4ffb
commit 5a7bf53e1c
13 changed files with 209 additions and 72 deletions

View File

@@ -5,3 +5,4 @@ from .physics import fiber, materials, pulse, simulate, units
from .physics.simulate import RK4IP, new_simulation, resume_simulations from .physics.simulate import RK4IP, new_simulation, resume_simulations
from .plotting import plot_avg, plot_results_1D, plot_results_2D, plot_spectrogram from .plotting import plot_avg, plot_results_1D, plot_results_2D, plot_spectrogram
from .spectra import Pulse from .spectra import Pulse
from . import utils

View File

@@ -104,7 +104,7 @@ def create_parser():
init_plot_parser.add_argument("config", help="path to the config file") init_plot_parser.add_argument("config", help="path to the config file")
init_plot_parser.add_argument( init_plot_parser.add_argument(
"--dispersion-limits", "--dispersion-limits",
"-s", "-d",
default=None, default=None,
type=float, type=float,
nargs=2, nargs=2,

View File

@@ -23,7 +23,7 @@ default_parameters = dict(
parallel=True, parallel=True,
repeat=1, repeat=1,
tolerated_error=1e-11, tolerated_error=1e-11,
lower_wavelength_interp_limit=300e-9, lower_wavelength_interp_limit=100e-9,
upper_wavelength_interp_limit=1900e-9, upper_wavelength_interp_limit=1900e-9,
interp_degree=8, interp_degree=8,
ideal_gas=False, ideal_gas=False,

View File

@@ -85,8 +85,8 @@ class Params(BareParams):
logger = get_logger(__name__) logger = get_logger(__name__)
self.interp_range = ( self.interp_range = (
max(self.lower_wavelength_interp_limit, units.m.inv(np.max(self.w[self.w > 0]))), max(self.lower_wavelength_interp_limit, self.l[self.l > 0].min()),
min(self.upper_wavelength_interp_limit, units.m.inv(np.min(self.w[self.w > 0]))), min(self.upper_wavelength_interp_limit, self.l[self.l > 0].max()),
) )
temp_gamma = None temp_gamma = None
@@ -106,7 +106,7 @@ class Params(BareParams):
if self.gamma is None: if self.gamma is None:
self.gamma = temp_gamma self.gamma = temp_gamma
logger.info(f"using computed \u0263 = {self.gamma:.2e} W/m^2") logger.info(f"using computed \u0263 = {self.gamma:.2e} W/m\u00B2")
# Raman response # Raman response
if "raman" in self.behaviors: if "raman" in self.behaviors:

View File

@@ -63,6 +63,6 @@ def configure_logger(logger: logging.Logger):
stream_handler.setLevel(print_lvl) stream_handler.setLevel(print_lvl)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
logger.setLevel(min(print_lvl, file_lvl)) logger.setLevel(logging.DEBUG)
logger.already_configured = True logger.already_configured = True
return logger return logger

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal, Tuple from typing import Any, Dict, Iterable, List, Literal, Tuple, Union
import numpy as np import numpy as np
import toml import toml
@@ -6,6 +6,8 @@ from numpy.fft import fft, ifft
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
from ..logger import get_logger
from .. import io from .. import io
from ..math import abs2, argclosest, power_fact, u_nm from ..math import abs2, argclosest, power_fact, u_nm
from ..utils.parameter import BareParams, hc_model_specific_parameters from ..utils.parameter import BareParams, hc_model_specific_parameters
@@ -15,14 +17,14 @@ from . import units
from .units import c, pi from .units import c, pi
def lambda_for_dispersion(): def lambda_for_dispersion(left: float, right: float) -> np.ndarray:
"""Returns a wl vector for dispersion calculation """Returns a wl vector for dispersion calculation
Returns Returns
------- -------
array of wl values array of wl values
""" """
return np.linspace(190e-9, 3000e-9, 4000) return np.arange(left - 2e-9, right + 3e-9, 1e-9)
def is_dynamic_dispersion(pressure=None): def is_dynamic_dispersion(pressure=None):
@@ -679,7 +681,7 @@ def compute_dispersion(params: BareParams):
gamma = None gamma = None
else: else:
interp_range = params.interp_range interp_range = params.interp_range
lambda_ = lambda_for_dispersion() lambda_ = lambda_for_dispersion(*interp_range)
beta2 = np.zeros_like(lambda_) beta2 = np.zeros_like(lambda_)
if params.model == "pcf": if params.model == "pcf":
@@ -773,7 +775,7 @@ def dispersion_coefficients(
beta2_coef : 1D array beta2_coef : 1D array
Taylor coefficients in decreasing order Taylor coefficients in decreasing order
""" """
logger = get_logger()
if interp_range is None: if interp_range is None:
r = slice(2, -2) r = slice(2, -2)
else: else:
@@ -783,15 +785,50 @@ def dispersion_coefficients(
r = (lambda_ > max(lambda_[2], interp_range[0])) & ( r = (lambda_ > max(lambda_[2], interp_range[0])) & (
lambda_ < min(lambda_[-2], interp_range[1]) lambda_ < min(lambda_[-2], interp_range[1])
) )
logger.debug(
f"interpolating dispersion between {lambda_[r].min()*1e9:.1f}nm and {lambda_[r].max()*1e9:.1f}nm"
)
# import matplotlib.pyplot as plt
# we get the beta2 Taylor coeffiecients by making a fit around w0 # we get the beta2 Taylor coeffiecients by making a fit around w0
w_c = units.m(lambda_) - w0 w_c = units.m(lambda_) - w0
# interp = interp1d(w_c[r], beta2[r])
# w_c = np.linspace(w_c)
# fig, ax = plt.subplots()
# ax.plot(w_c[r], beta2[r])
# fig.show()
fit = Chebyshev.fit(w_c[r], beta2[r], deg) fit = Chebyshev.fit(w_c[r], beta2[r], deg)
beta2_coef = cheb2poly(fit.convert().coef) * np.cumprod([1] + list(range(1, deg + 1))) beta2_coef = cheb2poly(fit.convert().coef) * np.cumprod([1] + list(range(1, deg + 1)))
return beta2_coef return beta2_coef
def dispersion_from_coefficients(
w_c: np.ndarray, beta: Union[list[float], np.ndarray]
) -> np.ndarray:
"""computes the dispersion profile (beta2) from the beta coefficients
Parameters
----------
w_c : np.ndarray, shape (n, )
centered angular frequency (0 <=> pump frequency)
beta : Iterable[float]
beta coefficients
Returns
-------
np.ndarray, shape (n, )
beta2 as function of w_c
"""
coef = np.array(beta) / np.cumprod([1] + list(range(1, len(beta))))
beta_arr = np.zeros_like(w_c)
for k, b in reversed(list(enumerate(coef))):
beta_arr = beta_arr + b * w_c ** k
return beta_arr
def delayed_raman_t(t, raman_type="stolen"): def delayed_raman_t(t, raman_type="stolen"):
""" """
computes the unnormalized temporal Raman response function applied to the array t computes the unnormalized temporal Raman response function applied to the array t
@@ -1007,3 +1044,16 @@ def effective_core_radius(lambda_, core_radius, s=0.08, h=200e-9):
def effective_radius_HCARF(core_radius, t, f1, f2, lambda_): def effective_radius_HCARF(core_radius, t, f1, f2, lambda_):
"""eq. 3 in Hasan 2018""" """eq. 3 in Hasan 2018"""
return f1 * core_radius * (1 - f2 * lambda_ ** 2 / (core_radius * t)) return f1 * core_radius * (1 - f2 * lambda_ ** 2 / (core_radius * t))
if __name__ == "__main__":
w = np.linspace(0, 1, 4096)
c = np.arange(8)
import time
t = time.time()
for _ in range(10000):
dispersion_from_coefficients(w, c)
print((time.time() - t) / 10, "ms")

View File

@@ -12,13 +12,16 @@ n is the number of spectra at the same z position and nt is the size of the time
import itertools import itertools
import os import os
from pathlib import Path from pathlib import Path
from typing import Literal, Tuple from typing import Literal, Tuple, TypeVar
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from numpy import pi from numpy import pi
from numpy.fft import fft, fftshift, ifft from numpy.fft import fft, fftshift, ifft
from scipy import optimize
from scipy.interpolate import UnivariateSpline from scipy.interpolate import UnivariateSpline
from scipy.optimize import minimize_scalar
from scipy.optimize.optimize import OptimizeResult
from .. import io from .. import io
from ..defaults import default_plotting from ..defaults import default_plotting
@@ -30,7 +33,7 @@ from . import units
c = 299792458.0 c = 299792458.0
hbar = 1.05457148e-34 hbar = 1.05457148e-34
T = TypeVar("T")
# #
fwhm_to_T0_fac = dict( fwhm_to_T0_fac = dict(
@@ -535,14 +538,23 @@ def peak_ind(values, mam=None):
am = np.argmax(values) am = np.argmax(values)
else: else:
m, am = mam m, am = mam
try:
left_ind = ( left_ind = (
am am
- np.where((values[am:0:-1] - values[am - 1 :: -1] < 0) & (values[am:0:-1] < m / 2))[0][0] - np.where((values[am:0:-1] - values[am - 1 :: -1] <= 0) & (values[am:0:-1] < m / 2))[
) 0
][0]
) - 3
except IndexError:
left_ind = 0
try:
right_ind = ( right_ind = (
am + np.where((values[am:-1] - values[am + 1 :] < 0) & (values[am:-1] < m / 2))[0][0] am + np.where((values[am:-1] - values[am + 1 :] <= 0) & (values[am:-1] < m / 2))[0][0]
) ) + 3
return left_ind - 3, right_ind + 3 except IndexError:
right_ind = len(values) - 1
return left_ind, right_ind
def setup_splines(x_axis, values, mam=None): def setup_splines(x_axis, values, mam=None):
@@ -895,3 +907,32 @@ def measure_field(t: np.ndarray, field: np.ndarray) -> Tuple[float, float, float
peak_power = intensity.max() peak_power = intensity.max()
energy = np.trapz(intensity, t) energy = np.trapz(intensity, t)
return fwhm, peak_power, energy return fwhm, peak_power, energy
def remove_2nd_order_dispersion(
spectrum: T, w_c: np.ndarray, beta2: float, max_z: float = -1.0
) -> tuple[T, OptimizeResult]:
"""attempts to remove 2nd order dispersion from a complex spectrum
Parameters
----------
spectrum : np.ndarray or Spectrum, shape (n, )
spectrum from which to remove 2nd order dispersion
w_c : np.ndarray, shape (n, )
corresponding centered angular frequencies (w-w0)
beta2 : float
2nd order dispersion coefficient
Returns
-------
np.ndarray, shape (n, )
spectrum with 2nd order dispersion removed
"""
# makeshift_t = np.linspace(0, 1, len(w_c))
propagate = lambda z: spectrum * np.exp(-0.5j * beta2 * w_c ** 2 * z)
def score(z):
return 1 / np.max(abs2(np.fft.ifft(propagate(z))))
opti = minimize_scalar(score, bracket=(max_z, 0))
return propagate(opti.x), opti

View File

@@ -6,6 +6,9 @@ import re
from threading import settrace from threading import settrace
from typing import Callable, TypeVar, Union from typing import Callable, TypeVar, Union
from dataclasses import dataclass from dataclasses import dataclass
from matplotlib import pyplot as plt
from numpy.lib.arraysetops import isin
from ..utils.parameter import Parameter, type_checker from ..utils.parameter import Parameter, type_checker
import numpy as np import numpy as np
from numpy import pi from numpy import pi
@@ -255,7 +258,8 @@ def sort_axis(axis, plt_range: PlotRange):
# slice y according to the given ranges # slice y according to the given ranges
y = y[ct][:, cw] y = y[ct][:, cw]
""" """
if isinstance(plt_range, tuple):
plt_range = PlotRange(*plt_range)
r = np.array((plt_range.left, plt_range.right), dtype="float") r = np.array((plt_range.left, plt_range.right), dtype="float")
indices = np.arange(len(axis))[ indices = np.arange(len(axis))[

View File

@@ -30,7 +30,7 @@ def plot_setup(
- an axis - an axis
""" """
out_path = defaults["name"] if out_path is None else out_path out_path = defaults["name"] if out_path is None else out_path
plot_name = out_path.stem plot_name = out_path.name.replace(f".{file_type}", "")
out_dir = out_path.resolve().parent out_dir = out_path.resolve().parent
file_name = plot_name + "." + file_type file_name = plot_name + "." + file_type
@@ -286,9 +286,8 @@ def _finish_plot_2D(
if isinstance(ax, tuple) and len(ax) > 1: if isinstance(ax, tuple) and len(ax) > 1:
ax, cbar_ax = ax[0], ax[1] ax, cbar_ax = ax[0], ax[1]
folder_name = ""
if is_new_plot: if is_new_plot:
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type) out_path, fig, ax = plot_setup(out_path=Path(file_name), file_type=file_type)
else: else:
fig = ax.get_figure() fig = ax.get_figure()

View File

@@ -1,7 +1,7 @@
from itertools import cycle from itertools import cycle
import itertools import itertools
from pathlib import Path from pathlib import Path
from typing import Iterable from typing import Any, Iterable, Optional
from cycler import cycler from cycler import cycler
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -12,10 +12,16 @@ from ..utils.parameter import BareParams
from ..initialize import ParamSequence from ..initialize import ParamSequence
from ..physics import units, fiber from ..physics import units, fiber
from ..spectra import Pulse from ..spectra import Pulse
from ..utils import pretty_format_value from ..utils import pretty_format_value, pretty_format_from_file_name, auto_crop
from .. import env, math from .. import env, math
def fingerprint(params: BareParams):
h1 = hash(params.field_0.tobytes())
h2 = tuple(params.beta)
return h1, h2
def plot_all(sim_dir: Path, limits: list[str]): def plot_all(sim_dir: Path, limits: list[str]):
for p in sim_dir.glob("*"): for p in sim_dir.glob("*"):
if not p.is_dir(): if not p.is_dir():
@@ -26,7 +32,13 @@ def plot_all(sim_dir: Path, limits: list[str]):
left, right, unit = lim.split(",") left, right, unit = lim.split(",")
left = float(left) left = float(left)
right = float(right) right = float(right)
pulse.plot_2D(left, right, unit, file_name=p.parent / f"{p.name}_{left}_{right}_{unit}") pulse.plot_2D(
left,
right,
unit,
file_name=p.parent
/ f"{pretty_format_from_file_name(p.name)} {left} {right} {unit}",
)
def plot_init_field_spec( def plot_init_field_spec(
@@ -49,7 +61,7 @@ def plot_init_field_spec(
def plot_dispersion(config_path: Path, lim: tuple[float, float] = None): def plot_dispersion(config_path: Path, lim: tuple[float, float] = None):
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7), sharex=True) fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7))
left.grid() left.grid()
right.grid() right.grid()
all_labels = [] all_labels = []
@@ -62,7 +74,7 @@ def plot_dispersion(config_path: Path, lim: tuple[float, float] = None):
lbl = plot_1_dispersion(lim, left, right, style, lbl, params) lbl = plot_1_dispersion(lim, left, right, style, lbl, params)
all_labels.append(lbl) all_labels.append(lbl)
finish_plot(fig, left, right, all_labels, params) finish_plot(fig, right, all_labels, params)
def plot_init( def plot_init(
@@ -77,8 +89,8 @@ def plot_init(
all_labels = [] all_labels = []
already_plotted = set() already_plotted = set()
for style, lbl, params in plot_helper(config_path): for style, lbl, params in plot_helper(config_path):
if (bbb := hash(params.field_0.tobytes())) not in already_plotted: if (fp := fingerprint(params)) not in already_plotted:
already_plotted.add(bbb) already_plotted.add(fp)
else: else:
continue continue
lbl = plot_1_dispersion(lim_disp, tl, tr, style, lbl, params) lbl = plot_1_dispersion(lim_disp, tl, tr, style, lbl, params)
@@ -90,8 +102,8 @@ def plot_init(
def plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params): def plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params):
field = math.abs2(params.field_0) field = math.abs2(params.field_0)
spec = math.abs2(params.spec_0) spec = math.abs2(params.spec_0)
t = units.fs.inv(params.t) t = units.To.fs(params.t)
wl = units.nm.inv(params.w) wl = units.To.nm(params.w)
lbl.append(f"max at {wl[spec.argmax()]:.1f} nm") lbl.append(f"max at {wl[spec.argmax()]:.1f} nm")
@@ -100,54 +112,68 @@ def plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params):
mt &= t >= lim_t[0] mt &= t >= lim_t[0]
mt &= t <= lim_t[1] mt &= t <= lim_t[1]
else: else:
mt = find_lim(t, field) mt = auto_crop(t, field)
ml = np.ones_like(wl, dtype=bool) ml = np.ones_like(wl, dtype=bool)
if lim_l is not None: if lim_l is not None:
ml &= t >= lim_l[0] ml &= wl >= lim_l[0]
ml &= t <= lim_l[1] ml &= wl <= lim_l[1]
else: else:
ml = find_lim(wl, spec) ml = auto_crop(wl, spec)
left.plot(t[mt], field[mt]) left.plot(t[mt], field[mt])
right.plot(wl[ml], spec[ml], label=" ", **style) right.plot(wl[ml], spec[ml], label=" ", **style)
return lbl return lbl
def plot_1_dispersion(lim, left, right, style, lbl, params): def plot_1_dispersion(
coef = params.beta / np.cumprod([1] + list(range(1, len(params.beta)))) lim: Optional[tuple[float, float]],
w_c = params.w_c left: plt.Axes,
right: plt.Axes,
beta_arr = np.zeros_like(w_c) style: dict[str, Any],
for k, beta in reversed(list(enumerate(coef))): lbl: list[str],
beta_arr = beta_arr + beta * w_c ** k params: BareParams,
):
beta_arr = fiber.dispersion_from_coefficients(params.w_c, params.beta)
wl = units.m.inv(params.w) wl = units.m.inv(params.w)
D = fiber.beta2_to_D(beta_arr, wl) * 1e6
zdw = math.all_zeros(wl, beta_arr) zdw = math.all_zeros(wl, beta_arr)
if len(zdw) > 0: if len(zdw) > 0:
zdw = zdw[np.argmin(abs(zdw - params.wavelength))] zdw = zdw[np.argmin(abs(zdw - params.wavelength))]
lbl.append(f"ZDW at {zdw*1e9:.1f}nm") lbl.append(f"ZDW at {zdw:.1f}nm")
else: else:
lbl.append("") lbl.append("")
m = np.ones_like(wl, dtype=bool) m = np.ones_like(wl, dtype=bool)
if lim is None: if lim is None:
lim = params.interp_range lim = params.interp_range
m &= wl >= lim[0] m &= wl >= (lim[0] if lim[0] < 1 else lim[0] * 1e-9)
m &= wl <= lim[1] m &= wl <= (lim[1] if lim[1] < 1 else lim[1] * 1e-9)
left.annotate(
rf"$\lambda_{{\mathrm{{min}}}}={np.min(params.l[params.l>0])*1e9:.1f}$ nm"
f"lower interpolation limit : {params.interp_range[0]*1e9:.1f} nm",
(0, 1),
xycoords="axes fraction",
va="top",
ha="left",
)
m = np.argwhere(m)[:, 0] m = np.argwhere(m)[:, 0]
m = np.array(sorted(m, key=lambda el: wl[el])) m = np.array(sorted(m, key=lambda el: wl[el]))
if len(m) == 0:
raise ValueError(f"nothing to plot in the range {lim!r}")
# plot D # plot D
D = fiber.beta2_to_D(beta_arr, wl) * 1e6
right.plot(1e9 * wl[m], D[m], label=" ", **style) right.plot(1e9 * wl[m], D[m], label=" ", **style)
right.set_ylabel(units.D_ps_nm_km.label) right.set_ylabel(units.D_ps_nm_km.label)
# plot beta # plot beta
left.plot(1e9 * wl[m], units.beta2_fs_cm.inv(beta_arr[m]), label=" ", **style) left.plot(units.To.Prad_s(params.w[m]), units.beta2_fs_cm.inv(beta_arr[m]), label=" ", **style)
left.set_ylabel(units.beta2_fs_cm.label) left.set_ylabel(units.beta2_fs_cm.label)
left.set_xlabel("wavelength (nm)") left.set_xlabel(units.Prad_s.label)
right.set_xlabel("wavelength (nm)") right.set_xlabel("wavelength (nm)")
return lbl return lbl
@@ -188,21 +214,3 @@ def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], BareParams
for style, (variables, params) in zip(cc, pseq): for style, (variables, params) in zip(cc, pseq):
lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]] lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]]
yield style, lbl, params yield style, lbl, params
def find_lim(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> int:
threshold = y.min() + rel_thr * (y.max() - y.min())
above_threshold = y > threshold
ind = np.argsort(x)
valid_ind = [
np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k
]
ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0]
width = len(ind_above)
return np.concatenate(
(
np.arange(max(ind_above[0] - width, 0), ind_above[0]),
ind_above,
np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)),
)
)

View File

@@ -47,7 +47,7 @@ class Spectrum(np.ndarray):
else: else:
return np.array([s.energy() for s in self]) return np.array([s.energy() for s in self])
def crop_wl(self, left: float, right: float) -> tuple[np.ndarray, np.ndarray]: def crop_wl(self, left: float, right: float) -> np.ndarray:
cond = (self.params.l >= left) & (self.params.l <= right) cond = (self.params.l >= left) & (self.params.l <= right)
return cond return cond
@@ -120,6 +120,9 @@ class Spectrum(np.ndarray):
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2) -(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
) )
def measure(self) -> tuple[float, float, float]:
return pulse.measure_field(self.params.t, self.time_amp)
class Pulse(Sequence): class Pulse(Sequence):
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None): def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
@@ -180,7 +183,7 @@ class Pulse(Sequence):
def __len__(self): def __len__(self):
return self.nmax return self.nmax
def __getitem__(self, key): def __getitem__(self, key) -> Spectrum:
return self.all_spectra(ind=range(self.nmax)[key]).squeeze() return self.all_spectra(ind=range(self.nmax)[key]).squeeze()
def intensity(self, unit): def intensity(self, unit):
@@ -282,6 +285,7 @@ class Pulse(Sequence):
spectra = [] spectra = []
for i in ind: for i in ind:
spectra.append(self._load1(i)) spectra.append(self._load1(i))
spectra = Spectrum(spectra, self.params)
self.logger.debug(f"all spectra from {self.path} successfully loaded") self.logger.debug(f"all spectra from {self.path} successfully loaded")
if len(ind) == 1: if len(ind) == 1:

View File

@@ -4,6 +4,7 @@ scgenerator module but some function may be used in any python program
""" """
from argparse import ArgumentTypeError
import itertools import itertools
import multiprocessing import multiprocessing
import re import re
@@ -214,6 +215,17 @@ def pretty_format_value(name: str, value) -> str:
return getattr(BareParams, name).display(value) return getattr(BareParams, name).display(value)
def pretty_format_from_file_name(name: str) -> str:
s = name.split(PARAM_SEPARATOR)
out = []
for key, value in zip(s[::2], s[1::2]):
try:
out.append(getattr(BareParams, key).display(float(value)))
except (AttributeError, ValueError):
out.append(key + PARAM_SEPARATOR + value)
return PARAM_SEPARATOR.join(out)
def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]: def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]:
"""given a config with "variable" parameters, iterates through every possible combination, """given a config with "variable" parameters, iterates through every possible combination,
yielding a a list of (parameter_name, value) tuples and a full config dictionary. yielding a a list of (parameter_name, value) tuples and a full config dictionary.
@@ -282,3 +294,21 @@ def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
for k in new: for k in new:
variable.pop(k, None) # remove old ones variable.pop(k, None) # remove old ones
return replace(old, variable=variable, **{k: None for k in variable}, **new) return replace(old, variable=variable, **{k: None for k in variable}, **new)
def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray:
threshold = y.min() + rel_thr * (y.max() - y.min())
above_threshold = y > threshold
ind = np.argsort(x)
valid_ind = [
np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k
]
ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0]
width = len(ind_above)
return np.concatenate(
(
np.arange(max(ind_above[0] - width, 0), ind_above[0]),
ind_above,
np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)),
)
)

View File

@@ -421,7 +421,7 @@ class BareParams:
dico : dict dico : dict
dictionary dictionary
""" """
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"] forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets", "l"]
types = (np.ndarray, float, int, str, list, tuple, dict) types = (np.ndarray, float, int, str, list, tuple, dict)
out = {} out = {}
for key, value in dico.items(): for key, value in dico.items():