22
22
import tensorflow .compat .v2 as tf
23
23
24
24
if JAX_MODE or NUMPY_MODE :
25
- tnp = np
25
+ numpy_ops = np
26
26
else :
27
- import tensorflow .experimental . numpy as tnp
27
+ from tensorflow .python . ops import numpy_ops
28
28
29
29
from tensorflow_probability .python .internal import assert_util
30
30
from tensorflow_probability .python .internal import distribution_util
@@ -739,13 +739,12 @@ def windowed_variance(
739
739
Then each element of `low_indices` and `high_indices` must be
740
740
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
741
741
742
- The shape of indices must be broadcastable with `x` unless the rank is lower
743
- than the rank of `x`, then the shape is expanded with extra inner dimensions
744
- to match the rank of `x`.
742
+ The shape `Bi + [1] + F` must be broadcastable with the shape of `x`.
745
743
746
- In the special case where the rank of indices is one, i.e when
747
- `rank(Bi) = rank(F) = 0`, the indices are reshaped to
748
- `[1] * rank(Bx) + [M] + [1] * rank(E)`.
744
+ If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded
745
+ with extra inner dimensions to match the rank of `x`. In the special
746
+ case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`,
747
+ the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`.
749
748
750
749
The default windows are
751
750
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
@@ -801,10 +800,10 @@ def windowed_variance(
801
800
def index_for_cumulative (indices ):
802
801
return tf .maximum (indices - 1 , 0 )
803
802
cum_sums = tf .cumsum (x , axis = axis )
804
- sums = tnp .take_along_axis (
803
+ sums = numpy_ops .take_along_axis (
805
804
cum_sums , index_for_cumulative (indices ), axis = axis )
806
805
cum_variances = cumulative_variance (x , sample_axis = axis )
807
- variances = tnp .take_along_axis (
806
+ variances = numpy_ops .take_along_axis (
808
807
cum_variances , index_for_cumulative (indices ), axis = axis )
809
808
810
809
# This formula is the binary accurate variance merge from [1],
@@ -860,13 +859,12 @@ def windowed_mean(
860
859
Then each element of `low_indices` and `high_indices` must be
861
860
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
862
861
863
- The shape of indices must be broadcastable with `x` unless the rank is lower
864
- than the rank of `x`, then the shape is expanded with extra inner dimensions
865
- to match the rank of `x`.
862
+ The shape `Bi + [1] + F` must be broadcastable with the shape of `x`.
866
863
867
- In the special case where the rank of indices is one, i.e when
868
- `rank(Bi) = rank(F) = 0`, the indices are reshaped to
869
- `[1] * rank(Bx) + [M] + [1] * rank(E)`.
864
+ If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded
865
+ with extra inner dimensions to match the rank of `x`. In the special
866
+ case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`,
867
+ the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`.
870
868
871
869
The default windows are
872
870
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
@@ -906,7 +904,7 @@ def windowed_mean(
906
904
paddings = ps .reshape (ps .one_hot (2 * axis , depth = 2 * rank , dtype = tf .int32 ),
907
905
(rank , 2 ))
908
906
cum_sums = ps .pad (raw_cumsum , paddings )
909
- sums = tnp .take_along_axis (cum_sums , indices ,
907
+ sums = numpy_ops .take_along_axis (cum_sums , indices ,
910
908
axis = axis )
911
909
counts = ps .cast (indices [1 ] - indices [0 ], dtype = sums .dtype )
912
910
return tf .math .divide_no_nan (sums [1 ] - sums [0 ], counts )
@@ -915,7 +913,7 @@ def windowed_mean(
915
913
def _prepare_window_args (x , low_indices = None , high_indices = None , axis = 0 ):
916
914
"""Common argument defaulting logic for windowed statistics."""
917
915
if high_indices is None :
918
- high_indices = tf .range (ps .shape (x )[axis ]) + 1
916
+ high_indices = ps .range (ps .shape (x )[axis ]) + 1
919
917
else :
920
918
high_indices = tf .convert_to_tensor (high_indices )
921
919
if low_indices is None :
@@ -941,7 +939,7 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
941
939
bc_shape = indices_shape
942
940
943
941
bc_shape = ps .concat ([[2 ], bc_shape ], axis = 0 )
944
- indices = tf .stack ([low_indices , high_indices ], axis = 0 )
942
+ indices = ps .stack ([low_indices , high_indices ], axis = 0 )
945
943
indices = ps .reshape (indices , bc_shape )
946
944
x = tf .expand_dims (x , axis = 0 )
947
945
axis += 1
0 commit comments