Skip to content

Commit f074801

Browse files
committed
Add feature to increase iPEPS bond dim
1 parent 5c08c3b commit f074801

File tree

7 files changed

+366
-60
lines changed

7 files changed

+366
-60
lines changed

varipeps/contractions/definitions.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,66 @@ def _prepare_defs(cls):
691691
],
692692
}
693693

694+
unitcell_bond_dim_change_left: Definition = {
695+
"tensors": [["tensor", "tensor_conj", "C1", "T1", "T3", "C4", "T4"]],
696+
"network": [
697+
[
698+
(5, 9, 11, -4, 4), # tensor
699+
(7, 10, 11, -3, 6), # tensor_conj
700+
(1, 3), # C1
701+
(3, 4, 6, -1), # T1
702+
(2, -2, 10, 9), # T3
703+
(2, 8), # C4
704+
(8, 7, 5, 1), # T4
705+
]
706+
],
707+
}
708+
709+
unitcell_bond_dim_change_right: Definition = {
710+
"tensors": [["tensor", "tensor_conj", "T1", "C2", "T2", "T3", "C3"]],
711+
"network": [
712+
[
713+
(-1, 8, 11, 5, 4), # tensor
714+
(-4, 9, 11, 7, 6), # tensor_conj
715+
(-2, 4, 6, 3), # T1
716+
(3, 1), # C2
717+
(5, 7, 10, 1), # T2
718+
(-3, 2, 9, 8), # T3
719+
(2, 10), # C3
720+
]
721+
],
722+
}
723+
724+
unitcell_bond_dim_change_top: Definition = {
725+
"tensors": [["tensor", "tensor_conj", "C1", "T1", "C2", "T2", "T4"]],
726+
"network": [
727+
[
728+
(8, -4, 11, 4, 5), # tensor
729+
(9, -3, 11, 6, 7), # tensor_conj
730+
(2, 10), # C1
731+
(10, 5, 7, 1), # T1
732+
(1, 3), # C2
733+
(4, 6, -2, 3), # T2
734+
(-1, 9, 8, 2), # T4
735+
]
736+
],
737+
}
738+
739+
unitcell_bond_dim_change_bottom: Definition = {
740+
"tensors": [["tensor", "tensor_conj", "T2", "C3", "T3", "C4", "T4"]],
741+
"network": [
742+
[
743+
(4, 5, 11, 8, -1), # tensor
744+
(6, 7, 11, 9, -4), # tensor_conj
745+
(8, 9, 2, -3), # T2
746+
(10, 2), # C3
747+
(1, 10, 7, 5), # T3
748+
(1, 3), # C4
749+
(3, 6, 4, -2), # T4
750+
]
751+
],
752+
}
753+
694754
kagome_pess3_single_site_mapping: Definition = {
695755
"tensors": ["up_simplex", "down_simplex", "site1", "site2", "site3"],
696756
"network": [

varipeps/ctmrg/absorption.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from varipeps.contractions import apply_contraction, apply_contraction_jitted
1111
from varipeps.utils.svd import gauge_fixed_svd
1212
from varipeps.utils.periodic_indices import calculate_periodic_indices
13+
from varipeps.utils.projector_dict import Projector_Dict
1314
from .projectors import (
1415
calc_left_projectors,
1516
calc_right_projectors,
1617
calc_top_projectors,
1718
calc_bottom_projectors,
18-
T_Projector,
1919
)
2020
from varipeps.expectation.one_site import calc_one_site_single_gate_obj
2121
from varipeps.config import PEPS_AD_Config
@@ -27,40 +27,6 @@
2727
CTMRG_Orientation = Literal["top-left", "top-right", "bottom-left", "bottom-right"]
2828

2929

30-
@dataclass
31-
class _Projector_Dict(collections.abc.MutableMapping):
32-
max_x: int
33-
max_y: int
34-
projector_dict: Dict[Tuple[int, int], T_Projector] = field(default_factory=dict) # type: ignore
35-
36-
def __getitem__(self, key: Tuple[int, int]) -> T_Projector:
37-
return self.projector_dict[key]
38-
39-
def __setitem__(self, key: Tuple[int, int], value: T_Projector) -> None:
40-
self.projector_dict[key] = value
41-
42-
def __delitem__(self, key: Tuple[int, int]) -> None:
43-
self.projector_dict.__delitem__(key)
44-
45-
def __iter__(self):
46-
return self.projector_dict.__iter__()
47-
48-
def __len__(self):
49-
return self.projector_dict.__len__()
50-
51-
def get_projector(
52-
self,
53-
current_x: int,
54-
current_y: int,
55-
relative_x: int,
56-
relative_y: int,
57-
) -> T_Projector:
58-
select_x = (current_x + relative_x) % self.max_x
59-
select_y = (current_y + relative_y) % self.max_y
60-
61-
return self.projector_dict[(select_x, select_y)]
62-
63-
6430
def _tensor_list_from_indices(
6531
peps_tensors: Sequence[jnp.ndarray], indices: Sequence[Sequence[int]]
6632
) -> List[List[jnp.ndarray]]:
@@ -150,7 +116,7 @@ def do_left_absorption(
150116
all elements of the unitcell.
151117
"""
152118
max_x, max_y = unitcell.get_size()
153-
left_projectors = _Projector_Dict(max_x=max_x, max_y=max_y)
119+
left_projectors = Projector_Dict(max_x=max_x, max_y=max_y)
154120

155121
working_unitcell = unitcell.copy()
156122

@@ -242,7 +208,7 @@ def do_right_absorption(
242208
all elements of the unitcell.
243209
"""
244210
max_x, max_y = unitcell.get_size()
245-
right_projectors = _Projector_Dict(max_x=max_x, max_y=max_y)
211+
right_projectors = Projector_Dict(max_x=max_x, max_y=max_y)
246212

247213
working_unitcell = unitcell.copy()
248214

@@ -338,7 +304,7 @@ def do_top_absorption(
338304
all elements of the unitcell.
339305
"""
340306
max_x, max_y = unitcell.get_size()
341-
top_projectors = _Projector_Dict(max_x=max_x, max_y=max_y)
307+
top_projectors = Projector_Dict(max_x=max_x, max_y=max_y)
342308

343309
working_unitcell = unitcell.copy()
344310

@@ -430,7 +396,7 @@ def do_bottom_absorption(
430396
all elements of the unitcell.
431397
"""
432398
max_x, max_y = unitcell.get_size()
433-
bottom_projectors = _Projector_Dict(max_x=max_x, max_y=max_y)
399+
bottom_projectors = Projector_Dict(max_x=max_x, max_y=max_y)
434400

435401
working_unitcell = unitcell.copy()
436402

varipeps/ctmrg/projectors.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections import namedtuple
21
import enum
32
from functools import partial
43

@@ -10,22 +9,18 @@
109
from varipeps import varipeps_config
1110
from varipeps.utils.func_cache import Checkpointing_Cache
1211
from varipeps.utils.svd import gauge_fixed_svd
12+
from varipeps.utils.projector_dict import (
13+
Left_Projectors,
14+
Right_Projectors,
15+
Top_Projectors,
16+
Bottom_Projectors,
17+
)
1318
from varipeps.config import Projector_Method, PEPS_AD_Config
1419
from varipeps.global_state import PEPS_AD_Global_State
1520

1621
from typing import Sequence, Tuple, TypeVar
1722

1823

19-
Left_Projectors = namedtuple("Left_Projectors", ("top", "bottom"))
20-
Right_Projectors = namedtuple("Right_Projectors", ("top", "bottom"))
21-
Top_Projectors = namedtuple("Top_Projectors", ("left", "right"))
22-
Bottom_Projectors = namedtuple("Bottom_Projectors", ("left", "right"))
23-
24-
T_Projector = TypeVar(
25-
"T_Projector", Left_Projectors, Right_Projectors, Top_Projectors, Bottom_Projectors
26-
)
27-
28-
2924
class _Projectors_Func_Cache:
3025
_left = None
3126
_right = None

varipeps/peps/tensor.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def replace_tensor(
327327
new_tensor: Tensor,
328328
*,
329329
reinitialize_env_as_identities: bool = True,
330+
new_D: Optional[Tuple[int, int, int, int]] = None,
330331
) -> T_PEPS_Tensor:
331332
"""
332333
Replace the PEPS tensor and returns new object of the class.
@@ -337,31 +338,38 @@ def replace_tensor(
337338
Keyword args:
338339
reinitialize_env_as_identities (:obj:`bool`):
339340
Reinitialize the CTM tensors as identities.
341+
new_D (:obj:`tuple` of four :obj:`int`, optional):
342+
Tuple of new iPEPS bond dimensions if tensor has changed dimensions
340343
Returns:
341344
:obj:`~varipeps.peps.PEPS_Tensor`:
342345
New instance of the class with the tensor replaced.
343346
"""
347+
if new_D is None:
348+
new_D = self.D
349+
elif not isinstance(new_D, tuple) and len(new_D) != 4:
350+
raise ValueError("Invalid argument for parameter new_D")
351+
344352
if reinitialize_env_as_identities:
345353
return type(self)(
346354
tensor=new_tensor,
347355
C1=jnp.ones((1, 1), dtype=self.C1.dtype),
348356
C2=jnp.ones((1, 1), dtype=self.C2.dtype),
349357
C3=jnp.ones((1, 1), dtype=self.C3.dtype),
350358
C4=jnp.ones((1, 1), dtype=self.C4.dtype),
351-
T1=jnp.eye(self.D[3], dtype=self.T1.dtype).reshape(
352-
1, self.D[3], self.D[3], 1
359+
T1=jnp.eye(new_D[3], dtype=self.T1.dtype).reshape(
360+
1, new_D[3], new_D[3], 1
353361
),
354-
T2=jnp.eye(self.D[2], dtype=self.T2.dtype).reshape(
355-
self.D[2], self.D[2], 1, 1
362+
T2=jnp.eye(new_D[2], dtype=self.T2.dtype).reshape(
363+
new_D[2], new_D[2], 1, 1
356364
),
357-
T3=jnp.eye(self.D[1], dtype=self.T3.dtype).reshape(
358-
1, 1, self.D[1], self.D[1]
365+
T3=jnp.eye(new_D[1], dtype=self.T3.dtype).reshape(
366+
1, 1, new_D[1], new_D[1]
359367
),
360-
T4=jnp.eye(self.D[0], dtype=self.T4.dtype).reshape(
361-
1, self.D[0], self.D[0], 1
368+
T4=jnp.eye(new_D[0], dtype=self.T4.dtype).reshape(
369+
1, new_D[0], new_D[0], 1
362370
),
363371
d=self.d,
364-
D=self.D,
372+
D=new_D,
365373
chi=self.chi,
366374
max_chi=self.max_chi,
367375
)
@@ -377,7 +385,7 @@ def replace_tensor(
377385
T3=self.T3,
378386
T4=self.T4,
379387
d=self.d,
380-
D=self.D,
388+
D=new_D,
381389
chi=self.chi,
382390
max_chi=self.max_chi,
383391
)

0 commit comments

Comments
 (0)