new: generic open_data_file function

This commit is contained in:
Benoît Sierro
2023-09-14 16:40:55 +02:00
parent 6f8f7cc999
commit 7f07b73626
2 changed files with 20 additions and 10 deletions

View File

@@ -234,19 +234,30 @@ class DataFile:
def load_arrays(self, *labels: str | tuple[str, ...]) -> tuple[np.ndarray, ...]: def load_arrays(self, *labels: str | tuple[str, ...]) -> tuple[np.ndarray, ...]:
raw_data = self.load_bytes() raw_data = self.load_bytes()
extension = self.path.lower().split(".")[-1] extension = self.path.lower().split(".")[-1]
if extension == "npz": return _parse_raw_data(raw_data, extension, *labels)
df = np.load(BytesIO(raw_data))
return load_npz_data(df, *labels)
elif extension == "npy":
return tuple(np.load(BytesIO(raw_data)))
else:
return load_txt_data(raw_data.decode())
def similar_to(self, other: DataFile) -> bool: def similar_to(self, other: DataFile) -> bool:
return Path(self.path).name == Path(other.path).name return Path(self.path).name == Path(other.path).name
def load_txt_data(s: str) -> tuple[np.ndarray, ...]: def open_data_file(path: os.PathLike, *labels: str) -> tuple[np.ndarray, ...]:
path = Path(path)
raw_data = path.read_bytes()
extension = path.suffix.lower()
return _parse_raw_data(raw_data, extension, *labels)
def _parse_raw_data(data: bytes, extension: str, *keys) -> tuple[np.ndarray, ...]:
if extension.endswith("npz"):
df = np.load(BytesIO(data))
return _parse_npz_data(df, *keys)
elif extension.endswith("npy"):
return tuple(np.load(BytesIO(data)))
else:
return _parse_text_data(data.decode())
def _parse_text_data(s: str) -> tuple[np.ndarray, ...]:
lines = s.splitlines() lines = s.splitlines()
for delimiter in ", \t;": for delimiter in ", \t;":
try: try:
@@ -256,7 +267,7 @@ def load_txt_data(s: str) -> tuple[np.ndarray, ...]:
raise ValueError("Could not load text data as numpy array") raise ValueError("Could not load text data as numpy array")
def load_npz_data( def _parse_npz_data(
df: np.lib.npyio.NpzFile, *labels: str | tuple[str, ...] df: np.lib.npyio.NpzFile, *labels: str | tuple[str, ...]
) -> tuple[np.ndarray, ...]: ) -> tuple[np.ndarray, ...]:
if not labels: if not labels:

View File

@@ -11,7 +11,6 @@ n is the number of spectra at the same z position and nt is the size of the time
import os import os
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Literal, Tuple, TypeVar from typing import Literal, Tuple, TypeVar