diff --git a/src/scgenerator/parameter.py b/src/scgenerator/parameter.py index 8f0f6e9..d9c9ae5 100644 --- a/src/scgenerator/parameter.py +++ b/src/scgenerator/parameter.py @@ -42,6 +42,7 @@ VALID_VARIABLE = { "pitch_ratio", "effective_mode_diameter", "core_radius", + "model", "capillary_num", "capillary_radius", "capillary_thickness", @@ -87,6 +88,7 @@ MANDATORY_PARAMETERS = [ "alpha", "spec_0", "field_0", + "mean_power", "input_transmission", "z_targets", "length", @@ -256,13 +258,17 @@ def func_validator(name, n): class Parameter: - def __init__(self, validator, converter=None, default=None, display_info=None, rules=None): + def __init__( + self, + validator: Callable[[str, Any], None], + converter: Callable = None, + default=None, + display_info: tuple[float, str] = None, + ): """Single parameter Parameters ---------- - tpe : type - type of the paramter validator : Callable[[str, Any], None] signature : validator(name, value) must raise a ValueError when value doesn't fit the criteria checked by @@ -277,21 +283,17 @@ class Parameter: self.converter = converter self.default = default self.display_info = display_info - if rules is None: - self.rules = [] - else: - self.rules = rules def __set_name__(self, owner, name): self.name = name def __get__(self, instance, owner): - if not instance: + if instance is None: return self return instance.__dict__[self.name] def __delete__(self, instance): - del instance.__dict__[self.name] + raise AttributeError("Cannot delete parameter") def __set__(self, instance, value): if isinstance(value, Parameter): @@ -351,9 +353,9 @@ class Parameters(_AbstractParameters): A_eff: float = Parameter(non_negative(float, int)) A_eff_file: str = Parameter(string) numerical_aperture: float = Parameter(in_range_excl(0, 1)) - pitch: float = Parameter(in_range_excl(0, 1e-3)) + pitch: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm")) pitch_ratio: float = Parameter(in_range_excl(0, 1)) - core_radius: float = Parameter(in_range_excl(0, 1e-3)) + core_radius: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm")) he_mode: tuple[int, int] = Parameter(int_pair, default=(1, 1)) fit_parameters: tuple[int, int] = Parameter(float_pair, default=(0.08, 200e-9)) beta2_coefficients: Iterable[float] = Parameter(num_list) @@ -361,11 +363,11 @@ class Parameters(_AbstractParameters): model: str = Parameter( literal("pcf", "marcatili", "marcatili_adjusted", "hasan", "custom"), default="custom" ) - length: float = Parameter(non_negative(float, int)) + length: float = Parameter(non_negative(float, int), display_info=(1e2, "cm")) capillary_num: int = Parameter(positive(int)) - capillary_radius: float = Parameter(in_range_excl(0, 1e-3)) - capillary_thickness: float = Parameter(in_range_excl(0, 1e-3)) - capillary_spacing: float = Parameter(in_range_excl(0, 1e-3)) + capillary_radius: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm")) + capillary_thickness: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm")) + capillary_spacing: float = Parameter(in_range_excl(0, 1e-3), display_info=(1e6, "μm")) capillary_resonance_strengths: Iterable[float] = Parameter( validator_list(type_checker(int, float, np.ndarray)) ) @@ -383,7 +385,7 @@ class Parameters(_AbstractParameters): # pulse field_file: str = Parameter(string) repetition_rate: float = Parameter( - non_negative(float, int), display_info=(1e-6, "MHz"), default=40e6 + non_negative(float, int), display_info=(1e-3, "kHz"), default=40e6 ) peak_power: float = Parameter(positive(float, int), display_info=(1e-3, "kW")) mean_power: float = Parameter(positive(float, int), display_info=(1e3, "mW")) @@ -467,6 +469,12 @@ class Parameters(_AbstractParameters): setattr(self, k, v) return results + def pformat(self) -> str: + return "\n".join( + f"{k} = {VariationDescriptor.format_value(k, v)}" + for k, v in self.prepare_for_dump().items() + ) + @classmethod def all_parameters(cls) -> list[str]: return [f.name for f in fields(cls)] @@ -556,6 +564,9 @@ class Rule: def __repr__(self) -> str: return f"Rule(targets={self.targets!r}, func={self.func!r}, args={self.args!r})" + def __str__(self) -> str: + return f"[{', '.join(self.args)}] -- {self.func.__module__}.{self.func.__name__} --> [{', '.join(self.targets)}]" + @classmethod def deduce( cls, @@ -629,7 +640,8 @@ class Evaluator: def __init__(self): self.rules: dict[str, list[Rule]] = defaultdict(list) self.params = {} - self.__curent_lookup = set() + self.__curent_lookup: list[str] = [] + self.__failed_rules: dict[str, list[Rule]] = defaultdict(list) self.eval_stats: dict[str, EvalStat] = defaultdict(EvalStat) self.logger = get_logger(__name__) @@ -683,10 +695,11 @@ class Evaluator: raise EvaluatorError( "cyclic dependency detected : " f"{target!r} seems to depend on itself, " - f"please provide a value for at least one variable in {self.__curent_lookup}" + f"please provide a value for at least one variable in {self.__curent_lookup!r}. " + + self.attempted_rules_str(target) ) else: - self.__curent_lookup.add(target) + self.__curent_lookup.append(target) if len(self.rules[target]) == 0: error = EvaluatorError(f"no rule for {target}") @@ -733,23 +746,27 @@ class Evaluator: self.logger.debug( prefix + f"error using {rule.func.__name__} : {str(error).strip()}" ) + self.__failed_rules[target].append(rule) continue else: default = self.get_default(target) if default is None: - error = NoDefaultError( + error = error or NoDefaultError( prefix - + f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}" + + f"No default provided for {target}. Current lookup cycle : {self.__curent_lookup!r}. " + + self.attempted_rules_str(target) ) else: value = default - self.logger.info(f"using default value of {value} for {target}") + self.logger.info(prefix + f"using default value of {value} for {target}") self.set_value(target, value, 0) + assert target == self.__curent_lookup.pop() + self.__failed_rules[target] = [] + if value is None and error is not None: raise error - self.__curent_lookup.remove(target) return value def __getitem__(self, key: str) -> Any: @@ -762,23 +779,11 @@ class Evaluator: def validate_condition(self, rule: Rule) -> bool: return all(self.compute(k) == v for k, v in rule.conditions.items()) - def __call__(self, target: str, args: list[str] = None): - """creates a wrapper that adds decorated functions to the set of rules - - Parameters - ---------- - target : str - name of the target - args : list[str], optional - list of name of arguments. Automatically deduced from function signature if - not provided, by default None - """ - - def wrapper(func): - self.append(Rule(target, func, args)) - return func - - return wrapper + def attempted_rules_str(self, target: str) -> str: + rules = ", ".join(str(r) for r in self.__failed_rules[target]) + if len(rules) == 0: + return "" + return "attempted rules : " + rules class Configuration: @@ -1075,6 +1080,7 @@ default_rules: list[Rule] = [ ), Rule("peak_power", pulse.E0_to_P0, ["energy", "t0", "shape"]), Rule("peak_power", pulse.soliton_num_to_peak_power), + Rule("mean_power", pulse.energy_to_mean_power), Rule("energy", pulse.P0_to_E0, ["peak_power", "t0", "shape"]), Rule("energy", pulse.mean_power_to_energy, priorities=2), Rule("t0", pulse.width_to_t0), diff --git a/src/scgenerator/physics/pulse.py b/src/scgenerator/physics/pulse.py index f9e5102..53cd159 100644 --- a/src/scgenerator/physics/pulse.py +++ b/src/scgenerator/physics/pulse.py @@ -289,6 +289,10 @@ def mean_power_to_energy(mean_power: float, repetition_rate: float) -> float: return mean_power / repetition_rate +def energy_to_mean_power(energy: float, repetition_rate: float) -> float: + return energy * repetition_rate + + def soliton_num_to_peak_power(soliton_num, beta2, gamma, t0): return soliton_num ** 2 * abs(beta2) / (gamma * t0 ** 2) @@ -367,7 +371,11 @@ def P0_to_E0(P0, t0, shape): def sech_pulse(t, t0, P0, offset=0): - return np.sqrt(P0) / np.cosh((t - offset) / t0) + arg = (t - offset) / t0 + ind = (arg < 700) & (arg > -700) + out = np.zeros_like(t) + out[ind] = np.sqrt(P0) / np.cosh(arg[ind]) + return out def gaussian_pulse(t, t0, P0, offset=0): diff --git a/src/scgenerator/scripts/__init__.py b/src/scgenerator/scripts/__init__.py index a50f34e..f223f10 100644 --- a/src/scgenerator/scripts/__init__.py +++ b/src/scgenerator/scripts/__init__.py @@ -26,8 +26,10 @@ def plot_all(sim_dir: Path, limits: list[str], show=False, **opts): for k, v in opts.items(): if k in ["skip"]: opts[k] = int(v) - if k in {"log", "renormalize"}: - opts[k] = True if v == "True" else False + if v == "True": + opts[k] = True + elif v == "False": + opts[k] = False dir_list = simulations_list(sim_dir) if len(dir_list) == 0: dir_list = [sim_dir] @@ -124,6 +126,7 @@ def plot_init( lbl = plot_1_dispersion(lim_disp, tl, tr, style, lbl, params, loss_ax) lbl = plot_1_init_spec_field(lim_field, lim_spec, bl, br, style, lbl, params) all_labels.append(lbl) + print(params.pformat()) finish_plot(fig, tr, all_labels, params) diff --git a/src/scgenerator/variationer.py b/src/scgenerator/variationer.py index eb96c49..b7cedd4 100644 --- a/src/scgenerator/variationer.py +++ b/src/scgenerator/variationer.py @@ -134,6 +134,23 @@ class VariationDescriptor(BaseModel): """ cls._format_registry[p_name] = func + @classmethod + def format_value(cls, name: str, value) -> str: + if value is True or value is False: + return str(value) + elif isinstance(value, (float, int)): + try: + return cls._format_registry[name](value) + except KeyError: + return format(value, ".9g") + elif isinstance(value, (list, tuple, np.ndarray)): + return "-".join([str(v) for v in value]) + elif isinstance(value, str): + p = Path(value) + if p.exists(): + return p.stem + return str(value) + class Config: allow_mutation = False @@ -165,22 +182,6 @@ class VariationDescriptor(BaseModel): self.identifier + PARAM_SEPARATOR + self.branch.identifier + PARAM_SEPARATOR + tmp_name ) - def format_value(self, name: str, value) -> str: - if value is True or value is False: - return str(value) - elif isinstance(value, (float, int)): - try: - return self._format_registry[name](value) - except KeyError: - return format(value, ".9g") - elif isinstance(value, (list, tuple, np.ndarray)): - return "-".join([str(v) for v in value]) - elif isinstance(value, str): - p = Path(value) - if p.exists(): - return p.stem - return str(value) - def __getitem__(self, key) -> "VariationDescriptor": return VariationDescriptor( raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator