From 0775f684c51b6b2f24d58c116cc2073d53659e3c Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 14 May 2018 14:58:17 -0700 Subject: [PATCH] Do shape validation in ScatterNd kernel, not just the shape inference function. Fixes #18648 PiperOrigin-RevId: 196572262 --- tensorflow/core/kernels/scatter_nd_op.cc | 47 +++++++++++++++++++++- .../python/kernel_tests/scatter_nd_ops_test.py | 12 ++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 0caa7bd..8ef6e77 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -62,14 +62,57 @@ class ScatterNdOp : public OpKernel { const Tensor& updates = c->input(1); const Tensor& shape_input = c->input(2); - OP_REQUIRES(c, shape_input.dims() == 1, - errors::InvalidArgument("Shape must be a vector")); + OP_REQUIRES(c, indices.shape().dims() >= 1, + errors::InvalidArgument( + "Indices shape must have rank at least one. Found:", + indices.shape().DebugString())); + OP_REQUIRES(c, updates.shape().dims() >= 1, + errors::InvalidArgument( + "Updates shape must have rank at least one. Found:", + updates.shape().DebugString())); auto vec = shape_input.flat(); TensorShape shape; OP_REQUIRES_OK(c, TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape)); + OP_REQUIRES( + c, + (shape.num_elements() > 0 || (indices.shape().num_elements() == 0 && + updates.shape().num_elements() == 0)), + errors::InvalidArgument( + "Indices and updates specified for empty output shape")); + + const int64 outer_dims = indices.shape().dims() - 1; + + for (int i = 0; i < outer_dims; ++i) { + OP_REQUIRES(c, indices.shape().dim_size(i) == updates.shape().dim_size(i), + errors::InvalidArgument( + "Outer dimensions of indices and update must match. " + "Indices shape: ", + indices.shape().DebugString(), + ", updates shape:", updates.shape().DebugString())); + } + + const int64 ix = indices.shape().dim_size(outer_dims); + OP_REQUIRES( + c, updates.shape().dims() - outer_dims == shape.dims() - ix, + errors::InvalidArgument("Inner dimensions of output shape must match " + "inner dimensions of updates shape. Output: ", + shape.DebugString(), + " updates: ", updates.shape().DebugString())); + for (int i = 0; i + outer_dims < updates.shape().dims(); ++i) { + OP_REQUIRES( + c, updates.shape().dim_size(i + outer_dims) == shape.dim_size(ix + i), + errors::InvalidArgument( + "The inner ", shape.dims() - ix, + " dimensions of output.shape=", shape.DebugString(), + " must match the inner ", updates.shape().dims() - outer_dims, + " dimensions of updates.shape=", updates.shape().DebugString())); + } + OP_REQUIRES(c, shape_input.dims() == 1, + errors::InvalidArgument("Shape must be a vector")); + Tensor out; OP_REQUIRES_OK( c, functor::DoScatterNd( diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index b7477a7..79fe927 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -23,8 +23,11 @@ import functools import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops @@ -364,6 +367,15 @@ class ScatterNdTest(test.TestCase): del input_ # input_ is not used in scatter_nd return array_ops.scatter_nd(indices, updates, shape) + @test_util.run_in_graph_and_eager_modes() + def testInvalidShape(self): + # TODO(apassos) figure out how to unify these errors + with self.assertRaises(errors.InvalidArgumentError + if context.executing_eagerly() else ValueError): + array_ops.scatter_nd(indices=[0], # this should be indices=[[0]] + updates=[0.0], + shape=[1]) + def testString(self): indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) -- 2.7.4