diff --git a/src/scgenerator/_utils/parameter.py b/src/scgenerator/_utils/parameter.py index 90dda50..347cc15 100644 --- a/src/scgenerator/_utils/parameter.py +++ b/src/scgenerator/_utils/parameter.py @@ -54,6 +54,7 @@ VALID_VARIABLE = { "pitch_ratio", "effective_mode_diameter", "core_radius", + "model", "capillary_num", "capillary_radius", "capillary_thickness", @@ -99,6 +100,7 @@ MANDATORY_PARAMETERS = [ "alpha", "spec_0", "field_0", + "mean_power", "input_transmission", "z_targets", "length", @@ -395,7 +397,7 @@ class Parameters(_AbstractParameters): # pulse field_file: str = Parameter(string) repetition_rate: float = Parameter( - non_negative(float, int), display_info=(1e-6, "MHz"), default=40e6 + non_negative(float, int), display_info=(1e-3, "kHz"), default=40e6 ) peak_power: float = Parameter(positive(float, int), display_info=(1e-3, "kW")) mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW")) @@ -479,6 +481,12 @@ class Parameters(_AbstractParameters): setattr(self, k, v) return results + def pformat(self) -> str: + return "\n".join( + f"{k} = {VariationDescriptor.format_value(k, v)}" + for k, v in self.prepare_for_dump().items() + ) + @classmethod def all_parameters(cls) -> list[str]: return [f.name for f in fields(cls)] @@ -749,13 +757,13 @@ class Evaluator: else: default = self.get_default(target) if default is None: - error = NoDefaultError( + error = error or NoDefaultError( prefix + f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}" ) else: value = default - self.logger.info(f"using default value of {value} for {target}") + self.logger.info(prefix + f"using default value of {value} for {target}") self.set_value(target, value, 0) if value is None and error is not None: @@ -1089,6 +1097,7 @@ default_rules: list[Rule] = [ Rule("peak_power", pulse.soliton_num_to_peak_power), Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]), Rule("energy", pulse.mean_power_to_energy, priorities=2), + Rule("mean_power", pulse.energy_to_mean_power), Rule("t0", pulse.width_to_t0), Rule("t0", pulse.soliton_num_to_t0), Rule("width", pulse.t0_to_width), diff --git a/src/scgenerator/_utils/utils.py b/src/scgenerator/_utils/utils.py index 2931f09..38ce1ef 100644 --- a/src/scgenerator/_utils/utils.py +++ b/src/scgenerator/_utils/utils.py @@ -171,7 +171,10 @@ def combine_simulations(path: Path, dest: Path = None): if p.is_dir(): paths[p.name.split()[1]].append(p) for l in paths.values(): - l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0]) + try: + l.sort(key=lambda el: re.search(r"(?<=num )[0-9]+", el.name)[0]) + except TypeError: + pass for pulses in paths.values(): new_path = dest / update_path(pulses[0].name) os.makedirs(new_path, exist_ok=True) diff --git a/src/scgenerator/_utils/variationer.py b/src/scgenerator/_utils/variationer.py index 538ca4f..2e8ba9f 100644 --- a/src/scgenerator/_utils/variationer.py +++ b/src/scgenerator/_utils/variationer.py @@ -9,7 +9,6 @@ from pydantic import validator from pydantic.main import BaseModel from ..const import PARAM_SEPARATOR -from . import utils T = TypeVar("T") @@ -135,6 +134,23 @@ class VariationDescriptor(BaseModel): """ cls._format_registry[p_name] = func + @classmethod + def format_value(cls, name: str, value) -> str: + if value is True or value is False: + return str(value) + elif isinstance(value, (float, int)): + try: + return cls._format_registry[name](value) + except KeyError: + return format(value, ".9g") + elif isinstance(value, (list, tuple, np.ndarray)): + return "-".join([str(v) for v in value]) + elif isinstance(value, str): + p = Path(value) + if p.exists(): + return p.stem + return str(value) + class Config: allow_mutation = False @@ -166,22 +182,6 @@ class VariationDescriptor(BaseModel): self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name ) - def format_value(self, name: str, value) -> str: - if value is True or value is False: - return str(value) - elif isinstance(value, (float, int)): - try: - return self._format_registry[name](value) - except KeyError: - return format(value, ".9g") - elif isinstance(value, (list, tuple, np.ndarray)): - return "-".join([str(v) for v in value]) - elif isinstance(value, str): - p = Path(value) - if p.exists(): - return p.stem - return str(value) - def __getitem__(self, key) -> "VariationDescriptor": return VariationDescriptor( raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index f9e5102..cbd8d2e 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -289,6 +289,10 @@ def mean_power_to_energy(mean_power: float, repetition_rate: float) -> float: return mean_power / repetition_rate +def energy_to_mean_power(energy: float, repetition_rate: float) -> float: + return energy * repetition_rate + + def soliton_num_to_peak_power(soliton_num, beta2, gamma, t0): return soliton_num ** 2 * abs(beta2) / (gamma * t0 ** 2) diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index 1d5fc8f..a54407d 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -128,6 +128,7 @@ def plot_init( lbl = plot_1_dispersion(lim_disp, tl, tr, style, lbl, params, loss_ax) lbl = plot_1_init_spec_field(lim_field, lim_spec, bl, br, style, lbl, params) all_labels.append(lbl) + print(params.pformat()) finish_plot(fig, tr, all_labels, params)