Skip to content

Commit d438901

Browse files
committed
tn.draw: show abelian signature
1 parent 5e99e66 commit d438901

File tree

4 files changed

+55
-8
lines changed

4 files changed

+55
-8
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Release notes for `quimb`.
1212
- [`schematic.Drawing`](quimb.schematic.drawing): add [`grid`](quimb.schematic.drawing.grid), [`grid3d`](quimb.schematic.drawing.grid3d), [`bezier`](quimb.schematic.drawing.bezier), [`star`](quimb.schematic.drawing.star), [`cross`](quimb.schematic.drawing.cross) and [`zigzag`](quimb.schematic.drawing.zigzag) methods.
1313
- [`schematic.Drawing`](quimb.schematic.drawing): add `relative` option to [`arrowhead`](quimb.schematic.drawing.arrowhead), `shorten` option to [`text_between`](quimb.schematic.drawing.text_between) and `text_left` and `text_right` options to [`line`](quimb.schematic.drawing.line).
1414
- refactor [`TEBDGen`](quimb.tensor.tensor_arbgeom_tebd.TEBDGen) and [`SimpleUpdateGen`](quimb.tensor.tensor_arbgeom_tebd.SimpleUpdateGen)
15+
- [`tn.draw()`](quimb.tensor.drawing.draw_tn): show abelian signature if using `symmray` arrays.
1516

1617
(whats-new-1-11-2)=
1718
## v1.11.2 (2025-07-30)

