diff --git a/funsor/distribution.py b/funsor/distribution.py index 4537e282f..73235c09d 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -16,13 +16,13 @@ import funsor.ops as ops from funsor.affine import is_affine from funsor.cnf import Contraction, GaussianMixture -from funsor.domains import Array, Real, Reals +from funsor.domains import Array, Real, Reals, RealsType from funsor.gaussian import Gaussian from funsor.interpreter import gensym from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype, ignore_jit_warnings, numeric_array, stack) from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, \ - eager, to_data, to_funsor + eager, reflect, to_data, to_funsor from funsor.util import broadcast_shape, get_backend, getargspec, lazy_property @@ -57,12 +57,36 @@ class DistributionMeta(FunsorMeta): """ def __call__(cls, *args, **kwargs): kwargs.update(zip(cls._ast_fields, args)) - value = kwargs.pop('value', 'value') - kwargs = OrderedDict( - (k, to_funsor(kwargs[k], output=cls._infer_param_domain(k, getattr(kwargs[k], "shape", ())))) - for k in cls._ast_fields if k != 'value') - value = to_funsor(value, output=cls._infer_value_domain(**{k: v.output for k, v in kwargs.items()})) - args = numbers_to_tensors(*(tuple(kwargs.values()) + (value,))) + kwargs["value"] = kwargs.get("value", "value") + kwargs = OrderedDict((k, kwargs[k]) for k in cls._ast_fields) # make sure args are sorted + + domains = OrderedDict() + for k, v in kwargs.items(): + if k == "value": + continue + + # compute unbroadcasted param domains + domain = cls._infer_param_domain(k, getattr(kwargs[k], "shape", ())) + # use to_funsor to infer output dimensions of e.g. tensors + domains[k] = domain if domain is not None else to_funsor(v).output + + # broadcast individual param domains with Funsor inputs + # this avoids .expand-ing underlying parameter tensors + if isinstance(v, Funsor) and isinstance(v.output, RealsType): + domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)] + elif ops.is_numeric_array(v): + domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)] + + # now use the broadcasted parameter shapes to infer the event_shape + domains["value"] = cls._infer_value_domain(**domains) + if isinstance(kwargs["value"], Funsor) and isinstance(kwargs["value"].output, RealsType): + # try to broadcast the event shape with the value, in case they disagree + domains["value"] = Reals[broadcast_shape(domains["value"].shape, kwargs["value"].output.shape)] + + # finally, perform conversions to funsors + kwargs = OrderedDict((k, to_funsor(v, output=domains[k])) for k, v in kwargs.items()) + args = numbers_to_tensors(*kwargs.values()) + return super(DistributionMeta, cls).__call__(*args) @@ -98,14 +122,6 @@ def eager_reduce(self, op, reduced_vars): return Number(0.) # distributions are normalized return super(Distribution, self).eager_reduce(op, reduced_vars) - @classmethod - def eager_log_prob(cls, *params): - inputs, tensors = align_tensors(*params) - params = dict(zip(cls._ast_fields, tensors)) - value = params.pop('value') - data = cls.dist_class(**params).log_prob(value) - return Tensor(data, inputs) - def _get_raw_dist(self): """ Internal method for working with underlying distribution attributes @@ -129,6 +145,26 @@ def has_rsample(self): def has_enumerate_support(self): return getattr(self.dist_class, "has_enumerate_support", False) + @classmethod + def eager_log_prob(cls, *params): + params, value = params[:-1], params[-1] + params = params + (Variable("value", value.output),) + instance = reflect(cls, *params) + raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist() + assert value.output == value_output + name_to_dim = {v: k for k, v in dim_to_name.items()} + dim_to_name.update({-1 - d - len(raw_dist.batch_shape): name + for d, name in enumerate(value.inputs) if name not in name_to_dim}) + name_to_dim.update({v: k for k, v in dim_to_name.items() if v not in name_to_dim}) + raw_log_prob = raw_dist.log_prob(to_data(value, name_to_dim=name_to_dim)) + log_prob = to_funsor(raw_log_prob, Real, dim_to_name=dim_to_name) + # this logic ensures that the inputs have the canonical order + # implied by align_tensors, which is assumed pervasively in tests + inputs = OrderedDict() + for x in params[:-1] + (value,): + inputs.update(x.inputs) + return log_prob.align(tuple(inputs)) + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): # note this should handle transforms correctly via distribution_to_data @@ -191,7 +227,13 @@ def _infer_value_domain(cls, **kwargs): # rely on the underlying distribution's logic to infer the event_shape given param domains instance = cls.dist_class(**{k: dummy_numeric_array(domain) for k, domain in kwargs.items()}, validate_args=False) - out_shape = instance.event_shape + + # Note inclusion of batch_shape here to handle independent event dimensions. + # The arguments to _infer_value_domain are the .output shapes of parameters, + # so any extra batch dimensions that aren't part of the instance event_shape + # must be broadcasted output dimensions by construction. + out_shape = instance.batch_shape + instance.event_shape + if type(instance.support).__name__ == "_IntegerInterval": out_dtype = int(instance.support.upper_bound + 1) else: @@ -400,10 +442,32 @@ def __call__(self, cls, args, kwargs): @to_data.register(Distribution) def distribution_to_data(funsor_dist, name_to_dim=None): - params = [to_data(getattr(funsor_dist, param_name), name_to_dim=name_to_dim) - for param_name in funsor_dist._ast_fields if param_name != 'value'] - pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params))) funsor_event_shape = funsor_dist.value.output.shape + + # attempt to generically infer the independent output dimensions + instance = funsor_dist.dist_class(**{ + k: dummy_numeric_array(v.output) + for k, v in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]) + }, validate_args=False) + event_shape = broadcast_shape(instance.event_shape, funsor_dist.value.output.shape) + reinterpreted_batch_ndims = len(event_shape) - len(instance.event_shape) + assert reinterpreted_batch_ndims >= 0 # XXX is this ever nonzero? + indep_shape = broadcast_shape(instance.batch_shape, event_shape[:reinterpreted_batch_ndims]) + + params = [] + for param_name, funsor_param in zip(funsor_dist._ast_fields, funsor_dist._ast_values[:-1]): + param = to_data(funsor_param, name_to_dim=name_to_dim) + + # infer the independent dimensions of each parameter separately, since we chose to keep them unbroadcasted + param_event_shape = getattr(funsor_dist._infer_param_domain(param_name, funsor_param.output.shape), "shape", ()) + param_indep_shape = funsor_param.output.shape[:len(funsor_param.output.shape) - len(param_event_shape)] + for i in range(max(0, len(indep_shape) - len(param_indep_shape))): + # add singleton event dimensions, leave broadcasting/expanding to backend + param = ops.unsqueeze(param, -1 - len(funsor_param.output.shape)) + + params.append(param) + + pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params))) pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0)) # TODO get this working for all backends diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 23a06be69..ff1b099de 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -200,7 +200,7 @@ def _infer_param_domain(cls, name, raw_shape): ########################################################### -# Converting distribution funsors to PyTorch distributions +# Converting distribution funsors to NumPyro distributions ########################################################### # Convert Delta **distribution** to raw data @@ -212,9 +212,17 @@ def deltadist_to_data(funsor_dist, name_to_dim=None): ############################################### -# Converting PyTorch Distributions to funsors +# Converting NumPyro Distributions to funsors ############################################### +# TODO move these properties upstream to numpyro.distributions +dist.Independent.has_rsample = property(lambda self: self.base_dist.has_rsample) +dist.Independent.rsample = dist.Independent.sample +dist.MaskedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample) +dist.MaskedDistribution.rsample = dist.MaskedDistribution.sample +dist.TransformedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample) +dist.TransformedDistribution.rsample = dist.TransformedDistribution.sample + to_funsor.register(dist.Independent)(indepdist_to_funsor) if hasattr(dist, "MaskedDistribution"): to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor) diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index b93431ad1..1bcd0185a 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -137,7 +137,7 @@ def _infer_value_domain(**kwargs): @functools.lru_cache(maxsize=5000) def _infer_value_domain(cls, **kwargs): instance = cls.dist_class(**{k: dummy_numeric_array(domain) for k, domain in kwargs.items()}, validate_args=False) - return Reals[instance.event_shape] + return Reals[instance.batch_shape + instance.event_shape] # TODO fix Delta.arg_constraints["v"] to be a diff --git a/test/test_distribution.py b/test/test_distribution.py index 2f621779d..d72e83f4b 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -246,7 +246,7 @@ def dirichlet(concentration: Reals[event_shape], check_funsor(expected, inputs, Real) actual = dist.Dirichlet(concentration, value) check_funsor(actual, inputs, Real) - assert_close(actual, expected) + assert_close(actual, expected, atol=1e-3, rtol=1e-3) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) @@ -1123,3 +1123,66 @@ def test_gamma_poisson_conjugate(batch_shape): obs = Tensor(ops.astype(ops.astype(ops.exp(randn(batch_shape)), 'int32'), 'float32'), inputs) _assert_conjugate_density_ok(latent, conditional, obs) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('event_shape', [(4,), (4, 7), (1, 4), (4, 1), (4, 1, 7)], ids=str) +@pytest.mark.parametrize('use_raw_scale', [False, True]) +def test_normal_event_dim_conversion(batch_shape, event_shape, use_raw_scale): + + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) + + value = Variable("value", Reals[event_shape]) + loc = Tensor(randn(batch_shape + event_shape), inputs) + scale = Tensor(ops.exp(randn(batch_shape)), inputs) + if use_raw_scale: + if batch_shape: + pytest.xfail(reason="raw scale is underspecified for nonempty batch_shape") + scale = scale.data + + with interpretation(lazy): + actual = dist.Normal(loc=loc, scale=scale, value=value) + + expected_inputs = inputs.copy() + expected_inputs.update({"value": Reals[event_shape]}) + check_funsor(actual, expected_inputs, Real) + + name_to_dim = {batch_dim: -1-i for i, batch_dim in enumerate(batch_dims)} + rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) + data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0] + + actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) + expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( + funsor.to_data(data, name_to_dim=name_to_dim)) + assert actual_log_prob.shape == expected_log_prob.shape + assert_close(actual_log_prob, expected_log_prob) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('event_shape', [(4,), (4, 7), (1, 4), (4, 1), (4, 1, 7)], ids=str) +def test_mvnormal_event_dim_conversion(batch_shape, event_shape): + + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) + + value = Variable("value", Reals[event_shape]) + loc = Tensor(randn(batch_shape + event_shape), inputs) + scale_tril = Tensor(random_scale_tril(batch_shape + event_shape + event_shape[-1:]), inputs) + + with interpretation(lazy): + actual = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril, value=value) + + expected_inputs = inputs.copy() + expected_inputs.update({"value": Reals[event_shape]}) + check_funsor(actual, expected_inputs, Real) + + name_to_dim = {batch_dim: -1-i for i, batch_dim in enumerate(batch_dims)} + rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) + data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0] + + actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) + expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( + funsor.to_data(data, name_to_dim=name_to_dim)) + assert actual_log_prob.shape == expected_log_prob.shape + assert_close(actual_log_prob, expected_log_prob)