Skip to content

Commit 4454b0c

Browse files
authored
Merge pull request #907 from emilyfertig/r0.10
For R0.10-rc1, with cherry-pick to fix convolutional layers.
2 parents db388e6 + 967443a commit 4454b0c

File tree

3 files changed

+41
-44
lines changed

3 files changed

+41
-44
lines changed

tensorflow_probability/python/layers/conv_variational.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def build(self, input_shape):
197197
self.kernel_prior = self.kernel_prior_fn(
198198
dtype, kernel_shape, 'kernel_prior',
199199
self.trainable, self.add_variable)
200-
self._built_kernel_divergence = False
201200

202201
if self.bias_posterior_fn is None:
203202
self.bias_posterior = None
@@ -212,7 +211,6 @@ def build(self, input_shape):
212211
self.bias_prior = self.bias_prior_fn(
213212
dtype, (self.filters,), 'bias_prior',
214213
self.trainable, self.add_variable)
215-
self._built_bias_divergence = False
216214

217215
self.input_spec = tf.keras.layers.InputSpec(
218216
ndim=self.rank + 2, axes={channel_axis: input_dim})
@@ -234,20 +232,16 @@ def call(self, inputs):
234232
outputs = self._apply_variational_bias(outputs)
235233
if self.activation is not None:
236234
outputs = self.activation(outputs)
237-
if not self._built_kernel_divergence:
238-
self._apply_divergence(self.kernel_divergence_fn,
239-
self.kernel_posterior,
240-
self.kernel_prior,
241-
self.kernel_posterior_tensor,
242-
name='divergence_kernel')
243-
self._built_kernel_divergence = True
244-
if not self._built_bias_divergence:
245-
self._apply_divergence(self.bias_divergence_fn,
246-
self.bias_posterior,
247-
self.bias_prior,
248-
self.bias_posterior_tensor,
249-
name='divergence_bias')
250-
self._built_bias_divergence = True
235+
self._apply_divergence(self.kernel_divergence_fn,
236+
self.kernel_posterior,
237+
self.kernel_prior,
238+
self.kernel_posterior_tensor,
239+
name='divergence_kernel')
240+
self._apply_divergence(self.bias_divergence_fn,
241+
self.bias_posterior,
242+
self.bias_prior,
243+
self.bias_posterior_tensor,
244+
name='divergence_bias')
251245
return outputs
252246

253247
def compute_output_shape(self, input_shape):

tensorflow_probability/python/layers/conv_variational_test.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
# Dependency imports
2222
import numpy as np
2323

24-
import tensorflow.compat.v1 as tf1
2524
import tensorflow.compat.v2 as tf
2625
import tensorflow_probability as tfp
2726

@@ -180,12 +179,12 @@ def __call__(self, *args, **kwargs):
180179
@test_util.test_all_tf_execution_regimes
181180
class ConvVariational(object):
182181

183-
def maybe_transpose_inputs(self, inputs):
182+
def maybe_transpose_tensor(self, tensor):
184183
if self.data_format == 'channels_first':
185-
order = channels_last_to_first(list(range(inputs.shape.rank)))
186-
return tf.transpose(a=inputs, perm=order)
184+
order = channels_last_to_first(list(range(tensor.shape.rank)))
185+
return tf.transpose(a=tensor, perm=order)
187186
else:
188-
return inputs
187+
return tensor
189188

190189
def _testKerasLayer(self, layer_class): # pylint: disable=invalid-name
191190
def kernel_posterior_fn(dtype, shape, name, trainable, add_variable_fn):
@@ -241,7 +240,7 @@ def _testKLPenaltyKernel(self, layer_class): # pylint: disable=invalid-name
241240
elif layer_class in (tfp.layers.Convolution3DReparameterization,
242241
tfp.layers.Convolution3DFlipout):
243242
inputs = tf.random.uniform([2, 3, 3, 3, 1], seed=1)
244-
inputs = self.maybe_transpose_inputs(inputs)
243+
inputs = self.maybe_transpose_tensor(inputs)
245244

246245
# No keys.
247246
input_dependent_losses = layer.get_losses_for(inputs=None)
@@ -273,7 +272,7 @@ def _testKLPenaltyBoth(self, layer_class): # pylint: disable=invalid-name
273272
elif layer_class in (tfp.layers.Convolution3DReparameterization,
274273
tfp.layers.Convolution3DFlipout):
275274
inputs = tf.random.uniform([2, 3, 3, 3, 1], seed=1)
276-
inputs = self.maybe_transpose_inputs(inputs)
275+
inputs = self.maybe_transpose_tensor(inputs)
277276

