From: Derek Murray Date: Mon, 5 Feb 2018 21:43:52 +0000 (-0800) Subject: [tf.data] Fix use-after-free bug when closing down an input pipeline. X-Git-Tag: upstream/v1.7.0~31^2~1002 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1bbfc0c9cebd1808fa46024f738e56d10d57d97e;p=platform%2Fupstream%2Ftensorflow.git [tf.data] Fix use-after-free bug when closing down an input pipeline. This fix affects the distributed runtime; DirectSession use is unaffected. Before this change, an iterator that used a background prefetching thread might attempt to use a captured FunctionLibraryRuntime from a subgraph that had been deregistered (and hence its FunctionLibraryRuntime would have been deleted). This change introduces a mechanism for "cloning" the necessary parts of the FunctionLibraryRuntime so that it can be owned by the IteratorResource. PiperOrigin-RevId: 184579490 --- diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index e1b5404..d349d2b 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -182,6 +182,10 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { string DebugString(Handle h) override; + Status Clone(std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) override; + private: typedef FunctionLibraryRuntimeImpl ME; @@ -894,6 +898,21 @@ string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { } } +Status FunctionLibraryRuntimeImpl::Clone( + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) { + TF_RETURN_IF_ERROR( + parent_->Clone(env_, graph_def_version_, optimizer_.options(), + custom_kernel_creator_, out_lib_def, out_pflr)); + *out_flr = (*out_pflr)->GetFLR(device_->name()); + if (out_flr != nullptr) { + return Status::OK(); + } else { + return errors::Internal("Cloning FunctionLibraryRuntime failed."); + } +} + namespace { struct CustomCreatorSingleton { diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 8477cea..8024628 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -52,6 +52,8 @@ class GraphOptimizer { shape_map, const std::function& cse_consider_fn = nullptr); + const OptimizerOptions& options() { return opts_; } + private: OptimizerOptions opts_; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index dd4bf6a..41e1ce8 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -350,4 +350,16 @@ void ProcessFunctionLibraryRuntime::Run( done(errors::Internal("Could not find device")); } +Status ProcessFunctionLibraryRuntime::Clone( + Env* env, int graph_def_version, const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator, + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr) { + out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_)); + out_pflr->reset(new ProcessFunctionLibraryRuntime( + device_mgr_, env, graph_def_version, out_lib_def->get(), + optimizer_options, std::move(custom_kernel_creator), parent_)); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 9c9c92f..4296f94 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -145,6 +145,12 @@ class ProcessFunctionLibraryRuntime { // Removes handle from the state owned by this object. Status RemoveHandle(FunctionLibraryRuntime::Handle handle); + Status Clone(Env* env, int graph_def_version, + const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator, + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr); + friend class FunctionLibraryRuntimeImpl; mutable mutex mu_; diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 7d0e156..e270011 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -35,6 +35,7 @@ namespace tensorflow { class CancellationManager; class GraphDef; class OpKernel; +class ProcessFunctionLibraryRuntime; class ResourceMgr; class Rendezvous; class ScopedStepContainer; @@ -535,6 +536,10 @@ class FunctionLibraryRuntime { virtual int graph_def_version() = 0; typedef uint64 LocalHandle; + + virtual Status Clone(std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr) = 0; }; // Returns a canonicalized string for the instantiation of the diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index dd5f4a4..8a420ac 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -459,7 +459,7 @@ class IteratorHandleOp : public OpKernel { { mutex_lock l(mu_); if (resource_ == nullptr) { - FunctionLibraryRuntime* lib = context->function_library(); + FunctionLibraryRuntime* lib; std::unique_ptr device_mgr(nullptr); std::unique_ptr flib_def(nullptr); std::unique_ptr pflr(nullptr); @@ -469,6 +469,9 @@ class IteratorHandleOp : public OpKernel { // is sufficient demand, but it will require a significant refactoring. if (!name_.empty()) { lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr); + } else { + OP_REQUIRES_OK(context, context->function_library()->Clone( + &flib_def, &pflr, &lib)); } ResourceMgr* mgr = context->resource_manager(); @@ -538,7 +541,7 @@ class IteratorHandleOp : public OpKernel { // Wrap the existing device in order to see any captured resources // in its resource manager. The existing device will outlive the // IteratorResource, because we are storing the IteratorResource - // in that device's resourc manager. + // in that device's resource manager. Device* wrapped_device = RenamedDevice::NewRenamedDevice( ctx->device()->name(), down_cast(ctx->device()), false /* owns_underlying */, false /* isolate_session_state */); diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py index 2c65c49..25c91b4 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops @@ -30,6 +32,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test @@ -140,6 +143,33 @@ class IteratorClusterTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testImplicitDisposeParallelMapDataset(self): + # Tests whether a parallel map dataset will be cleaned up correctly when + # the pipeline does not run it until exhaustion. + # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> + # RepeatDataset(None) -> PrefetchDataset(100). + worker, _ = test_util.create_local_cluster(1, 1) + + components = (np.arange(1000), + np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], + np.array(37.0) * np.arange(1000)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(None).prefetch(10000)) + + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with session.Session(worker[0].target) as sess: + sess.run(init_op) + for _ in range(3): + sess.run(get_next) + if __name__ == "__main__": test.main()