diff --git a/src/scgenerator/__init__.py b/src/scgenerator/__init__.py index 7276957..7eac434 100644 --- a/src/scgenerator/__init__.py +++ b/src/scgenerator/__init__.py @@ -5,7 +5,7 @@ from scgenerator import io, noise, operators, plotting from scgenerator.helpers import * from scgenerator.logger import get_logger from scgenerator.math import abs2, argclosest, normalized, span, tspace, wspace -from scgenerator.parameter import Parameters +from scgenerator.parameter import Parameters, format_value from scgenerator.spectra import Spectrum, propagation, propagation_series from scgenerator.physics import fiber, materials, plasma, pulse from scgenerator.physics.units import PlotRange diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index d16abed..d513b2c 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -27,7 +27,7 @@ SPEC1_FN_N = "spectrum_{}_{}.npy" Z_FN = "z.npy" PARAM_FN = "params.toml" PARAM_SEPARATOR = " " -DECIMALS_DISPLAY = 6 +DECIMALS_DISPLAY = 4 MANDATORY_PARAMETERS = { "name", diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index 44a8af0..cbf093b 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -5,7 +5,7 @@ from __future__ import annotations from operator import itemgetter -from typing import Callable, TypeVar, Union +from typing import Callable, Sequence, TypeVar, Union import numpy as np from numpy import pi @@ -124,7 +124,13 @@ def unit_formatter( prefix = "yzafpnµm kMGTPEZY"[8 + exp] if exp else "" return f"{float(base)*10**mult:.{decimals}g}{prefix}{unit}" - return formatter + def _format(val): + if isinstance(val, (Sequence, np.ndarray)): + return f"({', '.join(_format(el) for el in val)})" + else: + return formatter(val) + + return _format class unit: diff --git a/tests/test_variableparameters.py b/tests/test_variableparameters.py index b95581e..42906ee 100644 --- a/tests/test_variableparameters.py +++ b/tests/test_variableparameters.py @@ -21,6 +21,10 @@ def test_debase(): def test_get_sequence(): + s = get_sequence([[0, 0, 20]]) + assert len(s) == 20 + assert all(el == 0 for el in s) + s = get_sequence([1, 2, 3]) assert isinstance(s, np.ndarray) and s.dtype == int assert np.all(s == np.array([1, 2, 3])) @@ -41,6 +45,9 @@ def test_get_sequence(): assert isinstance(s, np.ndarray) and s.dtype == int assert np.all(s == np.array([0])) + s = get_sequence(np.zeros((1, 2))) + assert s.shape == (1, 2) + s = get_sequence([[1e-12, 1e-5, 8, "geometric"]]) assert isinstance(s, np.ndarray) and s.dtype == float assert np.all(s == np.geomspace(1e-12, 1e-5, 8)) @@ -49,6 +56,10 @@ def test_get_sequence(): assert isinstance(s, np.ndarray) and s.dtype == float assert np.all(s == np.concatenate(((0,), np.geomspace(1e-12, 1e-5, 8)))) + s = get_sequence(np.arange(10).reshape(5, 2)) + assert s[0][1] == 1 + assert s[4][0] == 8 + with pytest.raises(ValueError): get_sequence([[1, 2, 3, "logarithm"]]) @@ -88,18 +99,21 @@ def test_constant_list(): class Conf: x: Variable = vfield(default=[1, 2, 7]) y: Variable = vfield(default=[2, 0]) + z: Variable = vfield(default=np.arange(10).reshape(5, 2)) conf = Conf() - assert len(conf) == 6 + assert len(conf) == 3 * 2 * 5 assert conf.x(0) == 1 assert conf.y(0) == 2 assert conf.x(len(conf) - 1) == 7 assert conf.y(len(conf) - 1) == 0 + assert list(conf.z(0)) == [0, 1] + assert list(conf.z(len(conf) - 1)) == [8, 9] with pytest.raises(ValueError): - conf.x(6) + conf.x(30) with pytest.raises(ValueError): - conf.y(6) + conf.y(30) def test_simple_synchronize(): @@ -254,6 +268,12 @@ def test_unit_forammter(): assert fmt(0.000001235) == "0.00012%" assert fmt(0.000000001235) == "1.2e-07%" + fmt = unit_formatter("", 4) + assert fmt((4, 2)) == "(4, 2)" + assert fmt([4, 2]) == "(4, 2)" + assert fmt(np.array([4, 2])) == "(4, 2)" + assert fmt(np.zeros((2, 2, 2))) == "(((0, 0), (0, 0)), ((0, 0), (0, 0)))" + def test_param_formatting(): """formatting is always respected"""