parameters are no longer auto-computed

This commit is contained in:
Benoît Sierro
2021-09-02 15:53:30 +02:00
parent 15eafd4781
commit fb23786c70
8 changed files with 74 additions and 28 deletions

View File

@@ -236,6 +236,9 @@ quantum_noise: bool
intensity_noise : float intensity_noise : float
relative intensity noise relative intensity noise
noise_correlation : float
correlation between intensity noise and pulse width noise. a negative value means anti-correlation
shape: str {"gaussian", "sech"} shape: str {"gaussian", "sech"}
shape of the pulse. default : gaussian shape of the pulse. default : gaussian

View File

@@ -118,6 +118,13 @@ def create_parser():
) )
init_plot_parser.set_defaults(func=plot_init) init_plot_parser.set_defaults(func=plot_init)
convert_parser = subparsers.add_parser(
"convert",
help="convert parameter files that have been saved with an older version of the program",
)
convert_parser.add_argument("config", help="path to config/parameter file")
convert_parser.set_defaults(func=translate_parameters)
return parser return parser
@@ -224,5 +231,10 @@ def plot_dispersion(args):
scripts.plot_dispersion(args.config, lims) scripts.plot_dispersion(args.config, lims)
def translate_parameters(args):
path = args.config
scripts.convert_params(path)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,4 +1,4 @@
__version__ = "0.2.1rules" __version__ = "0.2.2rules"
from typing import Any from typing import Any

View File

@@ -1,4 +1,3 @@
from send2trash import send2trash
import multiprocessing import multiprocessing
import multiprocessing.connection import multiprocessing.connection
import os import os
@@ -8,6 +7,7 @@ from pathlib import Path
from typing import Any, Generator, Type from typing import Any, Generator, Type
import numpy as np import numpy as np
from send2trash import send2trash
from .. import env, utils from .. import env, utils
from ..logger import get_logger from ..logger import get_logger
@@ -638,20 +638,16 @@ class RaySimulations(Simulations, priority=2):
) )
) )
self.propagator = ray.remote(RayRK4IP).options(runtime_env=dict(env_vars=env.all_environ())) self.propagator = ray.remote(RayRK4IP)
self.update_cluster_frequency = 3 self.update_cluster_frequency = 3
self.jobs = [] self.jobs = []
self.pool = ray.util.ActorPool(self.propagator.remote() for _ in range(self.sim_jobs_total)) self.pool = ray.util.ActorPool(self.propagator.remote() for _ in range(self.sim_jobs_total))
self.num_submitted = 0 self.num_submitted = 0
self.rolling_id = 0 self.rolling_id = 0
self.p_actor = ( self.p_actor = ray.remote(utils.ProgressBarActor).remote(
ray.remote(utils.ProgressBarActor)
.options(runtime_env=dict(env_vars=env.all_environ()))
.remote(
self.configuration.name, self.sim_jobs_total, self.configuration.total_num_steps self.configuration.name, self.sim_jobs_total, self.configuration.total_num_steps
) )
)
self.configuration.skip_callback = lambda num: ray.get(self.p_actor.update.remote(0, num)) self.configuration.skip_callback = lambda num: ray.get(self.p_actor.update.remote(0, num))
def new_sim(self, v_list_str: str, params: Parameters): def new_sim(self, v_list_str: str, params: Parameters):

View File

