working on old/new cq solver
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
|||||||
|
build/
|
||||||
.DS_store
|
.DS_store
|
||||||
.idea
|
.idea
|
||||||
.conda-env
|
.conda-env
|
||||||
|
|||||||
55
examples/compute_coherence.py
Normal file
55
examples/compute_coherence.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import multiprocessing
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import scgenerator as sc
|
||||||
|
|
||||||
|
PARAMS = Path("./examples/noisy.toml")
|
||||||
|
SEED = 2564
|
||||||
|
|
||||||
|
|
||||||
|
def path(i) -> Path:
|
||||||
|
return Path(f"build/{Path(__file__).stem}_{i}.zip")
|
||||||
|
|
||||||
|
|
||||||
|
def propagate(i):
|
||||||
|
np.random.seed(SEED + i)
|
||||||
|
params = sc.Parameters.load("./examples/noisy.toml")
|
||||||
|
|
||||||
|
sc.propagation(path(i), params.compile()).simulate()
|
||||||
|
|
||||||
|
|
||||||
|
def propagate_all(n):
|
||||||
|
to_do = [i for i in range(n) if not path(i).exists()]
|
||||||
|
with multiprocessing.Pool(4) as pool:
|
||||||
|
pool.map(propagate, to_do)
|
||||||
|
|
||||||
|
spec, props = sc.propagation_series([path(i) for i in range(n)])
|
||||||
|
|
||||||
|
return spec, props
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
n = 1
|
||||||
|
spec, props = propagate_all(n)
|
||||||
|
print(spec.shape)
|
||||||
|
|
||||||
|
wl, ind, _ = sc.PlotRange(
|
||||||
|
spec.wl_disp[spec.wl_disp > 0].min() * 1e9,
|
||||||
|
props.parameters.wavelength_window[1] * 1e9,
|
||||||
|
"nm",
|
||||||
|
).sort_axis(spec.wl_disp)
|
||||||
|
|
||||||
|
for i in range(spec.shape[1]):
|
||||||
|
fig, (left, right) = plt.subplots(1, 2)
|
||||||
|
for s in spec[:, i].time_int:
|
||||||
|
left.plot(spec.t, s)
|
||||||
|
for s in spec[:, i].wl_int:
|
||||||
|
right.plot(wl, s[ind])
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
26
examples/noisy.toml
Normal file
26
examples/noisy.toml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
name = "Sierro2020 Noisy"
|
||||||
|
width = 50e-15
|
||||||
|
shape = "gaussian"
|
||||||
|
peak_power = 100e3
|
||||||
|
wavelength = 1050e-9
|
||||||
|
|
||||||
|
gamma = 0.018
|
||||||
|
beta2_coefficients = [
|
||||||
|
1.001190e-26,
|
||||||
|
-2.131124e-41,
|
||||||
|
3.286793e-55,
|
||||||
|
-1.290523e-69,
|
||||||
|
1.047255e-84,
|
||||||
|
1.696410e-98,
|
||||||
|
-9.261236e-113,
|
||||||
|
2.149311e-127,
|
||||||
|
-2.028394e-142,
|
||||||
|
]
|
||||||
|
length = 0.1
|
||||||
|
raman_type = "measured"
|
||||||
|
input_transmission = 1.0
|
||||||
|
|
||||||
|
wavelength_window = [550e-9, 2000e-9]
|
||||||
|
t_num = 8192
|
||||||
|
quantum_noise = true
|
||||||
|
z_num = 21
|
||||||
@@ -40,7 +40,6 @@ MANDATORY_PARAMETERS = {
|
|||||||
"spec_0",
|
"spec_0",
|
||||||
"c_to_a_factor",
|
"c_to_a_factor",
|
||||||
"field_0",
|
"field_0",
|
||||||
"mean_power",
|
|
||||||
"input_transmission",
|
"input_transmission",
|
||||||
"z_targets",
|
"z_targets",
|
||||||
"length",
|
"length",
|
||||||
|
|||||||
@@ -559,7 +559,7 @@ envelope_rules = default_rules + [
|
|||||||
Rule("c_to_a_factor", pulse.c_to_a_factor),
|
Rule("c_to_a_factor", pulse.c_to_a_factor),
|
||||||
# Dispersion
|
# Dispersion
|
||||||
Rule("beta2_coefficients", fiber.auto_dispersion_coefficients),
|
Rule("beta2_coefficients", fiber.auto_dispersion_coefficients),
|
||||||
Rule("beta2_coefficients", fiber.handle_dispersion_paramter),
|
Rule("beta2_coefficients", fiber.handle_dispersion_parameter),
|
||||||
Rule("beta2_arr", fiber.dispersion_from_coefficients),
|
Rule("beta2_arr", fiber.dispersion_from_coefficients),
|
||||||
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
|
Rule("beta2", lambda beta2_coefficients: beta2_coefficients[0]),
|
||||||
Rule("beta2", lambda beta2_arr, w0_ind: beta2_arr[w0_ind]),
|
Rule("beta2", lambda beta2_arr, w0_ind: beta2_arr[w0_ind]),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import datetime as datetime_module
|
import datetime as datetime_module
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import tomllib
|
||||||
import warnings
|
import warnings
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
@@ -429,6 +430,17 @@ class Parameters:
|
|||||||
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
datetime: datetime_module.datetime = Parameter(type_checker(datetime_module.datetime))
|
||||||
version: str = Parameter(string)
|
version: str = Parameter(string)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, s: str) -> Parameters:
|
||||||
|
decoded = tomllib.loads(s)
|
||||||
|
extras = set(decoded) - cls._p_names
|
||||||
|
if len(extras) > 0:
|
||||||
|
warnings.warn(f"extra keys (ignored) in parameter json: {extras!r}")
|
||||||
|
for e in extras:
|
||||||
|
decoded.pop(e)
|
||||||
|
|
||||||
|
return cls(**decoded)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, s: str) -> Parameters:
|
def from_json(cls, s: str) -> Parameters:
|
||||||
decoded = json.loads(s, object_hook=custom_decode_hook)
|
decoded = json.loads(s, object_hook=custom_decode_hook)
|
||||||
@@ -442,7 +454,12 @@ class Parameters:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: os.PathLike) -> Parameters:
|
def load(cls, path: os.PathLike) -> Parameters:
|
||||||
return cls.from_json(Path(path).read_text())
|
path = Path(path)
|
||||||
|
if path.suffix == ".toml":
|
||||||
|
decode = cls.from_toml
|
||||||
|
else:
|
||||||
|
decode = cls.from_json
|
||||||
|
return decode(path.read_text())
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
|
return "Parameter(" + ", ".join(self.__repr_list__()) + ")"
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ def dispersion_slope_to_beta3(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_dispersion_paramter(
|
def handle_dispersion_parameter(
|
||||||
wavelength: float, dispersion_parameter: float, dispersion_slope: float | None = None
|
wavelength: float, dispersion_parameter: float, dispersion_slope: float | None = None
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
coeffs = [dispersion_parameter_to_beta2(dispersion_parameter, wavelength)]
|
coeffs = [dispersion_parameter_to_beta2(dispersion_parameter, wavelength)]
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ def create_zoom_axis(
|
|||||||
return inset
|
return inset
|
||||||
|
|
||||||
|
|
||||||
def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0.05, **text_kwargs):
|
def corner_annotation(text, ax, position="tl", pts_x=8, pts_y=8, **text_kwargs):
|
||||||
"""puts an annotatin in a corner of an ax
|
"""puts an annotatin in a corner of an ax
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -250,24 +250,26 @@ def corner_annotation(text, ax, position="tl", rel_x_offset=0.05, rel_y_offset=0
|
|||||||
# yoff = length(ylim) * rel_y_offset
|
# yoff = length(ylim) * rel_y_offset
|
||||||
|
|
||||||
if position[0] == "t":
|
if position[0] == "t":
|
||||||
y = 1 - rel_y_offset
|
y = 1
|
||||||
|
pts_y = -pts_y
|
||||||
va = "top"
|
va = "top"
|
||||||
else:
|
else:
|
||||||
y = 0 + rel_y_offset
|
y = 0
|
||||||
va = "bottom"
|
va = "bottom"
|
||||||
if position[1] == "l":
|
if position[1] == "l":
|
||||||
x = 0 + rel_x_offset
|
x = 0
|
||||||
ha = "left"
|
ha = "left"
|
||||||
else:
|
else:
|
||||||
x = 1 - rel_x_offset
|
x = 1
|
||||||
|
pts_x = -pts_x
|
||||||
ha = "right"
|
ha = "right"
|
||||||
|
|
||||||
ax.annotate(
|
ax.annotate(
|
||||||
text,
|
text,
|
||||||
(x, y),
|
(x, y),
|
||||||
(x, y),
|
(pts_x, pts_y),
|
||||||
xycoords="axes fraction",
|
xycoords="axes fraction",
|
||||||
textcoords="axes fraction",
|
textcoords="offset points",
|
||||||
verticalalignment=va,
|
verticalalignment=va,
|
||||||
horizontalalignment=ha,
|
horizontalalignment=ha,
|
||||||
**text_kwargs,
|
**text_kwargs,
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import warnings
|
|||||||
import zipfile
|
import zipfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterator, Sequence
|
from typing import Any, Callable, Iterator, Sequence
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -16,6 +16,7 @@ from scgenerator.io import TimedMessage
|
|||||||
from scgenerator.logger import get_logger
|
from scgenerator.logger import get_logger
|
||||||
from scgenerator.math import abs2
|
from scgenerator.math import abs2
|
||||||
from scgenerator.operators import SpecOperator, VariableQuantity
|
from scgenerator.operators import SpecOperator, VariableQuantity
|
||||||
|
from scgenerator.physics.pulse import photon_number
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -120,6 +121,130 @@ def pi_step_factor(error: float, last_error: float, order: int, eps: float = 0.8
|
|||||||
return 1 + np.arctan(fac - 1)
|
return 1 + np.arctan(fac - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def solve_rk4ip_cq(
|
||||||
|
spec: np.ndarray,
|
||||||
|
linear: VariableQuantity,
|
||||||
|
nonlinear: SpecOperator,
|
||||||
|
z_max: float,
|
||||||
|
atol: float,
|
||||||
|
rtol: float,
|
||||||
|
safety: float,
|
||||||
|
cq: Callable[[np.ndarray], float],
|
||||||
|
h_const: float | None = None,
|
||||||
|
targets: Sequence[float] | None = None,
|
||||||
|
) -> Iterator[tuple[np.ndarray, dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Solve the GNLSE using a Runge-Kutta of order 4 in the interaction picture. Error estimation is
|
||||||
|
done via photon number estimation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : np.ndarray
|
||||||
|
initial spectrum
|
||||||
|
linear : Operator
|
||||||
|
linear operator
|
||||||
|
nonlinear : Operator
|
||||||
|
nonlinear operator
|
||||||
|
z_max : float
|
||||||
|
stop propagation when z >= z_max (the last step is not guaranteed to be exactly on z_max)
|
||||||
|
atol : float
|
||||||
|
absolute tolerance
|
||||||
|
rtol : float
|
||||||
|
relative tolerance
|
||||||
|
safety : float
|
||||||
|
safety factor when computing new step size
|
||||||
|
h_const : float | None, optional
|
||||||
|
constant step size to use, by default None (automatic step size based on atol and rtol)
|
||||||
|
|
||||||
|
Yields
|
||||||
|
------
|
||||||
|
np.ndarray
|
||||||
|
last computed spectrum
|
||||||
|
dict[str, Any]
|
||||||
|
stats about the last step, including `z`
|
||||||
|
"""
|
||||||
|
if h_const is not None:
|
||||||
|
h = h_const
|
||||||
|
const_step_size = True
|
||||||
|
else:
|
||||||
|
h = 0.000664237859 # from Luna
|
||||||
|
const_step_size = False
|
||||||
|
z = 0
|
||||||
|
stats = {}
|
||||||
|
rejected = []
|
||||||
|
if targets is not None:
|
||||||
|
if len(targets) <= 1:
|
||||||
|
return
|
||||||
|
targets = list(sorted(set(targets)))
|
||||||
|
z = targets[0]
|
||||||
|
if not const_step_size:
|
||||||
|
h = min(h, (targets[1] - targets[0]) / 2)
|
||||||
|
targets.pop(0)
|
||||||
|
|
||||||
|
step_ind = 0
|
||||||
|
last_cq = cq(abs2(spec))
|
||||||
|
error = 0
|
||||||
|
last_error = 0
|
||||||
|
store_next = False
|
||||||
|
|
||||||
|
def stats():
|
||||||
|
return dict(z=z, rejected=rejected.copy(), error=error, h=h)
|
||||||
|
|
||||||
|
yield spec, stats() | dict(h=0)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
expD = np.exp(h * 0.5 * linear(z))
|
||||||
|
|
||||||
|
A_I = expD * spec
|
||||||
|
k1 = expD * (h * nonlinear(spec, z))
|
||||||
|
k2 = h * nonlinear(A_I + 0.5 * k1, z + 0.5 * h)
|
||||||
|
k3 = h * nonlinear(A_I + 0.5 * k2, z + 0.5 * h)
|
||||||
|
k4 = h * nonlinear(expD * (A_I + k3), z + h)
|
||||||
|
|
||||||
|
new_spec = expD * (A_I + k1 / 6 + k2 / 3 + k3 / 3) + k4 / 6
|
||||||
|
|
||||||
|
new_cq = cq(abs2(new_spec))
|
||||||
|
error = abs(last_cq - new_cq) / last_cq
|
||||||
|
print(h, error)
|
||||||
|
|
||||||
|
if error == 0: # solution is exact if no nonlinerity is included
|
||||||
|
next_h_factor = 1.5
|
||||||
|
elif 0 < error <= rtol:
|
||||||
|
next_h_factor = safety * pi_step_factor(error, last_error, 4, 0.8)
|
||||||
|
else:
|
||||||
|
next_h_factor = max(0.1, safety * error ** (-0.25))
|
||||||
|
|
||||||
|
if const_step_size or error <= 1:
|
||||||
|
spec = new_spec
|
||||||
|
last_cq = new_cq
|
||||||
|
z += h
|
||||||
|
|
||||||
|
step_ind += 1
|
||||||
|
last_error = error
|
||||||
|
|
||||||
|
if targets is None or store_next:
|
||||||
|
if targets is not None:
|
||||||
|
targets.pop(0)
|
||||||
|
yield spec, stats()
|
||||||
|
|
||||||
|
rejected.clear()
|
||||||
|
if z >= z_max:
|
||||||
|
return
|
||||||
|
if const_step_size:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
rejected.append((h, error))
|
||||||
|
logger.info(f"{z = :.3f} rejected step {step_ind} with {h = :.2g}, {error = :.2g}")
|
||||||
|
|
||||||
|
h = h * next_h_factor
|
||||||
|
|
||||||
|
if targets is not None and z + h > targets[0]:
|
||||||
|
h = targets[0] - z
|
||||||
|
store_next = True
|
||||||
|
else:
|
||||||
|
store_next = False
|
||||||
|
|
||||||
|
|
||||||
def solve43(
|
def solve43(
|
||||||
spec: np.ndarray,
|
spec: np.ndarray,
|
||||||
linear: VariableQuantity,
|
linear: VariableQuantity,
|
||||||
@@ -259,6 +384,7 @@ def integrate(
|
|||||||
safety: float = 0.9,
|
safety: float = 0.9,
|
||||||
targets: Sequence[float] | None = None,
|
targets: Sequence[float] | None = None,
|
||||||
) -> SimulationResult:
|
) -> SimulationResult:
|
||||||
|
"""legacy function"""
|
||||||
spec0 = initial_spectrum.copy()
|
spec0 = initial_spectrum.copy()
|
||||||
all_spectra = []
|
all_spectra = []
|
||||||
stats = defaultdict(list)
|
stats = defaultdict(list)
|
||||||
|
|||||||
@@ -3,12 +3,13 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property, partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Generic, Iterable, Iterator, Sequence, TypeVar, overload
|
from typing import Callable, Generic, Iterator, Sequence, TypeVar, overload
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.signal as ss
|
import scipy.signal as ss
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from scgenerator import math
|
from scgenerator import math
|
||||||
from scgenerator.io import (
|
from scgenerator.io import (
|
||||||
@@ -22,6 +23,7 @@ from scgenerator.io import (
|
|||||||
from scgenerator.logger import get_logger
|
from scgenerator.logger import get_logger
|
||||||
from scgenerator.parameter import Parameters
|
from scgenerator.parameter import Parameters
|
||||||
from scgenerator.physics import pulse, units
|
from scgenerator.physics import pulse, units
|
||||||
|
from scgenerator.solver import solve43, solve_rk4ip_cq
|
||||||
|
|
||||||
PARAMS_FN = "params.json"
|
PARAMS_FN = "params.json"
|
||||||
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
|
ParamsOrNone = TypeVar("ParamsOrNone", Parameters, None)
|
||||||
@@ -208,7 +210,7 @@ class Spectrum(np.ndarray):
|
|||||||
returns the coherence of the spectrum, computed by collapsing axis `axis` and aligned
|
returns the coherence of the spectrum, computed by collapsing axis `axis` and aligned
|
||||||
on the wavelength grid
|
on the wavelength grid
|
||||||
"""
|
"""
|
||||||
return pulse.g12(self, axis)[..., self.wl_order]
|
return pulse.g12(self, axis)[..., self.l_order]
|
||||||
|
|
||||||
def spectrogram(
|
def spectrogram(
|
||||||
self,
|
self,
|
||||||
@@ -366,6 +368,39 @@ class Propagation(Generic[ParamsOrNone]):
|
|||||||
def load_all(self) -> Spectrum:
|
def load_all(self) -> Spectrum:
|
||||||
return self._load_slice(slice(None))
|
return self._load_slice(slice(None))
|
||||||
|
|
||||||
|
def simulate(self) -> Propagation:
|
||||||
|
"""
|
||||||
|
Run the simulations as specified by the Parameters
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Propagation
|
||||||
|
returns itself, so that you can use this method as a 'builder' method:
|
||||||
|
`popagation = sc.propagation(parameters).simulate()`
|
||||||
|
"""
|
||||||
|
params = self.parameters
|
||||||
|
with warnings.catch_warnings(), tqdm(total=params.z_num) as pbar:
|
||||||
|
warnings.filterwarnings("error")
|
||||||
|
for spec, _ in solve43(
|
||||||
|
params.spec_0,
|
||||||
|
params.linear_operator,
|
||||||
|
params.nonlinear_operator,
|
||||||
|
params.length,
|
||||||
|
params.tolerated_error,
|
||||||
|
params.tolerated_error,
|
||||||
|
0.95,
|
||||||
|
targets=params.z_targets,
|
||||||
|
# cq=partial(
|
||||||
|
# pulse.photon_number,
|
||||||
|
# w=params.w,
|
||||||
|
# dw=params.w[1] - params.w[0],
|
||||||
|
# gamma=params.gamma,
|
||||||
|
# ),
|
||||||
|
):
|
||||||
|
self.append(spec)
|
||||||
|
pbar.update()
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PropagationCollection:
|
class PropagationCollection:
|
||||||
|
|||||||
Reference in New Issue
Block a user