Skip to content

Commit

Permalink
NumPy 2.0 related fixes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666824480
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Aug 23, 2024
1 parent 30c737c commit 8a4555e
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 8 deletions.
18 changes: 13 additions & 5 deletions tensorflow_probability/python/bijectors/bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
]

JAX_MODE = False
NUMPY_MODE = False
SKIP_DTYPE_CHECKS = False

# Singleton object representing "no value", in cases where "None" is meaningful.
Expand Down Expand Up @@ -1914,11 +1915,12 @@ def __str__(self):
maybe_dtype = ''
if self.forward_min_event_ndims == self.inverse_min_event_ndims:
maybe_min_ndims = ', min_event_ndims={}'.format(
self.forward_min_event_ndims)
_unwrap_event_ndims(self.forward_min_event_ndims))
else:
maybe_min_ndims = (
', forward_min_event_ndims={}, inverse_min_event_ndims={}'.format(
self.forward_min_event_ndims, self.inverse_min_event_ndims))
_unwrap_event_ndims(self.forward_min_event_ndims),
_unwrap_event_ndims(self.inverse_min_event_ndims)))
maybe_min_ndims = maybe_min_ndims.replace('\'', '')
return ('tfp.bijectors.{type_name}('
'"{self_name}"'
Expand Down Expand Up @@ -1949,8 +1951,10 @@ def __repr__(self):
type_name=type(self).__name__,
self_name=self.name or '<unknown>',
batch_shape=batch_shape_str,
forward_min_event_ndims=self.forward_min_event_ndims,
inverse_min_event_ndims=self.inverse_min_event_ndims,
forward_min_event_ndims=_unwrap_event_ndims(
self.forward_min_event_ndims),
inverse_min_event_ndims=_unwrap_event_ndims(
self.inverse_min_event_ndims),
dtype_x=_str_dtype(self.inverse_dtype()),
dtype_y=_str_dtype(self.forward_dtype())))

Expand Down Expand Up @@ -2026,7 +2030,7 @@ def check_valid_ndims(ndims, validate=True):
assertions = []

shape = ps.shape(ndims)
if not tf.is_tensor(shape):
if not tf.is_tensor(shape) or NUMPY_MODE or JAX_MODE:
if shape.tolist():
raise ValueError('Expected scalar, saw shape {}.'.format(shape))
elif validate:
Expand Down Expand Up @@ -2390,3 +2394,7 @@ def _str_tensorshape(x):
if tensorshape_util.rank(x) is None:
return '?'
return str(tensorshape_util.as_list(x)).replace('None', '?')


def _unwrap_event_ndims(ndims):
return nest.map_structure(int, ndims)
1 change: 1 addition & 0 deletions tensorflow_probability/python/bijectors/sinh_arcsinh.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(self,
forward_min_event_ndims=0,
validate_args=validate_args,
parameters=parameters,
dtype=dtype,
name=name)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def _get_support_bijectors(dists, xs=None, ys=None):
event_ndims = [0, 1, 0]
fldj = joint_bijector.forward_log_det_jacobian(xs, event_ndims)
fldj_fd = _finite_difference_ldj(bijectors, 'forward', xs, delta=0.01)
self.assertAllClose(self.evaluate(fldj), self.evaluate(fldj_fd), rtol=1e-5)
self.assertAllClose(self.evaluate(fldj), self.evaluate(fldj_fd), rtol=2e-5)

# Test inverse log det Jacobian via finite differences.
ildj = joint_bijector.inverse_log_det_jacobian(ys, event_ndims)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_log_prob_matches_linear_gaussian_ssm(self):
num_steps=7)

x = markov_chain.sample(5, seed=seed)
self.assertAllClose(lgssm.log_prob(x), markov_chain.log_prob(x), rtol=1e-4)
self.assertAllClose(lgssm.log_prob(x), markov_chain.log_prob(x), rtol=5e-4)

@test_util.numpy_disable_test_missing_functionality(
'JointDistributionNamedAutoBatched')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _invert_permutation(x, name=None):
def _l2_normalize(x, axis=None, epsilon=1e-12, name=None): # pylint: disable=unused-argument
x = _convert_to_tensor(x)
norm = np.linalg.norm(x, ord=2, axis=_astuple(axis), keepdims=True)
norm = np.maximum(norm, np.sqrt(epsilon))
norm = np.maximum(norm, np.sqrt(np.asarray(epsilon, dtype=norm.dtype)))
return x / norm


Expand Down

0 comments on commit 8a4555e

Please sign in to comment.