Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derived functions #11

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ Note that unlike in NumPy or PyTorch, equations are pre-compiled using
:py:func:`~torch_semiring_einsum.compile_equation` rather than re-parsed from
scratch every time einsum is called.

Derived Functions
-----------------

For convenience, the module also implements functions `tensordot`, `matmul`, `inner`, `dot`, `mm`, `bmm`, `mv`, and `outer` in terms of einsum. All of these functions take a `block_size` argument and an `einsum` argument, which defaults to `torch_semiring_einsum.einsum` but can be set to other einsum replacements like `log_einsum`, etc.

API Documentation
-----------------

Expand Down
79 changes: 79 additions & 0 deletions tests/test_derived.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import unittest
import torch
import torch_semiring_einsum as semiring
import numpy

class TestDerived(unittest.TestCase):
def setUp(self):
self.device = torch.device('cpu')
self.generator = torch.manual_seed(123)

def test_tensordot(self):
A, B, C, D, E, F = 2, 3, 5, 7, 11, 13
for i, (x_size, y_size, inner_dims) in enumerate([
((A, B, C, D), (C, D, E, F), 2),
((A, B, C, D), (D, E, F), 1),
((A, B, C, D), (E, F), 0),
]):
with self.subTest(i):
x = torch.empty(x_size, device=self.device)
x.uniform_(-10., 10., generator=self.generator)
y = torch.empty(y_size, device=self.device)
y.uniform_(-10., 10., generator=self.generator)

semiring_out = semiring.tensordot(x, y, inner_dims, block_size=1)
torch_out = torch.tensordot(x, y, inner_dims)
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3)

semiring_out = semiring.tensordot(x, y, inner_dims, block_size=1, einsum=semiring.log_einsum)
torch_out = torch.tensordot(x.exp(), y.exp(), inner_dims).log()
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2)

def test_matmul(self):
J, K, M, N, P = 2, 3, 5, 7, 11

for i, (x_size, y_size) in enumerate([
# ((J, 1, N, M), (K, M, P)), # from torch.matmul docs
((J, K, N, M), (K, M, P)),
((J, K, N, M), (M,)),
((M), (K, M, P)),
]):
with self.subTest(i):
x = torch.empty(x_size)
x.uniform_(-10., 10., generator=self.generator)
y = torch.empty(y_size)
y.uniform_(-10., 10., generator=self.generator)

semiring_out = semiring.matmul(x, y, block_size=1)
torch_out = torch.matmul(x, y)
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3)

semiring_out = semiring.matmul(x, y, block_size=1, einsum=semiring.log_einsum)
torch_out = torch.matmul(x.exp(), y.exp()).log()
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2)

def test_inner(self):
A, B, C, D, E = 2, 3, 5, 7, 11

for i, (x_size, y_size) in enumerate([
((A, B), (C, D, B)),
((A, B), ()),
((), (A, B)),
]):
with self.subTest(i):
x = torch.empty(x_size)
x.uniform_(-10., 10., generator=self.generator)
y = torch.empty(y_size)
y.uniform_(-10., 10., generator=self.generator)

semiring_out = semiring.inner(x, y, block_size=1)
torch_out = torch.inner(x, y)
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3)

semiring_out = semiring.inner(x, y, block_size=1, einsum=semiring.log_einsum)
torch_out = torch.inner(x.exp(), y.exp()).log()
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions torch_semiring_einsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .log_backward import log_einsum_backward
from .log_differentiable import log_einsum
from .log_viterbi_forward import log_viterbi_einsum_forward
from .derived import *

__all__ = [
'compile_equation',
Expand Down
94 changes: 94 additions & 0 deletions torch_semiring_einsum/derived.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
__all__ = ['tensordot', 'matmul', 'inner', 'dot', 'mm', 'bmm', 'mv', 'outer']

from .real_forward import real_einsum_forward
from .real_backward import real_einsum_backward
from .equation import compile_equation
from .function import combine

default_einsum = combine(real_einsum_forward, real_einsum_backward)

def index_range(start, stop):
e = []
for i in range(start, stop):
e.append(chr(ord('a')+i))
return ''.join(e)

def tensordot(a, b, ndim, *, block_size, einsum=default_einsum):
if isinstance(ndim, (tuple, list)):
raise NotImplementedError()
e = (index_range(0, a.ndim) +
',' +
index_range(a.ndim-ndim,a.ndim+b.ndim-ndim) +
'->' +
index_range(0, a.ndim-ndim) +
index_range(a.ndim, a.ndim+b.ndim-ndim))
e = compile_equation(e)
return einsum(e, a, b, block_size=block_size)

def matmul(a, b, *, block_size, einsum=default_einsum):
"""Like torch.matmul"""
if a.ndim == 0 or b.ndim == 0:
raise ValueError('matmul of 0-dimensional tensors is not allowed')

ndim = max(a.ndim, b.ndim)

oi = index_range(3, ndim+1)
if a.ndim == 1:
ai = 'b'
else:
ai = index_range(ndim+1-(a.ndim-2), ndim+1) + 'ab'
oi += 'a'

if b.ndim == 1:
bi = 'b'
else:
bi = index_range(ndim+1-(b.ndim-2), ndim+1) + 'bc'
oi += 'c'

e = compile_equation(ai+','+bi+'->'+oi)
return einsum(e, a, b, block_size=block_size)

def inner(a, b, *, block_size, einsum=default_einsum):
if a.ndim == 0:
e = ','+index_range(0, b.ndim) + '->' + index_range(0, b.ndim)
elif b.ndim == 0:
e = index_range(0, a.ndim) + ',->' + index_range(0, a.ndim)
else:
ai = index_range(1, a.ndim)
bi = index_range(a.ndim+1, a.ndim+b.ndim)
e = ai + 'a,' + bi + 'a->' + ai + bi
e = compile_equation(e)
return einsum(e, a, b, block_size=block_size)

dot_equation = compile_equation('i,i->i')
def dot(a, b, *, block_size, einsum=default_einsum):
if a.ndim != 1 or b.ndim != 1:
raise ValueError('arguments must be 1-dimensional')
return einsum(dot_equation, a, b, block_size=block_size)

mm_equation = compile_equation('ij,jk->ik')
def mm(a, b, *, block_size, einsum=default_einsum):
if a.ndim != 2 or b.ndim != 2:
raise ValueError('arguments must be 2-dimensional')
return einsum(mm_equation, a, b, block_size=block_size)

mv_equation = compile_equation('ij,j->i')
def mv(a, b, *, block_size, einsum=default_einsum):
if a.ndim != 2 or b.ndim != 1:
raise ValueError('arguments must be 2-dimensional and 1-dimensional, respectively')
return einsum(mv_equation, a, b, block_size=block_size)

bmm_equation = compile_equation('bij,bjk->bik')
def bmm(a, b, *, block_size, einsum=default_einsum):
if a.ndim != 3 or b.ndim != 3:
raise ValueError('arguments must be 3-dimensional')
return einsum(bmm_equation, a, b, block_size=block_size)

outer_equation = compile_equation('i,j->ij')
def outer(a, b, *, block_size, einsum=default_einsum):
if a.ndim != 1 or b.ndim != 1:
raise ValueError('arguments must be 1-dimensional')
return einsum(outer_equation, a, b, block_size=block_size)
ger = outer