rudimentary new io data structure

This commit is contained in:
Benoît Sierro
2023-08-08 10:59:08 +02:00
parent 7b6e33ca0f
commit 98fa32c24b
8 changed files with 475 additions and 346 deletions

View File

@@ -1,8 +1,14 @@
from __future__ import annotations
import datetime import datetime
import json import json
import os
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Sequence from typing import BinaryIO, Protocol, Sequence
from zipfile import ZipFile
import numpy as np
import pkg_resources import pkg_resources
@@ -11,27 +17,207 @@ def data_file(path: str) -> Path:
return Path(pkg_resources.resource_filename("scgenerator", path)) return Path(pkg_resources.resource_filename("scgenerator", path))
class DatetimeEncoder(json.JSONEncoder): class CustomEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, (datetime.date, datetime.datetime)): if isinstance(obj, (datetime.date, datetime.datetime)):
return obj.isoformat() return obj.isoformat()
elif isinstance(obj, np.ndarray):
return tuple(obj)
elif isinstance(obj, DataFile):
return obj.__json__()
def decode_datetime_hook(obj): def _decode_datetime(s: str) -> datetime.date | datetime.datetime | str:
try:
return datetime.datetime.fromisoformat(s)
except Exception:
pass
try:
return datetime.date.fromisoformat(s)
except Exception:
pass
return s
def custom_decode_hook(obj):
"""
Some simple, non json-compatible objects are encoded as str, this function reconstructs the
original object and sets it in the decoded json structure
"""
for k, v in obj.items(): for k, v in obj.items():
if not isinstance(v, str): if isinstance(v, str):
continue obj[k] = _decode_datetime(v)
try: elif isinstance(v, list):
dt = datetime.datetime.fromisoformat(v) obj[k] = tuple(v)
except Exception:
try:
dt = datetime.date.fromisoformat(v)
except Exception:
continue
obj[k] = dt
return obj return obj
class PropagationIOHandler(Protocol):
def __len__(self) -> int:
...
def keys(self) -> list[str]:
...
def save_spectrum(self, index: int, spectrum: np.ndarray):
...
def load_spectrum(self, index: int) -> np.ndarray:
...
def save_data(self, name: str, data: bytes):
...
def load_data(self, name: str) -> bytes:
...
class MemoryIOHandler:
spectra: dict[int, np.ndarray]
data: dict[str, bytes]
def __init__(self):
self.spectra = {}
self.data = {}
def __len__(self) -> int:
return len(self.spectra)
def keys(self) -> list[str]:
return list(self.data.keys())
def save_spectrum(self, index: int, spectrum: np.ndarray):
self.spectra[index] = spectrum
def load_spectrum(self, index: int) -> np.ndarray:
return self.spectra[index]
def save_data(self, name: str, data: bytes):
self.data[name] = data
def load_data(self, name: str) -> bytes:
return self.data[name]
class ZipFileIOHandler:
file: BinaryIO
SPECTRUM_FN = "spectra/spectrum_{}.npy"
def __init__(self, file: os.PathLike):
"""
Create a IO handler to be used in Propagation.
This handler saves spectra as numpy `.npy` files.
Parameters
----------
file : os.PathLike
path to the desired zip file. Will be created if it doesn't exist.
Passing in an already opened file-like object is not supported at the moment
"""
self.file = Path(file)
if not self.file.exists():
ZipFile(self.file, "w").close()
def __len__(self) -> int:
with ZipFile(self.file, "r") as zip_file:
i = 0
while True:
try:
zip_file.open(self.SPECTRUM_FN.format(i))
except KeyError:
return i
i += 1
def keys(self) -> list[str]:
with ZipFile(self.file, "r") as zip_file:
return zip_file.namelist()
def save_spectrum(self, index: int, spectrum: np.ndarray):
with ZipFile(self.file, "a") as zip_file, zip_file.open(
self.SPECTRUM_FN.format(index), "w"
) as file:
np.lib.format.write_array(file, spectrum, allow_pickle=False)
def load_spectrum(self, index: int) -> np.ndarray:
with ZipFile(self.file, "r") as zip_file, zip_file.open(
self.SPECTRUM_FN.format(index), "r"
) as file:
return np.load(file)
def save_data(self, name: str, data: bytes):
with ZipFile(self.file, "a") as zip_file, zip_file.open(name, "w") as file:
file.write(data)
def load_data(self, name: str) -> bytes:
with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file:
return file.read()
@dataclass
class DataFile:
"""
Holds information about external files necessary for a simulation.
In the current implementation, only reading data from the file is supported
"""
prefix: str | None
path: str
io: PropagationIOHandler
PREFIX_SEP = "::"
@classmethod
def from_str(cls, s: str, io: PropagationIOHandler | None = None) -> DataFile:
if cls.PREFIX_SEP in s:
prefix, path = s.split(cls.PREFIX_SEP, 1)
else:
prefix = None
path = s
return cls(prefix, path, io)
@classmethod
def validate(cls, name: str, data: str | DataFile) -> DataFile:
"""To be used to automatically construct a DataFile when creating a Parameters obj"""
if isinstance(data, cls):
return data
elif isinstance(data, str):
return cls.from_str(data)
elif isinstance(data, Path):
return cls.from_str(os.fspath(data))
else:
raise TypeError(
f"{name!r} must be a path or a bundled file specifier, not a {type(data)!r}"
)
def __json__(self) -> str:
if self.prefix is None:
return os.fspath(Path(self.path))
return self.prefix + self.PREFIX_SEP + self.path
def load_data(self) -> bytes:
if self.prefix is not None and self.io is None:
raise ValueError(
f"a bundled file prefixed with {self.prefix} "
"must have a PropagationIOHandler attached"
)
if self.io is not None:
return self.io.load_data(self.path)
else:
return Path(self.path).read_bytes()
def unique_name(base_name: str, existing: set[str]) -> str:
name = base_name
p = Path(base_name)
base = p.stem
ext = p.suffix
i = 0
while name in existing:
name = f"{base}_{i}{ext}"
i += 1
return name
def format_graph(left_elements: Sequence[str], middle: str, right_elements: Sequence[str]): def format_graph(left_elements: Sequence[str], middle: str, right_elements: Sequence[str]):
if len(left_elements) == 0: if len(left_elements) == 0:
left_elements = [""] left_elements = [""]

