Skip to content

support for custom support in TransformedDistribution #2043

@Qazalbash

Description

@Qazalbash

Feature Summary

Currently, MixtureGeneral supports a custom support argument, as discussed in this issue comment. It would be helpful to extend similar functionality to TransformedDistribution.

Why is this needed?

In some use cases, I employ custom Transform objects that are well-defined mathematically and invertible. However, these transforms don't inherently enforce the desired output support constraints. As a result, samples from the transformed distribution can fall outside the intended codomain, which leads to inconsistencies in model behavior and inference.

A concrete example is the following transformation from a primary mass $m_1$ and mass ratio $q \in (0, 1]$ to component masses $(m_1, m_2)$, where $m_2 = m_1 q$ and $m_2 \le m_1$:

# Copyright 2023 The GWKokab Authors
# SPDX-License-Identifier: Apache-2.0


class PrimaryMassAndMassRatioToComponentMassesTransform(Transform):
    r"""Transforms a primary mass and mass ratio to component masses.

    .. math::
        f: (m_1, q)\to (m_1, m_1q)

    .. math::
        f^{-1}: (m_1, m_2)\to (m_1, m_2/m_1)
    """

    domain = constraints.independent(
        constraints.interval(
            jnp.zeros((2,)), jnp.array([jnp.finfo(jnp.result_type(float)).max, 1.0])
        ),
        1,
    )
    r""":math:`\mathcal{D}(f) = \mathbb{R}^2_+\times[0, 1]`"""
    codomain = positive_decreasing_vector
    r""":math:`\mathcal{C}(f)=\{(m_1, m_2)\in\mathbb{R}^2_+\mid m_1\geq m_2>0\}`"""

    def __call__(self, x: Array):
        m1, q = jnp.unstack(x, axis=-1)
        m2 = jnp.multiply(m1, q)
        m1m2 = jnp.stack((m1, m2), axis=-1)
        return m1m2

    def _inverse(self, y: Array):
        m1, m2 = jnp.unstack(y, axis=-1)
        q = mass_ratio(m2=m2, m1=m1)
        m1q = jnp.stack((m1, q), axis=-1)
        return m1q

    def log_abs_det_jacobian(self, x: Array, y: Array, intermediates=None):
        r"""
        .. math::
            \ln\left(|\mathrm{det}(J_f)|\right) = \ln(|m_1|)
        """
        m1 = x[..., 0]
        return jnp.log(jnp.abs(m1))

    def tree_flatten(self):
        return (), ((), dict())

    def __eq__(self, other):
        if not isinstance(other, PrimaryMassAndMassRatioToComponentMassesTransform):
            return False
        return self.domain == other.domain


class _PositiveDecreasingVector(_SingletonConstraint):
    r"""Constrain values to be positive and decreasing, i.e. :math:`\forall i<j, x_i
    \geq x_j`.
    """

    event_dim = 1

    def __call__(self, x):
        return decreasing_vector.check(x) & independent(positive, 1).check(x)

    def feasible_like(self, prototype):
        return jnp.ones(prototype.shape, dtype=prototype.dtype)

    def tree_flatten(self):
        return (), ((), dict())

    def __eq__(self, other):
        return isinstance(other, _PositiveDecreasingVector)

positive_decreasing_vector = _PositiveDecreasingVector()

While this transformation is mathematically correct, TransformedDistribution does not currently enforce the correct constraint, i.e., $m_{\mathrm{min}}\leq m_2 \leq m_1 \leq m_{\mathrm{max}}$, which leads to invalid samples unless post-processing or custom sampling is added. The condition $m_1,m_2\in[m_{\mathrm{min}}, m_{\mathrm{max}}]$ is not a part of the transformation; therefore, allowing an explicit support argument would enable this constraint to be enforced directly, just as it is done in MixtureGeneral.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions