From 40bced10b985df18d6343ad9bcfc6b38615e8734 Mon Sep 17 00:00:00 2001 From: Antoine Debouchage Date: Tue, 1 Feb 2022 17:45:24 +0100 Subject: [PATCH] Diffential Matrix Product State first implementation; projected wings method --- core/model.py | 7 + tensor/__init__.py | 3 +- .../differential_matrix_product_operator.py | 123 ++++++++ test/test_differentiate_mpo.py | 18 ++ test/test_model.py | 269 ++++-------------- 5 files changed, 200 insertions(+), 220 deletions(-) create mode 100644 tensor/differential_matrix_product_operator.py create mode 100644 test/test_differentiate_mpo.py diff --git a/core/model.py b/core/model.py index 4ffa37c..06e0147 100644 --- a/core/model.py +++ b/core/model.py @@ -17,6 +17,13 @@ def predict(self, inputs): return values + def feed_forward(self, sample): + values = sample + for layer in self.layers: + values = layer(values) + return values + + def build(self): for layer in self.layers: if not layer.is_built: diff --git a/tensor/__init__.py b/tensor/__init__.py index cffb22a..8787b7e 100644 --- a/tensor/__init__.py +++ b/tensor/__init__.py @@ -1,3 +1,4 @@ from syngular.tensor.matrix_product_state import MatrixProductState from syngular.tensor.matrix_product_operator import MatrixProductOperator -from syngular.tensor.matrix_product_density_operator import MatrixProductDensityOperator \ No newline at end of file +from syngular.tensor.matrix_product_density_operator import MatrixProductDensityOperator +from syngular.tensor.differential_matrix_product_operator import DifferentialMatrixProductOperator \ No newline at end of file diff --git a/tensor/differential_matrix_product_operator.py b/tensor/differential_matrix_product_operator.py new file mode 100644 index 0000000..8ac4917 --- /dev/null +++ b/tensor/differential_matrix_product_operator.py @@ -0,0 +1,123 @@ +from turtle import right +from syngular.tensor import MatrixProductOperator, MatrixProductState + +from opt_einsum import contract +import numpy as np + +from typing import Union, List, Any + +class DifferentialMatrixProductOperator(MatrixProductOperator): + + def gradient(self, loss): + pass + + def crumble_site(self, index: int) -> Any: + left_site = self.sites[index] + right_site = self.sites[index+1] + + print(left_site.shape) + print(right_site.shape) + + crumbled_site = contract(*[ + left_site, [1, 2, 3, 4], + right_site, [4, 5, 6, 7], + [1, 2, 3, 5, 6, 7] + ]) + + return crumbled_site + + def project(self, index, state): + + if not isinstance(state, MatrixProductState): + raise Exception("projected wings should come from a matrix product state input") + + if not (0 <= index < self.sites_number-1): + raise Exception("trying to project on non-existant site (site indices should be between 0 and the number of sites - 1)") + + crumbled_site = self.crumble_site(index) + + print(crumbled_site.shape) + + """ Left wing + + (blank, blank, output, ..., output, mps_bond, mpo_bond) + | | | | + ---O---O---O---O--- + | | | | + ---O---O---O---O--- + + => + + (blank, output*...*output_n, mps_bond, mpo_bond) + | + ---[ ]=== + """ + + struct = [] + n = (index+1) + + for idx in range(index): + struct += [state.sites[idx], [ idx+1, n + idx+1, idx+2]] + struct += [self.sites[idx], [ 2*n+idx, n+idx+1, 3*n+idx, 2*n+idx+1]] + + struct += [[1, 2*n] + list(range(3*n, 4*n-1)) + [n, 3*n-1]] + + raw_left_wing = contract(*struct) + squeezed_left_wing = np.squeeze(raw_left_wing, axis=0) + + left_wing = np.transpose(squeezed_left_wing, axes=(list(range(1, n)) + [0, n, n+1])) + leftover_shape = left_wing.shape[n-1:] + left_wing = np.reshape(left_wing, newshape=(-1, *leftover_shape)) + left_wing = np.transpose(left_wing, axes=(1, 0, 2, 3)) + + print("Left wing", left_wing.shape) + + + """ Right wing + + (mps_bond, mpo_bond, output, ..., output, blank, blank) + | | | | + ---O---O---O---O--- + | | | | + ---O---O---O---O--- + + => + + (mps_bond, mpo_bond, output*...*output_n, blank) + | + ===[ ]--- + """ + + struct = [] + n = (self.sites_number - index - 1) + + for jdx in range(index+2, self.sites_number): + idx = jdx-index-2 + struct += [state.sites[jdx], [ idx+1, n + idx+1, idx+2]] + struct += [self.sites[jdx], [ 2*n+idx, n+idx+1, 3*n+idx, 2*n+idx+1]] + + struct += [[1, 2*n] + list(range(3*n, 4*n-1)) + [n, 3*n-1]] + + raw_right_wing = contract(*struct) + squeezed_right_wing = np.squeeze(raw_right_wing, axis=-1) + + right_wing = np.transpose(squeezed_right_wing, axes=(list(range(2, n+1)) + [0, 1, n+1])) + leftover_shape = right_wing.shape[n-1:] + right_wing = np.reshape(right_wing, newshape=(-1, *leftover_shape)) + right_wing = np.transpose(right_wing, axes=(1, 2, 0, 3)) + + print("Right wing", right_wing.shape) + + left_center = state.sites[index] + right_center = state.sites[index+1] + + print("Left Center", left_center.shape) + print("Right Center", right_center.shape) + + return { + 'center_site': crumbled_site, + 'left_wing': left_wing, + 'left_center': left_center, + 'right_center': right_center, + 'right_wing': right_wing + } \ No newline at end of file diff --git a/test/test_differentiate_mpo.py b/test/test_differentiate_mpo.py new file mode 100644 index 0000000..cd81db5 --- /dev/null +++ b/test/test_differentiate_mpo.py @@ -0,0 +1,18 @@ +from syngular.tensor import MatrixProductState +from syngular.tensor import MatrixProductOperator +from syngular.tensor import DifferentialMatrixProductOperator + +import numpy as np + +n = 8 +tensor_mpo = np.arange(2**(2*n)).reshape(*((2,)*(2*n))).astype('float64') +tensor_mps = np.arange(2**(n)).reshape(*(2,)*n).astype('float64') + +mpo = DifferentialMatrixProductOperator(tensor_mpo, bond_shape=(2,)*(n-1)).decompose() +mps = MatrixProductState(tensor_mps, bond_shape=(2,)*(n-1)).decompose() + + +print(mpo) +print(mps) + +mpo.project(3, mps) \ No newline at end of file diff --git a/test/test_model.py b/test/test_model.py index a9a893c..3cf6b74 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -5,222 +5,53 @@ import numpy as np -w = np.arange(4**6).reshape(4,4,4,4,4,4).astype('float64') / 1.6e5 -W = MatrixProductOperator(w, bond_shape=(4,4,)) -W.decompose() - -model = Model([ - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - Linear(4**3,4**3,core=3, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,4,core=4, bond=4, weights_initializer=W), - # Linear(4,8,core=4,bond=4), - Output((4**3,)) -]) -model.build() - -print("---") -print(model.draw()) -print("---") -print("\n") -print(model.describe()) - -x = np.arange(4**3).reshape(4,4,4) -X = MatrixProductState(x, bond_shape=(4,4,)).decompose() - -y = model.predict(X) -print(y) - -w_ = w.reshape((4**3,4**3)) -print(x.reshape((4**3,)) @ w_ @ w_ @ w_ @ w_ @ w_ @ w_ @ w_ @ w_ @ w_ @ w_ @ w_ @ w_ @ w_ @ w_) +# w = np.arange(2**6).reshape(2,2,2,2,2,2).astype('float64') +# W = MatrixProductOperator(w, bond_shape=(2,2,)) +# W.decompose() + +def simple_model(): + + core = 2 + input_dim = 2 + output_dim = 2 + bond_dim = 2 + + input_size = input_dim**core + output_size = 2**core + + input_shape = (input_dim,) * core + output_shape = (output_dim,) * core + bond_shape = (bond_dim,) * (core-1) + + model = Model([ + Linear(input_size, output_size, core=core, bond=bond_dim), + Linear(input_size, output_size, core=core, bond=bond_dim), + Output((output_size, )) + ]) + model.build() + + print(model.describe()) + + + x = np.arange(input_size).reshape(input_shape) + X = MatrixProductState(x, bond_shape=bond_shape).decompose() + + y = model.predict(X) + print("Prediction", y) + print('\n') + + train_df = [X] + + for epoch in range(5): + print(f"Epoch {str(epoch+1)} : ") + + for sample in train_df: + model.feed_forward(sample) + # for layer in model.layers: + # for weight in layer.trainable_tensor_weights: + # weight["weight"] += MatrixProductOperator.random(input_shape, output_shape, bond_shape) + y = model.predict(X) + print("Prediction", y) + + +simple_model() \ No newline at end of file