diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 2748532..9e8624d 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -41,6 +41,22 @@ class Spectrum(np.ndarray): ifft = params.compute("ifft") return cls(array, w, t, ifft) + @classmethod + def from_bytes(cls, buf: bytes) -> Spectrum: + if buf[:2] != b"S\xc9": + raise OSError("not a valid buffer") + buf = buf[2:] + ifft = np.fft.ifft if buf[:4] == b"comp" else np.fft.irfft + shape_n = buf[4] + nt, *shape = ( + int.from_bytes(buf[5 + i * 8 : 5 + (i + 1) * 8], "big") for i in range(shape_n) + ) + buf = buf[5 + 8 * shape_n :] + t = np.frombuffer(buf[: (i := nt * 8)], np.float64) + w = np.frombuffer(buf[i : (i := i + shape[-1] * 8)], np.float64) + data = np.frombuffer(buf[i:], np.complex128).reshape(shape) + return cls(data, w, t, ifft) + def __new__( cls, input_array, @@ -63,7 +79,7 @@ class Spectrum(np.ndarray): obj.l_order = np.argsort(obj.l) obj.spectrum_factor = (obj.t[1] - obj.t[0]) / np.sqrt(2 * np.pi) - if not (len(obj.w) == len(obj.t) == len(obj.l) == obj.shape[-1]): + if not (len(obj.w) == len(obj.l) == obj.shape[-1]): raise ValueError( f"shape mismatch when creating Spectrum object. input shape: {obj.shape}, " f"len(w) = {len(obj.w)}, len(t) = {len(obj.t)}, len(l) = {len(obj.l)}" @@ -87,6 +103,38 @@ class Spectrum(np.ndarray): def __getitem__(self, key) -> "Spectrum": return super().__getitem__(key) + def __bytes__(self) -> bytes: + """ + data model + 4 bytes: ascii code for ifft function + `comp` - complex (np.fft.ifft) + `real` - real (np.fft.irfft) + 1 byte: number `ns` corresponding to number of dimensions + 1 + ns*8 bytes: shape (nt, *rest, nw) as sequence of 64bit big-endian integers + nt*8 bytes: sequence of float64 representing the time axis + nw*8 bytes: sequence of float64 representing the angular frequency axis + rest: sequence of complex128 representing the data. Can be then reshaped into + and array of shape (*rest, nw) + """ + if self.ifft is np.fft.ifft: + f_name = b"comp" + elif self.ifft is np.fft.irfft: + f_name = b"real" + else: + raise ValueError(f"cannot export ifft function {self.ifft!r}") + + shape = (len(self.t), *self.shape) + bshape = len(shape).to_bytes(1) + b"".join(el.to_bytes(8, "big") for el in shape) + + return ( + b"S\xc9" + + f_name + + bshape + + self.t.astype(np.float64).tobytes() + + self.w.astype(np.float64).tobytes() + + self.astype(np.complex128, subok=False).tobytes() + ) + @property def wl_disp(self): return self.l[self.l_order] diff --git a/tests/test_spectra.py b/tests/test_spectra.py index 40aa6db..e182816 100644 --- a/tests/test_spectra.py +++ b/tests/test_spectra.py @@ -6,6 +6,33 @@ from scgenerator.physics.units import m_rads from scgenerator.spectra import Spectrum +def test_export(): + t = np.linspace(-1e-12, 1e-12, 1024) + w = wspace(t) + m_rads(800e-9) + + spec = np.fft.fft(np.exp(-((t / 1e-13) ** 2))) + spec = Spectrum(spec, w, t) + + new_spec = Spectrum.from_bytes(bytes(spec)) + assert np.all(new_spec.w == spec.w) + assert np.all(new_spec.t == spec.t) + assert np.all(new_spec.l == spec.l) + assert np.all(new_spec == spec) + assert new_spec.ifft is spec.ifft + + t = np.linspace(-1e-12, 1e-12, 512) + w = np.fft.rfftfreq(len(t), d=t[1] - t[0]) * 2 * np.pi + m_rads(800e-9) + spec = np.exp(2j * np.pi * np.random.rand(4, 2, 5, 257)) + spec = Spectrum(spec, w, t, np.fft.irfft) + + new_spec = Spectrum.from_bytes(bytes(spec)) + assert np.all(new_spec.w == spec.w) + assert np.all(new_spec.t == spec.t) + assert np.all(new_spec.l == spec.l) + assert np.all(new_spec == spec) + assert new_spec.ifft is spec.ifft + + def test_center_gravity(): t = np.linspace(-1e-12, 1e-12, 1024) w = wspace(t) + m_rads(800e-9)