This commit is contained in:
Benoît Sierro
2021-06-17 11:35:27 +02:00
parent b315f501b6
commit b63a77cdd6
4 changed files with 88 additions and 71 deletions

View File

@@ -83,7 +83,7 @@ def main():
def run_sim(args): def run_sim(args):
method = prep_ray(args) method = prep_ray()
run_simulation_sequence(*args.configs, method=method) run_simulation_sequence(*args.configs, method=method)
@@ -95,7 +95,7 @@ def merge(args):
io.merge(args.output_name, path_trees) io.merge(args.output_name, path_trees)
def prep_ray(args): def prep_ray():
logger = get_logger(__name__) logger = get_logger(__name__)
if ray: if ray:
if env.get(const.START_RAY): if env.get(const.START_RAY):
@@ -114,7 +114,7 @@ def prep_ray(args):
def resume_sim(args): def resume_sim(args):
method = prep_ray(args) method = prep_ray()
sim = resume_simulations(Path(args.sim_dir), method=method) sim = resume_simulations(Path(args.sim_dir), method=method)
sim.run() sim.run()
run_simulation_sequence(*args.configs, method=method, prev_sim_dir=sim.sim_dir) run_simulation_sequence(*args.configs, method=method, prev_sim_dir=sim.sim_dir)

View File

@@ -19,9 +19,14 @@ def data_folder(task_id: int) -> Optional[str]:
def get(key: str) -> Any: def get(key: str) -> Any:
str_value = os.environ.get(key) str_value = os.environ.get(key)
if isinstance(str_value, str):
try: try:
return global_config[key]["type"](str_value) t = global_config[key]["type"]
if t == bool:
return str_value.lower() == "true"
return t(str_value)
except (ValueError, KeyError): except (ValueError, KeyError):
pass
return None return None

View File

