[tf.data] Fix use-after-free bug when closing down an input pipeline.
authorDerek Murray <mrry@google.com>
Mon, 5 Feb 2018 21:43:52 +0000 (13:43 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 21:51:04 +0000 (13:51 -0800)
This fix affects the distributed runtime; DirectSession use is unaffected.

Before this change, an iterator that used a background prefetching
thread might attempt to use a captured FunctionLibraryRuntime from a
subgraph that had been deregistered (and hence its
FunctionLibraryRuntime would have been deleted). This change
introduces a mechanism for "cloning" the necessary parts of the
FunctionLibraryRuntime so that it can be owned by the
IteratorResource.

PiperOrigin-RevId: 184579490

tensorflow/core/common_runtime/function.cc
tensorflow/core/common_runtime/graph_optimizer.h
tensorflow/core/common_runtime/process_function_library_runtime.cc
tensorflow/core/common_runtime/process_function_library_runtime.h
tensorflow/core/framework/function.h
tensorflow/core/kernels/data/iterator_ops.cc
tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py

index e1b5404..d349d2b 100644 (file)
@@ -182,6 +182,10 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
 
   string DebugString(Handle h) override;
 
+  Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
+               std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
+               FunctionLibraryRuntime** out_flr) override;
+
  private:
   typedef FunctionLibraryRuntimeImpl ME;
 
@@ -894,6 +898,21 @@ string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
   }
 }
 
+Status FunctionLibraryRuntimeImpl::Clone(
+    std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
+    std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
+    FunctionLibraryRuntime** out_flr) {
+  TF_RETURN_IF_ERROR(
+      parent_->Clone(env_, graph_def_version_, optimizer_.options(),
+                     custom_kernel_creator_, out_lib_def, out_pflr));
+  *out_flr = (*out_pflr)->GetFLR(device_->name());
+  if (out_flr != nullptr) {
+    return Status::OK();
+  } else {
+    return errors::Internal("Cloning FunctionLibraryRuntime failed.");
+  }
+}
+
 namespace {
 
 struct CustomCreatorSingleton {
index 8477cea..8024628 100644 (file)
@@ -52,6 +52,8 @@ class GraphOptimizer {
           shape_map,
       const std::function<bool(const Node*)>& cse_consider_fn = nullptr);
 
+  const OptimizerOptions& options() { return opts_; }
+
  private:
   OptimizerOptions opts_;
 
index dd4bf6a..41e1ce8 100644 (file)
@@ -350,4 +350,16 @@ void ProcessFunctionLibraryRuntime::Run(
   done(errors::Internal("Could not find device"));
 }
 
+Status ProcessFunctionLibraryRuntime::Clone(
+    Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
+    CustomKernelCreator custom_kernel_creator,
+    std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
+    std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr) {
+  out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_));
+  out_pflr->reset(new ProcessFunctionLibraryRuntime(
+      device_mgr_, env, graph_def_version, out_lib_def->get(),
+      optimizer_options, std::move(custom_kernel_creator), parent_));
+  return Status::OK();
+}
+
 }  // namespace tensorflow
index 9c9c92f..4296f94 100644 (file)
@@ -145,6 +145,12 @@ class ProcessFunctionLibraryRuntime {
   // Removes handle from the state owned by this object.
   Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
 
+  Status Clone(Env* env, int graph_def_version,
+               const OptimizerOptions& optimizer_options,
+               CustomKernelCreator custom_kernel_creator,
+               std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
+               std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr);
+
   friend class FunctionLibraryRuntimeImpl;
 
   mutable mutex mu_;
index 7d0e156..e270011 100644 (file)
@@ -35,6 +35,7 @@ namespace tensorflow {
 class CancellationManager;
 class GraphDef;
 class OpKernel;
+class ProcessFunctionLibraryRuntime;
 class ResourceMgr;
 class Rendezvous;
 class ScopedStepContainer;
@@ -535,6 +536,10 @@ class FunctionLibraryRuntime {
   virtual int graph_def_version() = 0;
 
   typedef uint64 LocalHandle;
+
+  virtual Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
+                       std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
+                       FunctionLibraryRuntime** out_flr) = 0;
 };
 
 // Returns a canonicalized string for the instantiation of the
index dd5f4a4..8a420ac 100644 (file)
@@ -459,7 +459,7 @@ class IteratorHandleOp : public OpKernel {
     {
       mutex_lock l(mu_);
       if (resource_ == nullptr) {
-        FunctionLibraryRuntime* lib = context->function_library();
+        FunctionLibraryRuntime* lib;
         std::unique_ptr<DeviceMgr> device_mgr(nullptr);
         std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
         std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
@@ -469,6 +469,9 @@ class IteratorHandleOp : public OpKernel {
         // is sufficient demand, but it will require a significant refactoring.
         if (!name_.empty()) {
           lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
+        } else {
+          OP_REQUIRES_OK(context, context->function_library()->Clone(
+                                      &flib_def, &pflr, &lib));
         }
 
         ResourceMgr* mgr = context->resource_manager();
@@ -538,7 +541,7 @@ class IteratorHandleOp : public OpKernel {
     // Wrap the existing device in order to see any captured resources
     // in its resource manager. The existing device will outlive the
     // IteratorResource, because we are storing the IteratorResource
-    // in that device's resourc manager.
+    // in that device's resource manager.
     Device* wrapped_device = RenamedDevice::NewRenamedDevice(
         ctx->device()->name(), down_cast<Device*>(ctx->device()),
         false /* owns_underlying */, false /* isolate_session_state */);
index 2c65c49..25c91b4 100644 (file)
@@ -17,6 +17,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import numpy as np
+
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import session
 from tensorflow.python.data.ops import dataset_ops
@@ -30,6 +32,7 @@ from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.platform import test
 
@@ -140,6 +143,33 @@ class IteratorClusterTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
+  def testImplicitDisposeParallelMapDataset(self):
+    # Tests whether a parallel map dataset will be cleaned up correctly when
+    # the pipeline does not run it until exhaustion.
+    # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
+    # RepeatDataset(None) -> PrefetchDataset(100).
+    worker, _ = test_util.create_local_cluster(1, 1)
+
+    components = (np.arange(1000),
+                  np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
+                  np.array(37.0) * np.arange(1000))
+
+    def _map_fn(x, y, z):
+      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+    dataset = (
+        dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+        .repeat(None).prefetch(10000))
+
+    iterator = dataset.make_initializable_iterator()
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    with session.Session(worker[0].target) as sess:
+      sess.run(init_op)
+      for _ in range(3):
+        sess.run(get_next)
+
 
 if __name__ == "__main__":
   test.main()