@@ -1,4 +1,5 @@
import itertools import itertools
import os
from itertools import cycle from itertools import cycle
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, Optional from typing import Any, Iterable, Optional
@@ -9,11 +10,11 @@ from cycler import cycler
from tqdm import tqdm from tqdm import tqdm
from .. import env, math from .. import env, math
from ..const import PARAM_SEPARATOR from ..const import PARAM_FN, PARAM_SEPARATOR
from ..physics import fiber, units from ..physics import fiber, units
from ..plotting import plot_setup from ..plotting import plot_setup
from ..spectra import Pulse from ..spectra import Pulse
from ..utils import auto_crop, load_toml from ..utils import auto_crop, load_toml, save_toml, translate_parameters
from ..utils.parameter import ( from ..utils.parameter import (
Configuration, Configuration,
Parameters, Parameters,
@@ -262,3 +263,18 @@ def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters
for style, (variables, params) in zip(cc, pseq): for style, (variables, params) in zip(cc, pseq):
lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]] lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]]
yield style, lbl, params yield style, lbl, params
def convert_params(params_file: os.PathLike):
p = Path(params_file)
if p.name == PARAM_FN:
d = load_toml(params_file)
d = translate_parameters(d)
save_toml(params_file, d)
print(f"converted {p}")
else:
for pp in p.glob(PARAM_FN):
convert_params(pp)
for pp in p.glob("fiber*"):
if pp.is_dir():
convert_params(pp)

View File

@@ -150,6 +150,7 @@ class Pulse(Sequence):
raise FileNotFoundError(f"Folder {self.path} does not exist") raise FileNotFoundError(f"Folder {self.path} does not exist")
self.params = Parameters.load(self.path / "params.toml") self.params = Parameters.load(self.path / "params.toml")
self.params.compute(["t", "l", "w_c", "w0", "z_targets"])
try: try:
self.z = np.load(os.path.join(path, "z.npy")) self.z = np.load(os.path.join(path, "z.npy"))

View File

@@ -14,7 +14,6 @@ import re
import shutil import shutil
import threading import threading
from collections import abc from collections import abc
from copy import deepcopy
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from string import printable as str_printable from string import printable as str_printable
@@ -26,8 +25,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__ from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__
from ..env import TMP_FOLDER_KEY_BASE, data_folder, pbar_policy from ..env import pbar_policy
from ..errors import IncompleteDataFolderError
from ..logger import get_logger from ..logger import get_logger
T_ = TypeVar("T_") T_ = TypeVar("T_")
@@ -126,8 +124,11 @@ def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str,
fiber_list = loaded_config.pop("Fiber") fiber_list = loaded_config.pop("Fiber")
configs = [] configs = []
if fiber_list is not None: if fiber_list is not None:
master_variable = loaded_config.get("variable", {})
for i, params in enumerate(fiber_list): for i, params in enumerate(fiber_list):
params.setdefault("variable", loaded_config.get("variable", {}) if i == 0 else {}) params.setdefault("variable", master_variable if i == 0 else {})
if i == 0:
params["variable"] |= master_variable
configs.append(loaded_config | params) configs.append(loaded_config | params)
else: else:
configs.append(loaded_config) configs.append(loaded_config)
@@ -618,11 +619,21 @@ def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray
def translate_parameters(d: dict[str, Any]) -> dict[str, Any]: def translate_parameters(d: dict[str, Any]) -> dict[str, Any]:
old_names = dict(interp_degree="interpolation_degree") old_names = dict(
interp_degree="interpolation_degree",
beta="beta2_coefficients",
interp_range="interpolation_range",
)
deleted_names = {"lower_wavelength_interp_limit", "upper_wavelength_interp_limit"}
defaults_to_add = dict(repeat=1)
new = {} new = {}
for k, v in d.items(): for k, v in d.items():
if isinstance(v, MutableMapping): if k == "error_ok":
new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v
elif k in deleted_names:
continue
elif isinstance(v, MutableMapping):
new[k] = translate_parameters(v) new[k] = translate_parameters(v)
else: else:
new[old_names.get(k, k)] = v new[old_names.get(k, k)] = v
return new return defaults_to_add | new

View File

