Trying to make Local Error method work, no success

This commit is contained in:
Benoît Sierro
2021-11-09 09:53:33 +01:00
parent 0d5d529ba3
commit f757466fe1
6 changed files with 427 additions and 135 deletions

View File

@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Union
import numpy as np
from . import math, operators, utils
from . import math, operators, utils, solver
from .const import MANDATORY_PARAMETERS
from .errors import EvaluatorError, NoDefaultError
from .physics import fiber, materials, pulse, units
@@ -377,6 +377,8 @@ default_rules: list[Rule] = [
Rule("loss_op", operators.NoLoss, priorities=-1),
Rule("plasma_op", operators.NoPlasma, priorities=-1),
Rule("conserved_quantity", operators.NoConservedQuantity, priorities=-1),
Rule("step_taker", solver.RK4IPStepTaker),
Rule("integrator", solver.ConstantStepIntegrator, priorities=-1),
]
envelope_rules = default_rules + [
@@ -417,6 +419,7 @@ envelope_rules = default_rules + [
Rule("dispersion_op", operators.DirectDispersion),
Rule("linear_operator", operators.EnvelopeLinearOperator),
Rule("conserved_quantity", operators.conserved_quantity),
Rule("integrator", solver.ConservedQuantityIntegrator),
]
full_field_rules = default_rules + [
@@ -440,4 +443,6 @@ full_field_rules = default_rules + [
operators.FullFieldLinearOperator,
),
Rule("nonlinear_operator", operators.FullFieldNonLinearOperator),
# Integration
Rule("integrator", solver.LocalErrorIntegrator),
]

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import dataclasses
from abc import ABC, abstractmethod
from dataclasses import dataclass
import re
from typing import Any, Callable
import numpy as np
@@ -23,14 +24,18 @@ class SpectrumDescriptor:
name: str
value: np.ndarray = None
_counter = 0
_full_field: bool = False
_converter: Callable[[np.ndarray], np.ndarray]
def __init__(self, spec2_name: str, field_name: str, field2_name: str):
self.spec2_name = spec2_name
self.field_name = field_name
self.field2_name = field2_name
def __set__(self, instance: CurrentState, value: np.ndarray):
self._counter += 1
instance.spec2 = math.abs2(value)
instance.field = instance.converter(value)
instance.field2 = math.abs2(instance.field)
setattr(instance, self.spec2_name, math.abs2(value))
setattr(instance, self.field_name, instance.converter(value))
setattr(instance, self.field2_name, math.abs2(getattr(instance, self.field_name)))
self.value = value
def __get__(self, instance, owner):
@@ -46,41 +51,96 @@ class SpectrumDescriptor:
self.name = name
class SpectrumDescriptor2:
name: str
spectrum: np.ndarray = None
__spec2: np.ndarray = None
__field: np.ndarray = None
__field2: np.ndarray = None
_converter: Callable[[np.ndarray], np.ndarray]
def __set__(self, instance: CurrentState, value: np.ndarray):
self._converter = instance.converter
self.spectrum = value
self.__spec2 = None
self.__field = None
self.__field2 = None
@property
def spec2(self) -> np.ndarray:
if self.__spec2 is None:
self.__spec2 = math.abs2(self.spectrum)
return self.__spec2
@property
def field(self) -> np.ndarray:
if self.__field is None:
self.__field = self._converter(self.spectrum)
return self.__field
@property
def field2(self) -> np.ndarray:
if self.__field2 is None:
self.__field2 = math.abs2(self.field)
return self.__field2
def __delete__(self, instance):
raise AttributeError("Cannot delete Spectrum field")
def __set_name__(self, owner, name):
self.name = name
@dataclasses.dataclass
class CurrentState:
length: float
z: float
h: float
current_step_size: float
previous_step_size: float
step: int
C_to_A_factor: np.ndarray
converter: Callable[[np.ndarray], np.ndarray] = np.fft.ifft
spectrum: np.ndarray = SpectrumDescriptor()
spectrum: np.ndarray = SpectrumDescriptor("spec2", "field", "field2")
spec2: np.ndarray = dataclasses.field(init=False)
field: np.ndarray = dataclasses.field(init=False)
field2: np.ndarray = dataclasses.field(init=False)
prev_spectrum: np.ndarray = SpectrumDescriptor("prev_spec2", "prev_field", "prev_field2")
prev_spec2: np.ndarray = dataclasses.field(init=False)
prev_field: np.ndarray = dataclasses.field(init=False)
prev_field2: np.ndarray = dataclasses.field(init=False)
@property
def z_ratio(self) -> float:
return self.z / self.length
def replace(self, new_spectrum) -> CurrentState:
return CurrentState(
self.length, self.z, self.h, self.C_to_A_factor, self.converter, new_spectrum
def replace(self, new_spectrum: np.ndarray, **new_params) -> CurrentState:
"""returns a new state with new attributes"""
params = dict(
spectrum=new_spectrum,
length=self.length,
z=self.z,
current_step_size=self.current_step_size,
previous_step_size=self.previous_step_size,
step=self.step,
C_to_A_factor=self.C_to_A_factor,
converter=self.converter,
)
return CurrentState(**(params | new_params))
class Operator(ABC):
class ValueTracker(ABC):
def values(self) -> dict[str, float]:
return {}
def get_values(self) -> dict[str, float]:
def all_values(self) -> dict[str, float]:
out = self.values()
for operator in self.__dict__.values():
if isinstance(operator, Operator):
out |= operator.get_values()
for operator in vars(self).values():
if isinstance(operator, ValueTracker):
out = operator.all_values() | out
return out
def __repr__(self) -> str:
value_pair_list = list(self.__dict__.items())
value_pair_list = list(vars(self).items())
if len(value_pair_list) == 0:
value_pair_str_list = ""
elif len(value_pair_list) == 1:
@@ -95,6 +155,8 @@ class Operator(ABC):
return repr(v[0])
return repr(v)
class Operator(ValueTracker):
@abstractmethod
def __call__(self, state: CurrentState) -> np.ndarray:
pass
@@ -757,7 +819,12 @@ class PhotonNumberLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float:
return pulse.photon_number_with_loss(
state.spec2, self.w, self.dw, self.gamma_op(state), self.loss_op(state), state.h
state.spec2,
self.w,
self.dw,
self.gamma_op(state),
self.loss_op(state),
state.current_step_size,
)
@@ -778,7 +845,10 @@ class EnergyLoss(AbstractConservedQuantity):
def __call__(self, state: CurrentState) -> float:
return pulse.pulse_energy_with_loss(
math.abs2(state.C_to_A_factor * state.spectrum), self.dw, self.loss_op(state), state.h
math.abs2(state.C_to_A_factor * state.spectrum),
self.dw,
self.loss_op(state),
state.current_step_size,
)

View File

@@ -13,18 +13,13 @@ from typing import Any, Callable, ClassVar, Iterable, Iterator, Set, Type, TypeV
import numpy as np
from scgenerator.physics import units
from . import env, legacy, utils
from .const import MANDATORY_PARAMETERS, PARAM_FN, VALID_VARIABLE, __version__
from .errors import EvaluatorError
from .evaluator import Evaluator
from .logger import get_logger
from .operators import (
AbstractConservedQuantity,
LinearOperator,
NonLinearOperator,
)
from .operators import AbstractConservedQuantity, LinearOperator, NonLinearOperator
from .solver import Integrator, StepTaker
from .utils import fiber_folder, update_path_name
from .variationer import VariationDescriptor, Variationer
@@ -382,6 +377,8 @@ class Parameters:
# computed
linear_operator: LinearOperator = Parameter(type_checker(LinearOperator))
nonlinear_operator: NonLinearOperator = Parameter(type_checker(NonLinearOperator))
step_taker: StepTaker = Parameter(type_checker(StepTaker))
integrator: Integrator = Parameter(type_checker(Integrator))
conserved_quantity: AbstractConservedQuantity = Parameter(
type_checker(AbstractConservedQuantity)
)

View File

@@ -1,3 +1,4 @@
from collections import defaultdict
import multiprocessing
import multiprocessing.connection
import os
@@ -13,7 +14,6 @@ from ..logger import get_logger
from ..operators import CurrentState
from ..parameter import Configuration, Parameters
from ..pbar import PBars, ProgressBarActor, progress_worker
from ..const import ONE_2, ONE_3, ONE_6
try:
import ray
@@ -21,6 +21,15 @@ except ModuleNotFoundError:
ray = None
class TrackedValues(defaultdict):
def __init__(self):
super().__init__(list)
def append(self, d: dict[str, Any]):
for k, v in d.items():
self[k].append(v)
class RK4IP:
params: Parameters
save_data: bool
@@ -53,19 +62,7 @@ class RK4IP:
save_data : bool, optional
save calculated spectra to disk, by default False
"""
self.set(params, save_data)
def __iter__(self) -> Generator[tuple[int, int, np.ndarray], None, None]:
yield from self.irun()
def __len__(self) -> int:
return self.params.z_num
def set(
self,
params: Parameters,
save_data=False,
):
self.params = params
self.save_data = save_data
@@ -77,16 +74,12 @@ class RK4IP:
self.logger = get_logger(self.params.output_path.name)
self.dw = self.params.w[1] - self.params.w[0]
self.z_targets = self.params.z_targets
self.error_ok = (
params.tolerated_error if self.params.adapt_step_size else self.params.step_size
)
self._setup_sim_parameters()
def _setup_sim_parameters(self):
# making sure to keep only the z that we want
# setup save targets
self.z_targets = self.params.z_targets
self.z_stored = list(self.z_targets.copy()[0 : self.params.recovery_last_stored + 1])
self.z_targets = list(self.z_targets.copy()[self.params.recovery_last_stored :])
self.z_targets.sort()
@@ -97,16 +90,18 @@ class RK4IP:
C_to_A_factor = (self.params.A_eff_arr / self.params.A_eff_arr[0]) ** (1 / 4)
else:
C_to_A_factor = 1.0
z = self.z_targets.pop(0)
# Initial step size
if self.params.adapt_step_size:
initial_h = (self.z_targets[0] - z) / 2
initial_h = (self.z_targets[1] - self.z_targets[0]) / 2
else:
initial_h = self.error_ok
self.state = CurrentState(
length=self.params.length,
z=z,
h=initial_h,
z=self.z_targets.pop(0),
current_step_size=initial_h,
previous_step_size=0.0,
step=1,
C_to_A_factor=C_to_A_factor,
converter=self.params.ifft,
spectrum=self.params.spec_0.copy() / C_to_A_factor,
@@ -114,11 +109,7 @@ class RK4IP:
self.stored_spectra = self.params.recovery_last_stored * [None] + [
self.state.spectrum.copy()
]
self.cons_qty = [
self.params.conserved_quantity(self.state),
0,
]
self.size_fac = 2 ** (1 / 5)
self.tracked_values = TrackedValues()
def _save_current_spectrum(self, num: int):
"""saves the spectrum and the corresponding cons_qty array
@@ -128,8 +119,8 @@ class RK4IP:
num : int
index of the z postition
"""
self._save_data(self.get_current_spectrum(), f"spectrum_{num}")
self._save_data(self.cons_qty, "cons_qty")
self.write(self.get_current_spectrum(), f"spectrum_{num}")
self.write(self.tracked_values, "tracked_values")
self.step_saved()
def get_current_spectrum(self) -> np.ndarray:
@@ -142,7 +133,7 @@ class RK4IP:
"""
return self.state.C_to_A_factor * self.state.spectrum
def _save_data(self, data: np.ndarray, name: str):
def write(self, data: np.ndarray, name: str):
"""calls the appropriate method to save data
Parameters
@@ -168,7 +159,7 @@ class RK4IP:
)
if self.save_data:
self._save_data(self.z_stored, "z.npy")
self.write(self.z_stored, "z.npy")
return self.stored_spectra
@@ -185,40 +176,36 @@ class RK4IP:
spectrum
"""
# Print introduction
self.logger.debug(
"Computing {} new spectra, first one at {}m".format(self.store_num, self.z_targets[0])
)
store = False
# Start of the integration
step = 1
store = False # store a spectrum
yield step, len(self.stored_spectra) - 1, self.get_current_spectrum()
yield self.state.step, len(self.stored_spectra) - 1, self.get_current_spectrum()
while self.state.z < self.params.length:
h_taken = self.take_step(step)
self.state = self.params.integrator(self.state)
step += 1
self.cons_qty.append(0)
self.state.step += 1
new_tracked_values = (
dict(step=self.state.step, z=self.state.z) | self.params.integrator.all_values()
)
self.logger.debug(f"tracked values at z={self.state.z} : {new_tracked_values}")
self.tracked_values.append(new_tracked_values)
# Whether the current spectrum has to be stored depends on previous step
if store:
self.logger.debug(
"{} steps, z = {:.4f}, h = {:.5g}".format(step, self.state.z, h_taken)
)
current_spec = self.get_current_spectrum()
self.stored_spectra.append(current_spec)
yield step, len(self.stored_spectra) - 1, current_spec
yield self.state.step, len(self.stored_spectra) - 1, current_spec
self.z_stored.append(self.state.z)
del self.z_targets[0]
# reset the constant step size after a spectrum is stored
if not self.params.adapt_step_size:
self.state.h = self.error_ok
self.state.current_step_size = self.error_ok
if len(self.z_targets) == 0:
break
@@ -226,69 +213,19 @@ class RK4IP:
# if the next step goes over a position at which we want to store
# a spectrum, we shorten the step to reach this position exactly
if self.state.z + self.state.h >= self.z_targets[0]:
if self.state.z + self.state.current_step_size >= self.z_targets[0]:
store = True
self.state.h = self.z_targets[0] - self.state.z
def take_step(self, step: int) -> float:
"""computes a new spectrum, whilst adjusting step size if required, until the error estimation
validates the new spectrum. Saves the result in the internal state attribute
Parameters
----------
step : int
index of the current
Returns
-------
h : float
step sized used
"""
keep = False
h_next_step = self.state.h
while not keep:
h = h_next_step
expD = np.exp(h * ONE_2 * self.params.linear_operator(self.state))
A_I = expD * self.state.spectrum
k1 = expD * (h * self.params.nonlinear_operator(self.state))
k2 = h * self.params.nonlinear_operator(self.state.replace(A_I + k1 * ONE_2))
k3 = h * self.params.nonlinear_operator(self.state.replace(A_I + k2 * ONE_2))
k4 = h * self.params.nonlinear_operator(self.state.replace(expD * (A_I + k3)))
new_state = self.state.replace(
expD * (A_I + k1 * ONE_6 + k2 * ONE_3 + k3 * ONE_3) + k4 * ONE_6
)
self.cons_qty[step] = self.params.conserved_quantity(new_state)
if self.params.adapt_step_size:
curr_p_change = np.abs(self.cons_qty[step - 1] - self.cons_qty[step])
cons_qty_change_ok = self.error_ok * self.cons_qty[step - 1]
if curr_p_change > 2 * cons_qty_change_ok:
progress_str = f"step {step} rejected with h = {h:.4e}, doing over"
self.logger.debug(progress_str)
keep = False
h_next_step = h * ONE_2
elif cons_qty_change_ok < curr_p_change <= 2.0 * cons_qty_change_ok:
keep = True
h_next_step = h / self.size_fac
elif curr_p_change < 0.1 * cons_qty_change_ok:
keep = True
h_next_step = h * self.size_fac
else:
keep = True
h_next_step = h
else:
keep = True
self.state = new_state
self.state.h = h_next_step
self.state.z += h
return h
self.state.current_step_size = self.z_targets[0] - self.state.z
def step_saved(self):
pass
def __iter__(self) -> Generator[tuple[int, int, np.ndarray], None, None]:
yield from self.irun()
def __len__(self) -> int:
return self.params.z_num
class SequentialRK4IP(RK4IP):
def __init__(
@@ -339,7 +276,7 @@ class RayRK4IP(RK4IP):
):
self.worker_id = worker_id
self.p_actor = p_actor
super().set(
super().__init__(
params,
save_data=save_data,
)

280
src/scgenerator/solver.py Normal file
View File

@@ -0,0 +1,280 @@
from abc import abstractmethod
import numpy as np
from . import math
from .logger import get_logger
from .operators import (
AbstractConservedQuantity,
CurrentState,
LinearOperator,
NonLinearOperator,
ValueTracker,
)
##################################################
################### STEP-TAKER ###################
##################################################
class StepTaker(ValueTracker):
linear_operator: LinearOperator
nonlinear_operator: NonLinearOperator
_tracked_values: dict[str, float]
def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator):
self.linear_operator = linear_operator
self.nonlinear_operator = nonlinear_operator
self._tracked_values = {}
@abstractmethod
def __call__(self, state: CurrentState, step_size: float) -> np.ndarray:
...
def all_values(self) -> dict[str, float]:
"""override ValueTracker.all_values to account for the fact that operators are called
multiple times per step, sometimes with different state, so we use value recorded
earlier. Please call self.recorde_tracked_values() one time only just after calling
the linear and nonlinear operators in your StepTaker.
Returns
-------
dict[str, float]
tracked values
"""
return self.values() | self._tracked_values
def record_tracked_values(self):
self._tracked_values = super().all_values()
class RK4IPStepTaker(StepTaker):
c2 = 1 / 2
c3 = 1 / 3
c6 = 1 / 6
_cached_values: tuple[np.ndarray, np.ndarray]
_cached_key: float
_cache_hits: int
_cache_misses: int
def __init__(self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator):
super().__init__(linear_operator, nonlinear_operator)
self._cached_key = None
self._cached_values = None
self._cache_hits = 0
self._cache_misses = 0
def __call__(self, state: CurrentState, step_size: float) -> np.ndarray:
h = step_size
l0, nl0 = self.cached_values(state)
expD = np.exp(h * self.c2 * l0)
A_I = expD * state.spectrum
k1 = expD * (h * nl0)
k2 = h * self.nonlinear_operator(state.replace(A_I + k1 * self.c2))
k3 = h * self.nonlinear_operator(state.replace(A_I + k2 * self.c2))
k4 = h * self.nonlinear_operator(state.replace(expD * (A_I + k3)))
return expD * (A_I + k1 * self.c6 + k2 * self.c3 + k3 * self.c3) + k4 * self.c6
def cached_values(self, state: CurrentState) -> tuple[np.ndarray, np.ndarray]:
"""the evaluation of the linear and nonlinear operators at the start of the step don't
depend on the step size, so we cache them in case we need them more than once (which
can happen depending on the adaptive step size controller)
Parameters
----------
state : CurrentState
current state of the simulation. state.z is used as the key for the cache
Returns
-------
np.ndarray
result of the linear operator
np.ndarray
result of the nonlinear operator
"""
if self._cached_key != state.z:
self._cache_misses += 1
self._cached_key = state.z
self._cached_values = self.linear_operator(state), self.nonlinear_operator(state)
self.record_tracked_values()
else:
self._cache_hits += 1
return self._cached_values
def values(self) -> dict[str, float]:
return dict(RK4IP_cache_hits=self._cache_hits, RK4IP_cache_misses=self._cache_misses)
##################################################
################### INTEGRATOR ###################
##################################################
class Integrator(ValueTracker):
last_step = 0.0
@abstractmethod
def __call__(self, state: CurrentState) -> CurrentState:
"""propagate the state with a step size of state.current_step_size
and return a new state with updated z and previous_step_size attributes"""
...
class ConstantStepIntegrator(Integrator):
def __call__(self, state: CurrentState) -> CurrentState:
new_state = state.replace(self.step_taker(state, state.current_step_size))
new_state.z += new_state.current_step_size
new_state.previous_step_size = new_state.current_step_size
return new_state
def values(self) -> dict[str, float]:
return dict(h=self.last_step)
class ConservedQuantityIntegrator(Integrator):
step_taker: StepTaker
conserved_quantity: AbstractConservedQuantity
last_quantity_value: float
tolerated_error: float
local_error: float = 0.0
def __init__(
self,
step_taker: StepTaker,
conserved_quantity: AbstractConservedQuantity,
tolerated_error: float,
):
self.conserved_quantity = conserved_quantity
self.last_quantity_value = 0
self.tolerated_error = tolerated_error
self.logger = get_logger(self.__class__.__name__)
self.size_fac = 2.0 ** (1.0 / 5.0)
self.step_taker = step_taker
def __call__(self, state: CurrentState) -> CurrentState:
keep = False
h_next_step = state.current_step_size
while not keep:
h = h_next_step
new_state = state.replace(self.step_taker(state, h))
new_qty = self.conserved_quantity(new_state)
delta = np.abs(new_qty - self.last_quantity_value) / self.last_quantity_value
if delta > 2 * self.tolerated_error:
progress_str = f"step {state.step} rejected with h = {h:.4e}, doing over"
self.logger.info(progress_str)
keep = False
h_next_step = h * 0.5
elif self.tolerated_error < delta <= 2.0 * self.tolerated_error:
keep = True
h_next_step = h / self.size_fac
elif delta < 0.1 * self.tolerated_error:
keep = True
h_next_step = h * self.size_fac
else:
keep = True
h_next_step = h
self.local_error = delta
self.last_quantity_value = new_qty
new_state.current_step_size = h_next_step
new_state.previous_step_size = h
new_state.z += h
self.last_step = h
return new_state
def values(self) -> dict[str, float]:
return dict(
cons_qty=self.last_quantity_value, h=self.last_step, relative_error=self.local_error
)
class LocalErrorIntegrator(Integrator):
step_taker: StepTaker
tolerated_error: float
local_error: float
def __init__(self, step_taker: StepTaker, tolerated_error: float, w_num: int):
self.tolerated_error = tolerated_error
self.local_error = 0.0
self.logger = get_logger(self.__class__.__name__)
self.size_fac, self.fine_fac, self.coarse_fac = 2.0 ** (1.0 / 5.0), 16 / 15, -1 / 15
self.step_taker = step_taker
def __call__(self, state: CurrentState) -> CurrentState:
keep = False
h_next_step = state.current_step_size
while not keep:
h = h_next_step
h_half = h / 2
coarse_spec = self.step_taker(state, h)
fine_spec1 = self.step_taker(state, h_half)
fine_state = state.replace(fine_spec1, z=state.z + h_half)
fine_spec = self.step_taker(fine_state, h_half)
delta = self.compute_diff(coarse_spec, fine_spec)
if delta > 2 * self.tolerated_error:
keep = False
h_next_step = h_half
elif self.tolerated_error <= delta <= 2 * self.tolerated_error:
keep = True
h_next_step = h / self.size_fac
elif 0.5 * self.tolerated_error <= delta < self.tolerated_error:
keep = True
h_next_step = h
else:
keep = True
h_next_step = h * self.size_fac
self.local_error = delta
fine_state.spectrum = fine_spec * self.fine_fac + coarse_spec * self.coarse_fac
fine_state.current_step_size = h_next_step
fine_state.previous_step_size = h
fine_state.z += h
self.last_step = h
return fine_state
def compute_diff(self, coarse_spec: np.ndarray, fine_spec: np.ndarray) -> float:
return np.sqrt(math.abs2(coarse_spec - fine_spec).sum() / math.abs2(fine_spec).sum())
def values(self) -> dict[str, float]:
return dict(relative_error=self.local_error, h=self.last_step)
class ERK43(Integrator):
linear_operator: LinearOperator
nonlinear_operator: NonLinearOperator
dt: float
def __init__(
self, linear_operator: LinearOperator, nonlinear_operator: NonLinearOperator, dt: float
):
self.linear_operator = linear_operator
self.nonlinear_operator = nonlinear_operator
self.dt = dt
def __call__(self, state: CurrentState) -> CurrentState:
keep = False
h_next_step = state.current_step_size
while not keep:
h = h_next_step
expD = np.exp(h * 0.5 * self.linear_operator(state))
A_I = expD * state.spectrum
k1 = expD * state.prev_spectrum
k2 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k1))
k3 = self.nonlinear_operator(state.replace(A_I + 0.5 * h * k2))
k4 = self.nonlinear_operator(state.replace(expD * A_I + h * k3))
r = expD * (A_I + h / 6 * (k1 + 2 * k2 + 2 * k3))
new_fine = r + h / 6 * k4
k5 = self.nonlinear_operator(state.replace(new_fine))
new_coarse = r + h / 30 * (2 * k4 + 3 * k5)

View File

@@ -14,7 +14,7 @@ from dataclasses import dataclass
from functools import cache
from pathlib import Path
from string import printable as str_printable
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Set
from typing import Any, Callable, MutableMapping, Sequence, TypeVar, Set, Union
import numpy as np
import pkg_resources as pkg
@@ -251,12 +251,12 @@ def load_material_dico(name: str) -> dict[str, Any]:
return tomli.loads(Paths.gets("materials"))[name]
def save_data(data: np.ndarray, data_dir: Path, file_name: str):
def save_data(data: Union[np.ndarray, MutableMapping], data_dir: Path, file_name: str):
"""saves numpy array to disk
Parameters
----------
data : np.ndarray
data : Union[np.ndarray, MutableMapping]
data to save
file_name : str
file name
@@ -266,7 +266,10 @@ def save_data(data: np.ndarray, data_dir: Path, file_name: str):
identifier in the main data folder of the task, by default ""
"""
path = data_dir / file_name
if isinstance(data, np.ndarray):
np.save(path, data)
elif isinstance(data, MutableMapping):
np.savez(path, **data)
get_logger(__name__).debug(f"saved data in {path}")
return