rudimentary new io data structure

This commit is contained in:
Benoît Sierro
2023-08-08 10:59:08 +02:00
parent 7b6e33ca0f
commit 98fa32c24b
8 changed files with 475 additions and 346 deletions

View File

@@ -1,8 +1,14 @@
from __future__ import annotations
import datetime
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence
from typing import BinaryIO, Protocol, Sequence
from zipfile import ZipFile
import numpy as np
import pkg_resources
@@ -11,27 +17,207 @@ def data_file(path: str) -> Path:
return Path(pkg_resources.resource_filename("scgenerator", path))
class DatetimeEncoder(json.JSONEncoder):
class CustomEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (datetime.date, datetime.datetime)):
return obj.isoformat()
elif isinstance(obj, np.ndarray):
return tuple(obj)
elif isinstance(obj, DataFile):
return obj.__json__()
def decode_datetime_hook(obj):
def _decode_datetime(s: str) -> datetime.date | datetime.datetime | str:
try:
return datetime.datetime.fromisoformat(s)
except Exception:
pass
try:
return datetime.date.fromisoformat(s)
except Exception:
pass
return s
def custom_decode_hook(obj):
"""
Some simple, non json-compatible objects are encoded as str, this function reconstructs the
original object and sets it in the decoded json structure
"""
for k, v in obj.items():
if not isinstance(v, str):
continue
try:
dt = datetime.datetime.fromisoformat(v)
except Exception:
try:
dt = datetime.date.fromisoformat(v)
except Exception:
continue
obj[k] = dt
if isinstance(v, str):
obj[k] = _decode_datetime(v)
elif isinstance(v, list):
obj[k] = tuple(v)
return obj
class PropagationIOHandler(Protocol):
def __len__(self) -> int:
...
def keys(self) -> list[str]:
...
def save_spectrum(self, index: int, spectrum: np.ndarray):
...
def load_spectrum(self, index: int) -> np.ndarray:
...
def save_data(self, name: str, data: bytes):
...
def load_data(self, name: str) -> bytes:
...
class MemoryIOHandler:
spectra: dict[int, np.ndarray]
data: dict[str, bytes]
def __init__(self):
self.spectra = {}
self.data = {}
def __len__(self) -> int:
return len(self.spectra)
def keys(self) -> list[str]:
return list(self.data.keys())
def save_spectrum(self, index: int, spectrum: np.ndarray):
self.spectra[index] = spectrum
def load_spectrum(self, index: int) -> np.ndarray:
return self.spectra[index]
def save_data(self, name: str, data: bytes):
self.data[name] = data
def load_data(self, name: str) -> bytes:
return self.data[name]
class ZipFileIOHandler:
file: BinaryIO
SPECTRUM_FN = "spectra/spectrum_{}.npy"
def __init__(self, file: os.PathLike):
"""
Create a IO handler to be used in Propagation.
This handler saves spectra as numpy `.npy` files.
Parameters
----------
file : os.PathLike
path to the desired zip file. Will be created if it doesn't exist.
Passing in an already opened file-like object is not supported at the moment
"""
self.file = Path(file)
if not self.file.exists():
ZipFile(self.file, "w").close()
def __len__(self) -> int:
with ZipFile(self.file, "r") as zip_file:
i = 0
while True:
try:
zip_file.open(self.SPECTRUM_FN.format(i))
except KeyError:
return i
i += 1
def keys(self) -> list[str]:
with ZipFile(self.file, "r") as zip_file:
return zip_file.namelist()
def save_spectrum(self, index: int, spectrum: np.ndarray):
with ZipFile(self.file, "a") as zip_file, zip_file.open(
self.SPECTRUM_FN.format(index), "w"
) as file:
np.lib.format.write_array(file, spectrum, allow_pickle=False)
def load_spectrum(self, index: int) -> np.ndarray:
with ZipFile(self.file, "r") as zip_file, zip_file.open(
self.SPECTRUM_FN.format(index), "r"
) as file:
return np.load(file)
def save_data(self, name: str, data: bytes):
with ZipFile(self.file, "a") as zip_file, zip_file.open(name, "w") as file:
file.write(data)
def load_data(self, name: str) -> bytes:
with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file:
return file.read()
@dataclass
class DataFile:
"""
Holds information about external files necessary for a simulation.
In the current implementation, only reading data from the file is supported
"""
prefix: str | None
path: str
io: PropagationIOHandler
PREFIX_SEP = "::"
@classmethod
def from_str(cls, s: str, io: PropagationIOHandler | None = None) -> DataFile:
if cls.PREFIX_SEP in s:
prefix, path = s.split(cls.PREFIX_SEP, 1)
else:
prefix = None
path = s
return cls(prefix, path, io)
@classmethod
def validate(cls, name: str, data: str | DataFile) -> DataFile:
"""To be used to automatically construct a DataFile when creating a Parameters obj"""
if isinstance(data, cls):
return data
elif isinstance(data, str):
return cls.from_str(data)
elif isinstance(data, Path):
return cls.from_str(os.fspath(data))
else:
raise TypeError(
f"{name!r} must be a path or a bundled file specifier, not a {type(data)!r}"
)
def __json__(self) -> str:
if self.prefix is None:
return os.fspath(Path(self.path))
return self.prefix + self.PREFIX_SEP + self.path
def load_data(self) -> bytes:
if self.prefix is not None and self.io is None:
raise ValueError(
f"a bundled file prefixed with {self.prefix} "
"must have a PropagationIOHandler attached"
)
if self.io is not None:
return self.io.load_data(self.path)
else:
return Path(self.path).read_bytes()
def unique_name(base_name: str, existing: set[str]) -> str:
name = base_name
p = Path(base_name)
base = p.stem
ext = p.suffix
i = 0
while name in existing:
name = f"{base}_{i}{ext}"
i += 1
return name
def format_graph(left_elements: Sequence[str], middle: str, right_elements: Sequence[str]):
if len(left_elements) == 0:
left_elements = [""]

View File

@@ -4,7 +4,7 @@ import datetime as datetime_module
import json
import os
from copy import copy, deepcopy
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field
from functools import lru_cache, wraps
from math import isnan
from pathlib import Path
@@ -12,12 +12,10 @@ from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeV
import numpy as np
from scgenerator import utils
from scgenerator.const import MANDATORY_PARAMETERS, __version__
from scgenerator.evaluator import Evaluator, EvaluatorError
from scgenerator.io import DatetimeEncoder, decode_datetime_hook
from scgenerator.io import CustomEncoder, DataFile, custom_decode_hook
from scgenerator.operators import Qualifier, SpecOperator
from scgenerator.utils import update_path_name
T = TypeVar("T")
DISPLAY_INFO = {}
@@ -77,6 +75,11 @@ def string(name, n):
raise ValueError(f"{name!r} must not be empty")
def low_string(name, n):
string(name, n)
return n.lower()
def in_range_excl(_min, _max):
@type_checker(float, int)
def _in_range(name, n):
@@ -255,9 +258,6 @@ class Parameter:
del instance._param_dico[self.name]
def __set__(self, instance: Parameters, value):
if instance._frozen:
raise AttributeError("Parameters instance is frozen and can no longer be modified")
if isinstance(value, Parameter):
if self.default is not None:
instance._param_dico[self.name] = copy(self.default)
@@ -288,9 +288,9 @@ class Parameter:
except TypeError:
is_value = True
if is_value:
if self.converter is not None:
v = self.converter(v)
self._validator(self.name, v)
ret_val = self._validator(self.name, v)
if ret_val is not None:
v = ret_val
return is_value, v
@@ -307,7 +307,6 @@ class Parameters:
# root
name: str = Parameter(string, default="no name")
output_path: Path = Parameter(type_checker(Path), converter=Path)
# fiber
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
@@ -315,10 +314,10 @@ class Parameters:
n2: float = Parameter(non_negative(float, int))
chi3: float = Parameter(non_negative(float, int))
loss: str = Parameter(literal("capillary"))
loss_file: str = Parameter(string)
loss_file: DataFile = Parameter(DataFile.validate)
effective_mode_diameter: float = Parameter(positive(float, int))
effective_area: float = Parameter(non_negative(float, int))
effective_area_file: str = Parameter(string)
effective_area_file: DataFile = Parameter(DataFile.validate)
numerical_aperture: float = Parameter(in_range_excl(0, 1))
pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm"))
pcf_pitch_ratio: float = Parameter(in_range_excl(0, 1))
@@ -326,7 +325,7 @@ class Parameters:
he_mode: tuple[int, int] = Parameter(int_pair, default=(1, 1))
fit_parameters: tuple[int, int] = Parameter(float_pair, default=(0.08, 200e-9))
beta2_coefficients: Iterable[float] = Parameter(num_list)
dispersion_file: str = Parameter(string)
dispersion_file: DataFile = Parameter(DataFile.validate)
model: str = Parameter(
literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"),
)
@@ -345,7 +344,7 @@ class Parameters:
capillary_nested: int = Parameter(non_negative(int), default=0)
# gas
gas_name: str = Parameter(string, converter=str.lower, default="vacuum")
gas_name: str = Parameter(low_string, default="vacuum")
pressure: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
pressure_in: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
pressure_out: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
@@ -353,7 +352,7 @@ class Parameters:
plasma_density: float = Parameter(non_negative(float, int), default=0)
# pulse
field_file: str = Parameter(string)
field_file: DataFile = Parameter(DataFile.validate)
input_time: np.ndarray = Parameter(type_checker(np.ndarray))
input_field: np.ndarray = Parameter(type_checker(np.ndarray))
repetition_rate: float = Parameter(
@@ -380,11 +379,9 @@ class Parameters:
# simulation
full_field: bool = Parameter(boolean, default=False)
integration_scheme: str = Parameter(
literal("erk43", "erk54", "cqe", "sd", "constant"),
converter=str.lower,
default="erk43",
literal("erk43", "erk54", "cqe", "sd", "constant"), default="erk43"
)
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"))
raman_fraction: float = Parameter(non_negative(float, int))
spm: bool = Parameter(boolean, default=True)
repeat: int = Parameter(positive(int), default=1)
@@ -437,7 +434,8 @@ class Parameters:
@classmethod
def from_json(cls, s: str) -> Parameters:
return cls(**json.loads(s, object_hook=decode_datetime_hook))
decoded = json.loads(s, object_hook=custom_decode_hook)
return cls(**decoded)
@classmethod
def load(cls, path: os.PathLike) -> Parameters:
@@ -462,18 +460,32 @@ class Parameters:
self.__post_init__()
def __setattr__(self, k, v):
if self._frozen:
if self._frozen and not k.endswith("_file"):
raise AttributeError(
f"cannot set attribute to frozen {self.__class__.__name__} instance"
)
object.__setattr__(self, k, v)
def copy(self) -> Parameters:
return Parameters(**deepcopy(self.strip_params_dict()))
def items(self) -> Iterator[tuple[str, Any]]:
for k, v in self._param_dico.items():
if v is None:
continue
yield k, v
def copy(self, freeze: bool = False) -> Parameters:
"""create a deep copy of self. if freeze is True, the returned copy is read-only"""
params = Parameters(**deepcopy(self.strip_params_dict()))
if freeze:
params.freeze()
return params
def freeze(self):
"""render the current instance read-only. This is not reversible"""
self._frozen = True
def to_json(self) -> str:
d = self.dump_dict()
return json.dumps(d, cls=DatetimeEncoder, default=list)
return json.dumps(d, cls=CustomEncoder, indent=4)
def get_evaluator(self):
evaluator = Evaluator.default(self.full_field)
@@ -611,7 +623,7 @@ class Parameters:
"linear_op",
"c_to_a_factor",
}
types = (np.ndarray, float, int, str, list, tuple, Path)
types = (np.ndarray, float, int, str, list, tuple, Path, DataFile)
c = deepcopy if copy else lambda x: x
out = {}
for key, value in self._param_dico.items():

