Skip to content

Commit

Permalink
tensordot: more doc
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanKossaifi committed Jan 2, 2023
1 parent 10d7af9 commit 11cd9e4
Showing 1 changed file with 13 additions and 34 deletions.
47 changes: 13 additions & 34 deletions tltorch/functional/factorized_tensordot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,21 @@
einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'


# def tensor_dot_tucker(tensor, tucker, modes):
# modes_tensor, modes_tucker = _validate_contraction_modes(tl.shape(tensor), tucker.tensor_shape, modes)
# input_order = tensor.ndim
# weight_order = tucker.order

# sorted_modes_tucker = sorted(modes_tucker, reverse=True)
# sorted_modes_tensor = sorted(modes_tensor, reverse=True)

# # Symbol for dimensionality of the core
# rank_sym = [einsum_symbols[i] for i in range(weight_order)]

# # Symbols for tucker weight size
# tucker_sym = [einsum_symbols[i+weight_order] for i in range(weight_order)]

# # Symbolds for input tensor
# tensor_sym = [einsum_symbols[i+2*weight_order] for i in range(tensor.ndim)]

# # Output: input + weights symbols after removing contraction symbols
# output_sym = tensor_sym + tucker_sym
# for m in sorted_modes_tucker:
# output_sym.pop(m+input_order)
# for m in sorted_modes_tensor:
# output_sym.pop(m)
# for i, e in enumerate(modes_tensor):
# tensor_sym[e] = tucker_sym[modes_tucker[i]]

# # Form the actual equation: tensor, core, factors -> output
# eq = ''.join(tensor_sym)
# eq += ',' + ''.join(rank_sym)
# eq += ',' + ','.join(f'{s}{r}' for s,r in zip(tucker_sym,rank_sym))
# eq += '->' + ''.join(output_sym)
def tensor_dot_tucker(tensor, tucker, modes, batched_modes=()):
"""Batched tensor contraction between a dense tensor and a Tucker tensor on specified modes
# return tl.einsum(eq, tensor, tucker.core, *tucker.factors)

Parameters
----------
tensor : DenseTensor
tucker : TuckerTensor
modes : int list or int
modes on which to contract tensor1 and tensor2
batched_modes : int or tuple[int]
def tensor_dot_tucker(tensor, tucker, modes, batched_modes=()):
Returns
-------
contraction : tensor contracted with cp on the specified modes
"""
modes_tensor, modes_tucker = _validate_contraction_modes(
tl.shape(tensor), tucker.tensor_shape, modes)
input_order = tensor.ndim
Expand Down

0 comments on commit 11cd9e4

Please sign in to comment.