Extend TF Eager C API to allow asynchronous execution.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 19:44:29 +0000 (12:44 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 19:51:15 +0000 (12:51 -0700)
PiperOrigin-RevId: 188763442

tensorflow/c/eager/BUILD
tensorflow/c/eager/c_api.cc
tensorflow/c/eager/c_api.h
tensorflow/c/eager/c_api_internal.h
tensorflow/c/eager/c_api_test.cc
tensorflow/c/eager/runtime.h
tensorflow/python/eager/core_test.py
tensorflow/python/eager/pywrap_tensor.cc
tensorflow/python/eager/pywrap_tfe_src.cc
tensorflow/python/lib/core/py_func.cc

index e55cb67..3046d90 100644 (file)
@@ -58,6 +58,7 @@ tf_cuda_library(
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:framework_lite",
+        "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
     ],
 )
index b9a47ea..56cec2d 100644 (file)
@@ -42,6 +42,7 @@ limitations under the License.
 #include "tensorflow/core/lib/gtl/flatmap.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/public/version.h"
@@ -67,6 +68,7 @@ string DeviceName(const tensorflow::Device* d) {
 #ifdef TENSORFLOW_EAGER_USE_XLA
 std::atomic_int_fast64_t func_id_generator(0);
 #endif  // TENSORFLOW_EAGER_USE_XLA
+
 }  // namespace
 
 TFE_ContextDevicePlacementPolicy PlacementPolicy(
@@ -90,11 +92,33 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
   TF_SetConfig(&options->session_options, proto, proto_len, status);
 }
 
+void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
+                                unsigned char async) {
+  options->async = async;
+}
 void TFE_ContextOptionsSetDevicePlacementPolicy(
     TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
   options->policy = policy;
 }
 
+TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
+                                                        unsigned char async,
+                                                        TF_Status* status) {
+  {
+    tensorflow::mutex_lock l(ctx->async_map_mu);
+    ctx->thread_local_async[std::this_thread::get_id()] = async;
+  }
+  if (async) {
+    ctx->executor.EnableAsync();
+  } else {
+    // TODO(agarwal): Currently we add a wait here to handle cases where a sync
+    // op has a control dependency on an async op, and the latter has not
+    // executed yet. This wait can be removed by storing all the control inputs
+    // and waiting for them when executing ops.
+    status->status = ctx->executor.WaitForAllPendingNodes();
+  }
+}
+
 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
 
 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
@@ -113,7 +137,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
 }
 
 void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
-  status->status = tensorflow::Status::OK();
+  status->status = ctx->executor.WaitForAllPendingNodes();
   {
     tensorflow::mutex_lock ml(ctx->cache_mu);
     tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
@@ -139,6 +163,9 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
   ctx->thread_local_policies[std::this_thread::get_id()] = policy;
 }
 
+// Note: this function looks up a thread local policy. So it should be called in
+// the appropriate client thread. In particular, in async mode, it may not be
+// safe to call this function from the async TFE_Executor threads.
 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
     TFE_Context* ctx) {
   tensorflow::mutex_lock ml(ctx->policy_map_mu);
@@ -150,6 +177,18 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
   return ctx->policy;
 }
 
+void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
+  status->status = ctx->executor.WaitForAllPendingNodes();
+}
+
+void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
+  status->status = ctx->executor.status();
+}
+
+void TFE_ContextAsyncClearError(TFE_Context* ctx) {
+  ctx->executor.ClearError();
+}
+
 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
   tensorflow::Tensor tensor;
   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
@@ -157,56 +196,70 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
   return new TFE_TensorHandle(tensor, nullptr, nullptr);
 }
 
-void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
+void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
+  DCHECK(h);
+  h->Unref();
+}
 
 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
-  return static_cast<TF_DataType>(h->t.dtype());
+  return static_cast<TF_DataType>(h->dtype);
 }
 
 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
-  status->status = tensorflow::Status::OK();
-  return h->t.dims();
+  const tensorflow::Tensor* t = nullptr;
+  status->status = h->Tensor(&t);
+  return t == nullptr ? 0 : t->dims();
 }
 
 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
                             TF_Status* status) {
-  status->status = tensorflow::Status::OK();
-  return h->t.dim_size(dim_index);
+  const tensorflow::Tensor* t = nullptr;
+  status->status = h->Tensor(&t);
+  return t == nullptr ? 0 : t->dim_size(dim_index);
 }
 
 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
-  status->status = tensorflow::Status::OK();
-  return (h->op_device == nullptr)
-             ? "/job:localhost/replica:0/task:0/device:CPU:0"
-             : h->op_device->name().c_str();
+  tensorflow::Device* d = nullptr;
+  status->status = h->OpDevice(&d);
+  return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
+                        : d->name().c_str();
 }
 
 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
-  if (!IsCPU(h->d)) {
+  // TODO(agarwal): move this implementation inside TFE_TensorHandle.
+  tensorflow::Device* d = nullptr;
+  tensorflow::Device* op_device = nullptr;
+  const tensorflow::Tensor* t = nullptr;
+  status->status = h->TensorAndDevice(&t, &d, &op_device);
+  if (!status->status.ok()) return nullptr;
+  if (!IsCPU(d)) {
     TF_SetStatus(status, TF_UNIMPLEMENTED,
                  tensorflow::strings::StrCat(
                      "TFE_TensorHandle can be resolved iff it is on CPU (this "
                      "handle is on ",
-                     h->d->name(),
+                     d->name(),
                      "). Consider using TFE_TensorHandleCopyToDevice to get a "
                      "copy of the tensor on CPU")
                      .c_str());
     return nullptr;
   }
-  return tensorflow::TF_TensorFromTensor(h->t, status);
+  return tensorflow::TF_TensorFromTensor(*t, status);
 }
+}  // extern "C"
 
-TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
-                                               TFE_Context* ctx,
-                                               const char* device_name,
-                                               TF_Status* status) {
-  tensorflow::Device* dstd = ctx->devices[0];
-  if (device_name != nullptr && strlen(device_name) > 0) {
-    status->status = ctx->device_manager->LookupDevice(device_name, &dstd);
-    if (!status->status.ok()) return nullptr;
-  }
+namespace {
 
-  tensorflow::Device* srcd = h->d == nullptr ? ctx->devices[0] : h->d;
+tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h,
+                                            TFE_Context* ctx,
+                                            tensorflow::Device* dstd,
+                                            TFE_TensorHandle** output) {
+  const tensorflow::Tensor* src = nullptr;
+  tensorflow::Device* srcd = nullptr;
+  // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept
+  // nullptr.
+  tensorflow::Device* src_opd = nullptr;
+  TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd));
+  if (srcd == nullptr) srcd = ctx->devices[0];
   bool is_same_device =
       (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
   const bool dst_cpu = IsCPU(dstd);
@@ -216,18 +269,15 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
   const bool both_on_cpu = src_cpu && dst_cpu;
   if (is_same_device || both_on_cpu) {
     dstd = dst_cpu ? nullptr : dstd;
-    return new TFE_TensorHandle(h->t, dstd, dstd);
+    *output = new TFE_TensorHandle(*src, dstd, dstd);
+    return tensorflow::Status::OK();
   }
-  tensorflow::Tensor* src = &(h->t);
   if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
                    !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
-    TF_SetStatus(
-        status, TF_INVALID_ARGUMENT,
-        tensorflow::strings::StrCat("Can't copy Tensor with type ",
-                                    tensorflow::DataTypeString(src->dtype()),
-                                    " to device ", DeviceName(dstd), ".")
-            .c_str());
-    return nullptr;
+    return tensorflow::errors::InvalidArgument(
+        "Can't copy Tensor with type ",
+        tensorflow::DataTypeString(src->dtype()), " to device ",
+        DeviceName(dstd), ".");
   }
   tensorflow::AllocatorAttributes attr;
   if (src->dtype() == tensorflow::DT_VARIANT) {
@@ -236,7 +286,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
   tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
   if (src->shape().num_elements() == 0) {
     dstd = dst_cpu ? nullptr : dstd;
-    return new TFE_TensorHandle(dst, dstd, dstd);
+    *output = new TFE_TensorHandle(dst, dstd, dstd);
+    return tensorflow::Status::OK();
   }
   tensorflow::DeviceContext* src_device_context = nullptr;
   if (!src_cpu) {
@@ -253,21 +304,26 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
   // With that setup, Sync()ing across all 3 streams should be sufficient
   // but more than necessary (since it waits for operations that might have
   // nothing to do with this tensor to complete).
-  status->status = srcd->Sync();
+  TF_RETURN_IF_ERROR(srcd->Sync());
   tensorflow::Notification n;
+  tensorflow::Status status;
   tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
                                  srcd, dstd, tensorflow::AllocatorAttributes(),
                                  tensorflow::AllocatorAttributes(), src, &dst,
-                                 [status, &n](const tensorflow::Status& s) {
-                                   status->status = s;
+                                 [&status, &n](const tensorflow::Status& s) {
+                                   status = s;
                                    n.Notify();
                                  });
   n.WaitForNotification();
-  return (TF_GetCode(status) == TF_OK)
-             ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd,
-                                    dst_cpu ? nullptr : dstd)
-             : nullptr;
+  if (status.ok()) {
+    dstd = dst_cpu ? nullptr : dstd;
+    *output = new TFE_TensorHandle(dst, dstd, dstd);
+  }
+  return status;
 }
+}  // namespace
+
+extern "C" {
 
 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
                   TF_Status* status) {
@@ -311,16 +367,19 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
 }
 
 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
-  // Questionable heuristic ...
-  // - If a device was explicitly set on the op, always use that.
-  // - If not, place on the first non-host device seen.
-  if (op->device == nullptr && !IsCPU(h->d)) {
-    op->device = h->d;
+  if (op->device == nullptr) {
+    // Questionable heuristic ...
+    // - If a device was explicitly set on the op, always use that.
+    // - If not, place on the first non-host device seen.
+    tensorflow::Device* d = nullptr;
+    // TODO(agarwal): This call may block if h is not ready. Avoid this if
+    // possible.
+    status->status = h->Device(&d);
+    if (!status->status.ok()) return;
+    if (!IsCPU(d)) op->device = d;
   }
-  if (!status->status.ok()) return;
-  op->inputs.push_back(h->t);
-  op->input_devices.push_back(h->d);
-  op->input_op_devices.push_back(h->op_device);
+  h->Ref();
+  op->inputs.push_back(h);
   op->attrs.NumInputs(op->inputs.size());
 }
 
