Fix of issue #13164 (Merges #13382) (#16368)
authorRobin Richtsfeld <robin.richtsfeld@gmail.com>
Fri, 25 May 2018 23:38:33 +0000 (01:38 +0200)
committerRasmus Munk Larsen <rmlarsen@google.com>
Fri, 25 May 2018 23:38:33 +0000 (16:38 -0700)
* tf.gather int64 GPU, tf.gather_nd int32/int64 GPU, tf.scatter_nd int32 GPU

* Fix tf.gather test

12 files changed:
tensorflow/core/kernels/dense_update_functor_gpu.cu.cc
tensorflow/core/kernels/gather_functor.cc
tensorflow/core/kernels/gather_functor_gpu.cu.cc
tensorflow/core/kernels/gather_nd_op.cc
tensorflow/core/kernels/gather_nd_op_gpu.cu.cc
tensorflow/core/kernels/gather_op.cc
tensorflow/core/kernels/scatter_nd_op.cc
tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
tensorflow/python/kernel_tests/gather_nd_op_test.py
tensorflow/python/kernel_tests/gather_op_test.py
tensorflow/python/kernel_tests/scatter_nd_ops_test.py
tensorflow/python/kernel_tests/scatter_ops_test.py

index 9a3b230..17a85d9 100644 (file)
@@ -57,6 +57,7 @@ struct DenseUpdate<GPUDevice, T, SUB> {
   template struct functor::DenseUpdate<GPUDevice, T, ADD>; \
   template struct functor::DenseUpdate<GPUDevice, T, SUB>;
 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
+TF_CALL_int32(DEFINE_GPU_KERNELS);
 TF_CALL_int64(DEFINE_GPU_KERNELS);
 #undef DEFINE_GPU_KERNELS
 
index e6fefe6..5cd8e04 100644 (file)
@@ -37,6 +37,7 @@ namespace functor {
   DECLARE_GPU_SPECS_INDEX(T, int32); \
   DECLARE_GPU_SPECS_INDEX(T, int64)
 
+TF_CALL_int64(DECLARE_GPU_SPECS);
 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
 TF_CALL_complex64(DECLARE_GPU_SPECS);
 TF_CALL_complex128(DECLARE_GPU_SPECS);
index 39b6924..4563fc6 100644 (file)
@@ -31,6 +31,7 @@ typedef Eigen::GpuDevice GPUDevice;
   DEFINE_GPU_SPECS_INDEX(T, int32); \
   DEFINE_GPU_SPECS_INDEX(T, int64);
 
+TF_CALL_int64(DEFINE_GPU_SPECS);
 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
 TF_CALL_complex64(DEFINE_GPU_SPECS);
 TF_CALL_complex128(DEFINE_GPU_SPECS);
index 7e5a9e1..4e53291 100644 (file)
@@ -228,6 +228,8 @@ namespace functor {
   DECLARE_GPU_SPECS_INDEX(T, int32); \
   DECLARE_GPU_SPECS_INDEX(T, int64)
 
+TF_CALL_int32(DECLARE_GPU_SPECS);
+TF_CALL_int64(DECLARE_GPU_SPECS);
 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
 TF_CALL_complex64(DECLARE_GPU_SPECS);
 TF_CALL_complex128(DECLARE_GPU_SPECS);
@@ -239,6 +241,8 @@ TF_CALL_complex128(DECLARE_GPU_SPECS);
 // Registration of the GPU implementations.
 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
 
+TF_CALL_int32(REGISTER_GATHER_ND_GPU);
+TF_CALL_int64(REGISTER_GATHER_ND_GPU);
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
 TF_CALL_complex64(REGISTER_GATHER_ND_GPU);
 TF_CALL_complex128(REGISTER_GATHER_ND_GPU);
index b03efc6..da8d2e9 100644 (file)
@@ -119,6 +119,8 @@ struct GatherNdSlice<GPUDevice, T, Index, IXDIM> {
   DEFINE_GPU_SPECS_INDEX(T, int32); \
   DEFINE_GPU_SPECS_INDEX(T, int64);
 
+TF_CALL_int32(DEFINE_GPU_SPECS);
+TF_CALL_int64(DEFINE_GPU_SPECS);
 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
 TF_CALL_complex64(DEFINE_GPU_SPECS);
 TF_CALL_complex128(DEFINE_GPU_SPECS);
index ef332eb..094504d 100644 (file)
@@ -153,6 +153,7 @@ TF_CALL_uint64(REGISTER_GATHER_CPU);
 // Registration of the GPU implementations.
 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
 
+TF_CALL_int64(REGISTER_GATHER_GPU);
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
 TF_CALL_complex64(REGISTER_GATHER_GPU);
 TF_CALL_complex128(REGISTER_GATHER_GPU);
index 8ef6e77..ff38026 100644 (file)
@@ -300,6 +300,7 @@ TF_CALL_string(REGISTER_SCATTER_ND_CPU);
   REGISTER_SCATTER_ND_UPDATE_GPU(type);   \
   REGISTER_SCATTER_ND_GPU(type);
 
+TF_CALL_int32(REGISTER_SCATTER_ND_ALL_GPU);
 // TODO(b/66916790): Support half types in ScatterNd.
 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ALL_GPU);
 TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU);
@@ -314,6 +315,8 @@ TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
 #define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
   REGISTER_SCATTER_ND_UPDATE(type, SYCL);
 
+TF_CALL_int32(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
+TF_CALL_int32(REGISTER_SCATTER_ND_UPDATE_SYCL);
 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
 #undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
@@ -584,6 +587,7 @@ namespace functor {
   DECLARE_GPU_SPECS_INDEX(T, int32); \
   DECLARE_GPU_SPECS_INDEX(T, int64)
 
+TF_CALL_int32(DECLARE_GPU_SPECS);
 // TODO(b/66916790): Support half types in ScatterNd.
 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
 TF_CALL_complex64(DECLARE_GPU_SPECS);
index a3c21ed..08b657f 100644 (file)
@@ -170,6 +170,7 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
   DECLARE_GPU_SPECS_INDEX(T, int32); \
   DECLARE_GPU_SPECS_INDEX(T, int64)
 
+TF_CALL_int32(DECLARE_GPU_SPECS);
 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
 TF_CALL_complex64(DECLARE_GPU_SPECS);
 TF_CALL_complex128(DECLARE_GPU_SPECS);
index 91ebe8d..58e2a8a 100644 (file)
@@ -197,7 +197,21 @@ class GatherNdTest(test.TestCase):
     self.assertEqual(None, shape.ndims)
     self.assertEqual(None, shape[0].value)
 
-  def testBadIndices(self):
+  def testBadIndicesCPU(self):
+    with self.test_session(use_gpu=False):
+      params = [0, 1, 2]
+      indices = [[[0], [7]]]  # Make this one higher rank
+      gather_nd = array_ops.gather_nd(params, indices)
+      with self.assertRaisesOpError(
+          r"flat indices\[1, :\] = \[7\] does not index into param "
+          r"\(shape: \[3\]\)"):
+        gather_nd.eval()
+
+  def _disabledTestBadIndicesGPU(self):
+    # TODO disabled due to different behavior on GPU and CPU
+    # On GPU the bad indices do not raise error but fetch 0 values
+    if not test.is_gpu_available():
+      return
     with self.test_session(use_gpu=True):
       params = [0, 1, 2]
       indices = [[[0], [7]]]  # Make this one higher rank
@@ -207,7 +221,21 @@ class GatherNdTest(test.TestCase):
           r"\(shape: \[3\]\)"):
         gather_nd.eval()
 
-  def testBadIndicesWithSlices(self):
+  def testBadIndicesWithSlicesCPU(self):
+    with self.test_session(use_gpu=False):
+      params = [[0, 1, 2]]
+      indices = [[[0], [0], [1]]]  # Make this one higher rank
+      gather_nd = array_ops.gather_nd(params, indices)
+      with self.assertRaisesOpError(
+          r"flat indices\[2, :\] = \[1\] does not index into param "
+          r"\(shape: \[1,3\]\)"):
+        gather_nd.eval()
+
+  def _disabledTestBadIndicesWithSlicesGPU(self):
+    # TODO disabled due to different behavior on GPU and CPU
+    # On GPU the bad indices do not raise error but fetch 0 values
+    if not test.is_gpu_available():
+      return
     with self.test_session(use_gpu=True):
       params = [[0, 1, 2]]
       indices = [[[0], [0], [1]]]  # Make this one higher rank
index a2fcd75..033fa95 100644 (file)
@@ -27,7 +27,8 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.platform import test
 
-_TEST_TYPES = (dtypes.float32, dtypes.complex64, dtypes.complex128)
+_TEST_TYPES = (dtypes.int64, dtypes.float32,
+               dtypes.complex64, dtypes.complex128)
 
 
 class GatherTest(test.TestCase):
@@ -122,6 +123,9 @@ class GatherTest(test.TestCase):
                 gather, [tf_params, tf_indices, tf_axis], gather_grad)
             self.assertEqual(indices_grad, None)
             self.assertEqual(axis_grad, None)
+            if dtype.is_integer:
+              self.assertEqual(params_grad, None)
+              continue
             # For axis 0, we are able to create an efficient IndexedSlices for
             # the gradient.
             if axis == 0:
@@ -177,7 +181,19 @@ class GatherTest(test.TestCase):
     gather_t = array_ops.gather(params, indices, axis=axis)
     self.assertEqual(None, gather_t.shape)
 
-  def testBadIndices(self):
+  def testBadIndicesCPU(self):
+    with self.test_session(use_gpu=False):
+      params = [[0, 1, 2], [3, 4, 5]]
+      with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
+        array_ops.gather(params, [[7]], axis=0).eval()
+      with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
+        array_ops.gather(params, [[7]], axis=1).eval()
+
+  def _disabledTestBadIndicesGPU(self):
+    # TODO disabled due to different behavior on GPU and CPU
+    # On GPU the bad indices do not raise error but fetch 0 values
+    if not test.is_gpu_available():
+      return
     with self.test_session(use_gpu=True):
       params = [[0, 1, 2], [3, 4, 5]]
       with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
index 79fe927..faa4b49 100644 (file)
@@ -144,7 +144,9 @@ class StatefulScatterNdTest(test.TestCase):
         self.assertAllClose(new, ref_var.eval())
 
   def _VariableRankTests(self, np_scatter, tf_scatter):
-    for vtype in (np.float32, np.float64, np.complex64, np.complex128):
+    for vtype in (np.int32,
+                  np.float32, np.float64,
+                  np.complex64, np.complex128):
       for itype in (np.int32, np.int64):
         self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)
 
@@ -221,7 +223,7 @@ class StatefulScatterNdTest(test.TestCase):
   #   self._VariableRankTests(_NumpyDiv, state_ops.scatter_nd_div)
 
   def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter):
