Fixes a bug where the ProcFLR doesn't lookup existing instantiations in the
authorRohan Jain <rohanj@google.com>
Fri, 9 Mar 2018 00:45:45 +0000 (16:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Mar 2018 00:52:20 +0000 (16:52 -0800)
distributed (ClusterFLR) case. As a result multiple instantiations for the same
function were happening.

PiperOrigin-RevId: 188411978

tensorflow/core/BUILD
tensorflow/core/common_runtime/process_function_library_runtime.cc
tensorflow/core/common_runtime/process_function_library_runtime.h
tensorflow/core/common_runtime/process_function_library_runtime_test.cc

index 0fbe4eb..f2b0d54 100644 (file)
@@ -3156,6 +3156,7 @@ tf_cc_test(
         ":core_cpu",
         ":core_cpu_internal",
         ":framework",
+        ":lib",
         ":test",
         ":test_main",
         ":testlib",
index 929f5c6..44dc6f9 100644 (file)
@@ -25,6 +25,19 @@ namespace tensorflow {
 
 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
 
+Status ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
+    DistributedFunctionLibraryRuntime* parent, const string& function_name,
+    const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
+    const FunctionLibraryRuntime::InstantiateOptions& options) {
+  mutex_lock l(mu_);
+  if (!init_started_) {
+    init_started_ = true;
+    init_result_ = parent->Instantiate(function_name, lib_def, attrs, options,
+                                       &local_handle_);
+  }
+  return init_result_;
+}
+
 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
     const DeviceMgr* device_mgr, Env* env, int graph_def_version,
     const FunctionLibraryDefinition* lib_def,
@@ -167,7 +180,8 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
     if (function_data_.count(h) != 0) return h;
   }
   h = next_handle_;
-  function_data_.insert({h, FunctionData(device_name, local_handle)});
+  FunctionData* fd = new FunctionData(device_name, local_handle);
+  function_data_[h] = std::unique_ptr<FunctionData>(fd);
   table_[function_key] = h;
   next_handle_++;
   return h;
@@ -196,19 +210,19 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
   if (function_data_.count(handle) == 0) {
     return kInvalidLocalHandle;
   }
-  const FunctionData& function_data = function_data_[handle];
-  if (function_data.target_device != device_name) {
+  FunctionData* function_data = function_data_[handle].get();
+  if (function_data->target_device() != device_name) {
     return kInvalidLocalHandle;
   }
-  return function_data.local_handle;
+  return function_data->local_handle();
 }
 
 string ProcessFunctionLibraryRuntime::GetDeviceName(
     FunctionLibraryRuntime::Handle handle) {
   mutex_lock l(mu_);
   CHECK_EQ(1, function_data_.count(handle));
-  const FunctionData& function_data = function_data_[handle];
-  return function_data.target_device;
+  FunctionData* function_data = function_data_[handle].get();
+  return function_data->target_device();
 }
 
 Status ProcessFunctionLibraryRuntime::Instantiate(
@@ -225,11 +239,26 @@ Status ProcessFunctionLibraryRuntime::Instantiate(
         "Currently don't support instantiating functions on device: ",
         options.target);
   }
-  FunctionLibraryRuntime::Handle cluster_handle;
-  TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs,
-                                          options, &cluster_handle));
+
   string function_key = Canonicalize(function_name, attrs);
-  *handle = AddHandle(function_key, options.target, cluster_handle);
+  FunctionData* f;
+  {
+    mutex_lock l(mu_);
+    FunctionLibraryRuntime::Handle h =
+        gtl::FindWithDefault(table_, function_key, kInvalidHandle);
+    if (h == kInvalidHandle || function_data_.count(h) == 0) {
+      h = next_handle_;
+      FunctionData* fd = new FunctionData(options.target, kInvalidHandle);
+      function_data_[h] = std::unique_ptr<FunctionData>(fd);
+      table_[function_key] = h;
+      next_handle_++;
+    }
+    f = function_data_[h].get();
+    *handle = h;
+  }
+  TF_RETURN_IF_ERROR(
+      f->DistributedInit(parent_, function_name, *lib_def_, attrs, options));
+
   return Status::OK();
 }
 
