Initial XLA support for TF eager. This is prerequisite for TF compiler's XLA support.
authorMingsheng Hong <hongm@google.com>
Wed, 7 Feb 2018 20:22:55 +0000 (12:22 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Feb 2018 20:26:52 +0000 (12:26 -0800)
This CL adds XLA support for the following TFE_Op's:

1. A TF op such as MatMul, with full support of constant and resource params.

2. A TF_Function as TFE_Op, where the function must have no constant and
resource params. Removing this restriction requires more discussion and will be
deferred to a later time.

PiperOrigin-RevId: 184877345

tensorflow/c/BUILD
tensorflow/c/c_test_util.cc
tensorflow/c/c_test_util.h
tensorflow/c/eager/BUILD
tensorflow/c/eager/c_api.cc
tensorflow/c/eager/c_api.h
tensorflow/c/eager/c_api_internal.h
tensorflow/c/eager/c_api_test.cc

index c46cb32..314cbc6 100644 (file)
@@ -135,6 +135,10 @@ tf_cuda_library(
     testonly = 1,
     srcs = ["c_test_util.cc"],
     hdrs = ["c_test_util.h"],
+    visibility = [
+        "//learning/brain:__subpackages__",
+        "//tensorflow:__subpackages__",
+    ],
     deps = [
         ":c_api",
         "//tensorflow/core:lib",
index 37439ff..3c1d5b5 100644 (file)
@@ -124,8 +124,9 @@ TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
   return Const(tensor.get(), graph, s, name);
 }
 
-void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s,
-               const char* name, TF_Operation** op, bool check) {
+void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+                 TF_Status* s, const char* name, TF_Operation** op,
+                 bool check) {
   TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
   TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
   TF_AddInputList(desc, add_inputs, 2);
@@ -139,14 +140,14 @@ void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s,
 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
                   TF_Status* s, const char* name) {
   TF_Operation* op;
-  AddHelper(l, r, graph, s, name, &op, true);
+  AddOpHelper(l, r, graph, s, name, &op, true);
   return op;
 }
 
 TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
                          TF_Status* s, const char* name) {
   TF_Operation* op;
-  AddHelper(l, r, graph, s, name, &op, false);
+  AddOpHelper(l, r, graph, s, name, &op, false);
   return op;
 }
 
@@ -160,6 +161,26 @@ TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
   return TF_FinishOperation(desc, s);
 }
 
+void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r,
+                    TF_Graph* graph, TF_Status* s, const char* name,
+                    TF_Operation** op, bool check) {
+  TF_OperationDescription* desc = TF_NewOperation(graph, op_name, name);
+  TF_AddInput(desc, {l, 0});
+  TF_AddInput(desc, {r, 0});
+  *op = TF_FinishOperation(desc, s);
+  if (check) {
+    ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+    ASSERT_NE(*op, nullptr);
+  }
+}
+
+TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+                  TF_Status* s, const char* name) {
+  TF_Operation* op;
+  BinaryOpHelper("Min", l, r, graph, s, name, &op, true);
+  return op;
+}
+
 TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
                   const char* name) {
   TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
index 6acc2fe..77520be 100644 (file)
@@ -69,6 +69,9 @@ TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
 TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
                   const char* name = "add");
 
+TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+                  TF_Status* s, const char* name = "min");
+
 TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s,
                   const char* name = "neg");
 