@@ -482,14 +541,14 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
                 tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
                     funcs.get(), num_values));
 }
+}  // extern "C"
 
 namespace {
 
 tensorflow::Status ValidateInputTypeAndPlacement(
     TFE_Context* ctx, tensorflow::Device* host_device,
     tensorflow::Device* op_device, TFE_Op* op,
-    const tensorflow::OpKernel* kernel,
-    std::vector<TFE_TensorHandle*>* copied_tensors) {
+    const tensorflow::OpKernel* kernel) {
   const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
   if (memtypes.size() != op->inputs.size()) {
     return tensorflow::errors::InvalidArgument(
@@ -498,14 +557,17 @@ tensorflow::Status ValidateInputTypeAndPlacement(
   for (int i = 0; i < op->inputs.size(); ++i) {
     const tensorflow::Device* expected_device =
         memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
+    TFE_TensorHandle* handle = op->inputs[i];
+    tensorflow::Device* handle_device = nullptr;
+    TF_RETURN_IF_ERROR(handle->Device(&handle_device));
     const tensorflow::Device* actual_device =
-        op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
+        handle_device == nullptr ? host_device : handle_device;
     if (expected_device != actual_device) {
       switch (TFE_ContextGetDevicePlacementPolicy(ctx)) {
         case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32:
           // TODO(xpan): See if we could bubble python related error up
           // to python level.
-          if (op->inputs[i].dtype() == tensorflow::DT_INT32) {
+          if (handle->dtype == tensorflow::DT_INT32) {
             // Note: enabling silent copies of int32 tensors to match behavior
             // of graph mode.
             break;
@@ -536,36 +598,245 @@ tensorflow::Status ValidateInputTypeAndPlacement(
       }
       // We are only here if the policy is warn or silent copies, so we should
       // trigger a copy.
-      TFE_TensorHandle original{op->inputs[i], op->input_devices[i],
-                                op->device};
       TF_Status* s = TF_NewStatus();
       TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
-          &original, ctx, expected_device->name().c_str(), s);
-      if (!s->status.ok()) {
-        tensorflow::Status status = s->status;
-        delete s;
+          handle, ctx, expected_device->name().c_str(), s);
+      tensorflow::Status status = s->status;
+      TF_DeleteStatus(s);
+      if (!status.ok()) {
+        if (copied_tensor != nullptr) copied_tensor->Unref();
         return tensorflow::errors::Internal(
             "Failed copying input tensor from ", actual_device->name(), " to ",
             expected_device->name(), " in order to run ", op->name, ": ",
             status.error_message());
       }
-      op->inputs[i] = copied_tensor->t;
-      copied_tensors->push_back(copied_tensor);
-      op->input_devices[i] = copied_tensor->d;
-      delete s;
+      handle->Unref();
+      handle = copied_tensor;
+      op->inputs[i] = copied_tensor;
     }
-    if (op->inputs[i].dtype() != kernel->input_type(i)) {
+    if (handle->dtype != kernel->input_type(i)) {
       return tensorflow::errors::InvalidArgument(
           "cannot compute ", op->name, " as input #", i,
           " was expected to be a ",
           tensorflow::DataTypeString(kernel->input_type(i)),
-          " tensor but is a ",
-          tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor");
+          " tensor but is a ", tensorflow::DataTypeString(handle->dtype),
+          " tensor");
     }
   }
   return tensorflow::Status::OK();
 }
 
+tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
+                                 TFE_Context* ctx, TF_Status* status) {
+  tensorflow::DeviceSet ds;
+  for (tensorflow::Device* d : ctx->devices) {
+    ds.AddDevice(d);
+  }
+  tensorflow::DeviceTypeVector final_devices;
+  status->status = tensorflow::SupportedDeviceTypesForNode(
+      ds.PrioritizedDeviceTypeList(), ndef, &final_devices);
+  if (!status->status.ok()) {
+    return nullptr;
+  }
+  if (final_devices.empty()) {
+    status->status = tensorflow::errors::Internal(
+        "Could not find valid device for node ", ndef.DebugString());
+    return nullptr;
+  }
+  for (tensorflow::Device* d : ctx->devices) {
+    if (d->device_type() == final_devices[0].type_string()) {
+      return d;
+    }
+  }
+  status->status = tensorflow::errors::Unknown(
+      "Could not find a device for node ", ndef.DebugString());
+  return nullptr;
+}
+
+tensorflow::Status Execute(
+    TFE_Context* ctx, tensorflow::Device* device,
+    const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs,
+    tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats,
+    TFE_TensorHandle** retvals, int num_retvals) {
+  if (!ctx->soft_placement && device == nullptr) {
+    // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU
+    device = ctx->devices[0];
+  }
+
+  if (device == nullptr) {
+    // TODO(apassos) debug how the assignment below might return a different
+    // device from the one requested above.
+    device = kernel->device();
+  }
+
+  std::vector<tensorflow::Tensor> outputs(1);
+  const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
+  output_memory_types = &kernel->kernel()->output_memory_types();
+  std::vector<tensorflow::Tensor> inputs(op_inputs.size());
+  for (int i = 0; i < op_inputs.size(); ++i) {
+    const tensorflow::Tensor* input_tensor = nullptr;
+    TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
+    inputs[i] = *input_tensor;
+  }
+  // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
+  // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
+  // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
+  // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
+  // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
+  // This is quite subtle. Re-work things to make this better?  (Would it make
+  // sense for FunctionLibraryRuntime to ensure thread-safe access to
+  // FunctionLibraryDefinition?).  TODO(apassos) figure out how to record stats
+  // for ops which are a part of functions.
+  // TODO(agarwal): change Run to take vector of handles ?
+  TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats));
+  if (maybe_stats != nullptr) {
+    maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
+                                       maybe_stats->all_start_micros());
+    tensorflow::mutex_lock ml(ctx->metadata_mu);
+    if (ctx->should_store_metadata.load()) {
+      auto* step_stats = ctx->run_metadata.mutable_step_stats();
+      // Lazily initialize the RunMetadata with information about all devices if
+      // this is the first call.
+      while (step_stats->dev_stats_size() < ctx->devices.size()) {
+        step_stats->add_dev_stats();
+      }
+      // Find the current device's index.
+      int device_idx = 0;
+      for (int i = 0; i < ctx->devices.size(); ++i) {
+        if (ctx->devices[i] == device) {
+          device_idx = i;
+          break;
+        }
+      }
+      // Populate the device stats for this device.
+      auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
+      dev_stats->set_device(device->name());
+      *dev_stats->add_node_stats() = *maybe_stats;
+    }
+  }
+  if (num_retvals != outputs.size()) {
+    return tensorflow::errors::InvalidArgument(
+        "Expecting ", num_retvals, " outputs but got ", outputs.size());
+  }
+  tensorflow::Device* op_device = IsCPU(device) ? nullptr : device;
+  for (int i = 0; i < num_retvals; ++i) {
+    tensorflow::Device* d = op_device;
+    if (d != nullptr && output_memory_types != nullptr &&
+        (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
+      d = nullptr;
+    }
+    if (retvals[i] == nullptr) {
+      retvals[i] = new TFE_TensorHandle(outputs[i], d, op_device);
+    } else {
+      retvals[i]->SetTensorAndDevice(outputs[i], d, op_device);
+    }
+  }
+  return tensorflow::Status::OK();
+}
+
+// TODO(agarwal): move TFE_Executor and TFE_Node related code to a separate
+// file.
+class ExecuteNode : public TFE_Node {
+ public:
+  ExecuteNode(TFE_Op* op, tensorflow::KernelAndDevice* kernel,
+              tensorflow::NodeExecStats* maybe_stats,
+              const tensorflow::DataTypeVector& output_dtypes,
+              TFE_TensorHandle** retvals, int num_retvals)
+      : TFE_Node(op->ctx->executor.NextId()),
+        ctx_(op->ctx),
+        op_device_(op->device),
+        inputs_(op->inputs),
+        kernel_(kernel),
+        maybe_stats_(maybe_stats),
+        retvals_(num_retvals) {
+    for (auto handle : inputs_) {
+      handle->Ref();
+    }
+    TFE_Context* ctx = op->ctx;
+    for (int i = 0; i < num_retvals; ++i) {
+      TFE_TensorHandle* h = new TFE_TensorHandle(id, output_dtypes[i], ctx);
+      h->Ref();
+      retvals[i] = h;
+      retvals_[i] = h;
+    }
+  }
+
+  ~ExecuteNode() override {
+    for (auto handle : inputs_) {
+      handle->Unref();
+    }
+    for (auto handle : retvals_) {
+      handle->Unref();
+    }
+  }
+
+  tensorflow::Status Run() override {
+    const tensorflow::Status status =
+        Execute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
+                retvals_.begin(), retvals_.size());
+    if (status.ok()) {
+      return status;
+    } else {
+      return tensorflow::Status(
+          status.code(),
+          tensorflow::strings::StrCat("Got error, \"", status.error_message(),
+                                      "\" while executing kernel ",
+                                      kernel_->kernel()->def().DebugString()));
+    }
+  }
+
+ private:
+  TFE_Context* ctx_;
+  tensorflow::Device* op_device_;
+  tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs_;
+  tensorflow::KernelAndDevice* kernel_;
+  std::unique_ptr<tensorflow::NodeExecStats> maybe_stats_;
+  tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals_;
+};
+
+class CopyToDeviceNode : public TFE_Node {
+ public:
+  CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd,
+                   TFE_Context* ctx)
+      : TFE_Node(ctx->executor.NextId()),
+        src_(src),
+        dstd_(dstd),
+        ctx_(ctx),
+        dst_(new TFE_TensorHandle(id, src_->dtype, ctx)) {
+    src_->Ref();
+    dst_->Ref();
+  }
+
+  ~CopyToDeviceNode() override {
+    src_->Unref();
+    dst_->Unref();
+  }
+
+  tensorflow::Status Run() override {
+    TFE_TensorHandle* temp = nullptr;
+    TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp));
+    const tensorflow::Tensor* tensor = nullptr;
+    tensorflow::Device* device = nullptr;
+    tensorflow::Device* op_device = nullptr;
+    tensorflow::Status status =
+        temp->TensorAndDevice(&tensor, &device, &op_device);
+    // `temp` is a ready handle. So the following call should return OK.
+    TF_DCHECK_OK(status) << status.error_message();
+    DCHECK(tensor);
+    dst_->SetTensorAndDevice(*tensor, device, op_device);
+    temp->Unref();
+    return tensorflow::Status::OK();
+  }
+
+  TFE_TensorHandle* dst() { return dst_; }
+
+ private:
+  TFE_TensorHandle* src_;
+  tensorflow::Device* dstd_;
+  TFE_Context* ctx_;
+  TFE_TensorHandle* dst_;
+};
+
 #ifdef TENSORFLOW_EAGER_USE_XLA
 // Synthesizes and returns a wrapper function over `op`, which must be a
 // primitive op (e.g. matmul).
