diff --git a/pyproject.toml b/pyproject.toml index d23798c..0db193e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "scgenerator" -version = "0.3.8" +version = "0.3.9" description = "Simulate nonlinear pulse propagation in optical fibers" readme = "README.md" authors = [{ name = "Benoit Sierro", email = "benoit.sierro@iap.unibe.ch" }] diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 5ea03d0..9cd9521 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -13,6 +13,19 @@ from zipfile import ZipFile import numpy as np +class TimedMessage: + def __init__(self, interval: float = 10.0): + self.interval = datetime.timedelta(seconds=interval) + self.next = datetime.datetime.now() + + def ready(self) -> bool: + t = datetime.datetime.now() + if self.next <= t: + self.next = t + self.interval + return True + return False + + def data_file(path: str) -> Path: """returns a `Path` object pointing to the desired data file included in `scgenerator`""" return importlib.resources.path("scgenerator", "data") / path diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index d727eff..fde6c64 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -492,8 +492,8 @@ class Parameters: continue yield k, v - def copy(self, deep: bool = True, freeze: bool = False) -> Parameters: - """create a deep copy of self. if freeze is True, the returned copy is read-only""" + def copy(self, deep: bool = False, freeze: bool = False) -> Parameters: + """create a (deep) copy of self. if freeze is True, the returned copy is read-only""" if deep: params = Parameters(**deepcopy(self._param_dico)) else: diff --git a/src/scgenerator/solver.py b/src/scgenerator/solver.py index b36efd9..af0aa9c 100644 --- a/src/scgenerator/solver.py +++ b/src/scgenerator/solver.py @@ -12,9 +12,12 @@ from typing import Any, Iterator, Sequence import numba import numpy as np +from scgenerator.io import TimedMessage +from scgenerator.logger import get_logger from scgenerator.math import abs2 from scgenerator.operators import SpecOperator, VariableQuantity -from scgenerator.utils import TimedMessage + +logger = get_logger(__name__) class SimulationResult: @@ -229,10 +232,7 @@ def solve43( continue else: rejected.append((h, error)) - print( - f"{z = :.3f} rejected step {step_ind} with {h = :.2g}, {error = :.2g}", - file=sys.stderr, - ) + logger.info(f"{z = :.3f} rejected step {step_ind} with {h = :.2g}, {error = :.2g}") h = h * next_h_factor @@ -243,7 +243,7 @@ def solve43( store_next = False if msg.ready(): - print(f"step {step_ind}, {z = :.3f}, {error = :g}, {h = :.3g}", file=sys.stderr) + logger.info(f"step {step_ind}, {z = :.3f}, {error = :g}, {h = :.3g}") def integrate( diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index fe73a51..82a26d6 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -207,11 +207,14 @@ class Propagation(Generic[ParamsOrNone]): def _load_slice(self, key: slice) -> Spectrum: _iter = range(len(self))[key] - out = np.zeros((len(_iter), self.parameters.t_num), dtype=complex) if self.parameters is not None: - out = Spectrum(out, self.parameters) - for i in _iter: - out[i] = self.io.load_spectrum(i) + out = Spectrum( + np.zeros((len(_iter), self.parameters.t_num), dtype=complex), self.parameters + ) + for i in _iter: + out[i] = self.io.load_spectrum(i) + else: + out = np.array([self.io.load_spectrum(i) for i in _iter]) return out def append(self, spectrum: np.ndarray): diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index b68520c..c04d0c2 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -27,19 +27,6 @@ from scgenerator.logger import get_logger T_ = TypeVar("T_") -class TimedMessage: - def __init__(self, interval: float = 10.0): - self.interval = datetime.timedelta(seconds=interval) - self.next = datetime.datetime.now() - - def ready(self) -> bool: - t = datetime.datetime.now() - if self.next <= t: - self.next = t + self.interval - return True - return False - - def conform_variable_entry(d) -> list[dict[str, list]]: if isinstance(d, MutableMapping): d = [{k: v} for k, v in d.items()]