from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
self.assertAllClose(cache_values, value)
+ @test_util.enable_c_shapes
def testConv2DTransposeShapeInference(self):
# Test case for 8972
initializer = random_ops.truncated_normal(
f_shape = array_ops.stack([array_ops.shape(x)[0], 10, 5, 5])
output = nn_ops.conv2d_transpose(
x, f, f_shape, strides=[1, 1, 1, 1], padding="SAME")
- self.assertEqual(output.get_shape().as_list(), [None, 10, 5, 5])
+ self.assertEqual(output.get_shape().as_list(), [3, 10, 5, 5])
+
if __name__ == "__main__":
test.main()