View File

@@ -4,7 +4,7 @@ import datetime as datetime_module
import json import json
import os import os
from copy import copy, deepcopy from copy import copy, deepcopy
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field
from functools import lru_cache, wraps from functools import lru_cache, wraps
from math import isnan from math import isnan
from pathlib import Path from pathlib import Path
@@ -12,12 +12,10 @@ from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeV
import numpy as np import numpy as np
from scgenerator import utils
from scgenerator.const import MANDATORY_PARAMETERS, __version__ from scgenerator.const import MANDATORY_PARAMETERS, __version__
from scgenerator.evaluator import Evaluator, EvaluatorError from scgenerator.evaluator import Evaluator, EvaluatorError
from scgenerator.io import DatetimeEncoder, decode_datetime_hook from scgenerator.io import CustomEncoder, DataFile, custom_decode_hook
from scgenerator.operators import Qualifier, SpecOperator from scgenerator.operators import Qualifier, SpecOperator
from scgenerator.utils import update_path_name
T = TypeVar("T") T = TypeVar("T")
DISPLAY_INFO = {} DISPLAY_INFO = {}
@@ -77,6 +75,11 @@ def string(name, n):
raise ValueError(f"{name!r} must not be empty") raise ValueError(f"{name!r} must not be empty")
def low_string(name, n):
string(name, n)
return n.lower()
def in_range_excl(_min, _max): def in_range_excl(_min, _max):
@type_checker(float, int) @type_checker(float, int)
def _in_range(name, n): def _in_range(name, n):
@@ -255,9 +258,6 @@ class Parameter:
del instance._param_dico[self.name] del instance._param_dico[self.name]
def __set__(self, instance: Parameters, value): def __set__(self, instance: Parameters, value):
if instance._frozen:
raise AttributeError("Parameters instance is frozen and can no longer be modified")
if isinstance(value, Parameter): if isinstance(value, Parameter):
if self.default is not None: if self.default is not None:
instance._param_dico[self.name] = copy(self.default) instance._param_dico[self.name] = copy(self.default)
@@ -288,9 +288,9 @@ class Parameter:
except TypeError: except TypeError:
is_value = True is_value = True
if is_value: if is_value:
if self.converter is not None: ret_val = self._validator(self.name, v)
v = self.converter(v) if ret_val is not None:
self._validator(self.name, v) v = ret_val
return is_value, v return is_value, v
@@ -307,7 +307,6 @@ class Parameters:
# root # root
name: str = Parameter(string, default="no name") name: str = Parameter(string, default="no name")
output_path: Path = Parameter(type_checker(Path), converter=Path)
# fiber # fiber
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0) input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
@@ -315,10 +314,10 @@ class Parameters:
n2: float = Parameter(non_negative(float, int)) n2: float = Parameter(non_negative(float, int))
chi3: float = Parameter(non_negative(float, int)) chi3: float = Parameter(non_negative(float, int))
loss: str = Parameter(literal("capillary")) loss: str = Parameter(literal("capillary"))
loss_file: str = Parameter(string) loss_file: DataFile = Parameter(DataFile.validate)
effective_mode_diameter: float = Parameter(positive(float, int)) effective_mode_diameter: float = Parameter(positive(float, int))
effective_area: float = Parameter(non_negative(float, int)) effective_area: float = Parameter(non_negative(float, int))
effective_area_file: str = Parameter(string) effective_area_file: DataFile = Parameter(DataFile.validate)
numerical_aperture: float = Parameter(in_range_excl(0, 1)) numerical_aperture: float = Parameter(in_range_excl(0, 1))
pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm")) pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm"))
pcf_pitch_ratio: float = Parameter(in_range_excl(0, 1)) pcf_pitch_ratio: float = Parameter(in_range_excl(0, 1))
@@ -326,7 +325,7 @@ class Parameters:
he_mode: tuple[int, int] = Parameter(int_pair, default=(1, 1)) he_mode: tuple[int, int] = Parameter(int_pair, default=(1, 1))
fit_parameters: tuple[int, int] = Parameter(float_pair, default=(0.08, 200e-9)) fit_parameters: tuple[int, int] = Parameter(float_pair, default=(0.08, 200e-9))
beta2_coefficients: Iterable[float] = Parameter(num_list) beta2_coefficients: Iterable[float] = Parameter(num_list)
dispersion_file: str = Parameter(string) dispersion_file: DataFile = Parameter(DataFile.validate)
model: str = Parameter( model: str = Parameter(
literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"), literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"),
) )
@@ -345,7 +344,7 @@ class Parameters:
capillary_nested: int = Parameter(non_negative(int), default=0) capillary_nested: int = Parameter(non_negative(int), default=0)
# gas # gas
gas_name: str = Parameter(string, converter=str.lower, default="vacuum") gas_name: str = Parameter(low_string, default="vacuum")
pressure: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar")) pressure: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
pressure_in: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar")) pressure_in: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
pressure_out: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar")) pressure_out: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
@@ -353,7 +352,7 @@ class Parameters:
plasma_density: float = Parameter(non_negative(float, int), default=0) plasma_density: float = Parameter(non_negative(float, int), default=0)
# pulse # pulse
field_file: str = Parameter(string) field_file: DataFile = Parameter(DataFile.validate)
input_time: np.ndarray = Parameter(type_checker(np.ndarray)) input_time: np.ndarray = Parameter(type_checker(np.ndarray))
input_field: np.ndarray = Parameter(type_checker(np.ndarray)) input_field: np.ndarray = Parameter(type_checker(np.ndarray))
repetition_rate: float = Parameter( repetition_rate: float = Parameter(
@@ -380,11 +379,9 @@ class Parameters:
# simulation # simulation
full_field: bool = Parameter(boolean, default=False) full_field: bool = Parameter(boolean, default=False)
integration_scheme: str = Parameter( integration_scheme: str = Parameter(
literal("erk43", "erk54", "cqe", "sd", "constant"), literal("erk43", "erk54", "cqe", "sd", "constant"), default="erk43"
converter=str.lower,
default="erk43",
) )
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower) raman_type: str = Parameter(literal("measured", "agrawal", "stolen"))
raman_fraction: float = Parameter(non_negative(float, int)) raman_fraction: float = Parameter(non_negative(float, int))
spm: bool = Parameter(boolean, default=True) spm: bool = Parameter(boolean, default=True)
repeat: int = Parameter(positive(int), default=1) repeat: int = Parameter(positive(int), default=1)
@@ -437,7 +434,8 @@ class Parameters:
@classmethod @classmethod
def from_json(cls, s: str) -> Parameters: def from_json(cls, s: str) -> Parameters:
return cls(**json.loads(s, object_hook=decode_datetime_hook)) decoded = json.loads(s, object_hook=custom_decode_hook)
return cls(**decoded)
@classmethod @classmethod
def load(cls, path: os.PathLike) -> Parameters: def load(cls, path: os.PathLike) -> Parameters:
@@ -462,18 +460,32 @@ class Parameters:
self.__post_init__() self.__post_init__()
def __setattr__(self, k, v): def __setattr__(self, k, v):
if self._frozen: if self._frozen and not k.endswith("_file"):
raise AttributeError( raise AttributeError(
f"cannot set attribute to frozen {self.__class__.__name__} instance" f"cannot set attribute to frozen {self.__class__.__name__} instance"
) )
object.__setattr__(self, k, v) object.__setattr__(self, k, v)
def copy(self) -> Parameters: def items(self) -> Iterator[tuple[str, Any]]:
return Parameters(**deepcopy(self.strip_params_dict())) for k, v in self._param_dico.items():
if v is None:
continue
yield k, v
def copy(self, freeze: bool = False) -> Parameters:
"""create a deep copy of self. if freeze is True, the returned copy is read-only"""
params = Parameters(**deepcopy(self.strip_params_dict()))
if freeze:
params.freeze()
return params
def freeze(self):
"""render the current instance read-only. This is not reversible"""
self._frozen = True
def to_json(self) -> str: def to_json(self) -> str:
d = self.dump_dict() d = self.dump_dict()
return json.dumps(d, cls=DatetimeEncoder, default=list) return json.dumps(d, cls=CustomEncoder, indent=4)
def get_evaluator(self): def get_evaluator(self):
evaluator = Evaluator.default(self.full_field) evaluator = Evaluator.default(self.full_field)
@@ -611,7 +623,7 @@ class Parameters:
"linear_op", "linear_op",
"c_to_a_factor", "c_to_a_factor",
} }
types = (np.ndarray, float, int, str, list, tuple, Path) types = (np.ndarray, float, int, str, list, tuple, Path, DataFile)
c = deepcopy if copy else lambda x: x c = deepcopy if copy else lambda x: x
out = {} out = {}
for key, value in self._param_dico.items(): for key, value in self._param_dico.items():

