added pretty print of Rule

This commit is contained in:
Benoît Sierro
2023-08-03 09:43:07 +02:00
parent 0ab946dc57
commit ef30f618cc
3 changed files with 98 additions and 129 deletions

View File

@@ -1,123 +0,0 @@
def print_graph(left_els, middle, right_els):
left_off = 1
right_off = 1
middle_off = 1
middle_els = middle.split(".")
for i, el in enumerate(middle_els[1:]):
middle_els[i + 1] = "." + el
left_els = [el + " " for el in left_els]
right_els = [" " + el for el in right_els]
left_width = max(len(e) for e in left_els)
middle_width_inner = max(len(e) for e in middle_els)
middle_width = middle_width_inner + 2 * middle_off
right_width = max(len(e) for e in right_els)
left_height = len(left_els)
middle_height = len(middle_els) + 1
right_height = len(right_els)
total_width = left_width + middle_width + right_width + left_off + right_off + 2
total_height = max(left_height, middle_height, right_height)
horiz_pos = min(total_height - middle_height, left_height, right_height)
lines = []
left_els = (
[None] * max(horiz_pos - left_height + 1, 0)
+ list(left_els)
+ [None] * (total_height - left_height)
)
middle_els = (
[" " * middle_width] * (horiz_pos)
+ ["" * middle_width]
+ [
" " * middle_off + format(el, f"<{middle_width_inner}") + " " * middle_off
for el in middle_els[::-1]
]
+ [" " * middle_width] * (total_height - horiz_pos)
)
right_els = (
[None] * max(horiz_pos - right_height + 1, 0)
+ list(right_els)
+ [None] * (total_height - right_height)
)
for i, (left_el, middle_el, right_el) in enumerate(zip(left_els, middle_els, right_els)):
line = get_left(left_el, i, horiz_pos, left_height, left_width, left_off)
line += middle_el
line += get_right(right_el, i, horiz_pos, right_height, right_width, right_off)
lines.append(line)
print(f"{horiz_pos = }, {total_height = }, {middle_height = }")
print("\n".join(lines[::-1]))
def get_left(left_el, i, horiz_pos, left_height, left_width, left_off):
if not left_el:
return " " * (left_width + left_off + 1)
line = format(left_el, f">{left_width}") + "" * left_off
if i == horiz_pos:
if left_height == 1:
line += ""
elif i == 0:
line += ""
elif i == left_height:
line += ""
else:
line += ""
else:
if i == left_height - 1:
line += ""
elif i == 0:
line += ""
else:
line += ""
return line
def get_right(right_el, i, horiz_pos, right_height, right_width, right_off):
if not right_el:
return " " * (right_width + right_off + 1)
if i == horiz_pos:
if right_height == 1:
line = ""
elif i == 0:
line = ""
elif i == right_height:
line = ""
else:
line = ""
else:
if i == right_height - 1:
line = ""
elif i == 0:
line = ""
else:
line = ""
line += "" * right_off + format(right_el, f"<{right_width}")
return line
if __name__ == "__main__":
print_graph(
("ads", "s", "45 sdksd dkfj"),
"main.fn.bonjour.s",
("asdf", "bonjour", "gamma", "wavelengt", "dt"),
)
print()
print()
print_graph(("ads", "s", "45 sdksd dkfj"), "main.fn.bonjour.s", ("asdf",))
print()
print()
print_graph(("ads", "s", "45 sdksd dkfj"), "main", ("asdf",))
print()
print()
print_graph(
("ads", "s", "45 sdksd dkfj", "a", "b", "c"),
"main.ollol",
("asdf", "some super long variable name"),
)
print()
print()

View File

