unit fixes

This commit is contained in:
Benoît Sierro
2021-11-03 17:00:27 +01:00
parent 50c9128fa5
commit 514bd0a433
2 changed files with 51 additions and 39 deletions

View File

@@ -3,11 +3,10 @@
# to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ... # to be used especially when giving plotting ranges : (400, 1400, nm), (-4, 8, ps), ...
from __future__ import annotations from __future__ import annotations
from collections import defaultdict
from typing import Callable, TypeVar, Union
from functools import wraps
from operator import itemgetter from operator import itemgetter
from typing import Callable, TypeVar, Union
import numpy as np import numpy as np
from numpy import pi from numpy import pi
@@ -30,37 +29,42 @@ types are "WL", "FREQ", "AFREQ", "TIME", "PRESSURE", "TEMPERATURE", "OTHER"
""" """
_T = TypeVar("_T") _T = TypeVar("_T")
_UT = Callable[[_T], _T] _UT = Callable[[_T], _T]
PRIMARIES = dict(WL="Rad/s", FREQ="Rad/s", AFREQ="Rad/s", TIME="s", PRESSURE="Pa", TEMPERATURE="K")
def chain(c1: _UT, c2: _UT) -> _UT:
def chained_function(x: _T) -> _T:
return c1(c2(x))
return chained_function
class UnitMap(dict): class UnitMap(dict):
def __setitem__(self, new_name, new_func): def __setitem__(self, new_name, new_func):
super().__setitem__(new_name, new_func) super().__setitem__(new_name, new_func)
already_here = [(name, func) for name, func in self.items() if isinstance(name, str)] already_here = [name for name in self if isinstance(name, str)]
for name, func in already_here: for old_name in already_here:
super().__setitem__((name, new_name), chain(self[name].inv, new_func)) super().__setitem__((old_name, new_name), self._chain(old_name, new_name))
super().__setitem__((new_name, name), chain(self[new_name].inv, func)) super().__setitem__((new_name, old_name), self._chain(new_name, old_name))
def _chain(self, name_1: str, name_2: str) -> _UT:
c1 = self[name_1]
c2 = self[name_2]
def chained_function(x: _T) -> _T:
return c2.inv(c1(x))
chained_function.__name__ = f"{name_1}_to_{name_2}"
chained_function.__doc__ = f"converts x from {name_1} to {name_2}"
return chained_function
units_map: dict[str, dict[Union[str, tuple[str, str]], Unit]] = defaultdict(UnitMap) units_map: dict[Union[str, tuple[str, str]], Unit] = UnitMap()
class To: class To:
def __init__(self, name: str, tpe: str): def __init__(self, name: str):
self.name = name self.name = name
self.type = tpe
def __getattr__(self, key: str): def __getattr__(self, key: str):
try: try:
return units_map[self.type][self.name, key] return units_map[self.name, key]
except KeyError: except KeyError:
raise KeyError(f"no registered unit named {key!r} of type {self.type!r}") from None raise KeyError(f"no registered unit named {key!r}") from None
def W_to_Vm(n0: float, A_eff: float) -> float: def W_to_Vm(n0: float, A_eff: float) -> float:
@@ -80,35 +84,45 @@ def W_to_Vm(n0: float, A_eff: float) -> float:
class Unit: class Unit:
__func: _UT func: _UT
to: To
inv: _UT inv: _UT
name: str to: To = None
label: str name: str = "unit"
type: str = "other"
label: str = ""
def __init__(self, func: _UT, inv: _UT, name: str, label: str, tpe: str): def __init_subclass__(cls):
self.__func = func def call(self, x: _T) -> _T:
return self.func(x)
call.__doc__ = f"Transform x from {cls.name!r} to {PRIMARIES.get(cls.type)!r}"
cls.__call__ = call
def __init__(self, func: _UT, inv: _UT = None):
self.func = func
if inv is None:
self.inv = func
else:
self.inv = inv self.inv = inv
self.to = To(name, tpe) self.to = To(self.name)
self.name = name units_map[self.name] = self
self.label = label
self.type = tpe
self.__name__ = name
def __call__(self, x: _T) -> _T: def __call__(self, x: _T) -> _T:
"""call the original unit function""" """call the original unit function"""
return self.__func(x) return self.func(x)
def inverse(self, func: _UT):
self.inv = func
return self
def unit(tpe: str, label: str, inv: Callable = None): def unit(tpe: str, label: str, inv: Callable = None):
def unit_maker(func) -> Unit: def unit_maker(func) -> Unit:
nonlocal inv nonlocal inv
name = func.__name__ name = func.__name__
if inv is None: unit_type = type(f"Unit_{name}", (Unit,), dict(name=name, label=label, type=tpe))
inv = func
unit_obj = wraps(func)(Unit(func, inv, name, label, tpe)) return unit_type(func, inv)
units_map[tpe][name] = unit_obj
return unit_obj
return unit_maker return unit_maker

View File

@@ -30,8 +30,6 @@ def main():
spec_ax.set_xlabel(rs.unit.label) spec_ax.set_xlabel(rs.unit.label)
field_ax = app[1] field_ax = app[1]
field_ax.set_xlabel(rt.unit.label) field_ax.set_xlabel(rt.unit.label)
x: float = 4.5
y = sc.units.m.to.nm(x)
@app.update @app.update
def draw(i): def draw(i):