View File

@@ -1,3 +1,4 @@
from io import BytesIO
from typing import Iterable, TypeVar from typing import Iterable, TypeVar
import numpy as np import numpy as np
@@ -7,6 +8,7 @@ from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
from scgenerator import io from scgenerator import io
from scgenerator.io import DataFile
from scgenerator.math import argclosest, u_nm from scgenerator.math import argclosest, u_nm
from scgenerator.physics import materials as mat from scgenerator.physics import materials as mat
from scgenerator.physics import units from scgenerator.physics import units
@@ -653,7 +655,7 @@ def saitoh_paramters(pcf_pitch_ratio: float) -> tuple[float, float]:
return A, B return A, B
def load_custom_effective_area(effective_area_file: str, l: np.ndarray) -> np.ndarray: def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) -> np.ndarray:
""" """
loads custom effective area file loads custom effective area file
@@ -669,14 +671,14 @@ def load_custom_effective_area(effective_area_file: str, l: np.ndarray) -> np.nd
np.ndarray, shape (n,) np.ndarray, shape (n,)
wl-dependent effective mode field area wl-dependent effective mode field area
""" """
data = np.load(effective_area_file) data = np.load(BytesIO(effective_area_file.load_data()))
effective_area = data.get("A_eff", data.get("effective_area")) effective_area = data.get("A_eff", data.get("effective_area"))
wl = data["wavelength"] wl = data["wavelength"]
return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l) return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l)
def load_custom_dispersion(dispersion_file: str) -> tuple[np.ndarray, np.ndarray]: def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
disp_file = np.load(dispersion_file) disp_file = np.load(BytesIO(dispersion_file.load_data()))
wl_for_disp = disp_file["wavelength"] wl_for_disp = disp_file["wavelength"]
interp_range = (np.min(wl_for_disp), np.max(wl_for_disp)) interp_range = (np.min(wl_for_disp), np.max(wl_for_disp))
D = disp_file["dispersion"] D = disp_file["dispersion"]
@@ -684,7 +686,7 @@ def load_custom_dispersion(dispersion_file: str) -> tuple[np.ndarray, np.ndarray
return wl_for_disp, beta2, interp_range return wl_for_disp, beta2, interp_range
def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray: def load_custom_loss(l: np.ndarray, loss_file: DataFile) -> np.ndarray:
""" """
loads a npz loss file that contains a wavelength and a loss entry loads a npz loss file that contains a wavelength and a loss entry
@@ -700,7 +702,7 @@ def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray:
np.ndarray, shape (n,) np.ndarray, shape (n,)
loss in 1/m units loss in 1/m units
""" """
loss_data = np.load(loss_file) loss_data = np.load(BytesIO(loss_file.load_data()))
wl = loss_data["wavelength"] wl = loss_data["wavelength"]
loss = loss_data["loss"] loss = loss_data["loss"]
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l) return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)

