Add AnonymousIteratorHandleOp for non-shared Iterator resources
authorAllen Lavoie <allenl@google.com>
Tue, 29 May 2018 17:34:20 +0000 (10:34 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 17:36:51 +0000 (10:36 -0700)
Fixes Iterator cleanup when executing eagerly. DestroyResourceOp will now remove the last reference from the Iterator resource when it runs (after the last Python reference to an EagerIterator is removed).

Previously EagerIterator used IteratorHandleOp to create resource handles, which used one kernel per (unique) shared name since the shared name was an attribute. These kernels each held a reference to their resource, which kept it alive indefinitely.

Fixes #19499.

PiperOrigin-RevId: 198417997

tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt [new file with mode: 0644]
tensorflow/core/kernels/data/iterator_ops.cc
tensorflow/core/ops/dataset_ops.cc
tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
tensorflow/python/data/ops/iterator_ops.py

diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt
new file mode 100644 (file)
index 0000000..d8c2ed4
--- /dev/null
@@ -0,0 +1,13 @@
+op {
+  graph_op_name: "AnonymousIterator"
+  out_arg {
+    name: "handle"
+    description: <<END
+A handle to the iterator that can be passed to a "MakeIterator" or
+"IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents
+resource sharing by name, and does not keep a reference to the resource
+container.
+END
+  }
+  summary: "A container for an iterator resource."
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/python_api/api_def_AnonymousIterator.pbtxt
new file mode 100644 (file)
index 0000000..98b7def
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "AnonymousIterator"
+  visibility: HIDDEN
+}
index b6bf0ec..87bc8eb 100644 (file)
@@ -438,6 +438,9 @@ class IteratorStateVariant {
 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
                                        kIteratorVariantTypeName);
 
+// Note that IteratorHandleOp holds a reference to the resource it creates. If
+// cleaning up resources with DestroyResourceOp is important, consider creating
+// resource containers with AnonymousIteratorHandleOp instead.
 class IteratorHandleOp : public OpKernel {
  public:
   explicit IteratorHandleOp(OpKernelConstruction* ctx)
@@ -574,6 +577,75 @@ class IteratorHandleOp : public OpKernel {
   string name_;
 };
 
+// Like IteratorHandleOp, but creates handles which are never shared, and does
+// not hold a reference to these handles. The latter is important for eager
+// execution, since OpKernel instances generally live as long as the program
+// running them.
+class AnonymousIteratorHandleOp : public OpKernel {
+ public:
+  explicit AnonymousIteratorHandleOp(OpKernelConstruction* context)
+      : OpKernel(context), graph_def_version_(context->graph_def_version()) {
+    OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_));
+    OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    FunctionLibraryRuntime* lib;
+    std::unique_ptr<DeviceMgr> device_mgr(nullptr);
+    std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+    std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+    OP_REQUIRES_OK(context,
+                   context->function_library()->Clone(&flib_def, &pflr, &lib));
+
+    ResourceMgr* mgr = context->resource_manager();
+
+    const string container_name = "AnonymousIterator";
+    string unique_name;
+    {
+      mutex_lock l(static_resource_lookup_mutex_);
+      while (true) {  // Find an unused name
+        IteratorResource* existing_resource = nullptr;
+        unique_name = strings::StrCat("AnonymousIterator", current_id_++);
+        Status status = mgr->Lookup<IteratorResource>(
+            container_name, unique_name, &existing_resource);
+        if (status.code() == error::NOT_FOUND) {
+          break;
+        }
+        OP_REQUIRES_OK(context, status);
+        existing_resource->Unref();
+      }
+      IteratorResource* new_resource = new IteratorResource(
+          output_dtypes_, output_shapes_, graph_def_version_,
+          std::move(device_mgr), std::move(flib_def), std::move(pflr), lib);
+      // Create the resource with our chosen name under the resource lookup
+      // mutex to avoid another kernel racily creating a resource with this
+      // name.
+      OP_REQUIRES_OK(context, mgr->Create<IteratorResource>(
+                                  container_name, unique_name, new_resource));
+    }
+    OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+                                context, 0, container_name, unique_name,
+                                MakeTypeIndex<IteratorResource>()));
+  }
+
+ private:
+  // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp
+  // instances.
+  static mutex static_resource_lookup_mutex_;
+  // current_id_ is just a hint for creating unique names. If it turns out
+  // there's a collision (e.g. because another AnonymousIteratorHandleOp
+  // instance is generating handles) we'll just skip that id.
+  static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_);
+  DataTypeVector output_dtypes_;
+  std::vector<PartialTensorShape> output_shapes_;
+  const int graph_def_version_;
+};
+
+// Static initializers for AnonymousIteratorHandleOp id counting.
+mutex AnonymousIteratorHandleOp::static_resource_lookup_mutex_{
+    LINKER_INITIALIZED};
+int64 AnonymousIteratorHandleOp::current_id_(0);
+
 class MakeIteratorOp : public OpKernel {
  public:
   explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -1066,6 +1138,8 @@ class DeserializeIteratorOp : public OpKernel {
 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
                         MakeIteratorOp);
+REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_CPU),
+                        AnonymousIteratorHandleOp);
 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
                         ToSingleElementOp);
 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
