better submit, more tests
This commit is contained in:
@@ -675,23 +675,35 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def setup_custom_field(params: Dict[str, Any]) -> bool:
|
||||
"""sets up a custom field function if necessary and returns
|
||||
True if it did so, False otherwise
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : Dict[str, Any]
|
||||
params dictionary
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the field has been modified
|
||||
"""
|
||||
logger = get_logger(__name__)
|
||||
custom_field = True
|
||||
if "prev_data_dir" in params:
|
||||
spec = io.load_last_spectrum(Path(params["prev_data_dir"]))[1]
|
||||
params["field_0"] = np.fft.ifft(spec) * params["input_transmission"]
|
||||
elif "field_file" in params:
|
||||
field_data = np.load(params["field_file"])
|
||||
field_interp = interp1d(
|
||||
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
|
||||
)
|
||||
params["field_0"] = field_interp(params["t"])
|
||||
elif "field_0" in params:
|
||||
params = _evalutate_custom_field_equation(params)
|
||||
params["field_0"] = np.fft.ifft(spec) * np.sqrt(params["input_transmission"])
|
||||
else:
|
||||
custom_field = False
|
||||
if "field_file" in params:
|
||||
field_data = np.load(params["field_file"])
|
||||
field_interp = interp1d(
|
||||
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
|
||||
)
|
||||
params["field_0"] = field_interp(params["t"])
|
||||
elif "field_0" in params:
|
||||
params = _evalutate_custom_field_equation(params)
|
||||
else:
|
||||
return False
|
||||
|
||||
if custom_field:
|
||||
params["field_0"] = params["field_0"] * pulse.modify_field_ratio(
|
||||
params["t"],
|
||||
params["field_0"],
|
||||
@@ -706,7 +718,7 @@ def setup_custom_field(params: Dict[str, Any]) -> bool:
|
||||
logger.debug(f"had to adjust w by {delta_w}")
|
||||
params["wavelength"] = units.m.inv(units.m(params["wavelength"]) - delta_w)
|
||||
_update_frequency_domain(params)
|
||||
return custom_field
|
||||
return True
|
||||
|
||||
|
||||
def _update_pulse_parameters(params):
|
||||
|
||||
@@ -11,7 +11,7 @@ n is the number of spectra at the same z position and nt is the size of the time
|
||||
|
||||
import itertools
|
||||
import os
|
||||
from typing import Tuple
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@@ -115,14 +115,14 @@ def modify_field_ratio(
|
||||
|
||||
|
||||
def conform_pulse_params(
|
||||
shape,
|
||||
width=None,
|
||||
t0=None,
|
||||
peak_power=None,
|
||||
energy=None,
|
||||
soliton_num=None,
|
||||
gamma=None,
|
||||
beta2=None,
|
||||
shape: Literal["gaussian", "sech"],
|
||||
width: float = None,
|
||||
t0: float = None,
|
||||
peak_power: float = None,
|
||||
energy: float = None,
|
||||
soliton_num: float = None,
|
||||
gamma: float = None,
|
||||
beta2: float = None,
|
||||
):
|
||||
"""makes sure all parameters of the pulse are set and consistent
|
||||
|
||||
|
||||
@@ -4,12 +4,65 @@ import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Tuple
|
||||
import numpy as np
|
||||
|
||||
from ..initialize import validate_config_sequence
|
||||
from ..io import Paths
|
||||
from ..utils import count_variations
|
||||
|
||||
|
||||
def primes(n):
|
||||
primfac = []
|
||||
d = 2
|
||||
while d * d <= n:
|
||||
while (n % d) == 0:
|
||||
primfac.append(d) # supposing you want multiple factors repeated
|
||||
n //= d
|
||||
d += 1
|
||||
if n > 1:
|
||||
primfac.append(n)
|
||||
return primfac
|
||||
|
||||
|
||||
def balance(n, lim=(32, 32)):
|
||||
factors = primes(n)
|
||||
if len(factors) == 1:
|
||||
factors = primes(n + 1)
|
||||
a, b, x, y = 1, 1, 1, 1
|
||||
while len(factors) > 0 and x <= lim[0] and y <= lim[1]:
|
||||
a = x
|
||||
b = y
|
||||
if y >= x:
|
||||
x *= factors.pop(0)
|
||||
else:
|
||||
y *= factors.pop()
|
||||
return a, b
|
||||
|
||||
|
||||
def distribute(
|
||||
num: int, nodes: int = None, cpus_per_node: int = None, lim=(16, 32)
|
||||
) -> Tuple[int, int]:
|
||||
if nodes is None and cpus_per_node is None:
|
||||
balanced = balance(num, lim)
|
||||
if num > max(lim):
|
||||
while np.product(balanced) < min(lim):
|
||||
num += 1
|
||||
balanced = balance(num, lim)
|
||||
nodes = min(balanced)
|
||||
cpus_per_node = max(balanced)
|
||||
|
||||
elif nodes is None:
|
||||
nodes = num // cpus_per_node
|
||||
while nodes > lim[0]:
|
||||
nodes //= 2
|
||||
elif cpus_per_node is None:
|
||||
cpus_per_node = num // nodes
|
||||
while cpus_per_node > lim[1]:
|
||||
cpus_per_node //= 2
|
||||
return nodes, cpus_per_node
|
||||
|
||||
|
||||
def format_time(t):
|
||||
try:
|
||||
t = float(t)
|
||||
@@ -25,9 +78,9 @@ def create_parser():
|
||||
"-t", "--time", required=True, type=str, help="time required for the job in hh:mm:ss"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--cpus-per-node", required=True, type=int, help="number of cpus required per node"
|
||||
"-c", "--cpus-per-node", default=None, type=int, help="number of cpus required per node"
|
||||
)
|
||||
parser.add_argument("-n", "--nodes", required=True, type=int, help="number of nodes required")
|
||||
parser.add_argument("-n", "--nodes", default=None, type=int, help="number of nodes required")
|
||||
parser.add_argument(
|
||||
"--environment-setup",
|
||||
required=False,
|
||||
@@ -70,6 +123,8 @@ def main():
|
||||
|
||||
sim_num, _ = count_variations(final_config)
|
||||
|
||||
args.nodes, args.cpus_per_nodes = distribute(sim_num, args.nodes, args.cpus_per_nodes)
|
||||
|
||||
file_name = (
|
||||
"submit " + final_config["name"] + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh"
|
||||
)
|
||||
|
||||
@@ -25,10 +25,6 @@ from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, HUSH_PROGRE
|
||||
from .logger import get_logger
|
||||
from .math import *
|
||||
|
||||
# XXX ############################################
|
||||
# XXX ############### Pure Python ################
|
||||
# XXX ############################################
|
||||
|
||||
|
||||
class PBars:
|
||||
@classmethod
|
||||
@@ -245,34 +241,6 @@ def format_value(value):
|
||||
return str(value)
|
||||
|
||||
|
||||
# def variable_list_from_path(s: str) -> List[tuple]:
|
||||
# s = s.replace("/", "")
|
||||
# str_list = s.split(PARAM_SEPARATOR)
|
||||
# out = []
|
||||
# for i in range(0, len(str_list) // 2 * 2, 2):
|
||||
# out.append((str_list[i], get_value(str_list[i + 1])))
|
||||
# return out
|
||||
|
||||
|
||||
# def get_value(s: str):
|
||||
# if s.lower() == "true":
|
||||
# return True
|
||||
# if s.lower() == "false":
|
||||
# return False
|
||||
|
||||
# try:
|
||||
# return int(s)
|
||||
# except ValueError:
|
||||
# pass
|
||||
|
||||
# try:
|
||||
# return float(s)
|
||||
# except ValueError:
|
||||
# pass
|
||||
|
||||
# return s
|
||||
|
||||
|
||||
def variable_iterator(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
|
||||
"""given a config with "variable" parameters, iterates through every possible combination,
|
||||
yielding a a list of (parameter_name, value) tuples and a full config dictionary.
|
||||
@@ -385,7 +353,7 @@ def parallelize(func, arg_iter, sim_jobs=4, progress_tracker_kwargs=None, const_
|
||||
return np.array(results)
|
||||
|
||||
|
||||
def deep_update(d: Mapping, u: Mapping):
|
||||
def deep_update(d: Mapping, u: Mapping) -> dict:
|
||||
for k, v in u.items():
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
d[k] = deep_update(d.get(k, {}), v)
|
||||
|
||||
BIN
testing/configs/custom_field/init_field.npz
Normal file
BIN
testing/configs/custom_field/init_field.npz
Normal file
Binary file not shown.
8
testing/configs/custom_field/mean_power.toml
Normal file
8
testing/configs/custom_field/mean_power.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
dt = 1e-15
|
||||
field_file = "testing/configs/custom_field/init_field.npz"
|
||||
length = 1
|
||||
mean_power = 220e-3
|
||||
repetition_rate = 40e6
|
||||
t_num = 2048
|
||||
wavelength = 1000e-9
|
||||
z_num = 32
|
||||
6
testing/configs/custom_field/no_change.toml
Normal file
6
testing/configs/custom_field/no_change.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
dt = 1e-15
|
||||
field_file = "testing/configs/custom_field/init_field.npz"
|
||||
length = 1
|
||||
t_num = 2048
|
||||
wavelength = 1000e-9
|
||||
z_num = 32
|
||||
7
testing/configs/custom_field/peak_power.toml
Normal file
7
testing/configs/custom_field/peak_power.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
dt = 1e-15
|
||||
field_file = "testing/configs/custom_field/init_field.npz"
|
||||
length = 1
|
||||
peak_power = 20000
|
||||
t_num = 2048
|
||||
wavelength = 1000e-9
|
||||
z_num = 32
|
||||
7
testing/configs/custom_field/recover1.toml
Normal file
7
testing/configs/custom_field/recover1.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
dt = 1e-15
|
||||
input_transmission = 1
|
||||
length = 1
|
||||
prev_data_dir = "testing/configs/custom_field/recover_data"
|
||||
t_num = 2048
|
||||
wavelength = 1000e-9
|
||||
z_num = 32
|
||||
7
testing/configs/custom_field/recover2.toml
Normal file
7
testing/configs/custom_field/recover2.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
dt = 1e-15
|
||||
input_transmission = 0.9
|
||||
length = 1
|
||||
prev_data_dir = "testing/configs/custom_field/recover_data"
|
||||
t_num = 2048
|
||||
wavelength = 1000e-9
|
||||
z_num = 32
|
||||
@@ -5,6 +5,7 @@ name = "test config"
|
||||
[fiber]
|
||||
gamma = 0.018
|
||||
length = 1
|
||||
model = "pcf"
|
||||
pitch = 1.5e-6
|
||||
pitch_ratio = 0.37
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@ import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import scgenerator.initialize as init
|
||||
import numpy as np
|
||||
import toml
|
||||
from scgenerator import utils
|
||||
from scgenerator import defaults, utils, math
|
||||
from scgenerator.errors import *
|
||||
|
||||
|
||||
@@ -30,7 +31,7 @@ class TestParamSequence(unittest.TestCase):
|
||||
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
|
||||
l = []
|
||||
s = []
|
||||
for vary_list, _ in param_seq.iterate_without_computing():
|
||||
for vary_list, _ in utils.required_simulations(param_seq.config):
|
||||
self.assertNotIn(vary_list, l)
|
||||
self.assertNotIn(utils.format_variable_list(vary_list), s)
|
||||
l.append(vary_list)
|
||||
@@ -39,12 +40,12 @@ class TestParamSequence(unittest.TestCase):
|
||||
def test_init_config_not_affected_by_iteration(self):
|
||||
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
|
||||
config = deepcopy(param_seq.config)
|
||||
for _ in param_seq.iterate_without_computing():
|
||||
for _ in utils.required_simulations(param_seq.config):
|
||||
self.assertEqual(config.items(), param_seq.config.items())
|
||||
|
||||
def test_no_variations_yields_only_num_and_id(self):
|
||||
for param_seq in self.iterconf(["no_variations"]):
|
||||
for vary_list, _ in param_seq.iterate_without_computing():
|
||||
for vary_list, _ in utils.required_simulations(param_seq.config):
|
||||
self.assertEqual(vary_list[1][0], "num")
|
||||
self.assertEqual(vary_list[0][0], "id")
|
||||
self.assertEqual(2, len(vary_list))
|
||||
@@ -65,18 +66,18 @@ class TestInitializeMethods(unittest.TestCase):
|
||||
with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"):
|
||||
init._validate_types(conf("bad4"))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Varying parameters should be specified in a list"):
|
||||
with self.assertRaisesRegex(TypeError, "Variable parameters should be specified in a list"):
|
||||
init._validate_types(conf("bad5"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"value '0' of type <class 'int'> for key 'repeat' is not valid, must be a strictly positive integer",
|
||||
"value '0' of type .*int.* for key 'repeat' is not valid, must be a strictly positive integer",
|
||||
):
|
||||
init._validate_types(conf("bad6"))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Varying parameters lists should contain at least 1 element",
|
||||
r"Variable parameters lists should contain at least 1 element",
|
||||
):
|
||||
init._ensure_consistency(conf("bad7"))
|
||||
|
||||
@@ -92,7 +93,7 @@ class TestInitializeMethods(unittest.TestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
MissingParameterError,
|
||||
r"1 of '\['peak_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set",
|
||||
r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set",
|
||||
):
|
||||
init._ensure_consistency(conf("bad2"))
|
||||
|
||||
@@ -156,14 +157,49 @@ class TestInitializeMethods(unittest.TestCase):
|
||||
dict(
|
||||
t_num=16384,
|
||||
time_window=37e-12,
|
||||
lower_wavelength_interp_limit=0,
|
||||
upper_wavelength_interp_limit=1900e-9,
|
||||
lower_wavelength_interp_limit=defaults.default_parameters[
|
||||
"lower_wavelength_interp_limit"
|
||||
],
|
||||
upper_wavelength_interp_limit=defaults.default_parameters[
|
||||
"upper_wavelength_interp_limit"
|
||||
],
|
||||
).items(),
|
||||
init._ensure_consistency(conf("good6"))["simulation"].items(),
|
||||
)
|
||||
|
||||
# def test_compute_init_parameters(self):
|
||||
# conf = lambda s: load_conf("compute_init_parameters/" + s)
|
||||
def test_setup_custom_field(self):
|
||||
d = np.load("testing/configs/custom_field/init_field.npz")
|
||||
t = d["time"]
|
||||
field = d["field"]
|
||||
conf = load_conf("custom_field/no_change")
|
||||
conf = init._generate_sim_grid(conf)
|
||||
result = init.setup_custom_field(conf)
|
||||
self.assertAlmostEqual(conf["field_0"].real.max(), field.real.max(), 4)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = load_conf("custom_field/peak_power")
|
||||
conf = init._generate_sim_grid(conf)
|
||||
result = init.setup_custom_field(conf)
|
||||
self.assertAlmostEqual(math.abs2(conf["field_0"]).max(), 20000, 4)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = load_conf("custom_field/mean_power")
|
||||
conf = init._generate_sim_grid(conf)
|
||||
result = init.setup_custom_field(conf)
|
||||
self.assertAlmostEqual(np.trapz(math.abs2(conf["field_0"]), conf["t"]), 0.22 / 40e6, 4)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = load_conf("custom_field/recover1")
|
||||
conf = init._generate_sim_grid(conf)
|
||||
result = init.setup_custom_field(conf)
|
||||
self.assertAlmostEqual(math.abs2(conf["field_0"] - field).sum(), 0)
|
||||
self.assertTrue(result)
|
||||
|
||||
conf = load_conf("custom_field/recover2")
|
||||
conf = init._generate_sim_grid(conf)
|
||||
result = init.setup_custom_field(conf)
|
||||
self.assertAlmostEqual((math.abs2(conf["field_0"]) / 0.9 - math.abs2(field)).sum(), 0)
|
||||
self.assertTrue(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user