added Spectrum export/import to bytes
This commit is contained in:
@@ -41,6 +41,22 @@ class Spectrum(np.ndarray):
|
|||||||
ifft = params.compute("ifft")
|
ifft = params.compute("ifft")
|
||||||
return cls(array, w, t, ifft)
|
return cls(array, w, t, ifft)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, buf: bytes) -> Spectrum:
|
||||||
|
if buf[:2] != b"S\xc9":
|
||||||
|
raise OSError("not a valid buffer")
|
||||||
|
buf = buf[2:]
|
||||||
|
ifft = np.fft.ifft if buf[:4] == b"comp" else np.fft.irfft
|
||||||
|
shape_n = buf[4]
|
||||||
|
nt, *shape = (
|
||||||
|
int.from_bytes(buf[5 + i * 8 : 5 + (i + 1) * 8], "big") for i in range(shape_n)
|
||||||
|
)
|
||||||
|
buf = buf[5 + 8 * shape_n :]
|
||||||
|
t = np.frombuffer(buf[: (i := nt * 8)], np.float64)
|
||||||
|
w = np.frombuffer(buf[i : (i := i + shape[-1] * 8)], np.float64)
|
||||||
|
data = np.frombuffer(buf[i:], np.complex128).reshape(shape)
|
||||||
|
return cls(data, w, t, ifft)
|
||||||
|
|
||||||
def __new__(
|
def __new__(
|
||||||
cls,
|
cls,
|
||||||
input_array,
|
input_array,
|
||||||
@@ -63,7 +79,7 @@ class Spectrum(np.ndarray):
|
|||||||
obj.l_order = np.argsort(obj.l)
|
obj.l_order = np.argsort(obj.l)
|
||||||
obj.spectrum_factor = (obj.t[1] - obj.t[0]) / np.sqrt(2 * np.pi)
|
obj.spectrum_factor = (obj.t[1] - obj.t[0]) / np.sqrt(2 * np.pi)
|
||||||
|
|
||||||
if not (len(obj.w) == len(obj.t) == len(obj.l) == obj.shape[-1]):
|
if not (len(obj.w) == len(obj.l) == obj.shape[-1]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"shape mismatch when creating Spectrum object. input shape: {obj.shape}, "
|
f"shape mismatch when creating Spectrum object. input shape: {obj.shape}, "
|
||||||
f"len(w) = {len(obj.w)}, len(t) = {len(obj.t)}, len(l) = {len(obj.l)}"
|
f"len(w) = {len(obj.w)}, len(t) = {len(obj.t)}, len(l) = {len(obj.l)}"
|
||||||
@@ -87,6 +103,38 @@ class Spectrum(np.ndarray):
|
|||||||
def __getitem__(self, key) -> "Spectrum":
|
def __getitem__(self, key) -> "Spectrum":
|
||||||
return super().__getitem__(key)
|
return super().__getitem__(key)
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
"""
|
||||||
|
data model
|
||||||
|
4 bytes: ascii code for ifft function
|
||||||
|
`comp` - complex (np.fft.ifft)
|
||||||
|
`real` - real (np.fft.irfft)
|
||||||
|
1 byte: number `ns` corresponding to number of dimensions + 1
|
||||||
|
ns*8 bytes: shape (nt, *rest, nw) as sequence of 64bit big-endian integers
|
||||||
|
nt*8 bytes: sequence of float64 representing the time axis
|
||||||
|
nw*8 bytes: sequence of float64 representing the angular frequency axis
|
||||||
|
rest: sequence of complex128 representing the data. Can be then reshaped into
|
||||||
|
and array of shape (*rest, nw)
|
||||||
|
"""
|
||||||
|
if self.ifft is np.fft.ifft:
|
||||||
|
f_name = b"comp"
|
||||||
|
elif self.ifft is np.fft.irfft:
|
||||||
|
f_name = b"real"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"cannot export ifft function {self.ifft!r}")
|
||||||
|
|
||||||
|
shape = (len(self.t), *self.shape)
|
||||||
|
bshape = len(shape).to_bytes(1) + b"".join(el.to_bytes(8, "big") for el in shape)
|
||||||
|
|
||||||
|
return (
|
||||||
|
b"S\xc9"
|
||||||
|
+ f_name
|
||||||
|
+ bshape
|
||||||
|
+ self.t.astype(np.float64).tobytes()
|
||||||
|
+ self.w.astype(np.float64).tobytes()
|
||||||
|
+ self.astype(np.complex128, subok=False).tobytes()
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def wl_disp(self):
|
def wl_disp(self):
|
||||||
return self.l[self.l_order]
|
return self.l[self.l_order]
|
||||||
|
|||||||
@@ -6,6 +6,33 @@ from scgenerator.physics.units import m_rads
|
|||||||
from scgenerator.spectra import Spectrum
|
from scgenerator.spectra import Spectrum
|
||||||
|
|
||||||
|
|
||||||
|
def test_export():
|
||||||
|
t = np.linspace(-1e-12, 1e-12, 1024)
|
||||||
|
w = wspace(t) + m_rads(800e-9)
|
||||||
|
|
||||||
|
spec = np.fft.fft(np.exp(-((t / 1e-13) ** 2)))
|
||||||
|
spec = Spectrum(spec, w, t)
|
||||||
|
|
||||||
|
new_spec = Spectrum.from_bytes(bytes(spec))
|
||||||
|
assert np.all(new_spec.w == spec.w)
|
||||||
|
assert np.all(new_spec.t == spec.t)
|
||||||
|
assert np.all(new_spec.l == spec.l)
|
||||||
|
assert np.all(new_spec == spec)
|
||||||
|
assert new_spec.ifft is spec.ifft
|
||||||
|
|
||||||
|
t = np.linspace(-1e-12, 1e-12, 512)
|
||||||
|
w = np.fft.rfftfreq(len(t), d=t[1] - t[0]) * 2 * np.pi + m_rads(800e-9)
|
||||||
|
spec = np.exp(2j * np.pi * np.random.rand(4, 2, 5, 257))
|
||||||
|
spec = Spectrum(spec, w, t, np.fft.irfft)
|
||||||
|
|
||||||
|
new_spec = Spectrum.from_bytes(bytes(spec))
|
||||||
|
assert np.all(new_spec.w == spec.w)
|
||||||
|
assert np.all(new_spec.t == spec.t)
|
||||||
|
assert np.all(new_spec.l == spec.l)
|
||||||
|
assert np.all(new_spec == spec)
|
||||||
|
assert new_spec.ifft is spec.ifft
|
||||||
|
|
||||||
|
|
||||||
def test_center_gravity():
|
def test_center_gravity():
|
||||||
t = np.linspace(-1e-12, 1e-12, 1024)
|
t = np.linspace(-1e-12, 1e-12, 1024)
|
||||||
w = wspace(t) + m_rads(800e-9)
|
w = wspace(t) + m_rads(800e-9)
|
||||||
|
|||||||
Reference in New Issue
Block a user