-    for vtype in (np.float32, np.float64):
+    for vtype in (np.int32, np.float32, np.float64):
       for itype in (np.int32, np.int64):
         self._VariableRankTest(
             np_scatter, tf_scatter, vtype, itype, repeat_indices=True)
index c70a4ff..1a0fa74 100644 (file)
@@ -159,7 +159,13 @@ class ScatterTest(test.TestCase):
 
           # Clips small values to avoid division by zero.
           def clip_small_values(x):
-            return 1e-4 * np.sign(x) if np.abs(x) < 1e-4 else x
+            threshold = 1e-4
+            sign = np.sign(x)
+
+            if isinstance(x, np.int32):
+              threshold = 1
+              sign = np.random.choice([-1, 1])
+            return threshold * sign if np.abs(x) < threshold else x
 
           updates = np.vectorize(clip_small_values)(updates)
           old = _AsType(np.random.randn(*((first_dim,) + extra_shape)), vtype)
@@ -181,7 +187,11 @@ class ScatterTest(test.TestCase):
                          tf_scatter,
                          repeat_indices=False,
                          updates_are_scalar=False):
-    for vtype in (np.float32, np.float64):
+    vtypes = [np.float32, np.float64]
+    if tf_scatter != state_ops.scatter_div:
+      vtypes.append(np.int32)
+
+    for vtype in vtypes:
       for itype in (np.int32, np.int64):
         self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices,
                                updates_are_scalar)