diff --git a/.gitignore b/.gitignore
index d5b7e20..a6dba3d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
pyrightconfig.json
+.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d0714e6
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,16 @@
+dispersionapp: a pogramm to interactively play with fiber parameters
+and see the result on the dispersion
+Copyright (C) 2023 Benoît Sierro
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program. If not, see .
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..22fb14c
--- /dev/null
+++ b/README.md
@@ -0,0 +1,16 @@
+# Crash course on Python installation, with a few small side tracks
+
+
+
+# Installation
+1. If not already done so, create a virtual environment to contain all the packages required. Python 3.10 is recommended. The command below will create a new environment named 'app-env' with the latest version of Python 3.10.
+
+conda create -y -n app-env python=3.10
+
+2. activate said environment
+
+conda activate app-env
+
+3. The prompt should now read '(app-env)' on the left. The app is not published on Github or anywhere else. The link below points to my own personnal home server (I didn't find any way of getting a direct download link You are now ready to install everything with this command:
+
+pip install https://bao.dedebenui.me/s/f3KgiqMq7giN73i/download
diff --git a/config.toml b/config.toml
deleted file mode 100644
index 9d79bf2..0000000
--- a/config.toml
+++ /dev/null
@@ -1,7 +0,0 @@
-wl_min = 160
-wl_max = 1600
-wl_pump = 800
-
-rep_rate = 8e3
-
-safety_factor = 10
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..7e2f23e
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,29 @@
+[project]
+name = "dispersionapp"
+version = "0.1.0"
+description = "Model hollow capillary and revolver fiber interactively"
+authors = [{ name = "Benoît Sierro", email = "benoit.sierro@unibe.ch" }]
+
+dependencies = [
+ "scgenerator @ git+https://github.com/bsierro/scgenerator.git",
+ "click",
+ "pydantic",
+ "tomli",
+ "tomli_w",
+ "PySide6 >= 6.4.0",
+ "pyqtgraph >= 0.13.1",
+]
+
+license = { file = "LICENSE" }
+
+classifiers = [
+ "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
+]
+
+[project.scripts]
+
+dispersionapp = "dispersionapp.__main__:main"
+
+[build-system]
+build-backend = "setuptools.build_meta"
+requires = ["setuptools >= 65.0.0"]
diff --git a/src/dispersionapp/__init__.py b/src/dispersionapp/__init__.py
new file mode 100755
index 0000000..66588c4
--- /dev/null
+++ b/src/dispersionapp/__init__.py
@@ -0,0 +1,3 @@
+import importlib
+
+__version__ = importlib.metadata.version("dispersionapp")
diff --git a/src/dispersionapp/__main__.py b/src/dispersionapp/__main__.py
new file mode 100644
index 0000000..92c202a
--- /dev/null
+++ b/src/dispersionapp/__main__.py
@@ -0,0 +1,24 @@
+import os
+import sys
+
+import click
+
+from dispersionapp.core import DEFAULT_CONFIG_FILE
+from dispersionapp.gui import app
+
+
+@click.command()
+@click.option(
+ "-c",
+ "--config",
+ default=DEFAULT_CONFIG_FILE,
+ help="configuration file in TOML format",
+ type=click.Path(file_okay=True, dir_okay=False, resolve_path=True),
+)
+@click.version_option()
+def main(config: os.PathLike):
+ app(config)
+
+
+if __name__ == "__main__" and not hasattr(sys, "ps1"):
+ main()
diff --git a/src/dispersionapp/core.py b/src/dispersionapp/core.py
new file mode 100644
index 0000000..3d81ff9
--- /dev/null
+++ b/src/dispersionapp/core.py
@@ -0,0 +1,131 @@
+from __future__ import annotations
+
+import sys
+from pathlib import Path
+from typing import Any, NamedTuple
+
+import numpy as np
+import scgenerator as sc
+import tomli
+import tomli_w
+from pydantic import BaseModel, PrivateAttr, ValidationError, confloat
+
+DEFAULT_CONFIG_FILE = "config.toml"
+
+
+class CurrentState(BaseModel):
+ core_diameter_um: float
+ pressure_mbar: float
+ wall_thickness_um: float
+ n_tubes: int
+ gap_um: float
+ t_fwhm_fs: float
+
+
+class Config(BaseModel):
+ wl_min: confloat(ge=100, le=1000)
+ wl_max: confloat(ge=500, le=6000)
+ wl_pump: confloat(ge=200, le=6000)
+ rep_rate: confloat(gt=0)
+ gas: str
+ safety_factor: float
+ current_state: CurrentState | None = None
+
+ _file_name: Path = PrivateAttr()
+
+ @classmethod
+ def load(cls, config_file: str | None = None) -> Config:
+ config_file = Path(config_file or DEFAULT_CONFIG_FILE)
+ if not config_file.exists():
+ config_file.touch()
+
+ with open(config_file, "rb") as file:
+ d = tomli.load(file)
+ d = cls.default() | d
+ try:
+ out = cls(**d)
+ except ValidationError as e:
+ s = f"invalid input in config file {config_file}:\n{e}"
+ print(s)
+ sys.exit(1)
+ out._file_name = Path(config_file)
+ return out
+
+ @classmethod
+ def default(cls) -> dict[str, Any]:
+ return dict(
+ wl_min=160, wl_max=1600, wl_pump=800, rep_rate=8e3, gas="argon", safety_factor=10
+ )
+
+ def save(self):
+ tmp = self._file_name.parent / f"{self._file_name.name}.tmp"
+ with open(tmp, "wb") as file:
+ tomli_w.dump(self.dict(), file)
+ tmp.rename(self._file_name)
+
+ def update_current(
+ self, core_diameter_um, pressure_mbar, wall_thickness_um, n_tubes, gap_um, t_fwhm_fs
+ ):
+ self.current_state = CurrentState(
+ core_diameter_um=core_diameter_um,
+ pressure_mbar=pressure_mbar,
+ wall_thickness_um=wall_thickness_um,
+ n_tubes=n_tubes,
+ gap_um=gap_um,
+ t_fwhm_fs=t_fwhm_fs,
+ )
+ self.save()
+
+
+class LimitValues(NamedTuple):
+ wl_zero_disp: float
+ ion_lim: float
+ sf_lim: float
+
+
+def b2(w, n_eff):
+ dw = w[1] - w[0]
+ beta = sc.fiber.beta(w, n_eff)
+ return sc.math.differentiate_arr(beta, 2, 4, dw)
+
+
+def N_sf_max(
+ wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0
+) -> np.ndarray:
+ """
+ maximum soliton number according to self focusing
+
+ eq. S15 in Travers2019
+ """
+ delta_gas = gas.sellmeier.delta(wl, wl_zero_disp)
+ return t0 * np.sqrt(wl / (safety * np.abs(delta_gas)))
+
+
+def N_ion_max(
+ wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0
+) -> np.ndarray:
+ """
+ eq. S16 in Travers2019
+ """
+ ind = sc.math.argclosest(wl, wl_zero_disp)
+ f = np.gradient(np.gradient(gas.sellmeier.chi(wl), wl), wl)
+ factor = (sc.math.u_nm(1, 1) / sc.units.c) ** 2 * (0.5 * wl / np.pi) ** 3
+ delta = factor * (f / f[ind] - 1)
+ denom = safety * np.pi * wl * np.abs(delta) * f[ind]
+ return t0 * sc.math.u_nm(1, 1) * np.sqrt(gas.n2() * gas.barrier_suppression / denom)
+
+
+def solition_num(
+ t0: float, w0: float, beta2: float, n2: float, core_radius: float, peak_power: float
+) -> float:
+ gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius))
+ ld = sc.pulse.L_D(t0, beta2)
+ return np.sqrt(gamma * ld * peak_power)
+
+
+def energy(
+ t0: float, w0: float, beta2: float, n2: float, core_radius: float, solition_num: float
+) -> float:
+ gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius))
+ peak_power = solition_num**2 * abs(beta2) / (t0**2 * gamma)
+ return sc.pulse.P0_to_E0(peak_power, t0, "sech")
diff --git a/dispersion_app.py b/src/dispersionapp/gui.py
old mode 100755
new mode 100644
similarity index 57%
rename from dispersion_app.py
rename to src/dispersionapp/gui.py
index 0b93c30..a097110
--- a/dispersion_app.py
+++ b/src/dispersionapp/gui.py
@@ -1,98 +1,14 @@
from __future__ import annotations
import os
-import sys
-import warnings
-from functools import cache
-from typing import Any, NamedTuple
-import click
import numpy as np
import scgenerator as sc
-import tomli
-from customfunc.app import PlotApp
-from pydantic import BaseModel, ValidationError, confloat
+from functools import cache
+import warnings
-DEFAULT_CONFIG_FILE = "config.toml"
-
-
-class Config(BaseModel):
- wl_min: confloat(ge=100, le=1000)
- wl_max: confloat(ge=500, le=6000)
- wl_pump: confloat(ge=200, le=6000)
- rep_rate: confloat(gt=0)
- gas: str
-
- @classmethod
- def load(cls, config_file: str | None = None) -> Config:
- config_file = config_file or DEFAULT_CONFIG_FILE
- with open(config_file, "rb") as file:
- d = tomli.load(file)
- d = cls.default() | d
- try:
- return cls(**d)
- except ValidationError as e:
- s = f"invalid input in config file {config_file}:\n{e}"
- print(s)
- sys.exit(1)
-
- @classmethod
- def default(cls) -> dict[str, Any]:
- return dict(wl_min=160, wl_max=1600, wl_pump=800, rep_rate=8e3, gas="argon")
-
-
-class LimitValues(NamedTuple):
- wl_zero_disp: float
- ion_lim: float
- sf_lim: float
-
-
-def b2(w, n_eff):
- dw = w[1] - w[0]
- beta = sc.fiber.beta(w, n_eff)
- return sc.math.differentiate_arr(beta, 2, 4, dw)
-
-
-def N_sf_max(
- wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0
-) -> np.ndarray:
- """
- maximum soliton number according to self focusing
-
- eq. S15 in Travers2019
- """
- delta_gas = gas.sellmeier.delta(wl, wl_zero_disp)
- return t0 * np.sqrt(wl / (safety * np.abs(delta_gas)))
-
-
-def N_ion_max(
- wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0
-) -> np.ndarray:
- """
- eq. S16 in Travers2019
- """
- ind = sc.math.argclosest(wl, wl_zero_disp)
- f = np.gradient(np.gradient(gas.sellmeier.chi(wl), wl), wl)
- factor = (sc.math.u_nm(1, 1) / sc.units.c) ** 2 * (0.5 * wl / np.pi) ** 3
- delta = factor * (f / f[ind] - 1)
- denom = safety * np.pi * wl * np.abs(delta) * f[ind]
- return t0 * sc.math.u_nm(1, 1) * np.sqrt(gas.n2() * gas.barrier_suppression / denom)
-
-
-def solition_num(
- t0: float, w0: float, beta2: float, n2: float, core_radius: float, peak_power: float
-) -> float:
- gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius))
- ld = sc.pulse.L_D(t0, beta2)
- return np.sqrt(gamma * ld * peak_power)
-
-
-def energy(
- t0: float, w0: float, beta2: float, n2: float, core_radius: float, solition_num: float
-) -> float:
- gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius))
- peak_power = solition_num**2 * abs(beta2) / (t0**2 * gamma)
- return sc.pulse.P0_to_E0(peak_power, t0, "sech")
+from dispersionapp.core import Config, LimitValues, N_ion_max, N_sf_max, energy, b2
+from dispersionapp.plotapp import PlotApp
def app(config_file: os.PathLike | None = None):
@@ -118,14 +34,24 @@ def app(config_file: os.PathLike | None = None):
app[0].horizontal_line("reference", 0, color="gray")
app[0].set_xlabel("wavelength (nm)")
app[0].set_ylabel("beta2 (fs^2/cm)")
- app.params["wall_thickness_um"].value = 1
- app.params["core_diameter_um"].value = 100
- app.params["pressure_mbar"].value = 500
- app.params["n_tubes"].value = 7
- app.params["gap_um"].value = 5
- app.params["t_fwhm_fs"].value = 100
+ if config.current_state is not None:
+ app.params["core_diameter_um"].value = config.current_state.core_diameter_um
+ app.params["pressure_mbar"].value = config.current_state.pressure_mbar
+ app.params["wall_thickness_um"].value = config.current_state.wall_thickness_um
+ app.params["n_tubes"].value = config.current_state.n_tubes
+ app.params["gap_um"].value = config.current_state.gap_um
+ app.params["t_fwhm_fs"].value = config.current_state.t_fwhm_fs
+ else:
+ app.params["core_diameter_um"].value = 100
+ app.params["pressure_mbar"].value = 500
+ app.params["wall_thickness_um"].value = 1
+ app.params["n_tubes"].value = 7
+ app.params["gap_um"].value = 5
+ app.params["t_fwhm_fs"].value = 100
app[0].set_lim(ylim=(-4, 2))
+ app.update(config.update_current)
+
@cache
def compute_max_energy(
core_diameter_um: float, pressure_mbar: float, t_fwhm_fs: float
@@ -143,8 +69,8 @@ def app(config_file: os.PathLike | None = None):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
- ion_limit = N_ion_max(wl, t0, wl_zero_disp, gas)[wl_ind]
- sf_limit = N_sf_max(wl, t0, wl_zero_disp, gas)[wl_ind]
+ ion_limit = N_ion_max(wl, t0, wl_zero_disp, gas, config.safety_factor)[wl_ind]
+ sf_limit = N_sf_max(wl, t0, wl_zero_disp, gas, config.safety_factor)[wl_ind]
beta2 = disp[wl_ind]
n2 = gas.n2(pressure=pressure)
@@ -218,19 +144,3 @@ def app(config_file: os.PathLike | None = None):
zdw = lim.wl_zero_disp * 1e9
app[0].set_line_data("zdw", [zdw, zdw], [-3, 3])
app[0].set_line_name("zdw", f"ZDW = {zdw:.0f}nm")
-
-
-@click.command()
-@click.option(
- "-c",
- "--config",
- default=DEFAULT_CONFIG_FILE,
- help="configuration file in TOML format",
- type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True),
-)
-def main(config: os.PathLike):
- app(config)
-
-
-if __name__ == "__main__" and not hasattr(sys, "ps1"):
- main()
diff --git a/src/dispersionapp/plotapp.py b/src/dispersionapp/plotapp.py
new file mode 100644
index 0000000..a464e7b
--- /dev/null
+++ b/src/dispersionapp/plotapp.py
@@ -0,0 +1,722 @@
+from __future__ import annotations
+
+import inspect
+import itertools
+from collections.abc import MutableMapping, Sequence
+from functools import cache
+from types import MethodType
+from typing import Any, Callable, Iterable, Iterator, Optional, Type, Union, overload
+
+from PySide6 import QtCore, QtWidgets, QtGui
+import numpy as np
+import pyqtgraph as pg
+from pyqtgraph.dockarea import Dock, DockArea
+from pyqtgraph.graphicsItems.PlotDataItem import PlotDataItem
+
+MPL_COLORS = [
+ "#1f77b4",
+ "#ff7f0e",
+ "#2ca02c",
+ "#d62728",
+ "#9467bd",
+ "#8c564b",
+ "#e377c2",
+ "#7f7f7f",
+ "#bcbd22",
+ "#17becf",
+]
+
+key_type = Union[str, int]
+
+
+class Field(QtWidgets.QWidget):
+ dtype: Type
+ value: Any
+ value_changed: QtCore.Signal
+ timer: QtCore.QTimer
+
+ @property
+ def value(self) -> Any:
+ raise NotImplementedError()
+
+ def values(self) -> list[Any]:
+ raise NotImplementedError()
+
+
+class SliderField(Field):
+ dtype: Type
+ possible_values: np.ndarray
+ _slider_max = 100
+ _tuple_signal = QtCore.Signal(tuple)
+ _int_signal = QtCore.Signal(int)
+ _float_signal = QtCore.Signal(float)
+ _str_signal = QtCore.Signal(str)
+
+ def __init__(self, name: str, values: Iterable) -> None:
+ super().__init__()
+ self.__value = None
+ self.slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal)
+ self.slider.setSingleStep(1)
+ self.slider.valueChanged.connect(self.slider_changed)
+
+ self.field = QtWidgets.QLineEdit()
+ self.field.setMaximumWidth(150)
+ self.field.editingFinished.connect(self.field_changed)
+
+ self.step_backward_button = QtWidgets.QPushButton("<")
+ self.step_backward_button.clicked.connect(self.step_backward)
+
+ self.step_forward_button = QtWidgets.QPushButton(">")
+ self.step_forward_button.clicked.connect(self.step_forward)
+
+ self._layout = QtWidgets.QHBoxLayout()
+ self.setLayout(self._layout)
+ self._layout.setContentsMargins(10, 0, 10, 0)
+
+ pretty_name = " ".join(s.title() for s in name.split("_"))
+ self.name_label = QtWidgets.QLabel(pretty_name + " :")
+
+ self._layout.addWidget(self.name_label)
+ self._layout.addWidget(self.slider)
+ self._layout.addWidget(self.step_backward_button)
+ self._layout.addWidget(self.step_forward_button)
+ self._layout.addWidget(self.field)
+
+ self.set_values(values)
+ self.value_changed = {
+ int: self._int_signal,
+ float: self._float_signal,
+ str: self._str_signal,
+ tuple: self._tuple_signal,
+ }[self.dtype]
+ self.value_changed.emit(self.__value)
+
+ def set_values(self, values: Iterable):
+ self.possible_values = np.array(values)
+ self.value_to_slider_map = {v: i for i, v in enumerate(self.possible_values)}
+ self._slider_max = len(self.possible_values) - 1
+ self.slider.setRange(0, self._slider_max)
+ if self.__value not in self.value_to_slider_map:
+ self.__value = self.possible_values[0]
+ if isinstance(self.__value, (np.integer, int)):
+ new_dtype = int
+ elif isinstance(self.__value, (np.floating, float)):
+ new_dtype = float
+ elif isinstance(self.__value, tuple):
+ raise NotImplementedError("Tuples currently not supported")
+ elif isinstance(self.__value, str):
+ new_dtype = str
+ else:
+ raise TypeError(f"parameter of type {type(self.__value)!r} not possible")
+ if hasattr(self, "dtype") and self.dtype != new_dtype:
+ raise RuntimeError(f"Field {self.name} cannot change dtype after creation")
+ self.dtype = new_dtype
+ self.update_label()
+ self.update_slider()
+
+ @property
+ def value(self) -> float:
+ return self.__value
+
+ @value.setter
+ def value(self, new_value):
+ if new_value not in self.possible_values:
+ new_value = self.possible_values[np.argmin(np.abs(self.possible_values - new_value))]
+ self.__value = new_value
+ self.update_label()
+ self.slider.blockSignals(True)
+ self.slider.setValue(self.value_to_slider())
+ self.slider.blockSignals(False)
+ self.value_changed.emit(new_value)
+
+ def field_changed(self):
+ new_val = self.dtype(self.field.text())
+ if new_val not in self.value_to_slider_map:
+ self.update_label()
+ return
+ self.value = new_val
+
+ def slider_changed(self):
+ self.value = self.slider_to_value()
+
+ def update_label(self):
+ self.field.setText(str(self))
+
+ def update_slider(self):
+ self.slider.setValue(self.value_to_slider())
+
+ def step_forward(self):
+ next_value = self.slider.value() + 1
+ if next_value <= self._slider_max:
+ self.slider.setValue(next_value)
+
+ def step_backward(self):
+ next_value = self.slider.value() - 1
+ if next_value >= 0:
+ self.slider.setValue(next_value)
+
+ def slider_to_value(self, i=None) -> float:
+ if i is None:
+ i = self.slider.value()
+ return self.possible_values[i]
+
+ def value_to_slider(self) -> int:
+ return self.value_to_slider_map[self.__value]
+
+ def __str__(self) -> str:
+ if self.dtype is int:
+ return format(self.value)
+ elif self.dtype is float:
+ return format(self.value, ".3g")
+ else:
+ return format(self.value)
+
+ def values(self) -> list[float]:
+ return list(self.possible_values)
+
+
+class AnimatedSliderField(SliderField):
+ def __init__(self, name: str, values: Iterable) -> None:
+ super().__init__(name, values)
+ self.timer = QtCore.QTimer()
+ self.timer.timeout.connect(self.step_forward)
+
+ self.play_button = QtWidgets.QPushButton("⏯")
+ self.play_button.clicked.connect(self.toggle)
+ self.play_button.setMaximumWidth(30)
+ self.playing = False
+
+ self.interval_field = QtWidgets.QLineEdit()
+ self.interval_field.setMaximumWidth(60)
+ # self.interval_field.setValidator(QtGui.QIntValidator(1, 5000))
+ self.interval_field.editingFinished.connect(self.set_interval)
+ self.interval_field.inputRejected.connect(self.set_interval)
+ self.interval = 16
+ self.set_interval()
+
+ self._layout.addWidget(self.play_button)
+ self._layout.addWidget(self.interval_field)
+
+ def toggle(self):
+ if self.playing:
+ self.stop()
+ self.playing = False
+ else:
+ self.play()
+ self.playing = True
+
+ def play(self):
+ if self.slider.value() == self._slider_max:
+ self.slider.setValue(0)
+ self.timer.start(self.interval)
+ self.play_button.setStyleSheet("QPushButton {background-color: #DEF2DD;}")
+
+ def stop(self):
+ self.timer.stop()
+ self.play_button.setStyleSheet("QPushButton {background-color: none;}")
+
+ def set_interval(self):
+ try:
+ self.interval = max(1, int(self.interval_field.text()))
+ except ValueError:
+ self.interval_field.setText(str(self.interval))
+ if self.interval < 16:
+ self.increment = int(np.ceil(16 / self.interval))
+ self.interval *= self.increment
+ else:
+ self.increment = 1
+
+ def step_forward(self):
+ current = self.slider.value()
+ if current + self.increment <= self._slider_max:
+ self.slider.setValue(current + self.increment)
+ else:
+ self.slider.setValue(self._slider_max)
+ self.stop()
+
+
+class Plot:
+ name: str
+ dock: Dock
+ plot_widget: pg.PlotWidget
+ legend: pg.LegendItem
+ color_cycle: Iterable[str]
+ lines: dict[str, PlotDataItem]
+ pens: dict[str, QtGui.QPen]
+
+ def __init__(self, name: str, area: DockArea):
+ self.name = name
+ self.plot_widget = pg.PlotWidget()
+ self.plot_widget.setBackground("w")
+ self.dock = Dock(name)
+ self.legend = pg.LegendItem((80, 60), offset=(70, 20))
+ self.legend.setParentItem(self.plot_widget.plotItem)
+ self.legend.setLabelTextSize("12pt")
+ self.color_cycle = itertools.cycle(MPL_COLORS)
+ self.lines = {}
+ self.pens = {}
+ self.dock.addWidget(self.plot_widget)
+ area.addDock(self.dock)
+
+ def __getitem__(self, key: str) -> PlotDataItem:
+ if key not in self.lines:
+ plot: PlotDataItem = self.plot_widget.plot()
+ self.legend.addItem(plot, key)
+ self.lines[key] = plot
+ return self.lines[key]
+
+ def __setitem__(self, key: str, item: PlotDataItem):
+ self.add_item(key, item, item)
+
+ def add_item(self, key: str, curve: PlotDataItem, obj: pg.GraphicsObject):
+ self.legend.addItem(curve, key)
+ self.plot_widget.addItem(obj)
+ self.lines[key] = curve
+
+ def set_line_name(self, line_name: str, new_name: str):
+ """updates the displayed name of a line. Does not change the internal
+ name of the plot data though.
+
+ Parameters
+ ----------
+ name : str
+ original name of the line
+ new_name : str
+ new name
+ """
+ line = self[line_name]
+ self.legend.getLabel(line).setText(new_name)
+
+ def set_line_data(
+ self,
+ key: str,
+ x: Iterable[float],
+ y: Iterable[float] = None,
+ label: str = None,
+ **pen_kwargs,
+ ):
+ """plots the given data
+
+ Parameters
+ ----------
+ key : str
+ internal name of the line. If it already exists, the line is updated, otherwise
+ it is created
+ x : Iterable[float]
+ x data
+ y : Iterable[float], optional
+ y data. If not given, takes x as y data, by default None
+ label : str, optional
+ if given, updates the legend label, by default None
+ pen_kwargs : Any, optional
+ given to the pg.mkPen constructor
+ """
+ if y is None:
+ x, y = np.arange(len(x)), x
+ self[key].setData(x, y)
+ self[key].setPen(self.get_pen(key, **pen_kwargs))
+ if label is not None:
+ self.set_line_name(key, label)
+
+ def vertical_line(self, key: str, x: float, **pen_kwargs):
+ """plots a vertical, infinite line
+
+ Parameters
+ ----------
+ key : str
+ name of the line
+ x : float
+ position
+ """
+ pen = self.get_pen(key, **pen_kwargs)
+ line = pg.InfiniteLine(x, 90, pen)
+ legend_item = PlotDataItem((), pen=pen)
+ self.add_item(key, legend_item, line)
+
+ def horizontal_line(self, key: str, y: float, **pen_kwargs):
+ """plots a horizontal, infinite line
+
+ Parameters
+ ----------
+ key : str
+ name of the line
+ y : float
+ position
+ """
+ pen = self.get_pen(key, **pen_kwargs)
+ line = pg.InfiniteLine(y, 0, pen)
+ legend_item = PlotDataItem((), pen=pen)
+ self.add_item(key, legend_item, line)
+
+ def set_lim(self, *, xlim=None, ylim=None):
+
+ x_auto = xlim is None
+ y_auto = ylim is None
+ if not x_auto:
+ self.plot_widget.setXRange(*xlim)
+ if not y_auto:
+ self.plot_widget.setYRange(*ylim)
+ self.plot_widget.plotItem.enableAutoRange(x=x_auto, y=y_auto)
+
+ def set_xlabel(self, label: str, ignore_math=True):
+ if ignore_math:
+ label = label.replace("$", "")
+ self.plot_widget.plotItem.setLabel("bottom", label)
+
+ def set_ylabel(self, label: str, ignore_math=True):
+ if ignore_math:
+ label = label.replace("$", "")
+ self.plot_widget.plotItem.setLabel("left", label)
+
+ def link_x(self, other: Plot):
+ self.plot_widget.plotItem.setXLink(other.plot_widget.plotItem)
+
+ def link_y(self, other: Plot):
+ self.plot_widget.plotItem.setYLink(other.plot_widget.plotItem)
+
+ def get_pen(self, key: str, **pen_kwargs) -> QtGui.QPen:
+ if key not in self.pens:
+ pen_kwargs = dict(width=3) | pen_kwargs
+ if "color" not in pen_kwargs:
+ pen_kwargs["color"] = next(self.color_cycle)
+
+ self.pens[key] = pg.mkPen(**pen_kwargs)
+ return self.pens[key]
+
+
+class CacheThread(QtCore.QThread):
+ sig_progress = QtCore.Signal(int)
+ tasks: list[tuple[Callable, list]]
+ size: int
+
+ def __init__(self):
+ super().__init__()
+ self.tasks = []
+ self.size = 0
+ self.__stop = False
+
+ def __del__(self):
+ self.__stop = True
+ self.wait()
+
+ def add_task(self, func: Callable, args: Iterable[list]):
+ try:
+ func.cache_info()
+ except AttributeError:
+ raise ValueError(f"func {func.__name__} is not cached")
+ all_args = list(itertools.product(*args))
+ self.tasks.append((func, all_args))
+ self.size += len(all_args)
+
+ def run(self):
+ i = 0
+ for func, all_args in self.tasks:
+ for args in all_args:
+ if self.__stop:
+ return
+ func(*args)
+ i += 1
+ self.sig_progress.emit(i)
+
+
+class CacheWidget(QtWidgets.QWidget):
+ cache_thread: CacheThread
+ sig_finished = QtCore.Signal()
+
+ def __init__(self):
+ super().__init__()
+ self.cache_thread = CacheThread()
+ self.cache_thread.sig_progress.connect(self.update_pbar)
+
+ layout = QtWidgets.QHBoxLayout()
+ self.label = QtWidgets.QLabel("caching")
+ self.pbar = QtWidgets.QProgressBar()
+ self.pbar.setValue(0)
+ self.pbar.setMaximum(self.cache_thread.size)
+ layout.addWidget(self.label)
+ layout.addWidget(self.pbar)
+ self.setLayout(layout)
+
+ def add_task(self, func: Callable, args: Iterable[list]):
+ if self.cache_thread.isRunning():
+ raise RuntimeError("cannot add tasks after caching has begun.")
+ self.cache_thread.add_task(func, args)
+ self.pbar.setMaximum(self.cache_thread.size)
+
+ def start(self):
+ self.cache_thread.start()
+
+ def update_pbar(self, val: int):
+ self.pbar.setValue(val)
+ if val == self.cache_thread.size:
+ self.sig_finished.emit()
+ self.setParent(None)
+
+
+class PlotApp:
+ """
+ Easy interactive plotting
+
+ Example
+ -------
+ ```
+ import numpy as np
+ from customfunc.app import PlotApp
+
+
+ def main():
+ t, dt = np.linspace(-10, 10, 4096, retstep=True)
+ f = np.fft.rfftfreq(len(t), dt)
+ with PlotApp(freq=(2, 50, 30), offset=(0, 2 * np.pi, 10), spacing=(0, 2, 50)) as app:
+
+ @app.cache_and_update(app["field"])
+ def signal(freq, offset, spacing):
+ out = np.zeros(len(t))
+ for i in range(-5, 5 + 1):
+ out += np.sin(t * 2 * np.pi * freq + offset * i) * np.exp(
+ -(((t + spacing * i) / 0.3) ** 2)
+ )
+ return t, out
+
+ @app.update
+ def draw(freq, offset, spacing):
+ _, y = signal(freq, offset, spacing)
+ spec = np.fft.rfft(y)
+ app["spec"].set_line_data("real", f, spec.real)
+ app["spec"].set_line_data("imag", f, spec.imag)
+
+
+ if __name__ == "__main__":
+ main()
+ ```
+ """
+
+ app: QtWidgets.QApplication
+ window: QtWidgets.QMainWindow
+ central_layout: QtWidgets.QVBoxLayout
+ params_layout: QtWidgets.QVBoxLayout
+ central_widget: QtWidgets.QWidget
+ params_widget: QtWidgets.QWidget
+ plots: dict[str, Plot]
+ params: dict[str, Field]
+ __cache_widget: Optional[CacheWidget] = None
+
+ def __init__(self, name: str = "Plot App", **params: dict[str, Any]):
+ self.app = pg.mkQApp()
+ self.window = QtWidgets.QMainWindow()
+ self.window.setWindowTitle(name)
+ self.window.resize(1200, 800)
+ self.central_widget = QtWidgets.QWidget()
+ self.params_widget = QtWidgets.QWidget()
+ self.window.setCentralWidget(self.central_widget)
+ self.central_layout = QtWidgets.QVBoxLayout()
+ self.params_layout = QtWidgets.QVBoxLayout()
+ self.central_layout.setContentsMargins(0, 0, 0, 0)
+ self.central_widget.setLayout(self.central_layout)
+ self.params_widget.setLayout(self.params_layout)
+ self.central_layout.addWidget(self.params_widget, stretch=0)
+ self.dock_area = DockArea()
+ self.plots = {}
+
+ self.__ran = False
+ self.params = {}
+ for p_name, values in params.items():
+ field = AnimatedSliderField(p_name, values)
+ self.params[p_name] = field
+ self.params_layout.addWidget(field)
+
+ def set_antialiasing(self, val: bool):
+ pg.setConfigOptions(antialias=val)
+
+ def _parse_params(
+ self, params: dict[str, Iterable]
+ ) -> Iterator[tuple[str, Type, dict[str, Any]]]:
+ for p_name, opts in params.items():
+ if isinstance(opts, MutableMapping):
+ dtype = opts.pop("dtype", float)
+ elif isinstance(opts, Sequence):
+ if isinstance(opts[0], type):
+ dtype, *opts = opts
+ else:
+ dtype = float
+ opts = dict(zip(["v_min", "v_max", "v_num", "v_init"], opts))
+ yield p_name, dtype, opts
+
+ def update(self, *p_names: str):
+ """
+ use this as a decorator to connect the decorated function to
+ the value_changed signal of some parameters
+ The decorated function will be called with the specified parameters
+ when any of those changes
+ if decorating without parameter nor (), PlotApp will read the argument names
+ in the function signature (those MUST match the parameter names given in __init__)
+
+ Parameters
+ ----------
+ p_names : str, optional
+ name of the parameter as given in the `PlotApp.__init__`
+
+ Example
+ -------
+ ```
+ with PlotApp(speed=(float, 0, 5), num_cars=(int, 1, 5)) as plot_app
+ @plot_app.update("speed")
+ def updating_plot(speed:float):
+ x, y = some_function(speed)
+ plot_app["speed_plot"].set_line_data("my_car", x, y)
+
+ @plot_app.update
+ def call_everytime():
+ # parameters are also directly accessible anytime
+ print(plot_app.params["num_car"].value)
+ ```
+ In this example, `call_everytime` is called without argument everytime any of the
+ parameters changes, whereas `updating_plot` is called only when the `speed` parameter
+ changes and is given its new value as argument
+ """
+
+ if len(p_names) == 1 and callable(p_names[0]):
+ return self.update(*self._get_func_args(p_names[0]))(p_names[0])
+
+ def wrapper(func):
+ def to_call(v=None):
+ func(**{p_name: self.params[p_name].value for p_name in p_names})
+
+ for p_name in p_names:
+ self.params[p_name].value_changed.connect(to_call)
+ return func
+
+ return wrapper
+
+ def cache(self, *parameters: Union[str, Iterable]):
+ """
+ use this as a decorator to cache a function that outputs data
+ if decorating without parameter nor (), PlotApp will read the argument names
+ in the function signature (those MUST match the parameter names given in __init__)
+
+ Paramters
+ ---------
+ parameters : str | Iterable
+ if str : name of the parameter as given in the `PlotApp.__init__`
+ if Iterable : all possible value of a particular argument of the
+ decorated function
+
+ Example
+ -------
+ ```
+ with PlotApp(speed=(float, 0, 5)) as plot_app
+ @plot_app.cache(range(20), "speed")
+ def heavy_computation(n:int, speed:float):
+ ...
+ ```
+ As soon as the app is launched, caching of the output of `heavy_computation` for
+ every possible compination of `[0, 1, ..., 19]` and values of the `'speed'` parameter
+ will begin, with a progress bar showing how it's going.
+ """
+ if len(parameters) == 1 and callable(parameters[0]):
+ return self.cache(*self._get_func_args(parameters[0]))(parameters[0])
+
+ all_params: list[list] = []
+ for param in parameters:
+ if isinstance(param, str):
+ all_params.append(self.params[param].values())
+ else:
+ all_params.append(list(param))
+
+ def wrapper(func):
+ cached_func = cache(func)
+ self.cache_widget.add_task(cached_func, all_params)
+ return cached_func
+
+ return wrapper
+
+ def cache_and_update(self, plot: Plot, *line_names: str):
+ """combination of update and cache. the decorated function should return alternating
+ x, y variable (for example, a function to plot 2 lines would end in `return x1, y1, x2, y2`)
+ the decorated function should only computed data, not plot anything.
+
+ Parameters
+ ----------
+ plot : Plot
+ plot on which the lines are to be plotted
+ line_names : str, ..., optional
+ name of the lines, by default line
+ """
+
+ def wrapper(func):
+ new_func = self.cache(func)
+ arg_names = self._get_func_args(func)
+
+ init_vals = new_func(**{a: self.params[a].value for a in arg_names})
+ labels = [*line_names, *[f"line {i}" for i in range(len(init_vals) // 2)]]
+
+ def plotter_func(*args):
+ vals = new_func(*args)
+ for label, x, y in zip(labels, vals[::2], vals[1::2]):
+ plot.set_line_data(label, x, y)
+
+ self.update(*arg_names)(plotter_func)
+ return new_func
+
+ return wrapper
+
+ @property
+ def cache_widget(self) -> CacheWidget:
+ if self.__cache_widget is None:
+ self.__cache_widget = CacheWidget()
+ self.central_layout.addWidget(self.__cache_widget)
+
+ def delete():
+ del self.cache_widget
+
+ self.__cache_widget.sig_finished.connect(delete)
+
+ return self.__cache_widget
+
+ @cache_widget.deleter
+ def cache_widget(self):
+ self.params_widget.setMaximumHeight(self.params_widget.geometry().height())
+ self.central_layout.removeWidget(self.__cache_widget)
+ self.__cache_widget = None
+
+ def _get_func_args(self, func) -> list[str]:
+ arg_names = inspect.getfullargspec(func).args
+ if isinstance(func, MethodType):
+ arg_names = arg_names[1:]
+ for arg in arg_names:
+ if arg not in self.params:
+ raise ValueError(f"{arg} not in app parameters")
+ return arg_names
+
+ def __enter__(self) -> PlotApp:
+ if self.__ran:
+ raise RuntimeError("App already ran")
+ self.__ran = True
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ self.central_layout.addWidget(self.dock_area, stretch=1000)
+ if self.__cache_widget is not None:
+ self.__cache_widget.start()
+ for datafield in self.params.values():
+ datafield.value_changed.emit(datafield.value)
+ self.window.show()
+ self.app.exec()
+
+ @overload
+ def __getitem__(self, key: tuple[key_type, key_type]) -> PlotDataItem:
+ ...
+
+ @overload
+ def __getitem__(self, key: key_type) -> Plot:
+ ...
+
+ def __getitem__(self, key: key_type) -> Plot:
+ if isinstance(key, tuple):
+ return self[key[0]][key[1]]
+ key = str(key)
+ if key not in self.plots:
+ self.plots[key] = Plot(key, self.dock_area)
+ return self.plots[key]
+
+