new spectrogam method

This commit is contained in:
Benoît Sierro
2024-01-08 15:01:49 +01:00
parent 2cfbb714de
commit c3ea042d71
3 changed files with 116 additions and 3 deletions

12
examples/cli_script.py Normal file
View File

@@ -0,0 +1,12 @@
from pathlib import Path
import scgenerator as sc
class Config(sc.MainConfig):
parameters: sc.Parameters
root: Path
if __name__ == "__main__":
Config.main()

62
src/scgenerator/cli.py Normal file
View File

@@ -0,0 +1,62 @@
import os
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Protocol, TypeVar
import click
from matplotlib.pyplot import inspect
from scgenerator.parameter import Parameters
T = TypeVar("T")
class MultiSimConfig(Protocol):
def prepare(self, id: int) -> Parameters:
...
def output(self, id: int) -> Path:
...
def load_params(file: Path) -> tuple[Parameters, dict[str, Any]]:
file_config = {}
dico = tomllib.loads(file.read_text())
file_config = dico.pop("config", {})
params = Parameters(**dico)
if params.name is None:
params.name = file.stem
return params, file_config
def main_config(cls: T) -> T:
param_key = None
for k, v in inspect.get_annotations(cls, eval_str=True).items():
if v is Parameters:
param_key = k
break
@click.group()
@click.option(
"-p",
"--parameters",
type=click.Path(
exists=True, file_okay=True, dir_okay=False, resolve_path=True, path_type=Path
),
)
@click.pass_context
def main(ctx, parameters: Path | None):
params, config = load_params(Path(f"{root.stem}.toml"))
config = cls(**{param_key: params}, **config)
datadir = config.fn(0).parent
if not datadir.exists():
try:
os.makedirs(datadir, exist_ok=True)
except OSError:
pass
config.load_noise()
ctx.obj = config
setattr(cls, "main", main)
return cls

View File

@@ -8,6 +8,7 @@ from pathlib import Path
from typing import Callable, Generic, Iterable, Iterator, Sequence, TypeVar, overload
import numpy as np
import scipy.signal as ss
from scgenerator import math
from scgenerator.io import (
@@ -161,11 +162,49 @@ class Spectrum(np.ndarray):
"""
return pulse.g12(self, axis)[..., self.wl_order]
def spectrogram(self, delays: Iterable[float], gate_width: float = 2e-13) -> np.ndarray:
return np.fft.fftshift(
pulse.spectrogram(self.t, delays, self.time_amp, gate_width), axes=-1
def spectrogram(
self,
gate_width: float = 100e-15,
wavelength: bool = True,
autocrop: bool | float = 1e-5,
) -> np.ndarray:
dt = self.t[1] - self.t[0]
sigma = gate_width / (2 * np.sqrt(2 * np.log(2))) / dt
nperseg = int(sigma) * 16
f, t, s = ss.stft(
self.time_amp,
1 / dt,
window=("gaussian", sigma),
nperseg=nperseg,
noverlap=nperseg - 4,
detrend=False,
scaling="psd",
boundary=None,
return_onesided=False,
padded=False,
)
f = np.fft.fftshift(f) + self.w[0] * 0.5 / np.pi
t += self.t[0]
s = np.fft.fftshift(math.abs2(s), axes=0)
if wavelength:
f = units.m_hz(f)
s = units.to_WL(s.T, f).T
f = f[::-1]
s = s[::-1]
if autocrop:
thr = s.max()
yind, xind = np.where(s > thr * autocrop)
xmin, xmax = xind.min(), xind.max()
ymin, ymax = yind.min(), yind.max()
t = t[xmin : xmax + 1]
f = f[ymin : ymax + 1]
s = s[ymin : ymax + 1][:, xmin : xmax + 1]
return t, f, s
freq_int = afreq_int
freq_amp = afreq_amp