From 514bd0a433188ec3a6de84cd8f33095c2b9549a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Sierro?= Date: Wed, 3 Nov 2021 17:00:27 +0100 Subject: [PATCH] unit fixes --- src/scgenerator/physics/units.py | 88 ++++++++++++++++++-------------- testing/test_full_field.py | 2 - 2 files changed, 51 insertions(+), 39 deletions(-) diff --git a/src/scgenerator/physics/units.py b/src/scgenerator/physics/units.py index 0d78a0b..23f3748 100644 --- a/src/scgenerator/physics/units.py +++ b/src/scgenerator/physics/units.py @@ -3,11 +3,10 @@ # to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ... from __future__ import annotations -from collections import defaultdict -from typing import Callable, TypeVar, Union -from functools import wraps from operator import itemgetter +from typing import Callable, TypeVar, Union + import numpy as np from numpy import pi @@ -30,37 +29,42 @@ types are "WL", "FREQ", "AFREQ", "TIME", "PRESSURE", "TEMPERATURE", "OTHER" """ _T = TypeVar("_T") _UT = Callable[[_T], _T] - - -def chain(c1: _UT, c2: _UT) -> _UT: - def chained_function(x: _T) -> _T: - return c1(c2(x)) - - return chained_function +PRIMARIES = dict(WL="Rad/s", FREQ="Rad/s", AFREQ="Rad/s", TIME="s", PRESSURE="Pa", TEMPERATURE="K") class UnitMap(dict): def __setitem__(self, new_name, new_func): super().__setitem__(new_name, new_func) - already_here = [(name, func) for name, func in self.items() if isinstance(name, str)] - for name, func in already_here: - super().__setitem__((name, new_name), chain(self[name].inv, new_func)) - super().__setitem__((new_name, name), chain(self[new_name].inv, func)) + already_here = [name for name in self if isinstance(name, str)] + for old_name in already_here: + super().__setitem__((old_name, new_name), self._chain(old_name, new_name)) + super().__setitem__((new_name, old_name), self._chain(new_name, old_name)) + + def _chain(self, name_1: str, name_2: str) -> _UT: + c1 = self[name_1] + c2 = self[name_2] + + def chained_function(x: _T) -> _T: + return c2.inv(c1(x)) + + chained_function.__name__ = f"{name_1}_to_{name_2}" + chained_function.__doc__ = f"converts x from {name_1} to {name_2}" + + return chained_function -units_map: dict[str, dict[Union[str, tuple[str, str]], Unit]] = defaultdict(UnitMap) +units_map: dict[Union[str, tuple[str, str]], Unit] = UnitMap() class To: - def __init__(self, name: str, tpe: str): + def __init__(self, name: str): self.name = name - self.type = tpe def __getattr__(self, key: str): try: - return units_map[self.type][self.name, key] + return units_map[self.name, key] except KeyError: - raise KeyError(f"no registered unit named {key!r} of type {self.type!r}") from None + raise KeyError(f"no registered unit named {key!r}") from None def W_to_Vm(n0: float, A_eff: float) -> float: @@ -80,35 +84,45 @@ def W_to_Vm(n0: float, A_eff: float) -> float: class Unit: - __func: _UT - to: To + func: _UT inv: _UT - name: str - label: str + to: To = None + name: str = "unit" + type: str = "other" + label: str = "" - def __init__(self, func: _UT, inv: _UT, name: str, label: str, tpe: str): - self.__func = func - self.inv = inv - self.to = To(name, tpe) - self.name = name - self.label = label - self.type = tpe - self.__name__ = name + def __init_subclass__(cls): + def call(self, x: _T) -> _T: + return self.func(x) + + call.__doc__ = f"Transform x from {cls.name!r} to {PRIMARIES.get(cls.type)!r}" + cls.__call__ = call + + def __init__(self, func: _UT, inv: _UT = None): + self.func = func + if inv is None: + self.inv = func + else: + self.inv = inv + self.to = To(self.name) + units_map[self.name] = self def __call__(self, x: _T) -> _T: """call the original unit function""" - return self.__func(x) + return self.func(x) + + def inverse(self, func: _UT): + self.inv = func + return self def unit(tpe: str, label: str, inv: Callable = None): def unit_maker(func) -> Unit: nonlocal inv name = func.__name__ - if inv is None: - inv = func - unit_obj = wraps(func)(Unit(func, inv, name, label, tpe)) - units_map[tpe][name] = unit_obj - return unit_obj + unit_type = type(f"Unit_{name}", (Unit,), dict(name=name, label=label, type=tpe)) + + return unit_type(func, inv) return unit_maker diff --git a/testing/test_full_field.py b/testing/test_full_field.py index 1a338a8..0218dcd 100644 --- a/testing/test_full_field.py +++ b/testing/test_full_field.py @@ -30,8 +30,6 @@ def main(): spec_ax.set_xlabel(rs.unit.label) field_ax = app[1] field_ax.set_xlabel(rt.unit.label) - x: float = 4.5 - y = sc.units.m.to.nm(x) @app.update def draw(i):