View File

@@ -1,3 +1,4 @@
from io import BytesIO
from typing import Iterable, TypeVar
import numpy as np
@@ -7,6 +8,7 @@ from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
from scipy.interpolate import interp1d
from scgenerator import io
from scgenerator.io import DataFile
from scgenerator.math import argclosest, u_nm
from scgenerator.physics import materials as mat
from scgenerator.physics import units
@@ -653,7 +655,7 @@ def saitoh_paramters(pcf_pitch_ratio: float) -> tuple[float, float]:
return A, B
def load_custom_effective_area(effective_area_file: str, l: np.ndarray) -> np.ndarray:
def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) -> np.ndarray:
"""
loads custom effective area file
@@ -669,14 +671,14 @@ def load_custom_effective_area(effective_area_file: str, l: np.ndarray) -> np.nd
np.ndarray, shape (n,)
wl-dependent effective mode field area
"""
data = np.load(effective_area_file)
data = np.load(BytesIO(effective_area_file.load_data()))
effective_area = data.get("A_eff", data.get("effective_area"))
wl = data["wavelength"]
return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l)
def load_custom_dispersion(dispersion_file: str) -> tuple[np.ndarray, np.ndarray]:
disp_file = np.load(dispersion_file)
def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
disp_file = np.load(BytesIO(dispersion_file.load_data()))
wl_for_disp = disp_file["wavelength"]
interp_range = (np.min(wl_for_disp), np.max(wl_for_disp))
D = disp_file["dispersion"]
@@ -684,7 +686,7 @@ def load_custom_dispersion(dispersion_file: str) -> tuple[np.ndarray, np.ndarray
return wl_for_disp, beta2, interp_range
def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray:
def load_custom_loss(l: np.ndarray, loss_file: DataFile) -> np.ndarray:
"""
loads a npz loss file that contains a wavelength and a loss entry
@@ -700,7 +702,7 @@ def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray:
np.ndarray, shape (n,)
loss in 1/m units
"""
loss_data = np.load(loss_file)
loss_data = np.load(BytesIO(loss_file.load_data()))
wl = loss_data["wavelength"]
loss = loss_data["loss"]
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)

