Do shape validation in ScatterNd kernel, not just the shape inference function.
authorAlexandre Passos <apassos@google.com>
Mon, 14 May 2018 21:58:17 +0000 (14:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 14 May 2018 22:01:28 +0000 (15:01 -0700)
Fixes #18648

PiperOrigin-RevId: 196572262

tensorflow/core/kernels/scatter_nd_op.cc
tensorflow/python/kernel_tests/scatter_nd_ops_test.py

index 0caa7bd..8ef6e77 100644 (file)
@@ -62,14 +62,57 @@ class ScatterNdOp : public OpKernel {
     const Tensor& updates = c->input(1);
     const Tensor& shape_input = c->input(2);
 
-    OP_REQUIRES(c, shape_input.dims() == 1,
-                errors::InvalidArgument("Shape must be a vector"));
+    OP_REQUIRES(c, indices.shape().dims() >= 1,
+                errors::InvalidArgument(
+                    "Indices shape must have rank at least one. Found:",
+                    indices.shape().DebugString()));
+    OP_REQUIRES(c, updates.shape().dims() >= 1,
+                errors::InvalidArgument(
+                    "Updates shape must have rank at least one. Found:",
+                    updates.shape().DebugString()));
 
     auto vec = shape_input.flat<Index>();
     TensorShape shape;
     OP_REQUIRES_OK(c,
                    TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape));
 
+    OP_REQUIRES(
+        c,
+        (shape.num_elements() > 0 || (indices.shape().num_elements() == 0 &&
+                                      updates.shape().num_elements() == 0)),
+        errors::InvalidArgument(
+            "Indices and updates specified for empty output shape"));
+
+    const int64 outer_dims = indices.shape().dims() - 1;
+
+    for (int i = 0; i < outer_dims; ++i) {
+      OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i),
+                  errors::InvalidArgument(
+                      "Outer dimensions of indices and update must match. "
+                      "Indices shape: ",
+                      indices.shape().DebugString(),
+                      ", updates shape:", updates.shape().DebugString()));
+    }
+
+    const int64 ix = indices.shape().dim_size(outer_dims);
+    OP_REQUIRES(
+        c, updates.shape().dims() - outer_dims == shape.dims() - ix,
+        errors::InvalidArgument("Inner dimensions of output shape must match "
+                                "inner dimensions of updates shape. Output: ",
+                                shape.DebugString(),
+                                " updates: ", updates.shape().DebugString()));
+    for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) {
+      OP_REQUIRES(
+          c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i),
+          errors::InvalidArgument(
+              "The inner ", shape.dims() - ix,
+              " dimensions of output.shape=", shape.DebugString(),
+              " must match the inner ", updates.shape().dims() - outer_dims,
+              " dimensions of updates.shape=", updates.shape().DebugString()));
+    }
+    OP_REQUIRES(c, shape_input.dims() == 1,
+                errors::InvalidArgument("Shape must be a vector"));
+
     Tensor out;
     OP_REQUIRES_OK(
         c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
index b7477a7..79fe927 100644 (file)
@@ -23,8 +23,11 @@ import functools
 import numpy as np
 
 from tensorflow.python.client import session
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import resource_variable_ops
@@ -364,6 +367,15 @@ class ScatterNdTest(test.TestCase):
     del input_  # input_ is not used in scatter_nd
     return array_ops.scatter_nd(indices, updates, shape)
 
+  @test_util.run_in_graph_and_eager_modes()
+  def testInvalidShape(self):
+    # TODO(apassos) figure out how to unify these errors
+    with self.assertRaises(errors.InvalidArgumentError
+                           if context.executing_eagerly() else ValueError):
+      array_ops.scatter_nd(indices=[0],  # this should be indices=[[0]]
+                           updates=[0.0],
+                           shape=[1])
+
   def testString(self):
     indices = constant_op.constant([[4], [3], [1], [7]],
                                    dtype=dtypes.int32)