@@ -67,6 +67,7 @@ VALID_VARIABLE = {
"step_size", "step_size",
"interpolation_degree", "interpolation_degree",
"ideal_gas", "ideal_gas",
"length",
} }
MANDATORY_PARAMETERS = [ MANDATORY_PARAMETERS = [
@@ -91,6 +92,7 @@ MANDATORY_PARAMETERS = [
"dynamic_dispersion", "dynamic_dispersion",
"recovery_last_stored", "recovery_last_stored",
"output_path", "output_path",
"repeat",
] ]
@@ -428,11 +430,11 @@ class Parameters:
param["version"] = __version__ param["version"] = __version__
return param return param
def __post_init__(self): def compute(self, to_compute: list[str] = MANDATORY_PARAMETERS):
param_dict = {k: v for k, v in asdict(self).items() if v is not None} param_dict = {k: v for k, v in asdict(self).items() if v is not None}
evaluator = Evaluator.default() evaluator = Evaluator.default()
evaluator.set(**param_dict) evaluator.set(**param_dict)
for p_name in MANDATORY_PARAMETERS: for p_name in to_compute:
evaluator.compute(p_name) evaluator.compute(p_name)
valid_fields = self.all_parameters() valid_fields = self.all_parameters()
for k, v in evaluator.params.items(): for k, v in evaluator.params.items():
@@ -447,6 +449,12 @@ class Parameters:
def load(cls, path: os.PathLike) -> "Parameters": def load(cls, path: os.PathLike) -> "Parameters":
return cls(**utils.load_toml(path)) return cls(**utils.load_toml(path))
@classmethod
def load_and_compute(cls, path: os.PathLike) -> "Parameters":
p = cls.load(path)
p.compute()
return p
@staticmethod @staticmethod
def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]: def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]:
"""prepares a dictionary for serialization. Some keys may not be preserved """prepares a dictionary for serialization. Some keys may not be preserved
@@ -753,7 +761,6 @@ class Configuration:
num_sim: int num_sim: int
repeat: int repeat: int
z_num: int z_num: int
total_length: float
total_num_steps: int total_num_steps: int
worker_num: int worker_num: int
parallel: bool parallel: bool
@@ -789,7 +796,6 @@ class Configuration:
if self.name is None: if self.name is None:
self.name = Parameters.name.default self.name = Parameters.name.default
self.z_num = 0 self.z_num = 0
self.total_length = 0.0
self.total_num_steps = 0 self.total_num_steps = 0
self.sim_dirs = [] self.sim_dirs = []
self.overwrite = overwrite self.overwrite = overwrite
@@ -800,7 +806,6 @@ class Configuration:
names = set() names = set()
for i, config in enumerate(self.configs): for i, config in enumerate(self.configs):
self.z_num += config["z_num"] self.z_num += config["z_num"]
self.total_length += config["length"]
config.setdefault("name", f"{Parameters.name.default} {i}") config.setdefault("name", f"{Parameters.name.default} {i}")
given_name = config["name"] given_name = config["name"]
i = 0 i = 0
@@ -858,8 +863,8 @@ class Configuration:
) )
self.data_dirs[i].append(this_path) self.data_dirs[i].append(this_path)
this_conf.pop("variable") this_conf.pop("variable")
this_conf.update({k: v for k, v in this_vary if k != "num"}) conf_to_use = {k: v for k, v in this_vary if k != "num"} | this_conf
self.all_required[i].append((this_vary, this_conf)) self.all_required[i].append((this_vary, conf_to_use))
def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]: def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]:
for sim_paths, fiber in zip(self.data_dirs, self.all_required): for sim_paths, fiber in zip(self.data_dirs, self.all_required):
@@ -897,7 +902,9 @@ class Configuration:
task, config_dict = self.__decide(data_dir, config_dict) task, config_dict = self.__decide(data_dir, config_dict)
if task == self.Action.RUN: if task == self.Action.RUN:
sim_dict.pop(data_dir) sim_dict.pop(data_dir)
yield variable_list, data_dir, Parameters(**config_dict) p = Parameters(**config_dict)
p.compute()
yield variable_list, data_dir, p
if "recovery_last_stored" in config_dict and self.skip_callback is not None: if "recovery_last_stored" in config_dict and self.skip_callback is not None:
self.skip_callback(config_dict["recovery_last_stored"]) self.skip_callback(config_dict["recovery_last_stored"])
break break