278277
# No keys.
279278
input_dependent_losses = layer.get_losses_for(inputs=None)
@@ -307,7 +306,7 @@ def _testConvSetUp(self, layer_class, batch_size, depth=None,
307306
inputs = tf.random.uniform([batch_size, depth, height, width, channels],
308307
seed=seed())
309308
kernel_size = (2, 2, 2)
310-
inputs = self.maybe_transpose_inputs(inputs)
309+
inputs = self.maybe_transpose_tensor(inputs)
311310

312311
kernel_shape = kernel_size + (channels, filters)
313312
kernel_posterior = MockDistribution(
@@ -547,7 +546,7 @@ def _testRandomConvFlipout(self, layer_class): # pylint: disable=invalid-name
547546
inputs = tf.random.uniform([batch_size, depth, height, width, channels],
548547
seed=seed())
549548
kernel_size = (2, 2, 2)
550-
inputs = self.maybe_transpose_inputs(inputs)
549+
inputs = self.maybe_transpose_tensor(inputs)
551550

552551
kernel_shape = kernel_size + (channels, filters)
553552
bias_size = (filters,)
@@ -597,26 +596,30 @@ def _testRandomConvFlipout(self, layer_class): # pylint: disable=invalid-name
597596
np.prod(outputs_one_.shape))
598597

599598
def _testLayerInSequential(self, layer_class): # pylint: disable=invalid-name
600-
with self.cached_session() as sess:
601-
if layer_class in (tfp.layers.Convolution1DReparameterization,
602-
tfp.layers.Convolution1DFlipout):
603-
inputs = tf.random.uniform([2, 3, 1])
604-
elif layer_class in (tfp.layers.Convolution2DReparameterization,
605-
tfp.layers.Convolution2DFlipout):
606-
inputs = tf.random.uniform([2, 3, 3, 1])
607-
elif layer_class in (tfp.layers.Convolution3DReparameterization,
608-
tfp.layers.Convolution3DFlipout):
609-
inputs = tf.random.uniform([2, 3, 3, 3, 1])
610-
inputs = self.maybe_transpose_inputs(inputs)
599+
if layer_class in (tfp.layers.Convolution1DReparameterization,
600+
tfp.layers.Convolution1DFlipout):
601+
inputs = tf.random.uniform([2, 3, 1])
602+
outputs = tf.random.uniform([2, 1, 2])
603+
elif layer_class in (tfp.layers.Convolution2DReparameterization,
604+
tfp.layers.Convolution2DFlipout):
605+
inputs = tf.random.uniform([2, 3, 3, 1])
606+
outputs = tf.random.uniform([2, 1, 1, 2])
607+
elif layer_class in (tfp.layers.Convolution3DReparameterization,
608+
tfp.layers.Convolution3DFlipout):
609+
inputs = tf.random.uniform([2, 3, 3, 3, 1])
610+
outputs = tf.random.uniform([2, 1, 1, 1, 2])
611+
inputs = self.maybe_transpose_tensor(inputs)
612+
outputs = self.maybe_transpose_tensor(outputs)
613+
614+
net = tf.keras.Sequential([
615+
layer_class(filters=2, kernel_size=3, data_format=self.data_format),
616+
layer_class(filters=2, kernel_size=1, data_format=self.data_format)])
611617

612-
net = tf.keras.Sequential([
613-
layer_class(filters=2, kernel_size=3, data_format=self.data_format),
614-
layer_class(filters=2, kernel_size=1, data_format=self.data_format)])
615-
output = net(inputs)
618+
net.compile(loss='mse', optimizer='adam')
619+
net.fit(inputs, outputs, batch_size=2, epochs=3, steps_per_epoch=2)
616620

617-
# Verify that the network runs without errors
618-
sess.run(tf1.global_variables_initializer())
619-
sess.run(output)
621+
batch_output = self.evaluate(net(inputs))
622+
self.assertAllEqual(outputs.shape, batch_output.shape)
620623

621624
def testKerasLayerConvolution1DReparameterization(self):
622625
self._testKerasLayer(tfp.layers.Convolution1DReparameterization)

tensorflow_probability/python/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a
2525
# release branch, the current version is by default assumed to be a
2626
# 'development' version, labeled 'dev'.
27-
_VERSION_SUFFIX = 'rc0'
27+
_VERSION_SUFFIX = 'rc1'
2828

2929
# Example, '0.4.0-dev'
3030
__version__ = '.'.join([

0 commit comments

Comments
 (0)