Fixes a race condition in function instantiation.
authorDerek Murray <mrry@google.com>
Sun, 11 Mar 2018 22:38:16 +0000 (15:38 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 11 Mar 2018 22:42:32 +0000 (15:42 -0700)
Previously, if the same function was being concurrently instantiated
and released:

1. Thread one could begin to instantiate the function, determine
   that it already existed in the runtime, then be preempted.
2. Thread two could release the handle on the function, causing it to
   be freed and removed from the `FunctionLibraryRuntime::items_` map.
3. Thread one could then incorrectly assume that the function still
   existed, and fail to find it in the `FunctionLibraryRuntime::items_`
   map, causing a segfault when it attempted to increment the refcount
   on an uninitialized object.

PiperOrigin-RevId: 188661500

tensorflow/core/common_runtime/function.cc
tensorflow/python/data/kernel_tests/filter_dataset_op_test.py

index effe53c..37c59a1 100644 (file)
@@ -496,11 +496,26 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
   InstantiateOptions options_copy(options);
   options_copy.target = device_name_;
   const string key = Canonicalize(function_name, attrs, options_copy);
-  *handle = parent_->GetHandle(key);
-  if (*handle != kInvalidHandle) {
+
+  {
     mutex_lock l(mu_);
-    items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref();
-    return Status::OK();
+    *handle = parent_->GetHandle(key);
+    if (*handle != kInvalidHandle) {
+      FunctionLibraryRuntime::LocalHandle handle_on_device =
+          parent_->GetHandleOnDevice(device_name_, *handle);
+      if (handle_on_device == kInvalidLocalHandle) {
+        return errors::Internal("LocalHandle not found for handle ", *handle,
+                                ".");
+      }
+      auto item_handle = items_.find(handle_on_device);
+      if (item_handle == items_.end()) {
+        return errors::Internal("LocalHandle ", handle_on_device,
+                                " for handle ", *handle,
+                                " not found in items.");
+      }
+      item_handle->second->Ref();
+      return Status::OK();
+    }
   }
 
   Status s;
@@ -553,6 +568,7 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
   }
 
   LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
+  CHECK_NE(h, kInvalidLocalHandle);
   mutex_lock l(mu_);
   CHECK_EQ(1, items_.count(h));
   Item* item = items_[h];
index 2c71723..4f2216f 100644 (file)
@@ -176,6 +176,14 @@ class FilterDatasetTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
+  def testParallelFilters(self):
+    dataset = dataset_ops.Dataset.range(10).filter(
+        lambda x: math_ops.equal(x % 2, 0))
+    iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
+    next_elements = [iterator.get_next() for iterator in iterators]
+    with self.test_session() as sess:
+      self.assertEqual([0 for _ in range(10)], sess.run(next_elements))
+
 
 class FilterDatasetBenchmark(test.Benchmark):