@@ -7,15 +7,14 @@ from collections import defaultdict
import numpy as np import numpy as np
from numpy import pi from numpy import pi
from numpy.core.fromnumeric import var
from . import io, utils from . import io, utils
from .defaults import default_parameters from .defaults import default_parameters
from .errors import * from .errors import *
from .logger import get_logger from .logger import get_logger
from .math import abs2, power_fact from .math import power_fact
from .physics import fiber, pulse, units from .physics import fiber, pulse, units
from .utils import count_variations, override_config, required_simulations, variable_iterator from .utils import count_variations, override_config, required_simulations
from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters from .utils.parameter import BareConfig, BareParams, hc_model_specific_parameters
@@ -31,34 +30,60 @@ class Params(BareParams):
def compute(self): def compute(self):
logger = get_logger(__name__) logger = get_logger(__name__)
self.__build_sim_grid()
did_set_custom_pulse = self.__compute_custom_pulse()
self.__compute_fiber()
if not did_set_custom_pulse:
logger.info(f"using generic input pulse of {self.shape.title()} shape")
self.__compute_generic_pulse()
if self.quantum_noise and self.prev_sim_dir is None:
self.field_0 = self.field_0 + pulse.shot_noise(
self.w_c, self.w0, self.time_window, self.dt
)
logger.info("added some quantum noise")
self.spec_0 = np.fft.fft(self.field_0)
def __build_sim_grid(self):
build_sim_grid_in_place(self) build_sim_grid_in_place(self)
# Initial field may influence the grid def __compute_generic_pulse(self):
if self.mean_power is not None:
self.energy = self.mean_power / self.repetition_rate
( (
custom_field,
self.width, self.width,
self.t0,
self.peak_power, self.peak_power,
self.energy, self.energy,
self.field_0, self.soliton_num,
) = pulse.setup_custom_field(self) ) = pulse.conform_pulse_params(
if self.readjust_wavelength: self.shape,
old_wl = self.wavelength self.width,
self.wavelength = pulse.correct_wavelength(self.wavelength, self.w_c, self.field_0) self.t0,
logger.info(f"moved wavelength from {1e9*old_wl:.2f} to {1e9*self.wavelength:.2f}") self.peak_power,
self.w_c, self.w0, self.w, self.w_power_fact = update_frequency_domain( self.energy,
self.t, self.wavelength self.soliton_num,
self.gamma,
self.beta[0],
) )
logger = get_logger(__name__)
logger.info(f"computed initial N = {self.soliton_num:.3g}")
if self.step_size is not None: self.L_D = self.t0 ** 2 / abs(self.beta[0])
self.error_ok = self.step_size self.L_NL = 1 / (self.gamma * self.peak_power) if self.gamma else np.inf
self.adapt_step_size = False self.L_sol = pi / 2 * self.L_D
else:
self.error_ok = self.tolerated_error # Technical noise
self.adapt_step_size = True if self.intensity_noise is not None and self.intensity_noise > 0:
delta_int, delta_T0 = pulse.technical_noise(self.intensity_noise)
self.peak_power *= delta_int
self.t0 *= delta_T0
self.width *= delta_T0
self.field_0 = pulse.initial_field(self.t, self.shape, self.t0, self.peak_power)
def __compute_fiber(self):
logger = get_logger(__name__)
# FIBER
self.interp_range = ( self.interp_range = (
max(self.lower_wavelength_interp_limit, units.m.inv(np.max(self.w[self.w > 0]))), max(self.lower_wavelength_interp_limit, units.m.inv(np.max(self.w[self.w > 0]))),
min(self.upper_wavelength_interp_limit, units.m.inv(np.min(self.w[self.w > 0]))), min(self.upper_wavelength_interp_limit, units.m.inv(np.min(self.w[self.w > 0]))),
@@ -87,46 +112,33 @@ class Params(BareParams):
if "raman" in self.behaviors: if "raman" in self.behaviors:
self.hr_w = fiber.delayed_raman_w(self.t, self.dt, self.raman_type) self.hr_w = fiber.delayed_raman_w(self.t, self.dt, self.raman_type)
# GENERIC PULSE def __compute_custom_pulse(self):
if not custom_field: logger = get_logger(__name__)
custom_field = False
if self.mean_power is not None:
self.energy = self.mean_power / self.repetition_rate
( (
did_set_custom_pulse,
self.width, self.width,
self.t0,
self.peak_power, self.peak_power,
self.energy, self.energy,
self.soliton_num, self.field_0,
) = pulse.conform_pulse_params( ) = pulse.setup_custom_field(self)
self.shape, if self.readjust_wavelength:
self.width, old_wl = self.wavelength
self.t0, self.wavelength = pulse.correct_wavelength(self.wavelength, self.w_c, self.field_0)
self.peak_power, logger.info(f"moved wavelength from {1e9*old_wl:.2f} to {1e9*self.wavelength:.2f}")
self.energy, self.w_c, self.w0, self.w, self.w_power_fact = update_frequency_domain(
self.soliton_num, self.t, self.wavelength
self.gamma,
self.beta[0],
)
logger.info(f"computed initial N = {self.soliton_num:.3g}")
self.L_D = self.t0 ** 2 / abs(self.beta[0])
self.L_NL = 1 / (self.gamma * self.peak_power) if self.gamma else np.inf
self.L_sol = pi / 2 * self.L_D
# Technical noise
if self.intensity_noise is not None and self.intensity_noise > 0:
delta_int, delta_T0 = pulse.technical_noise(self.intensity_noise)
self["peak_power"] *= delta_int
self["t0"] *= delta_T0
self["width"] *= delta_T0
self.field_0 = pulse.initial_field(self.t, self.shape, self.t0, self.peak_power)
if self.quantum_noise:
self.field_0 = self.field_0 + pulse.shot_noise(
self.w_c, self.w0, self.time_window, self.dt
) )
self.spec_0 = np.fft.fft(self.field_0) if self.step_size is not None:
self.error_ok = self.step_size
self.adapt_step_size = False
else:
self.error_ok = self.tolerated_error
self.adapt_step_size = True
return did_set_custom_pulse
@dataclass @dataclass

View File

@@ -49,16 +49,16 @@ def configure_logger(logger: logging.Logger):
""" """
if not hasattr(logger, "already_configured"): if not hasattr(logger, "already_configured"):
print_lvl = lvl_map.get(log_print_level()) print_lvl = lvl_map.get(log_print_level(), logging.NOTSET)
file_lvl = lvl_map.get(log_file_level()) file_lvl = lvl_map.get(log_file_level(), logging.NOTSET)
if file_lvl is not None: if file_lvl > logging.NOTSET:
formatter = logging.Formatter("{levelname}: {name}: {message}", style="{") formatter = logging.Formatter("{levelname}: {name}: {message}", style="{")
file_handler1 = logging.FileHandler("scgenerator.log", "a+") file_handler1 = logging.FileHandler("scgenerator.log", "a+")
file_handler1.setFormatter(formatter) file_handler1.setFormatter(formatter)
file_handler1.setLevel(file_lvl) file_handler1.setLevel(file_lvl)
logger.addHandler(file_handler1) logger.addHandler(file_handler1)
if print_lvl is not None: if print_lvl > logging.NOTSET:
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
stream_handler.setLevel(print_lvl) stream_handler.setLevel(print_lvl)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)