[Functions] Fix unbounded memory growth in FunctionLibraryRuntime.
authorDerek Murray <mrry@google.com>
Wed, 9 May 2018 16:42:18 +0000 (09:42 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 17:52:05 +0000 (10:52 -0700)
A recent change modified the behavior of `FunctionLibraryRuntimeImpl::ReleaseHandle()`
so that it no longer freed the memory associated with an instantiated function. Since
we rely on instantiating and releasing a potentially large number of instances of the
same function in tf.data to isolate the (e.g. random number generator) state in each
instance, this change meant that the memory consumption could grow without bound in
a simple program like:

```python

  ds = tf.data.Dataset.from_tensors(0).repeat(None)

  # The function `lambda y: y + 1` would be instantiated for each element in the input.
  ds = ds.flat_map(lambda x: tf.data.Dataset.from_tensors(x).map(
    lambda y: y + tf.random_uniform([], minval=0, maxval=10, dtype=tf.int32)))

  iterator = ds.make_one_shot_iterator()
  next_elem = iterator.get_next()

  with tf.Session() as sess:
    while True:
      sess.run(next_elem)
```

PiperOrigin-RevId: 195983977

tensorflow/core/common_runtime/function.cc
tensorflow/core/common_runtime/function_test.cc
tensorflow/core/common_runtime/function_threadpool_test.cc
tensorflow/core/common_runtime/process_function_library_runtime.cc
tensorflow/core/common_runtime/process_function_library_runtime.h
tensorflow/core/common_runtime/process_function_library_runtime_test.cc

index bf05f6f..d05564e 100644 (file)
@@ -208,19 +208,19 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
 
   // The instantiated and transformed function is encoded as a Graph
   // object, and an executor is created for the graph.
-  struct Item : public core::RefCounted {
-    bool invalidated = false;
+  struct Item {
+    uint64 instantiation_counter = 0;
     const Graph* graph = nullptr;                            // Owned by exec.
     const FunctionLibraryDefinition* overlay_lib = nullptr;  // Not owned.
     FunctionBody* func_graph = nullptr;
     Executor* exec = nullptr;
 
-    ~Item() override {
+    ~Item() {
       delete this->func_graph;
       delete this->exec;
     }
   };
-  std::unordered_map<Handle, Item*> items_ GUARDED_BY(mu_);
+  std::unordered_map<Handle, std::unique_ptr<Item>> items_ GUARDED_BY(mu_);
 
   ProcessFunctionLibraryRuntime* parent_ = nullptr;  // not owned.
 
@@ -284,9 +284,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
   }
 }
 
-FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
-  for (auto p : items_) p.second->Unref();
-}
+FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {}
 
 // An asynchronous op kernel which executes an instantiated function
 // defined in a library.
@@ -490,30 +488,24 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
   options_copy.target = device_name_;
   const string key = Canonicalize(function_name, attrs, options_copy);
 
-  Handle found_handle = kInvalidHandle;
   {
     mutex_lock l(mu_);
-    found_handle = parent_->GetHandle(key);
-    if (found_handle != kInvalidHandle) {
+    *handle = parent_->GetHandle(key);
+    if (*handle != kInvalidHandle) {
       FunctionLibraryRuntime::LocalHandle handle_on_device =
-          parent_->GetHandleOnDevice(device_name_, found_handle);
+          parent_->GetHandleOnDevice(device_name_, *handle);
       if (handle_on_device == kInvalidLocalHandle) {
         return errors::Internal("LocalHandle not found for handle ", *handle,
                                 ".");
       }
-      auto iter = items_.find(handle_on_device);
-      if (iter == items_.end()) {
+      auto item_handle = items_.find(handle_on_device);
+      if (item_handle == items_.end()) {
         return errors::Internal("LocalHandle ", handle_on_device,
-                                " for handle ", found_handle,
+                                " for handle ", *handle,
                                 " not found in items.");
       }
-      Item* item = iter->second;
-      if (!item->invalidated) {
-        *handle = found_handle;
-        return Status::OK();
-      }
-      // *item is invalidated. Fall through and instantiate the given
-      // function_name/attrs/option again.
+      ++item_handle->second->instantiation_counter;
+      return Status::OK();
     }
   }
 
@@ -545,16 +537,18 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
 
   {
     mutex_lock l(mu_);
-    Handle found_handle_again = parent_->GetHandle(key);
-    if (found_handle_again != found_handle) {
+    *handle = parent_->GetHandle(key);
+    if (*handle != kInvalidHandle) {
       delete fbody;
-      *handle = found_handle_again;
+      ++items_[parent_->GetHandleOnDevice(device_name_, *handle)]
+            ->instantiation_counter;
     } else {
       *handle = parent_->AddHandle(key, device_name_, next_handle_);
       Item* item = new Item;
       item->func_graph = fbody;
       item->overlay_lib = options.overlay_lib;
-      items_.insert({next_handle_, item});
+      item->instantiation_counter = 1;
+      items_.emplace(next_handle_, std::unique_ptr<Item>(item));
       next_handle_++;
     }
   }
@@ -565,12 +559,17 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
   if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
     return parent_->ReleaseHandle(handle);
   }
+
   LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
   CHECK_NE(h, kInvalidLocalHandle);
   mutex_lock l(mu_);
   CHECK_EQ(1, items_.count(h));
-  Item* item = items_[h];
-  item->invalidated = true;  // Reinstantiate later.
+  std::unique_ptr<Item>& item = items_[h];
+  --item->instantiation_counter;
+  if (item->instantiation_counter == 0) {
+    items_.erase(h);
+    TF_RETURN_IF_ERROR(parent_->RemoveHandle(handle));
+  }
   return Status::OK();
 }
 
@@ -680,7 +679,7 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
       return errors::NotFound("Function handle ", handle,
                               " is not valid. Likely an internal error.");
     }