index 3505f70..e55cb67 100644 (file)
@@ -72,6 +72,7 @@ tf_cuda_cc_test(
     ],
     deps = [
         ":c_api",
+        "//tensorflow/c:c_test_util",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
index 3a6d2ce..9cd1acc 100644 (file)
@@ -37,6 +37,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/gtl/stl_util.h"
 #include "tensorflow/core/platform/mutex.h"
@@ -47,19 +48,23 @@ using tensorflow::int64;
 using tensorflow::string;
 
 namespace {
-bool IsCPU(tensorflow::Device* d) {
+bool IsCPU(const tensorflow::Device* d) {
   return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
 }
 
-bool IsXLA(tensorflow::Device* d) {
+bool IsXLA(const tensorflow::Device* d) {
   if (d == nullptr) return false;
   const auto& device_type = d->attributes().device_type();
   return device_type.find("XLA") != std::string::npos;
 }
 
-string DeviceName(tensorflow::Device* d) {
+string DeviceName(const tensorflow::Device* d) {
   return (d == nullptr) ? "cpu:0" : d->name();
 }
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+std::atomic_int_fast64_t func_id_generator(0);
+#endif  // TENSORFLOW_EAGER_USE_XLA
 }  // namespace
 
 extern "C" {
@@ -281,6 +286,14 @@ const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
   return device->name().c_str();
 }
 
+void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
+  op->use_xla = enable;
+#ifndef TENSORFLOW_EAGER_USE_XLA
+  LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
+                  "built with XLA support.";
+#endif  // TENSORFLOW_EAGER_USE_XLA
+}
+
 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
   // Questionable heuristic ...
   //
@@ -523,6 +536,228 @@ tensorflow::Status ValidateInputTypeAndPlacement(
   }
   return tensorflow::Status::OK();
 }
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+// Synthesizes and returns a wrapper function over `op`, which must be a
+// primitive op (e.g. matmul).
+//
+// The wrapper function conforms to the function signature expected by
+// _XlaLaunchOp, with input params ordered by <constants, (variable) args and
+// resources>. For example, if the op has input params <Const1, Arg2, Const3,
+// Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
+// Resource4> as the input params to the synthesized function.
+//
+// It populates `const_input_types`, `arg_input_types` and
+// `op_input_to_func_input` based on the reordering results, that the caller can
+// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets
+// `status` accordingly.
+const tensorflow::FunctionDef* OpToFunction(
+    TFE_Op* op, std::vector<TF_DataType>* const_input_types,
+    std::vector<TF_DataType>* arg_input_types,
+    tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input,
+    TF_Status* status) {
+  DCHECK(!op->is_function());
+
+  tensorflow::FunctionDef fdef;
+
+  // Get the OpDef of the op we are trying to encapsulate.
+  TFE_Context* ctx = op->ctx;
+  const tensorflow::OpRegistrationData* op_data;
+  {
+    tensorflow::tf_shared_lock l(ctx->functions_mu);
+    status->status = ctx->func_lib_def.LookUp(op->name, &op_data);
+    if (!status->status.ok()) {
+      return nullptr;
+    }
+  }
+  const tensorflow::OpDef& op_def = op_data->op_def;
+
+  tensorflow::OpDef* signature = fdef.mutable_signature();
+
+  // Handle constant inputs.
+  const std::unordered_set<string> const_inputs(
+      *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name));
+
+  // First add place holders for the input args, so that we can refer to them by
+  // position in the next loop. Also tally up the resource inputs.
+  int num_resource_inputs = 0;
+  for (int i = 0; i < op_def.input_arg_size(); ++i) {
+    if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) {
+      ++num_resource_inputs;
+    }
+    signature->add_input_arg();
+  }
+
+  // Now we map the input params from `op_def` to `signature`, where the param
+  // ordering for `signature` is: <constants, args, resources>.
+  int const_index = 0;
+  int arg_index = const_inputs.size();
+  int resource_index = op_def.input_arg_size() - num_resource_inputs;
+  for (int i = 0; i < op_def.input_arg_size(); ++i) {
+    const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i);
+    tensorflow::OpDef::ArgDef* func_input_arg = nullptr;
+    if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) {
+      VLOG(1) << "For const input, mapping op input " << i << " to func input "
+              << const_index;
+      (*op_input_to_func_input)[i] = const_index;
+      func_input_arg = signature->mutable_input_arg(const_index++);
+      const_input_types->push_back(
+          static_cast<TF_DataType>(op->inputs[i].dtype()));
+    } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
+      VLOG(1) << "For resource input, mapping op input " << i
+              << " to func input " << resource_index;
+      (*op_input_to_func_input)[i] = resource_index;
+      func_input_arg = signature->mutable_input_arg(resource_index++);
+    } else {
+      VLOG(1) << "For arg input, mapping op input " << i << " to func input "
+              << arg_index;
+      (*op_input_to_func_input)[i] = arg_index;
+      func_input_arg = signature->mutable_input_arg(arg_index++);
+      arg_input_types->push_back(
+          static_cast<TF_DataType>(op->inputs[i].dtype()));
+    }
+
+    func_input_arg->set_name(op_input_arg.name());
+    func_input_arg->set_type(op->inputs[i].dtype());
+  }
+  VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
+
+  // Resources args are at the end of the function input params, and we should
+  // have iterated over all of them.
+  DCHECK_EQ(signature->input_arg_size(), resource_index);
+
+  // Make the synthesized function's name unique.
+  signature->set_name(tensorflow::strings::StrCat(
+      op_def.name(), func_id_generator.fetch_add(1)));
+
+  // Add the node def and set its input names to match op_def's names.
+  const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
+  DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
+  *fdef.add_node_def() = ndef;
+  for (int i = 0; i < op_def.input_arg_size(); ++i) {
+    fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name());
+  }
+  VLOG(1) << "Added NodeDef: " << fdef.DebugString();
+
+  // Fix the output names and set output types.
+  for (int i = 0; i < op_def.output_arg_size(); ++i) {
+    tensorflow::OpDef::ArgDef* arg = signature->add_output_arg();
+    const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i);
+    const string& out_tensor_name = tensorflow::strings::StrCat(
+        ndef.name(), ":", op_def_arg.name(), ":", 0);
+    arg->set_name(op_def_arg.name());
+    (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name;
+    const string& type_attr = op_def_arg.type_attr();
+    if (!type_attr.empty()) {
+      auto i = ndef.attr().find(type_attr);
+      if (i == ndef.attr().end()) {
+        status->status = tensorflow::errors::InvalidArgument(
+            tensorflow::strings::StrCat("Could not find attr ", type_attr,
+                                        " in NodeDef ", ndef.DebugString()));
+        return nullptr;
+      }
+      arg->set_type(i->second.type());
+    }
+  }
+  VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
+
+  tensorflow::mutex_lock l(ctx->functions_mu);
+  status->status = ctx->func_lib_def.AddFunctionDef(fdef);
+  if (!status->status.ok()) return nullptr;
+  const auto ret = ctx->func_lib_def.Find(signature->name());
+  DCHECK(ret != nullptr);
+  return ret;
+}
+
+// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
+// via XLA.
+std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
+  VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name;
+  auto launch_op =
+      std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch", status));
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+  if (op->device) {
+    TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status);
+    if (TF_GetCode(status) != TF_OK) return nullptr;
+  }
+
+  const tensorflow::FunctionDef* fdef;
+  {
+    tensorflow::tf_shared_lock l(op->ctx->functions_mu);
+    fdef = op->ctx->func_lib_def.Find(op->name);
+  }
+  std::vector<TF_DataType> const_input_types;
+  std::vector<TF_DataType> arg_input_types;
+  tensorflow::gtl::FlatMap<int, int> op_input_to_func_input;
+  if (fdef == nullptr) {
+    // See if this is a primitive op, and if so create a function for it, so
+    // that _XlaLaunchOp can access it.
+    fdef = OpToFunction(op, &const_input_types, &arg_input_types,
+                        &op_input_to_func_input, status);
+    if (!status->status.ok()) return nullptr;
+  } else {
+    // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
+    // functions, so we need to find another way to handle constant inputs.
+    for (int i = const_input_types.size();
+         i < fdef->signature().input_arg_size(); ++i) {
+      VLOG(1) << "Adding Targs from input arg " << i;
+      const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i);
+      arg_input_types.push_back(static_cast<TF_DataType>(arg.type()));
+    }
+  }
+  DCHECK(fdef != nullptr);
+
+  // Copy inputs and their devices.
+  // Since input param reordering may have occurred between `op` and `launch_op`
+  // via `op_input_to_func_input`, adjust the actual inputs accordingly.
+  launch_op->inputs = op->inputs;
+  launch_op->input_devices = op->input_devices;
+  if (!op_input_to_func_input.empty()) {
+    DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
+    if (!op->input_devices.empty()) {
+      DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size());
+    }
+    for (int i = 0; i < op_input_to_func_input.size(); ++i) {
+      VLOG(1) << "mapping op input " << i << " to func input "
+              << op_input_to_func_input[i];
+
+      launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
+      if (!op->input_devices.empty()) {
+        launch_op->input_devices[op_input_to_func_input[i]] =
+            op->input_devices[i];
+      }
+    }
+  }
+  launch_op->attrs.NumInputs(op->inputs.size());
+
+  TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
+                        const_input_types.size());
+
+  // Set Targs and Nresources attrs.
+  TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(),
+                        arg_input_types.size());
+  const int num_resource_inputs = fdef->signature().input_arg_size() -
+                                  const_input_types.size() -
+                                  arg_input_types.size();
+  TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs);
+
+  // Set Tresults attr.
+  std::vector<TF_DataType> tresults;
+  for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) {
+    tresults.push_back(static_cast<TF_DataType>(arg.type()));
+  }
+  TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(),
+                        tresults.size());
+
+  // Set function attr.
+  tensorflow::AttrValue attr_value;
+  tensorflow::NameAttrList* func = attr_value.mutable_func();
+  func->set_name(fdef->signature().name());
+  launch_op->attrs.Set("function", attr_value);
+
+  return launch_op;
+}
+#endif  // TENSORFLOW_EAGER_USE_XLA
 }  // namespace
 
 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
