Skip to content

Commit 84593af

Browse files
committed
updates to 1d compression routines to support symmray
1 parent d438901 commit 84593af

File tree

6 files changed

+730
-153
lines changed

6 files changed

+730
-153
lines changed

quimb/tensor/decomp.py

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,27 @@
3030
# some convenience functions for multiplying diagonals
3131

3232

33+
@compose
3334
def rdmul(x, d):
3435
"""Right-multiplication a matrix by a vector representing a diagonal."""
3536
return x * d[None, :]
3637

3738

39+
@compose
3840
def rddiv(x, d):
3941
"""Right-multiplication of a matrix by a vector representing an inverse
4042
diagonal.
4143
"""
4244
return x / d[None, :]
4345

4446

47+
@compose
4548
def ldmul(d, x):
4649
"""Left-multiplication a matrix by a vector representing a diagonal."""
4750
return x * d[:, None]
4851

4952

53+
@compose
5054
def lddiv(d, x):
5155
"""Left-multiplication of a matrix by a vector representing an inverse
5256
diagonal.
@@ -61,22 +65,22 @@ def dag_numba(x):
6165

6266
@njit # pragma: no cover
6367
def rdmul_numba(x, d):
64-
return x * d.reshape(1, -1)
68+
return x * d[None, :]
6569

6670

6771
@njit # pragma: no cover
6872
def rddiv_numba(x, d):
69-
return x / d.reshape(1, -1)
73+
return x / d[None, :]
7074

7175

7276
@njit # pragma: no cover
7377
def ldmul_numba(d, x):
74-
return x * d.reshape(-1, 1)
78+
return x * d[:, None]
7579

7680

7781
@njit # pragma: no cover
7882
def lddiv_numba(d, x):
79-
return x / d.reshape(-1, 1)
83+
return x / d[:, None]
8084

8185

8286
@compose
@@ -151,6 +155,8 @@ def _trim_and_renorm_svd_result(
151155
# assume already all positive
152156
sabs = s
153157

158+
d = do("shape", sabs)[0]
159+
154160
if (cutoff > 0.0) or (renorm > 0):
155161
if cutoff_mode == 1: # 'abs'
156162
n_chi = do("count_nonzero", sabs > cutoff)
@@ -184,9 +190,9 @@ def _trim_and_renorm_svd_result(
184190
n_chi = max_bond
185191
else:
186192
# neither maximum bond dimension nor cutoff specified
187-
n_chi = do("shape", s)[0]
193+
n_chi = d
188194

189-
if n_chi < do("shape", s)[0]:
195+
if n_chi < d:
190196
s = s[:n_chi]
191197
U = U[:, :n_chi]
192198
VH = VH[:n_chi, :]
@@ -583,14 +589,21 @@ def eigh_truncated(
583589
max_bond=-1,
584590
absorb=0,
585591
renorm=0,
592+
positive=0,
586593
backend=None,
587594
):
588595
with backend_like(backend):
589596
s, U = do("linalg.eigh", x)
590597

591598
# make sure largest singular value first
592-
idx = do("argsort", -do("abs", s))
593-
s, U = s[idx], U[:, idx]
599+
if not positive:
600+
idx = do("argsort", -do("abs", s))
601+
s, U = s[idx], U[:, idx]
602+
else:
603+
# assume all positive, simply reverse
604+
s = s[::-1]
605+
U = U[:, ::-1]
606+
594607
VH = dag(U)
595608

596609
# XXX: better to absorb phase in V and return positive 'values'?
@@ -613,16 +626,26 @@ def eigh_truncated(
613626
@eigh_truncated.register("numpy")
614627
@njit # pragma: no cover
615628
def eigh_truncated_numba(
616-
x, cutoff=-1.0, cutoff_mode=4, max_bond=-1, absorb=0, renorm=0
629+
x,
630+
cutoff=-1.0,
631+
cutoff_mode=4,
632+
max_bond=-1,
633+
absorb=0,
634+
renorm=0,
635+
positive=0,
617636
):
618637
"""SVD-decomposition, using hermitian eigen-decomposition, only works if
619638
``x`` is hermitian.
620639
"""
621640
s, U = np.linalg.eigh(x)
622641

623642
# make sure largest singular value first
624-
k = np.argsort(-np.abs(s))
625-
s, U = s[k], U[:, k]
643+
if not positive:
644+
k = np.argsort(-np.abs(s))
645+
s, U = s[k], U[:, k]
646+
else:
647+
s = s[::-1]
648+
U = U[:, ::-1]
626649
VH = dag_numba(U)
627650

628651
# XXX: better to absorb phase in V and return positive 'values'?
@@ -1210,47 +1233,40 @@ def squared_op_to_reduced_factor(x2, dl, dr, right=True):
12101233
# know exactly low-rank, so truncate
12111234
keep = dl
12121235
else:
1213-
keep = None
1236+
keep = -1
12141237
else:
12151238
if dl > dr:
12161239
# know exactly low-rank, so truncate
12171240
keep = dr
12181241
else:
1219-
keep = None
1242+
keep = -1
12201243

12211244
try:
12221245
# attempt faster hermitian eigendecomposition
1223-
s2, W = do("linalg.eigh", x2)
1224-
1225-
if keep is not None:
1226-
# outer dimension smaller -> exactly low-rank
1227-
s2 = s2[-keep:]
1228-
W = W[:, -keep:]
1229-
1246+
U, s2, VH = eigh_truncated(
1247+
x2,
1248+
max_bond=keep,
1249+
cutoff=0.0,
1250+
absorb=None,
1251+
positive=1, # know positive
1252+
)
12301253
# might have negative eigenvalues due to numerical error from squaring
12311254
s2 = do("clip", s2, 0.0, None)
1232-
s = do("sqrt", s2)
1233-
1234-
if right:
1235-
factor = ldmul(s, dag(W))
1236-
else: # 'left'
1237-
factor = rdmul(W, s)
12381255

12391256
except Exception:
1240-
# fallback to SVD
1241-
U, s2, VH = do("linalg.svd", x2)
1242-
if keep is not None:
1243-
# outer dimension smaller -> exactly low-rank
1244-
s2 = s2[:keep]
1245-
if right:
1246-
VH = VH[:keep, :]
1247-
else:
1248-
U = U[:, :keep]
1249-
s = do("sqrt", s2)
1250-
if right:
1251-
factor = ldmul(s, VH)
1252-
else: # 'left'
1253-
factor = rdmul(U, s)
1257+
# fallback to SVD if maybe badly conditioned etc.
1258+
U, s2, VH = svd_truncated(
1259+
x2,
1260+
max_bond=keep,
1261+
cutoff=0.0,
1262+
absorb=None,
1263+
)
1264+
1265+
s = do("sqrt", s2)
1266+
if right:
1267+
factor = ldmul(s, VH)
1268+
else: # 'left'
1269+
factor = rdmul(U, s)
12541270

12551271
return factor
12561272

quimb/tensor/tensor_1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1980,7 +1980,7 @@ def permute_arrays(self, shape="lrp"):
19801980
----------
19811981
shape : str, optional
19821982
A permutation of ``'lrp'`` specifying the *desired* order of the
1983-
left, right, and physical indices respectively.
1983+
[l]eft, [r]ight, and [p]hysical indices respectively.
19841984
"""
19851985
self.ensure_bonds_exist()
19861986

0 commit comments

Comments
 (0)