@@ -631,7 +902,7 @@ const tensorflow::FunctionDef* OpToFunction(
       (*op_input_to_func_input)[i] = const_index;
       func_input_arg = signature->mutable_input_arg(const_index++);
       const_input_types->push_back(
-          static_cast<TF_DataType>(op->inputs[i].dtype()));
+          static_cast<TF_DataType>(op->inputs[i]->dtype));
     } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
       VLOG(1) << "For resource input, mapping op input " << i
               << " to func input " << resource_index;
@@ -643,11 +914,11 @@ const tensorflow::FunctionDef* OpToFunction(
       (*op_input_to_func_input)[i] = arg_index;
       func_input_arg = signature->mutable_input_arg(arg_index++);
       arg_input_types->push_back(
-          static_cast<TF_DataType>(op->inputs[i].dtype()));
+          static_cast<TF_DataType>(op->inputs[i]->dtype));
     }
 
     func_input_arg->set_name(op_input_arg.name());
-    func_input_arg->set_type(op->inputs[i].dtype());
+    func_input_arg->set_type(op->inputs[i]->dtype);
   }
   VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
 
@@ -740,22 +1011,16 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
   // Since input param reordering may have occurred between `op` and `launch_op`
   // via `op_input_to_func_input`, adjust the actual inputs accordingly.
   launch_op->inputs = op->inputs;
-  launch_op->input_devices = op->input_devices;
-  launch_op->input_op_devices = op->input_op_devices;
+  for (TFE_TensorHandle* h : launch_op->inputs) {
+    h->Ref();
+  }
   if (!op_input_to_func_input.empty()) {
     DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
-    if (!op->input_devices.empty()) {
-      DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size());
-    }
     for (int i = 0; i < op_input_to_func_input.size(); ++i) {
       VLOG(1) << "mapping op input " << i << " to func input "
               << op_input_to_func_input[i];
 
       launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
-      if (!op->input_devices.empty()) {
-        launch_op->input_devices[op_input_to_func_input[i]] =
-            op->input_devices[i];
-      }
     }
   }
   launch_op->attrs.NumInputs(op->inputs.size());
@@ -789,37 +1054,17 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
 }
 #endif  // TENSORFLOW_EAGER_USE_XLA
 
-tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
-                                 TFE_Context* ctx, TF_Status* status) {
-  tensorflow::DeviceSet ds;
-  for (tensorflow::Device* d : ctx->devices) {
-    ds.AddDevice(d);
-  }
-  tensorflow::DeviceTypeVector final_devices;
-  status->status = tensorflow::SupportedDeviceTypesForNode(
-      ds.PrioritizedDeviceTypeList(), ndef, &final_devices);
-  if (!status->status.ok()) {
-    return nullptr;
-  }
-  if (final_devices.empty()) {
-    status->status = tensorflow::errors::Internal(
-        "Could not find valid device for node ", ndef.DebugString());
-    return nullptr;
-  }
-  for (tensorflow::Device* d : ctx->devices) {
-    if (d->device_type() == final_devices[0].type_string()) {
-      return d;
-    }
-  }
-  status->status = tensorflow::errors::Unknown(
-      "Could not find a device for node ", ndef.DebugString());
-  return nullptr;
-}
-
 }  // namespace
 
+extern "C" {
+
 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
                  TF_Status* status) {
+  TFE_Context* ctx = op->ctx;
+  status->status = ctx->executor.status();
+  if (!status->status.ok()) {
+    return;
+  }
 #ifdef TENSORFLOW_EAGER_USE_XLA
   std::unique_ptr<TFE_Op> xla_launch_op;
   if (op->use_xla && op->name != "_XlaLaunch") {
@@ -830,31 +1075,29 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     op = xla_launch_op.get();
   }
 #endif  // TENSORFLOW_EAGER_USE_XLA
-  TFE_Context* ctx = op->ctx;
-  tensorflow::Device* device = op->device;
   // Ensure all resource-touching ops run in the device the resource is,
   // regardless of anything else that has been specified. This is identical to
   // the graph mode behavior.
   for (int i = 0; i < op->inputs.size(); ++i) {
-    if (op->inputs[i].dtype() == tensorflow::DT_RESOURCE &&
-        op->input_op_devices[i] != device) {
-      tensorflow::Device* d = op->input_op_devices[i] == nullptr
-                                  ? ctx->devices[0]
-                                  : op->input_op_devices[i];
+    tensorflow::Device* input_op_device = nullptr;
+    status->status = op->inputs[i]->OpDevice(&input_op_device);
+    if (!status->status.ok()) return;
+    if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE &&
+        input_op_device != op->device) {
+      tensorflow::Device* d =
+          input_op_device == nullptr ? ctx->devices[0] : input_op_device;
       VLOG(1) << "Changing device of operation " << op->name << " to "
               << d->name() << " because input #" << i
               << " is a resource in this device.";
-      device = d;
       op->device = d;
     }
   }
+  tensorflow::Device* device = op->device;
   if (!ctx->soft_placement && device == nullptr) {
     // TODO(ashankar): ASSUMPTION: ctx->devices[0] is always CPU
     device = ctx->devices[0];
   }
 
-  std::vector<tensorflow::Tensor> outputs(1);
-  const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
   tensorflow::Fprint128 cache_key =
       op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name());
   tensorflow::KernelAndDevice* kernel;
@@ -879,8 +1122,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     // Knowledge of the implementation of Init (and in-turn
     // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
     // will be accessed, so grab on to the lock.
-    // See WARNING comment below - would be nice to rework to avoid this
-    // subtlety.
+    // See WARNING comment in Execute (before kernel->Run) - would be nice to
+    // rework to avoid this subtlety.
     tensorflow::tf_shared_lock l(ctx->functions_mu);
     status->status =
         tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
@@ -903,29 +1146,30 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     }
     tensorflow::DataTypeVector input_dtypes;
     status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes,
-                                       kernel->output_dtypes());
+                                       kernel->mutable_output_dtypes());
     if (!status->status.ok()) {
       return;
     }
     tensorflow::mutex_lock ml(ctx->cache_mu);
     tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
   }
+  const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
+  if (output_dtypes.size() != *num_retvals) {
+    TF_SetStatus(status, TF_INVALID_ARGUMENT,
+                 tensorflow::strings::StrCat("Expecting ", output_dtypes.size(),
+                                             " outputs, but *num_retvals is ",
+                                             *num_retvals)
+                     .c_str());
+    return;
+  }
   if (device == nullptr) {
     // TODO(apassos) debug how the assignment below might return a different
     // device from the one requested above.
     device = kernel->device();
   }
-
-  std::vector<TFE_TensorHandle*> copied_tensors;
-  status->status = ValidateInputTypeAndPlacement(
-      ctx, ctx->devices[0], device, op, kernel->kernel(), &copied_tensors);
-  output_memory_types = &kernel->kernel()->output_memory_types();
-  if (!status->status.ok()) {
-    for (auto* t : copied_tensors) {
-      TFE_DeleteTensorHandle(t);
-    }
-    return;
-  }
+  status->status = ValidateInputTypeAndPlacement(ctx, ctx->devices[0], device,
+                                                 op, kernel->kernel());
+  if (!status->status.ok()) return;
   std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
   if (ctx->should_store_metadata.load()) {
     maybe_stats.reset(new tensorflow::NodeExecStats);
@@ -935,53 +1179,47 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
     // TODO(apassos) track referenced tensors
   }
