rudimentary new io data structure
This commit is contained in:
@@ -1,8 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence
|
from typing import BinaryIO, Protocol, Sequence
|
||||||
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
|
|
||||||
@@ -11,27 +17,207 @@ def data_file(path: str) -> Path:
|
|||||||
return Path(pkg_resources.resource_filename("scgenerator", path))
|
return Path(pkg_resources.resource_filename("scgenerator", path))
|
||||||
|
|
||||||
|
|
||||||
class DatetimeEncoder(json.JSONEncoder):
|
class CustomEncoder(json.JSONEncoder):
|
||||||
def default(self, obj):
|
def default(self, obj):
|
||||||
if isinstance(obj, (datetime.date, datetime.datetime)):
|
if isinstance(obj, (datetime.date, datetime.datetime)):
|
||||||
return obj.isoformat()
|
return obj.isoformat()
|
||||||
|
elif isinstance(obj, np.ndarray):
|
||||||
|
return tuple(obj)
|
||||||
|
elif isinstance(obj, DataFile):
|
||||||
|
return obj.__json__()
|
||||||
|
|
||||||
|
|
||||||
def decode_datetime_hook(obj):
|
def _decode_datetime(s: str) -> datetime.date | datetime.datetime | str:
|
||||||
|
try:
|
||||||
|
return datetime.datetime.fromisoformat(s)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return datetime.date.fromisoformat(s)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def custom_decode_hook(obj):
|
||||||
|
"""
|
||||||
|
Some simple, non json-compatible objects are encoded as str, this function reconstructs the
|
||||||
|
original object and sets it in the decoded json structure
|
||||||
|
"""
|
||||||
for k, v in obj.items():
|
for k, v in obj.items():
|
||||||
if not isinstance(v, str):
|
if isinstance(v, str):
|
||||||
continue
|
obj[k] = _decode_datetime(v)
|
||||||
try:
|
elif isinstance(v, list):
|
||||||
dt = datetime.datetime.fromisoformat(v)
|
obj[k] = tuple(v)
|
||||||
except Exception:
|
|
||||||
try:
|
|
||||||
dt = datetime.date.fromisoformat(v)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
obj[k] = dt
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class PropagationIOHandler(Protocol):
|
||||||
|
def __len__(self) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
def keys(self) -> list[str]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def save_spectrum(self, index: int, spectrum: np.ndarray):
|
||||||
|
...
|
||||||
|
|
||||||
|
def load_spectrum(self, index: int) -> np.ndarray:
|
||||||
|
...
|
||||||
|
|
||||||
|
def save_data(self, name: str, data: bytes):
|
||||||
|
...
|
||||||
|
|
||||||
|
def load_data(self, name: str) -> bytes:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryIOHandler:
|
||||||
|
spectra: dict[int, np.ndarray]
|
||||||
|
data: dict[str, bytes]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.spectra = {}
|
||||||
|
self.data = {}
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.spectra)
|
||||||
|
|
||||||
|
def keys(self) -> list[str]:
|
||||||
|
return list(self.data.keys())
|
||||||
|
|
||||||
|
def save_spectrum(self, index: int, spectrum: np.ndarray):
|
||||||
|
self.spectra[index] = spectrum
|
||||||
|
|
||||||
|
def load_spectrum(self, index: int) -> np.ndarray:
|
||||||
|
return self.spectra[index]
|
||||||
|
|
||||||
|
def save_data(self, name: str, data: bytes):
|
||||||
|
self.data[name] = data
|
||||||
|
|
||||||
|
def load_data(self, name: str) -> bytes:
|
||||||
|
return self.data[name]
|
||||||
|
|
||||||
|
|
||||||
|
class ZipFileIOHandler:
|
||||||
|
file: BinaryIO
|
||||||
|
SPECTRUM_FN = "spectra/spectrum_{}.npy"
|
||||||
|
|
||||||
|
def __init__(self, file: os.PathLike):
|
||||||
|
"""
|
||||||
|
Create a IO handler to be used in Propagation.
|
||||||
|
This handler saves spectra as numpy `.npy` files.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
file : os.PathLike
|
||||||
|
path to the desired zip file. Will be created if it doesn't exist.
|
||||||
|
Passing in an already opened file-like object is not supported at the moment
|
||||||
|
"""
|
||||||
|
self.file = Path(file)
|
||||||
|
if not self.file.exists():
|
||||||
|
ZipFile(self.file, "w").close()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
with ZipFile(self.file, "r") as zip_file:
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
zip_file.open(self.SPECTRUM_FN.format(i))
|
||||||
|
except KeyError:
|
||||||
|
return i
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
def keys(self) -> list[str]:
|
||||||
|
with ZipFile(self.file, "r") as zip_file:
|
||||||
|
return zip_file.namelist()
|
||||||
|
|
||||||
|
def save_spectrum(self, index: int, spectrum: np.ndarray):
|
||||||
|
with ZipFile(self.file, "a") as zip_file, zip_file.open(
|
||||||
|
self.SPECTRUM_FN.format(index), "w"
|
||||||
|
) as file:
|
||||||
|
np.lib.format.write_array(file, spectrum, allow_pickle=False)
|
||||||
|
|
||||||
|
def load_spectrum(self, index: int) -> np.ndarray:
|
||||||
|
with ZipFile(self.file, "r") as zip_file, zip_file.open(
|
||||||
|
self.SPECTRUM_FN.format(index), "r"
|
||||||
|
) as file:
|
||||||
|
return np.load(file)
|
||||||
|
|
||||||
|
def save_data(self, name: str, data: bytes):
|
||||||
|
with ZipFile(self.file, "a") as zip_file, zip_file.open(name, "w") as file:
|
||||||
|
file.write(data)
|
||||||
|
|
||||||
|
def load_data(self, name: str) -> bytes:
|
||||||
|
with ZipFile(self.file, "r") as zip_file, zip_file.open(name, "r") as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataFile:
|
||||||
|
"""
|
||||||
|
Holds information about external files necessary for a simulation.
|
||||||
|
In the current implementation, only reading data from the file is supported
|
||||||
|
"""
|
||||||
|
|
||||||
|
prefix: str | None
|
||||||
|
path: str
|
||||||
|
io: PropagationIOHandler
|
||||||
|
|
||||||
|
PREFIX_SEP = "::"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(cls, s: str, io: PropagationIOHandler | None = None) -> DataFile:
|
||||||
|
if cls.PREFIX_SEP in s:
|
||||||
|
prefix, path = s.split(cls.PREFIX_SEP, 1)
|
||||||
|
else:
|
||||||
|
prefix = None
|
||||||
|
path = s
|
||||||
|
return cls(prefix, path, io)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, name: str, data: str | DataFile) -> DataFile:
|
||||||
|
"""To be used to automatically construct a DataFile when creating a Parameters obj"""
|
||||||
|
if isinstance(data, cls):
|
||||||
|
return data
|
||||||
|
elif isinstance(data, str):
|
||||||
|
return cls.from_str(data)
|
||||||
|
elif isinstance(data, Path):
|
||||||
|
return cls.from_str(os.fspath(data))
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"{name!r} must be a path or a bundled file specifier, not a {type(data)!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __json__(self) -> str:
|
||||||
|
if self.prefix is None:
|
||||||
|
return os.fspath(Path(self.path))
|
||||||
|
return self.prefix + self.PREFIX_SEP + self.path
|
||||||
|
|
||||||
|
def load_data(self) -> bytes:
|
||||||
|
if self.prefix is not None and self.io is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"a bundled file prefixed with {self.prefix} "
|
||||||
|
"must have a PropagationIOHandler attached"
|
||||||
|
)
|
||||||
|
if self.io is not None:
|
||||||
|
return self.io.load_data(self.path)
|
||||||
|
else:
|
||||||
|
return Path(self.path).read_bytes()
|
||||||
|
|
||||||
|
|
||||||
|
def unique_name(base_name: str, existing: set[str]) -> str:
|
||||||
|
name = base_name
|
||||||
|
p = Path(base_name)
|
||||||
|
base = p.stem
|
||||||
|
ext = p.suffix
|
||||||
|
i = 0
|
||||||
|
while name in existing:
|
||||||
|
name = f"{base}_{i}{ext}"
|
||||||
|
i += 1
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
def format_graph(left_elements: Sequence[str], middle: str, right_elements: Sequence[str]):
|
def format_graph(left_elements: Sequence[str], middle: str, right_elements: Sequence[str]):
|
||||||
if len(left_elements) == 0:
|
if len(left_elements) == 0:
|
||||||
left_elements = [""]
|
left_elements = [""]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import datetime as datetime_module
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from math import isnan
|
from math import isnan
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -12,12 +12,10 @@ from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeV
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from scgenerator import utils
|
|
||||||
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
from scgenerator.const import MANDATORY_PARAMETERS, __version__
|
||||||
from scgenerator.evaluator import Evaluator, EvaluatorError
|
from scgenerator.evaluator import Evaluator, EvaluatorError
|
||||||
from scgenerator.io import DatetimeEncoder, decode_datetime_hook
|
from scgenerator.io import CustomEncoder, DataFile, custom_decode_hook
|
||||||
from scgenerator.operators import Qualifier, SpecOperator
|
from scgenerator.operators import Qualifier, SpecOperator
|
||||||
from scgenerator.utils import update_path_name
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
DISPLAY_INFO = {}
|
DISPLAY_INFO = {}
|
||||||
@@ -77,6 +75,11 @@ def string(name, n):
|
|||||||
raise ValueError(f"{name!r} must not be empty")
|
raise ValueError(f"{name!r} must not be empty")
|
||||||
|
|
||||||
|
|
||||||
|
def low_string(name, n):
|
||||||
|
string(name, n)
|
||||||
|
return n.lower()
|
||||||
|
|
||||||
|
|
||||||
def in_range_excl(_min, _max):
|
def in_range_excl(_min, _max):
|
||||||
@type_checker(float, int)
|
@type_checker(float, int)
|
||||||
def _in_range(name, n):
|
def _in_range(name, n):
|
||||||
@@ -255,9 +258,6 @@ class Parameter:
|
|||||||
del instance._param_dico[self.name]
|
del instance._param_dico[self.name]
|
||||||
|
|
||||||
def __set__(self, instance: Parameters, value):
|
def __set__(self, instance: Parameters, value):
|
||||||
if instance._frozen:
|
|
||||||
raise AttributeError("Parameters instance is frozen and can no longer be modified")
|
|
||||||
|
|
||||||
if isinstance(value, Parameter):
|
if isinstance(value, Parameter):
|
||||||
if self.default is not None:
|
if self.default is not None:
|
||||||
instance._param_dico[self.name] = copy(self.default)
|
instance._param_dico[self.name] = copy(self.default)
|
||||||
@@ -288,9 +288,9 @@ class Parameter:
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
is_value = True
|
is_value = True
|
||||||
if is_value:
|
if is_value:
|
||||||
if self.converter is not None:
|
ret_val = self._validator(self.name, v)
|
||||||
v = self.converter(v)
|
if ret_val is not None:
|
||||||
self._validator(self.name, v)
|
v = ret_val
|
||||||
return is_value, v
|
return is_value, v
|
||||||
|
|
||||||
|
|
||||||
@@ -307,7 +307,6 @@ class Parameters:
|
|||||||
|
|
||||||
# root
|
# root
|
||||||
name: str = Parameter(string, default="no name")
|
name: str = Parameter(string, default="no name")
|
||||||
output_path: Path = Parameter(type_checker(Path), converter=Path)
|
|
||||||
|
|
||||||
# fiber
|
# fiber
|
||||||
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
|
input_transmission: float = Parameter(in_range_incl(0, 1), default=1.0)
|
||||||
@@ -315,10 +314,10 @@ class Parameters:
|
|||||||
n2: float = Parameter(non_negative(float, int))
|
n2: float = Parameter(non_negative(float, int))
|
||||||
chi3: float = Parameter(non_negative(float, int))
|
chi3: float = Parameter(non_negative(float, int))
|
||||||
loss: str = Parameter(literal("capillary"))
|
loss: str = Parameter(literal("capillary"))
|
||||||
loss_file: str = Parameter(string)
|
loss_file: DataFile = Parameter(DataFile.validate)
|
||||||
effective_mode_diameter: float = Parameter(positive(float, int))
|
effective_mode_diameter: float = Parameter(positive(float, int))
|
||||||
effective_area: float = Parameter(non_negative(float, int))
|
effective_area: float = Parameter(non_negative(float, int))
|
||||||
effective_area_file: str = Parameter(string)
|
effective_area_file: DataFile = Parameter(DataFile.validate)
|
||||||
numerical_aperture: float = Parameter(in_range_excl(0, 1))
|
numerical_aperture: float = Parameter(in_range_excl(0, 1))
|
||||||
pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm"))
|
pcf_pitch: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm"))
|
||||||
pcf_pitch_ratio: float = Parameter(in_range_excl(0, 1))
|
pcf_pitch_ratio: float = Parameter(in_range_excl(0, 1))
|
||||||
@@ -326,7 +325,7 @@ class Parameters:
|
|||||||
he_mode: tuple[int, int] = Parameter(int_pair, default=(1, 1))
|
he_mode: tuple[int, int] = Parameter(int_pair, default=(1, 1))
|
||||||
fit_parameters: tuple[int, int] = Parameter(float_pair, default=(0.08, 200e-9))
|
fit_parameters: tuple[int, int] = Parameter(float_pair, default=(0.08, 200e-9))
|
||||||
beta2_coefficients: Iterable[float] = Parameter(num_list)
|
beta2_coefficients: Iterable[float] = Parameter(num_list)
|
||||||
dispersion_file: str = Parameter(string)
|
dispersion_file: DataFile = Parameter(DataFile.validate)
|
||||||
model: str = Parameter(
|
model: str = Parameter(
|
||||||
literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"),
|
literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"),
|
||||||
)
|
)
|
||||||
@@ -345,7 +344,7 @@ class Parameters:
|
|||||||
capillary_nested: int = Parameter(non_negative(int), default=0)
|
capillary_nested: int = Parameter(non_negative(int), default=0)
|
||||||
|
|
||||||
# gas
|
# gas
|
||||||
gas_name: str = Parameter(string, converter=str.lower, default="vacuum")
|
gas_name: str = Parameter(low_string, default="vacuum")
|
||||||
pressure: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
|
pressure: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
|
||||||
pressure_in: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
|
pressure_in: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
|
||||||
pressure_out: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
|
pressure_out: float = Parameter(non_negative(float, int), display_info=(1e-5, "bar"))
|
||||||
@@ -353,7 +352,7 @@ class Parameters:
|
|||||||
plasma_density: float = Parameter(non_negative(float, int), default=0)
|
plasma_density: float = Parameter(non_negative(float, int), default=0)
|
||||||
|
|
||||||
# pulse
|
# pulse
|
||||||
field_file: str = Parameter(string)
|
field_file: DataFile = Parameter(DataFile.validate)
|
||||||
input_time: np.ndarray = Parameter(type_checker(np.ndarray))
|
input_time: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
input_field: np.ndarray = Parameter(type_checker(np.ndarray))
|
input_field: np.ndarray = Parameter(type_checker(np.ndarray))
|
||||||
repetition_rate: float = Parameter(
|
repetition_rate: float = Parameter(
|
||||||
@@ -380,11 +379,9 @@ class Parameters:
|
|||||||
# simulation
|
# simulation
|
||||||
full_field: bool = Parameter(boolean, default=False)
|
full_field: bool = Parameter(boolean, default=False)
|
||||||
integration_scheme: str = Parameter(
|
integration_scheme: str = Parameter(
|
||||||
literal("erk43", "erk54", "cqe", "sd", "constant"),
|
literal("erk43", "erk54", "cqe", "sd", "constant"), default="erk43"
|
||||||
converter=str.lower,
|
|
||||||
default="erk43",
|
|
||||||
)
|
)
|
||||||
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"), converter=str.lower)
|
raman_type: str = Parameter(literal("measured", "agrawal", "stolen"))
|
||||||
raman_fraction: float = Parameter(non_negative(float, int))
|
raman_fraction: float = Parameter(non_negative(float, int))
|
||||||
spm: bool = Parameter(boolean, default=True)
|
spm: bool = Parameter(boolean, default=True)
|
||||||
repeat: int = Parameter(positive(int), default=1)
|
repeat: int = Parameter(positive(int), default=1)
|
||||||
@@ -437,7 +434,8 @@ class Parameters:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, s: str) -> Parameters:
|
def from_json(cls, s: str) -> Parameters:
|
||||||
return cls(**json.loads(s, object_hook=decode_datetime_hook))
|
decoded = json.loads(s, object_hook=custom_decode_hook)
|
||||||
|
return cls(**decoded)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: os.PathLike) -> Parameters:
|
def load(cls, path: os.PathLike) -> Parameters:
|
||||||
@@ -462,18 +460,32 @@ class Parameters:
|
|||||||
self.__post_init__()
|
self.__post_init__()
|
||||||
|
|
||||||
def __setattr__(self, k, v):
|
def __setattr__(self, k, v):
|
||||||
if self._frozen:
|
if self._frozen and not k.endswith("_file"):
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
f"cannot set attribute to frozen {self.__class__.__name__} instance"
|
f"cannot set attribute to frozen {self.__class__.__name__} instance"
|
||||||
)
|
)
|
||||||
object.__setattr__(self, k, v)
|
object.__setattr__(self, k, v)
|
||||||
|
|
||||||
def copy(self) -> Parameters:
|
def items(self) -> Iterator[tuple[str, Any]]:
|
||||||
return Parameters(**deepcopy(self.strip_params_dict()))
|
for k, v in self._param_dico.items():
|
||||||
|
if v is None:
|
||||||
|
continue
|
||||||
|
yield k, v
|
||||||
|
|
||||||
|
def copy(self, freeze: bool = False) -> Parameters:
|
||||||
|
"""create a deep copy of self. if freeze is True, the returned copy is read-only"""
|
||||||
|
params = Parameters(**deepcopy(self.strip_params_dict()))
|
||||||
|
if freeze:
|
||||||
|
params.freeze()
|
||||||
|
return params
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
"""render the current instance read-only. This is not reversible"""
|
||||||
|
self._frozen = True
|
||||||
|
|
||||||
def to_json(self) -> str:
|
def to_json(self) -> str:
|
||||||
d = self.dump_dict()
|
d = self.dump_dict()
|
||||||
return json.dumps(d, cls=DatetimeEncoder, default=list)
|
return json.dumps(d, cls=CustomEncoder, indent=4)
|
||||||
|
|
||||||
def get_evaluator(self):
|
def get_evaluator(self):
|
||||||
evaluator = Evaluator.default(self.full_field)
|
evaluator = Evaluator.default(self.full_field)
|
||||||
@@ -611,7 +623,7 @@ class Parameters:
|
|||||||
"linear_op",
|
"linear_op",
|
||||||
"c_to_a_factor",
|
"c_to_a_factor",
|
||||||
}
|
}
|
||||||
types = (np.ndarray, float, int, str, list, tuple, Path)
|
types = (np.ndarray, float, int, str, list, tuple, Path, DataFile)
|
||||||
c = deepcopy if copy else lambda x: x
|
c = deepcopy if copy else lambda x: x
|
||||||
out = {}
|
out = {}
|
||||||
for key, value in self._param_dico.items():
|
for key, value in self._param_dico.items():
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from io import BytesIO
|
||||||
from typing import Iterable, TypeVar
|
from typing import Iterable, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -7,6 +8,7 @@ from numpy.polynomial.chebyshev import Chebyshev, cheb2poly
|
|||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
|
||||||
from scgenerator import io
|
from scgenerator import io
|
||||||
|
from scgenerator.io import DataFile
|
||||||
from scgenerator.math import argclosest, u_nm
|
from scgenerator.math import argclosest, u_nm
|
||||||
from scgenerator.physics import materials as mat
|
from scgenerator.physics import materials as mat
|
||||||
from scgenerator.physics import units
|
from scgenerator.physics import units
|
||||||
@@ -653,7 +655,7 @@ def saitoh_paramters(pcf_pitch_ratio: float) -> tuple[float, float]:
|
|||||||
return A, B
|
return A, B
|
||||||
|
|
||||||
|
|
||||||
def load_custom_effective_area(effective_area_file: str, l: np.ndarray) -> np.ndarray:
|
def load_custom_effective_area(effective_area_file: DataFile, l: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
loads custom effective area file
|
loads custom effective area file
|
||||||
|
|
||||||
@@ -669,14 +671,14 @@ def load_custom_effective_area(effective_area_file: str, l: np.ndarray) -> np.nd
|
|||||||
np.ndarray, shape (n,)
|
np.ndarray, shape (n,)
|
||||||
wl-dependent effective mode field area
|
wl-dependent effective mode field area
|
||||||
"""
|
"""
|
||||||
data = np.load(effective_area_file)
|
data = np.load(BytesIO(effective_area_file.load_data()))
|
||||||
effective_area = data.get("A_eff", data.get("effective_area"))
|
effective_area = data.get("A_eff", data.get("effective_area"))
|
||||||
wl = data["wavelength"]
|
wl = data["wavelength"]
|
||||||
return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l)
|
return interp1d(wl, effective_area, fill_value=1, bounds_error=False)(l)
|
||||||
|
|
||||||
|
|
||||||
def load_custom_dispersion(dispersion_file: str) -> tuple[np.ndarray, np.ndarray]:
|
def load_custom_dispersion(dispersion_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
|
||||||
disp_file = np.load(dispersion_file)
|
disp_file = np.load(BytesIO(dispersion_file.load_data()))
|
||||||
wl_for_disp = disp_file["wavelength"]
|
wl_for_disp = disp_file["wavelength"]
|
||||||
interp_range = (np.min(wl_for_disp), np.max(wl_for_disp))
|
interp_range = (np.min(wl_for_disp), np.max(wl_for_disp))
|
||||||
D = disp_file["dispersion"]
|
D = disp_file["dispersion"]
|
||||||
@@ -684,7 +686,7 @@ def load_custom_dispersion(dispersion_file: str) -> tuple[np.ndarray, np.ndarray
|
|||||||
return wl_for_disp, beta2, interp_range
|
return wl_for_disp, beta2, interp_range
|
||||||
|
|
||||||
|
|
||||||
def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray:
|
def load_custom_loss(l: np.ndarray, loss_file: DataFile) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
loads a npz loss file that contains a wavelength and a loss entry
|
loads a npz loss file that contains a wavelength and a loss entry
|
||||||
|
|
||||||
@@ -700,7 +702,7 @@ def load_custom_loss(l: np.ndarray, loss_file: str) -> np.ndarray:
|
|||||||
np.ndarray, shape (n,)
|
np.ndarray, shape (n,)
|
||||||
loss in 1/m units
|
loss in 1/m units
|
||||||
"""
|
"""
|
||||||
loss_data = np.load(loss_file)
|
loss_data = np.load(BytesIO(loss_file.load_data()))
|
||||||
wl = loss_data["wavelength"]
|
wl = loss_data["wavelength"]
|
||||||
loss = loss_data["loss"]
|
loss = loss_data["loss"]
|
||||||
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
return interp1d(wl, loss, fill_value=0, bounds_error=False)(l)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ n is the number of spectra at the same z position and nt is the size of the time
|
|||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from dataclasses import astuple, dataclass
|
from dataclasses import astuple, dataclass
|
||||||
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Tuple, TypeVar
|
from typing import Literal, Tuple, TypeVar
|
||||||
|
|
||||||
@@ -25,6 +26,7 @@ from scipy.optimize._optimize import OptimizeResult
|
|||||||
|
|
||||||
from scgenerator import math
|
from scgenerator import math
|
||||||
from scgenerator.defaults import default_plotting
|
from scgenerator.defaults import default_plotting
|
||||||
|
from scgenerator.io import DataFile
|
||||||
from scgenerator.physics import units
|
from scgenerator.physics import units
|
||||||
|
|
||||||
c = 299792458.0
|
c = 299792458.0
|
||||||
@@ -410,8 +412,9 @@ def interp_custom_field(
|
|||||||
return field_0
|
return field_0
|
||||||
|
|
||||||
|
|
||||||
def load_custom_field(field_file: str) -> tuple[np.ndarray, np.ndarray]:
|
def load_custom_field(field_file: DataFile) -> tuple[np.ndarray, np.ndarray]:
|
||||||
field_data = np.load(field_file)
|
data = field_file.load_data()
|
||||||
|
field_data = np.load(BytesIO(data))
|
||||||
return field_data["time"], field_data["field"]
|
return field_data["time"], field_data["field"]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,37 +1,32 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Iterator, Optional, Union
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from scgenerator import math
|
from scgenerator import math
|
||||||
from scgenerator.const import PARAM_FN, SPEC1_FN, SPEC1_FN_N
|
from scgenerator.io import DataFile, PropagationIOHandler, ZipFileIOHandler, unique_name
|
||||||
from scgenerator.logger import get_logger
|
|
||||||
from scgenerator.parameter import Parameters
|
from scgenerator.parameter import Parameters
|
||||||
from scgenerator.physics import pulse, units
|
from scgenerator.physics import pulse, units
|
||||||
from scgenerator.physics.units import PlotRange
|
|
||||||
from scgenerator.plotting import (
|
|
||||||
mean_values_plot,
|
|
||||||
propagation_plot,
|
|
||||||
single_position_plot,
|
|
||||||
transform_1D_values,
|
|
||||||
transform_2D_propagation,
|
|
||||||
)
|
|
||||||
from scgenerator.utils import load_spectrum, load_toml, simulations_list
|
|
||||||
|
|
||||||
|
|
||||||
class Spectrum(np.ndarray):
|
class Spectrum(np.ndarray):
|
||||||
params: Parameters
|
w: np.ndarray
|
||||||
|
t: np.ndarray
|
||||||
|
l: np.ndarray
|
||||||
|
ifft: Callable[[np.ndarray], np.ndarray]
|
||||||
|
|
||||||
def __new__(cls, input_array, params: Parameters):
|
def __new__(cls, input_array, params: Parameters):
|
||||||
# Input array is an already formed ndarray instance
|
# Input array is an already formed ndarray instance
|
||||||
# We first cast to be our class type
|
# We first cast to be our class type
|
||||||
obj = np.asarray(input_array).view(cls)
|
obj = np.asarray(input_array).view(cls)
|
||||||
# add the new attribute to the created instance
|
# add the new attribute to the created instance
|
||||||
obj.params = params
|
obj.w = params.compute("w")
|
||||||
|
obj.t = params.compute("t")
|
||||||
|
obj.t = params.compute("t")
|
||||||
|
obj.ifft = params.compute("ifft")
|
||||||
|
|
||||||
# Finally, we must return the newly created object:
|
# Finally, we must return the newly created object:
|
||||||
return obj
|
return obj
|
||||||
@@ -40,14 +35,17 @@ class Spectrum(np.ndarray):
|
|||||||
# see InfoArray.__array_finalize__ for comments
|
# see InfoArray.__array_finalize__ for comments
|
||||||
if obj is None:
|
if obj is None:
|
||||||
return
|
return
|
||||||
self.params = getattr(obj, "params", None)
|
self.w = getattr(obj, "w", None)
|
||||||
|
self.t = getattr(obj, "t", None)
|
||||||
|
self.l = getattr(obj, "l", None)
|
||||||
|
self.ifft = getattr(obj, "ifft", None)
|
||||||
|
|
||||||
def __getitem__(self, key) -> "Spectrum":
|
def __getitem__(self, key) -> "Spectrum":
|
||||||
return super().__getitem__(key)
|
return super().__getitem__(key)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def wl_int(self):
|
def wl_int(self):
|
||||||
return units.to_WL(math.abs2(self), self.params.l)
|
return units.to_WL(math.abs2(self), self.l)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def freq_int(self):
|
def freq_int(self):
|
||||||
@@ -59,13 +57,13 @@ class Spectrum(np.ndarray):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def time_int(self):
|
def time_int(self):
|
||||||
return math.abs2(self.params.ifft(self))
|
return math.abs2(self.ifft(self))
|
||||||
|
|
||||||
def amplitude(self, unit):
|
def amplitude(self, unit):
|
||||||
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||||
x_axis = unit.inv(self.params.w)
|
x_axis = unit.inv(self.w)
|
||||||
else:
|
else:
|
||||||
x_axis = unit.inv(self.params.t)
|
x_axis = unit.inv(self.t)
|
||||||
|
|
||||||
order = np.argsort(x_axis)
|
order = np.argsort(x_axis)
|
||||||
func = dict(
|
func = dict(
|
||||||
@@ -84,7 +82,7 @@ class Spectrum(np.ndarray):
|
|||||||
np.sqrt(
|
np.sqrt(
|
||||||
units.to_WL(
|
units.to_WL(
|
||||||
math.abs2(self),
|
math.abs2(self),
|
||||||
self.params.l,
|
self.l,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
* self
|
* self
|
||||||
@@ -106,311 +104,115 @@ class Spectrum(np.ndarray):
|
|||||||
@property
|
@property
|
||||||
def wl_max(self):
|
def wl_max(self):
|
||||||
if self.ndim == 1:
|
if self.ndim == 1:
|
||||||
return self.params.l[np.argmax(self.wl_int, axis=-1)]
|
return self.l[np.argmax(self.wl_int, axis=-1)]
|
||||||
return np.array([s.wl_max for s in self])
|
return np.array([s.wl_max for s in self])
|
||||||
|
|
||||||
def mask_wl(self, pos: float, width: float) -> Spectrum:
|
def mask_wl(self, pos: float, width: float) -> Spectrum:
|
||||||
return self * np.exp(
|
return self * np.exp(-(((self.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2))
|
||||||
-(((self.params.l - pos) / (pulse.fwhm_to_T0_fac["gaussian"] * width)) ** 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
def measure(self) -> tuple[float, float, float]:
|
def measure(self) -> tuple[float, float, float]:
|
||||||
return pulse.measure_field(self.params.t, self.time_amp)
|
return pulse.measure_field(self.t, self.time_amp)
|
||||||
|
|
||||||
|
|
||||||
class SimulationSeries:
|
class Propagation:
|
||||||
"""
|
io: PropagationIOHandler
|
||||||
SimulationsSeries are the interface the user should use to load and
|
|
||||||
interact with simulation data. The object loads each fiber of the simulation
|
|
||||||
into a separate object and exposes convenience methods to make the series behave
|
|
||||||
as a single fiber.
|
|
||||||
|
|
||||||
It should be noted that the last spectrum of a fiber and the first one of the next
|
|
||||||
fibers are identical. Therefore, SimulationSeries object will return fewer datapoints
|
|
||||||
than when manually mergin the corresponding data.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
path: Path
|
|
||||||
fibers: list[SimulatedFiber]
|
|
||||||
params: Parameters
|
params: Parameters
|
||||||
z_indices: list[tuple[int, int]]
|
_current_index: int
|
||||||
fiber_positions: list[tuple[str, float]]
|
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike):
|
PARAMS_FN = "params.json"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
io_handler: PropagationIOHandler,
|
||||||
|
params: Parameters | None = None,
|
||||||
|
bundle_data: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create a SimulationSeries
|
A propagation is the object that manages IO for one single propagation.
|
||||||
|
It is recommended to use the "memory_propagation" and "zip_propagation" convenience
|
||||||
|
factories instead of creating this object directly.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : os.PathLike
|
io : PropagationIOHandler
|
||||||
path to the last fiber of the series
|
object that implements the PropagationIOHandler Protocol.
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
No simulation found in specified directory
|
|
||||||
"""
|
|
||||||
self.logger = get_logger()
|
|
||||||
for self.path in simulations_list(path):
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"No simulation in {path}")
|
|
||||||
self.fibers = [SimulatedFiber(self.path)]
|
|
||||||
while (p := self.fibers[-1].params.prev_data_dir) is not None:
|
|
||||||
p = Path(p)
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(self.fibers[-1].params.output_path) / p
|
|
||||||
self.fibers.append(SimulatedFiber(p))
|
|
||||||
self.fibers = self.fibers[::-1]
|
|
||||||
|
|
||||||
self.fiber_positions = [(self.fibers[0].params.name, 0.0)]
|
|
||||||
self.params = Parameters(**self.fibers[0].params.dump_dict(False, False))
|
|
||||||
z_targets = list(self.params.z_targets)
|
|
||||||
self.z_indices = [(0, j) for j in range(self.params.z_num)]
|
|
||||||
for i, fiber in enumerate(self.fibers[1:]):
|
|
||||||
self.fiber_positions.append((fiber.params.name, z_targets[-1]))
|
|
||||||
z_targets += list(fiber.params.z_targets[1:] + z_targets[-1])
|
|
||||||
self.z_indices += [(i + 1, j) for j in range(1, fiber.params.z_num)]
|
|
||||||
self.params.z_targets = np.array(z_targets)
|
|
||||||
self.params.length = self.params.z_targets[-1]
|
|
||||||
self.params.z_num = len(self.params.z_targets)
|
|
||||||
|
|
||||||
def spectra(
|
|
||||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
|
||||||
) -> Spectrum:
|
|
||||||
...
|
|
||||||
if z_descr is None:
|
|
||||||
out = [self.fibers[i].spectra(j, sim_ind) for i, j in self.z_indices]
|
|
||||||
else:
|
|
||||||
if isinstance(z_descr, (float, np.floating)):
|
|
||||||
fib_ind, z_ind = self.z_ind(z_descr)
|
|
||||||
else:
|
|
||||||
fib_ind, z_ind = self.z_indices[z_descr]
|
|
||||||
out = self.fibers[fib_ind].spectra(z_ind, sim_ind)
|
|
||||||
return Spectrum(out, self.params)
|
|
||||||
|
|
||||||
def fields(
|
|
||||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
|
||||||
) -> Spectrum:
|
|
||||||
return self.params.ifft(self.spectra(z_descr, sim_ind))
|
|
||||||
|
|
||||||
def z_ind(self, pos: float) -> tuple[int, int]:
|
|
||||||
if 0 <= pos <= self.params.length:
|
|
||||||
ind = np.argmin(np.abs(self.params.z_targets - pos))
|
|
||||||
return self.z_indices[ind]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
|
|
||||||
|
|
||||||
def plot_2D(
|
|
||||||
self,
|
|
||||||
left: float,
|
|
||||||
right: float,
|
|
||||||
unit: Union[Callable[[float], float], str],
|
|
||||||
ax: plt.Axes,
|
|
||||||
sim_ind: int = 0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
plot_range = PlotRange(left, right, unit)
|
|
||||||
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
|
|
||||||
return propagation_plot(vals, plot_range, self.params, ax, **kwargs)
|
|
||||||
|
|
||||||
def plot_values_2D(
|
|
||||||
self,
|
|
||||||
left: float,
|
|
||||||
right: float,
|
|
||||||
unit: Union[Callable[[float], float], str],
|
|
||||||
sim_ind: int = 0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
plot_range = PlotRange(left, right, unit)
|
|
||||||
vals = self.retrieve_plot_values(plot_range, None, sim_ind)
|
|
||||||
return transform_2D_propagation(vals, plot_range, self.params, **kwargs)
|
|
||||||
|
|
||||||
def plot_1D(
|
|
||||||
self,
|
|
||||||
left: float,
|
|
||||||
right: float,
|
|
||||||
unit: Union[Callable[[float], float], str],
|
|
||||||
ax: plt.Axes,
|
|
||||||
z_pos: int,
|
|
||||||
sim_ind: int = 0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
plot_range = PlotRange(left, right, unit)
|
|
||||||
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
|
|
||||||
return single_position_plot(vals, plot_range, self.params, ax, **kwargs)
|
|
||||||
|
|
||||||
def plot_values_1D(
|
|
||||||
self,
|
|
||||||
left: float,
|
|
||||||
right: float,
|
|
||||||
unit: Union[Callable[[float], float], str],
|
|
||||||
z_pos: int,
|
|
||||||
sim_ind: int = 0,
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""
|
|
||||||
gives the desired values already tranformes according to the give range
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
left : float
|
|
||||||
leftmost limit in unit
|
|
||||||
right : float
|
|
||||||
rightmost limit in unit
|
|
||||||
unit : Union[Callable[[float], float], str]
|
|
||||||
unit
|
|
||||||
z_pos : Union[int, float]
|
|
||||||
position either as an index (int) or a real position (float)
|
|
||||||
sim_ind : Optional[int]
|
|
||||||
which simulation to take when more than one are present
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
x axis
|
|
||||||
np.ndarray
|
|
||||||
y values
|
|
||||||
"""
|
|
||||||
plot_range = PlotRange(left, right, unit)
|
|
||||||
vals = self.retrieve_plot_values(plot_range, z_pos, sim_ind)
|
|
||||||
return transform_1D_values(vals, plot_range, self.params, **kwargs)
|
|
||||||
|
|
||||||
def plot_mean(
|
|
||||||
self,
|
|
||||||
left: float,
|
|
||||||
right: float,
|
|
||||||
unit: Union[Callable[[float], float], str],
|
|
||||||
ax: plt.Axes,
|
|
||||||
z_pos: int,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
plot_range = PlotRange(left, right, unit)
|
|
||||||
vals = self.retrieve_plot_values(plot_range, z_pos, None)
|
|
||||||
return mean_values_plot(vals, plot_range, self.params, ax, **kwargs)
|
|
||||||
|
|
||||||
def retrieve_plot_values(
|
|
||||||
self, plot_range: PlotRange, z_pos: Optional[Union[int, float]], sim_ind: Optional[int]
|
|
||||||
):
|
|
||||||
if plot_range.unit.type == "TIME":
|
|
||||||
return self.fields(z_pos, sim_ind)
|
|
||||||
else:
|
|
||||||
return self.spectra(z_pos, sim_ind)
|
|
||||||
|
|
||||||
def rin_propagation(
|
|
||||||
self, left: float, right: float, unit: str
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
"""
|
|
||||||
returns the RIN as function of unit and z
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
left : float
|
|
||||||
left limit in unit
|
|
||||||
right : float
|
|
||||||
right limit in unit
|
|
||||||
unit : str
|
|
||||||
unit descriptor
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
x : np.ndarray, shape (nt,)
|
|
||||||
x axis
|
|
||||||
y : np.ndarray, shape (z_num, )
|
|
||||||
y axis
|
|
||||||
rin_prop : np.ndarray, shape (z_num, nt)
|
|
||||||
RIN
|
|
||||||
"""
|
|
||||||
spectra = []
|
|
||||||
for spec in np.moveaxis(self.spectra(None, None), 1, 0):
|
|
||||||
x, z, tmp = transform_2D_propagation(spec, (left, right, unit), self.params, False)
|
|
||||||
spectra.append(tmp)
|
|
||||||
return x, z, pulse.rin_curve(np.moveaxis(spectra, 0, 1))
|
|
||||||
|
|
||||||
# Magic methods
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Spectrum]:
|
|
||||||
for i, j in self.z_indices:
|
|
||||||
yield self.fibers[i].spectra(j, None)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"{self.__class__.__name__}(path={self.path})"
|
|
||||||
|
|
||||||
def __eq__(self, other: SimulationSeries) -> bool:
|
|
||||||
return self.path == other.path and self.params == other.params
|
|
||||||
|
|
||||||
def __contains__(self, fiber: SimulatedFiber) -> bool:
|
|
||||||
return fiber in self.fibers
|
|
||||||
|
|
||||||
def __getitem__(self, key) -> Spectrum:
|
|
||||||
if isinstance(key, tuple):
|
|
||||||
return self.spectra(*key)
|
|
||||||
else:
|
|
||||||
return self.spectra(key, None)
|
|
||||||
|
|
||||||
|
|
||||||
class SimulatedFiber:
|
|
||||||
params : Parameters
|
params : Parameters
|
||||||
t: np.ndarray
|
simulations parameters. Those will be saved via the
|
||||||
w: np.ndarray
|
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike):
|
|
||||||
self.path = Path(path)
|
|
||||||
self.params = Parameters(**load_toml(self.path / PARAM_FN))
|
|
||||||
self.params.output_path = str(self.path.resolve())
|
|
||||||
self.t = self.params.t
|
|
||||||
self.w = self.params.w
|
|
||||||
self.z = self.params.z_targets
|
|
||||||
|
|
||||||
def spectra(
|
|
||||||
self, z_descr: Union[float, int, None] = None, sim_ind: Optional[int] = 0
|
|
||||||
) -> np.ndarray:
|
|
||||||
if z_descr is None:
|
|
||||||
out = [self.spectra(i, sim_ind) for i in range(self.params.z_num)]
|
|
||||||
else:
|
|
||||||
if isinstance(z_descr, (float, np.floating)):
|
|
||||||
return self.spectra(self.z_ind(z_descr), sim_ind)
|
|
||||||
else:
|
|
||||||
z_ind = z_descr
|
|
||||||
|
|
||||||
if z_ind < 0:
|
|
||||||
z_ind = self.params.z_num + z_ind
|
|
||||||
|
|
||||||
if sim_ind is None:
|
|
||||||
out = [self._load_1(z_ind, i) for i in range(self.params.repeat)]
|
|
||||||
else:
|
|
||||||
out = self._load_1(z_ind)
|
|
||||||
return Spectrum(out, self.params)
|
|
||||||
|
|
||||||
def z_ind(self, pos: float) -> int:
|
|
||||||
if 0 <= pos <= self.z[-1]:
|
|
||||||
return np.argmin(np.abs(self.z - pos))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"cannot match z={pos} with max length of {self.params.length}")
|
|
||||||
|
|
||||||
def _load_1(self, z_ind: int, sim_ind=0) -> np.ndarray:
|
|
||||||
"""
|
"""
|
||||||
loads a spectrum file
|
self.io = io_handler
|
||||||
|
self._current_index = len(self.io)
|
||||||
|
|
||||||
Parameters
|
new_params = params is not None
|
||||||
----------
|
if not new_params:
|
||||||
z_ind : int
|
if bundle_data:
|
||||||
z_index relative to the entire simulation
|
raise ValueError(
|
||||||
sim_ind : int, optional
|
"cannot bundle data to existing Propagation. Create a new one instead"
|
||||||
simulation index, used when repeated simulations with same parameters are ran,
|
)
|
||||||
by default 0
|
try:
|
||||||
|
params_data = self.io.load_data(self.PARAMS_FN)
|
||||||
|
params = Parameters.from_json(params_data.decode())
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f"Missing Parameters in {self.__class__.__name__}.") from None
|
||||||
|
|
||||||
Returns
|
self.params = params
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
loaded spectrum file
|
|
||||||
"""
|
|
||||||
if sim_ind > 0:
|
|
||||||
return load_spectrum(self.path / SPEC1_FN_N.format(z_ind, sim_ind))
|
|
||||||
else:
|
|
||||||
return load_spectrum(self.path / SPEC1_FN.format(z_ind))
|
|
||||||
psd = np.fft.rfft(signal) / np.sqrt(0.5 * len(time) / dt)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
if bundle_data:
|
||||||
return f"{self.__class__.__name__}(path={self.path})"
|
if not params._frozen:
|
||||||
|
warnings.warn(
|
||||||
|
f"Parameters instance {params.name!r} is not frozen but will be saved "
|
||||||
|
"in an unmodifiable state anyway."
|
||||||
|
)
|
||||||
|
_bundle_external_files(self.params, self.io)
|
||||||
|
|
||||||
|
if new_params:
|
||||||
|
self.io.save_data(self.PARAMS_FN, self.params.to_json().encode())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._current_index
|
||||||
|
|
||||||
|
def __getitem__(self, key: int | slice) -> Spectrum:
|
||||||
|
if isinstance(key, slice):
|
||||||
|
return self._load_slice(key)
|
||||||
|
if isinstance(key, (float, np.floating)):
|
||||||
|
key = math.argclosest(self.params.compute("z_targets"), key)
|
||||||
|
array = self.io.load_spectrum(key)
|
||||||
|
return Spectrum(array, self.params)
|
||||||
|
|
||||||
|
def __setitem__(self, key: int, value: np.ndarray):
|
||||||
|
if not isinstance(key, int):
|
||||||
|
raise TypeError(f"Cannot save a spectrum at index {key!r} of type {type(key)!r}")
|
||||||
|
self.io.save_spectrum(key, np.asarray(value))
|
||||||
|
|
||||||
|
def _load_slice(self, key: slice) -> Spectrum:
|
||||||
|
_iter = range(len(self))[key]
|
||||||
|
out = Spectrum(np.zeros((len(_iter), self.params.t_num), dtype=complex), self.params)
|
||||||
|
for i in _iter:
|
||||||
|
out[i] = self.io.load_spectrum(i)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def append(self, spectrum: np.ndarray):
|
||||||
|
self.io.save_spectrum(self._current_index, spectrum.asarray())
|
||||||
|
self._current_index += 1
|
||||||
|
|
||||||
|
def load_all(self) -> Spectrum:
|
||||||
|
return self._load_slice(slice(None))
|
||||||
|
|
||||||
|
|
||||||
|
def _bundle_external_files(params: Parameters, io: PropagationIOHandler):
|
||||||
|
"""copies every external file specified in the parameters and saves it"""
|
||||||
|
existing_files = set(io.keys())
|
||||||
|
for _, value in params.items():
|
||||||
|
if isinstance(value, DataFile):
|
||||||
|
data = value.load_data()
|
||||||
|
|
||||||
|
value.io = io
|
||||||
|
value.prefix = "zip"
|
||||||
|
value.path = unique_name(Path(value.path).name, existing_files)
|
||||||
|
existing_files.add(value.path)
|
||||||
|
|
||||||
|
io.save_data(value.path, data)
|
||||||
|
|
||||||
|
|
||||||
|
def load_all(path: os.PathLike) -> Spectrum:
|
||||||
|
io = ZipFileIOHandler(path)
|
||||||
|
return Propagation(io).load_all()
|
||||||
|
|||||||
BIN
tests/data/PM2000D_2 extrapolated 4 0.npz
Normal file
BIN
tests/data/PM2000D_2 extrapolated 4 0.npz
Normal file
Binary file not shown.
BIN
tests/data/PM2000D_A_eff_marcuse.npz
Normal file
BIN
tests/data/PM2000D_A_eff_marcuse.npz
Normal file
Binary file not shown.
124
tests/test_io_handlers.py
Normal file
124
tests/test_io_handlers.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from scgenerator.io import DataFile, MemoryIOHandler, ZipFileIOHandler, unique_name
|
||||||
|
from scgenerator.parameter import Parameters
|
||||||
|
from scgenerator.spectra import Propagation
|
||||||
|
|
||||||
|
PARAMS = dict(
|
||||||
|
name="PM2000D sech pulse",
|
||||||
|
# pulse
|
||||||
|
wavelength=1546e-9,
|
||||||
|
# field_file="./Pos30000New.npz",
|
||||||
|
width=1.5e-12,
|
||||||
|
shape="sech",
|
||||||
|
repetition_rate=40e6,
|
||||||
|
# fiber
|
||||||
|
dispersion_file="./PM2000D_2 extrapolated 4 0.npz",
|
||||||
|
effective_area_file="./PM2000D_A_eff_marcuse.npz",
|
||||||
|
wavelength_window=(400e-9, 4000e-9),
|
||||||
|
n2=4.5e-20,
|
||||||
|
# simulation
|
||||||
|
raman_type="measured",
|
||||||
|
quantum_noise=True,
|
||||||
|
interpolation_degree=11,
|
||||||
|
z_num=128,
|
||||||
|
length=1.5,
|
||||||
|
t_num=512,
|
||||||
|
dt=5e-15,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_file(tmp_path: Path):
|
||||||
|
params = Parameters(**PARAMS)
|
||||||
|
stuff = np.random.rand(8, 512)
|
||||||
|
io = ZipFileIOHandler(tmp_path / "test.zip")
|
||||||
|
io.save_data("params.json", params.to_json().encode())
|
||||||
|
for i, spec in enumerate(stuff):
|
||||||
|
io.save_spectrum(i, spec)
|
||||||
|
|
||||||
|
new_params = Parameters.from_json(io.load_data("params.json").decode())
|
||||||
|
assert new_params is not params
|
||||||
|
for k, v in params.items():
|
||||||
|
v_new = getattr(new_params, k)
|
||||||
|
if isinstance(v, DataFile):
|
||||||
|
assert Path(v.path).name == Path(v_new.path).name
|
||||||
|
else:
|
||||||
|
assert v == getattr(new_params, k)
|
||||||
|
|
||||||
|
for i in range(8):
|
||||||
|
assert np.all(io.load_spectrum(i) == stuff[i])
|
||||||
|
|
||||||
|
assert len(ZipFileIOHandler(tmp_path / "test.zip")) == len(io) == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory():
|
||||||
|
params = Parameters(**PARAMS)
|
||||||
|
stuff = np.random.rand(8, 512)
|
||||||
|
io = MemoryIOHandler()
|
||||||
|
assert len(io) == 0
|
||||||
|
io.save_data("params.json", params.to_json().encode())
|
||||||
|
for i, spec in enumerate(stuff):
|
||||||
|
io.save_spectrum(i, spec)
|
||||||
|
|
||||||
|
new_params = Parameters.from_json(io.load_data("params.json").decode())
|
||||||
|
for k, v in params.items():
|
||||||
|
v_new = getattr(new_params, k)
|
||||||
|
if isinstance(v, DataFile):
|
||||||
|
assert Path(v.path).name == Path(v_new.path).name
|
||||||
|
else:
|
||||||
|
assert v == getattr(new_params, k)
|
||||||
|
for i in range(8):
|
||||||
|
assert np.all(io.load_spectrum(i) == stuff[i])
|
||||||
|
|
||||||
|
assert len(io) == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_zip_bundle(tmp_path: Path):
|
||||||
|
params = Parameters(**PARAMS)
|
||||||
|
io = ZipFileIOHandler(tmp_path / "file.zip")
|
||||||
|
prop = Propagation(io, params.copy(True))
|
||||||
|
|
||||||
|
assert (tmp_path / "file.zip").exists()
|
||||||
|
assert (tmp_path / "file.zip").read_bytes() != b""
|
||||||
|
|
||||||
|
new_disp_path = Path("./tests/data/PM2000D_2 extrapolated 4 0.npz")
|
||||||
|
new_aeff_path = Path("./tests/data/PM2000D_A_eff_marcuse.npz")
|
||||||
|
params.dispersion_file.path = str(new_disp_path)
|
||||||
|
params.effective_area_file.path = str(new_aeff_path)
|
||||||
|
params.freeze()
|
||||||
|
|
||||||
|
io = ZipFileIOHandler(tmp_path / "file2.zip")
|
||||||
|
prop2 = Propagation(io, params, True)
|
||||||
|
|
||||||
|
assert params.dispersion_file.path == new_disp_path.name
|
||||||
|
assert params.dispersion_file.prefix == "zip"
|
||||||
|
assert params.effective_area_file.path == new_aeff_path.name
|
||||||
|
assert params.effective_area_file.prefix == "zip"
|
||||||
|
|
||||||
|
with ZipFile(tmp_path / "file2.zip", "r") as zfile:
|
||||||
|
with zfile.open(new_aeff_path.name) as file:
|
||||||
|
assert file.read() == new_aeff_path.read_bytes()
|
||||||
|
with zfile.open(new_disp_path.name) as file:
|
||||||
|
assert file.read() == new_disp_path.read_bytes()
|
||||||
|
|
||||||
|
with zfile.open(Propagation.PARAMS_FN) as file:
|
||||||
|
df = json.loads(file.read().decode())
|
||||||
|
|
||||||
|
assert (
|
||||||
|
df["dispersion_file"] == params.dispersion_file.prefix + "::" + params.dispersion_file.path
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
df["effective_area_file"]
|
||||||
|
== params.effective_area_file.prefix + "::" + params.effective_area_file.path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unique_name():
|
||||||
|
existing = {"spec.npy", "spec_0.npy"}
|
||||||
|
assert unique_name("spec.npy", existing) == "spec_1.npy"
|
||||||
Reference in New Issue
Block a user