@@ -531,6 +766,18 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
   // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU
   tensorflow::Device* device =
       (op->device == nullptr) ? ctx->devices()[0] : op->device;
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+  std::unique_ptr<TFE_Op> xla_launch_op;
+  if (op->use_xla && op->name != "_XlaLaunch") {
+    xla_launch_op = BuildXlaLaunch(op, status);
+    if (!status->status.ok()) {
+      return;
+    }
+    op = xla_launch_op.get();
+  }
+#endif  // TENSORFLOW_EAGER_USE_XLA
+
   std::vector<tensorflow::Tensor> outputs(1);
   const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
   tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
index 6a2aff1..9506cf7 100644 (file)
@@ -158,6 +158,13 @@ TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
 TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op,
                                                   TF_Status* status);
 
+// When 'enable' is set to 1, and if TensorFlow library is built with XLA
+// support, a subsequent TFE_Execute() call on `op` will run the op via XLA.
+//
+// If the library is not built with XLA support, this call would be a no-op.
+TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op,
+                                                   unsigned char enable);
+
 TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
 
 TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
index f2abffb..7b9f1db 100644 (file)
@@ -121,6 +121,7 @@ struct TFE_Op {
   std::vector<tensorflow::Tensor> inputs;
   std::vector<tensorflow::Device*> input_devices;
   tensorflow::Device* device;
+  bool use_xla = false;
 };
 
 #endif  // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
