added Spectrum export/import to bytes

This commit is contained in:
Benoît Sierro
2024-01-12 11:36:02 +01:00
parent db77bef2af
commit 19acb550e7
2 changed files with 76 additions and 1 deletions

View File

@@ -41,6 +41,22 @@ class Spectrum(np.ndarray):
ifft = params.compute("ifft") ifft = params.compute("ifft")
return cls(array, w, t, ifft) return cls(array, w, t, ifft)
@classmethod
def from_bytes(cls, buf: bytes) -> Spectrum:
if buf[:2] != b"S\xc9":
raise OSError("not a valid buffer")
buf = buf[2:]
ifft = np.fft.ifft if buf[:4] == b"comp" else np.fft.irfft
shape_n = buf[4]
nt, *shape = (
int.from_bytes(buf[5 + i * 8 : 5 + (i + 1) * 8], "big") for i in range(shape_n)
)
buf = buf[5 + 8 * shape_n :]
t = np.frombuffer(buf[: (i := nt * 8)], np.float64)
w = np.frombuffer(buf[i : (i := i + shape[-1] * 8)], np.float64)
data = np.frombuffer(buf[i:], np.complex128).reshape(shape)
return cls(data, w, t, ifft)
def __new__( def __new__(
cls, cls,
input_array, input_array,
@@ -63,7 +79,7 @@ class Spectrum(np.ndarray):
obj.l_order = np.argsort(obj.l) obj.l_order = np.argsort(obj.l)
obj.spectrum_factor = (obj.t[1] - obj.t[0]) / np.sqrt(2 * np.pi) obj.spectrum_factor = (obj.t[1] - obj.t[0]) / np.sqrt(2 * np.pi)
if not (len(obj.w) == len(obj.t) == len(obj.l) == obj.shape[-1]): if not (len(obj.w) == len(obj.l) == obj.shape[-1]):
raise ValueError( raise ValueError(
f"shape mismatch when creating Spectrum object. input shape: {obj.shape}, " f"shape mismatch when creating Spectrum object. input shape: {obj.shape}, "
f"len(w) = {len(obj.w)}, len(t) = {len(obj.t)}, len(l) = {len(obj.l)}" f"len(w) = {len(obj.w)}, len(t) = {len(obj.t)}, len(l) = {len(obj.l)}"
@@ -87,6 +103,38 @@ class Spectrum(np.ndarray):
def __getitem__(self, key) -> "Spectrum": def __getitem__(self, key) -> "Spectrum":
return super().__getitem__(key) return super().__getitem__(key)
def __bytes__(self) -> bytes:
"""
data model
4 bytes: ascii code for ifft function
`comp` - complex (np.fft.ifft)
`real` - real (np.fft.irfft)
1 byte: number `ns` corresponding to number of dimensions + 1
ns*8 bytes: shape (nt, *rest, nw) as sequence of 64bit big-endian integers
nt*8 bytes: sequence of float64 representing the time axis
nw*8 bytes: sequence of float64 representing the angular frequency axis
rest: sequence of complex128 representing the data. Can be then reshaped into
and array of shape (*rest, nw)
"""
if self.ifft is np.fft.ifft:
f_name = b"comp"
elif self.ifft is np.fft.irfft:
f_name = b"real"
else:
raise ValueError(f"cannot export ifft function {self.ifft!r}")
shape = (len(self.t), *self.shape)
bshape = len(shape).to_bytes(1) + b"".join(el.to_bytes(8, "big") for el in shape)
return (
b"S\xc9"
+ f_name
+ bshape
+ self.t.astype(np.float64).tobytes()
+ self.w.astype(np.float64).tobytes()
+ self.astype(np.complex128, subok=False).tobytes()
)
@property @property
def wl_disp(self): def wl_disp(self):
return self.l[self.l_order] return self.l[self.l_order]

View File

@@ -6,6 +6,33 @@ from scgenerator.physics.units import m_rads
from scgenerator.spectra import Spectrum from scgenerator.spectra import Spectrum
def test_export():
t = np.linspace(-1e-12, 1e-12, 1024)
w = wspace(t) + m_rads(800e-9)
spec = np.fft.fft(np.exp(-((t / 1e-13) ** 2)))
spec = Spectrum(spec, w, t)
new_spec = Spectrum.from_bytes(bytes(spec))
assert np.all(new_spec.w == spec.w)
assert np.all(new_spec.t == spec.t)
assert np.all(new_spec.l == spec.l)
assert np.all(new_spec == spec)
assert new_spec.ifft is spec.ifft
t = np.linspace(-1e-12, 1e-12, 512)
w = np.fft.rfftfreq(len(t), d=t[1] - t[0]) * 2 * np.pi + m_rads(800e-9)
spec = np.exp(2j * np.pi * np.random.rand(4, 2, 5, 257))
spec = Spectrum(spec, w, t, np.fft.irfft)
new_spec = Spectrum.from_bytes(bytes(spec))
assert np.all(new_spec.w == spec.w)
assert np.all(new_spec.t == spec.t)
assert np.all(new_spec.l == spec.l)
assert np.all(new_spec == spec)
assert new_spec.ifft is spec.ifft
def test_center_gravity(): def test_center_gravity():
t = np.linspace(-1e-12, 1e-12, 1024) t = np.linspace(-1e-12, 1e-12, 1024)
w = wspace(t) + m_rads(800e-9) w = wspace(t) + m_rads(800e-9)