units refactor

This commit is contained in:
Benoît Sierro
2023-09-25 16:14:38 +02:00
parent 2f9b5005a5
commit 4b5563bf54
13 changed files with 100 additions and 77 deletions

View File

@@ -459,8 +459,8 @@ default_rules: list[Rule] = [
"dynamic_dispersion",
lambda pressure: isinstance(pressure, (list, tuple, np.ndarray)),
),
Rule("w0", units.m, ["wavelength"]),
Rule("l", units.m.inv, ["w"]),
Rule("w0", units.m_rads, ["wavelength"]),
Rule("l", units.m_rads, ["w"]),
Rule("w0_ind", math.argclosest, ["w_for_disp", "w0"]),
Rule("w_num", len, ["w"]),
Rule("dw", lambda w: w[1] - w[0]),
@@ -490,7 +490,7 @@ default_rules: list[Rule] = [
Rule("soliton_length", pulse.soliton_length),
Rule("c_to_a_factor", lambda: 1, priorities=-1),
# Fiber Dispersion
Rule("w_for_disp", units.m, ["wl_for_disp"]),
Rule("w_for_disp", units.m_rads, ["wl_for_disp"]),
Rule("gas_info", materials.Gas),
Rule("chi_gas", lambda gas_info, wl_for_disp: gas_info.sellmeier.chi(wl_for_disp)),
Rule("n_gas_2", materials.n_gas_2),

View File

@@ -1,30 +1,30 @@
"""
series of helper functions
"""
from contextlib import nullcontext
import warnings
from pathlib import Path
from typing import Any, Mapping
import matplotlib.pyplot as plt
import numpy as np
import tomli
from scgenerator.math import all_zeros
from scgenerator.math import abs2, all_zeros
from scgenerator.parameter import Parameters
from scgenerator.physics.fiber import beta2, n_eff_hasan, n_eff_marcatili
from scgenerator.physics.materials import n_gas_2
from scgenerator.physics.units import c, nm
from scgenerator.physics.units import c, nm_rads
from scgenerator.solver import solve43
from scgenerator.spectra import Propagation, propagation
try:
from tqdm import tqdm
except ModuleNotFoundError:
tqdm = None
__all__ = ["capillary_dispersion", "capillary_zdw", "revolver_dispersion", "quick_sim", "w_from_wl"]
__all__ = ["capillary_dispersion", "capillary_zdw", "revolver_dispersion", "compute", "w_from_wl"]
def w_from_wl(wl_min_nm: float, wl_max_nm: float, n: int) -> np.ndarray:
return np.linspace(nm(wl_max_nm), nm(wl_min_nm), n)
return np.linspace(nm_rads(wl_max_nm), nm_rads(wl_min_nm), n)
def capillary_dispersion(
@@ -162,40 +162,29 @@ def extend_axis(axis: np.ndarray) -> np.ndarray:
return axis
def quick_sim(params: dict[str, Any] | Parameters, **_params: Any) -> tuple[Parameters, np.ndarray]:
"""
run a quick simulation
def compute(parameters: Parameters, overwrite: bool = False) -> Propagation:
name = Path(parameters.compute("name")).stem + ".zip"
Parameters
----------
params : dict[str, Any] | Parameters | os.PathLike
a dict of parameters, a Parameters obj or a path to a toml file from which to read the
parameters
_params : Any
override the initial parameters with these keyword arguments
prop_params = parameters.compile()
prop = propagation(name, prop_params, bundle_data=True, overwrite=overwrite)
Example
-------
```
params, sim = quick_sim("long_fiber.toml", energy=10e-6)
```
"""
if isinstance(params, Mapping):
params = Parameters(**(params | _params))
else:
params = Parameters(**(tomli.loads(Path(params).read_text()) | _params))
sim = RK4IP(params)
if tqdm:
pbar = tqdm(total=params.z_num)
def callback(_, __):
with warnings.catch_warnings(), tqdm(total=prop_params.z_num) as pbar:
warnings.filterwarnings("error")
for i, (spec, new_stat) in enumerate(
solve43(
prop_params.spec_0,
prop_params.linear_operator,
prop_params.nonlinear_operator,
prop_params.length,
prop_params.tolerated_error,
prop_params.tolerated_error,
0.9,
targets=prop_params.z_targets,
)
):
pbar.update()
plt.plot(prop_params.t, abs2(prop_params.ifft(spec)))
plt.show()
prop.append(spec)
else:
pbar = nullcontext()
callback = None
with pbar:
return params, sim.run(progress_callback=callback)
return prop

View File

@@ -40,7 +40,7 @@ def total_extent(*vec: np.ndarray) -> float:
def span_above(arr: np.ndarray, threshold: float) -> tuple[int, int]:
"""returns the first and last index where the array is above the specified threshold"""
ind = np.where(arr >= threshold)[0]
ind = np.where(arr >= threshold)[-1]
return np.min(ind), np.max(ind)

View File

@@ -228,4 +228,4 @@ def segments(signal: np.ndarray, num_segments: int) -> np.ndarray:
def quantum_noise_limit(wavelength: float, power: float) -> float:
return units.m(wavelength) * units.hbar * 2 / power
return units.m_rads(wavelength) * units.hbar * 2 / power

View File

@@ -15,7 +15,7 @@ T = TypeVar("T")
def group_delay_to_gdd(wavelength: np.ndarray, group_delay: np.ndarray) -> np.ndarray:
w = units.m.inv(wavelength)
w = units.m_rads(wavelength)
gdd = np.gradient(group_delay, w)
return gdd
@@ -48,7 +48,7 @@ def material_dispersion(
beta2 as function of wavelength
"""
w = units.m(wavelengths)
w = units.m_rads(wavelengths)
sellmeier = materials.Sellmeier.load(material)
n_gas_2 = sellmeier.n_gas_2(wavelengths, temperature, pressure)
@@ -84,7 +84,7 @@ def find_optimal_depth(
w = w_c + w0
disp = np.zeros(len(w))
ind = w > (w0 / 10)
disp[ind] = material_dispersion(units.m.inv(w[ind]), material)
disp[ind] = material_dispersion(units.m_rads(w[ind]), material)
def propagate(z):
return spectrum * np.exp(-0.5j * disp * w_c**2 * z)
@@ -124,6 +124,6 @@ def propagate_field(
propagated field
"""
w_c = math.wspace(t)
l = units.m(w_c + units.nm(center_wl_nm))
l = units.m_rads(w_c + units.nm_rads(center_wl_nm))
disp = material_dispersion(l, material)
return np.fft.ifft(np.fft.fft(field) * np.exp(0.5j * disp * w_c**2 * z))

View File

@@ -141,7 +141,7 @@ def plasma_dispersion(wl_for_disp, number_density, simple=False):
"""
e2_me_e0 = 3182.60735 # e^2 /(m_e * epsilon_0)
w = units.m(wl_for_disp)
w = units.m_rads(wl_for_disp)
if simple:
w_pl = number_density * e2_me_e0
return -(w_pl**2) / (c * w**2)
@@ -544,7 +544,7 @@ def HCPCF_dispersion(
beta2 as function of wavelength
"""
w = units.m(wl_for_disp)
w = units.m_rads(wl_for_disp)
n_gas_2 = mat.Sellmeier.load(gas_name).n_gas_2(wl_for_disp, temperature, pressure)
n_eff_func = dict(

View File

@@ -358,7 +358,7 @@ def delta_gas(w: np.ndarray, gas: Gas) -> np.ndarray:
delta_t
since 2 gradients are computed, it is recommended to exclude the 2 extremum values
"""
chi = gas.sellmeier.chi(units.m.inv(w))
chi = gas.sellmeier.chi(units.m_rads(w))
N0 = gas.number_density_van_der_waals()
dchi_dw = np.gradient(chi, w)

View File

@@ -468,7 +468,7 @@ def correct_wavelength(init_wavelength: float, w_c: np.ndarray, field_0: np.ndar
to field_0 is located at init_wavelength
"""
delta_w = w_c[np.argmax(math.abs2(np.fft.fft(field_0)))]
return units.m.inv(units.m(init_wavelength) - delta_w)
return units.m_rads(units.m_rads(init_wavelength) - delta_w)
def E0_to_P0(energy: float, t0: float, shape: str):

View File

@@ -30,7 +30,7 @@ types are "WL", "FREQ", "AFREQ", "TIME", "PRESSURE", "TEMPERATURE", "OTHER"
"""
_T = TypeVar("_T")
_UT = Callable[[_T], _T]
PRIMARIES = dict(WL="Rad/s", FREQ="Rad/s", AFREQ="Rad/s", TIME="s", PRESSURE="Pa", TEMPERATURE="K")
PRIMARIES = dict(WL="m", FREQ="Rad/s", AFREQ="Rad/s", TIME="s", PRESSURE="Pa", TEMPERATURE="K")
class UnitMap(dict):
@@ -128,17 +128,27 @@ class unit:
@unit("WL", r"Wavelength λ (m)")
def m(l: _T) -> _T:
return 2 * pi * c / l
return l
@unit("WL", r"Wavelength λ (nm)")
def nm(l: _T) -> _T:
return 2 * pi * c / (l * 1e-9)
return l * 1e-9
@nm.inverse
def nm_inv(l: _T) -> _T:
return l * 1e9
@unit("WL", r"Wavelength λ (μm)")
def um(l: _T) -> _T:
return 2 * pi * c / (l * 1e-6)
return l * 1e-6
@nm.inverse
def um_inv(l: _T) -> _T:
return l * 1e6
@unit("FREQ", r"Frequency $f$ (Hz)")
@@ -321,6 +331,18 @@ def no_unit(x: _T) -> _T:
return x
def nm_rads(nm: _T) -> _T:
return 2e9 * np.pi * c / nm
def um_rads(um: _T) -> _T:
return 2e6 * np.pi * c / um
def m_rads(m: _T) -> _T:
return 2 * np.pi * c / m
def get_unit(unit: Union[str, Callable]) -> Callable[[float], float]:
if isinstance(unit, str):
return units_map[unit]

View File

@@ -433,8 +433,8 @@ def plot_2D(
def transform_2D_propagation(
values: np.ndarray,
plt_range: Union[PlotRange, RangeType],
x_axis: np.ndarray = None,
y_axis: np.ndarray = None,
x_axis: np.ndarray,
y_axis: np.ndarray,
log: Union[int, float, bool, str] = "1D",
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
@@ -470,10 +470,8 @@ def transform_2D_propagation(
if values.ndim != 2:
raise ValueError(f"shape was {values.shape}. Can only plot 2D array")
is_complex, plt_range = prep_plot_axis(values, plt_range)
if is_complex or any(values.ravel() < 0):
if is_complex or np.any(values < 0):
values = abs2(values)
# if params.full_field and plt_range.unit.type == "TIME":
# values = envelope_2d(x_axis, values)
x_axis, values = uniform_axis(x_axis, values, plt_range)
y_axis, values.T[:] = uniform_axis(y_axis, values.T, None)
@@ -792,7 +790,7 @@ def transform_1D_values(
new_axis, ind, ext = sort_axis(x_axis, plt_range)
values = values[ind]
if plt_range.unit.type == "WL" and plt_range.conserved_quantity:
values = units.to_WL(values, units.m.inv(plt_range.unit(new_axis)))
values = units.to_WL(values, units.m_rads(plt_range.unit(new_axis)))
if isinstance(spacing, (float, np.floating)):
tmp_axis = np.linspace(*span(new_axis), int(len(new_axis) / spacing))
@@ -1006,7 +1004,7 @@ def apply_log(values: np.ndarray, log: Union[str, bool, float, int]) -> np.ndarr
elif log == "2D":
values = math.to_dB(values, ref=values.max())
elif log == "1D" or log is True:
values = np.apply_along_axis(math.to_dB, -1, values)
values = math.to_dB(values)
elif log == "smooth 1D":
ref = np.max(values, axis=1)
ind = np.argmax((ref[:-1] - ref[1:]) < 0)
@@ -1160,10 +1158,14 @@ def summary_plot(
wl_range: PlotRange | None = None,
t_range: PlotRange | None = None,
db_min: float = -50.0,
axes: tuple[Axes, Axes] | None = None,
):
wl_int = specs.wl_int
time_int = specs.time_int
if z is None:
z = np.arange(specs.shape[0])
if wl_range is None:
imin, imax = math.span_above(wl_int, wl_int.max() * 1e-6)
wl_range = PlotRange(specs.wl_disp[imin] * 1e9, specs.wl_disp[imax] * 1e9, "nm")
@@ -1172,5 +1174,13 @@ def summary_plot(
imin, imax = math.span_above(time_int, time_int.max() * 1e-6)
t_range = PlotRange(specs.t[imin] * 1e15, specs.t[imax] * 1e15, "fs")
fig, (left, right) = plt.subplots(1, 2)
transform_2D_propagation(wl_int, wl_range, specs.w, z)
if axes is None:
_, (left, right) = plt.subplots(1, 2)
else:
left, right = axes
x, y, values = transform_2D_propagation(wl_int, wl_range, specs.wl_disp, z, log="1D")
left.imshow(values, extent=get_extent(x, y), origin="lower", aspect="auto", vmin=db_min)
x, y, values = transform_2D_propagation(time_int, t_range, specs.t, z, log=False)
right.imshow(values, extent=get_extent(x, y), origin="lower", aspect="auto", vmin=db_min)

View File

@@ -47,7 +47,7 @@ class Spectrum(np.ndarray):
# We first cast to be our class type
obj = np.asarray(input_array).view(cls)
# add the new attribute to the created instance
obj.order = np.argsort(w)
obj.w_order = np.argsort(w)
obj.w = w
if t is not None:
obj.t = t
@@ -55,6 +55,7 @@ class Spectrum(np.ndarray):
obj.t = math.iwspace(obj.w)
obj.ifft = ifft
obj.l = 2 * np.pi * units.c / obj.w
obj.l_order = np.argsort(obj.l)
if not (len(obj.w) == len(obj.t) == len(obj.l) == obj.shape[-1]):
raise ValueError(
@@ -72,7 +73,8 @@ class Spectrum(np.ndarray):
self.w = getattr(obj, "w", None)
self.t = getattr(obj, "t", None)
self.l = getattr(obj, "l", None)
self.order = getattr(obj, "order", None)
self.w_order = getattr(obj, "w_order", None)
self.l_order = getattr(obj, "l_order", None)
self.ifft = getattr(obj, "ifft", None)
def __getitem__(self, key) -> "Spectrum":
@@ -80,15 +82,15 @@ class Spectrum(np.ndarray):
@property
def wl_disp(self):
return self.l[self.order][::-1]
return self.l[self.l_order]
@property
def w_disp(self):
return self.w[self.order]
return self.w[self.w_order]
@property
def wl_int(self):
return units.to_WL(math.abs2(self), self.l)[self.order][::-1]
return units.to_WL(math.abs2(self), self.l)[..., self.l_order]
@property
def freq_int(self):
@@ -113,15 +115,15 @@ class Spectrum(np.ndarray):
)
* self
/ np.abs(self)
)[self.order][::-1]
)[..., self.l_order]
@property
def freq_amp(self):
return self[self.order]
return self[..., self.w_order[::-1]]
@property
def afreq_amp(self):
return self[self.order]
return self[..., self.w_order[::-1]]
@property
def time_amp(self):

View File

@@ -33,7 +33,7 @@ def test_simple():
evaluator.set(wavelength=800e-9, t_num=1024, dt=5e-15)
assert evaluator.compute("t") == pytest.approx(math.tspace(t_num=1024, dt=5e-15))
assert evaluator.compute("w0") == pytest.approx(units.nm(800))
assert evaluator.compute("w0") == pytest.approx(units.nm_rads(800))
def test_default_args_simple():

View File

@@ -17,7 +17,7 @@ def test_scaling():
def test_wl_dispersion():
t = sc.tspace(t_num=1 << 15, dt=3.8e-15)
w = sc.wspace(t)
wl = sc.units.m.inv(w + sc.units.nm(1546))
wl = sc.units.nm_rads(w + sc.units.nm_rads(1546)) * 1e-9
wl_disp, ind_disp = sc.fiber.lambda_for_envelope_dispersion(wl, (950e-9, 4000e-9))
assert all(np.diff(wl_disp) > 0)