Make conv2d_tranpose_test.py work with C API shapes enabled.
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 8 May 2018 00:24:28 +0000 (17:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 01:30:22 +0000 (18:30 -0700)
The C API provides more accurate shape information in many cases.

PiperOrigin-RevId: 195749030

tensorflow/python/kernel_tests/conv2d_transpose_test.py

index b692d3d..27804be 100644 (file)
@@ -23,6 +23,7 @@ from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradient_checker
 from tensorflow.python.ops import nn_ops
@@ -292,6 +293,7 @@ class Conv2DTransposeTest(test.TestCase):
 
         self.assertAllClose(cache_values, value)
 
+  @test_util.enable_c_shapes
   def testConv2DTransposeShapeInference(self):
     # Test case for 8972
     initializer = random_ops.truncated_normal(
@@ -301,7 +303,8 @@ class Conv2DTransposeTest(test.TestCase):
     f_shape = array_ops.stack([array_ops.shape(x)[0], 10, 5, 5])
     output = nn_ops.conv2d_transpose(
         x, f, f_shape, strides=[1, 1, 1, 1], padding="SAME")
-    self.assertEqual(output.get_shape().as_list(), [None, 10, 5, 5])
+    self.assertEqual(output.get_shape().as_list(), [3, 10, 5, 5])
+
 
 if __name__ == "__main__":
   test.main()