-  // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
-  // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
-  // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
-  // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
-  // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
-  // This is quite subtle. Re-work things to make this better?  (Would it make
-  // sense for FunctionLibraryRuntime to ensure thread-safe access to
-  // FunctionLibraryDefinition?).  TODO(apassos) figure out how to record stats
-  // for ops which are a part of functions.
-  status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get());
-  for (auto* t : copied_tensors) {
-    TFE_DeleteTensorHandle(t);
-  }
-  if (!status->status.ok()) return;
-  if (maybe_stats != nullptr) {
-    maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
-                                       maybe_stats->all_start_micros());
-    tensorflow::mutex_lock ml(ctx->metadata_mu);
-    if (ctx->should_store_metadata.load()) {
-      auto* step_stats = ctx->run_metadata.mutable_step_stats();
-      // Lazily initialize the RunMetadata with information about all devices if
-      // this is the first call.
-      while (step_stats->dev_stats_size() < ctx->devices.size()) {
-        step_stats->add_dev_stats();
-      }
-      // Find the current device's index.
-      int device_idx = 0;
-      for (int i = 0; i < ctx->devices.size(); ++i) {
-        if (ctx->devices[i] == device) {
-          device_idx = i;
-          break;
-        }
-      }
-      // Populate the device stats for this device.
-      auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
-      dev_stats->set_device(device->name());
-      *dev_stats->add_node_stats() = *maybe_stats;
+  if (ctx->Async()) {
+    // Note that for async mode, execution order will make sure that all
+    // input handles are ready before executing them.
+    // TODO(agarwal): Consider executing "cheap" kernels inline for performance.
+    TFE_Node* node = new ExecuteNode(op, kernel, maybe_stats.release(),
+                                     output_dtypes, retvals, *num_retvals);
+    ctx->executor.Add(node);
+  } else {
+    // Execute checks if retvals[i] is nullptr or not to figure if it needs to
+    // allocate it.
+    for (int i = 0; i < *num_retvals; ++i) {
+      retvals[i] = nullptr;
     }
+    status->status = Execute(op->ctx, op->device, op->inputs, kernel,
+                             maybe_stats.get(), retvals, *num_retvals);
   }
-  *num_retvals = std::min<int>(*num_retvals, outputs.size());
-  for (int i = 0; i < *num_retvals; ++i) {
-    tensorflow::Device* d = IsCPU(device) ? nullptr : device;
-    if (d != nullptr && output_memory_types != nullptr &&
-        (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
-      d = nullptr;
-    }
-    retvals[i] = new TFE_TensorHandle(outputs[i], d, device);
+}
+
+TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
+                                               TFE_Context* ctx,
+                                               const char* device_name,
+                                               TF_Status* status) {
+  status->status = ctx->executor.status();
+  if (!status->status.ok()) {
+    return nullptr;
+  }
+  tensorflow::Device* dstd = ctx->devices[0];
+  if (device_name != nullptr && strlen(device_name) > 0) {
+    status->status = ctx->device_manager->LookupDevice(device_name, &dstd);
+    if (!status->status.ok()) return nullptr;
+  }
+  if (ctx->Async()) {
+    // Note that `h` may not be currently ready. However execution order will
+    // make sure that `h` is ready before the copy is actually done.
+    CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
+    ctx->executor.Add(node);
+    return node->dst();
+  } else {
+    TFE_TensorHandle* output = nullptr;
+    status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output);
+    return output;
   }
 }
 
@@ -1004,6 +1242,16 @@ void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
   status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
 }
 
+void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
+  ctx->should_store_metadata.store(true);
+}
+
+void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
+  tensorflow::mutex_lock ml(ctx->metadata_mu);
+  ctx->should_store_metadata.store(false);
+  ctx->run_metadata.Clear();
+}
+
 }  // extern "C"
 
 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
@@ -1012,27 +1260,24 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
 
 const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
     TFE_TensorHandle* h, TF_Status* status) {
-  if (h->d != nullptr) {
+  tensorflow::Device* d = nullptr;
+  tensorflow::Device* op_device = nullptr;
+  const tensorflow::Tensor* t = nullptr;
+  status->status = h->TensorAndDevice(&t, &d, &op_device);
+  if (!status->status.ok()) return nullptr;
+  if (d != nullptr) {
     status->status = tensorflow::errors::FailedPrecondition(
         "TFE_TensorHandle is placed in device (not host) memory. Cannot return "
         "a tensorflow::Tensor");
     return nullptr;
   }
-  return &h->t;
-}
-
-void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
-  ctx->should_store_metadata.store(true);
-}
-
-void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
-  tensorflow::mutex_lock ml(ctx->metadata_mu);
-  ctx->should_store_metadata.store(false);
-  ctx->run_metadata.Clear();
+  return t;
 }
 
 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
                                   TF_Status* status) {
+  TFE_ContextAsyncWait(ctx, status);
+  if (!status->status.ok()) return;
   tensorflow::mutex_lock ml(ctx->metadata_mu);
   status->status = MessageToBuffer(ctx->run_metadata, buf);
   ctx->run_metadata.Clear();
@@ -1108,3 +1353,208 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
   }
 }
 }  // namespace tensorflow
+
+TFE_Node::TFE_Node(tensorflow::uint64 id) : id(id) {}
+
+TFE_Executor::~TFE_Executor() {
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  thread_done_ = true;
+  nodes_pending_.notify_all();
+}
+
+tensorflow::uint64 TFE_Executor::NextId() {
+  tensorflow::mutex_lock l(next_id_mutex_);
+  return next_id_++;
+}
+
+void TFE_Executor::EnableAsync() {
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  if (thread_ == nullptr) {
+    thread_.reset(tensorflow::Env::Default()->StartThread(
+        tensorflow::ThreadOptions(), "eager_async_executor",
+        std::bind(&TFE_Executor::Run, this)));
+  }
+}
+
+void TFE_Executor::Add(TFE_Node* node) {
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  DCHECK(thread_) << "EnableAsync should have been called before Add";
+  if (!status_.ok()) {
+    delete node;
+    return;
+  }
+  int qlen = node_queue_.size();
+  if (qlen > 0) {
+    if (node_queue_.back()->id >= node->id) {
+      status_ = tensorflow::errors::InvalidArgument(
+          "Inserting TFE_Node with non-increasing ids:", node_queue_.back()->id,
+          " vs ", node->id);
+      delete node;
+      return;
+    }
+    node_queue_.push(node);
+  } else {
+    node_queue_.push(node);
+    nodes_pending_.notify_all();
+  }
+}
+
+tensorflow::Status TFE_Executor::WaitFor(tensorflow::uint64 node_id) {
+  return WaitImpl(false, node_id);
+}
+
+tensorflow::Status TFE_Executor::WaitForAllPendingNodes() {
+  return WaitImpl(true, 0);
+}
+
+tensorflow::Status TFE_Executor::WaitImpl(bool wait_all,
+                                          tensorflow::uint64 node_id) {
+  tensorflow::condition_variable cond;
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  // Don't wait if an error is already set.
+  if (!status_.ok()) return status_;
+  if (node_queue_.empty()) return tensorflow::Status::OK();
+  if (wait_all) {
+    node_id = node_queue_.back()->id;
+  } else if (node_id < node_queue_.front()->id) {
+    // Note that we are relying on the ops being dispatched sequentially from
+    // the queue.
+    return tensorflow::Status::OK();
+  }
+  node_done_notifications_.insert(std::make_pair(node_id, &cond));
+  cond.wait(l);
+  // Note that we could be woken up if an error occurs, even though the node has
+  // not actually executed.
+  return status_;
+}
+
+void TFE_Executor::ClearError() {
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  if (status_.ok()) return;
+  // If an error was set, node_done_notifications_ and node_queue_ should have
+  // been cleared, and no new entries should have been added since.
+  DCHECK(node_done_notifications_.empty());
+  DCHECK(node_queue_.empty());
+  status_ = tensorflow::Status::OK();
+  nodes_pending_.notify_all();
+}
+
+tensorflow::Status TFE_Executor::status() {
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  return status_;
+}
+
+void TFE_Executor::Run() {
+  while (true) {
+    std::unique_ptr<TFE_Node> curr_node;
+    {
+      tensorflow::mutex_lock l(node_queue_mutex_);
+      while (node_queue_.empty() || !status_.ok()) {
+        if (thread_done_) return;
+        nodes_pending_.wait(l);
+      }
+      curr_node.reset(node_queue_.front());
+    }
+    tensorflow::Status status = curr_node->Run();
+    const bool ok = status.ok();
+    tensorflow::mutex_lock l(node_queue_mutex_);
+    node_queue_.pop();
+    if (!ok) {
+      status_ = status;
+      // TODO(agarwal): mark all affected handles as corrupted before clearing
+      // this queue.
+      // We remove any pending ops so that we don't try to execute them if
+      // ClearError is called.
+      for (int i = 0; i < node_queue_.size(); ++i) {
+        delete node_queue_.front();
+        node_queue_.pop();
+      }
+    }
+    if (!node_done_notifications_.empty()) {
+      tensorflow::uint64 node_id = curr_node->id;
+      // Note that we notify all waiting threads in case an error has occurred.
+      // These calling threads are responsible for checking status_ before
+      // proceeding.
+      const auto range = ok ? node_done_notifications_.equal_range(node_id)
+                            : make_pair(node_done_notifications_.begin(),
+                                        node_done_notifications_.end());
+      for (auto it = range.first; it != range.second; ++it) {
+        it->second->notify_all();
+      }
+      node_done_notifications_.erase(range.first, range.second);
+    }
+  }
+}
+
+bool TFE_Context::Async() const {
+  tensorflow::mutex_lock l(async_map_mu);
+  return tensorflow::gtl::FindWithDefault(
+      thread_local_async, std::this_thread::get_id(), async_default);
+}
+
+bool TFE_TensorHandle::IsReady() {
+  if (node_id == 0) return true;
+  tensorflow::mutex_lock l(ctx_mutex_);
+  return ctx_ == nullptr;
+}
+
+tensorflow::Status TFE_TensorHandle::WaitReady() {
+  if (node_id == 0) return tensorflow::Status::OK();
+  TFE_Executor* executor = nullptr;
+  {
+    tensorflow::mutex_lock l(ctx_mutex_);
+    if (ctx_ == nullptr) return tensorflow::Status::OK();
+    executor = &ctx_->executor;
+  }
+  return executor->WaitFor(node_id);
+}
+
+tensorflow::Status TFE_TensorHandle::Tensor(const tensorflow::Tensor** t) {
+  TF_RETURN_IF_ERROR(WaitReady());
+  DCHECK(IsReady());
+  *t = &tensor_;
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFE_TensorHandle::Device(tensorflow::Device** d) {
+  TF_RETURN_IF_ERROR(WaitReady());
+  DCHECK(IsReady());
+  *d = device_;
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFE_TensorHandle::OpDevice(tensorflow::Device** d) {
+  TF_RETURN_IF_ERROR(WaitReady());
+  DCHECK(IsReady());
+  *d = op_device_;
+  return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFE_TensorHandle::TensorAndDevice(
+    const tensorflow::Tensor** tensor, tensorflow::Device** device,
+    tensorflow::Device** op_device) {
+  TF_RETURN_IF_ERROR(WaitReady());
+  DCHECK(IsReady());
+  *tensor = &tensor_;
+  *device = device_;
+  *op_device = op_device_;
+  return tensorflow::Status::OK();
+}
+
+void TFE_TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor,
+                                          tensorflow::Device* device,
+                                          tensorflow::Device* op_device) {
+  tensorflow::mutex_lock l(ctx_mutex_);
+  DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called  "
+                              << "on non-ready handles.";
+  ctx_ = nullptr;
+  tensor_ = tensor;
+  device_ = device;
+  op_device_ = op_device;
+}
+
+TFE_Op::~TFE_Op() {
+  for (TFE_TensorHandle* h : inputs) {
+    h->Unref();
+  }
+}
index 9610ca1..316006b 100644 (file)
@@ -75,6 +75,11 @@ typedef enum TFE_ContextDevicePlacementPolicy {
   TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
 } TFE_ContextDevicePlacementPolicy;
 
+// Sets the default execution mode (sync/async). Note that this can be
+// overridden per thread using TFE_ContextSetAsyncForThread.
+TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
+                                                      unsigned char async);
+
 TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
     TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
 
@@ -110,6 +115,30 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy(
 TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy
 TFE_ContextGetDevicePlacementPolicy(TFE_Context*);
 
+// Overrides the execution mode (sync/async) for the current thread.
+TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
+                                                        unsigned char async,
+                                                        TF_Status* status);
+
+// Causes the calling thread to block till all ops dispatched in async mode
+// have been executed. Note that "execution" here refers to kernel execution /
+// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
+// that lower level device queues (like GPU streams) have been flushed.
+//
+// This call may not block for execution of ops enqueued concurrently with this
+// call.
+TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context*,
+                                                TF_Status* status);
+
+// When an error happens, any pending operations are discarded and newly issued
+// ops return an error. This call clears the error state and re-enables
+// execution of newly issued ops.
+//
+// Note that outputs of discarded ops remain in a corrupt state and should not
+// be used for future calls.
+// TODO(agarwal): mark the affected handles and raise errors if they are used.
+TF_CAPI_EXPORT extern void TFE_ContextAsyncClearError(TFE_Context*);
+
 // A handle to a tensor on a device.
 //
 // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
