general cleanup

This commit is contained in:
Benoît Sierro
2023-08-23 13:41:02 +02:00
parent 8c84487465
commit 4ea23bedda
6 changed files with 29 additions and 26 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "scgenerator" name = "scgenerator"
version = "0.3.8" version = "0.3.9"
description = "Simulate nonlinear pulse propagation in optical fibers" description = "Simulate nonlinear pulse propagation in optical fibers"
readme = "README.md" readme = "README.md"
authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }]

View File

@@ -13,6 +13,19 @@ from zipfile import ZipFile
import numpy as np import numpy as np
class TimedMessage:
def __init__(self, interval: float = 10.0):
self.interval = datetime.timedelta(seconds=interval)
self.next = datetime.datetime.now()
def ready(self) -> bool:
t = datetime.datetime.now()
if self.next <= t:
self.next = t + self.interval
return True
return False
def data_file(path: str) -> Path: def data_file(path: str) -> Path:
"""returns a `Path` object pointing to the desired data file included in `scgenerator`""" """returns a `Path` object pointing to the desired data file included in `scgenerator`"""
return importlib.resources.path("scgenerator", "data") / path return importlib.resources.path("scgenerator", "data") / path

View File

@@ -492,8 +492,8 @@ class Parameters:
continue continue
yield k, v yield k, v
def copy(self, deep: bool = True, freeze: bool = False) -> Parameters: def copy(self, deep: bool = False, freeze: bool = False) -> Parameters:
"""create a deep copy of self. if freeze is True, the returned copy is read-only""" """create a (deep) copy of self. if freeze is True, the returned copy is read-only"""
if deep: if deep:
params = Parameters(**deepcopy(self._param_dico)) params = Parameters(**deepcopy(self._param_dico))
else: else:

View File

@@ -12,9 +12,12 @@ from typing import Any, Iterator, Sequence
import numba import numba
import numpy as np import numpy as np
from scgenerator.io import TimedMessage
from scgenerator.logger import get_logger
from scgenerator.math import abs2 from scgenerator.math import abs2
from scgenerator.operators import SpecOperator, VariableQuantity from scgenerator.operators import SpecOperator, VariableQuantity
from scgenerator.utils import TimedMessage
logger = get_logger(__name__)
class SimulationResult: class SimulationResult:
@@ -229,10 +232,7 @@ def solve43(
continue continue
else: else:
rejected.append((h, error)) rejected.append((h, error))
print( logger.info(f"{z = :.3f} rejected step {step_ind} with {h = :.2g}, {error = :.2g}")
f"{z = :.3f} rejected step {step_ind} with {h = :.2g}, {error = :.2g}",
file=sys.stderr,
)
h = h * next_h_factor h = h * next_h_factor
@@ -243,7 +243,7 @@ def solve43(
store_next = False store_next = False
if msg.ready(): if msg.ready():
print(f"step {step_ind}, {z = :.3f}, {error = :g}, {h = :.3g}", file=sys.stderr) logger.info(f"step {step_ind}, {z = :.3f}, {error = :g}, {h = :.3g}")
def integrate( def integrate(

View File

@@ -207,11 +207,14 @@ class Propagation(Generic[ParamsOrNone]):
def _load_slice(self, key: slice) -> Spectrum: def _load_slice(self, key: slice) -> Spectrum:
_iter = range(len(self))[key] _iter = range(len(self))[key]
out = np.zeros((len(_iter), self.parameters.t_num), dtype=complex)
if self.parameters is not None: if self.parameters is not None:
out = Spectrum(out, self.parameters) out = Spectrum(
np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters
)
for i in _iter: for i in _iter:
out[i] = self.io.load_spectrum(i) out[i] = self.io.load_spectrum(i)
else:
out = np.array([self.io.load_spectrum(i) for i in _iter])
return out return out
def append(self, spectrum: np.ndarray): def append(self, spectrum: np.ndarray):

View File

@@ -27,19 +27,6 @@ from scgenerator.logger import get_logger
T_ = TypeVar("T_") T_ = TypeVar("T_")
class TimedMessage:
def __init__(self, interval: float = 10.0):
self.interval = datetime.timedelta(seconds=interval)
self.next = datetime.datetime.now()
def ready(self) -> bool:
t = datetime.datetime.now()
if self.next <= t:
self.next = t + self.interval
return True
return False
def conform_variable_entry(d) -> list[dict[str, list]]: def conform_variable_entry(d) -> list[dict[str, list]]:
if isinstance(d, MutableMapping): if isinstance(d, MutableMapping):
d = [{k: v} for k, v in d.items()] d = [{k: v} for k, v in d.items()]