better submit, more tests
This commit is contained in:
@@ -2,8 +2,9 @@ import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import scgenerator.initialize as init
|
||||
import numpy as np
|
||||
import toml
|
||||
from scgenerator import utils
|
||||
from scgenerator import defaults, utils, math
|
||||
from scgenerator.errors import *
|
||||
|
||||
|
||||
@@ -30,7 +31,7 @@ class TestParamSequence(unittest.TestCase):
|
||||
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
|
||||
l = []
|
||||
s = []
|
||||
for vary_list, _ in param_seq.iterate_without_computing():
|
||||
for vary_list, _ in utils.required_simulations(param_seq.config):
|
||||
self.assertNotIn(vary_list, l)
|
||||
self.assertNotIn(utils.format_variable_list(vary_list), s)
|
||||
l.append(vary_list)
|
||||
@@ -39,12 +40,12 @@ class TestParamSequence(unittest.TestCase):
|
||||
def test_init_config_not_affected_by_iteration(self):
|
||||
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
|
||||
config = deepcopy(param_seq.config)
|
||||
for _ in param_seq.iterate_without_computing():
|
||||
for _ in utils.required_simulations(param_seq.config):
|
||||
self.assertEqual(config.items(), param_seq.config.items())
|
||||
|
||||
def test_no_variations_yields_only_num_and_id(self):
|
||||
for param_seq in self.iterconf(["no_variations"]):
|
||||
for vary_list, _ in param_seq.iterate_without_computing():
|
||||
for vary_list, _ in utils.required_simulations(param_seq.config):
|
||||
self.assertEqual(vary_list[1][0], "num")
|
||||
self.assertEqual(vary_list[0][0], "id")
|
||||
self.assertEqual(2, len(vary_list))
|
||||
@@ -65,18 +66,18 @@ class TestInitializeMethods(unittest.TestCase):
|
||||
with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"):
|
||||
init._validate_types(conf("bad4"))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Varying parameters should be specified in a list"):
|
||||
with self.assertRaisesRegex(TypeError, "Variable parameters should be specified in a list"):
|
||||
init._validate_types(conf("bad5"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"value '0' of type <class 'int'> for key 'repeat' is not valid, must be a strictly positive integer",
|
||||
"value '0' of type .*int.* for key 'repeat' is not valid, must be a strictly positive integer",
|
||||
):
|
||||
init._validate_types(conf("bad6"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Varying parameters lists should contain at least 1 element",
|
||||
r"Variable parameters lists should contain at least 1 element",
|
||||
):
|
||||
init._ensure_consistency(conf("bad7"))
|
||||
|
||||
@@ -92,7 +93,7 @@ class TestInitializeMethods(unittest.TestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"1 of '\['peak_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set",
|
||||
r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set",
|
||||
):
|
||||
init._ensure_consistency(conf("bad2"))
|
||||
|
||||
@@ -156,14 +157,49 @@ class TestInitializeMethods(unittest.TestCase):
|
||||
dict(
|
||||
t_num=16384,
|
||||
time_window=37e-12,
|
||||
lower_wavelength_interp_limit=0,
|
||||
upper_wavelength_interp_limit=1900e-9,
|
||||
lower_wavelength_interp_limit=defaults.default_parameters[
|
||||
"lower_wavelength_interp_limit"
|
||||
],
|
||||
upper_wavelength_interp_limit=defaults.default_parameters[
|
||||
"upper_wavelength_interp_limit"
|
||||
],
|
||||
).items(),
|
||||
init._ensure_consistency(conf("good6"))["simulation"].items(),
|
||||
)
|
||||
|
||||
# def test_compute_init_parameters(self):
|
||||
# conf = lambda s: load_conf("compute_init_parameters/" + s)
|
||||
def test_setup_custom_field(self):
|
||||
d = np.load("testing/configs/custom_field/init_field.npz")
|
||||
t = d["time"]
|
||||
field = d["field"]
|
||||
conf = load_conf("custom_field/no_change")
|
||||
conf = init._generate_sim_grid(conf)
|
||||
result = init.setup_custom_field(conf)
|
||||
self.assertAlmostEqual(conf["field_0"].real.max(), field.real.max(), 4)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = load_conf("custom_field/peak_power")
|
||||
conf = init._generate_sim_grid(conf)
|
||||
result = init.setup_custom_field(conf)
|
||||
self.assertAlmostEqual(math.abs2(conf["field_0"]).max(), 20000, 4)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = load_conf("custom_field/mean_power")
|
||||
conf = init._generate_sim_grid(conf)
|
||||
result = init.setup_custom_field(conf)
|
||||
self.assertAlmostEqual(np.trapz(math.abs2(conf["field_0"]), conf["t"]), 0.22 / 40e6, 4)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = load_conf("custom_field/recover1")
|
||||
conf = init._generate_sim_grid(conf)
|
||||
result = init.setup_custom_field(conf)
|
||||
self.assertAlmostEqual(math.abs2(conf["field_0"] - field).sum(), 0)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = load_conf("custom_field/recover2")
|
||||
conf = init._generate_sim_grid(conf)
|
||||
result = init.setup_custom_field(conf)
|
||||
self.assertAlmostEqual((math.abs2(conf["field_0"]) / 0.9 - math.abs2(field)).sum(), 0)
|
||||
self.assertTrue(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user