Skip to content

Commit 66134b6

Browse files
authored
Merge pull request #333 from sjdu10/clean_quimb
add codes of 2-site MPS-MPO fitting for fermionic TN
2 parents 02c3acb + a59bd1a commit 66134b6

File tree

4 files changed

+308
-2
lines changed

4 files changed

+308
-2
lines changed

ci/requirements/py-openblas.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ dependencies:
1616
- pytest-cov
1717
- scipy<1.15.0
1818
- tqdm
19+
- pip
20+
- pip:
21+
- git+https://github.com/jcmgray/symmray.git@main

quimb/tensor/fitting.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from ..utils import check_opt
66
from .contraction import contract_strategy
7+
from .array_ops import isfermionic
78

89

910
def tensor_network_distance(
@@ -94,6 +95,16 @@ def tensor_network_distance(
9495
if method == "dense":
9596
tnA = tnA.contract(output_inds=oix, preserve_tensor=True)
9697
tnB = tnB.contract(output_inds=oix, preserve_tensor=True)
98+
if isfermionic(tnA.data):
99+
# if fermion tensor, flip dual outer indices in A
100+
data = tnA.data
101+
dual_outer_axs = tuple(
102+
ax
103+
for ax, ix in enumerate(tnA.inds)
104+
if (ix in oix) and not data.indices[ax].dual
105+
)
106+
if dual_outer_axs:
107+
tnA.modify(data=data.phase_flip(*dual_outer_axs))
97108

98109
# overlap method
99110
if xAA is None:

quimb/tensor/tensor_1d_compress.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .tensor_arbgeom import tensor_network_apply_op_vec
2424
from .tensor_arbgeom_compress import tensor_network_ag_compress
2525
from .tensor_builder import TN_matching, rand_tensor
26+
from .array_ops import isfermionic
2627
from .tensor_core import (
2728
Tensor,
2829
TensorNetwork,
@@ -1280,6 +1281,10 @@ def _tn1d_fit_sum_sweep_1site(
12801281
N = len(site_tags)
12811282
K = len(tn_overlaps)
12821283

1284+
fermion = isfermionic(tn_fit.tensors[0].data)
1285+
if fermion:
1286+
raise NotImplementedError("Fermionic 1-site fitting not implemented, use 2-site (bsz=2).")
1287+
12831288
if max_bond is not None:
12841289
current_bond_dim = tn_fit.max_bond()
12851290
if current_bond_dim < max_bond:
@@ -1461,7 +1466,9 @@ def _tn1d_fit_sum_sweep_2site(
14611466
left_inds = tuple(ix for ix in tfi0.inds if ix != bond)
14621467
right_inds = tuple(ix for ix in tfi1.inds if ix != bond)
14631468
tfinew = None
1464-
1469+
1470+
fermion = isfermionic(tfi0.data)
1471+
14651472
for k, tn_overlap in enumerate(tn_overlaps):
14661473
# form local overlap
14671474
tnik = (
@@ -1470,6 +1477,14 @@ def _tn1d_fit_sum_sweep_2site(
14701477
| envs["R", i + 1, k]
14711478
)
14721479

1480+
if fermion:
1481+
# keep track of the environment legs for possible phase flips later
1482+
left_env_ind, right_env_ind = None, None
1483+
if type(envs["L", i, k]) is Tensor:
1484+
(left_env_ind,) = tfi0.bonds(envs["L", i, k])
1485+
if type(envs["R", i + 1, k]) is Tensor:
1486+
(right_env_ind,) = tfi1.bonds(envs["R", i + 1, k])
1487+
14731488
# remove old tensors
14741489
del tnik["__FIT__", site0]
14751490
del tnik["__FIT__", site1]
@@ -1479,13 +1494,43 @@ def _tn1d_fit_sum_sweep_2site(
14791494
all, optimize=optimize, output_inds=left_inds + right_inds
14801495
)
14811496

1497+
if fermion:
1498+
# flip the dual indices of the environment legs if needed
1499+
lind_id, rind_id = None, None
1500+
lind_id = tfiknew.inds.index(left_env_ind) if left_env_ind is not None else None
1501+
rind_id = tfiknew.inds.index(right_env_ind) if right_env_ind is not None else None
1502+
if lind_id is not None and rind_id is not None:
1503+
if tfiknew.data.duals[lind_id]:
1504+
tfiknew.data.phase_flip(lind_id, inplace=True)
1505+
else:
1506+
tfiknew.data.phase_flip(rind_id, inplace=True)
1507+
elif lind_id is not None and rind_id is None:
1508+
if tfiknew.data.duals[lind_id]:
1509+
tfiknew.data.phase_flip(lind_id, inplace=True)
1510+
elif lind_id is None and rind_id is not None:
1511+
if tfiknew.data.duals[rind_id]:
1512+
tfiknew.data.phase_flip(rind_id, inplace=True)
1513+
14821514
# sum into fitted tensor
14831515
if tfinew is None:
14841516
tfinew = tfiknew
14851517
else:
14861518
tfinew += tfiknew
14871519

1488-
tfinew.conj_()
1520+
if fermion:
1521+
# when conjugating tn_fit (a ftn) the dual outer indices will be flipped (phase_dual=True)
1522+
# but ftensor.conj() assumes phase_dual=False by default
1523+
# we need to manually flip the dual 'physical' indices in tfinew before conjugating
1524+
data = tfinew.data
1525+
dual_outer_axs = tuple(
1526+
ax
1527+
for ax, ix in enumerate(tfinew.inds)
1528+
if ix not in (left_env_ind, right_env_ind) and data.indices[ax].dual # or tfixnew.data.duals[ax]
1529+
)
1530+
if dual_outer_axs:
1531+
tfinew.modify(data=data.phase_flip(*dual_outer_axs))
1532+
1533+
tfinew.conj_()
14891534

14901535
tfinew0, tfinew1 = tfinew.split(
14911536
max_bond=max_bond,
@@ -1509,6 +1554,13 @@ def _tn1d_fit_sum_sweep_2site(
15091554
tfi0.modify(data=tfinew0.data, left_inds=tfinew0.left_inds)
15101555
tfi1.modify(data=tfinew1.data, left_inds=tfinew1.left_inds)
15111556

1557+
if fermion:
1558+
# deal with the global signs generated during conjugation
1559+
for ts in (tfi0|tfi1).tensors:
1560+
if len(ts.data._oddpos) % 2 == 1:
1561+
ts.data.phase_global(inplace=True)
1562+
1563+
15121564
return max_tdiff
15131565

15141566

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import importlib
2+
3+
import numpy as np
4+
import pytest
5+
6+
requires_symmray = pytest.mark.skipif(
7+
importlib.util.find_spec("symmray") is None,
8+
reason="symmray not installed",
9+
)
10+
11+
@requires_symmray
12+
@pytest.fixture
13+
def get_fpeps_and_norm():
14+
import symmray as sr
15+
# fPEPS parameters
16+
Lx = int(4)
17+
Ly = int(4)
18+
symmetry = "Z2"
19+
D = 4
20+
# Load PEPS
21+
fpeps = sr.PEPS_fermionic_rand(
22+
symmetry, Lx, Ly, bond_dim=D, phys_dim=4, seed=42
23+
)
24+
fpeps.equalize_norms_()
25+
for ts in fpeps.tensors:
26+
ts.data.phase_sync(inplace=True)
27+
fpeps_norm = fpeps.make_norm()
28+
# benchmark_norm = fpeps_norm.contract_boundary_from_xmax(xrange=(0, Lx-1), max_bond=256, cutoff=0.0, mode='direct').contract()
29+
benchmark_norm = np.float64(9.347604511732736e18)
30+
return fpeps_norm, benchmark_norm
31+
32+
@requires_symmray
33+
@pytest.mark.parametrize("from_which", ["xmin", "xmax", "ymin", "ymax"])
34+
def test_fmps_mpo_fitting(from_which, get_fpeps_and_norm):
35+
36+
fpeps_norm, benchmark_norm = get_fpeps_and_norm
37+
print(f"Benchmark norm: {benchmark_norm}")
38+
print("Boundary fMPS-MPO fitting contraction test:")
39+
# contraction bond dimension
40+
chi = 128
41+
if from_which == "xmin":
42+
print("xmin:")
43+
c_xmin_0 = fpeps_norm.contract_boundary_from_xmin(
44+
xrange=(0, 1),
45+
max_bond=chi,
46+
cutoff=0.0,
47+
mode="fit",
48+
tol=1e-5,
49+
tn_fit="zipup",
50+
bsz=2,
51+
max_iterations=6,
52+
).contract()
53+
c_xmin_1 = fpeps_norm.contract_boundary_from_xmin(
54+
xrange=(0, 2),
55+
max_bond=chi,
56+
cutoff=0.0,
57+
mode="fit",
58+
tol=1e-5,
59+
tn_fit="zipup",
60+
bsz=2,
61+
max_iterations=6,
62+
).contract()
63+
c_xmin_2 = fpeps_norm.contract_boundary_from_xmin(
64+
xrange=(0, 1),
65+
max_bond=chi,
66+
cutoff=0.0,
67+
mode="fit",
68+
tol=1e-5,
69+
tn_fit="zipup",
70+
bsz=2,
71+
max_iterations=5,
72+
).contract()
73+
c_xmin_3 = fpeps_norm.contract_boundary_from_xmin(
74+
xrange=(0, 2),
75+
max_bond=chi,
76+
cutoff=0.0,
77+
mode="fit",
78+
tol=1e-5,
79+
tn_fit="zipup",
80+
bsz=2,
81+
max_iterations=5,
82+
).contract()
83+
print(c_xmin_0, np.allclose(c_xmin_0, benchmark_norm, rtol=1e-4))
84+
print(c_xmin_1, np.allclose(c_xmin_1, benchmark_norm, rtol=1e-4))
85+
print(c_xmin_2, np.allclose(c_xmin_2, benchmark_norm, rtol=1e-4))
86+
print(c_xmin_3, np.allclose(c_xmin_3, benchmark_norm, rtol=1e-4))
87+
assert np.allclose(c_xmin_0, benchmark_norm, rtol=1e-4)
88+
assert np.allclose(c_xmin_1, benchmark_norm, rtol=1e-4)
89+
assert np.allclose(c_xmin_2, benchmark_norm, rtol=1e-4)
90+
assert np.allclose(c_xmin_3, benchmark_norm, rtol=1e-4)
91+
elif from_which == "xmax":
92+
print("xmax:")
93+
c_xmax_0 = fpeps_norm.contract_boundary_from_xmax(
94+
xrange=(2, 3),
95+
max_bond=chi,
96+
cutoff=0.0,
97+
mode="fit",
98+
tol=1e-5,
99+
tn_fit="zipup",
100+
bsz=2,
101+
max_iterations=6,
102+
).contract()
103+
c_xmax_1 = fpeps_norm.contract_boundary_from_xmax(
104+
xrange=(1, 3),
105+
max_bond=chi,
106+
cutoff=0.0,
107+
mode="fit",
108+
tol=1e-5,
109+
tn_fit="zipup",
110+
bsz=2,
111+
max_iterations=6,
112+
).contract()
113+
c_xmax_2 = fpeps_norm.contract_boundary_from_xmax(
114+
xrange=(2, 3),
115+
max_bond=chi,
116+
cutoff=0.0,
117+
mode="fit",
118+
tol=1e-5,
119+
tn_fit="zipup",
120+
bsz=2,
121+
max_iterations=5,
122+
).contract()
123+
c_xmax_3 = fpeps_norm.contract_boundary_from_xmax(
124+
xrange=(1, 3),
125+
max_bond=chi,
126+
cutoff=0.0,
127+
mode="fit",
128+
tol=1e-5,
129+
tn_fit="zipup",
130+
bsz=2,
131+
max_iterations=5,
132+
).contract()
133+
print(c_xmax_0, np.allclose(c_xmax_0, benchmark_norm, rtol=1e-4))
134+
print(c_xmax_1, np.allclose(c_xmax_1, benchmark_norm, rtol=1e-4))
135+
print(c_xmax_2, np.allclose(c_xmax_2, benchmark_norm, rtol=1e-4))
136+
print(c_xmax_3, np.allclose(c_xmax_3, benchmark_norm, rtol=1e-4))
137+
assert np.allclose(c_xmax_0, benchmark_norm, rtol=1e-4)
138+
assert np.allclose(c_xmax_1, benchmark_norm, rtol=1e-4)
139+
assert np.allclose(c_xmax_2, benchmark_norm, rtol=1e-4)
140+
assert np.allclose(c_xmax_3, benchmark_norm, rtol=1e-4)
141+
elif from_which == "ymin":
142+
print("ymin:")
143+
c_ymin_0 = fpeps_norm.contract_boundary_from_ymin(
144+
yrange=(0, 1),
145+
max_bond=chi,
146+
cutoff=0.0,
147+
mode="fit",
148+
tol=1e-5,
149+
tn_fit="zipup",
150+
bsz=2,
151+
max_iterations=6,
152+
).contract()
153+
c_ymin_1 = fpeps_norm.contract_boundary_from_ymin(
154+
yrange=(0, 2),
155+
max_bond=chi,
156+
cutoff=0.0,
157+
mode="fit",
158+
tol=1e-5,
159+
tn_fit="zipup",
160+
bsz=2,
161+
max_iterations=6,
162+
).contract()
163+
c_ymin_2 = fpeps_norm.contract_boundary_from_ymin(
164+
yrange=(0, 1),
165+
max_bond=chi,
166+
cutoff=0.0,
167+
mode="fit",
168+
tol=1e-5,
169+
tn_fit="zipup",
170+
bsz=2,
171+
max_iterations=5,
172+
).contract()
173+
c_ymin_3 = fpeps_norm.contract_boundary_from_ymin(
174+
yrange=(0, 2),
175+
max_bond=chi,
176+
cutoff=0.0,
177+
mode="fit",
178+
tol=1e-5,
179+
tn_fit="zipup",
180+
bsz=2,
181+
max_iterations=5,
182+
).contract()
183+
print(c_ymin_0, np.allclose(c_ymin_0, benchmark_norm, rtol=1e-4))
184+
print(c_ymin_1, np.allclose(c_ymin_1, benchmark_norm, rtol=1e-4))
185+
print(c_ymin_2, np.allclose(c_ymin_2, benchmark_norm, rtol=1e-4))
186+
print(c_ymin_3, np.allclose(c_ymin_3, benchmark_norm, rtol=1e-4))
187+
assert np.allclose(c_ymin_0, benchmark_norm, rtol=1e-4)
188+
assert np.allclose(c_ymin_1, benchmark_norm, rtol=1e-4)
189+
assert np.allclose(c_ymin_2, benchmark_norm, rtol=1e-4)
190+
assert np.allclose(c_ymin_3, benchmark_norm, rtol=1e-4)
191+
elif from_which == "ymax":
192+
print("ymax:")
193+
c_ymax_0 = fpeps_norm.contract_boundary_from_ymax(
194+
yrange=(2, 3),
195+
max_bond=chi,
196+
cutoff=0.0,
197+
mode="fit",
198+
tol=1e-5,
199+
tn_fit="zipup",
200+
bsz=2,
201+
max_iterations=6,
202+
).contract()
203+
c_ymax_1 = fpeps_norm.contract_boundary_from_ymax(
204+
yrange=(1, 3),
205+
max_bond=chi,
206+
cutoff=0.0,
207+
mode="fit",
208+
tol=1e-5,
209+
tn_fit="zipup",
210+
bsz=2,
211+
max_iterations=6,
212+
).contract()
213+
c_ymax_2 = fpeps_norm.contract_boundary_from_ymax(
214+
yrange=(2, 3),
215+
max_bond=chi,
216+
cutoff=0.0,
217+
mode="fit",
218+
tol=1e-5,
219+
tn_fit="zipup",
220+
bsz=2,
221+
max_iterations=5,
222+
).contract()
223+
c_ymax_3 = fpeps_norm.contract_boundary_from_ymax(
224+
yrange=(1, 3),
225+
max_bond=chi,
226+
cutoff=0.0,
227+
mode="fit",
228+
tol=1e-5,
229+
tn_fit="zipup",
230+
bsz=2,
231+
max_iterations=5,
232+
).contract()
233+
print(c_ymax_0, np.allclose(c_ymax_0, benchmark_norm, rtol=1e-4))
234+
print(c_ymax_1, np.allclose(c_ymax_1, benchmark_norm, rtol=1e-4))
235+
print(c_ymax_2, np.allclose(c_ymax_2, benchmark_norm, rtol=1e-4))
236+
print(c_ymax_3, np.allclose(c_ymax_3, benchmark_norm, rtol=1e-4))
237+
assert np.allclose(c_ymax_0, benchmark_norm, rtol=1e-4)
238+
assert np.allclose(c_ymax_1, benchmark_norm, rtol=1e-4)
239+
assert np.allclose(c_ymax_2, benchmark_norm, rtol=1e-4)
240+
assert np.allclose(c_ymax_3, benchmark_norm, rtol=1e-4)

0 commit comments

Comments
 (0)