Skip to content
5 changes: 3 additions & 2 deletions numpyro/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,18 @@ class DistributionT(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...

def rsample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = ()
) -> ArrayLike: ...
def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
self, key: Optional[jax.dtypes.prng_key], sample_shape: tuple[int, ...] = ()
) -> ArrayLike: ...
def log_prob(self, value: ArrayLike) -> ArrayLike: ...
def cdf(self, value: ArrayLike) -> ArrayLike: ...
def icdf(self, q: ArrayLike) -> ArrayLike: ...
def entropy(self) -> ArrayLike: ...
def enumerate_support(self, expand: bool = True) -> ArrayLike: ...
def shape(self, sample_shape: tuple[int, ...] = ()) -> tuple[int, ...]: ...
def expand(self, batch_shape: tuple[int, ...]) -> "DistributionT": ...

@property
def batch_shape(self) -> tuple[int, ...]: ...
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def tree_flatten(self):
complex = _Complex()
corr_cholesky = _CorrCholesky()
corr_matrix = _CorrMatrix()
dependent: Constraint = _Dependent()
dependent: _Dependent = _Dependent()
greater_than = _GreaterThan
greater_than_eq = _GreaterThanEq
less_than = _LessThan
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@
from jax.scipy.stats import norm as jax_norm
from jax.typing import ArrayLike

from numpyro._typing import DistributionT
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just depend on Distribution, instead of DistributionT? We have some success in constraints.py and transforms.py, which do not use ConstraintT and TransformT. We want to reduce the scope of using DistributionT through the library. Ideally, it's great to not use those protocols at all to avoid confusion.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed them in 3303b4a :)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe grep DistributionT and remove at other scripts too? :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Started on c27121b and fixing edge cases in the following commits :)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

from numpyro.distributions import constraints
from numpyro.distributions.discrete import _to_logits_bernoulli
from numpyro.distributions.distribution import (
Distribution,
DistributionT,
TransformedDistribution,
)
from numpyro.distributions.transforms import (
Expand Down
3 changes: 2 additions & 1 deletion numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
from jax.scipy.special import expit, gammaincc, gammaln, logsumexp, xlog1py, xlogy
from jax.typing import ArrayLike

from numpyro._typing import DistributionT
from numpyro.distributions import constraints, transforms
from numpyro.distributions.distribution import Distribution, DistributionT
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
assert_one_of,
binary_cross_entropy_with_logits,
Expand Down
Loading
Loading