diff --git a/docs/index.rst b/docs/index.rst index 2c508f6..a0f024c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -86,6 +86,11 @@ Note that unlike in NumPy or PyTorch, equations are pre-compiled using :py:func:`~torch_semiring_einsum.compile_equation` rather than re-parsed from scratch every time einsum is called. +Derived Functions +----------------- + +For convenience, the module also implements functions `tensordot`, `matmul`, `inner`, `dot`, `mm`, `bmm`, `mv`, and `outer` in terms of einsum. All of these functions take a `block_size` argument and an `einsum` argument, which defaults to `torch_semiring_einsum.einsum` but can be set to other einsum replacements like `log_einsum`, etc. + API Documentation ----------------- diff --git a/tests/test_derived.py b/tests/test_derived.py new file mode 100644 index 0000000..d4b8dcc --- /dev/null +++ b/tests/test_derived.py @@ -0,0 +1,79 @@ +import unittest +import torch +import torch_semiring_einsum as semiring +import numpy + +class TestDerived(unittest.TestCase): + def setUp(self): + self.device = torch.device('cpu') + self.generator = torch.manual_seed(123) + + def test_tensordot(self): + A, B, C, D, E, F = 2, 3, 5, 7, 11, 13 + for i, (x_size, y_size, inner_dims) in enumerate([ + ((A, B, C, D), (C, D, E, F), 2), + ((A, B, C, D), (D, E, F), 1), + ((A, B, C, D), (E, F), 0), + ]): + with self.subTest(i): + x = torch.empty(x_size, device=self.device) + x.uniform_(-10., 10., generator=self.generator) + y = torch.empty(y_size, device=self.device) + y.uniform_(-10., 10., generator=self.generator) + + semiring_out = semiring.tensordot(x, y, inner_dims, block_size=1) + torch_out = torch.tensordot(x, y, inner_dims) + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3) + + semiring_out = semiring.tensordot(x, y, inner_dims, block_size=1, einsum=semiring.log_einsum) + torch_out = torch.tensordot(x.exp(), y.exp(), inner_dims).log() + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2) + + def test_matmul(self): + J, K, M, N, P = 2, 3, 5, 7, 11 + + for i, (x_size, y_size) in enumerate([ + # ((J, 1, N, M), (K, M, P)), # from torch.matmul docs + ((J, K, N, M), (K, M, P)), + ((J, K, N, M), (M,)), + ((M), (K, M, P)), + ]): + with self.subTest(i): + x = torch.empty(x_size) + x.uniform_(-10., 10., generator=self.generator) + y = torch.empty(y_size) + y.uniform_(-10., 10., generator=self.generator) + + semiring_out = semiring.matmul(x, y, block_size=1) + torch_out = torch.matmul(x, y) + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3) + + semiring_out = semiring.matmul(x, y, block_size=1, einsum=semiring.log_einsum) + torch_out = torch.matmul(x.exp(), y.exp()).log() + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2) + + def test_inner(self): + A, B, C, D, E = 2, 3, 5, 7, 11 + + for i, (x_size, y_size) in enumerate([ + ((A, B), (C, D, B)), + ((A, B), ()), + ((), (A, B)), + ]): + with self.subTest(i): + x = torch.empty(x_size) + x.uniform_(-10., 10., generator=self.generator) + y = torch.empty(y_size) + y.uniform_(-10., 10., generator=self.generator) + + semiring_out = semiring.inner(x, y, block_size=1) + torch_out = torch.inner(x, y) + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3) + + semiring_out = semiring.inner(x, y, block_size=1, einsum=semiring.log_einsum) + torch_out = torch.inner(x.exp(), y.exp()).log() + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2) + + +if __name__ == '__main__': + unittest.main() diff --git a/torch_semiring_einsum/__init__.py b/torch_semiring_einsum/__init__.py index 3cf2402..7a6740d 100644 --- a/torch_semiring_einsum/__init__.py +++ b/torch_semiring_einsum/__init__.py @@ -10,6 +10,7 @@ from .log_backward import log_einsum_backward from .log_differentiable import log_einsum from .log_viterbi_forward import log_viterbi_einsum_forward +from .derived import * __all__ = [ 'compile_equation', diff --git a/torch_semiring_einsum/derived.py b/torch_semiring_einsum/derived.py new file mode 100644 index 0000000..6162f5b --- /dev/null +++ b/torch_semiring_einsum/derived.py @@ -0,0 +1,94 @@ +__all__ = ['tensordot', 'matmul', 'inner', 'dot', 'mm', 'bmm', 'mv', 'outer'] + +from .real_forward import real_einsum_forward +from .real_backward import real_einsum_backward +from .equation import compile_equation +from .function import combine + +default_einsum = combine(real_einsum_forward, real_einsum_backward) + +def index_range(start, stop): + e = [] + for i in range(start, stop): + e.append(chr(ord('a')+i)) + return ''.join(e) + +def tensordot(a, b, ndim, *, block_size, einsum=default_einsum): + if isinstance(ndim, (tuple, list)): + raise NotImplementedError() + e = (index_range(0, a.ndim) + + ',' + + index_range(a.ndim-ndim,a.ndim+b.ndim-ndim) + + '->' + + index_range(0, a.ndim-ndim) + + index_range(a.ndim, a.ndim+b.ndim-ndim)) + e = compile_equation(e) + return einsum(e, a, b, block_size=block_size) + +def matmul(a, b, *, block_size, einsum=default_einsum): + """Like torch.matmul""" + if a.ndim == 0 or b.ndim == 0: + raise ValueError('matmul of 0-dimensional tensors is not allowed') + + ndim = max(a.ndim, b.ndim) + + oi = index_range(3, ndim+1) + if a.ndim == 1: + ai = 'b' + else: + ai = index_range(ndim+1-(a.ndim-2), ndim+1) + 'ab' + oi += 'a' + + if b.ndim == 1: + bi = 'b' + else: + bi = index_range(ndim+1-(b.ndim-2), ndim+1) + 'bc' + oi += 'c' + + e = compile_equation(ai+','+bi+'->'+oi) + return einsum(e, a, b, block_size=block_size) + +def inner(a, b, *, block_size, einsum=default_einsum): + if a.ndim == 0: + e = ','+index_range(0, b.ndim) + '->' + index_range(0, b.ndim) + elif b.ndim == 0: + e = index_range(0, a.ndim) + ',->' + index_range(0, a.ndim) + else: + ai = index_range(1, a.ndim) + bi = index_range(a.ndim+1, a.ndim+b.ndim) + e = ai + 'a,' + bi + 'a->' + ai + bi + e = compile_equation(e) + return einsum(e, a, b, block_size=block_size) + +dot_equation = compile_equation('i,i->i') +def dot(a, b, *, block_size, einsum=default_einsum): + if a.ndim != 1 or b.ndim != 1: + raise ValueError('arguments must be 1-dimensional') + return einsum(dot_equation, a, b, block_size=block_size) + +mm_equation = compile_equation('ij,jk->ik') +def mm(a, b, *, block_size, einsum=default_einsum): + if a.ndim != 2 or b.ndim != 2: + raise ValueError('arguments must be 2-dimensional') + return einsum(mm_equation, a, b, block_size=block_size) + +mv_equation = compile_equation('ij,j->i') +def mv(a, b, *, block_size, einsum=default_einsum): + if a.ndim != 2 or b.ndim != 1: + raise ValueError('arguments must be 2-dimensional and 1-dimensional, respectively') + return einsum(mv_equation, a, b, block_size=block_size) + +bmm_equation = compile_equation('bij,bjk->bik') +def bmm(a, b, *, block_size, einsum=default_einsum): + if a.ndim != 3 or b.ndim != 3: + raise ValueError('arguments must be 3-dimensional') + return einsum(bmm_equation, a, b, block_size=block_size) + +outer_equation = compile_equation('i,j->ij') +def outer(a, b, *, block_size, einsum=default_einsum): + if a.ndim != 1 or b.ndim != 1: + raise ValueError('arguments must be 1-dimensional') + return einsum(outer_equation, a, b, block_size=block_size) +ger = outer + +