Split into files. No more plotapp dependency

This commit is contained in:
Benoît Sierro
2023-03-21 09:19:40 +01:00
parent 81d7dceb9c
commit 6bc3e9510b
10 changed files with 964 additions and 119 deletions

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
pyrightconfig.json
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/

16
LICENSE Normal file
View File

@@ -0,0 +1,16 @@
dispersionapp: a pogramm to interactively play with fiber parameters
and see the result on the dispersion
Copyright (C) 2023 Benoît Sierro
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.

16
README.md Normal file
View File

@@ -0,0 +1,16 @@
# Crash course on Python installation, with a few small side tracks
# Installation
1. If not already done so, create a virtual environment to contain all the packages required. Python 3.10 is recommended. The command below will create a new environment named 'app-env' with the latest version of Python 3.10.
conda create -y -n app-env python=3.10
2. activate said environment
conda activate app-env
3. The prompt should now read '(app-env)' on the left. The app is not published on Github or anywhere else. The link below points to my own personnal home server (I didn't find any way of getting a direct download link You are now ready to install everything with this command:
pip install https://bao.dedebenui.me/s/f3KgiqMq7giN73i/download

View File

@@ -1,7 +0,0 @@
wl_min = 160
wl_max = 1600
wl_pump = 800
rep_rate = 8e3
safety_factor = 10

29
pyproject.toml Normal file
View File

@@ -0,0 +1,29 @@
[project]
name = "dispersionapp"
version = "0.1.0"
description = "Model hollow capillary and revolver fiber interactively"
authors = [{ name = "Benoît Sierro", email = "benoit.sierro@unibe.ch" }]
dependencies = [
"scgenerator @ git+https://github.com/bsierro/scgenerator.git",
"click",
"pydantic",
"tomli",
"tomli_w",
"PySide6 >= 6.4.0",
"pyqtgraph >= 0.13.1",
]
license = { file = "LICENSE" }
classifiers = [
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
]
[project.scripts]
dispersionapp = "dispersionapp.__main__:main"
[build-system]
build-backend = "setuptools.build_meta"
requires = ["setuptools >= 65.0.0"]

3
src/dispersionapp/__init__.py Executable file
View File

@@ -0,0 +1,3 @@
import importlib
__version__ = importlib.metadata.version("dispersionapp")

View File

@@ -0,0 +1,24 @@
import os
import sys
import click
from dispersionapp.core import DEFAULT_CONFIG_FILE
from dispersionapp.gui import app
@click.command()
@click.option(
"-c",
"--config",
default=DEFAULT_CONFIG_FILE,
help="configuration file in TOML format",
type=click.Path(file_okay=True, dir_okay=False, resolve_path=True),
)
@click.version_option()
def main(config: os.PathLike):
app(config)
if __name__ == "__main__" and not hasattr(sys, "ps1"):
main()

131
src/dispersionapp/core.py Normal file
View File

@@ -0,0 +1,131 @@
from __future__ import annotations
import sys
from pathlib import Path
from typing import Any, NamedTuple
import numpy as np
import scgenerator as sc
import tomli
import tomli_w
from pydantic import BaseModel, PrivateAttr, ValidationError, confloat
DEFAULT_CONFIG_FILE = "config.toml"
class CurrentState(BaseModel):
core_diameter_um: float
pressure_mbar: float
wall_thickness_um: float
n_tubes: int
gap_um: float
t_fwhm_fs: float
class Config(BaseModel):
wl_min: confloat(ge=100, le=1000)
wl_max: confloat(ge=500, le=6000)
wl_pump: confloat(ge=200, le=6000)
rep_rate: confloat(gt=0)
gas: str
safety_factor: float
current_state: CurrentState | None = None
_file_name: Path = PrivateAttr()
@classmethod
def load(cls, config_file: str | None = None) -> Config:
config_file = Path(config_file or DEFAULT_CONFIG_FILE)
if not config_file.exists():
config_file.touch()
with open(config_file, "rb") as file:
d = tomli.load(file)
d = cls.default() | d
try:
out = cls(**d)
except ValidationError as e:
s = f"invalid input in config file {config_file}:\n{e}"
print(s)
sys.exit(1)
out._file_name = Path(config_file)
return out
@classmethod
def default(cls) -> dict[str, Any]:
return dict(
wl_min=160, wl_max=1600, wl_pump=800, rep_rate=8e3, gas="argon", safety_factor=10
)
def save(self):
tmp = self._file_name.parent / f"{self._file_name.name}.tmp"
with open(tmp, "wb") as file:
tomli_w.dump(self.dict(), file)
tmp.rename(self._file_name)
def update_current(
self, core_diameter_um, pressure_mbar, wall_thickness_um, n_tubes, gap_um, t_fwhm_fs
):
self.current_state = CurrentState(
core_diameter_um=core_diameter_um,
pressure_mbar=pressure_mbar,
wall_thickness_um=wall_thickness_um,
n_tubes=n_tubes,
gap_um=gap_um,
t_fwhm_fs=t_fwhm_fs,
)
self.save()
class LimitValues(NamedTuple):
wl_zero_disp: float
ion_lim: float
sf_lim: float
def b2(w, n_eff):
dw = w[1] - w[0]
beta = sc.fiber.beta(w, n_eff)
return sc.math.differentiate_arr(beta, 2, 4, dw)
def N_sf_max(
wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0
) -> np.ndarray:
"""
maximum soliton number according to self focusing
eq. S15 in Travers2019
"""
delta_gas = gas.sellmeier.delta(wl, wl_zero_disp)
return t0 * np.sqrt(wl / (safety * np.abs(delta_gas)))
def N_ion_max(
wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0
) -> np.ndarray:
"""
eq. S16 in Travers2019
"""
ind = sc.math.argclosest(wl, wl_zero_disp)
f = np.gradient(np.gradient(gas.sellmeier.chi(wl), wl), wl)
factor = (sc.math.u_nm(1, 1) / sc.units.c) ** 2 * (0.5 * wl / np.pi) ** 3
delta = factor * (f / f[ind] - 1)
denom = safety * np.pi * wl * np.abs(delta) * f[ind]
return t0 * sc.math.u_nm(1, 1) * np.sqrt(gas.n2() * gas.barrier_suppression / denom)
def solition_num(
t0: float, w0: float, beta2: float, n2: float, core_radius: float, peak_power: float
) -> float:
gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius))
ld = sc.pulse.L_D(t0, beta2)
return np.sqrt(gamma * ld * peak_power)
def energy(
t0: float, w0: float, beta2: float, n2: float, core_radius: float, solition_num: float
) -> float:
gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius))
peak_power = solition_num**2 * abs(beta2) / (t0**2 * gamma)
return sc.pulse.P0_to_E0(peak_power, t0, "sech")