@@ -119,15 +148,21 @@ typedef struct TFE_TensorHandle TFE_TensorHandle;
 
 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
                                                             TF_Status* status);
+// Indicates that the caller will not be using `h` any more.
 TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
 TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
+// This function will block till the operation that produces `h` has completed.
 TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h,
                                                   TF_Status* status);
+// This function will block till the operation that produces `h` has completed.
 TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
                                                   int dim_index,
                                                   TF_Status* status);
+// This function will block till the operation that produces `h` has completed.
 TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
     TFE_TensorHandle* h, TF_Status* status);
+
+// This function will block till the operation that produces `h` has completed.
 TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
                                                          TF_Status* status);
 
@@ -137,6 +172,9 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
 // that shares the underlying buffer. Otherwise, it currently requires at least
 // one of the source or destination devices to be CPU (i.e., for the source or
 // destination tensor to be placed in host memory).
+// If async execution is enabled, the copy may be enqueued and the call will
+// return "non-ready" handle. Else, this function returns after the copy has
+// been done.
 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(
     TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name,
     TF_Status* status);
@@ -157,6 +195,7 @@ typedef struct TFE_Op TFE_Op;
 TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx,
                                         const char* op_or_function_name,
                                         TF_Status* status);
+
 TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
 
 TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
@@ -242,13 +281,20 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op,
                                                      int num_values);
 
 // Execute the operation defined by 'op' and return handles to computed
-// tensors in 'retvals'.
+// tensors in `retvals`.
+//
+// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and
+// '*num_retvals' should be set to the size of this array. It is an error if
+// the number of outputs is different from *num_retvals.
 //
-// 'retvals' must point to a pre-allocated array of TFE_TensorHandle*
-// and '*num_retvals' should be set to the size of this array.
+// If async execution is enabled, the call may simply enqueue the execution
+// and return "non-ready" handles in `retvals`. Note that any handles contained
+// in 'op' should not be mutated till the kernel execution actually finishes.
 //
-// On return, 'num_retvals' will be set to the actual number of outputs
-// returned by the operation.
+// For sync execution, if any of the inputs to `op` are not ready, this call
+// will block till they become ready and then return when the kernel execution
+// is done.
+// TODO(agarwal): change num_retvals to int from int*.
 TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
                                        int* num_retvals, TF_Status* status);
 
@@ -274,6 +320,8 @@ TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx);
 // Populates the passed-in buffer with a serialized RunMetadata protocol buffer
 // containing any run metadata information accumulated so far and clears this
 // information.
+// If async mode is enabled, this call blocks till all currently pending ops are
+// done.
 TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
                                                         TF_Buffer* buf,
                                                         TF_Status* status);
index 49b9434..8dba12f 100644 (file)
@@ -19,7 +19,9 @@ limitations under the License.
 
 #include <algorithm>
 #include <cstddef>
+#include <map>
 #include <memory>
+#include <queue>
 #include <string>
 #include <thread>
 #include <vector>
@@ -31,14 +33,113 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
 #include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/gtl/stl_util.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/public/version.h"
 
+// A unit of execution for the TFE_Executor class below. Example subclasses
+// encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one
+// device to another.
+class TFE_Node {
+ public:
+  explicit TFE_Node(tensorflow::uint64 id);
+
+  virtual ~TFE_Node() {}
+
+  // Runs the computation corresponding to this node and blocks till the
+  // execution is done.
+  virtual tensorflow::Status Run() = 0;
+
+  // An id unique to the TFE_Context under which this node is created. Allocated
+  // monotonically.
+  const tensorflow::uint64 id;
+};
+
+// A class for handling async execution (see TFE_ContextSetAsync).
+// Note that this class is thread-safe.
+// TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the
+// device of the input handle. Fix that.
+// TODO(agarwal): On error, mark all affected handles as corrupted.
+// TODO(agarwal): Implement support for control dependencies.
+// TODO(agarwal): Support out-of-order execution and dispatching multiple
+// TFE_Node in parallel.
+// TODO(agarwal): Implement optimizations over TFE_Node traces.
+class TFE_Executor {
+ public:
+  ~TFE_Executor();
+
+  // This is called whenever async mode is enabled. Note that it may be called
+  // multiple times as different calling threads may switch async mode on or off
+  // independently.
+  void EnableAsync();
+
+  // Helper function to create monotonically increasing ids unique to this
+  // object.
+  tensorflow::uint64 NextId();
+
+  // Schedules `node` for execution.
+  // Note that Add must be called in monotonically increasing order of node->id.
+  void Add(TFE_Node* node);
+
+  // Causes the caller to block till node with id `node_id` has finished
+  // execution.
+  tensorflow::Status WaitFor(tensorflow::uint64 node_id);
+
+  // Blocks till all currently pending ops are done.
+  tensorflow::Status WaitForAllPendingNodes();
+
+  // Clears all currently set errors which re-enables async execution.
+  void ClearError();
+
+  // Returns Status based on any errors that occurred during async execution.
+  tensorflow::Status status();
+
+ private:
+  // Starts execution of pending TFE_Nodes. This function loops till
+  // thread_done_ is set to true. If any errors are encontered, these are set
+  // inside `status_`. The loop blocks anytime there are no pending nodes, or if
+  // `status_` is not ok.
+  void Run();
+
+  tensorflow::Status WaitImpl(bool wait_all, tensorflow::uint64 node_id);
+
+  tensorflow::mutex node_queue_mutex_;
+
+  // Used to signal that some TFE_Nodes are pending execution.
+  tensorflow::condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_);
+
+  // Queue of pending TFE_Nodes.
+  std::queue<TFE_Node*> node_queue_ GUARDED_BY(node_queue_mutex_);
+
+  // `status_` is set based on any errors raised during execution of a TFE_Node.
+  // It remains set until ClearError is called.
+  tensorflow::Status status_ GUARDED_BY(node_queue_mutex_);
+
+  // Map from id of a TFE_Node to condition_variables (not owned by the map).
+  // These condition_variables are notified and removed when that TFE_Node is
+  // done executing, or if an error is found in execution of any TFE_Node.
+  std::multimap<tensorflow::uint64, tensorflow::condition_variable*>
+      node_done_notifications_ GUARDED_BY(node_queue_mutex_);
+
+  // Thread object that calls the `Run` method. Currently we use only one thread
+  // for executing the TFE_Nodes one-by-one.
+  std::unique_ptr<tensorflow::Thread> thread_ GUARDED_BY(node_queue_mutex_);
+
+  // Indicates that `thread_` should stop as soon as it is done executing the
+  // current TFE_Node.
+  bool thread_done_ GUARDED_BY(node_queue_mutex_) = false;
+
+  tensorflow::mutex next_id_mutex_;
+  tensorflow::uint64 next_id_ GUARDED_BY(next_id_mutex_) = 1;
+};
+
 struct TFE_ContextOptions {
   TF_SessionOptions session_options;
+  // true if async execution is enabled.
+  bool async = false;
   TFE_ContextDevicePlacementPolicy policy{
       TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32};
 };
@@ -60,7 +161,10 @@ struct TFE_Context {
             device_manager.get(), opts.session_options.options.env,
             TF_GRAPH_DEF_VERSION, &func_lib_def, {})),
         log_device_placement(
-            opts.session_options.options.config.log_device_placement()) {}
+            opts.session_options.options.config.log_device_placement()),
+        async_default(opts.async) {
+    if (async_default) executor.EnableAsync();
+  }
 
   const bool soft_placement;
   const TFE_ContextDevicePlacementPolicy policy;
@@ -98,29 +202,99 @@ struct TFE_Context {
   std::atomic<bool> should_store_metadata{false};
   tensorflow::mutex metadata_mu;
   tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu);
-
   const bool log_device_placement;
+  // TFE_Executor for async execution.
+  TFE_Executor executor;
+
+  // True if running in asynchronous mode.
+  bool Async() const;
+
+  // True if the default value for execution mode is async. Note that this value
+  // can be overridden per thread based on `thread_local_async` overrides.
+  const bool async_default;
+  mutable tensorflow::mutex async_map_mu;
+  std::unordered_map<std::thread::id, bool> thread_local_async
+      GUARDED_BY(async_map_mu);
 };
 
