Fix another eager PyObject leak
authorAllen Lavoie <allenl@google.com>
Mon, 12 Mar 2018 20:00:24 +0000 (13:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 20:07:14 +0000 (13:07 -0700)
Shockingly this one was also due to PySequence_GetItem.

PiperOrigin-RevId: 188765548

tensorflow/python/framework/test_util.py
tensorflow/python/framework/test_util_test.py
tensorflow/python/kernel_tests/constant_op_test.py
tensorflow/python/layers/core_test.py
tensorflow/python/lib/core/py_seq_tensor.cc

index fde9c85..c4952cf 100644 (file)
@@ -434,6 +434,32 @@ def with_c_api(cls):
   return cls
 
 
+def assert_no_new_pyobjects_executing_eagerly(f):
+  """Decorator for asserting that no new Python objects persist after a test.
+
+  Runs the test multiple times executing eagerly, first as a warmup and then
+  several times to let objects accumulate. The warmup helps ignore caches which
+  do not grow as the test is run repeatedly.
+
+  Useful for checking that there are no missing Py_DECREFs in the C exercised by
+  a bit of Python.
+  """
+  def decorator(self, **kwargs):
+    """Warms up, gets an object count, runs the test, checks for new objects."""
+    with context.eager_mode():
+      gc.disable()
+      f(self, **kwargs)
+      gc.collect()
+      previous_count = len(gc.get_objects())
+      for _ in range(3):
+        f(self, **kwargs)
+      gc.collect()
+      # There should be no new Python objects hanging around.
+      new_count = len(gc.get_objects())
+      self.assertEqual(previous_count, new_count)
+      gc.enable()
+  return decorator
+
 def assert_no_new_tensors(f):
   """Decorator for asserting that no new Tensors persist after a test.
 
index 20d8160..02ffa93 100644 (file)
@@ -448,6 +448,26 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
 
     LeakedTensorTest().test_has_no_leak()
 
+  def test_no_new_objects_decorator(self):
+
+    class LeakedObjectTest(object):
+
+      def __init__(inner_self):  # pylint: disable=no-self-argument
+        inner_self.assertEqual = self.assertEqual  # pylint: disable=invalid-name
+        inner_self.accumulation = []
+
+      @test_util.assert_no_new_pyobjects_executing_eagerly
+      def test_has_leak(self):
+        self.accumulation.append([1.])
+
+      @test_util.assert_no_new_pyobjects_executing_eagerly
+      def test_has_no_leak(self):
+        self.not_accumulating = [1.]
+
+    with self.assertRaises(AssertionError):
+      LeakedObjectTest().test_has_leak()
+
+    LeakedObjectTest().test_has_no_leak()
 
 if __name__ == "__main__":
   googletest.main()
index 16e5634..ffbdb0e 100644 (file)
@@ -30,6 +30,7 @@ from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradient_checker
 from tensorflow.python.ops import logging_ops
@@ -180,6 +181,11 @@ class ConstantTest(test.TestCase):
           shape=[2, 3, 5])
     self.assertEqual(c.get_shape(), [2, 3, 5])
 
+  @test_util.assert_no_new_pyobjects_executing_eagerly
+  def testEagerMemory(self):
+    """Tests PyObject refs are managed correctly when executing eagerly."""
+    constant_op.constant([[1.]])
+
   def testImplicitShapeNumPy(self):
     with ops.Graph().as_default():
       c = constant_op.constant(
index 7d74046..cf45b07 100644 (file)
@@ -19,7 +19,6 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
-import gc
 
 import numpy as np
 
@@ -84,27 +83,13 @@ class DenseTest(test.TestCase):
     self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
     self.assertEqual(dense.bias.name, 'my_dense/bias:0')
 
+  @test_util.assert_no_new_pyobjects_executing_eagerly
   def testNoEagerLeak(self):
     # Tests that repeatedly constructing and building a Layer does not leak
     # Python objects.
-    def _test_fn():
-      inputs = random_ops.random_uniform((5, 4), seed=1)
-      core_layers.Dense(5)(inputs)
-      core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')(inputs)
-
-    with context.eager_mode():
-      _test_fn()  # warmup
-      gc.disable()
-      gc.collect()
-      object_count = len(gc.get_objects())
-      for _ in range(100):
-        _test_fn()
-      gc.collect()
-      self.assertLessEqual(
-          len(gc.get_objects()),
-          # DEBUG_SAVEALL messes with this slightly.
-          object_count + 1)
-      gc.enable()
+    inputs = random_ops.random_uniform((5, 4), seed=1)
+    core_layers.Dense(5)(inputs)
+    core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')(inputs)
 
   @test_util.run_in_graph_and_eager_modes()
   def testCallTensorDot(self):
index 317bdc2..8247d35 100644 (file)
@@ -84,6 +84,7 @@ bool IsPyDimension(PyObject* obj) {
 }
 
 Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
+  std::vector<Safe_PyObjectPtr> refs_to_clean;
   while (true) {
     // We test strings first, in case a string is considered a sequence.
     if (IsPyString(obj)) {
@@ -93,6 +94,7 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
       if (length > 0) {
         shape->AddDim(length);
         obj = PySequence_GetItem(obj, 0);
+        refs_to_clean.push_back(make_safe(obj));
         continue;
       } else if (length == 0) {
         shape->AddDim(length);
@@ -167,14 +169,15 @@ const char ErrorFoundFloat[] =
     if (shape.dims() > 1) {                                               \
       /* Iterate over outer dim, and recursively convert each element. */ \
       const int64 s = shape.dim_size(0);                                  \
-      if (TF_PREDICT_FALSE(s != PySequence_Length(obj))) {                \
+      Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, ""));         \
+      if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) {   \
         return ErrorRectangular;                                          \
       }                                                                   \
       TensorShape rest = shape;                                           \
       rest.RemoveDim(0);                                                  \
       for (int64 i = 0; i < s; ++i) {                                     \
-        const char* error =                                               \
-            FUNCTION##Helper(PySequence_GetItem(obj, i), rest, buf);      \
+        const char* error = FUNCTION##Helper(                             \
+            PySequence_Fast_GET_ITEM(seq.get(), i), rest, buf);           \
         if (TF_PREDICT_FALSE(error != nullptr)) return error;             \
       }                                                                   \
     } else {                                                              \