@@ -247,7 +276,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseHandle(
   {
     mutex_lock l(mu_);
     CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
-    target_device = function_data_[handle].target_device;
+    target_device = function_data_[handle]->target_device();
   }
   flr = GetFLR(target_device);
   if (flr != nullptr) {
@@ -276,8 +305,8 @@ void ProcessFunctionLibraryRuntime::Run(
       done(errors::NotFound("Handle: ", handle, " not found."));
       return;
     }
-    target_device = function_data_[handle].target_device;
-    local_handle = function_data_[handle].local_handle;
+    target_device = function_data_[handle]->target_device();
+    local_handle = function_data_[handle]->local_handle();
   }
   flr = GetFLR(target_device);
   if (flr != nullptr) {
index 0473e16..10619ba 100644 (file)
@@ -145,14 +145,31 @@ class ProcessFunctionLibraryRuntime {
 
   mutable mutex mu_;
 
-  struct FunctionData {
-    const string target_device;
-    const FunctionLibraryRuntime::LocalHandle local_handle;
-
+  class FunctionData {
+   public:
     FunctionData(const string& target_device,
                  FunctionLibraryRuntime::LocalHandle local_handle)
-        : target_device(target_device), local_handle(local_handle) {}
-    FunctionData() : FunctionData("", -1) {}
+        : target_device_(target_device), local_handle_(local_handle) {}
+
+    string target_device() { return target_device_; }
+
+    FunctionLibraryRuntime::LocalHandle local_handle() { return local_handle_; }
+
+    // Initializes the FunctionData object by potentially making an Initialize
+    // call to the DistributedFunctionLibraryRuntime.
+    Status DistributedInit(
+        DistributedFunctionLibraryRuntime* parent, const string& function_name,
+        const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
+        const FunctionLibraryRuntime::InstantiateOptions& options);
+
+   private:
+    mutex mu_;
+
+    const string target_device_;
+    FunctionLibraryRuntime::LocalHandle local_handle_ GUARDED_BY(mu_);
+    bool init_started_ GUARDED_BY(mu_) = false;
+    Status init_result_ GUARDED_BY(mu_);
+    Notification init_done_;
   };
 
   const DeviceMgr* const device_mgr_;
@@ -160,7 +177,8 @@ class ProcessFunctionLibraryRuntime {
   // Holds all the function invocations here.
   std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
       GUARDED_BY(mu_);
-  std::unordered_map<FunctionLibraryRuntime::Handle, FunctionData>
+  std::unordered_map<FunctionLibraryRuntime::Handle,
+                     std::unique_ptr<FunctionData>>
       function_data_ GUARDED_BY(mu_);
   std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>> flr_map_;
   int next_handle_ GUARDED_BY(mu_);
index 439ba1c..ab1f919 100644 (file)
@@ -19,9 +19,11 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/device_factory.h"
 #include "tensorflow/core/common_runtime/function_testlib.h"
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function_testlib.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/public/session_options.h"
 #include "tensorflow/core/public/version.h"
@@ -29,8 +31,32 @@ limitations under the License.
 namespace tensorflow {
 namespace {
 
+class TestClusterFLR : public DistributedFunctionLibraryRuntime {
+ public:
+  TestClusterFLR() {}
+
+  Status Instantiate(const string& function_name,
+                     const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
+                     const FunctionLibraryRuntime::InstantiateOptions& options,
+                     FunctionLibraryRuntime::LocalHandle* handle) {
+    mutex_lock l(mu_);
+    *handle = next_handle_;
+    next_handle_++;
+    return Status::OK();
+  }
+
+  void Run(const FunctionLibraryRuntime::Options& opts,
+           FunctionLibraryRuntime::LocalHandle handle,
+           gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
+           FunctionLibraryRuntime::DoneCallback done) {}
+
+ private:
+  mutex mu_;
+  int next_handle_ GUARDED_BY(mu_) = 0;
+};
+
 class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
- protected:
+ public:
   void Init(const std::vector<FunctionDef>& flib) {
     SessionOptions options;
     auto* device_count = options.config.mutable_device_count();
@@ -42,12 +68,20 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
     for (const auto& fdef : flib) *(proto.add_function()) = fdef;
     lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
     OptimizerOptions opts;
+    cluster_flr_.reset(new TestClusterFLR());
     proc_flr_.reset(new ProcessFunctionLibraryRuntime(
         device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
-        opts, nullptr /* cluster_flr */));
+        opts, cluster_flr_.get()));
     rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
   }
 
+  Status Instantiate(
+      const string& name, test::function::Attrs attrs,
+      const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
+      FunctionLibraryRuntime::Handle* handle) {
+    return proc_flr_->Instantiate(name, attrs, instantiate_opts, handle);
+  }
+
   Status Run(const string& name, FunctionLibraryRuntime::Options opts,
              test::function::Attrs attrs,
              const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
@@ -106,6 +140,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
   std::vector<Device*> devices_;
   std::unique_ptr<DeviceMgr> device_mgr_;
   std::unique_ptr<FunctionLibraryDefinition> lib_def_;
+  std::unique_ptr<TestClusterFLR> cluster_flr_;
   std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
   IntraProcessRendezvous* rendezvous_;
 };
@@ -250,5 +285,60 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
   rendezvous_->Unref();
 }
 
+TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
+  Init({test::function::FindDevice()});
+  FunctionLibraryRuntime::Options opts;
+  opts.source_device = "/job:a/replica:0/task:0/cpu:0";
+  opts.rendezvous = rendezvous_;
+  opts.remote_execution = true;
+  FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
+  instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
+  FunctionLibraryRuntime::Handle h;
+  TF_CHECK_OK(Instantiate("FindDevice",
+                          {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
+                          instantiate_opts, &h));
+  EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
+                   "/job:b/replica:0/task:0/device:CPU:0", h));
+  TF_CHECK_OK(Instantiate("FindDevice",
+                          {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
+                          instantiate_opts, &h));
+  EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
+                   "/job:b/replica:0/task:0/device:CPU:0", h));
+  instantiate_opts.target = "/job:c/replica:0/task:0/device:CPU:0";
+  TF_CHECK_OK(Instantiate("FindDevice",
+                          {{"_target", "/job:c/replica:0/task:0/device:CPU:0"}},
+                          instantiate_opts, &h));
+  EXPECT_EQ(1, proc_flr_->GetHandleOnDevice(
+                   "/job:c/replica:0/task:0/device:CPU:0", h));
+  rendezvous_->Unref();
+}
+
+TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) {
+  Init({test::function::FindDevice()});
+  FunctionLibraryRuntime::Options opts;
+  opts.source_device = "/job:a/replica:0/task:0/cpu:0";
+  opts.rendezvous = rendezvous_;
+  opts.remote_execution = true;
+  FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
+  instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
+
+  thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
+  auto fn = [this, &instantiate_opts]() {
+    FunctionLibraryRuntime::Handle h;
+    TF_CHECK_OK(Instantiate(
+        "FindDevice", {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
+        instantiate_opts, &h));
+    EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
+                     "/job:b/replica:0/task:0/device:CPU:0", h));
+  };
+
+  for (int i = 0; i < 100; ++i) {
+    tp->Schedule(fn);
+  }
+  delete tp;
+
+  rendezvous_->Unref();
+}
+
 }  // anonymous namespace
 }  // namespace tensorflow