rudimentary new io data structure
This commit is contained in:
@@ -1,8 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
from typing import BinaryIO, Protocol, Sequence
|
||||
from zipfile import ZipFile
|
||||
|
||||
import numpy as np
|
||||
import pkg_resources
|
||||
|
||||
|
||||
@@ -11,27 +17,207 @@ def data_file(path: str) -> Path:
|
||||
return Path(pkg_resources.resource_filename("scgenerator", path))
|
||||
|
||||
|
||||
class DatetimeEncoder(json.JSONEncoder):
|
||||
class CustomEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, (datetime.date, datetime.datetime)):
|
||||
return obj.isoformat()
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return tuple(obj)
|
||||
elif isinstance(obj, DataFile):
|
||||
return obj.__json__()
|
||||
|
||||
|
||||
def decode_datetime_hook(obj):
|
||||
def _decode_datetime(s: str) -> datetime.date | datetime.datetime | str:
|
||||
try:
|
||||
return datetime.datetime.fromisoformat(s)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return datetime.date.fromisoformat(s)
|
||||
except Exception:
|
||||
pass
|
||||
return s
|
||||
|
||||
|
||||
def custom_decode_hook(obj):
|
||||
"""
|
||||
Some simple, non json-compatible objects are encoded as str, this function reconstructs the
|
||||
original object and sets it in the decoded json structure
|
||||
"""
|
||||
for k, v in obj.items():
|
||||
if not isinstance(v, str):
|
||||
continue
|
||||
try:
|
||||
dt = datetime.datetime.fromisoformat(v)
|
||||
except Exception:
|
||||
try:
|
||||
dt = datetime.date.fromisoformat(v)
|
||||
except Exception:
|
||||
continue
|
||||
obj[k] = dt
|
||||
if isinstance(v, str):
|
||||
obj[k] = _decode_datetime(v)
|
||||
elif isinstance(v, list):
|
||||
obj[k] = tuple(v)
|
||||
return obj
|
||||
|
||||
|
||||
class PropagationIOHandler(Protocol):
|
||||
def __len__(self) -> int:
|
||||
...
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
...
|
||||
|
||||
def save_spectrum(self, index: int, spectrum: np.ndarray):
|
||||
...
|
||||
|
||||
def load_spectrum(self, index: int) -> np.ndarray:
|
||||
...
|
||||
|
||||
def save_data(self, name: str, data: bytes):
|
||||
...
|
||||
|
||||
def load_data(self, name: str) -> bytes:
|
||||
...
|
||||
|
||||
|
||||
class MemoryIOHandler:
|
||||
spectra: dict[int, np.ndarray]
|
||||
data: dict[str, bytes]
|
||||
|
||||
def __init__(self):
|
||||
self.spectra = {}
|
||||
self.data = {}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.spectra)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
return list(self.data.keys())
|
||||
|
||||
def save_spectrum(self, index: int, spectrum: np.ndarray):
|
||||
self.spectra[index] = spectrum
|
||||
|
||||
def load_spectrum(self, index: int) -> np.ndarray:
|
||||
return self.spectra[index]
|
||||
|
||||
def save_data(self, name: str, data: bytes):
|
||||
self.data[name] = data
|
||||
|
||||
def load_data(self, name: str) -> bytes:
|
||||
return self.data[name]
|
||||
|
||||
|
||||
class ZipFileIOHandler:
|
||||
file: BinaryIO
|
||||
SPECTRUM_FN = "spectra/spectrum_{}.npy"
|
||||
|
||||
def __init__(self, file: os.PathLike):
|
||||
"""
|
||||
Create a IO handler to be used in Propagation.
|
||||
This handler saves spectra as numpy `.npy` files.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file : os.PathLike
|
||||
path to the desired zip file. Will be created if it doesn't exist.
|
||||
Passing in an already opened file-like object is not supported at the moment
|
||||
"""
|
||||
self.file = Path(file)
|
||||
if not self.file.exists():
|
||||
ZipFile(self.file, "w").close()
|
||||
|
||||
def __len__(self) -> int:
|
||||
with ZipFile(self.file, "r") as zip_file:
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
zip_file.open(self.SPECTRUM_FN.format(i))
|
||||
except KeyError:
|
||||
return i
|
||||
i += 1
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
with ZipFile(self.file, "r") as zip_file:
|
||||
return zip_file.namelist()
|
||||
|
||||
def save_spectrum(self, index: int, spectrum: np.ndarray):
|
||||
with ZipFile(self.file, "a") as zip_file, zip_file.open(
|
||||
self.SPECTRUM_FN.format(index), "w"
|
||||
) as file:
|
||||
np.lib.format.write_array(file, spectrum, allow_pickle=False)
|
||||
|
||||
def load_spectrum(self, index: int) -> np.ndarray:
|
||||
with ZipFile(self.file, "r") as zip_file, zip_file.open(
|
||||
self.SPECTRUM_FN.format(index), "r"
|
||||
) as file:
|
||||
return np.load(file)
|
||||
|
||||
def save_data(self, name: str, data: bytes):
|
||||
with ZipFile(self.file, "a") as zip_file, zip_file.open(name, "w") as file:
|
||||
file.write(data)
|
||||
|
||||
def load_data(self, name: str) -> bytes:
|
||||
with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file:
|
||||
return file.read()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataFile:
|
||||
"""
|
||||
Holds information about external files necessary for a simulation.
|
||||
In the current implementation, only reading data from the file is supported
|
||||
"""
|
||||
|
||||
prefix: str | None
|
||||
path: str
|
||||
io: PropagationIOHandler
|
||||
|
||||
PREFIX_SEP = "::"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, s: str, io: PropagationIOHandler | None = None) -> DataFile:
|
||||
if cls.PREFIX_SEP in s:
|
||||
prefix, path = s.split(cls.PREFIX_SEP, 1)
|
||||
else:
|
||||
prefix = None
|
||||
path = s
|
||||
return cls(prefix, path, io)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, name: str, data: str | DataFile) -> DataFile:
|
||||
"""To be used to automatically construct a DataFile when creating a Parameters obj"""
|
||||
if isinstance(data, cls):
|
||||
return data
|
||||
elif isinstance(data, str):
|
||||
return cls.from_str(data)
|
||||
elif isinstance(data, Path):
|
||||
return cls.from_str(os.fspath(data))
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{name!r} must be a path or a bundled file specifier, not a {type(data)!r}"
|
||||
)
|
||||
|
||||
def __json__(self) -> str:
|
||||
if self.prefix is None:
|
||||
return os.fspath(Path(self.path))
|
||||
return self.prefix + self.PREFIX_SEP + self.path
|
||||
|
||||
def load_data(self) -> bytes:
|
||||
if self.prefix is not None and self.io is None:
|
||||
raise ValueError(
|
||||
f"a bundled file prefixed with {self.prefix} "
|
||||
"must have a PropagationIOHandler attached"
|
||||
)
|
||||
if self.io is not None:
|
||||
return self.io.load_data(self.path)
|
||||
else:
|
||||
return Path(self.path).read_bytes()
|
||||
|
||||
|
||||
def unique_name(base_name: str, existing: set[str]) -> str:
|
||||
name = base_name
|
||||
p = Path(base_name)
|
||||
base = p.stem
|
||||
ext = p.suffix
|
||||
i = 0
|
||||
while name in existing:
|
||||
name = f"{base}_{i}{ext}"
|
||||
i += 1
|
||||
return name
|
||||
|
||||
|
||||
def format_graph(left_elements: Sequence[str], middle: str, right_elements: Sequence[str]):
|
||||
if len(left_elements) == 0:
|
||||
left_elements = [""]
|
||||
|
||||
@@ -4,7 +4,7 @@ import datetime as datetime_module
|
||||
import json
|
||||
import os
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass, field, fields
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache, wraps
|
||||
from math import isnan
|
||||
from pathlib import Path
|
||||
@@ -12,12 +12,10 @@ from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeV
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scgenerator import utils
|
||||
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
||||
from scgenerator.evaluator import Evaluator, EvaluatorError
|
||||
from scgenerator.io import DatetimeEncoder, decode_datetime_hook
|
||||
from scgenerator.io import CustomEncoder, DataFile, custom_decode_hook
|
||||
from scgenerator.operators import Qualifier, SpecOperator
|
||||
from scgenerator.utils import update_path_name
|
||||
|
||||
T = TypeVar("T")
|
||||
DISPLAY_INFO = {}
|
||||
@@ -77,6 +75,11 @@ def string(name, n):
|
||||
raise ValueError(f"{name!r} must not be empty")
|
||||
|
||||
|
||||
def low_string(name, n):
|
||||
string(name, n)
|
||||
return n.lower()
|
||||
|
||||
|
||||
def in_range_excl(_min, _max):
|
||||
@type_checker(float, int)
|
||||
def _in_range(name, n):
|
||||
@@ -255,9 +258,6 @@ class Parameter:
|
||||
del instance._param_dico[self.name]
|
||||
|
||||
def __set__(self, instance: Parameters, value):
|
||||
if instance._frozen:
|
||||
raise AttributeError("Parameters instance is frozen and can no longer be modified")
|
||||
|
||||
if isinstance(value, Parameter):
|
||||
if self.default is not None:
|
||||
instance._param_dico[self.name] = copy(self.default)
|
||||
@@ -288,9 +288,9 @@ class Parameter:
|
||||
except TypeError:
|
||||
is_value = True
|
||||
if is_value:
|
||||
if self.converter is not None:
|
||||
v = self.converter(v)
|
||||
self._validator(self.name, v)
|
||||
ret_val = self._validator(self.name, v)
|
||||
if ret_val is not None:
|
||||
v = ret_val
|
||||
return is_value, v
|
||||
|
||||
|
||||
@@ -307,7 +307,6 @@ class Parameters:
|
||||
|
||||
# root
|
||||
name: str = Parameter(string, default="no name")
|
||||
output_path: Path = Parameter(type_checker(Path), converter=Path)
|
||||
|
||||
# fiber
|
||||
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
|
||||
@@ -315,10 +314,10 @@ class Parameters:
|
||||
n2: float = Parameter(non_negative(float, int))
|
||||
chi3: float = Parameter(non_negative(float, int))
|
||||
loss: str = Parameter(literal("capillary"))
|
||||
loss_file: str = Parameter(string)
|
||||
loss_file: DataFile = Parameter(DataFile.validate)
|
||||
effective_mode_diameter: float = Parameter(positive(float, int))
|
||||
effective_area: float = Parameter(non_negative(float, int))
|
||||
effective_area_file: str = Parameter(string)
|
||||
effective_area_file: DataFile = Parameter(DataFile.validate)
|
||||
numerical_aperture: float = Parameter(in_range_excl(0, 1))
|
||||
pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm"))
|
||||
pcf_pitch_ratio: float = Parameter(in_range_excl(0, 1))
|
||||
@@ -326,7 +325,7 @@ class Parameters:
|
||||
he_mode: tuple[int, int] = Parameter(int_pair, default=(1, 1))
|
||||
fit_parameters: tuple[int, int] = Parameter(float_pair, default=(0.08, 200e-9))
|
||||
beta2_coefficients: Iterable[float] = Parameter(num_list)
|
||||
dispersion_file: str = Parameter(string)
|
||||
dispersion_file: DataFile = Parameter(DataFile.validate)
|
||||
model: str = Parameter(
|
||||
literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"),
|
||||
)
|
||||
@@ -345,7 +344,7 @@ class Parameters:
|
||||
capillary_nested: int = Parameter(non_negative(int), default=0)
|
||||
|
||||
# gas
|
||||
gas_name: str = Parameter(string, converter=str.lower, default="vacuum")
|
||||
gas_name: str = Parameter(low_string, default="vacuum")
|
||||
pressure: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
|
||||
pressure_in: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
|
||||
pressure_out: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
|
||||
@@ -353,7 +352,7 @@ class Parameters:
|
||||
plasma_density: float = Parameter(non_negative(float, int), default=0)
|
||||
|
||||
# pulse
|
||||
field_file: str = Parameter(string)
|
||||
field_file: DataFile = Parameter(DataFile.validate)
|
||||
input_time: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
input_field: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||
repetition_rate: float = Parameter(
|
||||
@@ -380,11 +379,9 @@ class Parameters:
|
||||
# simulation
|
||||
full_field: bool = Parameter(boolean, default=False)
|
||||
integration_scheme: str = Parameter(
|
||||
literal("erk43", "erk54", "cqe", "sd", "constant"),
|
||||
converter=str.lower,
|
||||
default="erk43",
|
||||
literal("erk43", "erk54", "cqe", "sd", "constant"), default="erk43"
|
||||
)
|
||||
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
|
||||
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"))
|
||||
raman_fraction: float = Parameter(non_negative(float, int))
|
||||
spm: bool = Parameter(boolean, default=True)
|
||||
repeat: int = Parameter(positive(int), default=1)
|
||||
@@ -437,7 +434,8 @@ class Parameters:
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s: str) -> Parameters:
|
||||
return cls(**json.loads(s, object_hook=decode_datetime_hook))
|
||||
decoded = json.loads(s, object_hook=custom_decode_hook)
|
||||
return cls(**decoded)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: os.PathLike) -> Parameters:
|
||||
@@ -462,18 +460,32 @@ class Parameters:
|
||||
self.__post_init__()
|
||||
|
||||
def __setattr__(self, k, v):
|
||||
if self._frozen:
|
||||
if self._frozen and not k.endswith("_file"):
|
||||
raise AttributeError(
|
||||
f"cannot set attribute to frozen {self.__class__.__name__} instance"
|
||||
)
|
||||
object.__setattr__(self, k, v)
|
||||
|
||||
def copy(self) -> Parameters:
|
||||
return Parameters(**deepcopy(self.strip_params_dict()))
|
||||
def items(self) -> Iterator[tuple[str, Any]]:
|
||||
for k, v in self._param_dico.items():
|
||||
if v is None:
|
||||
continue
|
||||
yield k, v
|
||||
|
||||
def copy(self, freeze: bool = False) -> Parameters:
|
||||
"""create a deep copy of self. if freeze is True, the returned copy is read-only"""
|
||||
params = Parameters(**deepcopy(self.strip_params_dict()))
|
||||
if freeze:
|
||||
params.freeze()
|
||||
return params
|
||||
|
||||
def freeze(self):
|
||||
"""render the current instance read-only. This is not reversible"""
|
||||
self._frozen = True
|
||||
|
||||
def to_json(self) -> str:
|
||||
d = self.dump_dict()
|
||||
return json.dumps(d, cls=DatetimeEncoder, default=list)
|
||||
return json.dumps(d, cls=CustomEncoder, indent=4)
|
||||
|
||||
def get_evaluator(self):
|
||||
evaluator = Evaluator.default(self.full_field)
|
||||
@@ -611,7 +623,7 @@ class Parameters:
|
||||
"linear_op",
|
||||
"c_to_a_factor",
|
||||
}
|
||||
types = (np.ndarray, float, int, str, list, tuple, Path)
|
||||
types = (np.ndarray, float, int, str, list, tuple, Path, DataFile)
|
||||
c = deepcopy if copy else lambda x: x
|
||||
out = {}
|
||||
for key, value in self._param_dico.items():
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from io import BytesIO
|
||||
from typing import Iterable, TypeVar
|
||||
|
||||
import numpy as np
|
||||
@@ -7,6 +8,7 @@ from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from scgenerator import io
|
||||
from scgenerator.io import DataFile
|
||||
from scgenerator.math import argclosest, u_nm
|
||||
from scgenerator.physics import materials as mat
|
||||
from scgenerator.physics import units
|
||||
@@ -653,7 +655,7 @@ def saitoh_paramters(pcf_pitch_ratio: float) -> tuple[float, float]:
|
||||
return A, B
|
||||
|
||||
|
||||
def load_custom_effective_area(effective_area_file: str, l: np.ndarray) -> np.ndarray:
|
||||
def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
loads custom effective area file
|
||||
|
||||
@@ -669,14 +671,14 @@ def load_custom_effective_area(effective_area_file: str, l: np.ndarray) -> np.nd
|
||||
np.ndarray, shape (n,)
|
||||
wl-dependent effective mode field area
|
||||
"""
|
||||
data = np.load(effective_area_file)
|
||||
data = np.load(BytesIO(effective_area_file.load_data()))
|
||||
effective_area = data.get("A_eff", data.get("effective_area"))
|
||||
wl = data["wavelength"]
|
||||
return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l)
|
||||
|
||||
|
||||
def load_custom_dispersion(dispersion_file: str) -> tuple[np.ndarray, np.ndarray]:
|
||||
disp_file = np.load(dispersion_file)
|
||||
def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
|
||||
disp_file = np.load(BytesIO(dispersion_file.load_data()))
|
||||
wl_for_disp = disp_file["wavelength"]
|
||||
interp_range = (np.min(wl_for_disp), np.max(wl_for_disp))
|
||||
D = disp_file["dispersion"]
|
||||
@@ -684,7 +686,7 @@ def load_custom_dispersion(dispersion_file: str) -> tuple[np.ndarray, np.ndarray
|
||||
return wl_for_disp, beta2, interp_range
|
||||
|
||||
|
||||
def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray:
|
||||
def load_custom_loss(l: np.ndarray, loss_file: DataFile) -> np.ndarray:
|
||||
"""
|
||||
loads a npz loss file that contains a wavelength and a loss entry
|
||||
|
||||
@@ -700,7 +702,7 @@ def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray:
|
||||
np.ndarray, shape (n,)
|
||||
loss in 1/m units
|
||||
"""
|
||||
loss_data = np.load(loss_file)
|
||||
loss_data = np.load(BytesIO(loss_file.load_data()))
|
||||
wl = loss_data["wavelength"]
|
||||
loss = loss_data["loss"]
|
||||
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
||||
|
||||
@@ -12,6 +12,7 @@ n is the number of spectra at the same z position and nt is the size of the time
|
||||
import itertools
|
||||
import os
|
||||
from dataclasses import astuple, dataclass
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal, Tuple, TypeVar
|
||||
|
||||
@@ -25,6 +26,7 @@ from scipy.optimize._optimize import OptimizeResult
|
||||
|
||||
from scgenerator import math
|
||||
from scgenerator.defaults import default_plotting
|
||||
from scgenerator.io import DataFile
|
||||
from scgenerator.physics import units
|
||||
|
||||
c = 299792458.0
|
||||
@@ -410,8 +412,9 @@ def interp_custom_field(
|
||||
return field_0
|
||||
|
||||
|
||||
def load_custom_field(field_file: str) -> tuple[np.ndarray, np.ndarray]:
|
||||
field_data = np.load(field_file)
|
||||
def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
|
||||
data = field_file.load_data()
|
||||
field_data = np.load(BytesIO(data))
|
||||
return field_data["time"], field_data["field"]
|
||||
|
||||
|
||||
|
||||
@@ -1,37 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, Optional, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from scgenerator import math
|
||||
from scgenerator.const import PARAM_FN, SPEC1_FN, SPEC1_FN_N
|
||||
from scgenerator.logger import get_logger
|
||||
from scgenerator.io import DataFile, PropagationIOHandler, ZipFileIOHandler, unique_name
|
||||
from scgenerator.parameter import Parameters
|
||||
from scgenerator.physics import pulse, units
|
||||
from scgenerator.physics.units import PlotRange
|
||||
from scgenerator.plotting import (
|
||||
mean_values_plot,
|
||||
propagation_plot,
|
||||
single_position_plot,
|
||||
transform_1D_values,
|
||||
transform_2D_propagation,
|
||||
)
|
||||
from scgenerator.utils import load_spectrum, load_toml, simulations_list
|
||||
|
||||
|
||||
class Spectrum(np.ndarray):
|
||||
params: Parameters
|
||||
w: np.ndarray
|
||||
t: np.ndarray
|
||||
l: np.ndarray
|
||||
ifft: Callable[[np.ndarray], np.ndarray]
|
||||
|
||||
def __new__(cls, input_array, params: Parameters):
|
||||
# Input array is an already formed ndarray instance
|
||||
# We first cast to be our class type
|
||||
obj = np.asarray(input_array).view(cls)
|
||||
# add the new attribute to the created instance
|
||||
obj.params = params
|
||||
obj.w = params.compute("w")
|
||||
obj.t = params.compute("t")
|
||||
obj.t = params.compute("t")
|
||||
obj.ifft = params.compute("ifft")
|
||||
|
||||
# Finally, we must return the newly created object:
|
||||
return obj
|
||||
@@ -40,14 +35,17 @@ class Spectrum(np.ndarray):
|
||||
# see InfoArray.__array_finalize__ for comments
|
||||
if obj is None:
|
||||
return
|
||||
self.params = getattr(obj, "params", None)
|
||||
self.w = getattr(obj, "w", None)
|
||||
self.t = getattr(obj, "t", None)
|
||||
self.l = getattr(obj, "l", None)
|
||||
self.ifft = getattr(obj, "ifft", None)
|
||||
|
||||
def __getitem__(self, key) -> "Spectrum":
|
||||
return super().__getitem__(key)
|
||||
|
||||
@property
|
||||
def wl_int(self):
|
||||
return units.to_WL(math.abs2(self), self.params.l)
|
||||
return units.to_WL(math.abs2(self), self.l)
|
||||
|
||||
@property
|
||||
def freq_int(self):
|
||||
@@ -59,13 +57,13 @@ class Spectrum(np.ndarray):
|
||||
|
||||
@property
|
||||
def time_int(self):
|
||||
return math.abs2(self.params.ifft(self))
|
||||
return math.abs2(self.ifft(self))
|
||||
|
||||
def amplitude(self, unit):
|
||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||
x_axis = unit.inv(self.params.w)
|
||||
x_axis = unit.inv(self.w)
|
||||
else:
|
||||
x_axis = unit.inv(self.params.t)
|
||||
x_axis = unit.inv(self.t)
|
||||
|
||||
order = np.argsort(x_axis)
|
||||
func = dict(
|
||||
@@ -84,7 +82,7 @@ class Spectrum(np.ndarray):
|
||||
np.sqrt(
|
||||
units.to_WL(
|
||||
math.abs2(self),
|
||||
self.params.l,
|
||||
self.l,
|
||||
)
|
||||
)
|
||||
* self
|
||||
@@ -106,311 +104,115 @@ class Spectrum(np.ndarray):
|
||||
@property
|
||||
def wl_max(self):
|
||||
if self.ndim == 1:
|
||||
return self.params.l[np.argmax(self.wl_int, axis=-1)]
|
||||
return self.l[np.argmax(self.wl_int, axis=-1)]
|
||||
return np.array([s.wl_max for s in self])
|
||||
|
||||
def mask_wl(self, pos: float, width: float) -> Spectrum:
|
||||
return self * np.exp(
|
||||
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
|
||||
)
|
||||
return self * np.exp(-(((self.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2))
|
||||
|
||||
def measure(self) -> tuple[float, float, float]:
|
||||
return pulse.measure_field(self.params.t, self.time_amp)
|
||||
return pulse.measure_field(self.t, self.time_amp)
|
||||
|
||||
|
||||
class SimulationSeries:
|
||||
"""
|
||||
SimulationsSeries are the interface the user should use to load and
|
||||
interact with simulation data. The object loads each fiber of the simulation
|
||||
into a separate object and exposes convenience methods to make the series behave
|
||||
as a single fiber.
|
||||
|
||||
It should be noted that the last spectrum of a fiber and the first one of the next
|
||||
fibers are identical. Therefore, SimulationSeries object will return fewer datapoints
|
||||
than when manually mergin the corresponding data.
|
||||
|
||||
"""
|
||||
|
||||
path: Path
|
||||
fibers: list[SimulatedFiber]
|
||||
class Propagation:
|
||||
io: PropagationIOHandler
|
||||
params: Parameters
|
||||
z_indices: list[tuple[int, int]]
|
||||
fiber_positions: list[tuple[str, float]]
|
||||
_current_index: int
|
||||
|
||||
def __init__(self, path: os.PathLike):
|
||||
PARAMS_FN = "params.json"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
io_handler: PropagationIOHandler,
|
||||
params: Parameters | None = None,
|
||||
bundle_data: bool = False,
|
||||
):
|
||||
"""
|
||||
Create a SimulationSeries
|
||||
A propagation is the object that manages IO for one single propagation.
|
||||
It is recommended to use the "memory_propagation" and "zip_propagation" convenience
|
||||
factories instead of creating this object directly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : os.PathLike
|
||||
path to the last fiber of the series
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
No simulation found in specified directory
|
||||
io : PropagationIOHandler
|
||||
object that implements the PropagationIOHandler Protocol.
|
||||
params : Parameters
|
||||
simulations parameters. Those will be saved via the
|
||||
"""
|
||||
self.logger = get_logger()
|
||||
for self.path in simulations_list(path):
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(f"No simulation in {path}")
|
||||
self.fibers = [SimulatedFiber(self.path)]
|
||||
while (p := self.fibers[-1].params.prev_data_dir) is not None:
|
||||
p = Path(p)
|
||||
if not p.is_absolute():
|
||||
p = Path(self.fibers[-1].params.output_path) / p
|
||||
self.fibers.append(SimulatedFiber(p))
|
||||
self.fibers = self.fibers[::-1]
|
||||
self.io = io_handler
|
||||
self._current_index = len(self.io)
|
||||
|
||||
self.fiber_positions = [(self.fibers[0].params.name, 0.0)]
|
||||
self.params = Parameters(**self.fibers[0].params.dump_dict(False, False))
|
||||
z_targets = list(self.params.z_targets)
|
||||
self.z_indices = [(0, j) for j in range(self.params.z_num)]
|
||||
for i, fiber in enumerate(self.fibers[1:]):
|
||||
self.fiber_positions.append((fiber.params.name, z_targets[-1]))
|
||||
z_targets += list(fiber.params.z_targets[1:] + z_targets[-1])
|
||||
self.z_indices += [(i + 1, j) for j in range(1, fiber.params.z_num)]
|
||||
self.params.z_targets = np.array(z_targets)
|
||||
self.params.length = self.params.z_targets[-1]
|
||||
self.params.z_num = len(self.params.z_targets)
|
||||
new_params = params is not None
|
||||
if not new_params:
|
||||
if bundle_data:
|
||||
raise ValueError(
|
||||
"cannot bundle data to existing Propagation. Create a new one instead"
|
||||
)
|
||||
try:
|
||||
params_data = self.io.load_data(self.PARAMS_FN)
|
||||
params = Parameters.from_json(params_data.decode())
|
||||
except KeyError:
|
||||
raise ValueError(f"Missing Parameters in {self.__class__.__name__}.") from None
|
||||
|
||||
def spectra(
|
||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||
) -> Spectrum:
|
||||
...
|
||||
if z_descr is None:
|
||||
out = [self.fibers[i].spectra(j, sim_ind) for i, j in self.z_indices]
|
||||
else:
|
||||
if isinstance(z_descr, (float, np.floating)):
|
||||
fib_ind, z_ind = self.z_ind(z_descr)
|
||||
else:
|
||||
fib_ind, z_ind = self.z_indices[z_descr]
|
||||
out = self.fibers[fib_ind].spectra(z_ind, sim_ind)
|
||||
return Spectrum(out, self.params)
|
||||
self.params = params
|
||||
|
||||
def fields(
|
||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||
) -> Spectrum:
|
||||
return self.params.ifft(self.spectra(z_descr, sim_ind))
|
||||
if bundle_data:
|
||||
if not params._frozen:
|
||||
warnings.warn(
|
||||
f"Parameters instance {params.name!r} is not frozen but will be saved "
|
||||
"in an unmodifiable state anyway."
|
||||
)
|
||||
_bundle_external_files(self.params, self.io)
|
||||
|
||||
def z_ind(self, pos: float) -> tuple[int, int]:
|
||||
if 0 <= pos <= self.params.length:
|
||||
ind = np.argmin(np.abs(self.params.z_targets - pos))
|
||||
return self.z_indices[ind]
|
||||
else:
|
||||
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
|
||||
if new_params:
|
||||
self.io.save_data(self.PARAMS_FN, self.params.to_json().encode())
|
||||
|
||||
def plot_2D(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
ax: plt.Axes,
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
plot_range = PlotRange(left, right, unit)
|
||||
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
|
||||
return propagation_plot(vals, plot_range, self.params, ax, **kwargs)
|
||||
def __len__(self) -> int:
|
||||
return self._current_index
|
||||
|
||||
def plot_values_2D(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
plot_range = PlotRange(left, right, unit)
|
||||
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
|
||||
return transform_2D_propagation(vals, plot_range, self.params, **kwargs)
|
||||
def __getitem__(self, key: int | slice) -> Spectrum:
|
||||
if isinstance(key, slice):
|
||||
return self._load_slice(key)
|
||||
if isinstance(key, (float, np.floating)):
|
||||
key = math.argclosest(self.params.compute("z_targets"), key)
|
||||
array = self.io.load_spectrum(key)
|
||||
return Spectrum(array, self.params)
|
||||
|
||||
def plot_1D(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
ax: plt.Axes,
|
||||
z_pos: int,
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
plot_range = PlotRange(left, right, unit)
|
||||
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
|
||||
return single_position_plot(vals, plot_range, self.params, ax, **kwargs)
|
||||
def __setitem__(self, key: int, value: np.ndarray):
|
||||
if not isinstance(key, int):
|
||||
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
|
||||
self.io.save_spectrum(key, np.asarray(value))
|
||||
|
||||
def plot_values_1D(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
z_pos: int,
|
||||
sim_ind: int = 0,
|
||||
**kwargs,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
gives the desired values already tranformes according to the give range
|
||||
def _load_slice(self, key: slice) -> Spectrum:
|
||||
_iter = range(len(self))[key]
|
||||
out = Spectrum(np.zeros((len(_iter), self.params.t_num), dtype=complex), self.params)
|
||||
for i in _iter:
|
||||
out[i] = self.io.load_spectrum(i)
|
||||
return out
|
||||
|
||||
Parameters
|
||||
----------
|
||||
left : float
|
||||
leftmost limit in unit
|
||||
right : float
|
||||
rightmost limit in unit
|
||||
unit : Union[Callable[[float], float], str]
|
||||
unit
|
||||
z_pos : Union[int, float]
|
||||
position either as an index (int) or a real position (float)
|
||||
sim_ind : Optional[int]
|
||||
which simulation to take when more than one are present
|
||||
def append(self, spectrum: np.ndarray):
|
||||
self.io.save_spectrum(self._current_index, spectrum.asarray())
|
||||
self._current_index += 1
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
x axis
|
||||
np.ndarray
|
||||
y values
|
||||
"""
|
||||
plot_range = PlotRange(left, right, unit)
|
||||
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
|
||||
return transform_1D_values(vals, plot_range, self.params, **kwargs)
|
||||
|
||||
def plot_mean(
|
||||
self,
|
||||
left: float,
|
||||
right: float,
|
||||
unit: Union[Callable[[float], float], str],
|
||||
ax: plt.Axes,
|
||||
z_pos: int,
|
||||
**kwargs,
|
||||
):
|
||||
plot_range = PlotRange(left, right, unit)
|
||||
vals = self.retrieve_plot_values(plot_range, z_pos, None)
|
||||
return mean_values_plot(vals, plot_range, self.params, ax, **kwargs)
|
||||
|
||||
def retrieve_plot_values(
|
||||
self, plot_range: PlotRange, z_pos: Optional[Union[int, float]], sim_ind: Optional[int]
|
||||
):
|
||||
if plot_range.unit.type == "TIME":
|
||||
return self.fields(z_pos, sim_ind)
|
||||
else:
|
||||
return self.spectra(z_pos, sim_ind)
|
||||
|
||||
def rin_propagation(
|
||||
self, left: float, right: float, unit: str
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
returns the RIN as function of unit and z
|
||||
|
||||
Parameters
|
||||
----------
|
||||
left : float
|
||||
left limit in unit
|
||||
right : float
|
||||
right limit in unit
|
||||
unit : str
|
||||
unit descriptor
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : np.ndarray, shape (nt,)
|
||||
x axis
|
||||
y : np.ndarray, shape (z_num, )
|
||||
y axis
|
||||
rin_prop : np.ndarray, shape (z_num, nt)
|
||||
RIN
|
||||
"""
|
||||
spectra = []
|
||||
for spec in np.moveaxis(self.spectra(None, None), 1, 0):
|
||||
x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False)
|
||||
spectra.append(tmp)
|
||||
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
|
||||
|
||||
# Magic methods
|
||||
|
||||
def __iter__(self) -> Iterator[Spectrum]:
|
||||
for i, j in self.z_indices:
|
||||
yield self.fibers[i].spectra(j, None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(path={self.path})"
|
||||
|
||||
def __eq__(self, other: SimulationSeries) -> bool:
|
||||
return self.path == other.path and self.params == other.params
|
||||
|
||||
def __contains__(self, fiber: SimulatedFiber) -> bool:
|
||||
return fiber in self.fibers
|
||||
|
||||
def __getitem__(self, key) -> Spectrum:
|
||||
if isinstance(key, tuple):
|
||||
return self.spectra(*key)
|
||||
else:
|
||||
return self.spectra(key, None)
|
||||
def load_all(self) -> Spectrum:
|
||||
return self._load_slice(slice(None))
|
||||
|
||||
|
||||
class SimulatedFiber:
|
||||
params: Parameters
|
||||
t: np.ndarray
|
||||
w: np.ndarray
|
||||
def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
|
||||
"""copies every external file specified in the parameters and saves it"""
|
||||
existing_files = set(io.keys())
|
||||
for _, value in params.items():
|
||||
if isinstance(value, DataFile):
|
||||
data = value.load_data()
|
||||
|
||||
def __init__(self, path: os.PathLike):
|
||||
self.path = Path(path)
|
||||
self.params = Parameters(**load_toml(self.path / PARAM_FN))
|
||||
self.params.output_path = str(self.path.resolve())
|
||||
self.t = self.params.t
|
||||
self.w = self.params.w
|
||||
self.z = self.params.z_targets
|
||||
value.io = io
|
||||
value.prefix = "zip"
|
||||
value.path = unique_name(Path(value.path).name, existing_files)
|
||||
existing_files.add(value.path)
|
||||
|
||||
def spectra(
|
||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
||||
) -> np.ndarray:
|
||||
if z_descr is None:
|
||||
out = [self.spectra(i, sim_ind) for i in range(self.params.z_num)]
|
||||
else:
|
||||
if isinstance(z_descr, (float, np.floating)):
|
||||
return self.spectra(self.z_ind(z_descr), sim_ind)
|
||||
else:
|
||||
z_ind = z_descr
|
||||
io.save_data(value.path, data)
|
||||
|
||||
if z_ind < 0:
|
||||
z_ind = self.params.z_num + z_ind
|
||||
|
||||
if sim_ind is None:
|
||||
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
|
||||
else:
|
||||
out = self._load_1(z_ind)
|
||||
return Spectrum(out, self.params)
|
||||
|
||||
def z_ind(self, pos: float) -> int:
|
||||
if 0 <= pos <= self.z[-1]:
|
||||
return np.argmin(np.abs(self.z - pos))
|
||||
else:
|
||||
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
|
||||
|
||||
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
|
||||
"""
|
||||
loads a spectrum file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z_ind : int
|
||||
z_index relative to the entire simulation
|
||||
sim_ind : int, optional
|
||||
simulation index, used when repeated simulations with same parameters are ran,
|
||||
by default 0
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
loaded spectrum file
|
||||
"""
|
||||
if sim_ind > 0:
|
||||
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
|
||||
else:
|
||||
return load_spectrum(self.path / SPEC1_FN.format(z_ind))
|
||||
psd = np.fft.rfft(signal) / np.sqrt(0.5 * len(time) / dt)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(path={self.path})"
|
||||
def load_all(path: os.PathLike) -> Spectrum:
|
||||
io = ZipFileIOHandler(path)
|
||||
return Propagation(io).load_all()
|
||||
|
||||
BIN
tests/data/PM2000D_2 extrapolated 4 0.npz
Normal file
BIN
tests/data/PM2000D_2 extrapolated 4 0.npz
Normal file
Binary file not shown.
BIN
tests/data/PM2000D_A_eff_marcuse.npz
Normal file
BIN
tests/data/PM2000D_A_eff_marcuse.npz
Normal file
Binary file not shown.
124
tests/test_io_handlers.py
Normal file
124
tests/test_io_handlers.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
|
||||
from scgenerator.parameter import Parameters
|
||||
from scgenerator.spectra import Propagation
|
||||
|
||||
PARAMS = dict(
|
||||
name="PM2000D sech pulse",
|
||||
# pulse
|
||||
wavelength=1546e-9,
|
||||
# field_file="./Pos30000New.npz",
|
||||
width=1.5e-12,
|
||||
shape="sech",
|
||||
repetition_rate=40e6,
|
||||
# fiber
|
||||
dispersion_file="./PM2000D_2 extrapolated 4 0.npz",
|
||||
effective_area_file="./PM2000D_A_eff_marcuse.npz",
|
||||
wavelength_window=(400e-9, 4000e-9),
|
||||
n2=4.5e-20,
|
||||
# simulation
|
||||
raman_type="measured",
|
||||
quantum_noise=True,
|
||||
interpolation_degree=11,
|
||||
z_num=128,
|
||||
length=1.5,
|
||||
t_num=512,
|
||||
dt=5e-15,
|
||||
)
|
||||
|
||||
|
||||
def test_file(tmp_path: Path):
|
||||
params = Parameters(**PARAMS)
|
||||
stuff = np.random.rand(8, 512)
|
||||
io = ZipFileIOHandler(tmp_path / "test.zip")
|
||||
io.save_data("params.json", params.to_json().encode())
|
||||
for i, spec in enumerate(stuff):
|
||||
io.save_spectrum(i, spec)
|
||||
|
||||
new_params = Parameters.from_json(io.load_data("params.json").decode())
|
||||
assert new_params is not params
|
||||
for k, v in params.items():
|
||||
v_new = getattr(new_params, k)
|
||||
if isinstance(v, DataFile):
|
||||
assert Path(v.path).name == Path(v_new.path).name
|
||||
else:
|
||||
assert v == getattr(new_params, k)
|
||||
|
||||
for i in range(8):
|
||||
assert np.all(io.load_spectrum(i) == stuff[i])
|
||||
|
||||
assert len(ZipFileIOHandler(tmp_path / "test.zip")) == len(io) == 8
|
||||
|
||||
|
||||
def test_memory():
|
||||
params = Parameters(**PARAMS)
|
||||
stuff = np.random.rand(8, 512)
|
||||
io = MemoryIOHandler()
|
||||
assert len(io) == 0
|
||||
io.save_data("params.json", params.to_json().encode())
|
||||
for i, spec in enumerate(stuff):
|
||||
io.save_spectrum(i, spec)
|
||||
|
||||
new_params = Parameters.from_json(io.load_data("params.json").decode())
|
||||
for k, v in params.items():
|
||||
v_new = getattr(new_params, k)
|
||||
if isinstance(v, DataFile):
|
||||
assert Path(v.path).name == Path(v_new.path).name
|
||||
else:
|
||||
assert v == getattr(new_params, k)
|
||||
for i in range(8):
|
||||
assert np.all(io.load_spectrum(i) == stuff[i])
|
||||
|
||||
assert len(io) == 8
|
||||
|
||||
|
||||
def test_zip_bundle(tmp_path: Path):
|
||||
params = Parameters(**PARAMS)
|
||||
io = ZipFileIOHandler(tmp_path / "file.zip")
|
||||
prop = Propagation(io, params.copy(True))
|
||||
|
||||
assert (tmp_path / "file.zip").exists()
|
||||
assert (tmp_path / "file.zip").read_bytes() != b""
|
||||
|
||||
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
|
||||
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
|
||||
params.dispersion_file.path = str(new_disp_path)
|
||||
params.effective_area_file.path = str(new_aeff_path)
|
||||
params.freeze()
|
||||
|
||||
io = ZipFileIOHandler(tmp_path / "file2.zip")
|
||||
prop2 = Propagation(io, params, True)
|
||||
|
||||
assert params.dispersion_file.path == new_disp_path.name
|
||||
assert params.dispersion_file.prefix == "zip"
|
||||
assert params.effective_area_file.path == new_aeff_path.name
|
||||
assert params.effective_area_file.prefix == "zip"
|
||||
|
||||
with ZipFile(tmp_path / "file2.zip", "r") as zfile:
|
||||
with zfile.open(new_aeff_path.name) as file:
|
||||
assert file.read() == new_aeff_path.read_bytes()
|
||||
with zfile.open(new_disp_path.name) as file:
|
||||
assert file.read() == new_disp_path.read_bytes()
|
||||
|
||||
with zfile.open(Propagation.PARAMS_FN) as file:
|
||||
df = json.loads(file.read().decode())
|
||||
|
||||
assert (
|
||||
df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path
|
||||
)
|
||||
|
||||
assert (
|
||||
df["effective_area_file"]
|
||||
== params.effective_area_file.prefix + "::" + params.effective_area_file.path
|
||||
)
|
||||
|
||||
|
||||
def test_unique_name():
|
||||
existing = {"spec.npy", "spec_0.npy"}
|
||||
assert unique_name("spec.npy", existing) == "spec_1.npy"
|
||||
Reference in New Issue
Block a user