134
dispersion_app.py → src/dispersionapp/gui.py Executable file → Normal file
View File

@@ -1,98 +1,14 @@
from __future__ import annotations
import os
import sys
import warnings
from functools import cache
from typing import Any, NamedTuple
import click
import numpy as np
import scgenerator as sc
import tomli
from customfunc.app import PlotApp
from pydantic import BaseModel, ValidationError, confloat
from functools import cache
import warnings
DEFAULT_CONFIG_FILE = "config.toml"
class Config(BaseModel):
wl_min: confloat(ge=100, le=1000)
wl_max: confloat(ge=500, le=6000)
wl_pump: confloat(ge=200, le=6000)
rep_rate: confloat(gt=0)
gas: str
@classmethod
def load(cls, config_file: str | None = None) -> Config:
config_file = config_file or DEFAULT_CONFIG_FILE
with open(config_file, "rb") as file:
d = tomli.load(file)
d = cls.default() | d
try:
return cls(**d)
except ValidationError as e:
s = f"invalid input in config file {config_file}:\n{e}"
print(s)
sys.exit(1)
@classmethod
def default(cls) -> dict[str, Any]:
return dict(wl_min=160, wl_max=1600, wl_pump=800, rep_rate=8e3, gas="argon")
class LimitValues(NamedTuple):
wl_zero_disp: float
ion_lim: float
sf_lim: float
def b2(w, n_eff):
dw = w[1] - w[0]
beta = sc.fiber.beta(w, n_eff)
return sc.math.differentiate_arr(beta, 2, 4, dw)
def N_sf_max(
wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0
) -> np.ndarray:
"""
maximum soliton number according to self focusing
eq. S15 in Travers2019
"""
delta_gas = gas.sellmeier.delta(wl, wl_zero_disp)
return t0 * np.sqrt(wl / (safety * np.abs(delta_gas)))
def N_ion_max(
wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0
) -> np.ndarray:
"""
eq. S16 in Travers2019
"""
ind = sc.math.argclosest(wl, wl_zero_disp)
f = np.gradient(np.gradient(gas.sellmeier.chi(wl), wl), wl)
factor = (sc.math.u_nm(1, 1) / sc.units.c) ** 2 * (0.5 * wl / np.pi) ** 3
delta = factor * (f / f[ind] - 1)
denom = safety * np.pi * wl * np.abs(delta) * f[ind]
return t0 * sc.math.u_nm(1, 1) * np.sqrt(gas.n2() * gas.barrier_suppression / denom)
def solition_num(
t0: float, w0: float, beta2: float, n2: float, core_radius: float, peak_power: float
) -> float:
gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius))
ld = sc.pulse.L_D(t0, beta2)
return np.sqrt(gamma * ld * peak_power)
def energy(
t0: float, w0: float, beta2: float, n2: float, core_radius: float, solition_num: float
) -> float:
gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius))
peak_power = solition_num**2 * abs(beta2) / (t0**2 * gamma)
return sc.pulse.P0_to_E0(peak_power, t0, "sech")
from dispersionapp.core import Config, LimitValues, N_ion_max, N_sf_max, energy, b2
from dispersionapp.plotapp import PlotApp
def app(config_file: os.PathLike | None = None):
@@ -118,14 +34,24 @@ def app(config_file: os.PathLike | None = None):
app[0].horizontal_line("reference", 0, color="gray")
app[0].set_xlabel("wavelength (nm)")
app[0].set_ylabel("beta2 (fs^2/cm)")
app.params["wall_thickness_um"].value = 1
app.params["core_diameter_um"].value = 100
app.params["pressure_mbar"].value = 500
app.params["n_tubes"].value = 7
app.params["gap_um"].value = 5
app.params["t_fwhm_fs"].value = 100
if config.current_state is not None:
app.params["core_diameter_um"].value = config.current_state.core_diameter_um
app.params["pressure_mbar"].value = config.current_state.pressure_mbar
app.params["wall_thickness_um"].value = config.current_state.wall_thickness_um
app.params["n_tubes"].value = config.current_state.n_tubes
app.params["gap_um"].value = config.current_state.gap_um
app.params["t_fwhm_fs"].value = config.current_state.t_fwhm_fs
else:
app.params["core_diameter_um"].value = 100
app.params["pressure_mbar"].value = 500
app.params["wall_thickness_um"].value = 1
app.params["n_tubes"].value = 7
app.params["gap_um"].value = 5
app.params["t_fwhm_fs"].value = 100
app[0].set_lim(ylim=(-4, 2))
app.update(config.update_current)
@cache
def compute_max_energy(
core_diameter_um: float, pressure_mbar: float, t_fwhm_fs: float
@@ -143,8 +69,8 @@ def app(config_file: os.PathLike | None = None):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ion_limit = N_ion_max(wl, t0, wl_zero_disp, gas)[wl_ind]
sf_limit = N_sf_max(wl, t0, wl_zero_disp, gas)[wl_ind]
ion_limit = N_ion_max(wl, t0, wl_zero_disp, gas, config.safety_factor)[wl_ind]
sf_limit = N_sf_max(wl, t0, wl_zero_disp, gas, config.safety_factor)[wl_ind]
beta2 = disp[wl_ind]
n2 = gas.n2(pressure=pressure)
@@ -218,19 +144,3 @@ def app(config_file: os.PathLike | None = None):
zdw = lim.wl_zero_disp * 1e9
app[0].set_line_data("zdw", [zdw, zdw], [-3, 3])
app[0].set_line_name("zdw", f"ZDW = {zdw:.0f}nm")
@click.command()
@click.option(
"-c",
"--config",
default=DEFAULT_CONFIG_FILE,
help="configuration file in TOML format",
type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True),
)
def main(config: os.PathLike):
app(config)
if __name__ == "__main__" and not hasattr(sys, "ps1"):
main()

