working on old/new cq solver

This commit is contained in:
Benoît Sierro
2024-01-24 15:24:46 +01:00
parent 820dbbdea5
commit 400ae2fd48
10 changed files with 276 additions and 15 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
build/
.DS_store .DS_store
.idea .idea
.conda-env .conda-env

View File

@@ -0,0 +1,55 @@
import multiprocessing
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import scgenerator as sc
PARAMS = Path("./examples/noisy.toml")
SEED = 2564
def path(i) -> Path:
return Path(f"build/{Path(__file__).stem}_{i}.zip")
def propagate(i):
np.random.seed(SEED + i)
params = sc.Parameters.load("./examples/noisy.toml")
sc.propagation(path(i), params.compile()).simulate()
def propagate_all(n):
to_do = [i for i in range(n) if not path(i).exists()]
with multiprocessing.Pool(4) as pool:
pool.map(propagate, to_do)
spec, props = sc.propagation_series([path(i) for i in range(n)])
return spec, props
def main():
n = 1
spec, props = propagate_all(n)
print(spec.shape)
wl, ind, _ = sc.PlotRange(
spec.wl_disp[spec.wl_disp > 0].min() * 1e9,
props.parameters.wavelength_window[1] * 1e9,
"nm",
).sort_axis(spec.wl_disp)
for i in range(spec.shape[1]):
fig, (left, right) = plt.subplots(1, 2)
for s in spec[:, i].time_int:
left.plot(spec.t, s)
for s in spec[:, i].wl_int:
right.plot(wl, s[ind])
plt.show()
if __name__ == "__main__":
main()

26
examples/noisy.toml Normal file
View File

@@ -0,0 +1,26 @@
name = "Sierro2020 Noisy"
width = 50e-15
shape = "gaussian"
peak_power = 100e3
wavelength = 1050e-9
gamma = 0.018
beta2_coefficients = [
1.001190e-26,
-2.131124e-41,
3.286793e-55,
-1.290523e-69,
1.047255e-84,
1.696410e-98,
-9.261236e-113,
2.149311e-127,
-2.028394e-142,
]
length = 0.1
raman_type = "measured"
input_transmission = 1.0
wavelength_window = [550e-9, 2000e-9]
t_num = 8192
quantum_noise = true
z_num = 21

View File

@@ -40,7 +40,6 @@ MANDATORY_PARAMETERS = {
"spec_0", "spec_0",
"c_to_a_factor", "c_to_a_factor",
"field_0", "field_0",
"mean_power",
"input_transmission", "input_transmission",
"z_targets", "z_targets",
"length", "length",

View File

@@ -559,7 +559,7 @@ envelope_rules = default_rules + [
Rule("c_to_a_factor", pulse.c_to_a_factor), Rule("c_to_a_factor", pulse.c_to_a_factor),
# Dispersion # Dispersion
Rule("beta2_coefficients", fiber.auto_dispersion_coefficients), Rule("beta2_coefficients", fiber.auto_dispersion_coefficients),
Rule("beta2_coefficients", fiber.handle_dispersion_paramter), Rule("beta2_coefficients", fiber.handle_dispersion_parameter),
Rule("beta2_arr", fiber.dispersion_from_coefficients), Rule("beta2_arr", fiber.dispersion_from_coefficients),
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]), Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
Rule("beta2", lambda beta2_arr, w0_ind: beta2_arr[w0_ind]), Rule("beta2", lambda beta2_arr, w0_ind: beta2_arr[w0_ind]),

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import datetime as datetime_module import datetime as datetime_module
import json import json
import os import os
import tomllib
import warnings import warnings
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
@@ -429,6 +430,17 @@ class Parameters:
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime)) datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
version: str = Parameter(string) version: str = Parameter(string)
@classmethod
def from_toml(cls, s: str) -> Parameters:
decoded = tomllib.loads(s)
extras = set(decoded) - cls._p_names
if len(extras) > 0:
warnings.warn(f"extra keys (ignored) in parameter json: {extras!r}")
for e in extras:
decoded.pop(e)
return cls(**decoded)
@classmethod @classmethod
def from_json(cls, s: str) -> Parameters: def from_json(cls, s: str) -> Parameters:
decoded = json.loads(s, object_hook=custom_decode_hook) decoded = json.loads(s, object_hook=custom_decode_hook)
@@ -442,7 +454,12 @@ class Parameters:
@classmethod @classmethod
def load(cls, path: os.PathLike) -> Parameters: def load(cls, path: os.PathLike) -> Parameters:
return cls.from_json(Path(path).read_text()) path = Path(path)
if path.suffix == ".toml":
decode = cls.from_toml
else:
decode = cls.from_json
return decode(path.read_text())
def __repr__(self) -> str: def __repr__(self) -> str:
return "Parameter(" + ", ".join(self.__repr_list__()) + ")" return "Parameter(" + ", ".join(self.__repr_list__()) + ")"

