Skip to content

Commit 7e48b74

Browse files
authored
Memoize Transform.inv (#885)
1 parent ee5bf10 commit 7e48b74

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

numpyro/distributions/transforms.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import math
55
import warnings
6+
import weakref
67

78
import numpy as np
89

@@ -52,6 +53,7 @@ def _clipped_expit(x):
5253
class Transform(object):
5354
domain = constraints.real
5455
codomain = constraints.real
56+
_inv = None
5557

5658
@property
5759
def event_dim(self):
@@ -62,7 +64,13 @@ def event_dim(self):
6264

6365
@property
6466
def inv(self):
65-
return _InverseTransform(self)
67+
inv = None
68+
if self._inv is not None:
69+
inv = self._inv()
70+
if inv is None:
71+
inv = _InverseTransform(self)
72+
self._inv = weakref.ref(inv)
73+
return inv
6674

6775
def __call__(self, x):
6876
return NotImplementedError

test/test_distributions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ def test_bijective_transforms(transform, event_shape, batch_shape):
10381038
z = transform.inv(y)
10391039
assert_allclose(x, z, atol=1e-6, rtol=1e-6)
10401040
assert transform.inv.inv is transform
1041+
assert transform.inv is transform.inv
10411042
assert transform.domain is transform.inv.codomain
10421043
assert transform.codomain is transform.inv.domain
10431044

0 commit comments

Comments
 (0)