diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index 7741bcb..1c91c2d 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -675,23 +675,35 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]: def setup_custom_field(params: Dict[str, Any]) -> bool: + """sets up a custom field function if necessary and returns + True if it did so, False otherwise + + Parameters + ---------- + params : Dict[str, Any] + params dictionary + + Returns + ------- + bool + True if the field has been modified + """ logger = get_logger(__name__) - custom_field = True if "prev_data_dir" in params: spec = io.load_last_spectrum(Path(params["prev_data_dir"]))[1] - params["field_0"] = np.fft.ifft(spec) * params["input_transmission"] - elif "field_file" in params: - field_data = np.load(params["field_file"]) - field_interp = interp1d( - field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0) - ) - params["field_0"] = field_interp(params["t"]) - elif "field_0" in params: - params = _evalutate_custom_field_equation(params) + params["field_0"] = np.fft.ifft(spec) * np.sqrt(params["input_transmission"]) else: - custom_field = False + if "field_file" in params: + field_data = np.load(params["field_file"]) + field_interp = interp1d( + field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0) + ) + params["field_0"] = field_interp(params["t"]) + elif "field_0" in params: + params = _evalutate_custom_field_equation(params) + else: + return False - if custom_field: params["field_0"] = params["field_0"] * pulse.modify_field_ratio( params["t"], params["field_0"], @@ -706,7 +718,7 @@ def setup_custom_field(params: Dict[str, Any]) -> bool: logger.debug(f"had to adjust w by {delta_w}") params["wavelength"] = units.m.inv(units.m(params["wavelength"]) - delta_w) _update_frequency_domain(params) - return custom_field + return True def _update_pulse_parameters(params): diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index 0662f35..00cb831 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -11,7 +11,7 @@ n is the number of spectra at the same z position and nt is the size of the time import itertools import os -from typing import Tuple +from typing import Literal, Tuple import matplotlib.pyplot as plt import numpy as np @@ -115,14 +115,14 @@ def modify_field_ratio( def conform_pulse_params( - shape, - width=None, - t0=None, - peak_power=None, - energy=None, - soliton_num=None, - gamma=None, - beta2=None, + shape: Literal["gaussian", "sech"], + width: float = None, + t0: float = None, + peak_power: float = None, + energy: float = None, + soliton_num: float = None, + gamma: float = None, + beta2: float = None, ): """makes sure all parameters of the pulse are set and consistent diff --git a/src/scgenerator/scripts/slurm_submit.py b/src/scgenerator/scripts/slurm_submit.py index 3dbcbb5..1d46265 100644 --- a/src/scgenerator/scripts/slurm_submit.py +++ b/src/scgenerator/scripts/slurm_submit.py @@ -4,12 +4,65 @@ import re import shutil import subprocess from datetime import datetime, timedelta +from typing import Tuple +import numpy as np from ..initialize import validate_config_sequence from ..io import Paths from ..utils import count_variations +def primes(n): + primfac = [] + d = 2 + while d * d <= n: + while (n % d) == 0: + primfac.append(d) # supposing you want multiple factors repeated + n //= d + d += 1 + if n > 1: + primfac.append(n) + return primfac + + +def balance(n, lim=(32, 32)): + factors = primes(n) + if len(factors) == 1: + factors = primes(n + 1) + a, b, x, y = 1, 1, 1, 1 + while len(factors) > 0 and x <= lim[0] and y <= lim[1]: + a = x + b = y + if y >= x: + x *= factors.pop(0) + else: + y *= factors.pop() + return a, b + + +def distribute( + num: int, nodes: int = None, cpus_per_node: int = None, lim=(16, 32) +) -> Tuple[int, int]: + if nodes is None and cpus_per_node is None: + balanced = balance(num, lim) + if num > max(lim): + while np.product(balanced) < min(lim): + num += 1 + balanced = balance(num, lim) + nodes = min(balanced) + cpus_per_node = max(balanced) + + elif nodes is None: + nodes = num // cpus_per_node + while nodes > lim[0]: + nodes //= 2 + elif cpus_per_node is None: + cpus_per_node = num // nodes + while cpus_per_node > lim[1]: + cpus_per_node //= 2 + return nodes, cpus_per_node + + def format_time(t): try: t = float(t) @@ -25,9 +78,9 @@ def create_parser(): "-t", "--time", required=True, type=str, help="time required for the job in hh:mm:ss" ) parser.add_argument( - "-c", "--cpus-per-node", required=True, type=int, help="number of cpus required per node" + "-c", "--cpus-per-node", default=None, type=int, help="number of cpus required per node" ) - parser.add_argument("-n", "--nodes", required=True, type=int, help="number of nodes required") + parser.add_argument("-n", "--nodes", default=None, type=int, help="number of nodes required") parser.add_argument( "--environment-setup", required=False, @@ -70,6 +123,8 @@ def main(): sim_num, _ = count_variations(final_config) + args.nodes, args.cpus_per_nodes = distribute(sim_num, args.nodes, args.cpus_per_nodes) + file_name = ( "submit " + final_config["name"] + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh" ) diff --git a/src/scgenerator/utils.py b/src/scgenerator/utils.py index cc66c54..1f31c35 100644 --- a/src/scgenerator/utils.py +++ b/src/scgenerator/utils.py @@ -25,10 +25,6 @@ from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, HUSH_PROGRE from .logger import get_logger from .math import * -# XXX ############################################ -# XXX ############### Pure Python ################ -# XXX ############################################ - class PBars: @classmethod @@ -245,34 +241,6 @@ def format_value(value): return str(value) -# def variable_list_from_path(s: str) -> List[tuple]: -# s = s.replace("/", "") -# str_list = s.split(PARAM_SEPARATOR) -# out = [] -# for i in range(0, len(str_list) // 2 * 2, 2): -# out.append((str_list[i], get_value(str_list[i + 1]))) -# return out - - -# def get_value(s: str): -# if s.lower() == "true": -# return True -# if s.lower() == "false": -# return False - -# try: -# return int(s) -# except ValueError: -# pass - -# try: -# return float(s) -# except ValueError: -# pass - -# return s - - def variable_iterator(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: """given a config with "variable" parameters, iterates through every possible combination, yielding a a list of (parameter_name, value) tuples and a full config dictionary. @@ -385,7 +353,7 @@ def parallelize(func, arg_iter, sim_jobs=4, progress_tracker_kwargs=None, const_ return np.array(results) -def deep_update(d: Mapping, u: Mapping): +def deep_update(d: Mapping, u: Mapping) -> dict: for k, v in u.items(): if isinstance(v, collections.abc.Mapping): d[k] = deep_update(d.get(k, {}), v) diff --git a/testing/configs/custom_field/init_field.npz b/testing/configs/custom_field/init_field.npz new file mode 100644 index 0000000..e5687ed Binary files /dev/null and b/testing/configs/custom_field/init_field.npz differ diff --git a/testing/configs/custom_field/mean_power.toml b/testing/configs/custom_field/mean_power.toml new file mode 100644 index 0000000..953ec5c --- /dev/null +++ b/testing/configs/custom_field/mean_power.toml @@ -0,0 +1,8 @@ +dt = 1e-15 +field_file = "testing/configs/custom_field/init_field.npz" +length = 1 +mean_power = 220e-3 +repetition_rate = 40e6 +t_num = 2048 +wavelength = 1000e-9 +z_num = 32 diff --git a/testing/configs/custom_field/no_change.toml b/testing/configs/custom_field/no_change.toml new file mode 100644 index 0000000..288b1fb --- /dev/null +++ b/testing/configs/custom_field/no_change.toml @@ -0,0 +1,6 @@ +dt = 1e-15 +field_file = "testing/configs/custom_field/init_field.npz" +length = 1 +t_num = 2048 +wavelength = 1000e-9 +z_num = 32 diff --git a/testing/configs/custom_field/peak_power.toml b/testing/configs/custom_field/peak_power.toml new file mode 100644 index 0000000..e658d0a --- /dev/null +++ b/testing/configs/custom_field/peak_power.toml @@ -0,0 +1,7 @@ +dt = 1e-15 +field_file = "testing/configs/custom_field/init_field.npz" +length = 1 +peak_power = 20000 +t_num = 2048 +wavelength = 1000e-9 +z_num = 32 diff --git a/testing/configs/custom_field/recover1.toml b/testing/configs/custom_field/recover1.toml new file mode 100644 index 0000000..52400f3 --- /dev/null +++ b/testing/configs/custom_field/recover1.toml @@ -0,0 +1,7 @@ +dt = 1e-15 +input_transmission = 1 +length = 1 +prev_data_dir = "testing/configs/custom_field/recover_data" +t_num = 2048 +wavelength = 1000e-9 +z_num = 32 diff --git a/testing/configs/custom_field/recover2.toml b/testing/configs/custom_field/recover2.toml new file mode 100644 index 0000000..853b8ea --- /dev/null +++ b/testing/configs/custom_field/recover2.toml @@ -0,0 +1,7 @@ +dt = 1e-15 +input_transmission = 0.9 +length = 1 +prev_data_dir = "testing/configs/custom_field/recover_data" +t_num = 2048 +wavelength = 1000e-9 +z_num = 32 diff --git a/testing/configs/ensure_consistency/good1.toml b/testing/configs/ensure_consistency/good1.toml index 194f6b0..c1b778d 100644 --- a/testing/configs/ensure_consistency/good1.toml +++ b/testing/configs/ensure_consistency/good1.toml @@ -5,6 +5,7 @@ name = "test config" [fiber] gamma = 0.018 length = 1 +model = "pcf" pitch = 1.5e-6 pitch_ratio = 0.37 diff --git a/testing/test_initialize.py b/testing/test_initialize.py index 05d5cd2..f438a61 100644 --- a/testing/test_initialize.py +++ b/testing/test_initialize.py @@ -2,8 +2,9 @@ import unittest from copy import deepcopy import scgenerator.initialize as init +import numpy as np import toml -from scgenerator import utils +from scgenerator import defaults, utils, math from scgenerator.errors import * @@ -30,7 +31,7 @@ class TestParamSequence(unittest.TestCase): for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]): l = [] s = [] - for vary_list, _ in param_seq.iterate_without_computing(): + for vary_list, _ in utils.required_simulations(param_seq.config): self.assertNotIn(vary_list, l) self.assertNotIn(utils.format_variable_list(vary_list), s) l.append(vary_list) @@ -39,12 +40,12 @@ class TestParamSequence(unittest.TestCase): def test_init_config_not_affected_by_iteration(self): for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]): config = deepcopy(param_seq.config) - for _ in param_seq.iterate_without_computing(): + for _ in utils.required_simulations(param_seq.config): self.assertEqual(config.items(), param_seq.config.items()) def test_no_variations_yields_only_num_and_id(self): for param_seq in self.iterconf(["no_variations"]): - for vary_list, _ in param_seq.iterate_without_computing(): + for vary_list, _ in utils.required_simulations(param_seq.config): self.assertEqual(vary_list[1][0], "num") self.assertEqual(vary_list[0][0], "id") self.assertEqual(2, len(vary_list)) @@ -65,18 +66,18 @@ class TestInitializeMethods(unittest.TestCase): with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"): init._validate_types(conf("bad4")) - with self.assertRaisesRegex(TypeError, "Varying parameters should be specified in a list"): + with self.assertRaisesRegex(TypeError, "Variable parameters should be specified in a list"): init._validate_types(conf("bad5")) with self.assertRaisesRegex( TypeError, - "value '0' of type for key 'repeat' is not valid, must be a strictly positive integer", + "value '0' of type .*int.* for key 'repeat' is not valid, must be a strictly positive integer", ): init._validate_types(conf("bad6")) with self.assertRaisesRegex( ValueError, - r"Varying parameters lists should contain at least 1 element", + r"Variable parameters lists should contain at least 1 element", ): init._ensure_consistency(conf("bad7")) @@ -92,7 +93,7 @@ class TestInitializeMethods(unittest.TestCase): with self.assertRaisesRegex( MissingParameterError, - r"1 of '\['peak_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set", + r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set", ): init._ensure_consistency(conf("bad2")) @@ -156,14 +157,49 @@ class TestInitializeMethods(unittest.TestCase): dict( t_num=16384, time_window=37e-12, - lower_wavelength_interp_limit=0, - upper_wavelength_interp_limit=1900e-9, + lower_wavelength_interp_limit=defaults.default_parameters[ + "lower_wavelength_interp_limit" + ], + upper_wavelength_interp_limit=defaults.default_parameters[ + "upper_wavelength_interp_limit" + ], ).items(), init._ensure_consistency(conf("good6"))["simulation"].items(), ) - # def test_compute_init_parameters(self): - # conf = lambda s: load_conf("compute_init_parameters/" + s) + def test_setup_custom_field(self): + d = np.load("testing/configs/custom_field/init_field.npz") + t = d["time"] + field = d["field"] + conf = load_conf("custom_field/no_change") + conf = init._generate_sim_grid(conf) + result = init.setup_custom_field(conf) + self.assertAlmostEqual(conf["field_0"].real.max(), field.real.max(), 4) + self.assertTrue(result) + + conf = load_conf("custom_field/peak_power") + conf = init._generate_sim_grid(conf) + result = init.setup_custom_field(conf) + self.assertAlmostEqual(math.abs2(conf["field_0"]).max(), 20000, 4) + self.assertTrue(result) + + conf = load_conf("custom_field/mean_power") + conf = init._generate_sim_grid(conf) + result = init.setup_custom_field(conf) + self.assertAlmostEqual(np.trapz(math.abs2(conf["field_0"]), conf["t"]), 0.22 / 40e6, 4) + self.assertTrue(result) + + conf = load_conf("custom_field/recover1") + conf = init._generate_sim_grid(conf) + result = init.setup_custom_field(conf) + self.assertAlmostEqual(math.abs2(conf["field_0"] - field).sum(), 0) + self.assertTrue(result) + + conf = load_conf("custom_field/recover2") + conf = init._generate_sim_grid(conf) + result = init.setup_custom_field(conf) + self.assertAlmostEqual((math.abs2(conf["field_0"]) / 0.9 - math.abs2(field)).sum(), 0) + self.assertTrue(result) if __name__ == "__main__":