From b63a77cdd66d698528d00aa14957713b4ce9468c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Thu, 17 Jun 2021 11:35:27 +0200 Subject: [PATCH] misc --- src/scgenerator/cli/cli.py | 6 +- src/scgenerator/env.py | 13 ++-- src/scgenerator/initialize.py | 132 ++++++++++++++++++---------------- src/scgenerator/logger.py | 8 +-- 4 files changed, 88 insertions(+), 71 deletions(-) diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py index f4636c6..087eede 100644 --- a/src/scgenerator/cli/cli.py +++ b/src/scgenerator/cli/cli.py @@ -83,7 +83,7 @@ def main(): def run_sim(args): - method = prep_ray(args) + method = prep_ray() run_simulation_sequence(*args.configs, method=method) @@ -95,7 +95,7 @@ def merge(args): io.merge(args.output_name, path_trees) -def prep_ray(args): +def prep_ray(): logger = get_logger(__name__) if ray: if env.get(const.START_RAY): @@ -114,7 +114,7 @@ def prep_ray(args): def resume_sim(args): - method = prep_ray(args) + method = prep_ray() sim = resume_simulations(Path(args.sim_dir), method=method) sim.run() run_simulation_sequence(*args.configs, method=method, prev_sim_dir=sim.sim_dir) diff --git a/src/scgenerator/env.py b/src/scgenerator/env.py index 0f6ff82..6f44b71 100644 --- a/src/scgenerator/env.py +++ b/src/scgenerator/env.py @@ -19,10 +19,15 @@ def data_folder(task_id: int) -> Optional[str]: def get(key: str) -> Any: str_value = os.environ.get(key) - try: - return global_config[key]["type"](str_value) - except (ValueError, KeyError): - return None + if isinstance(str_value, str): + try: + t = global_config[key]["type"] + if t == bool: + return str_value.lower() == "true" + return t(str_value) + except (ValueError, KeyError): + pass + return None def all_environ() -> Dict[str, str]: diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index a3d726a..006e8dc 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -7,15 +7,14 @@ from collections import defaultdict import numpy as np from numpy import pi -from numpy.core.fromnumeric import var from . import io, utils from .defaults import default_parameters from .errors import * from .logger import get_logger -from .math import abs2, power_fact +from .math import power_fact from .physics import fiber, pulse, units -from .utils import count_variations, override_config, required_simulations, variable_iterator +from .utils import count_variations, override_config, required_simulations from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters @@ -31,34 +30,60 @@ class Params(BareParams): def compute(self): logger = get_logger(__name__) + self.__build_sim_grid() + did_set_custom_pulse = self.__compute_custom_pulse() + self.__compute_fiber() + if not did_set_custom_pulse: + logger.info(f"using generic input pulse of {self.shape.title()} shape") + self.__compute_generic_pulse() + + if self.quantum_noise and self.prev_sim_dir is None: + self.field_0 = self.field_0 + pulse.shot_noise( + self.w_c, self.w0, self.time_window, self.dt + ) + logger.info("added some quantum noise") + + self.spec_0 = np.fft.fft(self.field_0) + + def __build_sim_grid(self): build_sim_grid_in_place(self) - # Initial field may influence the grid - if self.mean_power is not None: - self.energy = self.mean_power / self.repetition_rate + def __compute_generic_pulse(self): ( - custom_field, self.width, + self.t0, self.peak_power, self.energy, - self.field_0, - ) = pulse.setup_custom_field(self) - if self.readjust_wavelength: - old_wl = self.wavelength - self.wavelength = pulse.correct_wavelength(self.wavelength, self.w_c, self.field_0) - logger.info(f"moved wavelength from {1e9*old_wl:.2f} to {1e9*self.wavelength:.2f}") - self.w_c, self.w0, self.w, self.w_power_fact = update_frequency_domain( - self.t, self.wavelength - ) + self.soliton_num, + ) = pulse.conform_pulse_params( + self.shape, + self.width, + self.t0, + self.peak_power, + self.energy, + self.soliton_num, + self.gamma, + self.beta[0], + ) + logger = get_logger(__name__) + logger.info(f"computed initial N = {self.soliton_num:.3g}") - if self.step_size is not None: - self.error_ok = self.step_size - self.adapt_step_size = False - else: - self.error_ok = self.tolerated_error - self.adapt_step_size = True + self.L_D = self.t0 ** 2 / abs(self.beta[0]) + self.L_NL = 1 / (self.gamma * self.peak_power) if self.gamma else np.inf + self.L_sol = pi / 2 * self.L_D + + # Technical noise + if self.intensity_noise is not None and self.intensity_noise > 0: + delta_int, delta_T0 = pulse.technical_noise(self.intensity_noise) + self.peak_power *= delta_int + self.t0 *= delta_T0 + self.width *= delta_T0 + + self.field_0 = pulse.initial_field(self.t, self.shape, self.t0, self.peak_power) + + def __compute_fiber(self): + logger = get_logger(__name__) - # FIBER self.interp_range = ( max(self.lower_wavelength_interp_limit, units.m.inv(np.max(self.w[self.w > 0]))), min(self.upper_wavelength_interp_limit, units.m.inv(np.min(self.w[self.w > 0]))), @@ -87,46 +112,33 @@ class Params(BareParams): if "raman" in self.behaviors: self.hr_w = fiber.delayed_raman_w(self.t, self.dt, self.raman_type) - # GENERIC PULSE - if not custom_field: - custom_field = False - ( - self.width, - self.t0, - self.peak_power, - self.energy, - self.soliton_num, - ) = pulse.conform_pulse_params( - self.shape, - self.width, - self.t0, - self.peak_power, - self.energy, - self.soliton_num, - self.gamma, - self.beta[0], - ) - logger.info(f"computed initial N = {self.soliton_num:.3g}") + def __compute_custom_pulse(self): + logger = get_logger(__name__) - self.L_D = self.t0 ** 2 / abs(self.beta[0]) - self.L_NL = 1 / (self.gamma * self.peak_power) if self.gamma else np.inf - self.L_sol = pi / 2 * self.L_D - - # Technical noise - if self.intensity_noise is not None and self.intensity_noise > 0: - delta_int, delta_T0 = pulse.technical_noise(self.intensity_noise) - self["peak_power"] *= delta_int - self["t0"] *= delta_T0 - self["width"] *= delta_T0 - - self.field_0 = pulse.initial_field(self.t, self.shape, self.t0, self.peak_power) - - if self.quantum_noise: - self.field_0 = self.field_0 + pulse.shot_noise( - self.w_c, self.w0, self.time_window, self.dt + if self.mean_power is not None: + self.energy = self.mean_power / self.repetition_rate + ( + did_set_custom_pulse, + self.width, + self.peak_power, + self.energy, + self.field_0, + ) = pulse.setup_custom_field(self) + if self.readjust_wavelength: + old_wl = self.wavelength + self.wavelength = pulse.correct_wavelength(self.wavelength, self.w_c, self.field_0) + logger.info(f"moved wavelength from {1e9*old_wl:.2f} to {1e9*self.wavelength:.2f}") + self.w_c, self.w0, self.w, self.w_power_fact = update_frequency_domain( + self.t, self.wavelength ) - self.spec_0 = np.fft.fft(self.field_0) + if self.step_size is not None: + self.error_ok = self.step_size + self.adapt_step_size = False + else: + self.error_ok = self.tolerated_error + self.adapt_step_size = True + return did_set_custom_pulse @dataclass diff --git a/src/scgenerator/logger.py b/src/scgenerator/logger.py index 04de7fb..c669740 100644 --- a/src/scgenerator/logger.py +++ b/src/scgenerator/logger.py @@ -49,16 +49,16 @@ def configure_logger(logger: logging.Logger): """ if not hasattr(logger, "already_configured"): - print_lvl = lvl_map.get(log_print_level()) - file_lvl = lvl_map.get(log_file_level()) + print_lvl = lvl_map.get(log_print_level(), logging.NOTSET) + file_lvl = lvl_map.get(log_file_level(), logging.NOTSET) - if file_lvl is not None: + if file_lvl > logging.NOTSET: formatter = logging.Formatter("{levelname}: {name}: {message}", style="{") file_handler1 = logging.FileHandler("scgenerator.log", "a+") file_handler1.setFormatter(formatter) file_handler1.setLevel(file_lvl) logger.addHandler(file_handler1) - if print_lvl is not None: + if print_lvl > logging.NOTSET: stream_handler = logging.StreamHandler() stream_handler.setLevel(print_lvl) logger.addHandler(stream_handler)