Skip to content

Commit 8d20563

Browse files
committed
Test against tensors with dynamic shapes
Some `tensorflow` to `prefer_static` replacement
1 parent 6a18217 commit 8d20563

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

tensorflow_probability/python/stats/sample_stats.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
import tensorflow.compat.v2 as tf
2323

2424
if JAX_MODE or NUMPY_MODE:
25-
tnp = np
25+
numpy_ops = np
2626
else:
27-
import tensorflow.experimental.numpy as tnp
27+
from tensorflow.python.ops import numpy_ops
2828

2929
from tensorflow_probability.python.internal import assert_util
3030
from tensorflow_probability.python.internal import distribution_util
@@ -739,13 +739,12 @@ def windowed_variance(
739739
Then each element of `low_indices` and `high_indices` must be
740740
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
741741
742-
The shape of indices must be broadcastable with `x` unless the rank is lower
743-
than the rank of `x`, then the shape is expanded with extra inner dimensions
744-
to match the rank of `x`.
742+
The shape `Bi + [1] + F` must be broadcastable with the shape of `x`.
745743
746-
In the special case where the rank of indices is one, i.e when
747-
`rank(Bi) = rank(F) = 0`, the indices are reshaped to
748-
`[1] * rank(Bx) + [M] + [1] * rank(E)`.
744+
If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded
745+
with extra inner dimensions to match the rank of `x`. In the special
746+
case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`,
747+
the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`.
749748
750749
The default windows are
751750
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
@@ -801,10 +800,10 @@ def windowed_variance(
801800
def index_for_cumulative(indices):
802801
return tf.maximum(indices - 1, 0)
803802
cum_sums = tf.cumsum(x, axis=axis)
804-
sums = tnp.take_along_axis(
803+
sums = numpy_ops.take_along_axis(
805804
cum_sums, index_for_cumulative(indices), axis=axis)
806805
cum_variances = cumulative_variance(x, sample_axis=axis)
807-
variances = tnp.take_along_axis(
806+
variances = numpy_ops.take_along_axis(
808807
cum_variances, index_for_cumulative(indices), axis=axis)
809808

810809
# This formula is the binary accurate variance merge from [1],
@@ -860,13 +859,12 @@ def windowed_mean(
860859
Then each element of `low_indices` and `high_indices` must be
861860
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
862861
863-
The shape of indices must be broadcastable with `x` unless the rank is lower
864-
than the rank of `x`, then the shape is expanded with extra inner dimensions
865-
to match the rank of `x`.
862+
The shape `Bi + [1] + F` must be broadcastable with the shape of `x`.
866863
867-
In the special case where the rank of indices is one, i.e when
868-
`rank(Bi) = rank(F) = 0`, the indices are reshaped to
869-
`[1] * rank(Bx) + [M] + [1] * rank(E)`.
864+
If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded
865+
with extra inner dimensions to match the rank of `x`. In the special
866+
case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`,
867+
the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`.
870868
871869
The default windows are
872870
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
@@ -906,7 +904,7 @@ def windowed_mean(
906904
paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32),
907905
(rank, 2))
908906
cum_sums = ps.pad(raw_cumsum, paddings)
909-
sums = tnp.take_along_axis(cum_sums, indices,
907+
sums = numpy_ops.take_along_axis(cum_sums, indices,
910908
axis=axis)
911909
counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype)
912910
return tf.math.divide_no_nan(sums[1] - sums[0], counts)
@@ -915,7 +913,7 @@ def windowed_mean(
915913
def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
916914
"""Common argument defaulting logic for windowed statistics."""
917915
if high_indices is None:
918-
high_indices = tf.range(ps.shape(x)[axis]) + 1
916+
high_indices = ps.range(ps.shape(x)[axis]) + 1
919917
else:
920918
high_indices = tf.convert_to_tensor(high_indices)
921919
if low_indices is None:
@@ -941,7 +939,7 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
941939
bc_shape = indices_shape
942940

943941
bc_shape = ps.concat([[2], bc_shape], axis=0)
944-
indices = tf.stack([low_indices, high_indices], axis=0)
942+
indices = ps.stack([low_indices, high_indices], axis=0)
945943
indices = ps.reshape(indices, bc_shape)
946944
x = tf.expand_dims(x, axis=0)
947945
axis += 1

tensorflow_probability/python/stats/sample_stats_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -735,17 +735,26 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
735735
indices = rng.randint(shape[axis] + 1, size=indice_shape)
736736
indices = np.sort(indices, axis=0)
737737
low_indices, high_indices = indices[0], indices[1]
738+
739+
tf_low_indices = self._make_dynamic_shape(low_indices)
740+
tf_high_indices = self._make_dynamic_shape(high_indices)
741+
tf_x = self._make_dynamic_shape(x)
742+
743+
a = window_func(tf_x, low_indices=tf_low_indices,
744+
high_indices=tf_high_indices, axis=axis)
745+
738746
low_indices = self._maybe_expand_dims_to_make_broadcastable(
739747
low_indices, x.shape, axis)
740748
high_indices = self._maybe_expand_dims_to_make_broadcastable(
741749
high_indices, x.shape, axis)
742-
a = window_func(x, low_indices=low_indices,
743-
high_indices=high_indices, axis=axis)
744750
b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices,
745751
axis=axis)
746752
b[np.isnan(b)] = 0 # We treat stats computed on empty sets as zeros
747753
self.assertAllClose(a, b)
748754

755+
def _make_dynamic_shape(self, x):
756+
return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape))
757+
749758
def check_windowed(self, func, numpy_func):
750759
check_fn = functools.partial(self.check_gaussian_windowed,
751760
window_func=func, np_func=numpy_func)

0 commit comments

Comments
 (0)