Add int64 support for output_shape of tf.nn.conv3d_transpose (#19248)
authorYong Tang <yong.tang.github@outlook.com>
Mon, 14 May 2018 18:35:40 +0000 (11:35 -0700)
committerRasmus Munk Larsen <rmlarsen@google.com>
Mon, 14 May 2018 18:35:40 +0000 (11:35 -0700)
* Add int64 support for output_shape of tf.nn.conv3d_transpose

This fix tries to address the issue raised in 18887 where
the output_shape of tf.nn.conv3d_transpose only support
int32 data types. The support of int64 has been added in this PR
with test case covered.

This fix fixes 18887.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Update op registration for Conv3DBackpropInputV2

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add test case for int64 support of output_shape with tf.nn.conv3d_transpose

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Update test case with both int32 and int64

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Fix pylint issue

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
tensorflow/core/kernels/conv_grad_ops_3d.cc
tensorflow/core/ops/nn_ops.cc
tensorflow/python/kernel_tests/conv3d_transpose_test.py

index 9edc6d4..980b106 100644 (file)
@@ -195,8 +195,8 @@ class Conv3DBackpropInputOp : public OpKernel {
     TensorShape input_shape;
     if (takes_shape_) {
       const Tensor& input_sizes = context->input(0);
-      OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
-                                  input_sizes.vec<int32>(), &input_shape));
+      // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
+      OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
     } else {
       input_shape = context->input(0).shape();
     }
index bb46daf..fc60e80 100644 (file)
@@ -547,7 +547,7 @@ REGISTER_OP("Conv3DBackpropFilter")
     });
 
 REGISTER_OP("Conv3DBackpropInputV2")
-    .Input("input_sizes: int32")
+    .Input("input_sizes: Tshape")
     .Input("filter: T")
     .Input("out_backprop: T")
     .Output("output: T")
@@ -556,6 +556,7 @@ REGISTER_OP("Conv3DBackpropInputV2")
     .Attr(GetPaddingAttrString())
     .Attr(GetConvnet3dDataFormatAttrString())
     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+    .Attr("Tshape: {int32, int64} = DT_INT32")
     .SetShapeFn([](InferenceContext* c) {
       ShapeHandle s;
       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
index 8973a45..289ae29 100644 (file)
@@ -131,6 +131,23 @@ class Conv3DTransposeTest(test.TestCase):
     nn_ops.conv3d_transpose(
         x_value, f_value, y_shape, strides, data_format='NCDHW')
 
+  def testConv3DTransposeOutputShapeType(self):
+    # Test case for GitHub issue 18887
+    for dtype in [dtypes.int32, dtypes.int64]:
+      with self.test_session():
+        x_shape = [2, 5, 6, 4, 3]
+        y_shape = [2, 5, 6, 4, 2]
+        f_shape = [3, 3, 3, 2, 3]
+        strides = [1, 1, 1, 1, 1]
+        x_value = constant_op.constant(
+            1.0, shape=x_shape, name="x", dtype=dtypes.float32)
+        f_value = constant_op.constant(
+            1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
+        output = nn_ops.conv3d_transpose(
+            x_value, f_value, constant_op.constant(y_shape, dtype=dtype),
+            strides=strides, padding="SAME")
+        output.eval()
+
   def testConv3DTransposeValid(self):
     with self.test_session():
       strides = [1, 2, 2, 2, 1]