index 576946e..6d7d863 100644 (file)
@@ -564,6 +564,12 @@ REGISTER_OP("Iterator")
     .Attr("output_shapes: list(shape) >= 1")
     .SetShapeFn(shape_inference::ScalarShape);
 
+REGISTER_OP("AnonymousIterator")
+    .Output("handle: resource")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn(shape_inference::ScalarShape);
+
 REGISTER_OP("MakeIterator")
     .Input("dataset: variant")
     .Input("iterator: resource")
index 1ddedfd..e99f0a2 100644 (file)
@@ -24,6 +24,7 @@ import zlib
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.data.ops import readers
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -38,6 +39,13 @@ from tensorflow.python.platform import test
 from tensorflow.python.util import compat
 
 
+try:
+  import psutil  # pylint: disable=g-import-not-at-top
+  psutil_import_succeeded = True
+except ImportError:
+  psutil_import_succeeded = False
+
+
 class TextLineDatasetTest(test.TestCase):
 
   def _lineText(self, f, l):
@@ -162,6 +170,34 @@ class TextLineDatasetTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(iterator.get_next())
 
+  def testIteratorResourceCleanup(self):
+    filename = os.path.join(self.get_temp_dir(), "text.txt")
+    with open(filename, "wt") as f:
+      for i in range(3):
+        f.write("%d\n" % (i,))
+    with context.eager_mode():
+      first_iterator = iter(readers.TextLineDataset(filename))
+      self.assertEqual(b"0", next(first_iterator).numpy())
+      second_iterator = iter(readers.TextLineDataset(filename))
+      self.assertEqual(b"0", next(second_iterator).numpy())
+      # Eager kernel caching is based on op attributes, which includes the
+      # Dataset's output shape. Create a different kernel to test that they
+      # don't create resources with the same names.
+      different_kernel_iterator = iter(
+          readers.TextLineDataset(filename).repeat().batch(16))
+      self.assertEqual([16], next(different_kernel_iterator).shape)
+      # Remove our references to the Python Iterator objects, which (assuming no
+      # reference cycles) is enough to trigger DestroyResourceOp and close the
+      # partially-read files.
+      del first_iterator
+      del second_iterator
+      del different_kernel_iterator
+      if not psutil_import_succeeded:
+        self.skipTest(
+            "psutil is required to check that we've closed our files.")
+      open_files = psutil.Process().open_files()
+      self.assertNotIn(filename, [open_file.path for open_file in open_files])
+
 
 class FixedLengthRecordReaderTest(test.TestCase):
 
index fd16427..b6dba4e 100644 (file)
@@ -471,9 +471,7 @@ class EagerIterator(object):
           sparse.as_dense_types(self._output_types, self._output_classes))
       self._flat_output_shapes = nest.flatten(
           sparse.as_dense_shapes(self._output_shapes, self._output_classes))
-      self._resource = gen_dataset_ops.iterator(
-          shared_name="",
-          container=_generate_shared_name("eageriterator"),
+      self._resource = gen_dataset_ops.anonymous_iterator(
           output_types=self._flat_output_types,
           output_shapes=self._flat_output_shapes)
       gen_dataset_ops.make_iterator(ds_variant, self._resource)