View File

@@ -12,6 +12,7 @@ n is the number of spectra at the same z position and nt is the size of the time
import itertools import itertools
import os import os
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Literal, Tuple, TypeVar from typing import Literal, Tuple, TypeVar
@@ -25,6 +26,7 @@ from scipy.optimize._optimize import OptimizeResult
from scgenerator import math from scgenerator import math
from scgenerator.defaults import default_plotting from scgenerator.defaults import default_plotting
from scgenerator.io import DataFile
from scgenerator.physics import units from scgenerator.physics import units
c = 299792458.0 c = 299792458.0
@@ -410,8 +412,9 @@ def interp_custom_field(
return field_0 return field_0
def load_custom_field(field_file: str) -> tuple[np.ndarray, np.ndarray]: def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
field_data = np.load(field_file) data = field_file.load_data()
field_data = np.load(BytesIO(data))
return field_data["time"], field_data["field"] return field_data["time"], field_data["field"]

View File

@@ -1,37 +1,32 @@
from __future__ import annotations from __future__ import annotations
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, Iterator, Optional, Union
import matplotlib.pyplot as plt
import numpy as np import numpy as np
from scgenerator import math from scgenerator import math
from scgenerator.const import PARAM_FN, SPEC1_FN, SPEC1_FN_N from scgenerator.io import DataFile, PropagationIOHandler, ZipFileIOHandler, unique_name
from scgenerator.logger import get_logger
from scgenerator.parameter import Parameters from scgenerator.parameter import Parameters
from scgenerator.physics import pulse, units from scgenerator.physics import pulse, units
from scgenerator.physics.units import PlotRange
from scgenerator.plotting import (
mean_values_plot,
propagation_plot,
single_position_plot,
transform_1D_values,
transform_2D_propagation,
)
from scgenerator.utils import load_spectrum, load_toml, simulations_list
class Spectrum(np.ndarray): class Spectrum(np.ndarray):
params: Parameters w: np.ndarray
t: np.ndarray
l: np.ndarray
ifft: Callable[[np.ndarray], np.ndarray]
def __new__(cls, input_array, params: Parameters): def __new__(cls, input_array, params: Parameters):
# Input array is an already formed ndarray instance # Input array is an already formed ndarray instance
# We first cast to be our class type # We first cast to be our class type
obj = np.asarray(input_array).view(cls) obj = np.asarray(input_array).view(cls)
# add the new attribute to the created instance # add the new attribute to the created instance
obj.params = params obj.w = params.compute("w")
obj.t = params.compute("t")
obj.t = params.compute("t")
obj.ifft = params.compute("ifft")
# Finally, we must return the newly created object: # Finally, we must return the newly created object:
return obj return obj
@@ -40,14 +35,17 @@ class Spectrum(np.ndarray):
# see InfoArray.__array_finalize__ for comments # see InfoArray.__array_finalize__ for comments
if obj is None: if obj is None:
return return
self.params = getattr(obj, "params", None) self.w = getattr(obj, "w", None)
self.t = getattr(obj, "t", None)
self.l = getattr(obj, "l", None)
self.ifft = getattr(obj, "ifft", None)
def __getitem__(self, key) -> "Spectrum": def __getitem__(self, key) -> "Spectrum":
return super().__getitem__(key) return super().__getitem__(key)
@property @property
def wl_int(self): def wl_int(self):
return units.to_WL(math.abs2(self), self.params.l) return units.to_WL(math.abs2(self), self.l)
@property @property
def freq_int(self): def freq_int(self):
@@ -59,13 +57,13 @@ class Spectrum(np.ndarray):
@property @property
def time_int(self): def time_int(self):
return math.abs2(self.params.ifft(self)) return math.abs2(self.ifft(self))
def amplitude(self, unit): def amplitude(self, unit):
if unit.type in ["WL", "FREQ", "AFREQ"]: if unit.type in ["WL", "FREQ", "AFREQ"]:
x_axis = unit.inv(self.params.w) x_axis = unit.inv(self.w)
else: else:
x_axis = unit.inv(self.params.t) x_axis = unit.inv(self.t)
order = np.argsort(x_axis) order = np.argsort(x_axis)
func = dict( func = dict(
@@ -84,7 +82,7 @@ class Spectrum(np.ndarray):
np.sqrt( np.sqrt(
units.to_WL( units.to_WL(
math.abs2(self), math.abs2(self),
self.params.l, self.l,
) )
) )
* self * self
@@ -106,311 +104,115 @@ class Spectrum(np.ndarray):
@property @property
def wl_max(self): def wl_max(self):
if self.ndim == 1: if self.ndim == 1:
return self.params.l[np.argmax(self.wl_int, axis=-1)] return self.l[np.argmax(self.wl_int, axis=-1)]
return np.array([s.wl_max for s in self]) return np.array([s.wl_max for s in self])
def mask_wl(self, pos: float, width: float) -> Spectrum: def mask_wl(self, pos: float, width: float) -> Spectrum:
return self * np.exp( return self * np.exp(-(((self.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2))
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
)
def measure(self) -> tuple[float, float, float]: def measure(self) -> tuple[float, float, float]:
return pulse.measure_field(self.params.t, self.time_amp) return pulse.measure_field(self.t, self.time_amp)
class SimulationSeries: class Propagation:
""" io: PropagationIOHandler
SimulationsSeries are the interface the user should use to load and
interact with simulation data. The object loads each fiber of the simulation
into a separate object and exposes convenience methods to make the series behave
as a single fiber.
It should be noted that the last spectrum of a fiber and the first one of the next
fibers are identical. Therefore, SimulationSeries object will return fewer datapoints
than when manually mergin the corresponding data.
"""
path: Path
fibers: list[SimulatedFiber]
params: Parameters params: Parameters
z_indices: list[tuple[int, int]] _current_index: int
fiber_positions: list[tuple[str, float]]
def __init__(self, path: os.PathLike): PARAMS_FN = "params.json"
def __init__(
self,
io_handler: PropagationIOHandler,
params: Parameters | None = None,
bundle_data: bool = False,
):
""" """
Create a SimulationSeries A propagation is the object that manages IO for one single propagation.
It is recommended to use the "memory_propagation" and "zip_propagation" convenience
factories instead of creating this object directly.
Parameters Parameters
---------- ----------
path : os.PathLike io : PropagationIOHandler
path to the last fiber of the series object that implements the PropagationIOHandler Protocol.
Raises
------
FileNotFoundError
No simulation found in specified directory
"""
self.logger = get_logger()
for self.path in simulations_list(path):
break
else:
raise FileNotFoundError(f"No simulation in {path}")
self.fibers = [SimulatedFiber(self.path)]
while (p := self.fibers[-1].params.prev_data_dir) is not None:
p = Path(p)
if not p.is_absolute():
p = Path(self.fibers[-1].params.output_path) / p
self.fibers.append(SimulatedFiber(p))
self.fibers = self.fibers[::-1]
self.fiber_positions = [(self.fibers[0].params.name, 0.0)]
self.params = Parameters(**self.fibers[0].params.dump_dict(False, False))
z_targets = list(self.params.z_targets)
self.z_indices = [(0, j) for j in range(self.params.z_num)]
for i, fiber in enumerate(self.fibers[1:]):
self.fiber_positions.append((fiber.params.name, z_targets[-1]))
z_targets += list(fiber.params.z_targets[1:] + z_targets[-1])
self.z_indices += [(i + 1, j) for j in range(1, fiber.params.z_num)]
self.params.z_targets = np.array(z_targets)
self.params.length = self.params.z_targets[-1]
self.params.z_num = len(self.params.z_targets)
def spectra(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> Spectrum:
...
if z_descr is None:
out = [self.fibers[i].spectra(j, sim_ind) for i, j in self.z_indices]
else:
if isinstance(z_descr, (float, np.floating)):
fib_ind, z_ind = self.z_ind(z_descr)
else:
fib_ind, z_ind = self.z_indices[z_descr]
out = self.fibers[fib_ind].spectra(z_ind, sim_ind)
return Spectrum(out, self.params)
def fields(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> Spectrum:
return self.params.ifft(self.spectra(z_descr, sim_ind))
def z_ind(self, pos: float) -> tuple[int, int]:
if 0 <= pos <= self.params.length:
ind = np.argmin(np.abs(self.params.z_targets - pos))
return self.z_indices[ind]
else:
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
def plot_2D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
return propagation_plot(vals, plot_range, self.params, ax, **kwargs)
def plot_values_2D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
return transform_2D_propagation(vals, plot_range, self.params, **kwargs)
def plot_1D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: int,
sim_ind: int = 0,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
return single_position_plot(vals, plot_range, self.params, ax, **kwargs)
def plot_values_1D(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
z_pos: int,
sim_ind: int = 0,
**kwargs,
) -> tuple[np.ndarray, np.ndarray]:
"""
gives the desired values already tranformes according to the give range
Parameters
----------
left : float
leftmost limit in unit
right : float
rightmost limit in unit
unit : Union[Callable[[float], float], str]
unit
z_pos : Union[int, float]
position either as an index (int) or a real position (float)
sim_ind : Optional[int]
which simulation to take when more than one are present
Returns
-------
np.ndarray
x axis
np.ndarray
y values
"""
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
return transform_1D_values(vals, plot_range, self.params, **kwargs)
def plot_mean(
self,
left: float,
right: float,
unit: Union[Callable[[float], float], str],
ax: plt.Axes,
z_pos: int,
**kwargs,
):
plot_range = PlotRange(left, right, unit)
vals = self.retrieve_plot_values(plot_range, z_pos, None)
return mean_values_plot(vals, plot_range, self.params, ax, **kwargs)
def retrieve_plot_values(
self, plot_range: PlotRange, z_pos: Optional[Union[int, float]], sim_ind: Optional[int]
):
if plot_range.unit.type == "TIME":
return self.fields(z_pos, sim_ind)
else:
return self.spectra(z_pos, sim_ind)
def rin_propagation(
self, left: float, right: float, unit: str
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
returns the RIN as function of unit and z
Parameters
----------
left : float
left limit in unit
right : float
right limit in unit
unit : str
unit descriptor
Returns
-------
x : np.ndarray, shape (nt,)
x axis
y : np.ndarray, shape (z_num, )
y axis
rin_prop : np.ndarray, shape (z_num, nt)
RIN
"""
spectra = []
for spec in np.moveaxis(self.spectra(None, None), 1, 0):
x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False)
spectra.append(tmp)
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
# Magic methods
def __iter__(self) -> Iterator[Spectrum]:
for i, j in self.z_indices:
yield self.fibers[i].spectra(j, None)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path})"
def __eq__(self, other: SimulationSeries) -> bool:
return self.path == other.path and self.params == other.params
def __contains__(self, fiber: SimulatedFiber) -> bool:
return fiber in self.fibers
def __getitem__(self, key) -> Spectrum:
if isinstance(key, tuple):
return self.spectra(*key)
else:
return self.spectra(key, None)
class SimulatedFiber:
params : Parameters params : Parameters
t: np.ndarray simulations parameters. Those will be saved via the
w: np.ndarray
def __init__(self, path: os.PathLike):
self.path = Path(path)
self.params = Parameters(**load_toml(self.path / PARAM_FN))
self.params.output_path = str(self.path.resolve())
self.t = self.params.t
self.w = self.params.w
self.z = self.params.z_targets
def spectra(
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
) -> np.ndarray:
if z_descr is None:
out = [self.spectra(i, sim_ind) for i in range(self.params.z_num)]
else:
if isinstance(z_descr, (float, np.floating)):
return self.spectra(self.z_ind(z_descr), sim_ind)
else:
z_ind = z_descr
if z_ind < 0:
z_ind = self.params.z_num + z_ind
if sim_ind is None:
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
else:
out = self._load_1(z_ind)
return Spectrum(out, self.params)
def z_ind(self, pos: float) -> int:
if 0 <= pos <= self.z[-1]:
return np.argmin(np.abs(self.z - pos))
else:
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
""" """
loads a spectrum file self.io = io_handler
self._current_index = len(self.io)
Parameters new_params = params is not None
---------- if not new_params:
z_ind : int if bundle_data:
z_index relative to the entire simulation raise ValueError(
sim_ind : int, optional "cannot bundle data to existing Propagation. Create a new one instead"
simulation index, used when repeated simulations with same parameters are ran, )
by default 0 try:
params_data = self.io.load_data(self.PARAMS_FN)
params = Parameters.from_json(params_data.decode())
except KeyError:
raise ValueError(f"Missing Parameters in {self.__class__.__name__}.") from None
Returns self.params = params
-------
np.ndarray
loaded spectrum file
"""
if sim_ind > 0:
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
else:
return load_spectrum(self.path / SPEC1_FN.format(z_ind))
psd = np.fft.rfft(signal) / np.sqrt(0.5 * len(time) / dt)
def __repr__(self) -> str: if bundle_data:
return f"{self.__class__.__name__}(path={self.path})" if not params._frozen:
warnings.warn(
f"Parameters instance {params.name!r} is not frozen but will be saved "
"in an unmodifiable state anyway."
)
_bundle_external_files(self.params, self.io)
if new_params:
self.io.save_data(self.PARAMS_FN, self.params.to_json().encode())
def __len__(self) -> int:
return self._current_index
def __getitem__(self, key: int | slice) -> Spectrum:
if isinstance(key, slice):
return self._load_slice(key)
if isinstance(key, (float, np.floating)):
key = math.argclosest(self.params.compute("z_targets"), key)
array = self.io.load_spectrum(key)
return Spectrum(array, self.params)
def __setitem__(self, key: int, value: np.ndarray):
if not isinstance(key, int):
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
self.io.save_spectrum(key, np.asarray(value))
def _load_slice(self, key: slice) -> Spectrum:
_iter = range(len(self))[key]
out = Spectrum(np.zeros((len(_iter), self.params.t_num), dtype=complex), self.params)
for i in _iter:
out[i] = self.io.load_spectrum(i)
return out
def append(self, spectrum: np.ndarray):
self.io.save_spectrum(self._current_index, spectrum.asarray())
self._current_index += 1
def load_all(self) -> Spectrum:
return self._load_slice(slice(None))
def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
"""copies every external file specified in the parameters and saves it"""
existing_files = set(io.keys())
for _, value in params.items():
if isinstance(value, DataFile):
data = value.load_data()
value.io = io
value.prefix = "zip"
value.path = unique_name(Path(value.path).name, existing_files)
existing_files.add(value.path)
io.save_data(value.path, data)
def load_all(path: os.PathLike) -> Spectrum:
io = ZipFileIOHandler(path)
return Propagation(io).load_all()

Binary file not shown.

Binary file not shown.

124
tests/test_io_handlers.py Normal file
View File

@@ -0,0 +1,124 @@
import json
from pathlib import Path
from zipfile import ZipFile
import numpy as np
import pytest
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
from scgenerator.parameter import Parameters
from scgenerator.spectra import Propagation
PARAMS = dict(
name="PM2000D sech pulse",
# pulse
wavelength=1546e-9,
# field_file="./Pos30000New.npz",
width=1.5e-12,
shape="sech",
repetition_rate=40e6,
# fiber
dispersion_file="./PM2000D_2 extrapolated 4 0.npz",
effective_area_file="./PM2000D_A_eff_marcuse.npz",
wavelength_window=(400e-9, 4000e-9),
n2=4.5e-20,
# simulation
raman_type="measured",
quantum_noise=True,
interpolation_degree=11,
z_num=128,
length=1.5,
t_num=512,
dt=5e-15,
)
def test_file(tmp_path: Path):
params = Parameters(**PARAMS)
stuff = np.random.rand(8, 512)
io = ZipFileIOHandler(tmp_path / "test.zip")
io.save_data("params.json", params.to_json().encode())
for i, spec in enumerate(stuff):
io.save_spectrum(i, spec)
new_params = Parameters.from_json(io.load_data("params.json").decode())
assert new_params is not params
for k, v in params.items():
v_new = getattr(new_params, k)
if isinstance(v, DataFile):
assert Path(v.path).name == Path(v_new.path).name
else:
assert v == getattr(new_params, k)
for i in range(8):
assert np.all(io.load_spectrum(i) == stuff[i])
assert len(ZipFileIOHandler(tmp_path / "test.zip")) == len(io) == 8
def test_memory():
params = Parameters(**PARAMS)
stuff = np.random.rand(8, 512)
io = MemoryIOHandler()
assert len(io) == 0
io.save_data("params.json", params.to_json().encode())
for i, spec in enumerate(stuff):
io.save_spectrum(i, spec)
new_params = Parameters.from_json(io.load_data("params.json").decode())
for k, v in params.items():
v_new = getattr(new_params, k)
if isinstance(v, DataFile):
assert Path(v.path).name == Path(v_new.path).name
else:
assert v == getattr(new_params, k)
for i in range(8):
assert np.all(io.load_spectrum(i) == stuff[i])
assert len(io) == 8
def test_zip_bundle(tmp_path: Path):
params = Parameters(**PARAMS)
io = ZipFileIOHandler(tmp_path / "file.zip")
prop = Propagation(io, params.copy(True))
assert (tmp_path / "file.zip").exists()
assert (tmp_path / "file.zip").read_bytes() != b""
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
params.dispersion_file.path = str(new_disp_path)
params.effective_area_file.path = str(new_aeff_path)
params.freeze()
io = ZipFileIOHandler(tmp_path / "file2.zip")
prop2 = Propagation(io, params, True)
assert params.dispersion_file.path == new_disp_path.name
assert params.dispersion_file.prefix == "zip"
assert params.effective_area_file.path == new_aeff_path.name
assert params.effective_area_file.prefix == "zip"
with ZipFile(tmp_path / "file2.zip", "r") as zfile:
with zfile.open(new_aeff_path.name) as file:
assert file.read() == new_aeff_path.read_bytes()
with zfile.open(new_disp_path.name) as file:
assert file.read() == new_disp_path.read_bytes()
with zfile.open(Propagation.PARAMS_FN) as file:
df = json.loads(file.read().decode())
assert (
df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path
)
assert (
df["effective_area_file"]
== params.effective_area_file.prefix + "::" + params.effective_area_file.path
)
def test_unique_name():
existing = {"spec.npy", "spec_0.npy"}
assert unique_name("spec.npy", existing) == "spec_1.npy"