Skip to content

Commit

Permalink
Diffential Matrix Product State first implementation; projected wings…
Browse files Browse the repository at this point in the history
… method
  • Loading branch information
antoine311200 committed Feb 1, 2022
1 parent 13eeabb commit 40bced1
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 220 deletions.
7 changes: 7 additions & 0 deletions core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def predict(self, inputs):

return values

def feed_forward(self, sample):
values = sample
for layer in self.layers:
values = layer(values)
return values


def build(self):
for layer in self.layers:
if not layer.is_built:
Expand Down
3 changes: 2 additions & 1 deletion tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from syngular.tensor.matrix_product_state import MatrixProductState
from syngular.tensor.matrix_product_operator import MatrixProductOperator
from syngular.tensor.matrix_product_density_operator import MatrixProductDensityOperator
from syngular.tensor.matrix_product_density_operator import MatrixProductDensityOperator
from syngular.tensor.differential_matrix_product_operator import DifferentialMatrixProductOperator
123 changes: 123 additions & 0 deletions tensor/differential_matrix_product_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from turtle import right
from syngular.tensor import MatrixProductOperator, MatrixProductState

from opt_einsum import contract
import numpy as np

from typing import Union, List, Any

class DifferentialMatrixProductOperator(MatrixProductOperator):

def gradient(self, loss):
pass

def crumble_site(self, index: int) -> Any:
left_site = self.sites[index]
right_site = self.sites[index+1]

print(left_site.shape)
print(right_site.shape)

crumbled_site = contract(*[
left_site, [1, 2, 3, 4],
right_site, [4, 5, 6, 7],
[1, 2, 3, 5, 6, 7]
])

return crumbled_site

def project(self, index, state):

if not isinstance(state, MatrixProductState):
raise Exception("projected wings should come from a matrix product state input")

if not (0 <= index < self.sites_number-1):
raise Exception("trying to project on non-existant site (site indices should be between 0 and the number of sites - 1)")

crumbled_site = self.crumble_site(index)

print(crumbled_site.shape)

""" Left wing
(blank, blank, output, ..., output, mps_bond, mpo_bond)
| | | |
---O---O---O---O---
| | | |
---O---O---O---O---
=>
(blank, output*...*output_n, mps_bond, mpo_bond)
|
---[ ]===
"""

struct = []
n = (index+1)

for idx in range(index):
struct += [state.sites[idx], [ idx+1, n + idx+1, idx+2]]
struct += [self.sites[idx], [ 2*n+idx, n+idx+1, 3*n+idx, 2*n+idx+1]]

struct += [[1, 2*n] + list(range(3*n, 4*n-1)) + [n, 3*n-1]]

raw_left_wing = contract(*struct)
squeezed_left_wing = np.squeeze(raw_left_wing, axis=0)

left_wing = np.transpose(squeezed_left_wing, axes=(list(range(1, n)) + [0, n, n+1]))
leftover_shape = left_wing.shape[n-1:]
left_wing = np.reshape(left_wing, newshape=(-1, *leftover_shape))
left_wing = np.transpose(left_wing, axes=(1, 0, 2, 3))

print("Left wing", left_wing.shape)


""" Right wing
(mps_bond, mpo_bond, output, ..., output, blank, blank)
| | | |
---O---O---O---O---
| | | |
---O---O---O---O---
=>
(mps_bond, mpo_bond, output*...*output_n, blank)
|
===[ ]---
"""

struct = []
n = (self.sites_number - index - 1)

for jdx in range(index+2, self.sites_number):
idx = jdx-index-2
struct += [state.sites[jdx], [ idx+1, n + idx+1, idx+2]]
struct += [self.sites[jdx], [ 2*n+idx, n+idx+1, 3*n+idx, 2*n+idx+1]]

struct += [[1, 2*n] + list(range(3*n, 4*n-1)) + [n, 3*n-1]]

raw_right_wing = contract(*struct)
squeezed_right_wing = np.squeeze(raw_right_wing, axis=-1)

right_wing = np.transpose(squeezed_right_wing, axes=(list(range(2, n+1)) + [0, 1, n+1]))
leftover_shape = right_wing.shape[n-1:]
right_wing = np.reshape(right_wing, newshape=(-1, *leftover_shape))
right_wing = np.transpose(right_wing, axes=(1, 2, 0, 3))

print("Right wing", right_wing.shape)

left_center = state.sites[index]
right_center = state.sites[index+1]

print("Left Center", left_center.shape)
print("Right Center", right_center.shape)

return {
'center_site': crumbled_site,
'left_wing': left_wing,
'left_center': left_center,
'right_center': right_center,
'right_wing': right_wing
}
18 changes: 18 additions & 0 deletions test/test_differentiate_mpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from syngular.tensor import MatrixProductState
from syngular.tensor import MatrixProductOperator
from syngular.tensor import DifferentialMatrixProductOperator

import numpy as np

n = 8
tensor_mpo = np.arange(2**(2*n)).reshape(*((2,)*(2*n))).astype('float64')
tensor_mps = np.arange(2**(n)).reshape(*(2,)*n).astype('float64')

mpo = DifferentialMatrixProductOperator(tensor_mpo, bond_shape=(2,)*(n-1)).decompose()
mps = MatrixProductState(tensor_mps, bond_shape=(2,)*(n-1)).decompose()


print(mpo)
print(mps)

mpo.project(3, mps)
Loading

0 comments on commit 40bced1

Please sign in to comment.