View File

@@ -0,0 +1,722 @@
from __future__ import annotations
import inspect
import itertools
from collections.abc import MutableMapping, Sequence
from functools import cache
from types import MethodType
from typing import Any, Callable, Iterable, Iterator, Optional, Type, Union, overload
from PySide6 import QtCore, QtWidgets, QtGui
import numpy as np
import pyqtgraph as pg
from pyqtgraph.dockarea import Dock, DockArea
from pyqtgraph.graphicsItems.PlotDataItem import PlotDataItem
MPL_COLORS = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
key_type = Union[str, int]
class Field(QtWidgets.QWidget):
dtype: Type
value: Any
value_changed: QtCore.Signal
timer: QtCore.QTimer
@property
def value(self) -> Any:
raise NotImplementedError()
def values(self) -> list[Any]:
raise NotImplementedError()
class SliderField(Field):
dtype: Type
possible_values: np.ndarray
_slider_max = 100
_tuple_signal = QtCore.Signal(tuple)
_int_signal = QtCore.Signal(int)
_float_signal = QtCore.Signal(float)
_str_signal = QtCore.Signal(str)
def __init__(self, name: str, values: Iterable) -> None:
super().__init__()
self.__value = None
self.slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal)
self.slider.setSingleStep(1)
self.slider.valueChanged.connect(self.slider_changed)
self.field = QtWidgets.QLineEdit()
self.field.setMaximumWidth(150)
self.field.editingFinished.connect(self.field_changed)
self.step_backward_button = QtWidgets.QPushButton("<")
self.step_backward_button.clicked.connect(self.step_backward)
self.step_forward_button = QtWidgets.QPushButton(">")
self.step_forward_button.clicked.connect(self.step_forward)
self._layout = QtWidgets.QHBoxLayout()
self.setLayout(self._layout)
self._layout.setContentsMargins(10, 0, 10, 0)
pretty_name = " ".join(s.title() for s in name.split("_"))
self.name_label = QtWidgets.QLabel(pretty_name + " :")
self._layout.addWidget(self.name_label)
self._layout.addWidget(self.slider)
self._layout.addWidget(self.step_backward_button)
self._layout.addWidget(self.step_forward_button)
self._layout.addWidget(self.field)
self.set_values(values)
self.value_changed = {
int: self._int_signal,
float: self._float_signal,
str: self._str_signal,
tuple: self._tuple_signal,
}[self.dtype]
self.value_changed.emit(self.__value)
def set_values(self, values: Iterable):
self.possible_values = np.array(values)
self.value_to_slider_map = {v: i for i, v in enumerate(self.possible_values)}
self._slider_max = len(self.possible_values) - 1
self.slider.setRange(0, self._slider_max)
if self.__value not in self.value_to_slider_map:
self.__value = self.possible_values[0]
if isinstance(self.__value, (np.integer, int)):
new_dtype = int
elif isinstance(self.__value, (np.floating, float)):
new_dtype = float
elif isinstance(self.__value, tuple):
raise NotImplementedError("Tuples currently not supported")
elif isinstance(self.__value, str):
new_dtype = str
else:
raise TypeError(f"parameter of type {type(self.__value)!r} not possible")
if hasattr(self, "dtype") and self.dtype != new_dtype:
raise RuntimeError(f"Field {self.name} cannot change dtype after creation")
self.dtype = new_dtype
self.update_label()
self.update_slider()
@property
def value(self) -> float:
return self.__value
@value.setter
def value(self, new_value):
if new_value not in self.possible_values:
new_value = self.possible_values[np.argmin(np.abs(self.possible_values - new_value))]
self.__value = new_value
self.update_label()
self.slider.blockSignals(True)
self.slider.setValue(self.value_to_slider())
self.slider.blockSignals(False)
self.value_changed.emit(new_value)
def field_changed(self):
new_val = self.dtype(self.field.text())
if new_val not in self.value_to_slider_map:
self.update_label()
return
self.value = new_val
def slider_changed(self):
self.value = self.slider_to_value()
def update_label(self):
self.field.setText(str(self))
def update_slider(self):
self.slider.setValue(self.value_to_slider())
def step_forward(self):
next_value = self.slider.value() + 1
if next_value <= self._slider_max:
self.slider.setValue(next_value)
def step_backward(self):
next_value = self.slider.value() - 1
if next_value >= 0:
self.slider.setValue(next_value)
def slider_to_value(self, i=None) -> float:
if i is None:
i = self.slider.value()
return self.possible_values[i]
def value_to_slider(self) -> int:
return self.value_to_slider_map[self.__value]
def __str__(self) -> str:
if self.dtype is int:
return format(self.value)
elif self.dtype is float:
return format(self.value, ".3g")
else:
return format(self.value)
def values(self) -> list[float]:
return list(self.possible_values)
class AnimatedSliderField(SliderField):
def __init__(self, name: str, values: Iterable) -> None:
super().__init__(name, values)
self.timer = QtCore.QTimer()
self.timer.timeout.connect(self.step_forward)
self.play_button = QtWidgets.QPushButton("")
self.play_button.clicked.connect(self.toggle)
self.play_button.setMaximumWidth(30)
self.playing = False
self.interval_field = QtWidgets.QLineEdit()
self.interval_field.setMaximumWidth(60)
# self.interval_field.setValidator(QtGui.QIntValidator(1, 5000))
self.interval_field.editingFinished.connect(self.set_interval)
self.interval_field.inputRejected.connect(self.set_interval)
self.interval = 16
self.set_interval()
self._layout.addWidget(self.play_button)
self._layout.addWidget(self.interval_field)
def toggle(self):
if self.playing:
self.stop()
self.playing = False
else:
self.play()
self.playing = True
def play(self):
if self.slider.value() == self._slider_max:
self.slider.setValue(0)
self.timer.start(self.interval)
self.play_button.setStyleSheet("QPushButton {background-color: #DEF2DD;}")
def stop(self):
self.timer.stop()
self.play_button.setStyleSheet("QPushButton {background-color: none;}")
def set_interval(self):
try:
self.interval = max(1, int(self.interval_field.text()))
except ValueError:
self.interval_field.setText(str(self.interval))
if self.interval < 16:
self.increment = int(np.ceil(16 / self.interval))
self.interval *= self.increment
else:
self.increment = 1
def step_forward(self):
current = self.slider.value()
if current + self.increment <= self._slider_max:
self.slider.setValue(current + self.increment)
else:
self.slider.setValue(self._slider_max)
self.stop()
class Plot:
name: str
dock: Dock
plot_widget: pg.PlotWidget
legend: pg.LegendItem
color_cycle: Iterable[str]
lines: dict[str, PlotDataItem]
pens: dict[str, QtGui.QPen]
def __init__(self, name: str, area: DockArea):
self.name = name
self.plot_widget = pg.PlotWidget()
self.plot_widget.setBackground("w")
self.dock = Dock(name)
self.legend = pg.LegendItem((80, 60), offset=(70, 20))
self.legend.setParentItem(self.plot_widget.plotItem)
self.legend.setLabelTextSize("12pt")
self.color_cycle = itertools.cycle(MPL_COLORS)
self.lines = {}
self.pens = {}
self.dock.addWidget(self.plot_widget)
area.addDock(self.dock)
def __getitem__(self, key: str) -> PlotDataItem:
if key not in self.lines:
plot: PlotDataItem = self.plot_widget.plot()
self.legend.addItem(plot, key)
self.lines[key] = plot
return self.lines[key]
def __setitem__(self, key: str, item: PlotDataItem):
self.add_item(key, item, item)
def add_item(self, key: str, curve: PlotDataItem, obj: pg.GraphicsObject):
self.legend.addItem(curve, key)
self.plot_widget.addItem(obj)
self.lines[key] = curve
def set_line_name(self, line_name: str, new_name: str):
"""updates the displayed name of a line. Does not change the internal
name of the plot data though.
Parameters
----------
name : str
original name of the line
new_name : str
new name
"""
line = self[line_name]
self.legend.getLabel(line).setText(new_name)
def set_line_data(
self,
key: str,
x: Iterable[float],
y: Iterable[float] = None,
label: str = None,
**pen_kwargs,
):
"""plots the given data
Parameters
----------
key : str
internal name of the line. If it already exists, the line is updated, otherwise
it is created
x : Iterable[float]
x data
y : Iterable[float], optional
y data. If not given, takes x as y data, by default None
label : str, optional
if given, updates the legend label, by default None
pen_kwargs : Any, optional
given to the pg.mkPen constructor
"""
if y is None:
x, y = np.arange(len(x)), x
self[key].setData(x, y)
self[key].setPen(self.get_pen(key, **pen_kwargs))
if label is not None:
self.set_line_name(key, label)
def vertical_line(self, key: str, x: float, **pen_kwargs):
"""plots a vertical, infinite line
Parameters
----------
key : str
name of the line
x : float
position
"""
pen = self.get_pen(key, **pen_kwargs)
line = pg.InfiniteLine(x, 90, pen)
legend_item = PlotDataItem((), pen=pen)
self.add_item(key, legend_item, line)
def horizontal_line(self, key: str, y: float, **pen_kwargs):
"""plots a horizontal, infinite line
Parameters
----------
key : str
name of the line
y : float
position
"""
pen = self.get_pen(key, **pen_kwargs)
line = pg.InfiniteLine(y, 0, pen)
legend_item = PlotDataItem((), pen=pen)
self.add_item(key, legend_item, line)
def set_lim(self, *, xlim=None, ylim=None):
x_auto = xlim is None
y_auto = ylim is None
if not x_auto:
self.plot_widget.setXRange(*xlim)
if not y_auto:
self.plot_widget.setYRange(*ylim)
self.plot_widget.plotItem.enableAutoRange(x=x_auto, y=y_auto)
def set_xlabel(self, label: str, ignore_math=True):
if ignore_math:
label = label.replace("$", "")
self.plot_widget.plotItem.setLabel("bottom", label)
def set_ylabel(self, label: str, ignore_math=True):
if ignore_math:
label = label.replace("$", "")
self.plot_widget.plotItem.setLabel("left", label)
def link_x(self, other: Plot):
self.plot_widget.plotItem.setXLink(other.plot_widget.plotItem)
def link_y(self, other: Plot):
self.plot_widget.plotItem.setYLink(other.plot_widget.plotItem)
def get_pen(self, key: str, **pen_kwargs) -> QtGui.QPen:
if key not in self.pens:
pen_kwargs = dict(width=3) | pen_kwargs
if "color" not in pen_kwargs:
pen_kwargs["color"] = next(self.color_cycle)
self.pens[key] = pg.mkPen(**pen_kwargs)
return self.pens[key]
class CacheThread(QtCore.QThread):
sig_progress = QtCore.Signal(int)
tasks: list[tuple[Callable, list]]
size: int
def __init__(self):
super().__init__()
self.tasks = []
self.size = 0
self.__stop = False
def __del__(self):
self.__stop = True
self.wait()
def add_task(self, func: Callable, args: Iterable[list]):
try:
func.cache_info()
except AttributeError:
raise ValueError(f"func {func.__name__} is not cached")
all_args = list(itertools.product(*args))
self.tasks.append((func, all_args))
self.size += len(all_args)
def run(self):
i = 0
for func, all_args in self.tasks:
for args in all_args:
if self.__stop:
return
func(*args)
i += 1
self.sig_progress.emit(i)
class CacheWidget(QtWidgets.QWidget):
cache_thread: CacheThread
sig_finished = QtCore.Signal()
def __init__(self):
super().__init__()
self.cache_thread = CacheThread()
self.cache_thread.sig_progress.connect(self.update_pbar)
layout = QtWidgets.QHBoxLayout()
self.label = QtWidgets.QLabel("caching")
self.pbar = QtWidgets.QProgressBar()
self.pbar.setValue(0)
self.pbar.setMaximum(self.cache_thread.size)
layout.addWidget(self.label)
layout.addWidget(self.pbar)
self.setLayout(layout)
def add_task(self, func: Callable, args: Iterable[list]):
if self.cache_thread.isRunning():
raise RuntimeError("cannot add tasks after caching has begun.")
self.cache_thread.add_task(func, args)
self.pbar.setMaximum(self.cache_thread.size)
def start(self):
self.cache_thread.start()
def update_pbar(self, val: int):
self.pbar.setValue(val)
if val == self.cache_thread.size:
self.sig_finished.emit()
self.setParent(None)
class PlotApp:
"""
Easy interactive plotting
Example
-------
```
import numpy as np
from customfunc.app import PlotApp
def main():
t, dt = np.linspace(-10, 10, 4096, retstep=True)
f = np.fft.rfftfreq(len(t), dt)
with PlotApp(freq=(2, 50, 30), offset=(0, 2 * np.pi, 10), spacing=(0, 2, 50)) as app:
@app.cache_and_update(app["field"])
def signal(freq, offset, spacing):
out = np.zeros(len(t))
for i in range(-5, 5 + 1):
out += np.sin(t * 2 * np.pi * freq + offset * i) * np.exp(
-(((t + spacing * i) / 0.3) ** 2)
)
return t, out
@app.update
def draw(freq, offset, spacing):
_, y = signal(freq, offset, spacing)
spec = np.fft.rfft(y)
app["spec"].set_line_data("real", f, spec.real)
app["spec"].set_line_data("imag", f, spec.imag)
if __name__ == "__main__":
main()
```
"""
app: QtWidgets.QApplication
window: QtWidgets.QMainWindow
central_layout: QtWidgets.QVBoxLayout
params_layout: QtWidgets.QVBoxLayout
central_widget: QtWidgets.QWidget
params_widget: QtWidgets.QWidget
plots: dict[str, Plot]
params: dict[str, Field]
__cache_widget: Optional[CacheWidget] = None
def __init__(self, name: str = "Plot App", **params: dict[str, Any]):
self.app = pg.mkQApp()
self.window = QtWidgets.QMainWindow()
self.window.setWindowTitle(name)
self.window.resize(1200, 800)
self.central_widget = QtWidgets.QWidget()
self.params_widget = QtWidgets.QWidget()
self.window.setCentralWidget(self.central_widget)
self.central_layout = QtWidgets.QVBoxLayout()
self.params_layout = QtWidgets.QVBoxLayout()
self.central_layout.setContentsMargins(0, 0, 0, 0)
self.central_widget.setLayout(self.central_layout)
self.params_widget.setLayout(self.params_layout)
self.central_layout.addWidget(self.params_widget, stretch=0)
self.dock_area = DockArea()
self.plots = {}
self.__ran = False
self.params = {}
for p_name, values in params.items():
field = AnimatedSliderField(p_name, values)
self.params[p_name] = field
self.params_layout.addWidget(field)
def set_antialiasing(self, val: bool):
pg.setConfigOptions(antialias=val)
def _parse_params(
self, params: dict[str, Iterable]
) -> Iterator[tuple[str, Type, dict[str, Any]]]:
for p_name, opts in params.items():
if isinstance(opts, MutableMapping):
dtype = opts.pop("dtype", float)
elif isinstance(opts, Sequence):
if isinstance(opts[0], type):
dtype, *opts = opts
else:
dtype = float
opts = dict(zip(["v_min", "v_max", "v_num", "v_init"], opts))
yield p_name, dtype, opts
def update(self, *p_names: str):
"""
use this as a decorator to connect the decorated function to
the value_changed signal of some parameters
The decorated function will be called with the specified parameters
when any of those changes
if decorating without parameter nor (), PlotApp will read the argument names
in the function signature (those MUST match the parameter names given in __init__)
Parameters
----------
p_names : str, optional
name of the parameter as given in the `PlotApp.__init__`
Example
-------
```
with PlotApp(speed=(float, 0, 5), num_cars=(int, 1, 5)) as plot_app
@plot_app.update("speed")
def updating_plot(speed:float):
x, y = some_function(speed)
plot_app["speed_plot"].set_line_data("my_car", x, y)
@plot_app.update
def call_everytime():
# parameters are also directly accessible anytime
print(plot_app.params["num_car"].value)
```
In this example, `call_everytime` is called without argument everytime any of the
parameters changes, whereas `updating_plot` is called only when the `speed` parameter
changes and is given its new value as argument
"""
if len(p_names) == 1 and callable(p_names[0]):
return self.update(*self._get_func_args(p_names[0]))(p_names[0])
def wrapper(func):
def to_call(v=None):
func(**{p_name: self.params[p_name].value for p_name in p_names})
for p_name in p_names:
self.params[p_name].value_changed.connect(to_call)
return func
return wrapper
def cache(self, *parameters: Union[str, Iterable]):
"""
use this as a decorator to cache a function that outputs data
if decorating without parameter nor (), PlotApp will read the argument names
in the function signature (those MUST match the parameter names given in __init__)
Paramters
---------
parameters : str | Iterable
if str : name of the parameter as given in the `PlotApp.__init__`
if Iterable : all possible value of a particular argument of the
decorated function
Example
-------
```
with PlotApp(speed=(float, 0, 5)) as plot_app
@plot_app.cache(range(20), "speed")
def heavy_computation(n:int, speed:float):
...
```
As soon as the app is launched, caching of the output of `heavy_computation` for
every possible compination of `[0, 1, ..., 19]` and values of the `'speed'` parameter
will begin, with a progress bar showing how it's going.
"""
if len(parameters) == 1 and callable(parameters[0]):
return self.cache(*self._get_func_args(parameters[0]))(parameters[0])
all_params: list[list] = []
for param in parameters:
if isinstance(param, str):
all_params.append(self.params[param].values())
else:
all_params.append(list(param))
def wrapper(func):
cached_func = cache(func)
self.cache_widget.add_task(cached_func, all_params)
return cached_func
return wrapper
def cache_and_update(self, plot: Plot, *line_names: str):
"""combination of update and cache. the decorated function should return alternating
x, y variable (for example, a function to plot 2 lines would end in `return x1, y1, x2, y2`)
the decorated function should only computed data, not plot anything.
Parameters
----------
plot : Plot
plot on which the lines are to be plotted
line_names : str, ..., optional
name of the lines, by default line <n>
"""
def wrapper(func):
new_func = self.cache(func)
arg_names = self._get_func_args(func)
init_vals = new_func(**{a: self.params[a].value for a in arg_names})
labels = [*line_names, *[f"line {i}" for i in range(len(init_vals) // 2)]]
def plotter_func(*args):
vals = new_func(*args)
for label, x, y in zip(labels, vals[::2], vals[1::2]):
plot.set_line_data(label, x, y)
self.update(*arg_names)(plotter_func)
return new_func
return wrapper
@property
def cache_widget(self) -> CacheWidget:
if self.__cache_widget is None:
self.__cache_widget = CacheWidget()
self.central_layout.addWidget(self.__cache_widget)
def delete():
del self.cache_widget
self.__cache_widget.sig_finished.connect(delete)
return self.__cache_widget
@cache_widget.deleter
def cache_widget(self):
self.params_widget.setMaximumHeight(self.params_widget.geometry().height())
self.central_layout.removeWidget(self.__cache_widget)
self.__cache_widget = None
def _get_func_args(self, func) -> list[str]:
arg_names = inspect.getfullargspec(func).args
if isinstance(func, MethodType):
arg_names = arg_names[1:]
for arg in arg_names:
if arg not in self.params:
raise ValueError(f"{arg} not in app parameters")
return arg_names
def __enter__(self) -> PlotApp:
if self.__ran:
raise RuntimeError("App already ran")
self.__ran = True
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.central_layout.addWidget(self.dock_area, stretch=1000)
if self.__cache_widget is not None:
self.__cache_widget.start()
for datafield in self.params.values():
datafield.value_changed.emit(datafield.value)
self.window.show()
self.app.exec()
@overload
def __getitem__(self, key: tuple[key_type, key_type]) -> PlotDataItem:
...
@overload
def __getitem__(self, key: key_type) -> Plot:
...
def __getitem__(self, key: key_type) -> Plot:
if isinstance(key, tuple):
return self[key[0]][key[1]]
key = str(key)
if key not in self.plots:
self.plots[key] = Plot(key, self.dock_area)
return self.plots[key]