|
21 | 21 | # Dependency imports |
22 | 22 | import numpy as np |
23 | 23 |
|
24 | | -import tensorflow.compat.v1 as tf1 |
25 | 24 | import tensorflow.compat.v2 as tf |
26 | 25 | import tensorflow_probability as tfp |
27 | 26 |
|
@@ -180,12 +179,12 @@ def __call__(self, *args, **kwargs): |
180 | 179 | @test_util.test_all_tf_execution_regimes |
181 | 180 | class ConvVariational(object): |
182 | 181 |
|
183 | | - def maybe_transpose_inputs(self, inputs): |
| 182 | + def maybe_transpose_tensor(self, tensor): |
184 | 183 | if self.data_format == 'channels_first': |
185 | | - order = channels_last_to_first(list(range(inputs.shape.rank))) |
186 | | - return tf.transpose(a=inputs, perm=order) |
| 184 | + order = channels_last_to_first(list(range(tensor.shape.rank))) |
| 185 | + return tf.transpose(a=tensor, perm=order) |
187 | 186 | else: |
188 | | - return inputs |
| 187 | + return tensor |
189 | 188 |
|
190 | 189 | def _testKerasLayer(self, layer_class): # pylint: disable=invalid-name |
191 | 190 | def kernel_posterior_fn(dtype, shape, name, trainable, add_variable_fn): |
@@ -241,7 +240,7 @@ def _testKLPenaltyKernel(self, layer_class): # pylint: disable=invalid-name |
241 | 240 | elif layer_class in (tfp.layers.Convolution3DReparameterization, |
242 | 241 | tfp.layers.Convolution3DFlipout): |
243 | 242 | inputs = tf.random.uniform([2, 3, 3, 3, 1], seed=1) |
244 | | - inputs = self.maybe_transpose_inputs(inputs) |
| 243 | + inputs = self.maybe_transpose_tensor(inputs) |
245 | 244 |
|
246 | 245 | # No keys. |
247 | 246 | input_dependent_losses = layer.get_losses_for(inputs=None) |
@@ -273,7 +272,7 @@ def _testKLPenaltyBoth(self, layer_class): # pylint: disable=invalid-name |
273 | 272 | elif layer_class in (tfp.layers.Convolution3DReparameterization, |
274 | 273 | tfp.layers.Convolution3DFlipout): |
275 | 274 | inputs = tf.random.uniform([2, 3, 3, 3, 1], seed=1) |
276 | | - inputs = self.maybe_transpose_inputs(inputs) |
| 275 | + inputs = self.maybe_transpose_tensor(inputs) |
277 | 276 |
|
278 | 277 | # No keys. |
279 | 278 | input_dependent_losses = layer.get_losses_for(inputs=None) |
@@ -307,7 +306,7 @@ def _testConvSetUp(self, layer_class, batch_size, depth=None, |
307 | 306 | inputs = tf.random.uniform([batch_size, depth, height, width, channels], |
308 | 307 | seed=seed()) |
309 | 308 | kernel_size = (2, 2, 2) |
310 | | - inputs = self.maybe_transpose_inputs(inputs) |
| 309 | + inputs = self.maybe_transpose_tensor(inputs) |
311 | 310 |
|
312 | 311 | kernel_shape = kernel_size + (channels, filters) |
313 | 312 | kernel_posterior = MockDistribution( |
@@ -547,7 +546,7 @@ def _testRandomConvFlipout(self, layer_class): # pylint: disable=invalid-name |
547 | 546 | inputs = tf.random.uniform([batch_size, depth, height, width, channels], |
548 | 547 | seed=seed()) |
549 | 548 | kernel_size = (2, 2, 2) |
550 | | - inputs = self.maybe_transpose_inputs(inputs) |
| 549 | + inputs = self.maybe_transpose_tensor(inputs) |
551 | 550 |
|
552 | 551 | kernel_shape = kernel_size + (channels, filters) |
553 | 552 | bias_size = (filters,) |
@@ -597,26 +596,30 @@ def _testRandomConvFlipout(self, layer_class): # pylint: disable=invalid-name |
597 | 596 | np.prod(outputs_one_.shape)) |
598 | 597 |
|
599 | 598 | def _testLayerInSequential(self, layer_class): # pylint: disable=invalid-name |
600 | | - with self.cached_session() as sess: |
601 | | - if layer_class in (tfp.layers.Convolution1DReparameterization, |
602 | | - tfp.layers.Convolution1DFlipout): |
603 | | - inputs = tf.random.uniform([2, 3, 1]) |
604 | | - elif layer_class in (tfp.layers.Convolution2DReparameterization, |
605 | | - tfp.layers.Convolution2DFlipout): |
606 | | - inputs = tf.random.uniform([2, 3, 3, 1]) |
607 | | - elif layer_class in (tfp.layers.Convolution3DReparameterization, |
608 | | - tfp.layers.Convolution3DFlipout): |
609 | | - inputs = tf.random.uniform([2, 3, 3, 3, 1]) |
610 | | - inputs = self.maybe_transpose_inputs(inputs) |
| 599 | + if layer_class in (tfp.layers.Convolution1DReparameterization, |
| 600 | + tfp.layers.Convolution1DFlipout): |
| 601 | + inputs = tf.random.uniform([2, 3, 1]) |
| 602 | + outputs = tf.random.uniform([2, 1, 2]) |
| 603 | + elif layer_class in (tfp.layers.Convolution2DReparameterization, |
| 604 | + tfp.layers.Convolution2DFlipout): |
| 605 | + inputs = tf.random.uniform([2, 3, 3, 1]) |
| 606 | + outputs = tf.random.uniform([2, 1, 1, 2]) |
| 607 | + elif layer_class in (tfp.layers.Convolution3DReparameterization, |
| 608 | + tfp.layers.Convolution3DFlipout): |
| 609 | + inputs = tf.random.uniform([2, 3, 3, 3, 1]) |
| 610 | + outputs = tf.random.uniform([2, 1, 1, 1, 2]) |
| 611 | + inputs = self.maybe_transpose_tensor(inputs) |
| 612 | + outputs = self.maybe_transpose_tensor(outputs) |
| 613 | + |
| 614 | + net = tf.keras.Sequential([ |
| 615 | + layer_class(filters=2, kernel_size=3, data_format=self.data_format), |
| 616 | + layer_class(filters=2, kernel_size=1, data_format=self.data_format)]) |
611 | 617 |
|
612 | | - net = tf.keras.Sequential([ |
613 | | - layer_class(filters=2, kernel_size=3, data_format=self.data_format), |
614 | | - layer_class(filters=2, kernel_size=1, data_format=self.data_format)]) |
615 | | - output = net(inputs) |
| 618 | + net.compile(loss='mse', optimizer='adam') |
| 619 | + net.fit(inputs, outputs, batch_size=2, epochs=3, steps_per_epoch=2) |
616 | 620 |
|
617 | | - # Verify that the network runs without errors |
618 | | - sess.run(tf1.global_variables_initializer()) |
619 | | - sess.run(output) |
| 621 | + batch_output = self.evaluate(net(inputs)) |
| 622 | + self.assertAllEqual(outputs.shape, batch_output.shape) |
620 | 623 |
|
621 | 624 | def testKerasLayerConvolution1DReparameterization(self): |
622 | 625 | self._testKerasLayer(tfp.layers.Convolution1DReparameterization) |
|
0 commit comments