From 7f07b73626b4ed2c9e786495a72a50c6d2537df3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 14 Sep 2023 16:40:55 +0200 Subject: [PATCH] new: generic open_data_file function --- src/scgenerator/io.py | 29 ++++++++++++++++++++--------- src/scgenerator/physics/pulse.py | 1 - 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index cc316c3..3b6a9d3 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -234,19 +234,30 @@ class DataFile: def load_arrays(self, *labels: str | tuple[str, ...]) -> tuple[np.ndarray, ...]: raw_data = self.load_bytes() extension = self.path.lower().split(".")[-1] - if extension == "npz": - df = np.load(BytesIO(raw_data)) - return load_npz_data(df, *labels) - elif extension == "npy": - return tuple(np.load(BytesIO(raw_data))) - else: - return load_txt_data(raw_data.decode()) + return _parse_raw_data(raw_data, extension, *labels) def similar_to(self, other: DataFile) -> bool: return Path(self.path).name == Path(other.path).name -def load_txt_data(s: str) -> tuple[np.ndarray, ...]: +def open_data_file(path: os.PathLike, *labels: str) -> tuple[np.ndarray, ...]: + path = Path(path) + raw_data = path.read_bytes() + extension = path.suffix.lower() + return _parse_raw_data(raw_data, extension, *labels) + + +def _parse_raw_data(data: bytes, extension: str, *keys) -> tuple[np.ndarray, ...]: + if extension.endswith("npz"): + df = np.load(BytesIO(data)) + return _parse_npz_data(df, *keys) + elif extension.endswith("npy"): + return tuple(np.load(BytesIO(data))) + else: + return _parse_text_data(data.decode()) + + +def _parse_text_data(s: str) -> tuple[np.ndarray, ...]: lines = s.splitlines() for delimiter in ", \t;": try: @@ -256,7 +267,7 @@ def load_txt_data(s: str) -> tuple[np.ndarray, ...]: raise ValueError("Could not load text data as numpy array") -def load_npz_data( +def _parse_npz_data( df: np.lib.npyio.NpzFile, *labels: str | tuple[str, ...] ) -> tuple[np.ndarray, ...]: if not labels: diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 64a066b..13a90d7 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -11,7 +11,6 @@ n is the number of spectra at the same z position and nt is the size of the time import os from dataclasses import astuple, dataclass -from io import BytesIO from pathlib import Path from typing import Literal, Tuple, TypeVar