-struct TFE_TensorHandle {
+struct TFE_TensorHandle : public tensorflow::core::RefCounted {
+ public:
   TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d,
                    tensorflow::Device* op_device)
-      : t(t), d(d), op_device(op_device) {}
+      : dtype(t.dtype()),
+        node_id(0),
+        tensor_(t),
+        device_(d),
+        op_device_(op_device),
+        ctx_(nullptr) {}
+
+  TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype,
+                   TFE_Context* ctx)
+      : dtype(dtype),
+        node_id(node_id),
+        tensor_(dtype),
+        device_(nullptr),
+        op_device_(nullptr),
+        ctx_(ctx) {
+    DCHECK_GT(node_id, 0);
+  }
+
+  ~TFE_TensorHandle() override {}
+
+  tensorflow::Status Tensor(const tensorflow::Tensor** t);
+
+  tensorflow::Status Device(tensorflow::Device** d);
 
-  tensorflow::Tensor t;
-  // TODO(ashankar): d == nullptr iff local CPU
-  // This was expedient, but perhaps worth revisiting ('d' should always be a
-  // valid pointer?)
+  tensorflow::Status OpDevice(tensorflow::Device** d);
+
+  tensorflow::Status TensorAndDevice(const tensorflow::Tensor** tensor,
+                                     tensorflow::Device** device,
+                                     tensorflow::Device** op_device);
+
+  // Note that this can be called at most once, and only on non-ready handles,
+  // and makes them ready.
+  void SetTensorAndDevice(const tensorflow::Tensor& tensor,
+                          tensorflow::Device* device,
+                          tensorflow::Device* op_device);
+
+  // dtype for the handle. It must be the same as t.dtype() once the handle is
+  // ready.
+  const tensorflow::DataType dtype;
+
+ private:
+  // If the contents of the Tensor pointed to by this handle is yet to be
+  // computed by a TFE_Node, this function will block till that compuatation is
+  // done and the handle is "ready".
+  tensorflow::Status WaitReady();
+
+  bool IsReady();
+
+  // Id for the TFE_Node that will compute the value pointed to by this handle.
+  // If the value is 0, the handle is already ready, but not vice-versa.
+  const tensorflow::uint64 node_id;
+
+  tensorflow::Tensor tensor_;
+
+  // TODO(ashankar): device_ == nullptr iff local CPU
+  // This was expedient, but perhaps worth revisiting ('device_' should always
+  // be a valid pointer?)
   // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
   // provided with the appropriate TFE_Context.
   //
-  // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a
+  // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
   // TFE_TensorHandle does not outlive the TFE_Context from which it came?
-  tensorflow::Device* d;
+  tensorflow::Device* device_;
+
+  // Device in which the op producing this tensor was executed. Equals to
+  // device_ for constant tensors.
+  tensorflow::Device* op_device_;
 
-  // Device in which the op producing this tensor was executed. Equals to d for
-  // constant tensors.
-  tensorflow::Device* op_device;
+  tensorflow::mutex ctx_mutex_;
+
+  // `ctx` is only guaranteed to be set if the handle is not "ready". This is
+  // typically true when the handle was produced during async execution.
+  // `ctx` object is not owned and should outlive this handle.
+  TFE_Context* ctx_ GUARDED_BY(ctx_mutex_);
 };
 
 struct TFE_Op {
@@ -129,15 +303,15 @@ struct TFE_Op {
   TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
       : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
 
+  ~TFE_Op();
+
   bool const is_function() const { return attr_types == nullptr; }
 
   TFE_Context* ctx;  // Must outlive the TFE_Op.
   const tensorflow::string name;
   tensorflow::AttrBuilder attrs;
   const tensorflow::AttrTypeMap* attr_types;
-  std::vector<tensorflow::Tensor> inputs;
-  std::vector<tensorflow::Device*> input_devices;
-  std::vector<tensorflow::Device*> input_op_devices;
+  tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs;
   tensorflow::Device* device;
   bool use_xla = false;
 };
index 00fb7e6..927d119 100644 (file)
@@ -29,6 +29,20 @@ using tensorflow::string;
 
 namespace {
 
+TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
+  int64_t dims[] = {2, 2};
+  double data[] = {1.0, 2.0, 3.0, 4.0};
+  TF_Tensor* t = TF_AllocateTensor(
+      TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+  memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+  TF_Status* status = TF_NewStatus();
+  TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TF_DeleteTensor(t);
+  TF_DeleteStatus(status);
+  return th;
+}
+
 TFE_TensorHandle* TestMatrixTensorHandle() {
   int64_t dims[] = {2, 2};
   float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
@@ -43,6 +57,20 @@ TFE_TensorHandle* TestMatrixTensorHandle() {
   return th;
 }
 
+TFE_TensorHandle* TestMatrixTensorHandle3X2() {
+  int64_t dims[] = {3, 2};
+  double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+  TF_Tensor* t = TF_AllocateTensor(
+      TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
+  memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
+  TF_Status* status = TF_NewStatus();
+  TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TF_DeleteTensor(t);
+  TF_DeleteStatus(status);
+  return th;
+}
+
 TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
   TF_Status* status = TF_NewStatus();
 
@@ -139,10 +167,12 @@ void BM_InitOp(int iters) {
 }
 BENCHMARK(BM_InitOp);
 
-void BM_Execute(int iters) {
+void BM_Execute(int iters, int async) {
   tensorflow::testing::StopTiming();
+  tensorflow::testing::SetLabel(async ? "ExecuteAsync" : "Execute");
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status);
   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteContextOptions(opts);
@@ -156,6 +186,9 @@ void BM_Execute(int iters) {
     TFE_Execute(matmul, &retvals[0], &num_retvals, status);
     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   }
+  if (async) {
+    TFE_ContextAsyncWait(ctx, status);
+  }
   tensorflow::testing::StopTiming();
   TFE_DeleteOp(matmul);
   TFE_DeleteTensorHandle(m);
@@ -163,7 +196,7 @@ void BM_Execute(int iters) {
   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TF_DeleteStatus(status);
 }
-BENCHMARK(BM_Execute);
+BENCHMARK(BM_Execute)->Arg(0)->Arg(1);
 
 TEST(CAPI, Context) {
   TF_Status* status = TF_NewStatus();
@@ -205,10 +238,11 @@ TEST(CAPI, TensorHandle) {
   TFE_DeleteTensorHandle(h);
 }
 
-TEST(CAPI, TensorHandleCopyBetweenDevices) {
+void TensorHandleCopyBetweenDevices(bool async) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status.get());
   TFE_DeleteContextOptions(opts);
   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@@ -274,10 +308,56 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
   EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
 }
 
-TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
+TEST(CAPI, TensorHandleCopyBetweenDevices) {
+  TensorHandleCopyBetweenDevices(false);
+}
+
+TEST(CAPI, TensorHandleCopyBetweenDevicesAsync) {
+  TensorHandleCopyBetweenDevices(true);
+}
+
+void TensorHandleCopyBetweenDevicesError(bool async) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+  TFE_Context* ctx = TFE_NewContext(opts, status.get());
+  TFE_DeleteContextOptions(opts);
+  ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+  TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
+  const char* kErrorDevice = "NoSuchDevice:0";
+  TFE_TensorHandle* hdevice =
+      TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get());
+  EXPECT_NE(TF_OK, TF_GetCode(status.get()));
+  const char* msg = "NoSuchDevice:0 unknown device";
+  EXPECT_TRUE(strstr(TF_Message(status.get()), msg) != nullptr)
+      << TF_Message(status.get());
+  TF_SetStatus(status.get(), TF_OK, "");
+  const char* kCPUDevice = "CPU:0";
+  TFE_TensorHandle* hcopy =
+      TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get());
+  EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+  TFE_ContextAsyncWait(ctx, status.get());
+  EXPECT_EQ(TF_OK, TF_GetCode(status.get()));
+  TFE_DeleteTensorHandle(hcopy);
+  TFE_DeleteTensorHandle(hcpu);
+  if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice);
+  TFE_DeleteContext(ctx, status.get());
+}
+
+TEST(CAPI, TensorHandleCopyBetweenDevicesError) {
+  TensorHandleCopyBetweenDevicesError(false);
+}
+
+TEST(CAPI, TensorHandleCopyBetweenDevicesErrorAsync) {
+  TensorHandleCopyBetweenDevicesError(true);
+}
+
+void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status.get());
   TFE_DeleteContextOptions(opts);
   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@@ -332,11 +412,20 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
   EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
 }
 
-TEST(CAPI, TensorHandleSilentCopy) {
+TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
+  TensorHandleCopyBetweenTwoGPUDevices(false);
+}
+
+TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
+  TensorHandleCopyBetweenTwoGPUDevices(true);
+}
+
+void TensorHandleSilentCopy(bool async) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TFE_ContextOptions* opts = TFE_NewContextOptions();
   TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status.get());
   TFE_DeleteContextOptions(opts);
   ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@@ -366,14 +455,20 @@ TEST(CAPI, TensorHandleSilentCopy) {
 
   TF_DeleteTensor(t);
   TFE_DeleteTensorHandle(hcpu);
+  TFE_ContextAsyncWait(ctx, status.get());
+  EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
   TFE_DeleteContext(ctx, status.get());
   EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
 }
 
-TEST(CAPI, TensorHandleSilentCopyLocal) {
+TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); }
+TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); }
+
+void TensorHandleSilentCopyLocal(bool async) {
   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
       TF_NewStatus(), TF_DeleteStatus);
   TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_ContextOptionsSetDevicePlacementPolicy(opts,
                                              TFE_DEVICE_PLACEMENT_EXPLICIT);
   TFE_Context* ctx = TFE_NewContext(opts, status.get());
@@ -407,11 +502,17 @@ TEST(CAPI, TensorHandleSilentCopyLocal) {
 
   TF_DeleteTensor(t);
   TFE_DeleteTensorHandle(hcpu);
+  TFE_ContextAsyncWait(ctx, status.get());
+  EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
   TFE_DeleteContext(ctx, status.get());
   EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
 }
+TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
+TEST(CAPI, TensorHandleSilentCopyLocalAsync) {
+  TensorHandleSilentCopyLocal(true);
+}
 
