Moves TFE_Executor to tensorflow::EagerExecutor in tensorflow/core/common_runtime...
authorAlexandre Passos <apassos@google.com>
Mon, 19 Mar 2018 20:43:50 +0000 (13:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 19 Mar 2018 20:51:08 +0000 (13:51 -0700)
PiperOrigin-RevId: 189634404

tensorflow/c/eager/BUILD
tensorflow/c/eager/c_api.cc
tensorflow/c/eager/c_api_internal.h
tensorflow/core/BUILD
tensorflow/core/common_runtime/eager/eager_executor.cc [new file with mode: 0644]
tensorflow/core/common_runtime/eager/eager_executor.h [new file with mode: 0644]

index 3046d90..73a3450 100644 (file)
@@ -27,6 +27,7 @@ tf_cuda_library(
             ":runtime",
             "//tensorflow/c:c_api",
             "//tensorflow/c:c_api_internal",
+            "//tensorflow/core:core_cpu",
             "//tensorflow/core:core_cpu_internal",
             "//tensorflow/core:framework",
             "//tensorflow/core:framework_internal",
@@ -54,6 +55,7 @@ tf_cuda_library(
         ":runtime",
         "//tensorflow/c:c_api",
         "//tensorflow/c:c_api_internal",
+        "//tensorflow/core:core_cpu",
         "//tensorflow/core:core_cpu_lib",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
index 455bc19..4e5703f 100644 (file)
@@ -165,7 +165,7 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
 
 // Note: this function looks up a thread local policy. So it should be called in
 // the appropriate client thread. In particular, in async mode, it may not be
-// safe to call this function from the async TFE_Executor threads.
+// safe to call this function from the async EagerExecutor threads.
 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
     TFE_Context* ctx) {
   tensorflow::mutex_lock ml(ctx->policy_map_mu);
@@ -731,15 +731,15 @@ tensorflow::Status Execute(
   return tensorflow::Status::OK();
 }
 
-// TODO(agarwal): move TFE_Executor and TFE_Node related code to a separate
+// TODO(agarwal): move EagerExecutor and EagerNode related code to a separate
 // file.
-class ExecuteNode : public TFE_Node {
+class ExecuteNode : public tensorflow::EagerNode {
  public:
   ExecuteNode(TFE_Op* op, tensorflow::KernelAndDevice* kernel,
               tensorflow::NodeExecStats* maybe_stats,
               const tensorflow::DataTypeVector& output_dtypes,
               TFE_TensorHandle** retvals, int num_retvals)
-      : TFE_Node(op->ctx->executor.NextId()),
+      : tensorflow::EagerNode(op->ctx->executor.NextId()),
         ctx_(op->ctx),
         op_device_(op->device),
         inputs_(op->inputs),
@@ -791,11 +791,11 @@ class ExecuteNode : public TFE_Node {
   tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals_;
 };
 
-class CopyToDeviceNode : public TFE_Node {
+class CopyToDeviceNode : public tensorflow::EagerNode {
  public:
   CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd,
                    TFE_Context* ctx)
-      : TFE_Node(ctx->executor.NextId()),
+      : tensorflow::EagerNode(ctx->executor.NextId()),
         src_(src),
         dstd_(dstd),
         ctx_(ctx),
@@ -1182,8 +1182,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     // Note that for async mode, execution order will make sure that all
     // input handles are ready before executing them.
     // TODO(agarwal): Consider executing "cheap" kernels inline for performance.
-    TFE_Node* node = new ExecuteNode(op, kernel, maybe_stats.release(),
-                                     output_dtypes, retvals, *num_retvals);
+    tensorflow::EagerNode* node =
+        new ExecuteNode(op, kernel, maybe_stats.release(), output_dtypes,
+                        retvals, *num_retvals);
     ctx->executor.Add(node);
   } else {
     // Execute checks if retvals[i] is nullptr or not to figure if it needs to
@@ -1214,8 +1215,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
     // make sure that `h` is ready before the copy is actually done.
     CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
     TFE_TensorHandle* output = node->dst();
-    // Note that calling Add makes `node` accessible by the TFE_Executor thread.
-    // So further accesses need to be thread-safe.
+    // Note that calling Add makes `node` accessible by the EagerExecutor
+    // thread. So further accesses need to be thread-safe.
     ctx->executor.Add(node);
     return output;
   } else {
@@ -1356,137 +1357,6 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
 }
 }  // namespace tensorflow
 
-TFE_Node::TFE_Node(tensorflow::uint64 id) : id(id) {}
-
-TFE_Executor::~TFE_Executor() {
-  tensorflow::mutex_lock l(node_queue_mutex_);
-  thread_done_ = true;
-  nodes_pending_.notify_all();
-}
-
-tensorflow::uint64 TFE_Executor::NextId() {
-  tensorflow::mutex_lock l(next_id_mutex_);
-  return next_id_++;
-}
-
-void TFE_Executor::EnableAsync() {
-  tensorflow::mutex_lock l(node_queue_mutex_);
-  if (thread_ == nullptr) {
-    thread_.reset(tensorflow::Env::Default()->StartThread(
-        tensorflow::ThreadOptions(), "eager_async_executor",
-        std::bind(&TFE_Executor::Run, this)));
-  }
-}
-
-void TFE_Executor::Add(TFE_Node* node) {
-  tensorflow::mutex_lock l(node_queue_mutex_);
-  DCHECK(thread_) << "EnableAsync should have been called before Add";
-  if (!status_.ok()) {
-    delete node;
-    return;
-  }
-  int qlen = node_queue_.size();
-  if (qlen > 0) {
-    if (node_queue_.back()->id >= node->id) {
-      status_ = tensorflow::errors::InvalidArgument(
-          "Inserting TFE_Node with non-increasing ids:", node_queue_.back()->id,
-          " vs ", node->id);
-      delete node;
-      return;
-    }
-    node_queue_.push(node);
-  } else {
-    node_queue_.push(node);
-    nodes_pending_.notify_all();
-  }
-}
-
-tensorflow::Status TFE_Executor::WaitFor(tensorflow::uint64 node_id) {
-  return WaitImpl(false, node_id);
-}
-
-tensorflow::Status TFE_Executor::WaitForAllPendingNodes() {
-  return WaitImpl(true, 0);
-}
-
-tensorflow::Status TFE_Executor::WaitImpl(bool wait_all,
-                                          tensorflow::uint64 node_id) {
-  tensorflow::condition_variable cond;
-  tensorflow::mutex_lock l(node_queue_mutex_);
-  // Don't wait if an error is already set.
-  if (!status_.ok()) return status_;
-  if (node_queue_.empty()) return tensorflow::Status::OK();
-  if (wait_all) {
-    node_id = node_queue_.back()->id;
-  } else if (node_id < node_queue_.front()->id) {
-    // Note that we are relying on the ops being dispatched sequentially from
-    // the queue.
-    return tensorflow::Status::OK();
-  }
-  node_done_notifications_.insert(std::make_pair(node_id, &cond));
-  cond.wait(l);
-  // Note that we could be woken up if an error occurs, even though the node has
-  // not actually executed.
-  return status_;
-}
-
-void TFE_Executor::ClearError() {
-  tensorflow::mutex_lock l(node_queue_mutex_);
-  if (status_.ok()) return;
-  // If an error was set, node_done_notifications_ and node_queue_ should have
-  // been cleared, and no new entries should have been added since.
-  DCHECK(node_done_notifications_.empty());
-  DCHECK(node_queue_.empty());
-  status_ = tensorflow::Status::OK();
-  nodes_pending_.notify_all();
-}
-
-tensorflow::Status TFE_Executor::status() {
-  tensorflow::mutex_lock l(node_queue_mutex_);
-  return status_;
-}
-
-void TFE_Executor::Run() {
-  while (true) {
-    std::unique_ptr<TFE_Node> curr_node;
-    {
-      tensorflow::mutex_lock l(node_queue_mutex_);
-      while (node_queue_.empty() || !status_.ok()) {
-        if (thread_done_) return;
-        nodes_pending_.wait(l);
-      }
-      curr_node.reset(node_queue_.front());
-    }
-    tensorflow::Status status = curr_node->Run();
-    const bool ok = status.ok();
-    tensorflow::mutex_lock l(node_queue_mutex_);
-    node_queue_.pop();
-    if (!ok) {
-      status_ = status;
-      // TODO(agarwal): mark all affected handles as corrupted before clearing
-      // this queue.
-      // We remove any pending ops so that we don't try to execute them if
-      // ClearError is called.
-      for (int i = 0; i < node_queue_.size(); ++i) {
-        delete node_queue_.front();
-        node_queue_.pop();
-      }
-    }
-    if (!node_done_notifications_.empty()) {
-      tensorflow::uint64 node_id = curr_node->id;
-      // Note that we notify all waiting threads in case an error has occurred.
-      // These calling threads are responsible for checking status_ before
-      // proceeding.
-      const auto range = ok ? node_done_notifications_.equal_range(node_id)
-                            : make_pair(node_done_notifications_.begin(),
-                                        node_done_notifications_.end());
-      for (auto it = range.first; it != range.second; ++it) {
-        it->second->notify_all();
-      }
-      node_done_notifications_.erase(range.first, range.second);
-    }
-  }
-}
 
 bool TFE_Context::Async() const {
   tensorflow::mutex_lock l(async_map_mu);
@@ -1502,7 +1372,7 @@ bool TFE_TensorHandle::IsReady() {
 
 tensorflow::Status TFE_TensorHandle::WaitReady() {
   if (node_id == 0) return tensorflow::Status::OK();
-  TFE_Executor* executor = nullptr;
+  tensorflow::EagerExecutor* executor = nullptr;
   {
     tensorflow::mutex_lock l(ctx_mutex_);
     if (ctx_ == nullptr) return tensorflow::Status::OK();
index 8dba12f..1edbe81 100644 (file)
@@ -30,6 +30,7 @@ limitations under the License.
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/c/eager/runtime.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
 #include "tensorflow/core/framework/rendezvous.h"
@@ -40,101 +41,6 @@ limitations under the License.
 #include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/public/version.h"
 
-// A unit of execution for the TFE_Executor class below. Example subclasses
-// encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one
-// device to another.
-class TFE_Node {
- public:
-  explicit TFE_Node(tensorflow::uint64 id);
-
-  virtual ~TFE_Node() {}
-
-  // Runs the computation corresponding to this node and blocks till the
-  // execution is done.
-  virtual tensorflow::Status Run() = 0;
-
-  // An id unique to the TFE_Context under which this node is created. Allocated
-  // monotonically.
-  const tensorflow::uint64 id;
-};
-
-// A class for handling async execution (see TFE_ContextSetAsync).
-// Note that this class is thread-safe.
-// TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the
-// device of the input handle. Fix that.
-// TODO(agarwal): On error, mark all affected handles as corrupted.
-// TODO(agarwal): Implement support for control dependencies.
-// TODO(agarwal): Support out-of-order execution and dispatching multiple
-// TFE_Node in parallel.
-// TODO(agarwal): Implement optimizations over TFE_Node traces.
-class TFE_Executor {
- public:
-  ~TFE_Executor();
-
-  // This is called whenever async mode is enabled. Note that it may be called
-  // multiple times as different calling threads may switch async mode on or off
-  // independently.
-  void EnableAsync();
-
-  // Helper function to create monotonically increasing ids unique to this
-  // object.
-  tensorflow::uint64 NextId();
-
-  // Schedules `node` for execution.
-  // Note that Add must be called in monotonically increasing order of node->id.
-  void Add(TFE_Node* node);
-
-  // Causes the caller to block till node with id `node_id` has finished
-  // execution.
-  tensorflow::Status WaitFor(tensorflow::uint64 node_id);
-
-  // Blocks till all currently pending ops are done.
-  tensorflow::Status WaitForAllPendingNodes();
-
-  // Clears all currently set errors which re-enables async execution.
-  void ClearError();
-
-  // Returns Status based on any errors that occurred during async execution.
-  tensorflow::Status status();
-
- private:
-  // Starts execution of pending TFE_Nodes. This function loops till
-  // thread_done_ is set to true. If any errors are encontered, these are set
-  // inside `status_`. The loop blocks anytime there are no pending nodes, or if
-  // `status_` is not ok.
-  void Run();
-
-  tensorflow::Status WaitImpl(bool wait_all, tensorflow::uint64 node_id);
-
-  tensorflow::mutex node_queue_mutex_;
-
-  // Used to signal that some TFE_Nodes are pending execution.
-  tensorflow::condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_);
-
-  // Queue of pending TFE_Nodes.
-  std::queue<TFE_Node*> node_queue_ GUARDED_BY(node_queue_mutex_);
-
-  // `status_` is set based on any errors raised during execution of a TFE_Node.
-  // It remains set until ClearError is called.
-  tensorflow::Status status_ GUARDED_BY(node_queue_mutex_);
-
-  // Map from id of a TFE_Node to condition_variables (not owned by the map).
-  // These condition_variables are notified and removed when that TFE_Node is
-  // done executing, or if an error is found in execution of any TFE_Node.
-  std::multimap<tensorflow::uint64, tensorflow::condition_variable*>
-      node_done_notifications_ GUARDED_BY(node_queue_mutex_);
-
-  // Thread object that calls the `Run` method. Currently we use only one thread
-  // for executing the TFE_Nodes one-by-one.
-  std::unique_ptr<tensorflow::Thread> thread_ GUARDED_BY(node_queue_mutex_);
-
-  // Indicates that `thread_` should stop as soon as it is done executing the
-  // current TFE_Node.
-  bool thread_done_ GUARDED_BY(node_queue_mutex_) = false;
-
-  tensorflow::mutex next_id_mutex_;
-  tensorflow::uint64 next_id_ GUARDED_BY(next_id_mutex_) = 1;
-};
 
 struct TFE_ContextOptions {
   TF_SessionOptions session_options;
@@ -203,8 +109,8 @@ struct TFE_Context {
   tensorflow::mutex metadata_mu;
   tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu);
   const bool log_device_placement;
-  // TFE_Executor for async execution.
-  TFE_Executor executor;
+  // EagerExecutor for async execution.
+  tensorflow::EagerExecutor executor;
 
   // True if running in asynchronous mode.
   bool Async() const;
@@ -263,13 +169,13 @@ struct TFE_TensorHandle : public tensorflow::core::RefCounted {
 
  private:
   // If the contents of the Tensor pointed to by this handle is yet to be
-  // computed by a TFE_Node, this function will block till that compuatation is
+  // computed by a EagerNode, this function will block till that compuatation is
   // done and the handle is "ready".
   tensorflow::Status WaitReady();
 
   bool IsReady();
 
-  // Id for the TFE_Node that will compute the value pointed to by this handle.
+  // Id for the EagerNode that will compute the value pointed to by this handle.
   // If the value is 0, the handle is already ready, but not vice-versa.
   const tensorflow::uint64 node_id;
 
index df44857..cf29444 100644 (file)
@@ -793,6 +793,7 @@ tf_cuda_library(
     hdrs = [
         "common_runtime/device.h",
         "common_runtime/device_factory.h",
+        "common_runtime/eager/eager_executor.h",
         "common_runtime/optimization_registry.h",
         "common_runtime/shape_refiner.h",
         "graph/algorithm.h",
@@ -2141,6 +2142,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
     "common_runtime/stats_publisher_interface.h",
     "common_runtime/step_stats_collector.h",
     "common_runtime/threadpool_device.h",
+    "common_runtime/eager/eager_executor.h",
     "graph/gradients.h",
     "graph/quantize_training.h",
 ] + if_mkl(["graph/mkl_graph_util.h"])
@@ -2160,6 +2162,7 @@ tf_cuda_library(
         "common_runtime/device_factory.cc",
         "common_runtime/device_mgr.cc",
         "common_runtime/device_set.cc",
+        "common_runtime/eager/eager_executor.cc",
         "common_runtime/executor.cc",
         "common_runtime/function.cc",
         "common_runtime/graph_optimizer.cc",
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc
new file mode 100644 (file)
index 0000000..b699036
--- /dev/null
@@ -0,0 +1,152 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+
+namespace tensorflow {
+
+EagerNode::EagerNode(tensorflow::uint64 id) : id(id) {}
+
+EagerExecutor::~EagerExecutor() {
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  thread_done_ = true;
+  nodes_pending_.notify_all();
+}
+
+tensorflow::uint64 EagerExecutor::NextId() {
+  tensorflow::mutex_lock l(next_id_mutex_);
+  return next_id_++;
+}
+
+void EagerExecutor::EnableAsync() {
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  if (thread_ == nullptr) {
+    thread_.reset(tensorflow::Env::Default()->StartThread(
+        tensorflow::ThreadOptions(), "eager_async_executor",
+        std::bind(&EagerExecutor::Run, this)));
+  }
+}
+
+void EagerExecutor::Add(EagerNode* node) {
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  DCHECK(thread_) << "EnableAsync should have been called before Add";
+  if (!status_.ok()) {
+    delete node;
+    return;
+  }
+  int64 qlen = node_queue_.size();
+  if (qlen > 0) {
+    if (node_queue_.back()->id >= node->id) {
+      status_ = tensorflow::errors::InvalidArgument(
+          "Inserting EagerNode with non-increasing ids:",
+          node_queue_.back()->id, " vs ", node->id);
+      delete node;
+      return;
+    }
+    node_queue_.push(node);
+  } else {
+    node_queue_.push(node);
+    nodes_pending_.notify_all();
+  }
+}
+
+tensorflow::Status EagerExecutor::WaitFor(tensorflow::uint64 node_id) {
+  return WaitImpl(false, node_id);
+}
+
+tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
+  return WaitImpl(true, 0);
+}
+
+tensorflow::Status EagerExecutor::WaitImpl(bool wait_all,
+                                           tensorflow::uint64 node_id) {
+  tensorflow::condition_variable cond;
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  // Don't wait if an error is already set.
+  if (!status_.ok()) return status_;
+  if (node_queue_.empty()) return tensorflow::Status::OK();
+  if (wait_all) {
+    node_id = node_queue_.back()->id;
+  } else if (node_id < node_queue_.front()->id) {
+    // Note that we are relying on the ops being dispatched sequentially from
+    // the queue.
+    return tensorflow::Status::OK();
+  }
+  node_done_notifications_.insert(std::make_pair(node_id, &cond));
+  cond.wait(l);
+  // Note that we could be woken up if an error occurs, even though the node has
+  // not actually executed.
+  return status_;
+}
+
+void EagerExecutor::ClearError() {
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  if (status_.ok()) return;
+  // If an error was set, node_done_notifications_ and node_queue_ should have
+  // been cleared, and no new entries should have been added since.
+  DCHECK(node_done_notifications_.empty());
+  DCHECK(node_queue_.empty());
+  status_ = tensorflow::Status::OK();
+  nodes_pending_.notify_all();
+}
+
+tensorflow::Status EagerExecutor::status() {
+  tensorflow::mutex_lock l(node_queue_mutex_);
+  return status_;
+}
+
+void EagerExecutor::Run() {
+  while (true) {
+    std::unique_ptr<EagerNode> curr_node;
+    {
+      tensorflow::mutex_lock l(node_queue_mutex_);
+      while (node_queue_.empty() || !status_.ok()) {
+        if (thread_done_) return;
+        nodes_pending_.wait(l);
+      }
+      curr_node.reset(node_queue_.front());
+    }
+    tensorflow::Status status = curr_node->Run();
+    const bool ok = status.ok();
+    tensorflow::mutex_lock l(node_queue_mutex_);
+    node_queue_.pop();
+    if (!ok) {
+      status_ = status;
+      // TODO(agarwal): mark all affected handles as corrupted before clearing
+      // this queue.
+      // We remove any pending ops so that we don't try to execute them if
+      // ClearError is called.
+      for (int i = 0; i < node_queue_.size(); ++i) {
+        delete node_queue_.front();
+        node_queue_.pop();
+      }
+    }
+    if (!node_done_notifications_.empty()) {
+      tensorflow::uint64 node_id = curr_node->id;
+      // Note that we notify all waiting threads in case an error has occurred.
+      // These calling threads are responsible for checking status_ before
+      // proceeding.
+      const auto range = ok ? node_done_notifications_.equal_range(node_id)
+                            : make_pair(node_done_notifications_.begin(),
+                                        node_done_notifications_.end());
+      for (auto it = range.first; it != range.second; ++it) {
+        it->second->notify_all();
+      }
+      node_done_notifications_.erase(range.first, range.second);
+    }
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/eager_executor.h b/tensorflow/core/common_runtime/eager/eager_executor.h
new file mode 100644 (file)
index 0000000..021daeb
--- /dev/null
@@ -0,0 +1,138 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_
+
+#include <algorithm>
+#include <cstddef>
+#include <map>
+#include <memory>
+#include <queue>
+#include <string>
+#include <thread>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+// A unit of execution for the EagerExecutor class below. Example subclasses
+// encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one
+// device to another.
+class EagerNode {
+ public:
+  explicit EagerNode(uint64 id);
+
+  virtual ~EagerNode() {}
+
+  // Runs the computation corresponding to this node and blocks till the
+  // execution is done.
+  virtual Status Run() = 0;
+
+  // An id unique to the TFE_Context under which this node is created. Allocated
+  // monotonically.
+  const uint64 id;
+};
+
+// A class for handling async execution (see TFE_ContextSetAsync).
+// Note that this class is thread-safe.
+// TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the
+// device of the input handle. Fix that.
+// TODO(agarwal): On error, mark all affected handles as corrupted.
+// TODO(agarwal): Implement support for control dependencies.
+// TODO(agarwal): Support out-of-order execution and dispatching multiple
+// EagerNode in parallel.
+// TODO(agarwal): Implement optimizations over EagerNode traces.
+class EagerExecutor {
+ public:
+  ~EagerExecutor();
+
+  // This is called whenever async mode is enabled. Note that it may be called
+  // multiple times as different calling threads may switch async mode on or off
+  // independently.
+  void EnableAsync();
+
+  // Helper function to create monotonically increasing ids unique to this
+  // object.
+  uint64 NextId();
+
+  // Schedules `node` for execution.
+  // Note that Add must be called in monotonically increasing order of node->id.
+  void Add(EagerNode* node);
+
+  // Causes the caller to block till node with id `node_id` has finished
+  // execution.
+  Status WaitFor(uint64 node_id);
+
+  // Blocks till all currently pending ops are done.
+  Status WaitForAllPendingNodes();
+
+  // Clears all currently set errors which re-enables async execution.
+  void ClearError();
+
+  // Returns Status based on any errors that occurred during async execution.
+  Status status();
+
+ private:
+  // Starts execution of pending EagerNodes. This function loops till
+  // thread_done_ is set to true. If any errors are encontered, these are set
+  // inside `status_`. The loop blocks anytime there are no pending nodes, or if
+  // `status_` is not ok.
+  void Run();
+
+  Status WaitImpl(bool wait_all, uint64 node_id);
+
+  mutex node_queue_mutex_;
+
+  // Used to signal that some EagerNodes are pending execution.
+  condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_);
+
+  // Queue of pending EagerNodes.
+  std::queue<EagerNode*> node_queue_ GUARDED_BY(node_queue_mutex_);
+
+  // `status_` is set based on any errors raised during execution of a
+  // EagerNode.  It remains set until ClearError is called.
+  Status status_ GUARDED_BY(node_queue_mutex_);
+
+  // Map from id of a EagerNode to condition_variables (not owned by the map).
+  // These condition_variables are notified and removed when that EagerNode is
+  // done executing, or if an error is found in execution of any EagerNode.
+  std::multimap<uint64, condition_variable*> node_done_notifications_
+      GUARDED_BY(node_queue_mutex_);
+
+  // Thread object that calls the `Run` method. Currently we use only one thread
+  // for executing the EagerNodes one-by-one.
+  std::unique_ptr<Thread> thread_ GUARDED_BY(node_queue_mutex_);
+
+  // Indicates that `thread_` should stop as soon as it is done executing the
+  // current EagerNode.
+  bool thread_done_ GUARDED_BY(node_queue_mutex_) = false;
+
+  mutex next_id_mutex_;
+  uint64 next_id_ GUARDED_BY(next_id_mutex_) = 1;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_