From d95aaf819fdc213b97c63d294f6daa45fe364f16 Mon Sep 17 00:00:00 2001 From: daydreamt Date: Sun, 29 Mar 2020 21:38:04 +0200 Subject: [PATCH 01/12] WIP: Add initial experiments based on pytorch's gumbel_softmax: need more reading on relaxed_categorical and transformations of distributions instead --- numpyro/distributions/__init__.py | 2 ++ numpyro/distributions/discrete.py | 53 ++++++++++++++++++++++++++++--- numpyro/distributions/util.py | 23 ++++++++++++++ 3 files changed, 73 insertions(+), 5 deletions(-) diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 297ef3587..d7d5ee513 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -38,6 +38,7 @@ CategoricalLogits, CategoricalProbs, Delta, + GumbelSoftmaxProbs, Multinomial, MultinomialLogits, MultinomialProbs, @@ -75,6 +76,7 @@ 'GammaPoisson', 'GaussianRandomWalk', 'Gumbel', + 'GumbelSoftmaxProbs', 'HalfCauchy', 'HalfNormal', 'Independent', diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index f6c29e22f..29366ae7b 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -27,7 +27,6 @@ from jax import device_put, lax from jax.dtypes import canonicalize_dtype -from jax.nn import softmax import jax.numpy as np import jax.random as random from jax.scipy.special import expit, gammaln, logsumexp, xlog1py, xlogy @@ -41,11 +40,14 @@ categorical_logits, clamp_probs, get_dtype, + gumbel_softmax_logits, + gumbel_softmax_probs, lazy_property, multinomial, poisson, promote_shapes, sum_rightmost, + _to_probs_multinom, validate_sample, ) from numpyro.util import copy_docs_from @@ -60,10 +62,6 @@ def _to_logits_bernoulli(probs): return np.log(ps_clamped) - np.log1p(-ps_clamped) -def _to_probs_multinom(logits): - return softmax(logits, axis=-1) - - def _to_logits_multinom(probs): minval = np.finfo(get_dtype(probs)).min return np.clip(np.log(probs), a_min=minval) @@ -336,6 +334,51 @@ def variance(self): return np.zeros(self.batch_shape + self.event_shape) +@copy_docs_from(Distribution) +class GumbelSoftmaxProbs(Distribution): + + arg_constraints = {'probs': constraints.simplex} + + def __init__(self, probs, temperature=1., validate_args=None): + if np.ndim(probs) < 1: + raise ValueError("`probs` parameter must be at least one-dimensional.") + self.probs = probs + self.standard_gumbel = Gumbel() + self.temperature = temperature + super(GumbelSoftmaxProbs, self).__init__(batch_shape=np.shape(self.probs)[:-1], + validate_args=validate_args) + + def sample(self, key, sample_shape=()): + gs = self.standard_gumbel.sample(key=key, sample_shape=sample_shape) + return _to_probs_multinom((np.log(self.probs) + gs)/self.temperature) + # gumbel_softmax_probs(key, probs, + # shape=sample_shape + self.batch_shape + self.event_shape, + # temperature=self.temperature, hard=True) + """ + @validate_sample + def log_prob(self, value): + # FIXME: implement + batch_shape = lax.broadcast_shapes(np.shape(value), self.batch_shape) + value = np.expand_dims(value, axis=-1) + value = np.broadcast_to(value, batch_shape + (1,)) + logits = _to_logits_multinom(self.probs) + log_pmf = np.broadcast_to(logits, batch_shape + np.shape(logits)[-1:]) + return np.take_along_axis(log_pmf, value, axis=-1)[..., 0] + """ + + @property + def mean(self): + return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) + + @property + def variance(self): + return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) + + @property + def support(self): + return constraints.integer_interval(0, np.shape(self.probs)[-1]) + + class OrderedLogistic(CategoricalProbs): """ A categorical distribution with ordered outcomes. diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index d5e5d2669..a539d66b5 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -7,6 +7,7 @@ from jax import custom_transforms, defjvp, jit, lax, random, vmap from jax.dtypes import canonicalize_dtype from jax.lib import xla_bridge +from jax.nn import softmax import jax.numpy as np from jax.scipy.linalg import solve_triangular from jax.scipy.special import gammaln @@ -185,6 +186,24 @@ def categorical_logits(key, logits, shape=()): + logits, axis=-1) +def gumbel_softmax_logits(key, logits, shape=(), temperature=1., hard=False): + shape = shape or logits.shape[:-1] + y_soft = softmax((random.gumbel(key, shape + logits.shape[-1:], logits.dtype) + + logits)/temperature, axis=-1) + + + if hard: + y_hard = np.where(y_soft == np.amax(y_soft, axis=-1, keepdims=True), 1., 0.) + ret = y_hard - jax.lax.stop_gradient(y_soft) + y_soft + else: + ret = y_soft + return ret + + +def gumbel_softmax_probs(key, probs, shape=(), temperature=1., hard=False): + return gumbel_softmax_logits(key, np.log(probs), shape, temperature=temperature, hard=hard) + + # Ref https://github.com/numpy/numpy/blob/8a0858f3903e488495a56b4a6d19bbefabc97dca/ # numpy/random/src/distributions/distributions.c#L574 def _poisson_large(val): @@ -295,6 +314,10 @@ def multinomial(key, p, n, shape=()): return _multinomial(key, p, n, n_max, shape) +def _to_probs_multinom(logits): + return softmax(logits, axis=-1) + + def cholesky_of_inverse(matrix): # This formulation only takes the inverse of a triangular matrix # which is more numerically stable. From 3a847af9be8daec841f78f89f4df6cc3666447c7 Mon Sep 17 00:00:00 2001 From: daydreamt Date: Sun, 5 Apr 2020 17:04:52 +0200 Subject: [PATCH 02/12] Move GumbelSoftmaxProbs to continuous, because it needs access to Gumbel --- numpyro/distributions/__init__.py | 2 +- numpyro/distributions/continuous.py | 48 +++++++++++++++++++++++++++++ numpyro/distributions/discrete.py | 47 ---------------------------- 3 files changed, 49 insertions(+), 48 deletions(-) diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index d7d5ee513..45e00e883 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -13,6 +13,7 @@ Gamma, GaussianRandomWalk, Gumbel, + GumbelSoftmaxProbs, HalfCauchy, HalfNormal, InverseGamma, @@ -38,7 +39,6 @@ CategoricalLogits, CategoricalProbs, Delta, - GumbelSoftmaxProbs, Multinomial, MultinomialLogits, MultinomialProbs, diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 6c5dabf7c..6ebb85d4b 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -38,10 +38,13 @@ from numpyro.distributions.util import ( cholesky_of_inverse, cumsum, + gumbel_softmax_logits, + gumbel_softmax_probs, lazy_property, matrix_to_tril_vec, promote_shapes, signed_stick_breaking_tril, + _to_probs_multinom, validate_sample, vec_to_tril_matrix ) @@ -362,6 +365,51 @@ def variance(self): self.batch_shape) +@copy_docs_from(Distribution) +class GumbelSoftmaxProbs(Distribution): + + arg_constraints = {'probs': constraints.simplex} + + def __init__(self, probs, temperature=1., validate_args=None): + if np.ndim(probs) < 1: + raise ValueError("`probs` parameter must be at least one-dimensional.") + self.probs = probs + self.standard_gumbel = Gumbel() + self.temperature = temperature + super(GumbelSoftmaxProbs, self).__init__(batch_shape=np.shape(self.probs)[:-1], + validate_args=validate_args) + + def sample(self, key, sample_shape=()): + gs = self.standard_gumbel.sample(key=key, sample_shape=sample_shape) + return _to_probs_multinom((np.log(self.probs) + gs)/self.temperature) + # gumbel_softmax_probs(key, probs, + # shape=sample_shape + self.batch_shape + self.event_shape, + # temperature=self.temperature, hard=True) + """ + @validate_sample + def log_prob(self, value): + # FIXME: implement + batch_shape = lax.broadcast_shapes(np.shape(value), self.batch_shape) + value = np.expand_dims(value, axis=-1) + value = np.broadcast_to(value, batch_shape + (1,)) + logits = _to_logits_multinom(self.probs) + log_pmf = np.broadcast_to(logits, batch_shape + np.shape(logits)[-1:]) + return np.take_along_axis(log_pmf, value, axis=-1)[..., 0] + """ + + @property + def mean(self): + return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) + + @property + def variance(self): + return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) + + @property + def support(self): + return constraints.integer_interval(0, np.shape(self.probs)[-1]) + + @copy_docs_from(Distribution) class LKJ(TransformedDistribution): r""" diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 29366ae7b..196b6edc2 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -40,8 +40,6 @@ categorical_logits, clamp_probs, get_dtype, - gumbel_softmax_logits, - gumbel_softmax_probs, lazy_property, multinomial, poisson, @@ -334,51 +332,6 @@ def variance(self): return np.zeros(self.batch_shape + self.event_shape) -@copy_docs_from(Distribution) -class GumbelSoftmaxProbs(Distribution): - - arg_constraints = {'probs': constraints.simplex} - - def __init__(self, probs, temperature=1., validate_args=None): - if np.ndim(probs) < 1: - raise ValueError("`probs` parameter must be at least one-dimensional.") - self.probs = probs - self.standard_gumbel = Gumbel() - self.temperature = temperature - super(GumbelSoftmaxProbs, self).__init__(batch_shape=np.shape(self.probs)[:-1], - validate_args=validate_args) - - def sample(self, key, sample_shape=()): - gs = self.standard_gumbel.sample(key=key, sample_shape=sample_shape) - return _to_probs_multinom((np.log(self.probs) + gs)/self.temperature) - # gumbel_softmax_probs(key, probs, - # shape=sample_shape + self.batch_shape + self.event_shape, - # temperature=self.temperature, hard=True) - """ - @validate_sample - def log_prob(self, value): - # FIXME: implement - batch_shape = lax.broadcast_shapes(np.shape(value), self.batch_shape) - value = np.expand_dims(value, axis=-1) - value = np.broadcast_to(value, batch_shape + (1,)) - logits = _to_logits_multinom(self.probs) - log_pmf = np.broadcast_to(logits, batch_shape + np.shape(logits)[-1:]) - return np.take_along_axis(log_pmf, value, axis=-1)[..., 0] - """ - - @property - def mean(self): - return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) - - @property - def variance(self): - return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) - - @property - def support(self): - return constraints.integer_interval(0, np.shape(self.probs)[-1]) - - class OrderedLogistic(CategoricalProbs): """ A categorical distribution with ordered outcomes. From e648ceba399b1c7f14206d2303d4d943fad1c87d Mon Sep 17 00:00:00 2001 From: daydreamt Date: Mon, 6 Apr 2020 12:06:23 +0200 Subject: [PATCH 03/12] rewrite proposal by using gumbel_soft_max_logits in utils --- numpyro/distributions/continuous.py | 8 +++++--- numpyro/distributions/util.py | 13 ++++++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 6ebb85d4b..6552c93c5 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -380,11 +380,13 @@ def __init__(self, probs, temperature=1., validate_args=None): validate_args=validate_args) def sample(self, key, sample_shape=()): - gs = self.standard_gumbel.sample(key=key, sample_shape=sample_shape) - return _to_probs_multinom((np.log(self.probs) + gs)/self.temperature) + #gs = self.standard_gumbel.sample(key=key, sample_shape=sample_shape + self.probs.shape) + #s = (np.log(self.probs) + gs)/self.temperature + return gumbel_softmax_probs(key, self.probs, shape=sample_shape, + temperature=self.temperature, hard=False) + # gumbel_softmax_probs(key, probs, # shape=sample_shape + self.batch_shape + self.event_shape, - # temperature=self.temperature, hard=True) """ @validate_sample def log_prob(self, value): diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index a539d66b5..c40e848c8 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -186,18 +186,21 @@ def categorical_logits(key, logits, shape=()): + logits, axis=-1) -def gumbel_softmax_logits(key, logits, shape=(), temperature=1., hard=False): +def gumbel_softmax_logits(key, logits, shape=(), temperature=1., hard=False, one_hot=False): shape = shape or logits.shape[:-1] y_soft = softmax((random.gumbel(key, shape + logits.shape[-1:], logits.dtype) - + logits)/temperature, axis=-1) - + + logits) / temperature, axis=-1) if hard: y_hard = np.where(y_soft == np.amax(y_soft, axis=-1, keepdims=True), 1., 0.) - ret = y_hard - jax.lax.stop_gradient(y_soft) + y_soft + ret = y_hard - lax.stop_gradient(y_soft) + y_soft else: ret = y_soft - return ret + + if one_hot: + return ret + else: + return _categorical(key, ret, shape) def gumbel_softmax_probs(key, probs, shape=(), temperature=1., hard=False): From 810adba82040a1e8f1a29f07445b6d7c3542ca33 Mon Sep 17 00:00:00 2001 From: daydreamt Date: Mon, 6 Apr 2020 12:06:46 +0200 Subject: [PATCH 04/12] Add first test of correct sampling from GumbelSoftmax for high and low temperatures. --- test/test_distributions.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/test_distributions.py b/test/test_distributions.py index 148de1f70..06e58fe02 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -107,6 +107,8 @@ def _lowrank_mvn_to_scipy(loc, cov_fac, cov_diag): T(dist.Gumbel, 0., 1.), T(dist.Gumbel, 0.5, 2.), T(dist.Gumbel, np.array([0., 0.5]), np.array([1., 2.])), + T(dist.GumbelSoftmaxProbs([0.1, 0.2, 0.3, 0.4], temperature=0.0001)), + T(dist.GumbelSoftmaxProbs([0.1, 0.2, 0.3, 0.4], temperature=100)), T(dist.HalfCauchy, 1.), T(dist.HalfCauchy, np.array([1., 2.])), T(dist.HalfNormal, 1.), @@ -962,3 +964,35 @@ def test_generated_sample_distribution(jax_dist, sp_dist, params, our_samples = jax_dist.sample(key, (N_sample,)) ks_result = osp.kstest(our_samples, sp_dist(*params).cdf) assert ks_result.pvalue > 0.05 + + +@pytest.mark.parametrize('temperature', [0.0001, 0.00001]) +@pytest.mark.parametrize('N_sample', [10_000, 100_000]) +def test_relaxations_low(temperature, N_sample, key=jax.random.PRNGKey(52)): + """ Test that samples from low temperatures are close to samples from a + Categorical distribution (and consequently that the GumbelSoftmax + distribution samples correctly). + """ + + probs = np.array([0.1, 0.1, 0.8]) + GS1 = dist.GumbelSoftmaxProbs(probs, temperature=temperature) + gs_samples = GS1.sample(key, (N_sample,)) + categorical_samples = dist.CategoricalProbs(probs).sample(key, (N_sample, )) + ks_result = osp.ks_2samp(gs_samples, categorical_samples) + assert ks_result.pvalue > 0.05 + + +@pytest.mark.parametrize('temperature', [100, 1000]) +@pytest.mark.parametrize('N_sample', [10_000, 100_000]) +def test_relaxations_high(temperature, N_sample, key=jax.random.PRNGKey(52)): + """ Test that samples from high temperatures are close to samples + from a Categorical distribution with equal probabilities for the classes. + """ + probs = np.array([0.1, 0.1, 0.8]) + + GS1 = dist.GumbelSoftmaxProbs(probs, temperature=temperature) + gs_samples = GS1.sample(key, (N_sample,)) + uniform_samples = dist.Categorical(np.array([1./3, 1./3, 1./3])).sample(key, (N_sample, )) + + ks_result = osp.ks_2samp(gs_samples, uniform_samples) + assert ks_result.pvalue > 0.05, "failed KS betwen Gumbel Softmax and categorical with equal probabilities for temperature {}".format(temperature) From 249112a97c46e1bc7c6c469aef01de49c21d8252 Mon Sep 17 00:00:00 2001 From: daydreamt Date: Mon, 6 Apr 2020 15:27:30 +0200 Subject: [PATCH 05/12] Fix constructor of GumbelSoftmaxProbs; many tests still failing --- test/test_distributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 06e58fe02..b0dca05b5 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -107,8 +107,8 @@ def _lowrank_mvn_to_scipy(loc, cov_fac, cov_diag): T(dist.Gumbel, 0., 1.), T(dist.Gumbel, 0.5, 2.), T(dist.Gumbel, np.array([0., 0.5]), np.array([1., 2.])), - T(dist.GumbelSoftmaxProbs([0.1, 0.2, 0.3, 0.4], temperature=0.0001)), - T(dist.GumbelSoftmaxProbs([0.1, 0.2, 0.3, 0.4], temperature=100)), + T(dist.GumbelSoftmaxProbs(np.array([0.1, 0.2, 0.3, 0.4]), temperature=0.0001)), + T(dist.GumbelSoftmaxProbs(np.array([0.1, 0.2, 0.3, 0.4]), temperature=100)), T(dist.HalfCauchy, 1.), T(dist.HalfCauchy, np.array([1., 2.])), T(dist.HalfNormal, 1.), From e20fc99bb344afdc972ea8c29cf368b4c3b828a8 Mon Sep 17 00:00:00 2001 From: daydreamt Date: Mon, 6 Apr 2020 16:05:22 +0200 Subject: [PATCH 06/12] add working but not tested version of log_prob for GumbelSoftmaxProbs --- numpyro/distributions/continuous.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 6552c93c5..15d91e3f4 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -374,6 +374,7 @@ def __init__(self, probs, temperature=1., validate_args=None): if np.ndim(probs) < 1: raise ValueError("`probs` parameter must be at least one-dimensional.") self.probs = probs + self.k = probs.shape[-1] self.standard_gumbel = Gumbel() self.temperature = temperature super(GumbelSoftmaxProbs, self).__init__(batch_shape=np.shape(self.probs)[:-1], @@ -387,6 +388,18 @@ def sample(self, key, sample_shape=()): # gumbel_softmax_probs(key, probs, # shape=sample_shape + self.batch_shape + self.event_shape, + + @validate_sample + def log_prob(self, ys): + probs = self.probs + k = self.k + temperature = self.temperature + + #ys = np.array([0., 0., 1.]) # TODO: ENSURE ONE HOT + eps = np.finfo(ys.dtype).eps + ys = np.clip(ys, a_min=eps, a_max=1-eps) + return gammaln(k) + (k-1)*np.log(temperature) -k*np.log( (probs / (ys**temperature)).sum(axis=-1)) + np.sum( np.log(probs / (ys**temperature)), axis=-1) + """ @validate_sample def log_prob(self, value): From c4f3ec87f553ec9dc298d91493448266dbe8ae1c Mon Sep 17 00:00:00 2001 From: daydreamt Date: Mon, 6 Apr 2020 17:38:37 +0200 Subject: [PATCH 07/12] Fix gradient for most cases --- numpyro/distributions/continuous.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 15d91e3f4..16b3ea4b5 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -38,6 +38,7 @@ from numpyro.distributions.util import ( cholesky_of_inverse, cumsum, + get_dtype, gumbel_softmax_logits, gumbel_softmax_probs, lazy_property, @@ -380,10 +381,10 @@ def __init__(self, probs, temperature=1., validate_args=None): super(GumbelSoftmaxProbs, self).__init__(batch_shape=np.shape(self.probs)[:-1], validate_args=validate_args) - def sample(self, key, sample_shape=()): + def sample(self, rng_key, sample_shape=()): #gs = self.standard_gumbel.sample(key=key, sample_shape=sample_shape + self.probs.shape) #s = (np.log(self.probs) + gs)/self.temperature - return gumbel_softmax_probs(key, self.probs, shape=sample_shape, + return gumbel_softmax_probs(rng_key, self.probs, shape=sample_shape, temperature=self.temperature, hard=False) # gumbel_softmax_probs(key, probs, @@ -396,10 +397,14 @@ def log_prob(self, ys): temperature = self.temperature #ys = np.array([0., 0., 1.]) # TODO: ENSURE ONE HOT - eps = np.finfo(ys.dtype).eps - ys = np.clip(ys, a_min=eps, a_max=1-eps) - return gammaln(k) + (k-1)*np.log(temperature) -k*np.log( (probs / (ys**temperature)).sum(axis=-1)) + np.sum( np.log(probs / (ys**temperature)), axis=-1) - + eps = np.finfo(probs.dtype).eps + probs = np.clip(probs, a_min=eps, a_max=1-eps) + # We clip the ys to be positive to have a positive denominator + ys = np.clip(ys, a_min=eps) + + res = gammaln(k) + (k-1)*np.log(temperature) -k*np.log( (probs / (ys**temperature)).sum(axis=-1)) + np.sum( np.log(probs / (ys**temperature)), axis=-1) + #res += np.sum(np.log(probs), axis=-1) - (temperature + 1) * np.sum(np.log(ys), axis=-1) + return res """ @validate_sample def log_prob(self, value): From 8790e5aef16d16345aa9794e92052fffc2ec80ef Mon Sep 17 00:00:00 2001 From: daydreamt Date: Mon, 6 Apr 2020 20:25:53 +0200 Subject: [PATCH 08/12] Make most tests, except d_log_prob broadcasting, and test_log_prob_gradient work --- numpyro/distributions/continuous.py | 39 +++++++++++------------------ numpyro/distributions/util.py | 5 ++-- test/test_distributions.py | 9 +++---- 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 16b3ea4b5..14f427c11 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -369,7 +369,7 @@ def variance(self): @copy_docs_from(Distribution) class GumbelSoftmaxProbs(Distribution): - arg_constraints = {'probs': constraints.simplex} + arg_constraints = {'probs': constraints.simplex, 'temperature':constraints.real} def __init__(self, probs, temperature=1., validate_args=None): if np.ndim(probs) < 1: @@ -379,50 +379,39 @@ def __init__(self, probs, temperature=1., validate_args=None): self.standard_gumbel = Gumbel() self.temperature = temperature super(GumbelSoftmaxProbs, self).__init__(batch_shape=np.shape(self.probs)[:-1], - validate_args=validate_args) + event_shape=np.shape(self.probs)[-1:], + validate_args=validate_args) + - def sample(self, rng_key, sample_shape=()): - #gs = self.standard_gumbel.sample(key=key, sample_shape=sample_shape + self.probs.shape) - #s = (np.log(self.probs) + gs)/self.temperature - return gumbel_softmax_probs(rng_key, self.probs, shape=sample_shape, - temperature=self.temperature, hard=False) + def sample(self, key, sample_shape=(), one_hot=True): + return gumbel_softmax_probs(key, self.probs, shape=sample_shape + self.batch_shape,#sample_shape, + temperature=self.temperature, hard=False, one_hot=one_hot) - # gumbel_softmax_probs(key, probs, - # shape=sample_shape + self.batch_shape + self.event_shape, - @validate_sample - def log_prob(self, ys): + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) probs = self.probs k = self.k temperature = self.temperature - #ys = np.array([0., 0., 1.]) # TODO: ENSURE ONE HOT + # TODO: ENSURE ONE HOT, OTHERWISE CONVERT eps = np.finfo(probs.dtype).eps probs = np.clip(probs, a_min=eps, a_max=1-eps) # We clip the ys to be positive to have a positive denominator - ys = np.clip(ys, a_min=eps) + ys = np.clip(value, a_min=eps) res = gammaln(k) + (k-1)*np.log(temperature) -k*np.log( (probs / (ys**temperature)).sum(axis=-1)) + np.sum( np.log(probs / (ys**temperature)), axis=-1) - #res += np.sum(np.log(probs), axis=-1) - (temperature + 1) * np.sum(np.log(ys), axis=-1) return res - """ - @validate_sample - def log_prob(self, value): - # FIXME: implement - batch_shape = lax.broadcast_shapes(np.shape(value), self.batch_shape) - value = np.expand_dims(value, axis=-1) - value = np.broadcast_to(value, batch_shape + (1,)) - logits = _to_logits_multinom(self.probs) - log_pmf = np.broadcast_to(logits, batch_shape + np.shape(logits)[-1:]) - return np.take_along_axis(log_pmf, value, axis=-1)[..., 0] - """ @property def mean(self): + #FIXME: correct return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) @property def variance(self): + #FIXME: correct return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) @property diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index c40e848c8..201f2341d 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -203,8 +203,9 @@ def gumbel_softmax_logits(key, logits, shape=(), temperature=1., hard=False, one return _categorical(key, ret, shape) -def gumbel_softmax_probs(key, probs, shape=(), temperature=1., hard=False): - return gumbel_softmax_logits(key, np.log(probs), shape, temperature=temperature, hard=hard) +def gumbel_softmax_probs(key, probs, shape=(), temperature=1., hard=False, one_hot=False): + return gumbel_softmax_logits(key, np.log(probs), shape, + temperature=temperature, hard=hard, one_hot=one_hot) # Ref https://github.com/numpy/numpy/blob/8a0858f3903e488495a56b4a6d19bbefabc97dca/ diff --git a/test/test_distributions.py b/test/test_distributions.py index b0dca05b5..6f2cbbdab 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -107,8 +107,8 @@ def _lowrank_mvn_to_scipy(loc, cov_fac, cov_diag): T(dist.Gumbel, 0., 1.), T(dist.Gumbel, 0.5, 2.), T(dist.Gumbel, np.array([0., 0.5]), np.array([1., 2.])), - T(dist.GumbelSoftmaxProbs(np.array([0.1, 0.2, 0.3, 0.4]), temperature=0.0001)), - T(dist.GumbelSoftmaxProbs(np.array([0.1, 0.2, 0.3, 0.4]), temperature=100)), + T(dist.GumbelSoftmaxProbs, np.array([0.1, 0.2, 0.3, 0.4]), 0.0001), + T(dist.GumbelSoftmaxProbs, np.array([0.1, 0.2, 0.3, 0.4]), 100), T(dist.HalfCauchy, 1.), T(dist.HalfCauchy, np.array([1., 2.])), T(dist.HalfNormal, 1.), @@ -666,7 +666,6 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): valid_params[i] = gen_values_within_bounds(constraint, np.shape(params[i]), key_gen) assert jax_dist(*oob_params) - # Invalid parameter values throw ValueError if not dependent_constraint: with pytest.raises(ValueError): @@ -976,7 +975,7 @@ def test_relaxations_low(temperature, N_sample, key=jax.random.PRNGKey(52)): probs = np.array([0.1, 0.1, 0.8]) GS1 = dist.GumbelSoftmaxProbs(probs, temperature=temperature) - gs_samples = GS1.sample(key, (N_sample,)) + gs_samples = GS1.sample(key, (N_sample,), one_hot=False) categorical_samples = dist.CategoricalProbs(probs).sample(key, (N_sample, )) ks_result = osp.ks_2samp(gs_samples, categorical_samples) assert ks_result.pvalue > 0.05 @@ -991,7 +990,7 @@ def test_relaxations_high(temperature, N_sample, key=jax.random.PRNGKey(52)): probs = np.array([0.1, 0.1, 0.8]) GS1 = dist.GumbelSoftmaxProbs(probs, temperature=temperature) - gs_samples = GS1.sample(key, (N_sample,)) + gs_samples = GS1.sample(key, (N_sample,), one_hot=False) uniform_samples = dist.Categorical(np.array([1./3, 1./3, 1./3])).sample(key, (N_sample, )) ks_result = osp.ks_2samp(gs_samples, uniform_samples) From c67abad5661a30484f10821ef10069acae0eb4b0 Mon Sep 17 00:00:00 2001 From: daydreamt Date: Mon, 6 Apr 2020 22:06:43 +0200 Subject: [PATCH 09/12] Write out the calculation of the log_prob to prepare for dimension expansion where needed --- numpyro/distributions/continuous.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 14f427c11..4b699f411 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -397,11 +397,12 @@ def log_prob(self, value): # TODO: ENSURE ONE HOT, OTHERWISE CONVERT eps = np.finfo(probs.dtype).eps - probs = np.clip(probs, a_min=eps, a_max=1-eps) # We clip the ys to be positive to have a positive denominator ys = np.clip(value, a_min=eps) - res = gammaln(k) + (k-1)*np.log(temperature) -k*np.log( (probs / (ys**temperature)).sum(axis=-1)) + np.sum( np.log(probs / (ys**temperature)), axis=-1) + res = gammaln(k) + (k-1) * np.log(temperature) + res += -k * np.log( (probs / (ys**temperature)).sum(axis=-1)) #+ np.sum( np.log(probs / (ys**temperature)), axis=-1) + res += np.sum(np.log(probs), axis=-1) - (temperature + 1) * np.sum(np.log(ys), axis=-1) return res @property From 78451d6c02b4562ea6b651287849d46e60518876 Mon Sep 17 00:00:00 2001 From: daydreamt Date: Tue, 7 Apr 2020 08:59:38 +0200 Subject: [PATCH 10/12] Fix lint problems --- numpyro/distributions/continuous.py | 14 ++++++-------- numpyro/distributions/util.py | 7 ++++--- test/test_distributions.py | 6 ++++-- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 4b699f411..8ce73f155 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -39,13 +39,11 @@ cholesky_of_inverse, cumsum, get_dtype, - gumbel_softmax_logits, gumbel_softmax_probs, lazy_property, matrix_to_tril_vec, promote_shapes, signed_stick_breaking_tril, - _to_probs_multinom, validate_sample, vec_to_tril_matrix ) @@ -369,7 +367,8 @@ def variance(self): @copy_docs_from(Distribution) class GumbelSoftmaxProbs(Distribution): - arg_constraints = {'probs': constraints.simplex, 'temperature':constraints.real} + arg_constraints = {'probs': constraints.simplex, + 'temperature': constraints.real} def __init__(self, probs, temperature=1., validate_args=None): if np.ndim(probs) < 1: @@ -382,9 +381,8 @@ def __init__(self, probs, temperature=1., validate_args=None): event_shape=np.shape(self.probs)[-1:], validate_args=validate_args) - def sample(self, key, sample_shape=(), one_hot=True): - return gumbel_softmax_probs(key, self.probs, shape=sample_shape + self.batch_shape,#sample_shape, + return gumbel_softmax_probs(key, self.probs, shape=sample_shape + self.batch_shape, temperature=self.temperature, hard=False, one_hot=one_hot) @validate_sample @@ -401,18 +399,18 @@ def log_prob(self, value): ys = np.clip(value, a_min=eps) res = gammaln(k) + (k-1) * np.log(temperature) - res += -k * np.log( (probs / (ys**temperature)).sum(axis=-1)) #+ np.sum( np.log(probs / (ys**temperature)), axis=-1) + res += -k * np.log((probs / (ys**temperature)).sum(axis=-1)) res += np.sum(np.log(probs), axis=-1) - (temperature + 1) * np.sum(np.log(ys), axis=-1) return res @property def mean(self): - #FIXME: correct + # FIXME: correct return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) @property def variance(self): - #FIXME: correct + # FIXME: correct return np.full(self.batch_shape, np.nan, dtype=get_dtype(self.probs)) @property diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 201f2341d..d3b81fe72 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -190,13 +190,13 @@ def gumbel_softmax_logits(key, logits, shape=(), temperature=1., hard=False, one shape = shape or logits.shape[:-1] y_soft = softmax((random.gumbel(key, shape + logits.shape[-1:], logits.dtype) + logits) / temperature, axis=-1) - + if hard: y_hard = np.where(y_soft == np.amax(y_soft, axis=-1, keepdims=True), 1., 0.) ret = y_hard - lax.stop_gradient(y_soft) + y_soft else: ret = y_soft - + if one_hot: return ret else: @@ -205,7 +205,8 @@ def gumbel_softmax_logits(key, logits, shape=(), temperature=1., hard=False, one def gumbel_softmax_probs(key, probs, shape=(), temperature=1., hard=False, one_hot=False): return gumbel_softmax_logits(key, np.log(probs), shape, - temperature=temperature, hard=hard, one_hot=one_hot) + temperature=temperature, hard=hard, + one_hot=one_hot) # Ref https://github.com/numpy/numpy/blob/8a0858f3903e488495a56b4a6d19bbefabc97dca/ diff --git a/test/test_distributions.py b/test/test_distributions.py index 6f2cbbdab..8067dcfdd 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -969,7 +969,7 @@ def test_generated_sample_distribution(jax_dist, sp_dist, params, @pytest.mark.parametrize('N_sample', [10_000, 100_000]) def test_relaxations_low(temperature, N_sample, key=jax.random.PRNGKey(52)): """ Test that samples from low temperatures are close to samples from a - Categorical distribution (and consequently that the GumbelSoftmax + Categorical distribution (and consequently that the GumbelSoftmax distribution samples correctly). """ @@ -994,4 +994,6 @@ def test_relaxations_high(temperature, N_sample, key=jax.random.PRNGKey(52)): uniform_samples = dist.Categorical(np.array([1./3, 1./3, 1./3])).sample(key, (N_sample, )) ks_result = osp.ks_2samp(gs_samples, uniform_samples) - assert ks_result.pvalue > 0.05, "failed KS betwen Gumbel Softmax and categorical with equal probabilities for temperature {}".format(temperature) + error_message = """failed KS betwen Gumbel Softmax and categorical with + equal probabilities for temperature {}""".format(temperature) + assert ks_result.pvalue > 0.05, error_message From 8ee6ed069fda88fc31293d5e45ab5295c7e90a55 Mon Sep 17 00:00:00 2001 From: daydreamt Date: Tue, 7 Apr 2020 12:24:52 +0200 Subject: [PATCH 11/12] Make gumbel_soft_max_probs behavior consistent in preparation of passing the shape tests; --- numpyro/distributions/continuous.py | 12 +++++++----- numpyro/distributions/util.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 8ce73f155..8db7329bf 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -374,15 +374,16 @@ def __init__(self, probs, temperature=1., validate_args=None): if np.ndim(probs) < 1: raise ValueError("`probs` parameter must be at least one-dimensional.") self.probs = probs + batch_shape, event_shape = probs.shape[:-1], probs.shape[-1:] self.k = probs.shape[-1] self.standard_gumbel = Gumbel() self.temperature = temperature - super(GumbelSoftmaxProbs, self).__init__(batch_shape=np.shape(self.probs)[:-1], - event_shape=np.shape(self.probs)[-1:], + super(GumbelSoftmaxProbs, self).__init__(batch_shape=batch_shape, + event_shape=event_shape, validate_args=validate_args) def sample(self, key, sample_shape=(), one_hot=True): - return gumbel_softmax_probs(key, self.probs, shape=sample_shape + self.batch_shape, + return gumbel_softmax_probs(key, self.probs, shape=sample_shape + self.batch_shape + self.event_shape, temperature=self.temperature, hard=False, one_hot=one_hot) @validate_sample @@ -394,11 +395,12 @@ def log_prob(self, value): temperature = self.temperature # TODO: ENSURE ONE HOT, OTHERWISE CONVERT - eps = np.finfo(probs.dtype).eps + eps = np.finfo(probs.dtype).eps # TODO: Make depending on value? But then must be inexact # We clip the ys to be positive to have a positive denominator ys = np.clip(value, a_min=eps) - + probs = promote_shapes(probs, shape=ys.shape)[0] res = gammaln(k) + (k-1) * np.log(temperature) + # lax._safe_mul(x, 1/ ys**temperature) res += -k * np.log((probs / (ys**temperature)).sum(axis=-1)) res += np.sum(np.log(probs), axis=-1) - (temperature + 1) * np.sum(np.log(ys), axis=-1) return res diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index d3b81fe72..10d4de0bb 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -188,7 +188,7 @@ def categorical_logits(key, logits, shape=()): def gumbel_softmax_logits(key, logits, shape=(), temperature=1., hard=False, one_hot=False): shape = shape or logits.shape[:-1] - y_soft = softmax((random.gumbel(key, shape + logits.shape[-1:], logits.dtype) + y_soft = softmax((random.gumbel(key, shape, logits.dtype) + logits) / temperature, axis=-1) if hard: @@ -204,7 +204,7 @@ def gumbel_softmax_logits(key, logits, shape=(), temperature=1., hard=False, one def gumbel_softmax_probs(key, probs, shape=(), temperature=1., hard=False, one_hot=False): - return gumbel_softmax_logits(key, np.log(probs), shape, + return gumbel_softmax_logits(key, np.log(probs), shape=shape, temperature=temperature, hard=hard, one_hot=one_hot) From 7c6837bb2577d94c5e9d9bc1954a467538359b4b Mon Sep 17 00:00:00 2001 From: daydreamt Date: Tue, 7 Apr 2020 15:25:32 +0200 Subject: [PATCH 12/12] Fix high/low temperature non-one hot tests --- numpyro/distributions/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 10d4de0bb..4b22bb8b7 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -200,7 +200,7 @@ def gumbel_softmax_logits(key, logits, shape=(), temperature=1., hard=False, one if one_hot: return ret else: - return _categorical(key, ret, shape) + return _categorical(key, ret, None) def gumbel_softmax_probs(key, probs, shape=(), temperature=1., hard=False, one_hot=False):