View File

@@ -63,7 +63,7 @@ def dispersion_slope_to_beta3(
) )
def handle_dispersion_paramter( def handle_dispersion_parameter(
wavelength: float, dispersion_parameter: float, dispersion_slope: float | None = None wavelength: float, dispersion_parameter: float, dispersion_slope: float | None = None
) -> list[float]: ) -> list[float]:
coeffs = [dispersion_parameter_to_beta2(dispersion_parameter, wavelength)] coeffs = [dispersion_parameter_to_beta2(dispersion_parameter, wavelength)]

View File

@@ -230,7 +230,7 @@ def create_zoom_axis(
return inset return inset
def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0.05, **text_kwargs): def corner_annotation(text, ax, position="tl", pts_x=8, pts_y=8, **text_kwargs):
"""puts an annotatin in a corner of an ax """puts an annotatin in a corner of an ax
Parameters Parameters
---------- ----------
@@ -250,24 +250,26 @@ def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0
# yoff = length(ylim) * rel_y_offset # yoff = length(ylim) * rel_y_offset
if position[0] == "t": if position[0] == "t":
y = 1 - rel_y_offset y = 1
pts_y = -pts_y
va = "top" va = "top"
else: else:
y = 0 + rel_y_offset y = 0
va = "bottom" va = "bottom"
if position[1] == "l": if position[1] == "l":
x = 0 + rel_x_offset x = 0
ha = "left" ha = "left"
else: else:
x = 1 - rel_x_offset x = 1
pts_x = -pts_x
ha = "right" ha = "right"
ax.annotate( ax.annotate(
text, text,
(x, y), (x, y),
(x, y), (pts_x, pts_y),
xycoords="axes fraction", xycoords="axes fraction",
textcoords="axes fraction", textcoords="offset points",
verticalalignment=va, verticalalignment=va,
horizontalalignment=ha, horizontalalignment=ha,
**text_kwargs, **text_kwargs,

View File

@@ -7,7 +7,7 @@ import warnings
import zipfile import zipfile
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Iterator, Sequence from typing import Any, Callable, Iterator, Sequence
import numba import numba
import numpy as np import numpy as np
@@ -16,6 +16,7 @@ from scgenerator.io import TimedMessage
from scgenerator.logger import get_logger from scgenerator.logger import get_logger
from scgenerator.math import abs2 from scgenerator.math import abs2
from scgenerator.operators import SpecOperator, VariableQuantity from scgenerator.operators import SpecOperator, VariableQuantity
from scgenerator.physics.pulse import photon_number
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -120,6 +121,130 @@ def pi_step_factor(error: float, last_error: float, order: int, eps: float = 0.8
return 1 + np.arctan(fac - 1) return 1 + np.arctan(fac - 1)
def solve_rk4ip_cq(
spec: np.ndarray,
linear: VariableQuantity,
nonlinear: SpecOperator,
z_max: float,
atol: float,
rtol: float,
safety: float,
cq: Callable[[np.ndarray], float],
h_const: float | None = None,
targets: Sequence[float] | None = None,
) -> Iterator[tuple[np.ndarray, dict[str, Any]]]:
"""
Solve the GNLSE using a Runge-Kutta of order 4 in the interaction picture. Error estimation is
done via photon number estimation.
Parameters
----------
spec : np.ndarray
initial spectrum
linear : Operator
linear operator
nonlinear : Operator
nonlinear operator
z_max : float
stop propagation when z >= z_max (the last step is not guaranteed to be exactly on z_max)
atol : float
absolute tolerance
rtol : float
relative tolerance
safety : float
safety factor when computing new step size
h_const : float | None, optional
constant step size to use, by default None (automatic step size based on atol and rtol)
Yields
------
np.ndarray
last computed spectrum
dict[str, Any]
stats about the last step, including `z`
"""
if h_const is not None:
h = h_const
const_step_size = True
else:
h = 0.000664237859 # from Luna
const_step_size = False
z = 0
stats = {}
rejected = []
if targets is not None:
if len(targets) <= 1:
return
targets = list(sorted(set(targets)))
z = targets[0]
if not const_step_size:
h = min(h, (targets[1] - targets[0]) / 2)
targets.pop(0)
step_ind = 0
last_cq = cq(abs2(spec))
error = 0
last_error = 0
store_next = False
def stats():
return dict(z=z, rejected=rejected.copy(), error=error, h=h)
yield spec, stats() | dict(h=0)
while True:
expD = np.exp(h * 0.5 * linear(z))
A_I = expD * spec
k1 = expD * (h * nonlinear(spec, z))
k2 = h * nonlinear(A_I + 0.5 * k1, z + 0.5 * h)
k3 = h * nonlinear(A_I + 0.5 * k2, z + 0.5 * h)
k4 = h * nonlinear(expD * (A_I + k3), z + h)
new_spec = expD * (A_I + k1 / 6 + k2 / 3 + k3 / 3) + k4 / 6
new_cq = cq(abs2(new_spec))
error = abs(last_cq - new_cq) / last_cq
print(h, error)
if error == 0: # solution is exact if no nonlinerity is included
next_h_factor = 1.5
elif 0 < error <= rtol:
next_h_factor = safety * pi_step_factor(error, last_error, 4, 0.8)
else:
next_h_factor = max(0.1, safety * error ** (-0.25))
if const_step_size or error <= 1:
spec = new_spec
last_cq = new_cq
z += h
step_ind += 1
last_error = error
if targets is None or store_next:
if targets is not None:
targets.pop(0)
yield spec, stats()
rejected.clear()
if z >= z_max:
return
if const_step_size:
continue
else:
rejected.append((h, error))
logger.info(f"{z = :.3f} rejected step {step_ind} with {h = :.2g}, {error = :.2g}")
h = h * next_h_factor
if targets is not None and z + h > targets[0]:
h = targets[0] - z
store_next = True
else:
store_next = False
def solve43( def solve43(
spec: np.ndarray, spec: np.ndarray,
linear: VariableQuantity, linear: VariableQuantity,
@@ -259,6 +384,7 @@ def integrate(
safety: float = 0.9, safety: float = 0.9,
targets: Sequence[float] | None = None, targets: Sequence[float] | None = None,
) -> SimulationResult: ) -> SimulationResult:
"""legacy function"""
spec0 = initial_spectrum.copy() spec0 = initial_spectrum.copy()
all_spectra = [] all_spectra = []
stats = defaultdict(list) stats = defaultdict(list)

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
import os import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property, partial
from pathlib import Path from pathlib import Path
from typing import Callable, Generic, Iterable, Iterator, Sequence, TypeVar, overload from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
import numpy as np import numpy as np
import scipy.signal as ss import scipy.signal as ss
from tqdm import tqdm
from scgenerator import math from scgenerator import math
from scgenerator.io import ( from scgenerator.io import (
@@ -22,6 +23,7 @@ from scgenerator.io import (
from scgenerator.logger import get_logger from scgenerator.logger import get_logger
from scgenerator.parameter import Parameters from scgenerator.parameter import Parameters
from scgenerator.physics import pulse, units from scgenerator.physics import pulse, units
from scgenerator.solver import solve43, solve_rk4ip_cq
PARAMS_FN = "params.json" PARAMS_FN = "params.json"
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None) ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
@@ -208,7 +210,7 @@ class Spectrum(np.ndarray):
returns the coherence of the spectrum, computed by collapsing axis `axis` and aligned returns the coherence of the spectrum, computed by collapsing axis `axis` and aligned
on the wavelength grid on the wavelength grid
""" """
return pulse.g12(self, axis)[..., self.wl_order] return pulse.g12(self, axis)[..., self.l_order]
def spectrogram( def spectrogram(
self, self,
@@ -366,6 +368,39 @@ class Propagation(Generic[ParamsOrNone]):
def load_all(self) -> Spectrum: def load_all(self) -> Spectrum:
return self._load_slice(slice(None)) return self._load_slice(slice(None))
def simulate(self) -> Propagation:
"""
Run the simulations as specified by the Parameters
Returns
-------
Propagation
returns itself, so that you can use this method as a 'builder' method:
`popagation = sc.propagation(parameters).simulate()`
"""
params = self.parameters
with warnings.catch_warnings(), tqdm(total=params.z_num) as pbar:
warnings.filterwarnings("error")
for spec, _ in solve43(
params.spec_0,
params.linear_operator,
params.nonlinear_operator,
params.length,
params.tolerated_error,
params.tolerated_error,
0.95,
targets=params.z_targets,
# cq=partial(
# pulse.photon_number,
# w=params.w,
# dw=params.w[1] - params.w[0],
# gamma=params.gamma,
# ),
):
self.append(spec)
pbar.update()
return self
@dataclass @dataclass
class PropagationCollection: class PropagationCollection: