Uses a thread pool for graph functions in eager mode with inter_op_parallelism_threads.
authorAlexandre Passos <apassos@google.com>
Mon, 26 Feb 2018 23:37:40 +0000 (15:37 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187092622

tensorflow/c/eager/BUILD
tensorflow/c/eager/c_api.cc
tensorflow/c/eager/c_api_internal.h
tensorflow/c/eager/runtime.cc
tensorflow/c/eager/runtime.h
tensorflow/c/eager/runtime_test.cc

index e55cb67..16a2a15 100644 (file)
@@ -21,6 +21,7 @@ tf_cuda_library(
     visibility = ["//visibility:public"],
     deps = select({
         "//tensorflow:android": [
+            "//tensorflow/core:lib",
             "//tensorflow/core:android_tensorflow_lib_lite",
         ],
         "//conditions:default": [
index bebb63c..b233dd5 100644 (file)
@@ -818,8 +818,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     // See WARNING comment below - would be nice to rework to avoid this
     // subtlety.
     tensorflow::tf_shared_lock l(ctx->functions_mu);
-    status->status =
-        tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
+    status->status = tensorflow::KernelAndDevice::Init(
+        ndef, ctx->func_lib(device), &ctx->runner, kernel);
     if (!status->status.ok()) {
       delete kernel;
       return;
index 3356054..29944df 100644 (file)
@@ -31,6 +31,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
 #include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/gtl/stl_util.h"
 #include "tensorflow/core/platform/mutex.h"
@@ -45,7 +46,15 @@ struct TFE_ContextOptions {
 
 struct TFE_Context {
   explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s)
-      : policy(opts.policy),
+      : thread_pool(new tensorflow::thread::ThreadPool(
+            opts.session_options.options.env, "EagerCompute",
+            opts.session_options.options.config
+                        .inter_op_parallelism_threads() != 0
+                ? opts.session_options.options.config
+                      .inter_op_parallelism_threads()
+                : tensorflow::port::NumSchedulableCPUs())),
+        runner([this](std::function<void()> f) { thread_pool->Schedule(f); }),
+        policy(opts.policy),
         session(s),
         rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)),
         pflr(new tensorflow::ProcessFunctionLibraryRuntime(
@@ -54,6 +63,9 @@ struct TFE_Context {
         log_device_placement(
             opts.session_options.options.config.log_device_placement()) {}
 
+  const std::unique_ptr<tensorflow::thread::ThreadPool> thread_pool;
+  std::function<void(std::function<void()>)> runner;
+
   const TFE_ContextDevicePlacementPolicy policy;
 
   // Note: we cannot use C++11 thread_local here as there is no concept of a
index 4bf24fe..b961842 100644 (file)
@@ -255,17 +255,22 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
   out->device_ = device;
   out->kernel_.reset(k);
   out->flib_ = nullptr;
+  out->runner_ = nullptr;
+  out->default_runner_ = [](std::function<void()> f) { f(); };
   return s;
 }
 
 // static
 Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
+                             std::function<void(std::function<void()>)>* runner,
                              KernelAndDevice* out) {
   OpKernel* k = nullptr;
   Status s = flib->CreateKernel(ndef, &k);
   out->device_ = flib->device();
   out->kernel_.reset(k);
   out->flib_ = flib;
+  out->runner_ = runner;
+  out->default_runner_ = [](std::function<void()> f) { f(); };
   return s;
 }
 
@@ -296,10 +301,11 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
   if (stats != nullptr) {
     params.track_allocations = true;
   }
-  // TODO(apassos): use a thread pool.
-  std::function<void(std::function<void()>)> runner =
-      [](std::function<void()> f) { f(); };
-  params.runner = &runner;
+  if (runner_ == nullptr) {
+    params.runner = &default_runner_;
+  } else {
+    params.runner = runner_;
+  }
 
   OpKernelContext context(&params);
   device_->Compute(kernel_.get(), &context);
index 7fede4d..fa5f839 100644 (file)
@@ -169,6 +169,7 @@ class KernelAndDevice {
   // the FunctionLibraryRuntime is pushed on to the caller (see locking in
   // c_api.cc).
   static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
+                     std::function<void(std::function<void()>)>* runner,
                      KernelAndDevice* out);
   // TODO(ashankar): Remove this
   static Status InitOp(Device* device, const NodeDef& ndef,
@@ -188,6 +189,8 @@ class KernelAndDevice {
  private:
   std::unique_ptr<OpKernel> kernel_;
   Device* device_;
+  std::function<void(std::function<void()>)>* runner_;
+  std::function<void(std::function<void()>)> default_runner_;
   FunctionLibraryRuntime* flib_;
   checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
   Rendezvous* rendez_;
index 6431530..ab0b535 100644 (file)
@@ -92,8 +92,8 @@ TEST(KernelAndDevice, Run) {
                    .BuildNodeDef());
   TestEnv env;
   KernelAndDevice kernel(nullptr);
-  Status s =
-      KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel);
+  Status s = KernelAndDevice::Init(ndef, env.function_library_runtime(),
+                                   nullptr, &kernel);
   ASSERT_TRUE(s.ok()) << s;
   std::vector<Tensor> outputs;
   s = kernel.Run(&inputs, &outputs, nullptr);
@@ -158,8 +158,8 @@ void BM_KernelAndDeviceInit(int iters) {
   KernelAndDevice k(nullptr);
   tensorflow::testing::StartTiming();
   for (int i = 0; i < iters; ++i) {
-    TF_CHECK_OK(
-        KernelAndDevice::Init(ndef, env.function_library_runtime(), &k));
+    TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
+                                      nullptr, &k));
   }
 }
 BENCHMARK(BM_KernelAndDeviceInit);
@@ -179,8 +179,8 @@ void BM_KernelAndDeviceRun(int iters) {
                    .BuildNodeDef());
   TestEnv env;
   KernelAndDevice kernel(nullptr);
-  TF_CHECK_OK(
-      KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel));
+  TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
+                                    nullptr, &kernel));
   tensorflow::testing::StartTiming();
   for (int i = 0; i < iters; ++i) {
     TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr));