new spectrogam method

This commit is contained in:
Benoît Sierro
2024-01-08 15:01:49 +01:00
parent 2cfbb714de
commit c3ea042d71
3 changed files with 116 additions and 3 deletions

12
examples/cli_script.py Normal file
View File

@@ -0,0 +1,12 @@
from pathlib import Path
import scgenerator as sc
class Config(sc.MainConfig):
parameters: sc.Parameters
root: Path
if __name__ == "__main__":
Config.main()

62
src/scgenerator/cli.py Normal file
View File

@@ -0,0 +1,62 @@
import os
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Protocol, TypeVar
import click
from matplotlib.pyplot import inspect
from scgenerator.parameter import Parameters
T = TypeVar("T")
class MultiSimConfig(Protocol):
def prepare(self, id: int) -> Parameters:
...
def output(self, id: int) -> Path:
...
def load_params(file: Path) -> tuple[Parameters, dict[str, Any]]:
file_config = {}
dico = tomllib.loads(file.read_text())
file_config = dico.pop("config", {})
params = Parameters(**dico)
if params.name is None:
params.name = file.stem
return params, file_config
def main_config(cls: T) -> T:
param_key = None
for k, v in inspect.get_annotations(cls, eval_str=True).items():
if v is Parameters:
param_key = k
break
@click.group()
@click.option(
"-p",
"--parameters",
type=click.Path(
exists=True, file_okay=True, dir_okay=False, resolve_path=True, path_type=Path
),
)
@click.pass_context
def main(ctx, parameters: Path | None):
params, config = load_params(Path(f"{root.stem}.toml"))
config = cls(**{param_key: params}, **config)
datadir = config.fn(0).parent
if not datadir.exists():
try:
os.makedirs(datadir, exist_ok=True)
except OSError:
pass
config.load_noise()
ctx.obj = config
setattr(cls, "main", main)
return cls

View File

@@ -8,6 +8,7 @@ from pathlib import Path
from typing import Callable, Generic, Iterable, Iterator, Sequence, TypeVar, overload from typing import Callable, Generic, Iterable, Iterator, Sequence, TypeVar, overload
import numpy as np import numpy as np
import scipy.signal as ss
from scgenerator import math from scgenerator import math
from scgenerator.io import ( from scgenerator.io import (
@@ -161,11 +162,49 @@ class Spectrum(np.ndarray):
""" """
return pulse.g12(self, axis)[..., self.wl_order] return pulse.g12(self, axis)[..., self.wl_order]
def spectrogram(self, delays: Iterable[float], gate_width: float = 2e-13) -> np.ndarray: def spectrogram(
return np.fft.fftshift( self,
pulse.spectrogram(self.t, delays, self.time_amp, gate_width), axes=-1 gate_width: float = 100e-15,
wavelength: bool = True,
autocrop: bool | float = 1e-5,
) -> np.ndarray:
dt = self.t[1] - self.t[0]
sigma = gate_width / (2 * np.sqrt(2 * np.log(2))) / dt
nperseg = int(sigma) * 16
f, t, s = ss.stft(
self.time_amp,
1 / dt,
window=("gaussian", sigma),
nperseg=nperseg,
noverlap=nperseg - 4,
detrend=False,
scaling="psd",
boundary=None,
return_onesided=False,
padded=False,
) )
f = np.fft.fftshift(f) + self.w[0] * 0.5 / np.pi
t += self.t[0]
s = np.fft.fftshift(math.abs2(s), axes=0)
if wavelength:
f = units.m_hz(f)
s = units.to_WL(s.T, f).T
f = f[::-1]
s = s[::-1]
if autocrop:
thr = s.max()
yind, xind = np.where(s > thr * autocrop)
xmin, xmax = xind.min(), xind.max()
ymin, ymax = yind.min(), yind.max()
t = t[xmin : xmax + 1]
f = f[ymin : ymax + 1]
s = s[ymin : ymax + 1][:, xmin : xmax + 1]
return t, f, s
freq_int = afreq_int freq_int = afreq_int
freq_amp = afreq_amp freq_amp = afreq_amp