View File

@@ -12,6 +12,7 @@ n is the number of spectra at the same z position and nt is the size of the time
import itertools
import os
from dataclasses import astuple, dataclass
from io import BytesIO
from pathlib import Path
from typing import Literal, Tuple, TypeVar
@@ -25,6 +26,7 @@ from scipy.optimize._optimize import OptimizeResult
from scgenerator import math
from scgenerator.defaults import default_plotting
from scgenerator.io import DataFile
from scgenerator.physics import units
c = 299792458.0
@@ -410,8 +412,9 @@ def interp_custom_field(
return field_0
def load_custom_field(field_file: str) -> tuple[np.ndarray, np.ndarray]:
field_data = np.load(field_file)
def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
data = field_file.load_data()
field_data = np.load(BytesIO(data))
return field_data["time"], field_data["field"]

View File

@@ -1,37 +1,32 @@
from __future__ import annotations
import os
import warnings
from pathlib import Path
from typing import Callable, Iterator, Optional, Union
import matplotlib.pyplot as plt
import numpy as np
from scgenerator import math
from scgenerator.const import PARAM_FN, SPEC1_FN, SPEC1_FN_N
from scgenerator.logger import get_logger
from scgenerator.io import DataFile, PropagationIOHandler, ZipFileIOHandler, unique_name
from scgenerator.parameter import Parameters
from scgenerator.physics import pulse, units
from scgenerator.physics.units import PlotRange
from scgenerator.plotting import (
mean_values_plot,
propagation_plot,
single_position_plot,
transform_1D_values,
transform_2D_propagation,
)
from scgenerator.utils import load_spectrum, load_toml, simulations_list
class Spectrum(np.ndarray):
params: Parameters
w: np.ndarray
t: np.ndarray
l: np.ndarray
ifft: Callable[[np.ndarray], np.ndarray]
def __new__(cls, input_array, params: Parameters):
# Input array is an already formed ndarray instance
# We first cast to be our class type
obj = np.asarray(input_array).view(cls)
# add the new attribute to the created instance
obj.params = params
obj.w = params.compute("w")
obj.t = params.compute("t")
obj.t = params.compute("t")
obj.ifft = params.compute("ifft")
# Finally, we must return the newly created object:
return obj
@@ -40,14 +35,17 @@ class Spectrum(np.ndarray):
# see InfoArray.__array_finalize__ for comments
if obj is None:
return
self.params = getattr(obj, "params", None)
self.w = getattr(obj, "w", None)
self.t = getattr(obj, "t", None)
self.l = getattr(obj, "l", None)
self.ifft = getattr(obj, "ifft", None)
def __getitem__(self, key) -> "Spectrum":
return super().__getitem__(key)
@property
def wl_int(self):
return units.to_WL(math.abs2(self), self.params.l)
return units.to_WL(math.abs2(self), self.l)
@property
def freq_int(self):
@@ -59,13 +57,13 @@ class Spectrum(np.ndarray):
@property
def time_int(self):
return math.abs2(self.params.ifft(self))
return math.abs2(self.ifft(self))
def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.params.w)
x_axis = unit.inv(self.w)
else:
x_axis = unit.inv(self.params.t)
x_axis = unit.inv(self.t)
order = np.argsort(x_axis)
func = dict(
@@ -84,7 +82,7 @@ class Spectrum(np.ndarray):
np.sqrt(
units.to_WL(
math.abs2(self),
self.params.l,
self.l,
)
)
* self
@@ -106,311 +104,115 @@ class Spectrum(np.ndarray):
@property
def wl_max(self):
if self.ndim == 1:
return self.params.l[np.argmax(self.wl_int, axis=-1)]
return self.l[np.argmax(self.wl_int, axis=-1)]
return np.array([s.wl_max for s in self])
def mask_wl(self, pos: float, width: float) -> Spectrum:
return self * np.exp(
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
)
return self * np.exp(-(((self.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2))
def measure(self) -> tuple[float, float, float]:
return pulse.measure_field(self.params.t, self.time_amp)
return pulse.measure_field(self.t, self.time_amp)
class SimulationSeries:
"""
SimulationsSeries are the interface the user should use to load and
interact with simulation data. The object loads each fiber of the simulation
into a separate object and exposes convenience methods to make the series behave
as a single fiber.
It should be noted that the last spectrum of a fiber and the first one of the next
fibers are identical. Therefore, SimulationSeries object will return fewer datapoints
than when manually mergin the corresponding data.
"""
path: Path
fibers: list[SimulatedFiber]
class Propagation:
io: PropagationIOHandler
params: Parameters
z_indices: list[tuple[int, int]]
fiber_positions: list[tuple[str, float]]
_current_index: int
def __init__(self, path: os.PathLike):
PARAMS_FN = "params.json"
def __init__(
self,
io_handler: PropagationIOHandler,
params: Parameters | None = None,
bundle_data: bool = False,
):
"""
Create a SimulationSeries
A propagation is the object that manages IO for one single propagation.
It is recommended to use the "memory_propagation" and "zip_propagation" convenience
factories instead of creating this object directly.
Parameters
----------
path : os.PathLike
path to the last fiber of the series
Raises
------
FileNotFoundError
No simulation found in specified directory
io : PropagationIOHandler
object that implements the PropagationIOHandler Protocol.
params : Parameters
simulations parameters. Those will be saved via the
"""
self.logger = get_logger()
for self.path in simulations_list(path):
break
else:
raise FileNotFoundError(f"No simulation in {path}")
self.fibers = [SimulatedFiber(self.path)]
while (p := self.fibers[-1].params.prev_data_dir) is not None:
p = Path(p)
if not p.is_absolute():
p = Path(self.fibers[-1].params.output_path) / p
self.fibers.append(SimulatedFiber(p))
self.fibers = self.fibers[::-1]
self.io = io_handler
self._current_index = len(self.io)
self.fiber_positions = [(self.fibers[0].params.name, 0.0)]
self.params = Parameters(**self.fibers[0].params.dump_dict(False, False))
z_targets = list(self.params.z_targets)
self.z_indices = [(0, j) for j in range(self.params.z_num)]
for i, fiber in enumerate(self.fibers[1:]):
self.fiber_positions.append((fiber.params.name, z_targets[-1]))
z_targets += list(fiber.params.z_targets[1:] + z_targets[-1])
self.z_indices += [(i + 1, j) for j in range(1, fiber.params.z_num)]
self.params.z_targets = np.array(z_targets)
self.params.length = self.params.z_targets[-1]
self.params.z_num = len(self.params.z_targets)
new_params = params is not None
if not new_params:
if bundle_data:
raise ValueError(
"cannot bundle data to existing Propagation. Create a new one instead"
)
try:
params_data = self.io.load_data(self.PARAMS_FN)
params = Parameters.from_json(params_data.decode())
except KeyError:
raise ValueError(f"Missing Parameters in {self.__class__.__name__}.") from None
def spectra(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> Spectrum:
...
if z_descr is None:
out = [self.fibers[i].spectra(j, sim_ind) for i, j in self.z_indices]
else:
if isinstance(z_descr, (float, np.floating)):
fib_ind, z_ind = self.z_ind(z_descr)
else:
fib_ind, z_ind = self.z_indices[z_descr]
out = self.fibers[fib_ind].spectra(z_ind, sim_ind)
return Spectrum(out, self.params)
self.params = params
def fields(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> Spectrum:
return self.params.ifft(self.spectra(z_descr, sim_ind))
if bundle_data:
if not params._frozen:
warnings.warn(
f"Parameters instance {params.name!r} is not frozen but will be saved "
"in an unmodifiable state anyway."
)
_bundle_external_files(self.params, self.io)
def z_ind(self, pos: float) -> tuple[int, int]:
if 0 <= pos <= self.params.length:
ind = np.argmin(np.abs(self.params.z_targets - pos))
return self.z_indices[ind]
else:
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
if new_params:
self.io.save_data(self.PARAMS_FN, self.params.to_json().encode())
def plot_2D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
return propagation_plot(vals, plot_range, self.params, ax, **kwargs)
def __len__(self) -> int:
return self._current_index
def plot_values_2D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
return transform_2D_propagation(vals, plot_range, self.params, **kwargs)
def __getitem__(self, key: int | slice) -> Spectrum:
if isinstance(key, slice):
return self._load_slice(key)
if isinstance(key, (float, np.floating)):
key = math.argclosest(self.params.compute("z_targets"), key)
array = self.io.load_spectrum(key)
return Spectrum(array, self.params)
def plot_1D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: int,
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
return single_position_plot(vals, plot_range, self.params, ax, **kwargs)
def __setitem__(self, key: int, value: np.ndarray):
if not isinstance(key, int):
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
self.io.save_spectrum(key, np.asarray(value))
def plot_values_1D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_pos: int,
sim_ind: int = 0,
**kwargs,
) -> tuple[np.ndarray, np.ndarray]:
"""
gives the desired values already tranformes according to the give range
def _load_slice(self, key: slice) -> Spectrum:
_iter = range(len(self))[key]
out = Spectrum(np.zeros((len(_iter), self.params.t_num), dtype=complex), self.params)
for i in _iter:
out[i] = self.io.load_spectrum(i)
return out
Parameters
----------
left : float
leftmost limit in unit
right : float
rightmost limit in unit
unit : Union[Callable[[float], float], str]
unit
z_pos : Union[int, float]
position either as an index (int) or a real position (float)
sim_ind : Optional[int]
which simulation to take when more than one are present
def append(self, spectrum: np.ndarray):
self.io.save_spectrum(self._current_index, spectrum.asarray())
self._current_index += 1
Returns
-------
np.ndarray
x axis
np.ndarray
y values
"""
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
return transform_1D_values(vals, plot_range, self.params, **kwargs)
def plot_mean(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: int,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, None)
return mean_values_plot(vals, plot_range, self.params, ax, **kwargs)
def retrieve_plot_values(
self, plot_range: PlotRange, z_pos: Optional[Union[int, float]], sim_ind: Optional[int]
):
if plot_range.unit.type == "TIME":
return self.fields(z_pos, sim_ind)
else:
return self.spectra(z_pos, sim_ind)
def rin_propagation(
self, left: float, right: float, unit: str
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
returns the RIN as function of unit and z
Parameters
----------
left : float
left limit in unit
right : float
right limit in unit
unit : str
unit descriptor
Returns
-------
x : np.ndarray, shape (nt,)
x axis
y : np.ndarray, shape (z_num, )
y axis
rin_prop : np.ndarray, shape (z_num, nt)
RIN
"""
spectra = []
for spec in np.moveaxis(self.spectra(None, None), 1, 0):
x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False)
spectra.append(tmp)
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
# Magic methods
def __iter__(self) -> Iterator[Spectrum]:
for i, j in self.z_indices:
yield self.fibers[i].spectra(j, None)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path})"
def __eq__(self, other: SimulationSeries) -> bool:
return self.path == other.path and self.params == other.params
def __contains__(self, fiber: SimulatedFiber) -> bool:
return fiber in self.fibers
def __getitem__(self, key) -> Spectrum:
if isinstance(key, tuple):
return self.spectra(*key)
else:
return self.spectra(key, None)
def load_all(self) -> Spectrum:
return self._load_slice(slice(None))
class SimulatedFiber:
params: Parameters
t: np.ndarray
w: np.ndarray
def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
"""copies every external file specified in the parameters and saves it"""
existing_files = set(io.keys())
for _, value in params.items():
if isinstance(value, DataFile):
data = value.load_data()
def __init__(self, path: os.PathLike):
self.path = Path(path)
self.params = Parameters(**load_toml(self.path / PARAM_FN))
self.params.output_path = str(self.path.resolve())
self.t = self.params.t
self.w = self.params.w
self.z = self.params.z_targets
value.io = io
value.prefix = "zip"
value.path = unique_name(Path(value.path).name, existing_files)
existing_files.add(value.path)
def spectra(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> np.ndarray:
if z_descr is None:
out = [self.spectra(i, sim_ind) for i in range(self.params.z_num)]
else:
if isinstance(z_descr, (float, np.floating)):
return self.spectra(self.z_ind(z_descr), sim_ind)
else:
z_ind = z_descr
io.save_data(value.path, data)
if z_ind < 0:
z_ind = self.params.z_num + z_ind
if sim_ind is None:
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
else:
out = self._load_1(z_ind)
return Spectrum(out, self.params)
def z_ind(self, pos: float) -> int:
if 0 <= pos <= self.z[-1]:
return np.argmin(np.abs(self.z - pos))
else:
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
"""
loads a spectrum file
Parameters
----------
z_ind : int
z_index relative to the entire simulation
sim_ind : int, optional
simulation index, used when repeated simulations with same parameters are ran,
by default 0
Returns
-------
np.ndarray
loaded spectrum file
"""
if sim_ind > 0:
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
else:
return load_spectrum(self.path / SPEC1_FN.format(z_ind))
psd = np.fft.rfft(signal) / np.sqrt(0.5 * len(time) / dt)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path})"
def load_all(path: os.PathLike) -> Spectrum:
io = ZipFileIOHandler(path)
return Propagation(io).load_all()

Binary file not shown.

Binary file not shown.

124
tests/test_io_handlers.py Normal file
View File

@@ -0,0 +1,124 @@
import json
from pathlib import Path
from zipfile import ZipFile
import numpy as np
import pytest
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
from scgenerator.parameter import Parameters
from scgenerator.spectra import Propagation
PARAMS = dict(
name="PM2000D sech pulse",
# pulse
wavelength=1546e-9,
# field_file="./Pos30000New.npz",
width=1.5e-12,
shape="sech",
repetition_rate=40e6,
# fiber
dispersion_file="./PM2000D_2 extrapolated 4 0.npz",
effective_area_file="./PM2000D_A_eff_marcuse.npz",
wavelength_window=(400e-9, 4000e-9),
n2=4.5e-20,
# simulation
raman_type="measured",
quantum_noise=True,
interpolation_degree=11,
z_num=128,
length=1.5,
t_num=512,
dt=5e-15,
)
def test_file(tmp_path: Path):
params = Parameters(**PARAMS)
stuff = np.random.rand(8, 512)
io = ZipFileIOHandler(tmp_path / "test.zip")
io.save_data("params.json", params.to_json().encode())
for i, spec in enumerate(stuff):
io.save_spectrum(i, spec)
new_params = Parameters.from_json(io.load_data("params.json").decode())
assert new_params is not params
for k, v in params.items():
v_new = getattr(new_params, k)
if isinstance(v, DataFile):
assert Path(v.path).name == Path(v_new.path).name
else:
assert v == getattr(new_params, k)
for i in range(8):
assert np.all(io.load_spectrum(i) == stuff[i])
assert len(ZipFileIOHandler(tmp_path / "test.zip")) == len(io) == 8
def test_memory():
params = Parameters(**PARAMS)
stuff = np.random.rand(8, 512)
io = MemoryIOHandler()
assert len(io) == 0
io.save_data("params.json", params.to_json().encode())
for i, spec in enumerate(stuff):
io.save_spectrum(i, spec)
new_params = Parameters.from_json(io.load_data("params.json").decode())
for k, v in params.items():
v_new = getattr(new_params, k)
if isinstance(v, DataFile):
assert Path(v.path).name == Path(v_new.path).name
else:
assert v == getattr(new_params, k)
for i in range(8):
assert np.all(io.load_spectrum(i) == stuff[i])
assert len(io) == 8
def test_zip_bundle(tmp_path: Path):
params = Parameters(**PARAMS)
io = ZipFileIOHandler(tmp_path / "file.zip")
prop = Propagation(io, params.copy(True))
assert (tmp_path / "file.zip").exists()
assert (tmp_path / "file.zip").read_bytes() != b""
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
params.dispersion_file.path = str(new_disp_path)
params.effective_area_file.path = str(new_aeff_path)
params.freeze()
io = ZipFileIOHandler(tmp_path / "file2.zip")
prop2 = Propagation(io, params, True)
assert params.dispersion_file.path == new_disp_path.name
assert params.dispersion_file.prefix == "zip"
assert params.effective_area_file.path == new_aeff_path.name
assert params.effective_area_file.prefix == "zip"
with ZipFile(tmp_path / "file2.zip", "r") as zfile:
with zfile.open(new_aeff_path.name) as file:
assert file.read() == new_aeff_path.read_bytes()
with zfile.open(new_disp_path.name) as file:
assert file.read() == new_disp_path.read_bytes()
with zfile.open(Propagation.PARAMS_FN) as file:
df = json.loads(file.read().decode())
assert (
df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path
)
assert (
df["effective_area_file"]
== params.effective_area_file.prefix + "::" + params.effective_area_file.path
)
def test_unique_name():
existing = {"spec.npy", "spec_0.npy"}
assert unique_name("spec.npy", existing) == "spec_1.npy"