diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 6d5a28bfe..58e5af889 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -13,6 +13,7 @@ Gamma, GaussianRandomWalk, Gumbel, + GumbelSoftmaxProbs, HalfCauchy, HalfNormal, InverseGamma, @@ -88,6 +89,7 @@ 'GammaPoisson', 'GaussianRandomWalk', 'Gumbel', + 'GumbelSoftmaxProbs', 'HalfCauchy', 'HalfNormal', 'ImproperUniform', diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 6c4846150..25445910d 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -38,6 +38,9 @@ from numpyro.distributions.transforms import AffineTransform, ExpTransform, InvCholeskyTransform, PowerTransform from numpyro.distributions.util import ( cholesky_of_inverse, + cumsum, + get_dtype, + gumbel_softmax_probs, lazy_property, matrix_to_tril_vec, promote_shapes, @@ -402,6 +405,62 @@ def variance(self): return jnp.broadcast_to(2 * self.scale ** 2, self.batch_shape) +@copy_docs_from(Distribution) +class GumbelSoftmaxProbs(Distribution): + + arg_constraints = {'probs': constraints.simplex, + 'temperature': constraints.real} + + def __init__(self, probs, temperature=1., validate_args=None): + if np.ndim(probs) < 1: + raise ValueError("`probs` parameter must be at least one-dimensional.") + self.probs = probs + batch_shape, event_shape = probs.shape[:-1], probs.shape[-1:] + self.k = probs.shape[-1] + self.standard_gumbel = Gumbel() + self.temperature = temperature + super(GumbelSoftmaxProbs, self).__init__(batch_shape=batch_shape, + event_shape=event_shape, + validate_args=validate_args) + + def sample(self, key, sample_shape=(), one_hot=True): + return gumbel_softmax_probs(key, self.probs, shape=sample_shape + self.batch_shape + self.event_shape, + temperature=self.temperature, hard=False, one_hot=one_hot) + + @validate_sample + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + probs = self.probs + k = self.k + temperature = self.temperature + + # TODO: ENSURE ONE HOT, OTHERWISE CONVERT + eps = np.finfo(probs.dtype).eps # TODO: Make depending on value? But then must be inexact + # We clip the ys to be positive to have a positive denominator + ys = np.clip(value, a_min=eps) + probs = promote_shapes(probs, shape=ys.shape)[0] + res = gammaln(k) + (k-1) * np.log(temperature) + # lax._safe_mul(x, 1/ ys**temperature) + res += -k * np.log((probs / (ys**temperature)).sum(axis=-1)) + res += np.sum(np.log(probs), axis=-1) - (temperature + 1) * np.sum(np.log(ys), axis=-1) + return res + + @property + def mean(self): + # FIXME: correct + return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) + + @property + def variance(self): + # FIXME: correct + return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) + + @property + def support(self): + return constraints.integer_interval(0, np.shape(self.probs)[-1]) + + @copy_docs_from(Distribution) class LKJ(TransformedDistribution): r""" diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 98018628b..3eb130220 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -46,7 +46,8 @@ multinomial, promote_shapes, sum_rightmost, - validate_sample + _to_probs_multinom, + validate_sample, ) from numpyro.util import copy_docs_from, not_jax_tracer @@ -60,10 +61,6 @@ def _to_logits_bernoulli(probs): return jnp.log(ps_clamped) - jnp.log1p(-ps_clamped) -def _to_probs_multinom(logits): - return softmax(logits, axis=-1) - - def _to_logits_multinom(probs): minval = jnp.finfo(get_dtype(probs)).min return jnp.clip(jnp.log(probs), a_min=minval) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 10b1c6fa2..555b079cd 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -7,6 +7,7 @@ from jax import jit, lax, random, vmap from jax.dtypes import canonicalize_dtype from jax.lib import xla_bridge +from jax.nn import softmax import jax.numpy as jnp from jax.scipy.linalg import solve_triangular from jax.util import partial @@ -215,6 +216,10 @@ def multinomial(key, p, n, shape=()): return _multinomial(key, p, n, n_max, shape) +def _to_probs_multinom(logits): + return softmax(logits, axis=-1) + + def cholesky_of_inverse(matrix): # This formulation only takes the inverse of a triangular matrix # which is more numerically stable. diff --git a/test/test_distributions.py b/test/test_distributions.py index 53dad79a1..8931308f8 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -113,7 +113,9 @@ def sample(self, key, sample_shape=()): T(dist.GaussianRandomWalk, jnp.array([0.1, 0.3, 0.25]), 10), T(dist.Gumbel, 0., 1.), T(dist.Gumbel, 0.5, 2.), - T(dist.Gumbel, jnp.array([0., 0.5]), jnp.array([1., 2.])), + T(dist.Gumbel, np.array([0., 0.5]), np.array([1., 2.])), + T(dist.GumbelSoftmaxProbs, np.array([0.1, 0.2, 0.3, 0.4]), 0.0001), + T(dist.GumbelSoftmaxProbs, np.array([0.1, 0.2, 0.3, 0.4]), 100), T(dist.HalfCauchy, 1.), T(dist.HalfCauchy, jnp.array([1., 2.])), T(dist.HalfNormal, 1.), @@ -697,7 +699,6 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): valid_params[i] = gen_values_within_bounds(constraint, jnp.shape(params[i]), key_gen) assert jax_dist(*oob_params) - # Invalid parameter values throw ValueError if not dependent_constraint and jax_dist is not _ImproperWrapper: with pytest.raises(ValueError): @@ -1011,6 +1012,38 @@ def test_generated_sample_distribution(jax_dist, sp_dist, params, assert ks_result.pvalue > 0.05 +@pytest.mark.parametrize('temperature', [0.0001, 0.00001]) +@pytest.mark.parametrize('N_sample', [10_000, 100_000]) +def test_relaxations_low(temperature, N_sample, key=jax.random.PRNGKey(52)): + """ Test that samples from low temperatures are close to samples from a + Categorical distribution (and consequently that the GumbelSoftmax + distribution samples correctly). + """ + + probs = np.array([0.1, 0.1, 0.8]) + GS1 = dist.GumbelSoftmaxProbs(probs, temperature=temperature) + gs_samples = GS1.sample(key, (N_sample,), one_hot=False) + categorical_samples = dist.CategoricalProbs(probs).sample(key, (N_sample, )) + ks_result = osp.ks_2samp(gs_samples, categorical_samples) + assert ks_result.pvalue > 0.05 + + +@pytest.mark.parametrize('temperature', [100, 1000]) +@pytest.mark.parametrize('N_sample', [10_000, 100_000]) +def test_relaxations_high(temperature, N_sample, key=jax.random.PRNGKey(52)): + """ Test that samples from high temperatures are close to samples + from a Categorical distribution with equal probabilities for the classes. + """ + probs = np.array([0.1, 0.1, 0.8]) + + GS1 = dist.GumbelSoftmaxProbs(probs, temperature=temperature) + gs_samples = GS1.sample(key, (N_sample,), one_hot=False) + uniform_samples = dist.Categorical(np.array([1./3, 1./3, 1./3])).sample(key, (N_sample, )) + + ks_result = osp.ks_2samp(gs_samples, uniform_samples) + error_message = """failed KS betwen Gumbel Softmax and categorical with + equal probabilities for temperature {}""".format(temperature) + assert ks_result.pvalue > 0.05, error_message @pytest.mark.parametrize('jax_dist, params, support', [ (dist.BernoulliLogits, (5.,), jnp.arange(2)), (dist.BernoulliProbs, (0.5,), jnp.arange(2)),