From b0c63dff91ad579c897b71227d4febfc1e7bb903 Mon Sep 17 00:00:00 2001 From: David Chiang Date: Wed, 29 Dec 2021 08:40:14 -0500 Subject: [PATCH 1/4] document log_einsum and (renamed) log_viterbi_einsum --- docs/index.rst | 2 ++ torch_semiring_einsum/__init__.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index b46f2cb..b0db499 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -86,6 +86,8 @@ Note that unlike in NumPy or PyTorch, equations are pre-compiled using :py:func:`~torch_semiring_einsum.compile_equation` rather than re-parsed from scratch every time einsum is called. +In addition to `einsum`, the module also exposes a differentiable `log_einsum` and a non-differentiable `log_viterbi_einsum`. + API Documentation ----------------- diff --git a/torch_semiring_einsum/__init__.py b/torch_semiring_einsum/__init__.py index eede46c..4e023d1 100644 --- a/torch_semiring_einsum/__init__.py +++ b/torch_semiring_einsum/__init__.py @@ -21,6 +21,8 @@ :py:func:`log_einsum_backward` into one auto-differentiable function. """ +log_viterbi_einsum = log_viterbi_einsum_forward + __all__ = [ 'compile_equation', 'real_einsum_forward', @@ -29,7 +31,7 @@ 'log_einsum_forward', 'log_einsum_backward', 'log_einsum', - 'log_viterbi_einsum_forward', + 'log_viterbi_einsum', 'semiring_einsum_forward', 'combine' ] From 08b8096d3878ce3c3d9ad0c04f67d64e9d1698b2 Mon Sep 17 00:00:00 2001 From: David Chiang Date: Tue, 4 Jan 2022 12:26:48 -0500 Subject: [PATCH 2/4] Add functions like tensordot and matmul derived from einsum --- docs/index.rst | 5 ++ tests/test_derived.py | 79 ++++++++++++++++++++++++++ torch_semiring_einsum/__init__.py | 1 + torch_semiring_einsum/derived.py | 94 +++++++++++++++++++++++++++++++ 4 files changed, 179 insertions(+) create mode 100644 tests/test_derived.py create mode 100644 torch_semiring_einsum/derived.py diff --git a/docs/index.rst b/docs/index.rst index b0db499..e2cbec4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -88,6 +88,11 @@ scratch every time einsum is called. In addition to `einsum`, the module also exposes a differentiable `log_einsum` and a non-differentiable `log_viterbi_einsum`. +Derived Functions +----------------- + +For convenience, the module also implements functions `tensordot`, `matmul`, `inner`, `dot`, `mm`, `bmm`, `mv`, and `outer` in terms of einsum. All of these functions take a `block_size` argument and an `einsum` argument, which defaults to `torch_semiring_einsum.einsum` but can be set to other einsum replacements like `log_einsum`, etc. + API Documentation ----------------- diff --git a/tests/test_derived.py b/tests/test_derived.py new file mode 100644 index 0000000..d4b8dcc --- /dev/null +++ b/tests/test_derived.py @@ -0,0 +1,79 @@ +import unittest +import torch +import torch_semiring_einsum as semiring +import numpy + +class TestDerived(unittest.TestCase): + def setUp(self): + self.device = torch.device('cpu') + self.generator = torch.manual_seed(123) + + def test_tensordot(self): + A, B, C, D, E, F = 2, 3, 5, 7, 11, 13 + for i, (x_size, y_size, inner_dims) in enumerate([ + ((A, B, C, D), (C, D, E, F), 2), + ((A, B, C, D), (D, E, F), 1), + ((A, B, C, D), (E, F), 0), + ]): + with self.subTest(i): + x = torch.empty(x_size, device=self.device) + x.uniform_(-10., 10., generator=self.generator) + y = torch.empty(y_size, device=self.device) + y.uniform_(-10., 10., generator=self.generator) + + semiring_out = semiring.tensordot(x, y, inner_dims, block_size=1) + torch_out = torch.tensordot(x, y, inner_dims) + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3) + + semiring_out = semiring.tensordot(x, y, inner_dims, block_size=1, einsum=semiring.log_einsum) + torch_out = torch.tensordot(x.exp(), y.exp(), inner_dims).log() + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2) + + def test_matmul(self): + J, K, M, N, P = 2, 3, 5, 7, 11 + + for i, (x_size, y_size) in enumerate([ + # ((J, 1, N, M), (K, M, P)), # from torch.matmul docs + ((J, K, N, M), (K, M, P)), + ((J, K, N, M), (M,)), + ((M), (K, M, P)), + ]): + with self.subTest(i): + x = torch.empty(x_size) + x.uniform_(-10., 10., generator=self.generator) + y = torch.empty(y_size) + y.uniform_(-10., 10., generator=self.generator) + + semiring_out = semiring.matmul(x, y, block_size=1) + torch_out = torch.matmul(x, y) + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3) + + semiring_out = semiring.matmul(x, y, block_size=1, einsum=semiring.log_einsum) + torch_out = torch.matmul(x.exp(), y.exp()).log() + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2) + + def test_inner(self): + A, B, C, D, E = 2, 3, 5, 7, 11 + + for i, (x_size, y_size) in enumerate([ + ((A, B), (C, D, B)), + ((A, B), ()), + ((), (A, B)), + ]): + with self.subTest(i): + x = torch.empty(x_size) + x.uniform_(-10., 10., generator=self.generator) + y = torch.empty(y_size) + y.uniform_(-10., 10., generator=self.generator) + + semiring_out = semiring.inner(x, y, block_size=1) + torch_out = torch.inner(x, y) + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3) + + semiring_out = semiring.inner(x, y, block_size=1, einsum=semiring.log_einsum) + torch_out = torch.inner(x.exp(), y.exp()).log() + numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2) + + +if __name__ == '__main__': + unittest.main() diff --git a/torch_semiring_einsum/__init__.py b/torch_semiring_einsum/__init__.py index 4e023d1..85e505b 100644 --- a/torch_semiring_einsum/__init__.py +++ b/torch_semiring_einsum/__init__.py @@ -6,6 +6,7 @@ from .log_forward import log_einsum_forward from .log_backward import log_einsum_backward from .log_viterbi_forward import log_viterbi_einsum_forward +from .derived import * einsum = combine(real_einsum_forward, real_einsum_backward) r"""Differentiable version of ordinary (real) einsum. diff --git a/torch_semiring_einsum/derived.py b/torch_semiring_einsum/derived.py new file mode 100644 index 0000000..73bc0bc --- /dev/null +++ b/torch_semiring_einsum/derived.py @@ -0,0 +1,94 @@ +__all__ = ['tensordot', 'matmul', 'inner', 'dot', 'mm', 'bmm', 'mv', 'outer'] + +from .real_forward import real_einsum_forward +from .real_backward import real_einsum_backward +from .equation import compile_equation +from .function import combine + +default_einsum = combine(real_einsum_forward, real_einsum_backward) + +def index_range(start, stop): + e = [] + for i in range(start, stop): + e.append(chr(ord('a')+i)) + return ''.join(e) + +def tensordot(a, b, ndim, *, block_size, einsum=default_einsum): + if isinstance(ndim, (tuple, list)): + raise NotImplementedError() + e = (index_range(0, a.ndim) + + ',' + + index_range(a.ndim-ndim,a.ndim+b.ndim-ndim) + + '->' + + index_range(0, a.ndim-ndim) + + index_range(a.ndim, a.ndim+b.ndim-ndim)) + e = compile_equation(e) + return einsum(e, a, b, block_size=block_size) + +def matmul(a, b, *, block_size, einsum=default_einsum): + """Like torch.matmul""" + if a.ndim == 0 or b.ndim == 0: + raise ValueError('matmul of 0-dimensional tensors is not allowed') + + ndim = max(a.ndim, b.ndim) + + oi = index_range(3, ndim+1) + if a.ndim == 1: + ai = 'b' + else: + ai = index_range(ndim+1-(a.ndim-2), ndim+1) + 'ab' + oi += 'a' + + if b.ndim == 1: + bi = 'b' + else: + bi = index_range(ndim+1-(b.ndim-2), ndim+1) + 'bc' + oi += 'c' + + e = compile_equation(ai+','+bi+'->'+oi) + return einsum(e, a, b, block_size=block_size) + +def inner(a, b, *, block_size, einsum=default_einsum): + if a.ndim == 0: + e = ','+index_range(0, b.ndim) + '->' + index_range(0, b.ndim) + elif b.ndim == 0: + e = index_range(0, a.ndim) + ',->' + index_range(0, a.ndim) + else: + ai = index_range(1, a.ndim) + bi = index_range(a.ndim+1, a.ndim+b.ndim) + e = ai + 'a,' + bi + 'a->' + ai + bi + e = compile_equation(e) + return einsum(e, a, b, block_size=block_size) + +dot_equation = compile_equation('i,i->i') +def dot(a, b, *, block_size, einsum=default_einsum): + if a.ndim != 1 or b.ndim != 1: + raise ValueError('arguments must be 1-dimensional') + return einsum(dot_equation, a, b, block_size=block_size) + +mm_equation = compile_equation('ij,jk->ik') +def mm(a, b, *, block_size, einsum=default_einsum): + if a.ndim != 2 or b.ndim != 2: + raise ValueError('arguments must be 2-dimensional') + return einsum(mm_equation, a, b, block_size=block_size) + +mv_equation = compile_equation('ij,j->i') +def mm(a, b, *, block_size, einsum=default_einsum): + if a.ndim != 2 or b.ndim != 1: + raise ValueError('arguments must be 2-dimensional and 1-dimensional, respectively') + return einsum(mv_equation, a, b, block_size=block_size) + +bmm_equation = compile_equation('bij,bjk->bik') +def bmm(a, b, *, block_size, einsum=default_einsum): + if a.ndim != 3 or b.ndim != 3: + raise ValueError('arguments must be 3-dimensional') + return einsum(bmm_equation, a, b, block_size=block_size) + +outer_equation = compile_equation('i,j->ij') +def outer(a, b, *, block_size, einsum=default_einsum): + if a.ndim != 1 or b.ndim != 1: + raise ValueError('arguments must be 1-dimensional') + return einsum(outer_equation, a, b, block_size=block_size) +ger = outer + + From e146f03dae9dc454eee79b0d61fee282fdebdad6 Mon Sep 17 00:00:00 2001 From: David Chiang Date: Tue, 4 Jan 2022 14:47:30 -0500 Subject: [PATCH 3/4] tiny fix --- torch_semiring_einsum/derived.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_semiring_einsum/derived.py b/torch_semiring_einsum/derived.py index 73bc0bc..6162f5b 100644 --- a/torch_semiring_einsum/derived.py +++ b/torch_semiring_einsum/derived.py @@ -73,7 +73,7 @@ def mm(a, b, *, block_size, einsum=default_einsum): return einsum(mm_equation, a, b, block_size=block_size) mv_equation = compile_equation('ij,j->i') -def mm(a, b, *, block_size, einsum=default_einsum): +def mv(a, b, *, block_size, einsum=default_einsum): if a.ndim != 2 or b.ndim != 1: raise ValueError('arguments must be 2-dimensional and 1-dimensional, respectively') return einsum(mv_equation, a, b, block_size=block_size) From eb6a3456be3391572eace35bdf10dcc9e5cb885e Mon Sep 17 00:00:00 2001 From: David Chiang Date: Wed, 20 Jul 2022 09:37:10 -0400 Subject: [PATCH 4/4] clean up --- docs/index.rst | 2 -- torch_semiring_einsum/__init__.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 12f77eb..a0f024c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -86,8 +86,6 @@ Note that unlike in NumPy or PyTorch, equations are pre-compiled using :py:func:`~torch_semiring_einsum.compile_equation` rather than re-parsed from scratch every time einsum is called. -In addition to `einsum`, the module also exposes a differentiable `log_einsum` and a non-differentiable `log_viterbi_einsum`. - Derived Functions ----------------- diff --git a/torch_semiring_einsum/__init__.py b/torch_semiring_einsum/__init__.py index bccdcb9..7a6740d 100644 --- a/torch_semiring_einsum/__init__.py +++ b/torch_semiring_einsum/__init__.py @@ -20,7 +20,7 @@ 'log_einsum_forward', 'log_einsum_backward', 'log_einsum', - 'log_viterbi_einsum', + 'log_viterbi_einsum_forward', 'semiring_einsum_forward', 'combine' ]