REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
+REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
+ scatter_op::UpdateOp::ASSIGN);
REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
.Device(DEVICE_GPU)
.HostMemory("resource")
+ .TypeConstraint<bool>("dtype")
+ .TypeConstraint<int32>("Tindices"),
+ ResourceScatterUpdateOp<GPUDevice, bool, int32,
+ scatter_op::UpdateOp::ASSIGN>)
+REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
.HostMemory("indices")
.TypeConstraint<Variant>("dtype")
.TypeConstraint<int64>("Tindices"),
DEFINE_GPU_SPECS(float);
DEFINE_GPU_SPECS(double);
+DEFINE_GPU_SPECS_OP(bool, int32, scatter_op::UpdateOp::ASSIGN);
+DEFINE_GPU_SPECS_OP(bool, int64, scatter_op::UpdateOp::ASSIGN);
// TODO(b/27222123): The following fails to compile due to lack of support for
// fp16.
// TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
resource_variable_ops.var_is_initialized_op(abc.handle)),
True)
+ def testScatterBool(self):
+ with context.eager_mode():
+ ref = resource_variable_ops.ResourceVariable(
+ [False, True, False], trainable=False)
+ indices = math_ops.range(3)
+ updates = constant_op.constant([True, True, True])
+ state_ops.scatter_update(ref, indices, updates)
+ self.assertAllEqual(ref.read_value(), [True, True, True])
+
@test_util.run_in_graph_and_eager_modes()
def testConstraintArg(self):
constraint = lambda x: x