Register resource_scatter_update for string types.
authorAlexandre Passos <apassos@google.com>
Fri, 2 Feb 2018 19:24:11 +0000 (11:24 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 19:28:07 +0000 (11:28 -0800)
PiperOrigin-RevId: 184309674

tensorflow/core/kernels/resource_variable_ops.cc
tensorflow/core/ops/resource_variable_ops.cc
tensorflow/python/kernel_tests/resource_variable_ops_test.py

index 9cc8e03..6ce53e7 100644 (file)
@@ -635,6 +635,9 @@ class ResourceScatterUpdateOp : public OpKernel {
 
 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
 
+REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
+                        scatter_op::UpdateOp::ASSIGN);
+
 // Registers GPU kernels.
 #if GOOGLE_CUDA
 #define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
index f6cfbf8..8dae7e1 100644 (file)
@@ -193,7 +193,7 @@ REGISTER_OP("ResourceScatterUpdate")
     .Input("resource: resource")
     .Input("indices: Tindices")
     .Input("updates: dtype")
-    .Attr("dtype: numbertype")
+    .Attr("dtype: type")
     .Attr("Tindices: {int32, int64}")
     .SetShapeFn([](InferenceContext* c) {
       ShapeAndType handle_shape_and_type;
index b4b5555..cd94579 100644 (file)
@@ -36,6 +36,7 @@ from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
+from tensorflow.python.util import compat
 
 
 @test_util.with_c_api
@@ -170,6 +171,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
     read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
     self.assertEqual(self.evaluate(read), [[3]])
 
+  def testScatterUpdateString(self):
+    handle = resource_variable_ops.var_handle_op(
+        dtype=dtypes.string, shape=[1, 1])
+    self.evaluate(resource_variable_ops.assign_variable_op(
+        handle, constant_op.constant([["a"]], dtype=dtypes.string)))
+    self.evaluate(resource_variable_ops.resource_scatter_update(
+        handle, [0], constant_op.constant([["b"]], dtype=dtypes.string)))
+    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string)
+    self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]),
+                     compat.as_bytes("b"))
+
   # TODO(alive): get this to work in Eager mode.
   def testGPU(self):
     with self.test_session(use_gpu=True):