-
Notifications
You must be signed in to change notification settings - Fork 270
Description
Feature Summary
Currently, MixtureGeneral supports a custom support argument, as discussed in this issue comment. It would be helpful to extend similar functionality to TransformedDistribution.
Why is this needed?
In some use cases, I employ custom Transform objects that are well-defined mathematically and invertible. However, these transforms don't inherently enforce the desired output support constraints. As a result, samples from the transformed distribution can fall outside the intended codomain, which leads to inconsistencies in model behavior and inference.
A concrete example is the following transformation from a primary mass
# Copyright 2023 The GWKokab Authors
# SPDX-License-Identifier: Apache-2.0
class PrimaryMassAndMassRatioToComponentMassesTransform(Transform):
r"""Transforms a primary mass and mass ratio to component masses.
.. math::
f: (m_1, q)\to (m_1, m_1q)
.. math::
f^{-1}: (m_1, m_2)\to (m_1, m_2/m_1)
"""
domain = constraints.independent(
constraints.interval(
jnp.zeros((2,)), jnp.array([jnp.finfo(jnp.result_type(float)).max, 1.0])
),
1,
)
r""":math:`\mathcal{D}(f) = \mathbb{R}^2_+\times[0, 1]`"""
codomain = positive_decreasing_vector
r""":math:`\mathcal{C}(f)=\{(m_1, m_2)\in\mathbb{R}^2_+\mid m_1\geq m_2>0\}`"""
def __call__(self, x: Array):
m1, q = jnp.unstack(x, axis=-1)
m2 = jnp.multiply(m1, q)
m1m2 = jnp.stack((m1, m2), axis=-1)
return m1m2
def _inverse(self, y: Array):
m1, m2 = jnp.unstack(y, axis=-1)
q = mass_ratio(m2=m2, m1=m1)
m1q = jnp.stack((m1, q), axis=-1)
return m1q
def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None):
r"""
.. math::
\ln\left(|\mathrm{det}(J_f)|\right) = \ln(|m_1|)
"""
m1 = x[..., 0]
return jnp.log(jnp.abs(m1))
def tree_flatten(self):
return (), ((), dict())
def __eq__(self, other):
if not isinstance(other, PrimaryMassAndMassRatioToComponentMassesTransform):
return False
return self.domain == other.domain
class _PositiveDecreasingVector(_SingletonConstraint):
r"""Constrain values to be positive and decreasing, i.e. :math:`\forall i<j, x_i
\geq x_j`.
"""
event_dim = 1
def __call__(self, x):
return decreasing_vector.check(x) & independent(positive, 1).check(x)
def feasible_like(self, prototype):
return jnp.ones(prototype.shape, dtype=prototype.dtype)
def tree_flatten(self):
return (), ((), dict())
def __eq__(self, other):
return isinstance(other, _PositiveDecreasingVector)
positive_decreasing_vector = _PositiveDecreasingVector()While this transformation is mathematically correct, TransformedDistribution does not currently enforce the correct constraint, i.e., MixtureGeneral.