quimb/tensor/drawing.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,26 +379,44 @@ def draw_tn(
379379
# dummy hyper outer edge - no arrows
380380
edges[pair]["arrow_left"].append(False)
381381
edges[pair]["arrow_right"].append(False)
382+
edges[pair]["label_left"].append(None)
383+
edges[pair]["label_right"].append(None)
382384
else:
383385
# tensor side can always have an incoming arrow
384-
tl_left_inds = tn.tensor_map[pair[0]].left_inds
386+
tl = tn.tensor_map[pair[0]]
387+
tl_left_inds = tl.left_inds
385388
edges[pair]["arrow_left"].append(
386389
show_left_inds
387390
and (tl_left_inds is not None)
388391
and (ix in tl_left_inds)
389392
)
393+
394+
if hasattr(tl.data, "signature"):
395+
sigl = tl.data.signature[tl.inds.index(ix)]
396+
edges[pair]["label_left"].append(sigl)
397+
else:
398+
edges[pair]["label_left"].append(None)
399+
390400
if ishyper:
391401
# hyper edge can't have an incoming arrow
392402
edges[pair]["arrow_right"].append(False)
403+
edges[pair]["label_right"].append(None)
393404
else:
394405
# standard edge can
395-
tr_left_inds = tn.tensor_map[pair[1]].left_inds
406+
tr = tn.tensor_map[pair[1]]
407+
tr_left_inds = tr.left_inds
396408
edges[pair]["arrow_right"].append(
397409
show_left_inds
398410
and (tr_left_inds is not None)
399411
and (ix in tr_left_inds)
400412
)
401413

414+
if hasattr(tr.data, "signature"):
415+
sigr = tr.data.signature[tr.inds.index(ix)]
416+
edges[pair]["label_right"].append(sigr)
417+
else:
418+
edges[pair]["label_right"].append(None)
419+
402420
# parse all tensors / nodes
403421
for tid, t in tn.tensor_map.items():
404422
nodes[tid]["tid"] = tid
@@ -519,7 +537,11 @@ def draw_tn(
519537
nodes[node]["coo"] = G.nodes[node]["coo"] = pos[node]
520538

521539
for (i, j), edge_data in edges.items():
522-
edges[i, j]["coos"] = G.edges[i, j]["coos"] = pos[i], pos[j]
540+
edge_data["coos"] = G.edges[i, j]["coos"] = pos[i], pos[j]
541+
edge_data["shorten"] = G.edges[i, j]["shorten"] = (
542+
nodes[i]["size"],
543+
nodes[j]["size"],
544+
)
523545

524546
if get == "pos":
525547
return pos
@@ -677,7 +699,7 @@ def _draw_matplotlib(
677699
fig = None
678700

679701
arrow_opts = arrow_opts or {}
680-
arrow_opts.setdefault("center", 3 / 4)
702+
arrow_opts.setdefault("center", 0.8)
681703
arrow_opts.setdefault("linewidth", 1)
682704
arrow_opts.setdefault("width", 0.08)
683705
arrow_opts.setdefault("length", 0.12)
@@ -692,7 +714,10 @@ def _draw_matplotlib(
692714
labels = edge_data["label"]
693715
arrow_lefts = edge_data["arrow_left"]
694716
arrow_rights = edge_data["arrow_right"]
717+
label_lefts = edge_data["label_left"]
718+
label_rights = edge_data["label_right"]
695719
multiplicity = len(edge_colors)
720+
shorten = edge_data["shorten"]
696721

697722
if multiplicity > 1:
698723
offsets = np.linspace(
@@ -731,11 +756,25 @@ def _draw_matplotlib(
731756
color=edge_data["label_color"],
732757
fontfamily=edge_data["label_fontfamily"],
733758
)
759+
if label_lefts[m]:
760+
line_opts["text_left"] = dict(
761+
text=label_lefts[m],
762+
fontsize=edge_data["label_fontsize"] + 3,
763+
color=edge_data["label_color"],
764+
fontfamily=edge_data["label_fontfamily"],
765+
)
766+
if label_rights[m]:
767+
line_opts["text_right"] = dict(
768+
text=label_rights[m],
769+
fontsize=edge_data["label_fontsize"] + 3,
770+
color=edge_data["label_color"],
771+
fontfamily=edge_data["label_fontfamily"],
772+
)
734773

735774
if multiplicity > 1:
736775
d.line_offset(offset=offsets[m], **line_opts)
737776
else:
738-
d.line(**line_opts)
777+
d.line(shorten=shorten, **line_opts)
739778

740779
# draw the tensors
741780
for _, node_data in nodes.items():

quimb/tensor/tensor_1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1971,7 +1971,7 @@ def add_MPS(self, other, inplace=False, **kwargs):
19711971
add_MPS_ = functools.partialmethod(add_MPS, inplace=True)
19721972

19731973
def permute_arrays(self, shape="lrp"):
1974-
"""Permute the indices of each tensor in this MPS to match ``shape``.
1974+
"""Ensure the arrays are stored internally in the specified order.
19751975
This doesn't change how the overall object interacts with other tensor
19761976
networks but may be useful for extracting the underlying arrays
19771977
consistently. This is an inplace operation.

quimb/tensor/tensor_core.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,6 +1725,9 @@ def check(self):
17251725
f"Tensor data contains non-finite values: {self.data}."
17261726
)
17271727

1728+
if hasattr(self.data, "check"):
1729+
self.data.check()
1730+
17281731
@property
17291732
def owners(self):
17301733
return self._owners
@@ -4539,6 +4542,10 @@ def check(self):
45394542
"Mismatched index dimension for index "
45404543
f"'{ix}' in tensors {ts}."
45414544
)
4545+
if len(ts) == 2 and hasattr(ts[0].data, "check_with"):
4546+
axa = ts[0].inds.index(ix)
4547+
axb = ts[1].inds.index(ix)
4548+
ts[0].data.check_with(ts[1].data, [axa], [axb])
45424549

45434550
def add_tag(self, tag, where=None, which="all", record=None):
45444551
"""Add tag(s) to every tensor in this network, or if ``where`` is
@@ -5498,7 +5505,7 @@ def _select_local_tids(
54985505
)
54995506

55005507
# full-rank decompose the outer tensor
5501-
l, r = self.tensor_map[tid_out].split(
5508+
_, r = self.tensor_map[tid_out].split(
55025509
left_inds=None,
55035510
right_inds=[ix],
55045511
max_bond=None,
@@ -9593,7 +9600,7 @@ def split(self, left_inds, right_inds=None, **split_opts):
95939600

95949601
def trace(self, left_inds, right_inds, **contract_opts):
95959602
"""Trace over ``left_inds`` joined with ``right_inds``"""
9596-
tn = self.reindex({u: l for u, l in zip(left_inds, right_inds)})
9603+
tn = self.reindex(dict(zip(left_inds, right_inds)))
95979604
return tn.contract_tags(..., **contract_opts)
95989605

95999606
def to_dense(self, *inds_seq, to_qarray=False, **contract_opts):

0 commit comments

Comments
 (0)