bunch of improvements

This commit is contained in:
2024-06-11 08:50:45 +02:00
parent 376abf628f
commit d8deeedfb6
10 changed files with 65 additions and 21 deletions

View File

@@ -122,11 +122,14 @@ class Cache:
stuff = pickle.loads(fn.read_bytes())
return stuff
@check_exists
def save(self, key: str, value: Any):
self.save_raw(key, pickle.dumps(value))
@check_exists
def save_raw(self, key: str, value: bytes):
key = normalize_path(key)
fn = self.dir / key
fn.write_bytes(pickle.dumps(value))
fn.write_bytes(value)
@check_exists
def reset(self):

View File

@@ -528,10 +528,11 @@ default_rules: list[Rule] = [
Rule("gamma", lambda gamma_arr: gamma_arr[0], priorities=-1),
Rule("gamma", fiber.gamma_parameter),
Rule("gamma_arr", fiber.gamma_parameter, ["n2", "w0", "effective_area_arr"]),
Rule("gamma_arr", lambda t_num, gamma: np.ones(t_num) * gamma, priorities=-1),
# Raman
Rule(["hr_w", "raman_fraction"], fiber.delayed_raman_w),
Rule("raman_fraction", fiber.raman_fraction),
Rule("raman_fraction", lambda: 0, priorities=-1),
Rule("raman_fraction", lambda: 0.0, priorities=-1),
# loss
Rule("alpha_arr", fiber.scalar_loss),
Rule("alpha_arr", fiber.safe_capillary_loss, conditions=dict(loss="capillary")),

View File

@@ -131,7 +131,7 @@ def sigmoid(x):
return 1 / (np.exp(-x) + 1)
def to_dB(arr: np.ndarray, ref=None, axis=None) -> np.ndarray:
def to_dB(arr: np.ndarray, ref=None, axis=None, default: float | None = None) -> np.ndarray:
"""
converts unitless values in dB
@@ -151,17 +151,26 @@ def to_dB(arr: np.ndarray, ref=None, axis=None) -> np.ndarray:
array in dB
"""
if axis is not None and arr.ndim > 1 and ref is None:
return np.apply_along_axis(to_dB, axis, arr)
return np.apply_along_axis(to_dB, axis, arr, default=default)
out = np.ones_like(arr)
if default is not None and math.isfinite(default):
out[:] *= default
else:
out *= np.nan
if ref is None:
ref = np.max(arr)
above_0 = arr > 0
if not np.any(above_0) or ref <= 0:
warnings.warn("invalid array to convert to dB, returning 0 instead")
warnings.warn(f"invalid array to convert to dB, returning {default} instead")
if default is None:
out *= np.nan
return out
m = arr / ref
return 10 * np.log10(m, out=out * (10 * np.log10(m[above_0].min())), where=above_0)
if default is None:
out *= 10 * np.log10(m[above_0].min())
return 10 * np.log10(m, out=out, where=above_0)
def u_nm(n, m):

View File

@@ -518,7 +518,7 @@ def no_linear() -> VariableQuantity:
def build_envelope_nonlinear(
w: np.ndarray, gamma: float, self_steepening: bool = True, raman: str | None = "measured"
) -> tuple[VariableQuantity, SpecOperator]:
) -> SpecOperator:
w0 = w[0]
t = math.iwspace(w)
w_c = w - w0
@@ -528,9 +528,8 @@ def build_envelope_nonlinear(
ss_op = constant_quantity(0.0)
if gamma != 0:
raman_frac = fiber.raman_fraction(raman) if raman else 0.0
hr_w, raman_frac = fiber.delayed_raman_w(t, raman)
spm_op = envelope_spm(raman_frac)
hr_w = fiber.delayed_raman_w(t, raman)
raman_op = envelope_raman(hr_w, raman_frac, sfft.fft, sfft.ifft)
gamma_op = constant_quantity(np.ones(len(w)) * gamma)
else:

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import datetime as datetime_module
import json
from operator import pos
import os
import tomllib
import warnings
@@ -360,6 +361,7 @@ class Parameters:
shape: str = Parameter(literal("gaussian", "sech"))
wavelength: float = Parameter(in_range_incl(100e-9, 10000e-9), unit="m")
intensity_noise: float = Parameter(in_range_incl(0, 1), unit="%")
ase_level: float = Parameter(positive(*number), default=1.0)
noise_correlation: float = Parameter(in_range_incl(-10, 10))
width: float = Parameter(positive(*number), unit="s")
t0: float = Parameter(positive(*number), unit="s")

View File

@@ -620,7 +620,7 @@ def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) ->
wl, effective_area = effective_area_file.load_arrays(
("wavelength", "wl"), ("A_eff", "effective_area")
)
return np.interp(l, wl, effective_area, left=0, right=0)
return np.interp(l, wl, effective_area)
def load_custom_dispersion(
@@ -777,8 +777,7 @@ def delayed_raman_t(t: np.ndarray, raman_type: str) -> np.ndarray:
hr_arr = np.interp(t, t_stored, hr_arr_stored, left=0, right=0)
else:
print("invalid raman response function, aborting")
quit()
raise ValueError("invalid raman response function, aborting")
return hr_arr

View File

@@ -538,7 +538,13 @@ def sech_pulse(t: np.ndarray, t0: float, P0: float, chirp: float = 0.0):
def gaussian_pulse(t: np.ndarray, t0: float, P0: float, chirp: float = 0.0):
return np.sqrt(P0) * np.exp(-((t / t0) ** 2) * (1 - 1j * chirp))
"""
for unchirped pulses, the following, slower, formula is equivalent
```
return np.sqrt(P0) * 4 ** (-(t / fwhm) ** 2)
```
"""
return np.sqrt(P0) * np.exp((1j * chirp - 1.0) * ((t / t0) ** 2))
def pulse_envelope(
@@ -659,10 +665,11 @@ def finalize_pulse(
quantum_noise: ShotNoiseParameter,
w: np.ndarray,
input_transmission: float,
ase_level: float = 1.0,
) -> np.ndarray:
pre_field_0 *= np.sqrt(input_transmission)
if quantum_noise:
pre_field_0 = pre_field_0 + np.fft.ifft(shot_noise(w, *quantum_noise))
pre_field_0 = pre_field_0 + np.fft.ifft(np.sqrt(ase_level) * shot_noise(w, *quantum_noise))
return pre_field_0
@@ -1379,3 +1386,13 @@ def remove_2nd_order_dispersion2(
def gdd(w: np.ndarray, gdd: float) -> np.ndarray:
return np.exp(0.5j * w**2 * gdd)
def mask_with_noise(
w: np.ndarray,
spectrum: T,
mask: np.ndarray,
sn_params: ShotNoiseParameter = ShotNoiseParameter(False, False),
) -> T:
sn = shot_noise(w, *sn_params)
return spectrum * mask + np.sqrt(1.0 - mask**2) * sn

View File

@@ -46,6 +46,13 @@ def alt_color(c, fac: float):
return (*colorsys.hsv_to_rgb(h, s, min(1, v * fac)), alpha)
def lighten_rgb(c, fac: float):
if isinstance(c, (list, tuple)) or (isinstance(c, np.ndarray) and c.ndim > 1):
return np.array([alt_color(el, fac) for el in c])
*color, a = ColorConverter.to_rgba(c)
return (*(min(max(el + (1.0 - el) * fac, 0.0), 1.0) for el in color), a)
def get_extent(x, y, facx=1, facy=1):
"""
returns the extent 4-tuple needed for imshow, aligning each pixel
@@ -1005,7 +1012,9 @@ def uniform_axis(
return new_axis, values.squeeze()
def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarray:
def apply_log(
values: np.ndarray, log: Union[str, bool, float, int], default: float = 1.0
) -> np.ndarray:
"""
apply log transform
@@ -1034,15 +1043,15 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
if log is not False:
if isinstance(log, (float, int, np.floating, np.integer)) and log is not True:
values = math.to_dB(values, ref=log)
values = math.to_dB(values, ref=log, default=default)
elif log == "2D":
values = math.to_dB(values, ref=values.max())
values = math.to_dB(values, ref=values.max(), default=default)
elif log == "1D" or log is True:
values = math.to_dB(values, axis=1)
values = math.to_dB(values, axis=1, default=default)
elif log == "smooth 1D":
ref = np.max(values, axis=1)
ind = np.argmax((ref[:-1] - ref[1:]) < 0)
values = math.to_dB(values, ref=np.max(ref[ind:]))
values = math.to_dB(values, ref=np.max(ref[ind:]), default=default)
elif log == "linear 1D":
values = (values.T / values.max(axis=1)).T
else:

View File

@@ -151,6 +151,9 @@ class Spectrum(np.ndarray):
+ self.astype(np.complex128, subok=False).tobytes()
)
def __reduce__(self) -> tuple[type, tuple[bytes]]:
return self.__class__.from_bytes, (bytes(self),)
def tobytes(self, *args, **kwargs) -> bytes:
warnings.warn(
"Calling `tobytes` (numpy function) on Spectrum object. Did you mean `bytes(obj)`?"
@@ -205,6 +208,7 @@ class Spectrum(np.ndarray):
def energy(self) -> np.ndarray:
return np.trapz(self.time_int, x=self.t, axis=-1)
# FIX: rename with bandpass and make a generic wrapper of pulse.mask_with_noise
def mask_wl(
self,
pos: float,
@@ -263,7 +267,7 @@ class Spectrum(np.ndarray):
gate_width: float = 100e-15,
wavelength: bool = True,
autocrop: bool | float = 1e-5,
) -> np.ndarray:
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
dt = self.t[1] - self.t[0]
sigma = gate_width / (2 * np.sqrt(2 * np.log(2))) / dt
nperseg = int(sigma) * 16

View File

@@ -32,3 +32,4 @@ def test_chirp():
c2 = np.unwrap(np.angle(f2))
c2 -= c2.min()
assert pytest.approx(c2, rel=1e-2, abs=1e-4) == t2**2
assert pytest.approx(sc.abs2(f2)) == 16 ** (-((t2 / np.sqrt(2 * np.log(2))) ** 2))