working on better dispersion interpolation
This commit is contained in:
@@ -5,3 +5,4 @@ from .physics import fiber, materials, pulse, simulate, units
|
||||
from .physics.simulate import RK4IP, new_simulation, resume_simulations
|
||||
from .plotting import plot_avg, plot_results_1D, plot_results_2D, plot_spectrogram
|
||||
from .spectra import Pulse
|
||||
from . import utils
|
||||
|
||||
@@ -104,7 +104,7 @@ def create_parser():
|
||||
init_plot_parser.add_argument("config", help="path to the config file")
|
||||
init_plot_parser.add_argument(
|
||||
"--dispersion-limits",
|
||||
"-s",
|
||||
"-d",
|
||||
default=None,
|
||||
type=float,
|
||||
nargs=2,
|
||||
|
||||
@@ -23,7 +23,7 @@ default_parameters = dict(
|
||||
parallel=True,
|
||||
repeat=1,
|
||||
tolerated_error=1e-11,
|
||||
lower_wavelength_interp_limit=300e-9,
|
||||
lower_wavelength_interp_limit=100e-9,
|
||||
upper_wavelength_interp_limit=1900e-9,
|
||||
interp_degree=8,
|
||||
ideal_gas=False,
|
||||
|
||||
@@ -85,8 +85,8 @@ class Params(BareParams):
|
||||
logger = get_logger(__name__)
|
||||
|
||||
self.interp_range = (
|
||||
max(self.lower_wavelength_interp_limit, units.m.inv(np.max(self.w[self.w > 0]))),
|
||||
min(self.upper_wavelength_interp_limit, units.m.inv(np.min(self.w[self.w > 0]))),
|
||||
max(self.lower_wavelength_interp_limit, self.l[self.l > 0].min()),
|
||||
min(self.upper_wavelength_interp_limit, self.l[self.l > 0].max()),
|
||||
)
|
||||
|
||||
temp_gamma = None
|
||||
@@ -106,7 +106,7 @@ class Params(BareParams):
|
||||
|
||||
if self.gamma is None:
|
||||
self.gamma = temp_gamma
|
||||
logger.info(f"using computed \u0263 = {self.gamma:.2e} W/m^2")
|
||||
logger.info(f"using computed \u0263 = {self.gamma:.2e} W/m\u00B2")
|
||||
|
||||
# Raman response
|
||||
if "raman" in self.behaviors:
|
||||
|
||||
@@ -63,6 +63,6 @@ def configure_logger(logger: logging.Logger):
|
||||
stream_handler.setLevel(print_lvl)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
logger.setLevel(min(print_lvl, file_lvl))
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.already_configured = True
|
||||
return logger
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Literal, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Literal, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import toml
|
||||
@@ -6,6 +6,8 @@ from numpy.fft import fft, ifft
|
||||
from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from ..logger import get_logger
|
||||
|
||||
from .. import io
|
||||
from ..math import abs2, argclosest, power_fact, u_nm
|
||||
from ..utils.parameter import BareParams, hc_model_specific_parameters
|
||||
@@ -15,14 +17,14 @@ from . import units
|
||||
from .units import c, pi
|
||||
|
||||
|
||||
def lambda_for_dispersion():
|
||||
def lambda_for_dispersion(left: float, right: float) -> np.ndarray:
|
||||
"""Returns a wl vector for dispersion calculation
|
||||
|
||||
Returns
|
||||
-------
|
||||
array of wl values
|
||||
"""
|
||||
return np.linspace(190e-9, 3000e-9, 4000)
|
||||
return np.arange(left - 2e-9, right + 3e-9, 1e-9)
|
||||
|
||||
|
||||
def is_dynamic_dispersion(pressure=None):
|
||||
@@ -679,7 +681,7 @@ def compute_dispersion(params: BareParams):
|
||||
gamma = None
|
||||
else:
|
||||
interp_range = params.interp_range
|
||||
lambda_ = lambda_for_dispersion()
|
||||
lambda_ = lambda_for_dispersion(*interp_range)
|
||||
beta2 = np.zeros_like(lambda_)
|
||||
|
||||
if params.model == "pcf":
|
||||
@@ -773,7 +775,7 @@ def dispersion_coefficients(
|
||||
beta2_coef : 1D array
|
||||
Taylor coefficients in decreasing order
|
||||
"""
|
||||
|
||||
logger = get_logger()
|
||||
if interp_range is None:
|
||||
r = slice(2, -2)
|
||||
else:
|
||||
@@ -783,15 +785,50 @@ def dispersion_coefficients(
|
||||
r = (lambda_ > max(lambda_[2], interp_range[0])) & (
|
||||
lambda_ < min(lambda_[-2], interp_range[1])
|
||||
)
|
||||
logger.debug(
|
||||
f"interpolating dispersion between {lambda_[r].min()*1e9:.1f}nm and {lambda_[r].max()*1e9:.1f}nm"
|
||||
)
|
||||
# import matplotlib.pyplot as plt
|
||||
|
||||
# we get the beta2 Taylor coeffiecients by making a fit around w0
|
||||
w_c = units.m(lambda_) - w0
|
||||
# interp = interp1d(w_c[r], beta2[r])
|
||||
# w_c = np.linspace(w_c)
|
||||
# fig, ax = plt.subplots()
|
||||
# ax.plot(w_c[r], beta2[r])
|
||||
# fig.show()
|
||||
|
||||
fit = Chebyshev.fit(w_c[r], beta2[r], deg)
|
||||
beta2_coef = cheb2poly(fit.convert().coef) * np.cumprod([1] + list(range(1, deg + 1)))
|
||||
|
||||
return beta2_coef
|
||||
|
||||
|
||||
def dispersion_from_coefficients(
|
||||
w_c: np.ndarray, beta: Union[list[float], np.ndarray]
|
||||
) -> np.ndarray:
|
||||
"""computes the dispersion profile (beta2) from the beta coefficients
|
||||
|
||||
Parameters
|
||||
----------
|
||||
w_c : np.ndarray, shape (n, )
|
||||
centered angular frequency (0 <=> pump frequency)
|
||||
beta : Iterable[float]
|
||||
beta coefficients
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray, shape (n, )
|
||||
beta2 as function of w_c
|
||||
"""
|
||||
|
||||
coef = np.array(beta) / np.cumprod([1] + list(range(1, len(beta))))
|
||||
beta_arr = np.zeros_like(w_c)
|
||||
for k, b in reversed(list(enumerate(coef))):
|
||||
beta_arr = beta_arr + b * w_c ** k
|
||||
return beta_arr
|
||||
|
||||
|
||||
def delayed_raman_t(t, raman_type="stolen"):
|
||||
"""
|
||||
computes the unnormalized temporal Raman response function applied to the array t
|
||||
@@ -1007,3 +1044,16 @@ def effective_core_radius(lambda_, core_radius, s=0.08, h=200e-9):
|
||||
def effective_radius_HCARF(core_radius, t, f1, f2, lambda_):
|
||||
"""eq. 3 in Hasan 2018"""
|
||||
return f1 * core_radius * (1 - f2 * lambda_ ** 2 / (core_radius * t))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
w = np.linspace(0, 1, 4096)
|
||||
c = np.arange(8)
|
||||
import time
|
||||
|
||||
t = time.time()
|
||||
|
||||
for _ in range(10000):
|
||||
dispersion_from_coefficients(w, c)
|
||||
|
||||
print((time.time() - t) / 10, "ms")
|
||||
|
||||
@@ -12,13 +12,16 @@ n is the number of spectra at the same z position and nt is the size of the time
|
||||
import itertools
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Tuple
|
||||
from typing import Literal, Tuple, TypeVar
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from numpy import pi
|
||||
from numpy.fft import fft, fftshift, ifft
|
||||
from scipy import optimize
|
||||
from scipy.interpolate import UnivariateSpline
|
||||
from scipy.optimize import minimize_scalar
|
||||
from scipy.optimize.optimize import OptimizeResult
|
||||
|
||||
from .. import io
|
||||
from ..defaults import default_plotting
|
||||
@@ -30,7 +33,7 @@ from . import units
|
||||
|
||||
c = 299792458.0
|
||||
hbar = 1.05457148e-34
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
#
|
||||
fwhm_to_T0_fac = dict(
|
||||
@@ -535,14 +538,23 @@ def peak_ind(values, mam=None):
|
||||
am = np.argmax(values)
|
||||
else:
|
||||
m, am = mam
|
||||
|
||||
try:
|
||||
left_ind = (
|
||||
am
|
||||
- np.where((values[am:0:-1] - values[am - 1 :: -1] < 0) & (values[am:0:-1] < m / 2))[0][0]
|
||||
)
|
||||
- np.where((values[am:0:-1] - values[am - 1 :: -1] <= 0) & (values[am:0:-1] < m / 2))[
|
||||
0
|
||||
][0]
|
||||
) - 3
|
||||
except IndexError:
|
||||
left_ind = 0
|
||||
try:
|
||||
right_ind = (
|
||||
am + np.where((values[am:-1] - values[am + 1 :] < 0) & (values[am:-1] < m / 2))[0][0]
|
||||
)
|
||||
return left_ind - 3, right_ind + 3
|
||||
am + np.where((values[am:-1] - values[am + 1 :] <= 0) & (values[am:-1] < m / 2))[0][0]
|
||||
) + 3
|
||||
except IndexError:
|
||||
right_ind = len(values) - 1
|
||||
return left_ind, right_ind
|
||||
|
||||
|
||||
def setup_splines(x_axis, values, mam=None):
|
||||
@@ -895,3 +907,32 @@ def measure_field(t: np.ndarray, field: np.ndarray) -> Tuple[float, float, float
|
||||
peak_power = intensity.max()
|
||||
energy = np.trapz(intensity, t)
|
||||
return fwhm, peak_power, energy
|
||||
|
||||
|
||||
def remove_2nd_order_dispersion(
|
||||
spectrum: T, w_c: np.ndarray, beta2: float, max_z: float = -1.0
|
||||
) -> tuple[T, OptimizeResult]:
|
||||
"""attempts to remove 2nd order dispersion from a complex spectrum
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spectrum : np.ndarray or Spectrum, shape (n, )
|
||||
spectrum from which to remove 2nd order dispersion
|
||||
w_c : np.ndarray, shape (n, )
|
||||
corresponding centered angular frequencies (w-w0)
|
||||
beta2 : float
|
||||
2nd order dispersion coefficient
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray, shape (n, )
|
||||
spectrum with 2nd order dispersion removed
|
||||
"""
|
||||
# makeshift_t = np.linspace(0, 1, len(w_c))
|
||||
propagate = lambda z: spectrum * np.exp(-0.5j * beta2 * w_c ** 2 * z)
|
||||
|
||||
def score(z):
|
||||
return 1 / np.max(abs2(np.fft.ifft(propagate(z))))
|
||||
|
||||
opti = minimize_scalar(score, bracket=(max_z, 0))
|
||||
return propagate(opti.x), opti
|
||||
|
||||
@@ -6,6 +6,9 @@ import re
|
||||
from threading import settrace
|
||||
from typing import Callable, TypeVar, Union
|
||||
from dataclasses import dataclass
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from numpy.lib.arraysetops import isin
|
||||
from ..utils.parameter import Parameter, type_checker
|
||||
import numpy as np
|
||||
from numpy import pi
|
||||
@@ -255,7 +258,8 @@ def sort_axis(axis, plt_range: PlotRange):
|
||||
# slice y according to the given ranges
|
||||
y = y[ct][:, cw]
|
||||
"""
|
||||
|
||||
if isinstance(plt_range, tuple):
|
||||
plt_range = PlotRange(*plt_range)
|
||||
r = np.array((plt_range.left, plt_range.right), dtype="float")
|
||||
|
||||
indices = np.arange(len(axis))[
|
||||
|
||||
@@ -30,7 +30,7 @@ def plot_setup(
|
||||
- an axis
|
||||
"""
|
||||
out_path = defaults["name"] if out_path is None else out_path
|
||||
plot_name = out_path.stem
|
||||
plot_name = out_path.name.replace(f".{file_type}", "")
|
||||
out_dir = out_path.resolve().parent
|
||||
|
||||
file_name = plot_name + "." + file_type
|
||||
@@ -286,9 +286,8 @@ def _finish_plot_2D(
|
||||
if isinstance(ax, tuple) and len(ax) > 1:
|
||||
ax, cbar_ax = ax[0], ax[1]
|
||||
|
||||
folder_name = ""
|
||||
if is_new_plot:
|
||||
out_path, fig, ax = plot_setup(out_path=Path(folder_name) / file_name, file_type=file_type)
|
||||
out_path, fig, ax = plot_setup(out_path=Path(file_name), file_type=file_type)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from itertools import cycle
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from typing import Any, Iterable, Optional
|
||||
from cycler import cycler
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -12,10 +12,16 @@ from ..utils.parameter import BareParams
|
||||
from ..initialize import ParamSequence
|
||||
from ..physics import units, fiber
|
||||
from ..spectra import Pulse
|
||||
from ..utils import pretty_format_value
|
||||
from ..utils import pretty_format_value, pretty_format_from_file_name, auto_crop
|
||||
from .. import env, math
|
||||
|
||||
|
||||
def fingerprint(params: BareParams):
|
||||
h1 = hash(params.field_0.tobytes())
|
||||
h2 = tuple(params.beta)
|
||||
return h1, h2
|
||||
|
||||
|
||||
def plot_all(sim_dir: Path, limits: list[str]):
|
||||
for p in sim_dir.glob("*"):
|
||||
if not p.is_dir():
|
||||
@@ -26,7 +32,13 @@ def plot_all(sim_dir: Path, limits: list[str]):
|
||||
left, right, unit = lim.split(",")
|
||||
left = float(left)
|
||||
right = float(right)
|
||||
pulse.plot_2D(left, right, unit, file_name=p.parent / f"{p.name}_{left}_{right}_{unit}")
|
||||
pulse.plot_2D(
|
||||
left,
|
||||
right,
|
||||
unit,
|
||||
file_name=p.parent
|
||||
/ f"{pretty_format_from_file_name(p.name)} {left} {right} {unit}",
|
||||
)
|
||||
|
||||
|
||||
def plot_init_field_spec(
|
||||
@@ -49,7 +61,7 @@ def plot_init_field_spec(
|
||||
|
||||
|
||||
def plot_dispersion(config_path: Path, lim: tuple[float, float] = None):
|
||||
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7), sharex=True)
|
||||
fig, (left, right) = plt.subplots(1, 2, figsize=(12, 7))
|
||||
left.grid()
|
||||
right.grid()
|
||||
all_labels = []
|
||||
@@ -62,7 +74,7 @@ def plot_dispersion(config_path: Path, lim: tuple[float, float] = None):
|
||||
|
||||
lbl = plot_1_dispersion(lim, left, right, style, lbl, params)
|
||||
all_labels.append(lbl)
|
||||
finish_plot(fig, left, right, all_labels, params)
|
||||
finish_plot(fig, right, all_labels, params)
|
||||
|
||||
|
||||
def plot_init(
|
||||
@@ -77,8 +89,8 @@ def plot_init(
|
||||
all_labels = []
|
||||
already_plotted = set()
|
||||
for style, lbl, params in plot_helper(config_path):
|
||||
if (bbb := hash(params.field_0.tobytes())) not in already_plotted:
|
||||
already_plotted.add(bbb)
|
||||
if (fp := fingerprint(params)) not in already_plotted:
|
||||
already_plotted.add(fp)
|
||||
else:
|
||||
continue
|
||||
lbl = plot_1_dispersion(lim_disp, tl, tr, style, lbl, params)
|
||||
@@ -90,8 +102,8 @@ def plot_init(
|
||||
def plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params):
|
||||
field = math.abs2(params.field_0)
|
||||
spec = math.abs2(params.spec_0)
|
||||
t = units.fs.inv(params.t)
|
||||
wl = units.nm.inv(params.w)
|
||||
t = units.To.fs(params.t)
|
||||
wl = units.To.nm(params.w)
|
||||
|
||||
lbl.append(f"max at {wl[spec.argmax()]:.1f} nm")
|
||||
|
||||
@@ -100,54 +112,68 @@ def plot_1_init_spec_field(lim_t, lim_l, left, right, style, lbl, params):
|
||||
mt &= t >= lim_t[0]
|
||||
mt &= t <= lim_t[1]
|
||||
else:
|
||||
mt = find_lim(t, field)
|
||||
mt = auto_crop(t, field)
|
||||
ml = np.ones_like(wl, dtype=bool)
|
||||
if lim_l is not None:
|
||||
ml &= t >= lim_l[0]
|
||||
ml &= t <= lim_l[1]
|
||||
ml &= wl >= lim_l[0]
|
||||
ml &= wl <= lim_l[1]
|
||||
else:
|
||||
ml = find_lim(wl, spec)
|
||||
ml = auto_crop(wl, spec)
|
||||
|
||||
left.plot(t[mt], field[mt])
|
||||
right.plot(wl[ml], spec[ml], label=" ", **style)
|
||||
return lbl
|
||||
|
||||
|
||||
def plot_1_dispersion(lim, left, right, style, lbl, params):
|
||||
coef = params.beta / np.cumprod([1] + list(range(1, len(params.beta))))
|
||||
w_c = params.w_c
|
||||
|
||||
beta_arr = np.zeros_like(w_c)
|
||||
for k, beta in reversed(list(enumerate(coef))):
|
||||
beta_arr = beta_arr + beta * w_c ** k
|
||||
def plot_1_dispersion(
|
||||
lim: Optional[tuple[float, float]],
|
||||
left: plt.Axes,
|
||||
right: plt.Axes,
|
||||
style: dict[str, Any],
|
||||
lbl: list[str],
|
||||
params: BareParams,
|
||||
):
|
||||
beta_arr = fiber.dispersion_from_coefficients(params.w_c, params.beta)
|
||||
wl = units.m.inv(params.w)
|
||||
D = fiber.beta2_to_D(beta_arr, wl) * 1e6
|
||||
|
||||
zdw = math.all_zeros(wl, beta_arr)
|
||||
if len(zdw) > 0:
|
||||
zdw = zdw[np.argmin(abs(zdw - params.wavelength))]
|
||||
lbl.append(f"ZDW at {zdw*1e9:.1f}nm")
|
||||
lbl.append(f"ZDW at {zdw:.1f}nm")
|
||||
else:
|
||||
lbl.append("")
|
||||
|
||||
m = np.ones_like(wl, dtype=bool)
|
||||
if lim is None:
|
||||
lim = params.interp_range
|
||||
m &= wl >= lim[0]
|
||||
m &= wl <= lim[1]
|
||||
m &= wl >= (lim[0] if lim[0] < 1 else lim[0] * 1e-9)
|
||||
m &= wl <= (lim[1] if lim[1] < 1 else lim[1] * 1e-9)
|
||||
|
||||
left.annotate(
|
||||
rf"$\lambda_{{\mathrm{{min}}}}={np.min(params.l[params.l>0])*1e9:.1f}$ nm"
|
||||
f"lower interpolation limit : {params.interp_range[0]*1e9:.1f} nm",
|
||||
(0, 1),
|
||||
xycoords="axes fraction",
|
||||
va="top",
|
||||
ha="left",
|
||||
)
|
||||
|
||||
m = np.argwhere(m)[:, 0]
|
||||
m = np.array(sorted(m, key=lambda el: wl[el]))
|
||||
|
||||
if len(m) == 0:
|
||||
raise ValueError(f"nothing to plot in the range {lim!r}")
|
||||
|
||||
# plot D
|
||||
D = fiber.beta2_to_D(beta_arr, wl) * 1e6
|
||||
right.plot(1e9 * wl[m], D[m], label=" ", **style)
|
||||
right.set_ylabel(units.D_ps_nm_km.label)
|
||||
|
||||
# plot beta
|
||||
left.plot(1e9 * wl[m], units.beta2_fs_cm.inv(beta_arr[m]), label=" ", **style)
|
||||
left.plot(units.To.Prad_s(params.w[m]), units.beta2_fs_cm.inv(beta_arr[m]), label=" ", **style)
|
||||
left.set_ylabel(units.beta2_fs_cm.label)
|
||||
|
||||
left.set_xlabel("wavelength (nm)")
|
||||
left.set_xlabel(units.Prad_s.label)
|
||||
right.set_xlabel("wavelength (nm)")
|
||||
return lbl
|
||||
|
||||
@@ -188,21 +214,3 @@ def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], BareParams
|
||||
for style, (variables, params) in zip(cc, pseq):
|
||||
lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]]
|
||||
yield style, lbl, params
|
||||
|
||||
|
||||
def find_lim(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> int:
|
||||
threshold = y.min() + rel_thr * (y.max() - y.min())
|
||||
above_threshold = y > threshold
|
||||
ind = np.argsort(x)
|
||||
valid_ind = [
|
||||
np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k
|
||||
]
|
||||
ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0]
|
||||
width = len(ind_above)
|
||||
return np.concatenate(
|
||||
(
|
||||
np.arange(max(ind_above[0] - width, 0), ind_above[0]),
|
||||
ind_above,
|
||||
np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -47,7 +47,7 @@ class Spectrum(np.ndarray):
|
||||
else:
|
||||
return np.array([s.energy() for s in self])
|
||||
|
||||
def crop_wl(self, left: float, right: float) -> tuple[np.ndarray, np.ndarray]:
|
||||
def crop_wl(self, left: float, right: float) -> np.ndarray:
|
||||
cond = (self.params.l >= left) & (self.params.l <= right)
|
||||
return cond
|
||||
|
||||
@@ -120,6 +120,9 @@ class Spectrum(np.ndarray):
|
||||
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
|
||||
)
|
||||
|
||||
def measure(self) -> tuple[float, float, float]:
|
||||
return pulse.measure_field(self.params.t, self.time_amp)
|
||||
|
||||
|
||||
class Pulse(Sequence):
|
||||
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
|
||||
@@ -180,7 +183,7 @@ class Pulse(Sequence):
|
||||
def __len__(self):
|
||||
return self.nmax
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
return self.all_spectra(ind=range(self.nmax)[key]).squeeze()
|
||||
|
||||
def intensity(self, unit):
|
||||
@@ -282,6 +285,7 @@ class Pulse(Sequence):
|
||||
spectra = []
|
||||
for i in ind:
|
||||
spectra.append(self._load1(i))
|
||||
spectra = Spectrum(spectra, self.params)
|
||||
|
||||
self.logger.debug(f"all spectra from {self.path} successfully loaded")
|
||||
if len(ind) == 1:
|
||||
|
||||
@@ -4,6 +4,7 @@ scgenerator module but some function may be used in any python program
|
||||
|
||||
"""
|
||||
|
||||
from argparse import ArgumentTypeError
|
||||
import itertools
|
||||
import multiprocessing
|
||||
import re
|
||||
@@ -214,6 +215,17 @@ def pretty_format_value(name: str, value) -> str:
|
||||
return getattr(BareParams, name).display(value)
|
||||
|
||||
|
||||
def pretty_format_from_file_name(name: str) -> str:
|
||||
s = name.split(PARAM_SEPARATOR)
|
||||
out = []
|
||||
for key, value in zip(s[::2], s[1::2]):
|
||||
try:
|
||||
out.append(getattr(BareParams, key).display(float(value)))
|
||||
except (AttributeError, ValueError):
|
||||
out.append(key + PARAM_SEPARATOR + value)
|
||||
return PARAM_SEPARATOR.join(out)
|
||||
|
||||
|
||||
def variable_iterator(config: BareConfig) -> Iterator[Tuple[List[Tuple[str, Any]], BareParams]]:
|
||||
"""given a config with "variable" parameters, iterates through every possible combination,
|
||||
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
|
||||
@@ -282,3 +294,21 @@ def override_config(new: Dict[str, Any], old: BareConfig = None) -> BareConfig:
|
||||
for k in new:
|
||||
variable.pop(k, None) # remove old ones
|
||||
return replace(old, variable=variable, **{k: None for k in variable}, **new)
|
||||
|
||||
|
||||
def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray:
|
||||
threshold = y.min() + rel_thr * (y.max() - y.min())
|
||||
above_threshold = y > threshold
|
||||
ind = np.argsort(x)
|
||||
valid_ind = [
|
||||
np.array(list(g)) for k, g in itertools.groupby(ind, key=lambda i: above_threshold[i]) if k
|
||||
]
|
||||
ind_above = sorted(valid_ind, key=lambda el: len(el), reverse=True)[0]
|
||||
width = len(ind_above)
|
||||
return np.concatenate(
|
||||
(
|
||||
np.arange(max(ind_above[0] - width, 0), ind_above[0]),
|
||||
ind_above,
|
||||
np.arange(ind_above[-1] + 1, min(len(y), ind_above[-1] + width)),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -421,7 +421,7 @@ class BareParams:
|
||||
dico : dict
|
||||
dictionary
|
||||
"""
|
||||
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets"]
|
||||
forbiden_keys = ["w_c", "w_power_fact", "field_0", "spec_0", "w", "t", "z_targets", "l"]
|
||||
types = (np.ndarray, float, int, str, list, tuple, dict)
|
||||
out = {}
|
||||
for key, value in dico.items():
|
||||
|
||||
Reference in New Issue
Block a user