From 045c4ba44ea5410506f515fa41dbc9822850c953 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Tue, 1 Jun 2021 11:47:12 +0200 Subject: [PATCH] more tests, wl adjusted on every custom field --- src/scgenerator/initialize.py | 8 +++---- src/scgenerator/io.py | 2 +- testing/configs/custom_field/peak_power.toml | 2 +- .../custom_field/wavelength_shift1.toml | 22 ++++++++++++++++++ .../custom_field/wavelength_shift2.toml | 23 +++++++++++++++++++ testing/test_initialize.py | 21 +++++++++++++++++ 6 files changed, 72 insertions(+), 6 deletions(-) create mode 100644 testing/configs/custom_field/wavelength_shift1.toml create mode 100644 testing/configs/custom_field/wavelength_shift2.toml diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 1c91c2d..1f03a53 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -714,10 +714,10 @@ def setup_custom_field(params: Dict[str, Any]) -> bool: params["width"], params["peak_power"], params["energy"] = pulse.measure_field( params["t"], params["field_0"] ) - delta_w = params["w_c"][np.argmax(abs2(np.fft.fft(params["field_0"])))] - logger.debug(f"had to adjust w by {delta_w}") - params["wavelength"] = units.m.inv(units.m(params["wavelength"]) - delta_w) - _update_frequency_domain(params) + delta_w = params["w_c"][np.argmax(abs2(np.fft.fft(params["field_0"])))] + logger.debug(f"adjusted w by {delta_w}") + params["wavelength"] = units.m.inv(units.m(params["wavelength"]) - delta_w) + _update_frequency_domain(params) return True diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index c7edc87..54db741 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -413,7 +413,7 @@ def append_and_merge(final_sim_path: os.PathLike, new_name=None): update_appended_params( final_sim_path / "initial_config.toml", destination_path / "initial_config.toml", z_arr ) - + pbars.close() merge(destination_path, delete=True) diff --git a/testing/configs/custom_field/peak_power.toml b/testing/configs/custom_field/peak_power.toml index e658d0a..5ce44f8 100644 --- a/testing/configs/custom_field/peak_power.toml +++ b/testing/configs/custom_field/peak_power.toml @@ -3,5 +3,5 @@ field_file = "testing/configs/custom_field/init_field.npz" length = 1 peak_power = 20000 t_num = 2048 -wavelength = 1000e-9 +wavelength = 1593e-9 z_num = 32 diff --git a/testing/configs/custom_field/wavelength_shift1.toml b/testing/configs/custom_field/wavelength_shift1.toml new file mode 100644 index 0000000..1dcca62 --- /dev/null +++ b/testing/configs/custom_field/wavelength_shift1.toml @@ -0,0 +1,22 @@ +name = "test config" + +[fiber] +length = 1 +model = "pcf" +pitch = 1.5e-6 +pitch_ratio = 0.37 + +[pulse] +field_file = "testing/configs/custom_field/init_field.npz" +quantum_noise = false +wavelength = 1050e-9 + +[simulation] +behaviors = ["spm", "raman", "ss"] +lower_wavelength_interp_limit = 300e-9 +raman_type = "agrawal" +t_num = 16384 +time_window = 37e-12 +tolerated_error = 1e-11 +upper_wavelength_interp_limit = 1900e-9 +z_num = 128 diff --git a/testing/configs/custom_field/wavelength_shift2.toml b/testing/configs/custom_field/wavelength_shift2.toml new file mode 100644 index 0000000..9aa79ec --- /dev/null +++ b/testing/configs/custom_field/wavelength_shift2.toml @@ -0,0 +1,23 @@ +name = "test config" + +[fiber] +length = 1 +model = "pcf" +pitch = 1.5e-6 +pitch_ratio = 0.37 + +[pulse] +field_file = "testing/configs/custom_field/init_field.npz" +quantum_noise = false +[pulse.variable] +wavelength = [1050e-9, 1321e-9, 1593e-9] + +[simulation] +behaviors = ["spm", "raman", "ss"] +lower_wavelength_interp_limit = 300e-9 +raman_type = "agrawal" +t_num = 16384 +time_window = 37e-12 +tolerated_error = 1e-11 +upper_wavelength_interp_limit = 1900e-9 +z_num = 128 diff --git a/testing/test_initialize.py b/testing/test_initialize.py index f438a61..39a060f 100644 --- a/testing/test_initialize.py +++ b/testing/test_initialize.py @@ -6,6 +6,7 @@ import numpy as np import toml from scgenerator import defaults, utils, math from scgenerator.errors import * +from scgenerator.physics import units def load_conf(name): @@ -182,6 +183,7 @@ class TestInitializeMethods(unittest.TestCase): result = init.setup_custom_field(conf) self.assertAlmostEqual(math.abs2(conf["field_0"]).max(), 20000, 4) self.assertTrue(result) + self.assertNotAlmostEqual(conf["wavelength"], 1593e-9) conf = load_conf("custom_field/mean_power") conf = init._generate_sim_grid(conf) @@ -201,6 +203,25 @@ class TestInitializeMethods(unittest.TestCase): self.assertAlmostEqual((math.abs2(conf["field_0"]) / 0.9 - math.abs2(field)).sum(), 0) self.assertTrue(result) + conf = load_conf("custom_field/wavelength_shift1") + result = init.compute_init_parameters(conf) + self.assertAlmostEqual( + units.m.inv(result["w"])[np.argmax(math.abs2(result["spec_0"]))], 1050e-9 + ) + + conf = load_conf("custom_field/wavelength_shift1") + conf["pulse"]["wavelength"] = 1593e-9 + result = init.compute_init_parameters(conf) + + conf = load_conf("custom_field/wavelength_shift2") + for target, (variable, config) in zip( + [1050e-9, 1321e-9, 1593e-9], init.ParamSequence(conf) + ): + self.assertAlmostEqual( + units.m.inv(config["w"])[np.argmax(math.abs2(config["spec_0"]))], target + ) + print(config["wavelength"], target) + if __name__ == "__main__": conf = conf_maker("validate_types")