index b0409af..4a3ecbc 100644 (file)
@@ -60,6 +60,38 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
   return op;
 }
 
+TFE_TensorHandle* TestAxisTensorHandle() {
+  int64_t dims[] = {1};
+  int data[] = {1};
+  TF_Tensor* t = TF_AllocateTensor(
+      TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+  memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+  TF_Status* status = TF_NewStatus();
+  TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TF_DeleteTensor(t);
+  TF_DeleteStatus(status);
+  return th;
+}
+
+TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
+              TFE_TensorHandle* axis) {
+  TF_Status* status = TF_NewStatus();
+
+  TFE_Op* op = TFE_NewOp(ctx, "Min", status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_OpAddInput(op, input, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_OpAddInput(op, axis, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_OpSetAttrBool(op, "keep_dims", 1);
+  TFE_OpSetAttrType(op, "Tidx", TF_INT32);
+  TF_DeleteStatus(status);
+  TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
+
+  return op;
+}
+
 // If there is a GPU device, returns true and sets 'gpu_device_name'
 // accordingly.
 bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) {
@@ -410,7 +442,7 @@ TEST(CAPI, SetAndGetOpDevices) {
   TF_DeleteStatus(status);
 }
 
-TEST(CAPI, Execute) {
+TEST(CAPI, Execute_MatMul_CPU) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
   TFE_Context* ctx = TFE_NewContext(opts, status);
@@ -443,6 +475,117 @@ TEST(CAPI, Execute) {
   TF_DeleteStatus(status);
 }
 
+TEST(CAPI, Execute_Min_CPU) {
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+
+  TFE_TensorHandle* input = TestMatrixTensorHandle();
+  TFE_TensorHandle* axis = TestAxisTensorHandle();
+  TFE_Op* minOp = MinOp(ctx, input, axis);
+  TFE_TensorHandle* retvals[2] = {nullptr};
+  int num_retvals = 2;  // Should be reduced to 1 by the TFE_Execute call.
+  TFE_Execute(minOp, &retvals[0], &num_retvals, status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteOp(minOp);
+  TFE_DeleteTensorHandle(input);
+  TFE_DeleteTensorHandle(axis);
+  TFE_DeleteContext(ctx, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  ASSERT_EQ(1, num_retvals);
+
+  TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+  TFE_DeleteTensorHandle(retvals[0]);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  float output[2] = {0};
+  EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
+  memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
+  TF_DeleteTensor(t);
+  EXPECT_EQ(1, output[0]);
+  EXPECT_EQ(3, output[1]);
+  TF_DeleteStatus(status);
+}
+
+#ifdef TENSORFLOW_EAGER_USE_XLA
+TEST(CAPI, Execute_MatMul_XLA_CPU) {
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+
+  TFE_TensorHandle* m = TestMatrixTensorHandle();
+  TFE_Op* matmul = MatMulOp(ctx, m, m);
+
+  TFE_OpSetXLACompilation(matmul, true);
+
+  TFE_TensorHandle* retvals[2] = {nullptr};
+  int num_retvals = 2;  // Should be reduced to 1 by the TFE_Execute call.
+  TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+  // Running a primitive TF operator via XLA is not yet supported.
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  TFE_DeleteOp(matmul);
+  TFE_DeleteTensorHandle(m);
+  TFE_DeleteContext(ctx, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  EXPECT_EQ(1, num_retvals);
+
+  TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+  TFE_DeleteTensorHandle(retvals[0]);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  float product[4] = {0};
+  EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+  memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+  TF_DeleteTensor(t);
+  EXPECT_EQ(7, product[0]);
+  EXPECT_EQ(10, product[1]);
+  EXPECT_EQ(15, product[2]);
+  EXPECT_EQ(22, product[3]);
+
+  TF_DeleteStatus(status);
+}
+
+TEST(CAPI, Execute_Min_XLA_CPU) {
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+
+  TFE_TensorHandle* input = TestMatrixTensorHandle();
+  TFE_TensorHandle* axis = TestAxisTensorHandle();
+  TFE_Op* minOp = MinOp(ctx, input, axis);
+
+  TFE_OpSetXLACompilation(minOp, true);
+
+  TFE_TensorHandle* retvals[2] = {nullptr};
+  int num_retvals = 2;  // Should be reduced to 1 by the TFE_Execute call.
+  TFE_Execute(minOp, &retvals[0], &num_retvals, status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteOp(minOp);
+  TFE_DeleteTensorHandle(input);
+  TFE_DeleteTensorHandle(axis);
+  TFE_DeleteContext(ctx, status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  ASSERT_EQ(1, num_retvals);
+
+  TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+  TFE_DeleteTensorHandle(retvals[0]);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  float output[2] = {0};
+  EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
+  memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
+  TF_DeleteTensor(t);
+  EXPECT_EQ(1, output[0]);
+  EXPECT_EQ(3, output[1]);
+  TF_DeleteStatus(status);
+}
+#endif  // TENSORFLOW_EAGER_USE_XLA
+
 TEST(CAPI, ExecuteWithTracing) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
@@ -484,7 +627,7 @@ TEST(CAPI, ExecuteWithTracing) {
   TF_DeleteStatus(status);
 }
 
-TEST(CAPI, Function) {
+TEST(CAPI, Function_ident_CPU) {
   // First create a simple identity function.
   TF_Graph* function_graph = TF_NewGraph();
   TF_OperationDescription* arg_descr =
@@ -545,6 +688,72 @@ TEST(CAPI, Function) {
   TF_DeleteStatus(status);
 }
 
+#ifdef TENSORFLOW_EAGER_USE_XLA
+TEST(CAPI, Function_ident_XLA_CPU) {
+  // First create a simple identity function.
+  TF_Graph* function_graph = TF_NewGraph();
+  TF_OperationDescription* arg_descr =
+      TF_NewOperation(function_graph, "Placeholder", "arg");
+  TF_SetAttrType(arg_descr, "dtype", TF_INT32);
+  TF_Status* status = TF_NewStatus();
+  TF_Operation* arg = TF_FinishOperation(arg_descr, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_OperationDescription* id_descr =
+      TF_NewOperation(function_graph, "Identity", "id");
+  TF_SetAttrType(id_descr, "T", TF_INT32);
+  TF_AddInput(id_descr, {arg, 0});
+  TF_Operation* id = TF_FinishOperation(id_descr, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_Output input{arg, 0};
+  TF_Output output{id, 0};
+  TF_Function* fn =
+      TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
+                         &output, nullptr, nullptr, "test", status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_DeleteGraph(function_graph);
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+  TFE_ContextAddFunction(ctx, fn, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_DeleteFunction(fn);
+
+  TF_Tensor* t =
+      TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
+  *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
+  TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_DeleteTensor(t);
+
+  TFE_Op* op = TFE_NewOp(ctx, "ident", status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TFE_OpAddInput(op, h, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+
+  // Now run it via XLA.
+  TFE_OpSetXLACompilation(op, true);
+
+  std::vector<TFE_TensorHandle*> result;
+  result.push_back(nullptr);
+  int num_retvals = 1;
+  TFE_Execute(op, result.data(), &num_retvals, status);
+  TFE_DeleteOp(op);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  ASSERT_EQ(num_retvals, 1);
+
+  TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
+  TFE_DeleteTensorHandle(h);
+  TF_DeleteTensor(r);
+  TFE_DeleteTensorHandle(result[0]);
+  TFE_DeleteContext(ctx, status);
+  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+  TF_DeleteStatus(status);
+}
+#endif  // TENSORFLOW_EAGER_USE_XLA
+
 string MatMulFunction() {
   tensorflow::FunctionDef def;
   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(