From acf78b20f71dd8c3a928b1f12ea4de6f5028fc48 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 26 Feb 2018 15:37:40 -0800 Subject: [PATCH] Uses a thread pool for graph functions in eager mode with inter_op_parallelism_threads. PiperOrigin-RevId: 187092622 --- tensorflow/c/eager/BUILD | 1 + tensorflow/c/eager/c_api.cc | 4 ++-- tensorflow/c/eager/c_api_internal.h | 14 +++++++++++++- tensorflow/c/eager/runtime.cc | 14 ++++++++++---- tensorflow/c/eager/runtime.h | 3 +++ tensorflow/c/eager/runtime_test.cc | 12 ++++++------ 6 files changed, 35 insertions(+), 13 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index e55cb67..16a2a15 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -21,6 +21,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ + "//tensorflow/core:lib", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index bebb63c7..b233dd5 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -818,8 +818,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // See WARNING comment below - would be nice to rework to avoid this // subtlety. tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = - tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); + status->status = tensorflow::KernelAndDevice::Init( + ndef, ctx->func_lib(device), &ctx->runner, kernel); if (!status->status.ok()) { delete kernel; return; diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 3356054..29944df 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" @@ -45,7 +46,15 @@ struct TFE_ContextOptions { struct TFE_Context { explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s) - : policy(opts.policy), + : thread_pool(new tensorflow::thread::ThreadPool( + opts.session_options.options.env, "EagerCompute", + opts.session_options.options.config + .inter_op_parallelism_threads() != 0 + ? opts.session_options.options.config + .inter_op_parallelism_threads() + : tensorflow::port::NumSchedulableCPUs())), + runner([this](std::function f) { thread_pool->Schedule(f); }), + policy(opts.policy), session(s), rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)), pflr(new tensorflow::ProcessFunctionLibraryRuntime( @@ -54,6 +63,9 @@ struct TFE_Context { log_device_placement( opts.session_options.options.config.log_device_placement()) {} + const std::unique_ptr thread_pool; + std::function)> runner; + const TFE_ContextDevicePlacementPolicy policy; // Note: we cannot use C++11 thread_local here as there is no concept of a diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index 4bf24fe..b961842 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -255,17 +255,22 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, out->device_ = device; out->kernel_.reset(k); out->flib_ = nullptr; + out->runner_ = nullptr; + out->default_runner_ = [](std::function f) { f(); }; return s; } // static Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, + std::function)>* runner, KernelAndDevice* out) { OpKernel* k = nullptr; Status s = flib->CreateKernel(ndef, &k); out->device_ = flib->device(); out->kernel_.reset(k); out->flib_ = flib; + out->runner_ = runner; + out->default_runner_ = [](std::function f) { f(); }; return s; } @@ -296,10 +301,11 @@ Status KernelAndDevice::Run(std::vector* input_tensors, if (stats != nullptr) { params.track_allocations = true; } - // TODO(apassos): use a thread pool. - std::function)> runner = - [](std::function f) { f(); }; - params.runner = &runner; + if (runner_ == nullptr) { + params.runner = &default_runner_; + } else { + params.runner = runner_; + } OpKernelContext context(¶ms); device_->Compute(kernel_.get(), &context); diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index 7fede4d..fa5f839 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -169,6 +169,7 @@ class KernelAndDevice { // the FunctionLibraryRuntime is pushed on to the caller (see locking in // c_api.cc). static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, + std::function)>* runner, KernelAndDevice* out); // TODO(ashankar): Remove this static Status InitOp(Device* device, const NodeDef& ndef, @@ -188,6 +189,8 @@ class KernelAndDevice { private: std::unique_ptr kernel_; Device* device_; + std::function)>* runner_; + std::function)> default_runner_; FunctionLibraryRuntime* flib_; checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; Rendezvous* rendez_; diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 6431530..ab0b535 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -92,8 +92,8 @@ TEST(KernelAndDevice, Run) { .BuildNodeDef()); TestEnv env; KernelAndDevice kernel(nullptr); - Status s = - KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel); + Status s = KernelAndDevice::Init(ndef, env.function_library_runtime(), + nullptr, &kernel); ASSERT_TRUE(s.ok()) << s; std::vector outputs; s = kernel.Run(&inputs, &outputs, nullptr); @@ -158,8 +158,8 @@ void BM_KernelAndDeviceInit(int iters) { KernelAndDevice k(nullptr); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK( - KernelAndDevice::Init(ndef, env.function_library_runtime(), &k)); + TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(), + nullptr, &k)); } } BENCHMARK(BM_KernelAndDeviceInit); @@ -179,8 +179,8 @@ void BM_KernelAndDeviceRun(int iters) { .BuildNodeDef()); TestEnv env; KernelAndDevice kernel(nullptr); - TF_CHECK_OK( - KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel)); + TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(), + nullptr, &kernel)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr)); -- 2.7.4