-    *item = items_[local_handle];
+    *item = items_[local_handle].get();
     if ((*item)->exec != nullptr) {
       return Status::OK();
     }
@@ -731,7 +730,6 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
   // computation is done and stored in *rets, we send the return values back
   // to the source_device (caller) so that the ProcFLR can receive them later.
   std::vector<Tensor>* remote_args = new std::vector<Tensor>;
-  item->Ref();
   ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
       source_device, target_device, "arg_", src_incarnation, args.size(),
       device_context, {}, rendezvous, remote_args,
@@ -743,7 +741,6 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
           s = frame->SetArgs(*remote_args);
         }
         if (!s.ok()) {
-          item->Unref();
           delete frame;
           delete remote_args;
           delete exec_args;
@@ -751,10 +748,9 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
           return;
         }
         item->exec->RunAsync(
-            *exec_args, [item, frame, rets, done, source_device, target_device,
+            *exec_args, [frame, rets, done, source_device, target_device,
                          target_incarnation, rendezvous, device_context,
                          remote_args, exec_args](const Status& status) {
-              core::ScopedUnref unref(item);
               Status s = status;
               if (s.ok()) {
                 s = frame->ConsumeRetvals(rets);
@@ -840,13 +836,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
     return;
   }
 
-  item->Ref();
   item->exec->RunAsync(
       // Executor args
       *exec_args,
       // Done callback.
-      [item, frame, rets, done, exec_args](const Status& status) {
-        core::ScopedUnref unref(item);
+      [frame, rets, done, exec_args](const Status& status) {
         Status s = status;
         if (s.ok()) {
           s = frame->ConsumeRetvals(rets);
@@ -906,7 +900,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
   exec_args->runner = *run_opts.runner;
   exec_args->call_frame = frame;
 
-  item->Ref();
   item->exec->RunAsync(
       // Executor args
       *exec_args,
@@ -915,7 +908,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
           [item, frame, exec_args](DoneCallback done,
                                    // Start unbound arguments.
                                    const Status& status) {
-            core::ScopedUnref unref(item);
             delete exec_args;
             done(status);
           },
index 373fc64..61b2f0e 100644 (file)
@@ -231,8 +231,19 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
       return status;
     }
     FunctionLibraryRuntime::Options opts;
-    TF_RETURN_IF_ERROR(Run(flr, handle, opts, args, rets, add_runner));
-    return flr->ReleaseHandle(handle);
+    status = Run(flr, handle, opts, args, rets, add_runner);
+    if (!status.ok()) return status;
+
+    // Release the handle and try running again. It should not succeed.
+    status = flr->ReleaseHandle(handle);
+    if (!status.ok()) return status;
+
+    Status status2 = Run(flr, handle, opts, args, std::move(rets));
+    EXPECT_TRUE(errors::IsInvalidArgument(status2));
+    EXPECT_TRUE(
+        str_util::StrContains(status2.error_message(), "remote execution."));
+
+    return status;
   }
 
   Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
@@ -293,8 +304,16 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
       *rets[i] = retvals[i];
     }
 
-    // Release the handle.
-    return flr->ReleaseHandle(handle);
+    // Release the handle and try running again. It should not succeed.
+    status = flr->ReleaseHandle(handle);
+    if (!status.ok()) return status;
+
+    Status status2 = Run(flr, handle, opts, args, std::move(rets));
+    EXPECT_TRUE(errors::IsInvalidArgument(status2));
+    EXPECT_TRUE(
+        str_util::StrContains(status2.error_message(), "remote execution."));
+
+    return status;
   }
 
   std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
index 98dac38..2d09e83 100644 (file)
@@ -144,7 +144,19 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
       return status;
     }
     FunctionLibraryRuntime::Options opts;
-    return Run(flr, handle, opts, args, std::move(rets), add_runner);
+    status = Run(flr, handle, opts, args, rets, add_runner);
+    if (!status.ok()) return status;
+
+    // Release the handle and try running again. It should not succeed.
+    status = flr->ReleaseHandle(handle);
+    if (!status.ok()) return status;
+
+    Status status2 = Run(flr, handle, opts, args, std::move(rets));
+    EXPECT_TRUE(errors::IsInvalidArgument(status2));
+    EXPECT_TRUE(
+        str_util::StrContains(status2.error_message(), "remote execution."));
+
+    return status;
   }
 
   Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
index 668ce87..729312a 100644 (file)
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/rendezvous_util.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/ptr_util.h"
 
 namespace tensorflow {
 
@@ -183,8 +184,8 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
     FunctionLibraryRuntime::LocalHandle local_handle) {
   mutex_lock l(mu_);
   auto h = next_handle_;
-  FunctionData* fd = new FunctionData(device_name, local_handle);
-  function_data_[h] = std::unique_ptr<FunctionData>(fd);
+  function_data_[h] = MakeUnique<FunctionData>(
+      device_name, local_handle, function_key);
   table_[function_key] = h;
   next_handle_++;
   return h;
@@ -247,8 +248,8 @@ Status ProcessFunctionLibraryRuntime::Instantiate(
         gtl::FindWithDefault(table_, function_key, kInvalidHandle);
     if (h == kInvalidHandle || function_data_.count(h) == 0) {
       h = next_handle_;
-      FunctionData* fd = new FunctionData(options.target, kInvalidHandle);
-      function_data_[h] = std::unique_ptr<FunctionData>(fd);
+      function_data_[h] = MakeUnique<FunctionData>(
+          options.target, kInvalidHandle, function_key);
       table_[function_key] = h;
       next_handle_++;
     }
@@ -263,6 +264,14 @@ Status ProcessFunctionLibraryRuntime::Instantiate(
   return Status::OK();
 }
 
+Status ProcessFunctionLibraryRuntime::RemoveHandle(
+    FunctionLibraryRuntime::Handle handle) {
+  mutex_lock l(mu_);
+  table_.erase(function_data_[handle]->function_key());
+  function_data_.erase(handle);
+  return Status::OK();
+}
+
 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
     FunctionLibraryRuntime::Handle handle) {
   FunctionLibraryRuntime* flr = nullptr;
index 05e5770..69381dd 100644 (file)
@@ -134,6 +134,9 @@ class ProcessFunctionLibraryRuntime {
   // of the device where the function is registered.
   string GetDeviceName(FunctionLibraryRuntime::Handle handle);
 
+  // Removes handle from the state owned by this object.
+  Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
+
   Status Clone(Env* env, int graph_def_version,
                const OptimizerOptions& optimizer_options,
                CustomKernelCreator custom_kernel_creator,
@@ -147,10 +150,14 @@ class ProcessFunctionLibraryRuntime {
   class FunctionData {
    public:
     FunctionData(const string& target_device,
-                 FunctionLibraryRuntime::LocalHandle local_handle)
-        : target_device_(target_device), local_handle_(local_handle) {}
+                 FunctionLibraryRuntime::LocalHandle local_handle,
+                 const string& function_key)
+        : target_device_(target_device),
+          local_handle_(local_handle),
+          function_key_(function_key) {}
 
     string target_device() { return target_device_; }
+    const string& function_key() { return function_key_; }
 
     FunctionLibraryRuntime::LocalHandle local_handle() {
       mutex_lock l(mu_);
@@ -169,6 +176,7 @@ class ProcessFunctionLibraryRuntime {
 
     const string target_device_;
     FunctionLibraryRuntime::LocalHandle local_handle_ GUARDED_BY(mu_);
+    const string function_key_;
     bool init_started_ GUARDED_BY(mu_) = false;
     Status init_result_ GUARDED_BY(mu_);
     Notification init_done_;
index cc10e77..4fbf2ab 100644 (file)
@@ -119,13 +119,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
 
     EXPECT_GE(call_count, 1);  // Test runner is used.
 
-    // Release the handle and then try running the function.  It
-    // should still succeed.
+    // Release the handle and then try running the function. It shouldn't
+    // succeed.
     status = proc_flr_->ReleaseHandle(handle);
     if (!status.ok()) {
       return status;
     }
-
     Notification done2;
     proc_flr_->Run(opts, handle, args, &out,
                    [&status, &done2](const Status& s) {
@@ -133,7 +132,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
                      done2.Notify();
                    });
     done2.WaitForNotification();
-    return status;
+    EXPECT_TRUE(errors::IsNotFound(status));
+    EXPECT_TRUE(str_util::StrContains(status.error_message(), "not found."));
+
+    return Status::OK();
   }
 
   std::vector<Device*> devices_;