From fb23786c70f63ad55dae172e0fba801dc09bd0c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 2 Sep 2021 15:53:30 +0200 Subject: [PATCH] parameters are no longer auto-computed --- README.md | 7 +++++-- src/scgenerator/cli/cli.py | 12 ++++++++++++ src/scgenerator/const.py | 2 +- src/scgenerator/physics/simulate.py | 12 ++++-------- src/scgenerator/scripts/__init__.py | 20 ++++++++++++++++++-- src/scgenerator/spectra.py | 1 + src/scgenerator/utils/__init__.py | 25 ++++++++++++++++++------- src/scgenerator/utils/parameter.py | 23 +++++++++++++++-------- 8 files changed, 74 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 3338609..d5890e5 100644 --- a/README.md +++ b/README.md @@ -230,12 +230,15 @@ soliton_num: float field_file : str if you have an initial field to use, convert it to a npz file with time (key : 'time') in s and electric field (key : 'field') in sqrt(W) (can be complex). You the use it with this config key. You can then scale it by settings any 1 of mean_power, energy and peak_power (priority is in this order) -quantum_noise: bool +quantum_noise : bool whether or not one-photon-per-mode quantum noise is activated. default : False -intensity_noise: float +intensity_noise : float relative intensity noise +noise_correlation : float + correlation between intensity noise and pulse width noise. a negative value means anti-correlation + shape: str {"gaussian", "sech"} shape of the pulse. default : gaussian diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index 42807b4..23c45dd 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -118,6 +118,13 @@ def create_parser(): ) init_plot_parser.set_defaults(func=plot_init) + convert_parser = subparsers.add_parser( + "convert", + help="convert parameter files that have been saved with an older version of the program", + ) + convert_parser.add_argument("config", help="path to config/parameter file") + convert_parser.set_defaults(func=translate_parameters) + return parser @@ -224,5 +231,10 @@ def plot_dispersion(args): scripts.plot_dispersion(args.config, lims) +def translate_parameters(args): + path = args.config + scripts.convert_params(path) + + if __name__ == "__main__": main() diff --git a/src/scgenerator/const.py b/src/scgenerator/const.py index fff0e62..974c44c 100644 --- a/src/scgenerator/const.py +++ b/src/scgenerator/const.py @@ -1,4 +1,4 @@ -__version__ = "0.2.1rules" +__version__ = "0.2.2rules" from typing import Any diff --git a/src/scgenerator/physics/simulate.py b/src/scgenerator/physics/simulate.py index 0f28412..cd62985 100644 --- a/src/scgenerator/physics/simulate.py +++ b/src/scgenerator/physics/simulate.py @@ -1,4 +1,3 @@ -from send2trash import send2trash import multiprocessing import multiprocessing.connection import os @@ -8,6 +7,7 @@ from pathlib import Path from typing import Any, Generator, Type import numpy as np +from send2trash import send2trash from .. import env, utils from ..logger import get_logger @@ -638,19 +638,15 @@ class RaySimulations(Simulations, priority=2): ) ) - self.propagator = ray.remote(RayRK4IP).options(runtime_env=dict(env_vars=env.all_environ())) + self.propagator = ray.remote(RayRK4IP) self.update_cluster_frequency = 3 self.jobs = [] self.pool = ray.util.ActorPool(self.propagator.remote() for _ in range(self.sim_jobs_total)) self.num_submitted = 0 self.rolling_id = 0 - self.p_actor = ( - ray.remote(utils.ProgressBarActor) - .options(runtime_env=dict(env_vars=env.all_environ())) - .remote( - self.configuration.name, self.sim_jobs_total, self.configuration.total_num_steps - ) + self.p_actor = ray.remote(utils.ProgressBarActor).remote( + self.configuration.name, self.sim_jobs_total, self.configuration.total_num_steps ) self.configuration.skip_callback = lambda num: ray.get(self.p_actor.update.remote(0, num)) diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index 10c5789..f249345 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -1,4 +1,5 @@ import itertools +import os from itertools import cycle from pathlib import Path from typing import Any, Iterable, Optional @@ -9,11 +10,11 @@ from cycler import cycler from tqdm import tqdm from .. import env, math -from ..const import PARAM_SEPARATOR +from ..const import PARAM_FN, PARAM_SEPARATOR from ..physics import fiber, units from ..plotting import plot_setup from ..spectra import Pulse -from ..utils import auto_crop, load_toml +from ..utils import auto_crop, load_toml, save_toml, translate_parameters from ..utils.parameter import ( Configuration, Parameters, @@ -262,3 +263,18 @@ def plot_helper(config_path: Path) -> Iterable[tuple[dict, list[str], Parameters for style, (variables, params) in zip(cc, pseq): lbl = [pretty_format_value(name, value) for name, value in variables[1:-1]] yield style, lbl, params + + +def convert_params(params_file: os.PathLike): + p = Path(params_file) + if p.name == PARAM_FN: + d = load_toml(params_file) + d = translate_parameters(d) + save_toml(params_file, d) + print(f"converted {p}") + else: + for pp in p.glob(PARAM_FN): + convert_params(pp) + for pp in p.glob("fiber*"): + if pp.is_dir(): + convert_params(pp) diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 7bd0225..6e8b089 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -150,6 +150,7 @@ class Pulse(Sequence): raise FileNotFoundError(f"Folder {self.path} does not exist") self.params = Parameters.load(self.path / "params.toml") + self.params.compute(["t", "l", "w_c", "w0", "z_targets"]) try: self.z = np.load(os.path.join(path, "z.npy")) diff --git a/src/scgenerator/utils/__init__.py b/src/scgenerator/utils/__init__.py index ed7a8ac..8447303 100644 --- a/src/scgenerator/utils/__init__.py +++ b/src/scgenerator/utils/__init__.py @@ -14,7 +14,6 @@ import re import shutil import threading from collections import abc -from copy import deepcopy from io import StringIO from pathlib import Path from string import printable as str_printable @@ -26,8 +25,7 @@ import toml from tqdm import tqdm from ..const import PARAM_FN, PARAM_SEPARATOR, SPEC1_FN, SPECN_FN, Z_FN, __version__ -from ..env import TMP_FOLDER_KEY_BASE, data_folder, pbar_policy -from ..errors import IncompleteDataFolderError +from ..env import pbar_policy from ..logger import get_logger T_ = TypeVar("T_") @@ -126,8 +124,11 @@ def load_config_sequence(final_config_path: os.PathLike) -> tuple[list[dict[str, fiber_list = loaded_config.pop("Fiber") configs = [] if fiber_list is not None: + master_variable = loaded_config.get("variable", {}) for i, params in enumerate(fiber_list): - params.setdefault("variable", loaded_config.get("variable", {}) if i == 0 else {}) + params.setdefault("variable", master_variable if i == 0 else {}) + if i == 0: + params["variable"] |= master_variable configs.append(loaded_config | params) else: configs.append(loaded_config) @@ -618,11 +619,21 @@ def auto_crop(x: np.ndarray, y: np.ndarray, rel_thr: float = 0.01) -> np.ndarray def translate_parameters(d: dict[str, Any]) -> dict[str, Any]: - old_names = dict(interp_degree="interpolation_degree") + old_names = dict( + interp_degree="interpolation_degree", + beta="beta2_coefficients", + interp_range="interpolation_range", + ) + deleted_names = {"lower_wavelength_interp_limit", "upper_wavelength_interp_limit"} + defaults_to_add = dict(repeat=1) new = {} for k, v in d.items(): - if isinstance(v, MutableMapping): + if k == "error_ok": + new["tolerated_error" if d.get("adapt_step_size", True) else "step_size"] = v + elif k in deleted_names: + continue + elif isinstance(v, MutableMapping): new[k] = translate_parameters(v) else: new[old_names.get(k, k)] = v - return new + return defaults_to_add | new diff --git a/src/scgenerator/utils/parameter.py b/src/scgenerator/utils/parameter.py index 4be75ca..05275ba 100644 --- a/src/scgenerator/utils/parameter.py +++ b/src/scgenerator/utils/parameter.py @@ -67,6 +67,7 @@ VALID_VARIABLE = { "step_size", "interpolation_degree", "ideal_gas", + "length", } MANDATORY_PARAMETERS = [ @@ -91,6 +92,7 @@ MANDATORY_PARAMETERS = [ "dynamic_dispersion", "recovery_last_stored", "output_path", + "repeat", ] @@ -428,11 +430,11 @@ class Parameters: param["version"] = __version__ return param - def __post_init__(self): + def compute(self, to_compute: list[str] = MANDATORY_PARAMETERS): param_dict = {k: v for k, v in asdict(self).items() if v is not None} evaluator = Evaluator.default() evaluator.set(**param_dict) - for p_name in MANDATORY_PARAMETERS: + for p_name in to_compute: evaluator.compute(p_name) valid_fields = self.all_parameters() for k, v in evaluator.params.items(): @@ -447,6 +449,12 @@ class Parameters: def load(cls, path: os.PathLike) -> "Parameters": return cls(**utils.load_toml(path)) + @classmethod + def load_and_compute(cls, path: os.PathLike) -> "Parameters": + p = cls.load(path) + p.compute() + return p + @staticmethod def strip_params_dict(dico: dict[str, Any]) -> dict[str, Any]: """prepares a dictionary for serialization. Some keys may not be preserved @@ -753,7 +761,6 @@ class Configuration: num_sim: int repeat: int z_num: int - total_length: float total_num_steps: int worker_num: int parallel: bool @@ -789,7 +796,6 @@ class Configuration: if self.name is None: self.name = Parameters.name.default self.z_num = 0 - self.total_length = 0.0 self.total_num_steps = 0 self.sim_dirs = [] self.overwrite = overwrite @@ -800,7 +806,6 @@ class Configuration: names = set() for i, config in enumerate(self.configs): self.z_num += config["z_num"] - self.total_length += config["length"] config.setdefault("name", f"{Parameters.name.default} {i}") given_name = config["name"] i = 0 @@ -858,8 +863,8 @@ class Configuration: ) self.data_dirs[i].append(this_path) this_conf.pop("variable") - this_conf.update({k: v for k, v in this_vary if k != "num"}) - self.all_required[i].append((this_vary, this_conf)) + conf_to_use = {k: v for k, v in this_vary if k != "num"} | this_conf + self.all_required[i].append((this_vary, conf_to_use)) def __iter__(self) -> Generator[tuple[list[tuple[str, Any]], Parameters], None, None]: for sim_paths, fiber in zip(self.data_dirs, self.all_required): @@ -897,7 +902,9 @@ class Configuration: task, config_dict = self.__decide(data_dir, config_dict) if task == self.Action.RUN: sim_dict.pop(data_dir) - yield variable_list, data_dir, Parameters(**config_dict) + p = Parameters(**config_dict) + p.compute() + yield variable_list, data_dir, p if "recovery_last_stored" in config_dict and self.skip_callback is not None: self.skip_callback(config_dict["recovery_last_stored"]) break