-TEST(CAPI, SetAndGetOpDevices) {
+void SetAndGetOpDevices(bool async) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
   TFE_Context* ctx = TFE_NewContext(opts, status);
@@ -442,27 +543,27 @@ TEST(CAPI, SetAndGetOpDevices) {
   TF_DeleteStatus(status);
 }
 
-TEST(CAPI, Execute_MatMul_CPU) {
+void Execute_MatMul_CPU(bool async) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status);
   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteContextOptions(opts);
 
   TFE_TensorHandle* m = TestMatrixTensorHandle();
   TFE_Op* matmul = MatMulOp(ctx, m, m);
-  TFE_TensorHandle* retvals[2] = {nullptr};
-  int num_retvals = 2;  // Should be reduced to 1 by the TFE_Execute call.
+  TFE_TensorHandle* retvals[1] = {nullptr};
+  int num_retvals = 1;
   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteOp(matmul);
   TFE_DeleteTensorHandle(m);
-  TFE_DeleteContext(ctx, status);
-  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
-  ASSERT_EQ(1, num_retvals);
 
   TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+  ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteTensorHandle(retvals[0]);
+  TFE_DeleteContext(ctx, status);
   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   float product[4] = {0};
   EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
@@ -474,7 +575,101 @@ TEST(CAPI, Execute_MatMul_CPU) {
   EXPECT_EQ(22, product[3]);
   TF_DeleteStatus(status);
 }
+TEST(CAPI, Execute_MatMul_CPU) { Execute_MatMul_CPU(false); }
+TEST(CAPI, Execute_MatMul_CPUAsync) { Execute_MatMul_CPU(true); }
+
+void Execute_MatMul_CPU_Runtime_Error(bool async) {
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+
+  TFE_TensorHandle* m1 = TestMatrixTensorHandle();
+  TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2();
+  TFE_Op* matmul = MatMulOp(ctx, m1, m2);
+  TFE_Op* matmul2 = MatMulOp(ctx, m1, m1);
+  TFE_TensorHandle* retvals[1] = {nullptr};
+  int num_retvals = 1;
+  TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+  TFE_DeleteOp(matmul);
+  if (!async) {
+    EXPECT_NE(TF_OK, TF_GetCode(status));
+  } else {
+    TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+    EXPECT_NE(TF_OK, TF_GetCode(status));
+    EXPECT_EQ(nullptr, t);
+    const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]";
+    EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr)
+        << TF_Message(status);
+    // Since error is not cleared, the following copy with correct device will
+    // still fail.
+    TF_SetStatus(status, TF_OK, "");
+    TFE_DeleteTensorHandle(retvals[0]);
+    retvals[0] = nullptr;
+    TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
+    EXPECT_NE(TF_OK, TF_GetCode(status));
+    TFE_ContextAsyncClearError(ctx);
+    TFE_ContextAsyncWait(ctx, status);
+    EXPECT_EQ(TF_OK, TF_GetCode(status));
+  }
+  // Following works in async mode since TFE_ContextAsyncClearError was called.
+  TF_SetStatus(status, TF_OK, "");
+  if (retvals[0] != nullptr) {
+    TFE_DeleteTensorHandle(retvals[0]);
+  }
+  retvals[0] = nullptr;
+  TFE_Execute(matmul2, &retvals[0], &num_retvals, status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status));
+  TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+  EXPECT_EQ(TF_OK, TF_GetCode(status));
+  TF_DeleteTensor(t);
+  TFE_DeleteOp(matmul2);
+  TFE_DeleteTensorHandle(m1);
+  TFE_DeleteTensorHandle(m2);
+  TFE_DeleteTensorHandle(retvals[0]);
+  TFE_DeleteContext(ctx, status);
+  TF_DeleteStatus(status);
+}
+TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) {
+  Execute_MatMul_CPU_Runtime_Error(false);
+}
+TEST(CAPI, Execute_MatMul_CPU_Runtime_ErrorAsync) {
+  Execute_MatMul_CPU_Runtime_Error(true);
+}
+
+void Execute_MatMul_CPU_Type_Error(bool async) {
+  TF_Status* status = TF_NewStatus();
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+  TFE_Context* ctx = TFE_NewContext(opts, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+  TFE_DeleteContextOptions(opts);
+
+  TFE_TensorHandle* m1 = TestMatrixTensorHandle();
+  TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle();
+  TFE_Op* matmul = MatMulOp(ctx, m1, m2);
+  TFE_TensorHandle* retvals[1] = {nullptr};
+  int num_retvals = 1;
+  TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+  EXPECT_NE(TF_OK, TF_GetCode(status));
+  TFE_DeleteOp(matmul);
+  TFE_DeleteTensorHandle(m1);
+  TFE_DeleteTensorHandle(m2);
+  if (retvals[0] != nullptr) {
+    TFE_DeleteTensorHandle(retvals[0]);
+  }
+  TFE_DeleteContext(ctx, status);
+  TF_DeleteStatus(status);
+}
 
+TEST(CAPI, Execute_MatMul_CPU_Type_Error) {
+  Execute_MatMul_CPU_Type_Error(false);
+}
+TEST(CAPI, Execute_MatMul_CPU_Type_ErrorAsync) {
+  Execute_MatMul_CPU_Type_Error(true);
+}
 TEST(CAPI, Execute_Min_CPU) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
@@ -485,8 +680,8 @@ TEST(CAPI, Execute_Min_CPU) {
   TFE_TensorHandle* input = TestMatrixTensorHandle();
   TFE_TensorHandle* axis = TestAxisTensorHandle();
   TFE_Op* minOp = MinOp(ctx, input, axis);
-  TFE_TensorHandle* retvals[2] = {nullptr};
-  int num_retvals = 2;  // Should be reduced to 1 by the TFE_Execute call.
+  TFE_TensorHandle* retvals[1] = {nullptr};
+  int num_retvals = 1;
   TFE_Execute(minOp, &retvals[0], &num_retvals, status);
   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteOp(minOp);
@@ -509,9 +704,10 @@ TEST(CAPI, Execute_Min_CPU) {
 }
 
 #ifdef TENSORFLOW_EAGER_USE_XLA
-TEST(CAPI, Execute_MatMul_XLA_CPU) {
+void Execute_MatMul_XLA_CPU(bool async) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status);
   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteContextOptions(opts);
@@ -521,15 +717,14 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) {
 
   TFE_OpSetXLACompilation(matmul, true);
 
-  TFE_TensorHandle* retvals[2] = {nullptr};
-  int num_retvals = 2;  // Should be reduced to 1 by the TFE_Execute call.
+  TFE_TensorHandle* retvals[1] = {nullptr};
+  int num_retvals = 1;
   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
   // Running a primitive TF operator via XLA is not yet supported.
   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
 
   TFE_DeleteOp(matmul);
   TFE_DeleteTensorHandle(m);
-  TFE_DeleteContext(ctx, status);
   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
 
   EXPECT_EQ(1, num_retvals);
@@ -545,13 +740,16 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) {
   EXPECT_EQ(10, product[1]);
   EXPECT_EQ(15, product[2]);
   EXPECT_EQ(22, product[3]);
-
+  TFE_DeleteContext(ctx, status);
   TF_DeleteStatus(status);
 }
+TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); }
+TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); }
 
-TEST(CAPI, Execute_Min_XLA_CPU) {
+void Execute_Min_XLA_CPU(bool async) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status);
   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteContextOptions(opts);
@@ -562,14 +760,13 @@ TEST(CAPI, Execute_Min_XLA_CPU) {
 
   TFE_OpSetXLACompilation(minOp, true);
 
-  TFE_TensorHandle* retvals[2] = {nullptr};
-  int num_retvals = 2;  // Should be reduced to 1 by the TFE_Execute call.
+  TFE_TensorHandle* retvals[1] = {nullptr};
+  int num_retvals = 1;
   TFE_Execute(minOp, &retvals[0], &num_retvals, status);
   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteOp(minOp);
   TFE_DeleteTensorHandle(input);
   TFE_DeleteTensorHandle(axis);
-  TFE_DeleteContext(ctx, status);
   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   ASSERT_EQ(1, num_retvals);
 
@@ -582,13 +779,17 @@ TEST(CAPI, Execute_Min_XLA_CPU) {
   TF_DeleteTensor(t);
   EXPECT_EQ(1, output[0]);
   EXPECT_EQ(3, output[1]);
+  TFE_DeleteContext(ctx, status);
   TF_DeleteStatus(status);
 }
+TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); }
+TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); }
 #endif  // TENSORFLOW_EAGER_USE_XLA
 
