From d1e0154c0c2cc6f929e25f576332adafe4c19994 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Thu, 30 Nov 2023 14:05:03 -0800 Subject: [PATCH] TN.draw: show 'hyper outer' indices dangling --- quimb/tensor/drawing.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/quimb/tensor/drawing.py b/quimb/tensor/drawing.py index 4136655d..c1f5b2ba 100644 --- a/quimb/tensor/drawing.py +++ b/quimb/tensor/drawing.py @@ -309,7 +309,10 @@ def draw_tn( # compute a label for this index if ishyper: # each tensor connects to the dummy node represeting the hyper edge - pairs = ((tid, ix) for tid in tids) + pairs = [(tid, ix) for tid in tids] + if isouter and len(tids) > 1: + # 'hyper outer' index + pairs.append((("outer", ix), ix)) # hyper labels get put on dummy node label = "" @@ -364,24 +367,29 @@ def draw_tn( edges[pair]["label_color"] = label_color edges[pair]["label_fontfamily"] = font_family - # tensor side can always have an incoming arrow - tl_left_inds = tn.tensor_map[pair[0]].left_inds - edges[pair]["arrow_left"].append( - show_left_inds - and (tl_left_inds is not None) - and (ix in tl_left_inds) - ) - if ishyper: - # hyper edge can't have an incoming arrow + if isinstance(pair[0], tuple): + # dummy hyper outer edge - no arrows + edges[pair]["arrow_left"].append(False) edges[pair]["arrow_right"].append(False) else: - # standard edge can - tr_left_inds = tn.tensor_map[pair[1]].left_inds - edges[pair]["arrow_right"].append( + # tensor side can always have an incoming arrow + tl_left_inds = tn.tensor_map[pair[0]].left_inds + edges[pair]["arrow_left"].append( show_left_inds - and (tr_left_inds is not None) - and (ix in tr_left_inds) + and (tl_left_inds is not None) + and (ix in tl_left_inds) ) + if ishyper: + # hyper edge can't have an incoming arrow + edges[pair]["arrow_right"].append(False) + else: + # standard edge can + tr_left_inds = tn.tensor_map[pair[1]].left_inds + edges[pair]["arrow_right"].append( + show_left_inds + and (tr_left_inds is not None) + and (ix in tr_left_inds) + ) # parse all tensors / nodes for tid, t in tn.tensor_map.items():