diff --git a/.gitignore b/.gitignore index d5b7e20..a6dba3d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ pyrightconfig.json +.DS_Store # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d0714e6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,16 @@ +dispersionapp: a pogramm to interactively play with fiber parameters +and see the result on the dispersion +Copyright (C) 2023 Benoît Sierro + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . diff --git a/README.md b/README.md new file mode 100644 index 0000000..22fb14c --- /dev/null +++ b/README.md @@ -0,0 +1,16 @@ +# Crash course on Python installation, with a few small side tracks + + + +# Installation +1. If not already done so, create a virtual environment to contain all the packages required. Python 3.10 is recommended. The command below will create a new environment named 'app-env' with the latest version of Python 3.10. + +conda create -y -n app-env python=3.10 + +2. activate said environment + +conda activate app-env + +3. The prompt should now read '(app-env)' on the left. The app is not published on Github or anywhere else. The link below points to my own personnal home server (I didn't find any way of getting a direct download link You are now ready to install everything with this command: + +pip install https://bao.dedebenui.me/s/f3KgiqMq7giN73i/download diff --git a/config.toml b/config.toml deleted file mode 100644 index 9d79bf2..0000000 --- a/config.toml +++ /dev/null @@ -1,7 +0,0 @@ -wl_min = 160 -wl_max = 1600 -wl_pump = 800 - -rep_rate = 8e3 - -safety_factor = 10 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7e2f23e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,29 @@ +[project] +name = "dispersionapp" +version = "0.1.0" +description = "Model hollow capillary and revolver fiber interactively" +authors = [{ name = "Benoît Sierro", email = "benoit.sierro@unibe.ch" }] + +dependencies = [ + "scgenerator @ git+https://github.com/bsierro/scgenerator.git", + "click", + "pydantic", + "tomli", + "tomli_w", + "PySide6 >= 6.4.0", + "pyqtgraph >= 0.13.1", +] + +license = { file = "LICENSE" } + +classifiers = [ + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", +] + +[project.scripts] + +dispersionapp = "dispersionapp.__main__:main" + +[build-system] +build-backend = "setuptools.build_meta" +requires = ["setuptools >= 65.0.0"] diff --git a/src/dispersionapp/__init__.py b/src/dispersionapp/__init__.py new file mode 100755 index 0000000..66588c4 --- /dev/null +++ b/src/dispersionapp/__init__.py @@ -0,0 +1,3 @@ +import importlib + +__version__ = importlib.metadata.version("dispersionapp") diff --git a/src/dispersionapp/__main__.py b/src/dispersionapp/__main__.py new file mode 100644 index 0000000..92c202a --- /dev/null +++ b/src/dispersionapp/__main__.py @@ -0,0 +1,24 @@ +import os +import sys + +import click + +from dispersionapp.core import DEFAULT_CONFIG_FILE +from dispersionapp.gui import app + + +@click.command() +@click.option( + "-c", + "--config", + default=DEFAULT_CONFIG_FILE, + help="configuration file in TOML format", + type=click.Path(file_okay=True, dir_okay=False, resolve_path=True), +) +@click.version_option() +def main(config: os.PathLike): + app(config) + + +if __name__ == "__main__" and not hasattr(sys, "ps1"): + main() diff --git a/src/dispersionapp/core.py b/src/dispersionapp/core.py new file mode 100644 index 0000000..3d81ff9 --- /dev/null +++ b/src/dispersionapp/core.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Any, NamedTuple + +import numpy as np +import scgenerator as sc +import tomli +import tomli_w +from pydantic import BaseModel, PrivateAttr, ValidationError, confloat + +DEFAULT_CONFIG_FILE = "config.toml" + + +class CurrentState(BaseModel): + core_diameter_um: float + pressure_mbar: float + wall_thickness_um: float + n_tubes: int + gap_um: float + t_fwhm_fs: float + + +class Config(BaseModel): + wl_min: confloat(ge=100, le=1000) + wl_max: confloat(ge=500, le=6000) + wl_pump: confloat(ge=200, le=6000) + rep_rate: confloat(gt=0) + gas: str + safety_factor: float + current_state: CurrentState | None = None + + _file_name: Path = PrivateAttr() + + @classmethod + def load(cls, config_file: str | None = None) -> Config: + config_file = Path(config_file or DEFAULT_CONFIG_FILE) + if not config_file.exists(): + config_file.touch() + + with open(config_file, "rb") as file: + d = tomli.load(file) + d = cls.default() | d + try: + out = cls(**d) + except ValidationError as e: + s = f"invalid input in config file {config_file}:\n{e}" + print(s) + sys.exit(1) + out._file_name = Path(config_file) + return out + + @classmethod + def default(cls) -> dict[str, Any]: + return dict( + wl_min=160, wl_max=1600, wl_pump=800, rep_rate=8e3, gas="argon", safety_factor=10 + ) + + def save(self): + tmp = self._file_name.parent / f"{self._file_name.name}.tmp" + with open(tmp, "wb") as file: + tomli_w.dump(self.dict(), file) + tmp.rename(self._file_name) + + def update_current( + self, core_diameter_um, pressure_mbar, wall_thickness_um, n_tubes, gap_um, t_fwhm_fs + ): + self.current_state = CurrentState( + core_diameter_um=core_diameter_um, + pressure_mbar=pressure_mbar, + wall_thickness_um=wall_thickness_um, + n_tubes=n_tubes, + gap_um=gap_um, + t_fwhm_fs=t_fwhm_fs, + ) + self.save() + + +class LimitValues(NamedTuple): + wl_zero_disp: float + ion_lim: float + sf_lim: float + + +def b2(w, n_eff): + dw = w[1] - w[0] + beta = sc.fiber.beta(w, n_eff) + return sc.math.differentiate_arr(beta, 2, 4, dw) + + +def N_sf_max( + wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0 +) -> np.ndarray: + """ + maximum soliton number according to self focusing + + eq. S15 in Travers2019 + """ + delta_gas = gas.sellmeier.delta(wl, wl_zero_disp) + return t0 * np.sqrt(wl / (safety * np.abs(delta_gas))) + + +def N_ion_max( + wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0 +) -> np.ndarray: + """ + eq. S16 in Travers2019 + """ + ind = sc.math.argclosest(wl, wl_zero_disp) + f = np.gradient(np.gradient(gas.sellmeier.chi(wl), wl), wl) + factor = (sc.math.u_nm(1, 1) / sc.units.c) ** 2 * (0.5 * wl / np.pi) ** 3 + delta = factor * (f / f[ind] - 1) + denom = safety * np.pi * wl * np.abs(delta) * f[ind] + return t0 * sc.math.u_nm(1, 1) * np.sqrt(gas.n2() * gas.barrier_suppression / denom) + + +def solition_num( + t0: float, w0: float, beta2: float, n2: float, core_radius: float, peak_power: float +) -> float: + gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius)) + ld = sc.pulse.L_D(t0, beta2) + return np.sqrt(gamma * ld * peak_power) + + +def energy( + t0: float, w0: float, beta2: float, n2: float, core_radius: float, solition_num: float +) -> float: + gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius)) + peak_power = solition_num**2 * abs(beta2) / (t0**2 * gamma) + return sc.pulse.P0_to_E0(peak_power, t0, "sech") diff --git a/dispersion_app.py b/src/dispersionapp/gui.py old mode 100755 new mode 100644 similarity index 57% rename from dispersion_app.py rename to src/dispersionapp/gui.py index 0b93c30..a097110 --- a/dispersion_app.py +++ b/src/dispersionapp/gui.py @@ -1,98 +1,14 @@ from __future__ import annotations import os -import sys -import warnings -from functools import cache -from typing import Any, NamedTuple -import click import numpy as np import scgenerator as sc -import tomli -from customfunc.app import PlotApp -from pydantic import BaseModel, ValidationError, confloat +from functools import cache +import warnings -DEFAULT_CONFIG_FILE = "config.toml" - - -class Config(BaseModel): - wl_min: confloat(ge=100, le=1000) - wl_max: confloat(ge=500, le=6000) - wl_pump: confloat(ge=200, le=6000) - rep_rate: confloat(gt=0) - gas: str - - @classmethod - def load(cls, config_file: str | None = None) -> Config: - config_file = config_file or DEFAULT_CONFIG_FILE - with open(config_file, "rb") as file: - d = tomli.load(file) - d = cls.default() | d - try: - return cls(**d) - except ValidationError as e: - s = f"invalid input in config file {config_file}:\n{e}" - print(s) - sys.exit(1) - - @classmethod - def default(cls) -> dict[str, Any]: - return dict(wl_min=160, wl_max=1600, wl_pump=800, rep_rate=8e3, gas="argon") - - -class LimitValues(NamedTuple): - wl_zero_disp: float - ion_lim: float - sf_lim: float - - -def b2(w, n_eff): - dw = w[1] - w[0] - beta = sc.fiber.beta(w, n_eff) - return sc.math.differentiate_arr(beta, 2, 4, dw) - - -def N_sf_max( - wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0 -) -> np.ndarray: - """ - maximum soliton number according to self focusing - - eq. S15 in Travers2019 - """ - delta_gas = gas.sellmeier.delta(wl, wl_zero_disp) - return t0 * np.sqrt(wl / (safety * np.abs(delta_gas))) - - -def N_ion_max( - wl: np.ndarray, t0: float, wl_zero_disp: float, gas: sc.materials.Gas, safety: float = 10.0 -) -> np.ndarray: - """ - eq. S16 in Travers2019 - """ - ind = sc.math.argclosest(wl, wl_zero_disp) - f = np.gradient(np.gradient(gas.sellmeier.chi(wl), wl), wl) - factor = (sc.math.u_nm(1, 1) / sc.units.c) ** 2 * (0.5 * wl / np.pi) ** 3 - delta = factor * (f / f[ind] - 1) - denom = safety * np.pi * wl * np.abs(delta) * f[ind] - return t0 * sc.math.u_nm(1, 1) * np.sqrt(gas.n2() * gas.barrier_suppression / denom) - - -def solition_num( - t0: float, w0: float, beta2: float, n2: float, core_radius: float, peak_power: float -) -> float: - gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius)) - ld = sc.pulse.L_D(t0, beta2) - return np.sqrt(gamma * ld * peak_power) - - -def energy( - t0: float, w0: float, beta2: float, n2: float, core_radius: float, solition_num: float -) -> float: - gamma = sc.fiber.gamma_parameter(n2, w0, sc.fiber.A_eff_marcatili(core_radius)) - peak_power = solition_num**2 * abs(beta2) / (t0**2 * gamma) - return sc.pulse.P0_to_E0(peak_power, t0, "sech") +from dispersionapp.core import Config, LimitValues, N_ion_max, N_sf_max, energy, b2 +from dispersionapp.plotapp import PlotApp def app(config_file: os.PathLike | None = None): @@ -118,14 +34,24 @@ def app(config_file: os.PathLike | None = None): app[0].horizontal_line("reference", 0, color="gray") app[0].set_xlabel("wavelength (nm)") app[0].set_ylabel("beta2 (fs^2/cm)") - app.params["wall_thickness_um"].value = 1 - app.params["core_diameter_um"].value = 100 - app.params["pressure_mbar"].value = 500 - app.params["n_tubes"].value = 7 - app.params["gap_um"].value = 5 - app.params["t_fwhm_fs"].value = 100 + if config.current_state is not None: + app.params["core_diameter_um"].value = config.current_state.core_diameter_um + app.params["pressure_mbar"].value = config.current_state.pressure_mbar + app.params["wall_thickness_um"].value = config.current_state.wall_thickness_um + app.params["n_tubes"].value = config.current_state.n_tubes + app.params["gap_um"].value = config.current_state.gap_um + app.params["t_fwhm_fs"].value = config.current_state.t_fwhm_fs + else: + app.params["core_diameter_um"].value = 100 + app.params["pressure_mbar"].value = 500 + app.params["wall_thickness_um"].value = 1 + app.params["n_tubes"].value = 7 + app.params["gap_um"].value = 5 + app.params["t_fwhm_fs"].value = 100 app[0].set_lim(ylim=(-4, 2)) + app.update(config.update_current) + @cache def compute_max_energy( core_diameter_um: float, pressure_mbar: float, t_fwhm_fs: float @@ -143,8 +69,8 @@ def app(config_file: os.PathLike | None = None): with warnings.catch_warnings(): warnings.simplefilter("ignore") - ion_limit = N_ion_max(wl, t0, wl_zero_disp, gas)[wl_ind] - sf_limit = N_sf_max(wl, t0, wl_zero_disp, gas)[wl_ind] + ion_limit = N_ion_max(wl, t0, wl_zero_disp, gas, config.safety_factor)[wl_ind] + sf_limit = N_sf_max(wl, t0, wl_zero_disp, gas, config.safety_factor)[wl_ind] beta2 = disp[wl_ind] n2 = gas.n2(pressure=pressure) @@ -218,19 +144,3 @@ def app(config_file: os.PathLike | None = None): zdw = lim.wl_zero_disp * 1e9 app[0].set_line_data("zdw", [zdw, zdw], [-3, 3]) app[0].set_line_name("zdw", f"ZDW = {zdw:.0f}nm") - - -@click.command() -@click.option( - "-c", - "--config", - default=DEFAULT_CONFIG_FILE, - help="configuration file in TOML format", - type=click.Path(exists=True, file_okay=True, dir_okay=False, resolve_path=True), -) -def main(config: os.PathLike): - app(config) - - -if __name__ == "__main__" and not hasattr(sys, "ps1"): - main() diff --git a/src/dispersionapp/plotapp.py b/src/dispersionapp/plotapp.py new file mode 100644 index 0000000..a464e7b --- /dev/null +++ b/src/dispersionapp/plotapp.py @@ -0,0 +1,722 @@ +from __future__ import annotations + +import inspect +import itertools +from collections.abc import MutableMapping, Sequence +from functools import cache +from types import MethodType +from typing import Any, Callable, Iterable, Iterator, Optional, Type, Union, overload + +from PySide6 import QtCore, QtWidgets, QtGui +import numpy as np +import pyqtgraph as pg +from pyqtgraph.dockarea import Dock, DockArea +from pyqtgraph.graphicsItems.PlotDataItem import PlotDataItem + +MPL_COLORS = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", +] + +key_type = Union[str, int] + + +class Field(QtWidgets.QWidget): + dtype: Type + value: Any + value_changed: QtCore.Signal + timer: QtCore.QTimer + + @property + def value(self) -> Any: + raise NotImplementedError() + + def values(self) -> list[Any]: + raise NotImplementedError() + + +class SliderField(Field): + dtype: Type + possible_values: np.ndarray + _slider_max = 100 + _tuple_signal = QtCore.Signal(tuple) + _int_signal = QtCore.Signal(int) + _float_signal = QtCore.Signal(float) + _str_signal = QtCore.Signal(str) + + def __init__(self, name: str, values: Iterable) -> None: + super().__init__() + self.__value = None + self.slider = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal) + self.slider.setSingleStep(1) + self.slider.valueChanged.connect(self.slider_changed) + + self.field = QtWidgets.QLineEdit() + self.field.setMaximumWidth(150) + self.field.editingFinished.connect(self.field_changed) + + self.step_backward_button = QtWidgets.QPushButton("<") + self.step_backward_button.clicked.connect(self.step_backward) + + self.step_forward_button = QtWidgets.QPushButton(">") + self.step_forward_button.clicked.connect(self.step_forward) + + self._layout = QtWidgets.QHBoxLayout() + self.setLayout(self._layout) + self._layout.setContentsMargins(10, 0, 10, 0) + + pretty_name = " ".join(s.title() for s in name.split("_")) + self.name_label = QtWidgets.QLabel(pretty_name + " :") + + self._layout.addWidget(self.name_label) + self._layout.addWidget(self.slider) + self._layout.addWidget(self.step_backward_button) + self._layout.addWidget(self.step_forward_button) + self._layout.addWidget(self.field) + + self.set_values(values) + self.value_changed = { + int: self._int_signal, + float: self._float_signal, + str: self._str_signal, + tuple: self._tuple_signal, + }[self.dtype] + self.value_changed.emit(self.__value) + + def set_values(self, values: Iterable): + self.possible_values = np.array(values) + self.value_to_slider_map = {v: i for i, v in enumerate(self.possible_values)} + self._slider_max = len(self.possible_values) - 1 + self.slider.setRange(0, self._slider_max) + if self.__value not in self.value_to_slider_map: + self.__value = self.possible_values[0] + if isinstance(self.__value, (np.integer, int)): + new_dtype = int + elif isinstance(self.__value, (np.floating, float)): + new_dtype = float + elif isinstance(self.__value, tuple): + raise NotImplementedError("Tuples currently not supported") + elif isinstance(self.__value, str): + new_dtype = str + else: + raise TypeError(f"parameter of type {type(self.__value)!r} not possible") + if hasattr(self, "dtype") and self.dtype != new_dtype: + raise RuntimeError(f"Field {self.name} cannot change dtype after creation") + self.dtype = new_dtype + self.update_label() + self.update_slider() + + @property + def value(self) -> float: + return self.__value + + @value.setter + def value(self, new_value): + if new_value not in self.possible_values: + new_value = self.possible_values[np.argmin(np.abs(self.possible_values - new_value))] + self.__value = new_value + self.update_label() + self.slider.blockSignals(True) + self.slider.setValue(self.value_to_slider()) + self.slider.blockSignals(False) + self.value_changed.emit(new_value) + + def field_changed(self): + new_val = self.dtype(self.field.text()) + if new_val not in self.value_to_slider_map: + self.update_label() + return + self.value = new_val + + def slider_changed(self): + self.value = self.slider_to_value() + + def update_label(self): + self.field.setText(str(self)) + + def update_slider(self): + self.slider.setValue(self.value_to_slider()) + + def step_forward(self): + next_value = self.slider.value() + 1 + if next_value <= self._slider_max: + self.slider.setValue(next_value) + + def step_backward(self): + next_value = self.slider.value() - 1 + if next_value >= 0: + self.slider.setValue(next_value) + + def slider_to_value(self, i=None) -> float: + if i is None: + i = self.slider.value() + return self.possible_values[i] + + def value_to_slider(self) -> int: + return self.value_to_slider_map[self.__value] + + def __str__(self) -> str: + if self.dtype is int: + return format(self.value) + elif self.dtype is float: + return format(self.value, ".3g") + else: + return format(self.value) + + def values(self) -> list[float]: + return list(self.possible_values) + + +class AnimatedSliderField(SliderField): + def __init__(self, name: str, values: Iterable) -> None: + super().__init__(name, values) + self.timer = QtCore.QTimer() + self.timer.timeout.connect(self.step_forward) + + self.play_button = QtWidgets.QPushButton("⏯") + self.play_button.clicked.connect(self.toggle) + self.play_button.setMaximumWidth(30) + self.playing = False + + self.interval_field = QtWidgets.QLineEdit() + self.interval_field.setMaximumWidth(60) + # self.interval_field.setValidator(QtGui.QIntValidator(1, 5000)) + self.interval_field.editingFinished.connect(self.set_interval) + self.interval_field.inputRejected.connect(self.set_interval) + self.interval = 16 + self.set_interval() + + self._layout.addWidget(self.play_button) + self._layout.addWidget(self.interval_field) + + def toggle(self): + if self.playing: + self.stop() + self.playing = False + else: + self.play() + self.playing = True + + def play(self): + if self.slider.value() == self._slider_max: + self.slider.setValue(0) + self.timer.start(self.interval) + self.play_button.setStyleSheet("QPushButton {background-color: #DEF2DD;}") + + def stop(self): + self.timer.stop() + self.play_button.setStyleSheet("QPushButton {background-color: none;}") + + def set_interval(self): + try: + self.interval = max(1, int(self.interval_field.text())) + except ValueError: + self.interval_field.setText(str(self.interval)) + if self.interval < 16: + self.increment = int(np.ceil(16 / self.interval)) + self.interval *= self.increment + else: + self.increment = 1 + + def step_forward(self): + current = self.slider.value() + if current + self.increment <= self._slider_max: + self.slider.setValue(current + self.increment) + else: + self.slider.setValue(self._slider_max) + self.stop() + + +class Plot: + name: str + dock: Dock + plot_widget: pg.PlotWidget + legend: pg.LegendItem + color_cycle: Iterable[str] + lines: dict[str, PlotDataItem] + pens: dict[str, QtGui.QPen] + + def __init__(self, name: str, area: DockArea): + self.name = name + self.plot_widget = pg.PlotWidget() + self.plot_widget.setBackground("w") + self.dock = Dock(name) + self.legend = pg.LegendItem((80, 60), offset=(70, 20)) + self.legend.setParentItem(self.plot_widget.plotItem) + self.legend.setLabelTextSize("12pt") + self.color_cycle = itertools.cycle(MPL_COLORS) + self.lines = {} + self.pens = {} + self.dock.addWidget(self.plot_widget) + area.addDock(self.dock) + + def __getitem__(self, key: str) -> PlotDataItem: + if key not in self.lines: + plot: PlotDataItem = self.plot_widget.plot() + self.legend.addItem(plot, key) + self.lines[key] = plot + return self.lines[key] + + def __setitem__(self, key: str, item: PlotDataItem): + self.add_item(key, item, item) + + def add_item(self, key: str, curve: PlotDataItem, obj: pg.GraphicsObject): + self.legend.addItem(curve, key) + self.plot_widget.addItem(obj) + self.lines[key] = curve + + def set_line_name(self, line_name: str, new_name: str): + """updates the displayed name of a line. Does not change the internal + name of the plot data though. + + Parameters + ---------- + name : str + original name of the line + new_name : str + new name + """ + line = self[line_name] + self.legend.getLabel(line).setText(new_name) + + def set_line_data( + self, + key: str, + x: Iterable[float], + y: Iterable[float] = None, + label: str = None, + **pen_kwargs, + ): + """plots the given data + + Parameters + ---------- + key : str + internal name of the line. If it already exists, the line is updated, otherwise + it is created + x : Iterable[float] + x data + y : Iterable[float], optional + y data. If not given, takes x as y data, by default None + label : str, optional + if given, updates the legend label, by default None + pen_kwargs : Any, optional + given to the pg.mkPen constructor + """ + if y is None: + x, y = np.arange(len(x)), x + self[key].setData(x, y) + self[key].setPen(self.get_pen(key, **pen_kwargs)) + if label is not None: + self.set_line_name(key, label) + + def vertical_line(self, key: str, x: float, **pen_kwargs): + """plots a vertical, infinite line + + Parameters + ---------- + key : str + name of the line + x : float + position + """ + pen = self.get_pen(key, **pen_kwargs) + line = pg.InfiniteLine(x, 90, pen) + legend_item = PlotDataItem((), pen=pen) + self.add_item(key, legend_item, line) + + def horizontal_line(self, key: str, y: float, **pen_kwargs): + """plots a horizontal, infinite line + + Parameters + ---------- + key : str + name of the line + y : float + position + """ + pen = self.get_pen(key, **pen_kwargs) + line = pg.InfiniteLine(y, 0, pen) + legend_item = PlotDataItem((), pen=pen) + self.add_item(key, legend_item, line) + + def set_lim(self, *, xlim=None, ylim=None): + + x_auto = xlim is None + y_auto = ylim is None + if not x_auto: + self.plot_widget.setXRange(*xlim) + if not y_auto: + self.plot_widget.setYRange(*ylim) + self.plot_widget.plotItem.enableAutoRange(x=x_auto, y=y_auto) + + def set_xlabel(self, label: str, ignore_math=True): + if ignore_math: + label = label.replace("$", "") + self.plot_widget.plotItem.setLabel("bottom", label) + + def set_ylabel(self, label: str, ignore_math=True): + if ignore_math: + label = label.replace("$", "") + self.plot_widget.plotItem.setLabel("left", label) + + def link_x(self, other: Plot): + self.plot_widget.plotItem.setXLink(other.plot_widget.plotItem) + + def link_y(self, other: Plot): + self.plot_widget.plotItem.setYLink(other.plot_widget.plotItem) + + def get_pen(self, key: str, **pen_kwargs) -> QtGui.QPen: + if key not in self.pens: + pen_kwargs = dict(width=3) | pen_kwargs + if "color" not in pen_kwargs: + pen_kwargs["color"] = next(self.color_cycle) + + self.pens[key] = pg.mkPen(**pen_kwargs) + return self.pens[key] + + +class CacheThread(QtCore.QThread): + sig_progress = QtCore.Signal(int) + tasks: list[tuple[Callable, list]] + size: int + + def __init__(self): + super().__init__() + self.tasks = [] + self.size = 0 + self.__stop = False + + def __del__(self): + self.__stop = True + self.wait() + + def add_task(self, func: Callable, args: Iterable[list]): + try: + func.cache_info() + except AttributeError: + raise ValueError(f"func {func.__name__} is not cached") + all_args = list(itertools.product(*args)) + self.tasks.append((func, all_args)) + self.size += len(all_args) + + def run(self): + i = 0 + for func, all_args in self.tasks: + for args in all_args: + if self.__stop: + return + func(*args) + i += 1 + self.sig_progress.emit(i) + + +class CacheWidget(QtWidgets.QWidget): + cache_thread: CacheThread + sig_finished = QtCore.Signal() + + def __init__(self): + super().__init__() + self.cache_thread = CacheThread() + self.cache_thread.sig_progress.connect(self.update_pbar) + + layout = QtWidgets.QHBoxLayout() + self.label = QtWidgets.QLabel("caching") + self.pbar = QtWidgets.QProgressBar() + self.pbar.setValue(0) + self.pbar.setMaximum(self.cache_thread.size) + layout.addWidget(self.label) + layout.addWidget(self.pbar) + self.setLayout(layout) + + def add_task(self, func: Callable, args: Iterable[list]): + if self.cache_thread.isRunning(): + raise RuntimeError("cannot add tasks after caching has begun.") + self.cache_thread.add_task(func, args) + self.pbar.setMaximum(self.cache_thread.size) + + def start(self): + self.cache_thread.start() + + def update_pbar(self, val: int): + self.pbar.setValue(val) + if val == self.cache_thread.size: + self.sig_finished.emit() + self.setParent(None) + + +class PlotApp: + """ + Easy interactive plotting + + Example + ------- + ``` + import numpy as np + from customfunc.app import PlotApp + + + def main(): + t, dt = np.linspace(-10, 10, 4096, retstep=True) + f = np.fft.rfftfreq(len(t), dt) + with PlotApp(freq=(2, 50, 30), offset=(0, 2 * np.pi, 10), spacing=(0, 2, 50)) as app: + + @app.cache_and_update(app["field"]) + def signal(freq, offset, spacing): + out = np.zeros(len(t)) + for i in range(-5, 5 + 1): + out += np.sin(t * 2 * np.pi * freq + offset * i) * np.exp( + -(((t + spacing * i) / 0.3) ** 2) + ) + return t, out + + @app.update + def draw(freq, offset, spacing): + _, y = signal(freq, offset, spacing) + spec = np.fft.rfft(y) + app["spec"].set_line_data("real", f, spec.real) + app["spec"].set_line_data("imag", f, spec.imag) + + + if __name__ == "__main__": + main() + ``` + """ + + app: QtWidgets.QApplication + window: QtWidgets.QMainWindow + central_layout: QtWidgets.QVBoxLayout + params_layout: QtWidgets.QVBoxLayout + central_widget: QtWidgets.QWidget + params_widget: QtWidgets.QWidget + plots: dict[str, Plot] + params: dict[str, Field] + __cache_widget: Optional[CacheWidget] = None + + def __init__(self, name: str = "Plot App", **params: dict[str, Any]): + self.app = pg.mkQApp() + self.window = QtWidgets.QMainWindow() + self.window.setWindowTitle(name) + self.window.resize(1200, 800) + self.central_widget = QtWidgets.QWidget() + self.params_widget = QtWidgets.QWidget() + self.window.setCentralWidget(self.central_widget) + self.central_layout = QtWidgets.QVBoxLayout() + self.params_layout = QtWidgets.QVBoxLayout() + self.central_layout.setContentsMargins(0, 0, 0, 0) + self.central_widget.setLayout(self.central_layout) + self.params_widget.setLayout(self.params_layout) + self.central_layout.addWidget(self.params_widget, stretch=0) + self.dock_area = DockArea() + self.plots = {} + + self.__ran = False + self.params = {} + for p_name, values in params.items(): + field = AnimatedSliderField(p_name, values) + self.params[p_name] = field + self.params_layout.addWidget(field) + + def set_antialiasing(self, val: bool): + pg.setConfigOptions(antialias=val) + + def _parse_params( + self, params: dict[str, Iterable] + ) -> Iterator[tuple[str, Type, dict[str, Any]]]: + for p_name, opts in params.items(): + if isinstance(opts, MutableMapping): + dtype = opts.pop("dtype", float) + elif isinstance(opts, Sequence): + if isinstance(opts[0], type): + dtype, *opts = opts + else: + dtype = float + opts = dict(zip(["v_min", "v_max", "v_num", "v_init"], opts)) + yield p_name, dtype, opts + + def update(self, *p_names: str): + """ + use this as a decorator to connect the decorated function to + the value_changed signal of some parameters + The decorated function will be called with the specified parameters + when any of those changes + if decorating without parameter nor (), PlotApp will read the argument names + in the function signature (those MUST match the parameter names given in __init__) + + Parameters + ---------- + p_names : str, optional + name of the parameter as given in the `PlotApp.__init__` + + Example + ------- + ``` + with PlotApp(speed=(float, 0, 5), num_cars=(int, 1, 5)) as plot_app + @plot_app.update("speed") + def updating_plot(speed:float): + x, y = some_function(speed) + plot_app["speed_plot"].set_line_data("my_car", x, y) + + @plot_app.update + def call_everytime(): + # parameters are also directly accessible anytime + print(plot_app.params["num_car"].value) + ``` + In this example, `call_everytime` is called without argument everytime any of the + parameters changes, whereas `updating_plot` is called only when the `speed` parameter + changes and is given its new value as argument + """ + + if len(p_names) == 1 and callable(p_names[0]): + return self.update(*self._get_func_args(p_names[0]))(p_names[0]) + + def wrapper(func): + def to_call(v=None): + func(**{p_name: self.params[p_name].value for p_name in p_names}) + + for p_name in p_names: + self.params[p_name].value_changed.connect(to_call) + return func + + return wrapper + + def cache(self, *parameters: Union[str, Iterable]): + """ + use this as a decorator to cache a function that outputs data + if decorating without parameter nor (), PlotApp will read the argument names + in the function signature (those MUST match the parameter names given in __init__) + + Paramters + --------- + parameters : str | Iterable + if str : name of the parameter as given in the `PlotApp.__init__` + if Iterable : all possible value of a particular argument of the + decorated function + + Example + ------- + ``` + with PlotApp(speed=(float, 0, 5)) as plot_app + @plot_app.cache(range(20), "speed") + def heavy_computation(n:int, speed:float): + ... + ``` + As soon as the app is launched, caching of the output of `heavy_computation` for + every possible compination of `[0, 1, ..., 19]` and values of the `'speed'` parameter + will begin, with a progress bar showing how it's going. + """ + if len(parameters) == 1 and callable(parameters[0]): + return self.cache(*self._get_func_args(parameters[0]))(parameters[0]) + + all_params: list[list] = [] + for param in parameters: + if isinstance(param, str): + all_params.append(self.params[param].values()) + else: + all_params.append(list(param)) + + def wrapper(func): + cached_func = cache(func) + self.cache_widget.add_task(cached_func, all_params) + return cached_func + + return wrapper + + def cache_and_update(self, plot: Plot, *line_names: str): + """combination of update and cache. the decorated function should return alternating + x, y variable (for example, a function to plot 2 lines would end in `return x1, y1, x2, y2`) + the decorated function should only computed data, not plot anything. + + Parameters + ---------- + plot : Plot + plot on which the lines are to be plotted + line_names : str, ..., optional + name of the lines, by default line + """ + + def wrapper(func): + new_func = self.cache(func) + arg_names = self._get_func_args(func) + + init_vals = new_func(**{a: self.params[a].value for a in arg_names}) + labels = [*line_names, *[f"line {i}" for i in range(len(init_vals) // 2)]] + + def plotter_func(*args): + vals = new_func(*args) + for label, x, y in zip(labels, vals[::2], vals[1::2]): + plot.set_line_data(label, x, y) + + self.update(*arg_names)(plotter_func) + return new_func + + return wrapper + + @property + def cache_widget(self) -> CacheWidget: + if self.__cache_widget is None: + self.__cache_widget = CacheWidget() + self.central_layout.addWidget(self.__cache_widget) + + def delete(): + del self.cache_widget + + self.__cache_widget.sig_finished.connect(delete) + + return self.__cache_widget + + @cache_widget.deleter + def cache_widget(self): + self.params_widget.setMaximumHeight(self.params_widget.geometry().height()) + self.central_layout.removeWidget(self.__cache_widget) + self.__cache_widget = None + + def _get_func_args(self, func) -> list[str]: + arg_names = inspect.getfullargspec(func).args + if isinstance(func, MethodType): + arg_names = arg_names[1:] + for arg in arg_names: + if arg not in self.params: + raise ValueError(f"{arg} not in app parameters") + return arg_names + + def __enter__(self) -> PlotApp: + if self.__ran: + raise RuntimeError("App already ran") + self.__ran = True + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.central_layout.addWidget(self.dock_area, stretch=1000) + if self.__cache_widget is not None: + self.__cache_widget.start() + for datafield in self.params.values(): + datafield.value_changed.emit(datafield.value) + self.window.show() + self.app.exec() + + @overload + def __getitem__(self, key: tuple[key_type, key_type]) -> PlotDataItem: + ... + + @overload + def __getitem__(self, key: key_type) -> Plot: + ... + + def __getitem__(self, key: key_type) -> Plot: + if isinstance(key, tuple): + return self[key[0]][key[1]] + key = str(key) + if key not in self.plots: + self.plots[key] = Plot(key, self.dock_area) + return self.plots[key] + +