3030# some convenience functions for multiplying diagonals
3131
3232
33+ @compose
3334def rdmul (x , d ):
3435 """Right-multiplication a matrix by a vector representing a diagonal."""
3536 return x * d [None , :]
3637
3738
39+ @compose
3840def rddiv (x , d ):
3941 """Right-multiplication of a matrix by a vector representing an inverse
4042 diagonal.
4143 """
4244 return x / d [None , :]
4345
4446
47+ @compose
4548def ldmul (d , x ):
4649 """Left-multiplication a matrix by a vector representing a diagonal."""
4750 return x * d [:, None ]
4851
4952
53+ @compose
5054def lddiv (d , x ):
5155 """Left-multiplication of a matrix by a vector representing an inverse
5256 diagonal.
@@ -61,22 +65,22 @@ def dag_numba(x):
6165
6266@njit # pragma: no cover
6367def rdmul_numba (x , d ):
64- return x * d . reshape ( 1 , - 1 )
68+ return x * d [ None , :]
6569
6670
6771@njit # pragma: no cover
6872def rddiv_numba (x , d ):
69- return x / d . reshape ( 1 , - 1 )
73+ return x / d [ None , :]
7074
7175
7276@njit # pragma: no cover
7377def ldmul_numba (d , x ):
74- return x * d . reshape ( - 1 , 1 )
78+ return x * d [:, None ]
7579
7680
7781@njit # pragma: no cover
7882def lddiv_numba (d , x ):
79- return x / d . reshape ( - 1 , 1 )
83+ return x / d [:, None ]
8084
8185
8286@compose
@@ -151,6 +155,8 @@ def _trim_and_renorm_svd_result(
151155 # assume already all positive
152156 sabs = s
153157
158+ d = do ("shape" , sabs )[0 ]
159+
154160 if (cutoff > 0.0 ) or (renorm > 0 ):
155161 if cutoff_mode == 1 : # 'abs'
156162 n_chi = do ("count_nonzero" , sabs > cutoff )
@@ -184,9 +190,9 @@ def _trim_and_renorm_svd_result(
184190 n_chi = max_bond
185191 else :
186192 # neither maximum bond dimension nor cutoff specified
187- n_chi = do ( "shape" , s )[ 0 ]
193+ n_chi = d
188194
189- if n_chi < do ( "shape" , s )[ 0 ] :
195+ if n_chi < d :
190196 s = s [:n_chi ]
191197 U = U [:, :n_chi ]
192198 VH = VH [:n_chi , :]
@@ -583,14 +589,21 @@ def eigh_truncated(
583589 max_bond = - 1 ,
584590 absorb = 0 ,
585591 renorm = 0 ,
592+ positive = 0 ,
586593 backend = None ,
587594):
588595 with backend_like (backend ):
589596 s , U = do ("linalg.eigh" , x )
590597
591598 # make sure largest singular value first
592- idx = do ("argsort" , - do ("abs" , s ))
593- s , U = s [idx ], U [:, idx ]
599+ if not positive :
600+ idx = do ("argsort" , - do ("abs" , s ))
601+ s , U = s [idx ], U [:, idx ]
602+ else :
603+ # assume all positive, simply reverse
604+ s = s [::- 1 ]
605+ U = U [:, ::- 1 ]
606+
594607 VH = dag (U )
595608
596609 # XXX: better to absorb phase in V and return positive 'values'?
@@ -613,16 +626,26 @@ def eigh_truncated(
613626@eigh_truncated .register ("numpy" )
614627@njit # pragma: no cover
615628def eigh_truncated_numba (
616- x , cutoff = - 1.0 , cutoff_mode = 4 , max_bond = - 1 , absorb = 0 , renorm = 0
629+ x ,
630+ cutoff = - 1.0 ,
631+ cutoff_mode = 4 ,
632+ max_bond = - 1 ,
633+ absorb = 0 ,
634+ renorm = 0 ,
635+ positive = 0 ,
617636):
618637 """SVD-decomposition, using hermitian eigen-decomposition, only works if
619638 ``x`` is hermitian.
620639 """
621640 s , U = np .linalg .eigh (x )
622641
623642 # make sure largest singular value first
624- k = np .argsort (- np .abs (s ))
625- s , U = s [k ], U [:, k ]
643+ if not positive :
644+ k = np .argsort (- np .abs (s ))
645+ s , U = s [k ], U [:, k ]
646+ else :
647+ s = s [::- 1 ]
648+ U = U [:, ::- 1 ]
626649 VH = dag_numba (U )
627650
628651 # XXX: better to absorb phase in V and return positive 'values'?
@@ -1210,47 +1233,40 @@ def squared_op_to_reduced_factor(x2, dl, dr, right=True):
12101233 # know exactly low-rank, so truncate
12111234 keep = dl
12121235 else :
1213- keep = None
1236+ keep = - 1
12141237 else :
12151238 if dl > dr :
12161239 # know exactly low-rank, so truncate
12171240 keep = dr
12181241 else :
1219- keep = None
1242+ keep = - 1
12201243
12211244 try :
12221245 # attempt faster hermitian eigendecomposition
1223- s2 , W = do ( "linalg.eigh" , x2 )
1224-
1225- if keep is not None :
1226- # outer dimension smaller -> exactly low-rank
1227- s2 = s2 [ - keep :]
1228- W = W [:, - keep :]
1229-
1246+ U , s2 , VH = eigh_truncated (
1247+ x2 ,
1248+ max_bond = keep ,
1249+ cutoff = 0.0 ,
1250+ absorb = None ,
1251+ positive = 1 , # know positive
1252+ )
12301253 # might have negative eigenvalues due to numerical error from squaring
12311254 s2 = do ("clip" , s2 , 0.0 , None )
1232- s = do ("sqrt" , s2 )
1233-
1234- if right :
1235- factor = ldmul (s , dag (W ))
1236- else : # 'left'
1237- factor = rdmul (W , s )
12381255
12391256 except Exception :
1240- # fallback to SVD
1241- U , s2 , VH = do ("linalg.svd" , x2 )
1242- if keep is not None :
1243- # outer dimension smaller -> exactly low-rank
1244- s2 = s2 [:keep ]
1245- if right :
1246- VH = VH [:keep , :]
1247- else :
1248- U = U [:, :keep ]
1249- s = do ("sqrt" , s2 )
1250- if right :
1251- factor = ldmul (s , VH )
1252- else : # 'left'
1253- factor = rdmul (U , s )
1257+ # fallback to SVD if maybe badly conditioned etc.
1258+ U , s2 , VH = svd_truncated (
1259+ x2 ,
1260+ max_bond = keep ,
1261+ cutoff = 0.0 ,
1262+ absorb = None ,
1263+ )
1264+
1265+ s = do ("sqrt" , s2 )
1266+ if right :
1267+ factor = ldmul (s , VH )
1268+ else : # 'left'
1269+ factor = rdmul (U , s )
12541270
12551271 return factor
12561272
0 commit comments