Register bool scatter_update for resource variables
authorAlexandre Passos <apassos@google.com>
Mon, 7 May 2018 18:49:08 +0000 (11:49 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 7 May 2018 23:48:12 +0000 (16:48 -0700)
Fixes #17784

PiperOrigin-RevId: 195696915

tensorflow/core/kernels/resource_variable_ops.cc
tensorflow/core/kernels/scatter_functor_gpu.cu.cc
tensorflow/python/kernel_tests/resource_variable_ops_test.py

index a8bcc7f..03cc414 100644 (file)
@@ -703,6 +703,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
 
 REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
                         scatter_op::UpdateOp::ASSIGN);
+REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
+                        scatter_op::UpdateOp::ASSIGN);
 REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
                         scatter_op::UpdateOp::ASSIGN);
 
@@ -728,6 +730,13 @@ REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
                             .Device(DEVICE_GPU)
                             .HostMemory("resource")
+                            .TypeConstraint<bool>("dtype")
+                            .TypeConstraint<int32>("Tindices"),
+                        ResourceScatterUpdateOp<GPUDevice, bool, int32,
+                                                scatter_op::UpdateOp::ASSIGN>)
+REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("resource")
                             .HostMemory("indices")
                             .TypeConstraint<Variant>("dtype")
                             .TypeConstraint<int64>("Tindices"),
index 59911bf..bdc8785 100644 (file)
@@ -42,6 +42,8 @@ typedef Eigen::GpuDevice GPUDevice;
 
 DEFINE_GPU_SPECS(float);
 DEFINE_GPU_SPECS(double);
+DEFINE_GPU_SPECS_OP(bool, int32, scatter_op::UpdateOp::ASSIGN);
+DEFINE_GPU_SPECS_OP(bool, int64, scatter_op::UpdateOp::ASSIGN);
 // TODO(b/27222123): The following fails to compile due to lack of support for
 // fp16.
 // TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
index 9841922..3daf07e 100644 (file)
@@ -400,6 +400,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
               resource_variable_ops.var_is_initialized_op(abc.handle)),
           True)
 
+  def testScatterBool(self):
+    with context.eager_mode():
+      ref = resource_variable_ops.ResourceVariable(
+          [False, True, False], trainable=False)
+      indices = math_ops.range(3)
+      updates = constant_op.constant([True, True, True])
+      state_ops.scatter_update(ref, indices, updates)
+      self.assertAllEqual(ref.read_value(), [True, True, True])
+
   @test_util.run_in_graph_and_eager_modes()
   def testConstraintArg(self):
     constraint = lambda x: x