-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Diffential Matrix Product State first implementation; projected wings…
… method
- Loading branch information
1 parent
13eeabb
commit 40bced1
Showing
5 changed files
with
200 additions
and
220 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from syngular.tensor.matrix_product_state import MatrixProductState | ||
from syngular.tensor.matrix_product_operator import MatrixProductOperator | ||
from syngular.tensor.matrix_product_density_operator import MatrixProductDensityOperator | ||
from syngular.tensor.matrix_product_density_operator import MatrixProductDensityOperator | ||
from syngular.tensor.differential_matrix_product_operator import DifferentialMatrixProductOperator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from turtle import right | ||
from syngular.tensor import MatrixProductOperator, MatrixProductState | ||
|
||
from opt_einsum import contract | ||
import numpy as np | ||
|
||
from typing import Union, List, Any | ||
|
||
class DifferentialMatrixProductOperator(MatrixProductOperator): | ||
|
||
def gradient(self, loss): | ||
pass | ||
|
||
def crumble_site(self, index: int) -> Any: | ||
left_site = self.sites[index] | ||
right_site = self.sites[index+1] | ||
|
||
print(left_site.shape) | ||
print(right_site.shape) | ||
|
||
crumbled_site = contract(*[ | ||
left_site, [1, 2, 3, 4], | ||
right_site, [4, 5, 6, 7], | ||
[1, 2, 3, 5, 6, 7] | ||
]) | ||
|
||
return crumbled_site | ||
|
||
def project(self, index, state): | ||
|
||
if not isinstance(state, MatrixProductState): | ||
raise Exception("projected wings should come from a matrix product state input") | ||
|
||
if not (0 <= index < self.sites_number-1): | ||
raise Exception("trying to project on non-existant site (site indices should be between 0 and the number of sites - 1)") | ||
|
||
crumbled_site = self.crumble_site(index) | ||
|
||
print(crumbled_site.shape) | ||
|
||
""" Left wing | ||
(blank, blank, output, ..., output, mps_bond, mpo_bond) | ||
| | | | | ||
---O---O---O---O--- | ||
| | | | | ||
---O---O---O---O--- | ||
=> | ||
(blank, output*...*output_n, mps_bond, mpo_bond) | ||
| | ||
---[ ]=== | ||
""" | ||
|
||
struct = [] | ||
n = (index+1) | ||
|
||
for idx in range(index): | ||
struct += [state.sites[idx], [ idx+1, n + idx+1, idx+2]] | ||
struct += [self.sites[idx], [ 2*n+idx, n+idx+1, 3*n+idx, 2*n+idx+1]] | ||
|
||
struct += [[1, 2*n] + list(range(3*n, 4*n-1)) + [n, 3*n-1]] | ||
|
||
raw_left_wing = contract(*struct) | ||
squeezed_left_wing = np.squeeze(raw_left_wing, axis=0) | ||
|
||
left_wing = np.transpose(squeezed_left_wing, axes=(list(range(1, n)) + [0, n, n+1])) | ||
leftover_shape = left_wing.shape[n-1:] | ||
left_wing = np.reshape(left_wing, newshape=(-1, *leftover_shape)) | ||
left_wing = np.transpose(left_wing, axes=(1, 0, 2, 3)) | ||
|
||
print("Left wing", left_wing.shape) | ||
|
||
|
||
""" Right wing | ||
(mps_bond, mpo_bond, output, ..., output, blank, blank) | ||
| | | | | ||
---O---O---O---O--- | ||
| | | | | ||
---O---O---O---O--- | ||
=> | ||
(mps_bond, mpo_bond, output*...*output_n, blank) | ||
| | ||
===[ ]--- | ||
""" | ||
|
||
struct = [] | ||
n = (self.sites_number - index - 1) | ||
|
||
for jdx in range(index+2, self.sites_number): | ||
idx = jdx-index-2 | ||
struct += [state.sites[jdx], [ idx+1, n + idx+1, idx+2]] | ||
struct += [self.sites[jdx], [ 2*n+idx, n+idx+1, 3*n+idx, 2*n+idx+1]] | ||
|
||
struct += [[1, 2*n] + list(range(3*n, 4*n-1)) + [n, 3*n-1]] | ||
|
||
raw_right_wing = contract(*struct) | ||
squeezed_right_wing = np.squeeze(raw_right_wing, axis=-1) | ||
|
||
right_wing = np.transpose(squeezed_right_wing, axes=(list(range(2, n+1)) + [0, 1, n+1])) | ||
leftover_shape = right_wing.shape[n-1:] | ||
right_wing = np.reshape(right_wing, newshape=(-1, *leftover_shape)) | ||
right_wing = np.transpose(right_wing, axes=(1, 2, 0, 3)) | ||
|
||
print("Right wing", right_wing.shape) | ||
|
||
left_center = state.sites[index] | ||
right_center = state.sites[index+1] | ||
|
||
print("Left Center", left_center.shape) | ||
print("Right Center", right_center.shape) | ||
|
||
return { | ||
'center_site': crumbled_site, | ||
'left_wing': left_wing, | ||
'left_center': left_center, | ||
'right_center': right_center, | ||
'right_wing': right_wing | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from syngular.tensor import MatrixProductState | ||
from syngular.tensor import MatrixProductOperator | ||
from syngular.tensor import DifferentialMatrixProductOperator | ||
|
||
import numpy as np | ||
|
||
n = 8 | ||
tensor_mpo = np.arange(2**(2*n)).reshape(*((2,)*(2*n))).astype('float64') | ||
tensor_mps = np.arange(2**(n)).reshape(*(2,)*n).astype('float64') | ||
|
||
mpo = DifferentialMatrixProductOperator(tensor_mpo, bond_shape=(2,)*(n-1)).decompose() | ||
mps = MatrixProductState(tensor_mps, bond_shape=(2,)*(n-1)).decompose() | ||
|
||
|
||
print(mpo) | ||
print(mps) | ||
|
||
mpo.project(3, mps) |
Oops, something went wrong.