Skip to content

WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough #562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Gamma,
GaussianRandomWalk,
Gumbel,
GumbelSoftmaxProbs,
HalfCauchy,
HalfNormal,
InverseGamma,
Expand Down Expand Up @@ -88,6 +89,7 @@
'GammaPoisson',
'GaussianRandomWalk',
'Gumbel',
'GumbelSoftmaxProbs',
'HalfCauchy',
'HalfNormal',
'ImproperUniform',
Expand Down
59 changes: 59 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from numpyro.distributions.transforms import AffineTransform, ExpTransform, InvCholeskyTransform, PowerTransform
from numpyro.distributions.util import (
cholesky_of_inverse,
cumsum,
get_dtype,
gumbel_softmax_probs,
lazy_property,
matrix_to_tril_vec,
promote_shapes,
Expand Down Expand Up @@ -402,6 +405,62 @@ def variance(self):
return jnp.broadcast_to(2 * self.scale ** 2, self.batch_shape)


@copy_docs_from(Distribution)
class GumbelSoftmaxProbs(Distribution):

arg_constraints = {'probs': constraints.simplex,
'temperature': constraints.real}

def __init__(self, probs, temperature=1., validate_args=None):
if np.ndim(probs) < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
self.probs = probs
batch_shape, event_shape = probs.shape[:-1], probs.shape[-1:]
self.k = probs.shape[-1]
self.standard_gumbel = Gumbel()
self.temperature = temperature
super(GumbelSoftmaxProbs, self).__init__(batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args)

def sample(self, key, sample_shape=(), one_hot=True):
return gumbel_softmax_probs(key, self.probs, shape=sample_shape + self.batch_shape + self.event_shape,
temperature=self.temperature, hard=False, one_hot=one_hot)

@validate_sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
probs = self.probs
k = self.k
temperature = self.temperature

# TODO: ENSURE ONE HOT, OTHERWISE CONVERT
eps = np.finfo(probs.dtype).eps # TODO: Make depending on value? But then must be inexact
# We clip the ys to be positive to have a positive denominator
ys = np.clip(value, a_min=eps)
probs = promote_shapes(probs, shape=ys.shape)[0]
res = gammaln(k) + (k-1) * np.log(temperature)
# lax._safe_mul(x, 1/ ys**temperature)
res += -k * np.log((probs / (ys**temperature)).sum(axis=-1))
res += np.sum(np.log(probs), axis=-1) - (temperature + 1) * np.sum(np.log(ys), axis=-1)
return res

@property
def mean(self):
# FIXME: correct
return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs))

@property
def variance(self):
# FIXME: correct
return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs))

@property
def support(self):
return constraints.integer_interval(0, np.shape(self.probs)[-1])


@copy_docs_from(Distribution)
class LKJ(TransformedDistribution):
r"""
Expand Down
7 changes: 2 additions & 5 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
multinomial,
promote_shapes,
sum_rightmost,
validate_sample
_to_probs_multinom,
validate_sample,
)
from numpyro.util import copy_docs_from, not_jax_tracer

Expand All @@ -60,10 +61,6 @@ def _to_logits_bernoulli(probs):
return jnp.log(ps_clamped) - jnp.log1p(-ps_clamped)


def _to_probs_multinom(logits):
return softmax(logits, axis=-1)


def _to_logits_multinom(probs):
minval = jnp.finfo(get_dtype(probs)).min
return jnp.clip(jnp.log(probs), a_min=minval)
Expand Down
5 changes: 5 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jax import jit, lax, random, vmap
from jax.dtypes import canonicalize_dtype
from jax.lib import xla_bridge
from jax.nn import softmax
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax.util import partial
Expand Down Expand Up @@ -215,6 +216,10 @@ def multinomial(key, p, n, shape=()):
return _multinomial(key, p, n, n_max, shape)


def _to_probs_multinom(logits):
return softmax(logits, axis=-1)


def cholesky_of_inverse(matrix):
# This formulation only takes the inverse of a triangular matrix
# which is more numerically stable.
Expand Down
37 changes: 35 additions & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def sample(self, key, sample_shape=()):
T(dist.GaussianRandomWalk, jnp.array([0.1, 0.3, 0.25]), 10),
T(dist.Gumbel, 0., 1.),
T(dist.Gumbel, 0.5, 2.),
T(dist.Gumbel, jnp.array([0., 0.5]), jnp.array([1., 2.])),
T(dist.Gumbel, np.array([0., 0.5]), np.array([1., 2.])),
T(dist.GumbelSoftmaxProbs, np.array([0.1, 0.2, 0.3, 0.4]), 0.0001),
T(dist.GumbelSoftmaxProbs, np.array([0.1, 0.2, 0.3, 0.4]), 100),
T(dist.HalfCauchy, 1.),
T(dist.HalfCauchy, jnp.array([1., 2.])),
T(dist.HalfNormal, 1.),
Expand Down Expand Up @@ -697,7 +699,6 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
valid_params[i] = gen_values_within_bounds(constraint, jnp.shape(params[i]), key_gen)

assert jax_dist(*oob_params)

# Invalid parameter values throw ValueError
if not dependent_constraint and jax_dist is not _ImproperWrapper:
with pytest.raises(ValueError):
Expand Down Expand Up @@ -1011,6 +1012,38 @@ def test_generated_sample_distribution(jax_dist, sp_dist, params,
assert ks_result.pvalue > 0.05


@pytest.mark.parametrize('temperature', [0.0001, 0.00001])
@pytest.mark.parametrize('N_sample', [10_000, 100_000])
def test_relaxations_low(temperature, N_sample, key=jax.random.PRNGKey(52)):
""" Test that samples from low temperatures are close to samples from a
Categorical distribution (and consequently that the GumbelSoftmax
distribution samples correctly).
"""

probs = np.array([0.1, 0.1, 0.8])
GS1 = dist.GumbelSoftmaxProbs(probs, temperature=temperature)
gs_samples = GS1.sample(key, (N_sample,), one_hot=False)
categorical_samples = dist.CategoricalProbs(probs).sample(key, (N_sample, ))
ks_result = osp.ks_2samp(gs_samples, categorical_samples)
assert ks_result.pvalue > 0.05


@pytest.mark.parametrize('temperature', [100, 1000])
@pytest.mark.parametrize('N_sample', [10_000, 100_000])
def test_relaxations_high(temperature, N_sample, key=jax.random.PRNGKey(52)):
""" Test that samples from high temperatures are close to samples
from a Categorical distribution with equal probabilities for the classes.
"""
probs = np.array([0.1, 0.1, 0.8])

GS1 = dist.GumbelSoftmaxProbs(probs, temperature=temperature)
gs_samples = GS1.sample(key, (N_sample,), one_hot=False)
uniform_samples = dist.Categorical(np.array([1./3, 1./3, 1./3])).sample(key, (N_sample, ))

ks_result = osp.ks_2samp(gs_samples, uniform_samples)
error_message = """failed KS betwen Gumbel Softmax and categorical with
equal probabilities for temperature {}""".format(temperature)
assert ks_result.pvalue > 0.05, error_message
@pytest.mark.parametrize('jax_dist, params, support', [
(dist.BernoulliLogits, (5.,), jnp.arange(2)),
(dist.BernoulliProbs, (0.5,), jnp.arange(2)),
Expand Down