better submit, more tests

This commit is contained in:
Benoît Sierro
2021-06-01 10:47:57 +02:00
parent 2fcd277563
commit e985f053ac
12 changed files with 176 additions and 69 deletions

View File

@@ -675,12 +675,25 @@ def compute_init_parameters(config: Dict[str, Any]) -> Dict[str, Any]:
def setup_custom_field(params: Dict[str, Any]) -> bool: def setup_custom_field(params: Dict[str, Any]) -> bool:
"""sets up a custom field function if necessary and returns
True if it did so, False otherwise
Parameters
----------
params : Dict[str, Any]
params dictionary
Returns
-------
bool
True if the field has been modified
"""
logger = get_logger(__name__) logger = get_logger(__name__)
custom_field = True
if "prev_data_dir" in params: if "prev_data_dir" in params:
spec = io.load_last_spectrum(Path(params["prev_data_dir"]))[1] spec = io.load_last_spectrum(Path(params["prev_data_dir"]))[1]
params["field_0"] = np.fft.ifft(spec) * params["input_transmission"] params["field_0"] = np.fft.ifft(spec) * np.sqrt(params["input_transmission"])
elif "field_file" in params: else:
if "field_file" in params:
field_data = np.load(params["field_file"]) field_data = np.load(params["field_file"])
field_interp = interp1d( field_interp = interp1d(
field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0) field_data["time"], field_data["field"], bounds_error=False, fill_value=(0, 0)
@@ -689,9 +702,8 @@ def setup_custom_field(params: Dict[str, Any]) -> bool:
elif "field_0" in params: elif "field_0" in params:
params = _evalutate_custom_field_equation(params) params = _evalutate_custom_field_equation(params)
else: else:
custom_field = False return False
if custom_field:
params["field_0"] = params["field_0"] * pulse.modify_field_ratio( params["field_0"] = params["field_0"] * pulse.modify_field_ratio(
params["t"], params["t"],
params["field_0"], params["field_0"],
@@ -706,7 +718,7 @@ def setup_custom_field(params: Dict[str, Any]) -> bool:
logger.debug(f"had to adjust w by {delta_w}") logger.debug(f"had to adjust w by {delta_w}")
params["wavelength"] = units.m.inv(units.m(params["wavelength"]) - delta_w) params["wavelength"] = units.m.inv(units.m(params["wavelength"]) - delta_w)
_update_frequency_domain(params) _update_frequency_domain(params)
return custom_field return True
def _update_pulse_parameters(params): def _update_pulse_parameters(params):

View File

@@ -11,7 +11,7 @@ n is the number of spectra at the same z position and nt is the size of the time
import itertools import itertools
import os import os
from typing import Tuple from typing import Literal, Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@@ -115,14 +115,14 @@ def modify_field_ratio(
def conform_pulse_params( def conform_pulse_params(
shape, shape: Literal["gaussian", "sech"],
width=None, width: float = None,
t0=None, t0: float = None,
peak_power=None, peak_power: float = None,
energy=None, energy: float = None,
soliton_num=None, soliton_num: float = None,
gamma=None, gamma: float = None,
beta2=None, beta2: float = None,
): ):
"""makes sure all parameters of the pulse are set and consistent """makes sure all parameters of the pulse are set and consistent

View File

@@ -4,12 +4,65 @@ import re
import shutil import shutil
import subprocess import subprocess
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Tuple
import numpy as np
from ..initialize import validate_config_sequence from ..initialize import validate_config_sequence
from ..io import Paths from ..io import Paths
from ..utils import count_variations from ..utils import count_variations
def primes(n):
primfac = []
d = 2
while d * d <= n:
while (n % d) == 0:
primfac.append(d) # supposing you want multiple factors repeated
n //= d
d += 1
if n > 1:
primfac.append(n)
return primfac
def balance(n, lim=(32, 32)):
factors = primes(n)
if len(factors) == 1:
factors = primes(n + 1)
a, b, x, y = 1, 1, 1, 1
while len(factors) > 0 and x <= lim[0] and y <= lim[1]:
a = x
b = y
if y >= x:
x *= factors.pop(0)
else:
y *= factors.pop()
return a, b
def distribute(
num: int, nodes: int = None, cpus_per_node: int = None, lim=(16, 32)
) -> Tuple[int, int]:
if nodes is None and cpus_per_node is None:
balanced = balance(num, lim)
if num > max(lim):
while np.product(balanced) < min(lim):
num += 1
balanced = balance(num, lim)
nodes = min(balanced)
cpus_per_node = max(balanced)
elif nodes is None:
nodes = num // cpus_per_node
while nodes > lim[0]:
nodes //= 2
elif cpus_per_node is None:
cpus_per_node = num // nodes
while cpus_per_node > lim[1]:
cpus_per_node //= 2
return nodes, cpus_per_node
def format_time(t): def format_time(t):
try: try:
t = float(t) t = float(t)
@@ -25,9 +78,9 @@ def create_parser():
"-t", "--time", required=True, type=str, help="time required for the job in hh:mm:ss" "-t", "--time", required=True, type=str, help="time required for the job in hh:mm:ss"
) )
parser.add_argument( parser.add_argument(
"-c", "--cpus-per-node", required=True, type=int, help="number of cpus required per node" "-c", "--cpus-per-node", default=None, type=int, help="number of cpus required per node"
) )
parser.add_argument("-n", "--nodes", required=True, type=int, help="number of nodes required") parser.add_argument("-n", "--nodes", default=None, type=int, help="number of nodes required")
parser.add_argument( parser.add_argument(
"--environment-setup", "--environment-setup",
required=False, required=False,
@@ -70,6 +123,8 @@ def main():
sim_num, _ = count_variations(final_config) sim_num, _ = count_variations(final_config)
args.nodes, args.cpus_per_nodes = distribute(sim_num, args.nodes, args.cpus_per_nodes)
file_name = ( file_name = (
"submit " + final_config["name"] + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh" "submit " + final_config["name"] + "-" + format(datetime.now(), "%Y%m%d%H%M") + ".sh"
) )

View File

@@ -25,10 +25,6 @@ from .const import PARAM_SEPARATOR, PREFIX_KEY_BASE, valid_variable, HUSH_PROGRE
from .logger import get_logger from .logger import get_logger
from .math import * from .math import *
# XXX ############################################
# XXX ############### Pure Python ################
# XXX ############################################
class PBars: class PBars:
@classmethod @classmethod
@@ -245,34 +241,6 @@ def format_value(value):
return str(value) return str(value)
# def variable_list_from_path(s: str) -> List[tuple]:
# s = s.replace("/", "")
# str_list = s.split(PARAM_SEPARATOR)
# out = []
# for i in range(0, len(str_list) // 2 * 2, 2):
# out.append((str_list[i], get_value(str_list[i + 1])))
# return out
# def get_value(s: str):
# if s.lower() == "true":
# return True
# if s.lower() == "false":
# return False
# try:
# return int(s)
# except ValueError:
# pass
# try:
# return float(s)
# except ValueError:
# pass
# return s
def variable_iterator(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]: def variable_iterator(config) -> Iterator[Tuple[List[Tuple[str, Any]], dict]]:
"""given a config with "variable" parameters, iterates through every possible combination, """given a config with "variable" parameters, iterates through every possible combination,
yielding a a list of (parameter_name, value) tuples and a full config dictionary. yielding a a list of (parameter_name, value) tuples and a full config dictionary.
@@ -385,7 +353,7 @@ def parallelize(func, arg_iter, sim_jobs=4, progress_tracker_kwargs=None, const_
return np.array(results) return np.array(results)
def deep_update(d: Mapping, u: Mapping): def deep_update(d: Mapping, u: Mapping) -> dict:
for k, v in u.items(): for k, v in u.items():
if isinstance(v, collections.abc.Mapping): if isinstance(v, collections.abc.Mapping):
d[k] = deep_update(d.get(k, {}), v) d[k] = deep_update(d.get(k, {}), v)

Binary file not shown.

View File

@@ -0,0 +1,8 @@
dt = 1e-15
field_file = "testing/configs/custom_field/init_field.npz"
length = 1
mean_power = 220e-3
repetition_rate = 40e6
t_num = 2048
wavelength = 1000e-9
z_num = 32

View File

@@ -0,0 +1,6 @@
dt = 1e-15
field_file = "testing/configs/custom_field/init_field.npz"
length = 1
t_num = 2048
wavelength = 1000e-9
z_num = 32

View File

@@ -0,0 +1,7 @@
dt = 1e-15
field_file = "testing/configs/custom_field/init_field.npz"
length = 1
peak_power = 20000
t_num = 2048
wavelength = 1000e-9
z_num = 32

View File

@@ -0,0 +1,7 @@
dt = 1e-15
input_transmission = 1
length = 1
prev_data_dir = "testing/configs/custom_field/recover_data"
t_num = 2048
wavelength = 1000e-9
z_num = 32

View File

@@ -0,0 +1,7 @@
dt = 1e-15
input_transmission = 0.9
length = 1
prev_data_dir = "testing/configs/custom_field/recover_data"
t_num = 2048
wavelength = 1000e-9
z_num = 32

View File

@@ -5,6 +5,7 @@ name = "test config"
[fiber] [fiber]
gamma = 0.018 gamma = 0.018
length = 1 length = 1
model = "pcf"
pitch = 1.5e-6 pitch = 1.5e-6
pitch_ratio = 0.37 pitch_ratio = 0.37

View File

@@ -2,8 +2,9 @@ import unittest
from copy import deepcopy from copy import deepcopy
import scgenerator.initialize as init import scgenerator.initialize as init
import numpy as np
import toml import toml
from scgenerator import utils from scgenerator import defaults, utils, math
from scgenerator.errors import * from scgenerator.errors import *
@@ -30,7 +31,7 @@ class TestParamSequence(unittest.TestCase):
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]): for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
l = [] l = []
s = [] s = []
for vary_list, _ in param_seq.iterate_without_computing(): for vary_list, _ in utils.required_simulations(param_seq.config):
self.assertNotIn(vary_list, l) self.assertNotIn(vary_list, l)
self.assertNotIn(utils.format_variable_list(vary_list), s) self.assertNotIn(utils.format_variable_list(vary_list), s)
l.append(vary_list) l.append(vary_list)
@@ -39,12 +40,12 @@ class TestParamSequence(unittest.TestCase):
def test_init_config_not_affected_by_iteration(self): def test_init_config_not_affected_by_iteration(self):
for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]): for param_seq in self.iterconf(["almost_equal", "equal", "no_variations"]):
config = deepcopy(param_seq.config) config = deepcopy(param_seq.config)
for _ in param_seq.iterate_without_computing(): for _ in utils.required_simulations(param_seq.config):
self.assertEqual(config.items(), param_seq.config.items()) self.assertEqual(config.items(), param_seq.config.items())
def test_no_variations_yields_only_num_and_id(self): def test_no_variations_yields_only_num_and_id(self):
for param_seq in self.iterconf(["no_variations"]): for param_seq in self.iterconf(["no_variations"]):
for vary_list, _ in param_seq.iterate_without_computing(): for vary_list, _ in utils.required_simulations(param_seq.config):
self.assertEqual(vary_list[1][0], "num") self.assertEqual(vary_list[1][0], "num")
self.assertEqual(vary_list[0][0], "id") self.assertEqual(vary_list[0][0], "id")
self.assertEqual(2, len(vary_list)) self.assertEqual(2, len(vary_list))
@@ -65,18 +66,18 @@ class TestInitializeMethods(unittest.TestCase):
with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"): with self.assertRaisesRegex(TypeError, "'parallel' is not a valid variable parameter"):
init._validate_types(conf("bad4")) init._validate_types(conf("bad4"))
with self.assertRaisesRegex(TypeError, "Varying parameters should be specified in a list"): with self.assertRaisesRegex(TypeError, "Variable parameters should be specified in a list"):
init._validate_types(conf("bad5")) init._validate_types(conf("bad5"))
with self.assertRaisesRegex( with self.assertRaisesRegex(
TypeError, TypeError,
"value '0' of type <class 'int'> for key 'repeat' is not valid, must be a strictly positive integer", "value '0' of type .*int.* for key 'repeat' is not valid, must be a strictly positive integer",
): ):
init._validate_types(conf("bad6")) init._validate_types(conf("bad6"))
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
r"Varying parameters lists should contain at least 1 element", r"Variable parameters lists should contain at least 1 element",
): ):
init._ensure_consistency(conf("bad7")) init._ensure_consistency(conf("bad7"))
@@ -92,7 +93,7 @@ class TestInitializeMethods(unittest.TestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
MissingParameterError, MissingParameterError,
r"1 of '\['peak_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set", r"1 of '\['peak_power', 'mean_power', 'energy', 'width', 't0'\]' is required when 'soliton_num' is specified and no defaults have been set",
): ):
init._ensure_consistency(conf("bad2")) init._ensure_consistency(conf("bad2"))
@@ -156,14 +157,49 @@ class TestInitializeMethods(unittest.TestCase):
dict( dict(
t_num=16384, t_num=16384,
time_window=37e-12, time_window=37e-12,
lower_wavelength_interp_limit=0, lower_wavelength_interp_limit=defaults.default_parameters[
upper_wavelength_interp_limit=1900e-9, "lower_wavelength_interp_limit"
],
upper_wavelength_interp_limit=defaults.default_parameters[
"upper_wavelength_interp_limit"
],
).items(), ).items(),
init._ensure_consistency(conf("good6"))["simulation"].items(), init._ensure_consistency(conf("good6"))["simulation"].items(),
) )
# def test_compute_init_parameters(self): def test_setup_custom_field(self):
# conf = lambda s: load_conf("compute_init_parameters/" + s) d = np.load("testing/configs/custom_field/init_field.npz")
t = d["time"]
field = d["field"]
conf = load_conf("custom_field/no_change")
conf = init._generate_sim_grid(conf)
result = init.setup_custom_field(conf)
self.assertAlmostEqual(conf["field_0"].real.max(), field.real.max(), 4)
self.assertTrue(result)
conf = load_conf("custom_field/peak_power")
conf = init._generate_sim_grid(conf)
result = init.setup_custom_field(conf)
self.assertAlmostEqual(math.abs2(conf["field_0"]).max(), 20000, 4)
self.assertTrue(result)
conf = load_conf("custom_field/mean_power")
conf = init._generate_sim_grid(conf)
result = init.setup_custom_field(conf)
self.assertAlmostEqual(np.trapz(math.abs2(conf["field_0"]), conf["t"]), 0.22 / 40e6, 4)
self.assertTrue(result)
conf = load_conf("custom_field/recover1")
conf = init._generate_sim_grid(conf)
result = init.setup_custom_field(conf)
self.assertAlmostEqual(math.abs2(conf["field_0"] - field).sum(), 0)
self.assertTrue(result)
conf = load_conf("custom_field/recover2")
conf = init._generate_sim_grid(conf)
result = init.setup_custom_field(conf)
self.assertAlmostEqual((math.abs2(conf["field_0"]) / 0.9 - math.abs2(field)).sum(), 0)
self.assertTrue(result)
if __name__ == "__main__": if __name__ == "__main__":