partway trough big data structure revamp
This commit is contained in:
@@ -1,35 +0,0 @@
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import toml
|
||||
from scgenerator import logger
|
||||
from send2trash import send2trash
|
||||
|
||||
TMP = "testing/.tmp"
|
||||
|
||||
|
||||
class TestRecoveryParamSequence(unittest.TestCase):
|
||||
def setUp(self):
|
||||
shutil.copytree("/Users/benoitsierro/sc_tests/scgenerator_full anomalous55", TMP)
|
||||
self.conf = toml.load(TMP + "/initial_config.toml")
|
||||
io.set_data_folder(55, TMP)
|
||||
|
||||
def test_remaining_simulations_count(self):
|
||||
param_seq = initialize.RecoveryParamSequence(self.conf, 55)
|
||||
self.assertEqual(5, len(param_seq))
|
||||
|
||||
def test_only_one_to_complete(self):
|
||||
param_seq = initialize.RecoveryParamSequence(self.conf, 55)
|
||||
i = 0
|
||||
for expected, (vary_list, params) in zip([True, False, False, False, False], param_seq):
|
||||
i += 1
|
||||
self.assertEqual(expected, "recovery_last_stored" in params)
|
||||
|
||||
self.assertEqual(5, i)
|
||||
|
||||
def tearDown(self):
|
||||
send2trash(TMP)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,216 +0,0 @@
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import toml
|
||||
from scgenerator import defaults, utils, math
|
||||
from scgenerator.errors import *
|
||||
from scgenerator.physics import pulse, units
|
||||
from scgenerator.utils.parameter import Config, Parameters
|
||||
|
||||
|
||||
def load_conf(name):
|
||||
with open("testing/configs/" + name + ".toml") as file:
|
||||
conf = toml.load(file)
|
||||
return conf
|
||||
|
||||
|
||||
def conf_maker(folder):
|
||||
def conf(name):
|
||||
return load_conf(folder + "/" + name)
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
class TestParamSequence(unittest.TestCase):
|
||||
def iterconf(self, files):
|
||||
conf = conf_maker("param_sequence")
|
||||
for path in files:
|
||||
yield init.ParamSequence(conf(path))
|
||||
|
||||
def test_no_repeat_in_sub_folder_names(self):
|
||||
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
|
||||
l = []
|
||||
s = []
|
||||
for vary_list, _ in utils.required_simulations(param_seq.config):
|
||||
self.assertNotIn(vary_list, l)
|
||||
self.assertNotIn(utils.format_variable_list(vary_list), s)
|
||||
l.append(vary_list)
|
||||
s.append(utils.format_variable_list(vary_list))
|
||||
|
||||
def test_no_variations_yields_only_num_and_id(self):
|
||||
for param_seq in self.iterconf(["no_variations"]):
|
||||
for vary_list, _ in utils.required_simulations(param_seq.config):
|
||||
self.assertEqual(vary_list[1][0], "num")
|
||||
self.assertEqual(vary_list[0][0], "id")
|
||||
self.assertEqual(2, len(vary_list))
|
||||
|
||||
|
||||
class TestInitializeMethods(unittest.TestCase):
|
||||
def test_validate_types(self):
|
||||
conf = lambda s: load_conf("validate_types/" + s)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, r"'behaviors\[3\]' must be a str in"):
|
||||
init.Config(**conf("bad2"))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "value must be of type <class 'float'>"):
|
||||
init.Config(**conf("bad3"))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"):
|
||||
init.Config(**conf("bad4"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "'variable intensity_noise' value must be of type <class 'list'>"
|
||||
):
|
||||
init.Config(**conf("bad5"))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "'repeat' must be positive"):
|
||||
init.Config(**conf("bad6"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "variable parameter 'intensity_noise' must not be empty"
|
||||
):
|
||||
init.Config(**conf("bad7"))
|
||||
|
||||
self.assertIsNone(init.Config(**conf("good")).hr_w)
|
||||
|
||||
def test_ensure_consistency(self):
|
||||
conf = lambda s: load_conf("ensure_consistency/" + s)
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"1 of '\['t0', 'width'\]' is required and no defaults have been set",
|
||||
):
|
||||
init.Config(**conf("bad1"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set",
|
||||
):
|
||||
init.Config(**conf("bad2"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"2 of '\['dt', 't_num', 'time_window'\]' are required and no defaults have been set",
|
||||
):
|
||||
init.Config(**conf("bad3"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
DuplicateParameterError,
|
||||
r"got multiple values for parameter 'width'",
|
||||
):
|
||||
init.Config(**conf("bad4"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"'capillary_thickness' is a required parameter for fiber model 'hasan' and no defaults have been set",
|
||||
):
|
||||
init.Config(**conf("bad5"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"1 of '\['capillary_spacing', 'capillary_outer_d'\]' is required for fiber model 'hasan' and no defaults have been set",
|
||||
):
|
||||
init.Config(**conf("bad6"))
|
||||
|
||||
self.assertLessEqual(
|
||||
{"model": "pcf"}.items(), init.Config(**conf("good1")).__dict__.items()
|
||||
)
|
||||
|
||||
self.assertIsNone(init.Config(**conf("good4")).gamma)
|
||||
|
||||
self.assertLessEqual(
|
||||
{"raman_type": "agrawal"}.items(),
|
||||
init.Config(**conf("good2")).__dict__.items(),
|
||||
)
|
||||
|
||||
self.assertLessEqual(
|
||||
{"name": "no name"}.items(), init.Config(**conf("good3")).__dict__.items()
|
||||
)
|
||||
|
||||
self.assertLessEqual(
|
||||
{"capillary_nested": 0, "capillary_resonance_strengths": []}.items(),
|
||||
init.Config(**conf("good4")).__dict__.items(),
|
||||
)
|
||||
|
||||
self.assertLessEqual(
|
||||
dict(he_mode=(1, 1)).items(),
|
||||
init.Config(**conf("good5")).__dict__.items(),
|
||||
)
|
||||
|
||||
self.assertLessEqual(
|
||||
dict(temperature=300, pressure=1e5, gas_name="vacuum", plasma_density=0).items(),
|
||||
init.Config(**conf("good5")).__dict__.items(),
|
||||
)
|
||||
|
||||
def setup_conf_custom_field(self, path) -> Parameters:
|
||||
|
||||
conf = load_conf(path)
|
||||
conf = Parameters(**conf)
|
||||
init.build_sim_grid_in_place(conf)
|
||||
return conf
|
||||
|
||||
def test_setup_custom_field(self):
|
||||
d = np.load("testing/configs/custom_field/init_field.npz")
|
||||
t = d["time"]
|
||||
field = d["field"]
|
||||
conf = self.setup_conf_custom_field("custom_field/no_change")
|
||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
||||
conf
|
||||
)
|
||||
self.assertAlmostEqual(conf.field_0.real.max(), field.real.max(), 4)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/peak_power")
|
||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
||||
conf
|
||||
)
|
||||
conf.wavelength = pulse.correct_wavelength(conf.wavelength, conf.w_c, conf.field_0)
|
||||
self.assertAlmostEqual(math.abs2(conf.field_0).max(), 20000, 4)
|
||||
self.assertTrue(result)
|
||||
self.assertNotAlmostEqual(conf.wavelength, 1593e-9)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/mean_power")
|
||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
||||
conf
|
||||
)
|
||||
self.assertAlmostEqual(np.trapz(math.abs2(conf.field_0), conf.t), 0.22 / 40e6, 4)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/recover1")
|
||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
||||
conf
|
||||
)
|
||||
self.assertAlmostEqual(math.abs2(conf.field_0 - field).sum(), 0)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/recover2")
|
||||
result, conf.width, conf.peak_power, conf.energy, conf.field_0 = pulse.setup_custom_field(
|
||||
conf
|
||||
)
|
||||
self.assertAlmostEqual((math.abs2(conf.field_0) / 0.9 - math.abs2(field)).sum(), 0)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
|
||||
result = Parameters(**conf)
|
||||
self.assertAlmostEqual(units.m.inv(result.w)[np.argmax(math.abs2(result.spec_0))], 1050e-9)
|
||||
|
||||
conf = self.setup_conf_custom_field("custom_field/wavelength_shift1")
|
||||
conf.wavelength = 1593e-9
|
||||
result = Parameters(**conf)
|
||||
|
||||
conf = load_conf("custom_field/wavelength_shift2")
|
||||
conf = init.Config(**conf)
|
||||
for target, (variable, config) in zip(
|
||||
[1050e-9, 1321e-9, 1593e-9], init.ParamSequence(conf)
|
||||
):
|
||||
init.build_sim_grid_in_place(conf)
|
||||
self.assertAlmostEqual(
|
||||
units.m.inv(config.w)[np.argmax(math.abs2(config.spec_0))], target
|
||||
)
|
||||
print(config.wavelength, target)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
conf = conf_maker("validate_types")
|
||||
|
||||
unittest.main()
|
||||
@@ -1,41 +0,0 @@
|
||||
import unittest
|
||||
from scgenerator.physics.pulse import conform_pulse_params
|
||||
|
||||
|
||||
class TestPulseMethods(unittest.TestCase):
|
||||
def test_conform_pulse_params(self):
|
||||
self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, energy=6))
|
||||
self.assertNotIn(None, conform_pulse_params("gaussian", width=5, energy=6))
|
||||
self.assertNotIn(None, conform_pulse_params("gaussian", t0=5, peak_power=6))
|
||||
self.assertNotIn(None, conform_pulse_params("gaussian", width=5, peak_power=6))
|
||||
|
||||
self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, energy=6)))
|
||||
self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, energy=6)))
|
||||
self.assertEqual(4, len(conform_pulse_params("gaussian", t0=5, peak_power=6)))
|
||||
self.assertEqual(4, len(conform_pulse_params("gaussian", width=5, peak_power=6)))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "when soliton number is desired, both gamma and beta2 must be specified"
|
||||
):
|
||||
conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "when soliton number is desired, both gamma and beta2 must be specified"
|
||||
):
|
||||
conform_pulse_params("gaussian", t0=5, energy=6, beta2=0.01)
|
||||
|
||||
self.assertEqual(
|
||||
5, len(conform_pulse_params("gaussian", t0=5, energy=6, gamma=0.01, beta2=2e-6))
|
||||
)
|
||||
self.assertEqual(
|
||||
5, len(conform_pulse_params("gaussian", width=5, energy=6, gamma=0.01, beta2=2e-6))
|
||||
)
|
||||
self.assertEqual(
|
||||
5, len(conform_pulse_params("gaussian", t0=5, peak_power=6, gamma=0.01, beta2=2e-6))
|
||||
)
|
||||
self.assertEqual(
|
||||
5, len(conform_pulse_params("gaussian", width=5, peak_power=6, gamma=0.01, beta2=2e-6))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,67 +0,0 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import toml
|
||||
from scgenerator import utils
|
||||
|
||||
|
||||
def load_conf(name):
|
||||
with open("testing/configs/" + name + ".toml") as file:
|
||||
conf = toml.load(file)
|
||||
return conf
|
||||
|
||||
|
||||
def conf_maker(folder, val=True):
|
||||
def conf(name):
|
||||
if val:
|
||||
return initialize.Config(**load_conf(folder + "/" + name))
|
||||
else:
|
||||
return initialize.Config(**load_conf(folder + "/" + name))
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
class TestUtilsMethods(unittest.TestCase):
|
||||
def test_count_variations(self):
|
||||
conf = conf_maker("count_variations")
|
||||
|
||||
for sim, vary in [(1, 0), (1, 1), (2, 1), (2, 0), (120, 3)]:
|
||||
self.assertEqual((sim, vary), utils.count_variations(conf(f"{sim}sim_{vary}vary")))
|
||||
|
||||
def test_format_value(self):
|
||||
values = [
|
||||
122e-6,
|
||||
True,
|
||||
["raman", "ss"],
|
||||
np.arange(5),
|
||||
1.123,
|
||||
1.1230001,
|
||||
0.002e122,
|
||||
12.3456e-9,
|
||||
]
|
||||
s = [
|
||||
"0.000122",
|
||||
"True",
|
||||
"raman-ss",
|
||||
"0-1-2-3-4",
|
||||
"1.123",
|
||||
"1.1230001",
|
||||
"2e+119",
|
||||
"1.23456e-08",
|
||||
]
|
||||
|
||||
for value, target in zip(values, s):
|
||||
self.assertEqual(target, utils.format_value(value))
|
||||
|
||||
def test_override_config(self):
|
||||
conf = conf_maker("override", False)
|
||||
old = conf("initial_config")
|
||||
new = conf("fiber2")
|
||||
|
||||
over = utils.override_config(vars(new), old)
|
||||
self.assertNotIn("input_transmission", over.variable)
|
||||
self.assertIsNone(over.input_transmission)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
54
testing/test_variationer.py
Normal file
54
testing/test_variationer.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from pydantic import main
|
||||
import scgenerator as sc
|
||||
|
||||
|
||||
def test_descriptor():
|
||||
# Same branch
|
||||
var1 = sc.VariationDescriptor(
|
||||
raw_descr=[[("num", 1), ("a", False)], [("b", 0)]], index=[[1, 0], [0]]
|
||||
)
|
||||
var2 = sc.VariationDescriptor(
|
||||
raw_descr=[[("num", 2), ("a", False)], [("b", 0)]], index=[[1, 0], [0]]
|
||||
)
|
||||
assert var1.branch.identifier == "b_0"
|
||||
assert var1.identifier != var1.branch.identifier
|
||||
assert var1.identifier != var2.identifier
|
||||
assert var1.branch.identifier == var2.branch.identifier
|
||||
|
||||
# different branch
|
||||
var3 = sc.VariationDescriptor(
|
||||
raw_descr=[[("num", 2), ("a", True)], [("b", 0)]], index=[[1, 0], [0]]
|
||||
)
|
||||
assert var1.branch.identifier != var3.branch.identifier
|
||||
assert var1.formatted_descriptor() != var2.formatted_descriptor()
|
||||
assert var1.formatted_descriptor() != var3.formatted_descriptor()
|
||||
|
||||
|
||||
def test_variationer():
|
||||
var = sc.Variationer(
|
||||
[
|
||||
dict(a=[1, 2], num=[0, 1, 2]),
|
||||
[dict(b=["000", "111"], c=["a", "-1"])],
|
||||
dict(),
|
||||
dict(),
|
||||
[dict(aaa=[True, False], bb=[1, 3])],
|
||||
]
|
||||
)
|
||||
assert var.var_num(0) == 6
|
||||
assert var.var_num(1) == 12
|
||||
assert var.var_num() == 24
|
||||
|
||||
cfg = dict(bb=None)
|
||||
branches = set()
|
||||
for descr in var.iterate():
|
||||
assert descr.update_config(cfg).items() >= set(descr.raw_descr[-1])
|
||||
branches.add(descr.branch.identifier)
|
||||
assert len(branches) == 8
|
||||
|
||||
|
||||
def main():
|
||||
test_descriptor()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user