-TEST(CAPI, ExecuteWithTracing) {
+void ExecuteWithTracing(bool async) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status);
   TFE_ContextEnableRunMetadata(ctx);
   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@@ -596,8 +797,8 @@ TEST(CAPI, ExecuteWithTracing) {
 
   TFE_TensorHandle* m = TestMatrixTensorHandle();
   TFE_Op* matmul = MatMulOp(ctx, m, m);
-  TFE_TensorHandle* retvals[2] = {nullptr};
-  int num_retvals = 2;  // Should be reduced to 1 by the TFE_Execute call.
+  TFE_TensorHandle* retvals[1] = {nullptr};
+  int num_retvals = 1;
   TFE_Execute(matmul, &retvals[0], &num_retvals, status);
   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteOp(matmul);
@@ -609,12 +810,12 @@ TEST(CAPI, ExecuteWithTracing) {
   EXPECT_TRUE(
       rm.ParseFromString({reinterpret_cast<const char*>(b->data), b->length}));
   TF_DeleteBuffer(b);
-  TFE_DeleteContext(ctx, status);
   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   ASSERT_EQ(1, num_retvals);
 
   TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
   TFE_DeleteTensorHandle(retvals[0]);
+  TFE_DeleteContext(ctx, status);
   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   float product[4] = {0};
   EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
@@ -626,6 +827,8 @@ TEST(CAPI, ExecuteWithTracing) {
   EXPECT_EQ(22, product[3]);
   TF_DeleteStatus(status);
 }
+TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); }
+TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); }
 
 TEST(CAPI, Function_ident_CPU) {
   // First create a simple identity function.
@@ -657,32 +860,37 @@ TEST(CAPI, Function_ident_CPU) {
   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
   TF_DeleteFunction(fn);
 
-  TF_Tensor* t =
-      TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
-  *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
-  TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteTensor(t);
+  for (bool async : {false, true, false}) {
+    TFE_ContextSetAsyncForThread(ctx, static_cast<unsigned char>(async),
+                                 status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK);
+    TF_Tensor* t =
+        TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
+    *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
+    TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    TF_DeleteTensor(t);
 
-  TFE_Op* op = TFE_NewOp(ctx, "ident", status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TFE_OpAddInput(op, h, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    TFE_Op* op = TFE_NewOp(ctx, "ident", status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    TFE_OpAddInput(op, h, status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
 
-  std::vector<TFE_TensorHandle*> result;
-  result.push_back(nullptr);
-  int num_retvals = 1;
-  TFE_Execute(op, result.data(), &num_retvals, status);
-  TFE_DeleteOp(op);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  ASSERT_EQ(num_retvals, 1);
+    std::vector<TFE_TensorHandle*> result;
+    result.push_back(nullptr);
+    int num_retvals = 1;
+    TFE_Execute(op, result.data(), &num_retvals, status);
+    TFE_DeleteOp(op);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    ASSERT_EQ(num_retvals, 1);
 
-  TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
-  TFE_DeleteTensorHandle(h);
-  TF_DeleteTensor(r);
-  TFE_DeleteTensorHandle(result[0]);
+    TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
+    TFE_DeleteTensorHandle(h);
+    TF_DeleteTensor(r);
+    TFE_DeleteTensorHandle(result[0]);
+  }
   TFE_DeleteContext(ctx, status);
   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
   TF_DeleteStatus(status);
@@ -719,35 +927,40 @@ TEST(CAPI, Function_ident_XLA_CPU) {
   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
   TF_DeleteFunction(fn);
 
-  TF_Tensor* t =
-      TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
-  *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
-  TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TF_DeleteTensor(t);
+  for (bool async : {false, true, false}) {
+    TFE_ContextSetAsyncForThread(ctx, static_cast<unsigned char>(async),
+                                 status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK);
+    TF_Tensor* t =
+        TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
+    *reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
+    TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    TF_DeleteTensor(t);
 
-  TFE_Op* op = TFE_NewOp(ctx, "ident", status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  TFE_OpAddInput(op, h, status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    TFE_Op* op = TFE_NewOp(ctx, "ident", status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    TFE_OpAddInput(op, h, status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
 
-  // Now run it via XLA.
-  TFE_OpSetXLACompilation(op, true);
+    // Now run it via XLA.
+    TFE_OpSetXLACompilation(op, true);
 
-  std::vector<TFE_TensorHandle*> result;
-  result.push_back(nullptr);
-  int num_retvals = 1;
-  TFE_Execute(op, result.data(), &num_retvals, status);
-  TFE_DeleteOp(op);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  ASSERT_EQ(num_retvals, 1);
+    std::vector<TFE_TensorHandle*> result;
+    result.push_back(nullptr);
+    int num_retvals = 1;
+    TFE_Execute(op, result.data(), &num_retvals, status);
+    TFE_DeleteOp(op);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    ASSERT_EQ(num_retvals, 1);
 
-  TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
-  ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
-  EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
-  TFE_DeleteTensorHandle(h);
-  TF_DeleteTensor(r);
-  TFE_DeleteTensorHandle(result[0]);
+    TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
+    ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
+    EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
+    TFE_DeleteTensorHandle(h);
+    TF_DeleteTensor(r);
+    TFE_DeleteTensorHandle(result[0]);
+  }
   TFE_DeleteContext(ctx, status);
   ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
   TF_DeleteStatus(status);
@@ -788,9 +1001,10 @@ string MatMulFunction() {
   return def.SerializeAsString();
 }
 
-TEST(CAPI, FunctionDefAndExecute) {
+void FunctionDefAndExecute(bool async) {
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status);
   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteContextOptions(opts);
@@ -827,11 +1041,16 @@ TEST(CAPI, FunctionDefAndExecute) {
   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TF_DeleteStatus(status);
 }
+TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); }
+TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); }
 
-void BM_ExecuteFunction(int iters) {
+void BM_ExecuteFunction(int iters, int async) {
   tensorflow::testing::StopTiming();
+  tensorflow::testing::SetLabel(async ? "ExecuteFunctionAsync"
+                                      : "ExecuteFunction");
   TF_Status* status = TF_NewStatus();
   TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
   TFE_Context* ctx = TFE_NewContext(opts, status);
   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TFE_DeleteContextOptions(opts);
@@ -853,6 +1072,9 @@ void BM_ExecuteFunction(int iters) {
     TFE_Execute(matmul, &retval[0], &num_retvals, status);
     CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   }
+  if (async) {
+    TFE_ContextAsyncWait(ctx, status);
+  }
   tensorflow::testing::StopTiming();
   TFE_DeleteTensorHandle(m);
   TFE_DeleteTensorHandle(retval[0]);
@@ -860,7 +1082,7 @@ void BM_ExecuteFunction(int iters) {
   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
   TF_DeleteStatus(status);
 }
-BENCHMARK(BM_ExecuteFunction);
+BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1);
 
 TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
                                  TF_Status* status) {
index 985ed96..ad16f65 100644 (file)
@@ -185,7 +185,8 @@ class KernelAndDevice {
 
   Device* device() const { return device_; }
 
-  DataTypeVector* output_dtypes() { return &output_dtypes_; }
+  DataTypeVector* mutable_output_dtypes() { return &output_dtypes_; }
+  const DataTypeVector& output_dtypes() { return output_dtypes_; }
 
  private:
   std::unique_ptr<OpKernel> kernel_;
index d504ca0..012c68f 100644 (file)
@@ -250,13 +250,23 @@ class TFETest(test_util.TensorFlowTestCase):
 
   def testExecuteTooManyNumOutputs(self):
     # num_outputs provided is 50, but only one output is produced.
-    # That should be okay.
-    product = execute(
-        b'Mul',
-        num_outputs=50,
-        inputs=[constant_op.constant(3), constant_op.constant(5)],
-        attrs=('T', dtypes.int32.as_datatype_enum))[0]
-    self.assertAllEqual(15, product)
+    with self.assertRaises(errors.InvalidArgumentError):
+      _ = execute(
+          b'Mul',
+          num_outputs=50,
+          inputs=[constant_op.constant(3),
+                  constant_op.constant(5)],
+          attrs=('T', dtypes.int32.as_datatype_enum))[0]
+
+  def testExecuteTooFewNumOutputs(self):
+    # num_outputs provided is 50, but only one output is produced.
+    with self.assertRaises(errors.InvalidArgumentError):
+      _ = execute(
+          b'Mul',
+          num_outputs=0,
+          inputs=[constant_op.constant(3),
+                  constant_op.constant(5)],
+          attrs=('T', dtypes.int32.as_datatype_enum))[0]
 
   def testMatMulGPU(self):
     if not context.context().num_gpus():
index 8338bc4..105c09e 100644 (file)
@@ -340,8 +340,10 @@ void EagerTensor_dealloc(EagerTensor* self) {
   Py_DECREF(self->handle_data);
   Py_DECREF(self->keras_mask);
   Py_DECREF(self->tensor_shape);
-  TFE_DeleteTensorHandle(self->handle);
-  self->handle = nullptr;
+  if (self->handle != nullptr) {
+    TFE_DeleteTensorHandle(self->handle);
+    self->handle = nullptr;
+  }
   // We have the global interpreter lock, so use this chance to perform delayed
   // refcount decrements.
   tensorflow::ClearDecrefCache();
index fcb0452..fe9785d 100644 (file)
@@ -1012,7 +1012,14 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
   if (EagerTensor_CheckExact(tensor)) {
     TFE_TensorHandle* t = EagerTensor_Handle(tensor);
     tensorflow::int64 id = EagerTensor_id(tensor);
-    return tensorflow::eager::TapeTensor{id, t->t.dtype(), t->t.shape()};
+    const tensorflow::Tensor* tensor = nullptr;
+    const tensorflow::Status status = t->Tensor(&tensor);
+    if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+      return tensorflow::eager::TapeTensor{id, t->dtype,
+                                           tensorflow::TensorShape({})};
+    } else {
+      return tensorflow::eager::TapeTensor{id, t->dtype, tensor->shape()};
+    }
   }
   tensorflow::int64 id = FastTensorId(tensor);
   if (PyErr_Occurred()) {
index 343415b..02eafd4 100644 (file)
@@ -164,9 +164,9 @@ bool IsSingleNone(PyObject* obj) {
 }
 
 // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
-void ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
-                                  Tensor* output_tensor) {
-  *output_tensor = EagerTensor_Handle(eager_tensor)->t;
+tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
+                                                const Tensor** output_tensor) {
+  return EagerTensor_Handle(eager_tensor)->Tensor(output_tensor);
 }
 
 // Calls the registered py function through the trampoline.
@@ -220,7 +220,9 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
       if (call->eager) {
         const PyObject* item = PyList_GetItem(result, i);
         if (EagerTensor_CheckExact(item)) {
-          ExtractTensorFromEagerTensor(item, &t);
+          const Tensor* tensor = nullptr;
+          s = ExtractTensorFromEagerTensor(item, &tensor);
+          if (s.ok()) t = *tensor;
         } else {
           s = errors::FailedPrecondition(
               "Expected EagerTensor, found PyObject of type: ",
@@ -238,10 +240,10 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
   } else if (EagerTensor_CheckExact(result) || result == Py_None) {
     // result is an `EagerTensor` or `None`.
     DCHECK(call->eager);
-    Tensor t;
     if (result != Py_None) {
-      ExtractTensorFromEagerTensor(result, &t);
-      call->out.push_back(t);
+      const Tensor* t = nullptr;
+      s = ExtractTensorFromEagerTensor(result, &t);
+      if (s.ok()) call->out.push_back(*t);
     }
   } else if (PyArray_Check(result)) {
     // `result` is a NumPy array.