@@ -7,7 +7,7 @@ from typing import Any, Callable, MutableMapping, NamedTuple, Optional, Type, Un
import numpy as np
from scgenerator import math, operators, utils
from scgenerator import io, math, operators, utils
from scgenerator.const import INF, MANDATORY_PARAMETERS
from scgenerator.physics import fiber, materials, plasma, pulse, units
from scgenerator.utils import _mock_function, func_rewrite, get_arg_names, get_logger
@@ -118,11 +118,7 @@ class Rule:
return self.func == other.func
def pretty_format(self) -> str:
func_name_elements = self.func_name.split(".")
targets = list(self.targets)
arg_size = max(self.args, key=len)
func_size = max(func_name_elements, key=len)
return io.format_graph(self.args, self.func_name, self.targets)
@property
def func_name(self) -> str:

View File

@@ -29,3 +29,99 @@ def decode_datetime_hook(obj):
continue
obj[k] = dt
return obj
def format_graph(left_elements: Sequence[str], middle: str, right_elements: Sequence[str]):
if len(left_elements) == 0:
left_elements = [""]
if len(right_elements) == 0:
right_elements = [""]
mid_elements = ["."[:i] + el for i, el in enumerate(middle.split("."))]
left_height = len(left_elements)
right_height = len(right_elements)
max_left = max(len(el) for el in left_elements) + 1
max_right = max(len(el) for el in right_elements) + 1
max_mid = max(len(el) for el in mid_elements) + 2
right_elements = [fit_right(el, max_right) for el in right_elements][::-1]
left_elements = [fit_left(el, max_left) for el in left_elements][::-1]
mid_elements = ["" * max_mid] + [fit_mid(el, max_mid) for el in mid_elements][::-1]
total_height = max(len(left_elements), len(right_elements), len(mid_elements))
line_pos = total_height - len(mid_elements)
left = left_col(
left_elements, line_pos, total_height, max_left, start=max(0, line_pos - left_height + 1)
)
right = right_col(
right_elements, line_pos, total_height, max_right, start=max(0, line_pos - right_height + 1)
)
middle = mid_col(mid_elements, line_pos, total_height, max_mid)
final = "\n".join(l + m + r for l, m, r in zip(left, middle, right))
return final
def fit_left(el, length, line=1, pad=1):
return f"{' '*length}{el}{' '*pad}{''*line}"[-(length + 1) :]
def fit_right(el, length, line=1, pad=1):
return f"{''*line}{' '*pad}{el}{' '*length}"[: length + 1]
def fit_mid(el, length, pad=1):
return f"{' '*pad}{el}{' '*length}"[:length]
def mid_col(mid_els, line_pos, height, pad_length):
out = [" " * pad_length] * line_pos + mid_els
while len(out) < height:
out.append(" " * pad_length)
return out[::-1]
def left_col(left_els, line_pos, height, pad_length, start=0):
out = [" " * (pad_length + 1)] * start
for rel_i, el in enumerate(left_els):
abs_j = rel_i + start
out.append(el + get_symb(False, rel_i, abs_j, line_pos, len(left_els) - 1))
while len(out) < height:
out.append(" " * (pad_length + 2))
return out[::-1]
def right_col(right_els, line_pos, height, pad_length, start=0):
out = [" " * (pad_length + 1)] * start
for rel_i, el in enumerate(right_els):
abs_j = rel_i + start
out.append(get_symb(True, rel_i, abs_j, line_pos, len(right_els) - 1) + el)
while len(out) < height:
out.append(" " * (pad_length + 2))
return out[::-1]
def get_symb(right: bool, rel_i, abs_j, line_pos, max_ind):
if max_ind == 0:
return ""
elif rel_i == 0:
if abs_j < line_pos:
return "╯╰"[right]
elif abs_j == line_pos:
return ""
else:
raise ValueError("bottom of left columns cannot be above line")
elif rel_i < max_ind:
if abs_j == line_pos:
return ""
else:
return "┤├"[right]
else:
if abs_j == line_pos:
return ""
elif abs_j > line_pos:
return "╮╭"[right]
else:
raise ValueError("top of left columns cannot be below line")