Fixed memory leak with py_func (#18292) (#19085)
authorvoegtlel <l.voegtle@gmx.de>
Tue, 29 May 2018 16:15:10 +0000 (18:15 +0200)
committerDerek Murray <derek.murray@gmail.com>
Tue, 29 May 2018 16:15:10 +0000 (09:15 -0700)
* Fixing memory leak with py_func (#18292)

* Fixed memory leak with py_func (#18292)

tensorflow/python/kernel_tests/py_func_test.py
tensorflow/python/ops/script_ops.py

index b9f44d7..c899945 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import gc
 import re
 
 import numpy as np
@@ -432,13 +433,29 @@ class PyFuncTest(test.TestCase):
 
   # ----- Tests shared by py_func and eager_py_func -----
   def testCleanup(self):
-    for _ in xrange(1000):
-      g = ops.Graph()
-      with g.as_default():
-        c = constant_op.constant([1.], dtypes.float32)
-        _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
-        _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
-    self.assertTrue(script_ops._py_funcs.size() < 100)
+    # Delete everything created by previous tests to avoid side effects.
+    ops.reset_default_graph()
+    gc.collect()
+    initial_size = script_ops._py_funcs.size()
+    # Encapsulate the graph generation, so locals can be deleted.
+    def make_graphs():
+      for _ in xrange(1000):
+        g = ops.Graph()
+        with g.as_default():
+          c = constant_op.constant([1.], dtypes.float32)
+          _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
+          _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
+          # These ops have a reference to 'c' which has a reference to the graph.
+          # Checks if the functions are being deleted though the graph is referenced from them.
+          # (see #18292)
+          _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
+          _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
+    # Call garbage collector to enforce deletion.
+    make_graphs()
+    ops.reset_default_graph()
+    gc.collect()
+    self.assertEqual(initial_size, script_ops._py_funcs.size())
 
   # ----- Tests for eager_py_func -----
   @test_util.run_in_graph_and_eager_modes()
index f87c5dc..16c7321 100644 (file)
@@ -24,6 +24,7 @@ import threading
 
 # Used by py_util.cc to get tracebacks.
 import traceback  # pylint: disable=unused-import
+import weakref
 
 import numpy as np
 import six
@@ -88,11 +89,14 @@ class FuncRegistry(object):
   def __init__(self):
     self._lock = threading.Lock()
     self._unique_id = 0  # GUARDED_BY(self._lock)
-    self._funcs = {}
+    # Only store weakrefs to the funtions. The strong reference is stored in
+    # the graph.
+    self._funcs = weakref.WeakValueDictionary()
 
   def insert(self, func):
     """Registers `func` and returns a unique token for this entry."""
     token = self._next_unique_token()
+    # Store a weakref to the function
     self._funcs[token] = func
     return token
 
@@ -145,7 +149,7 @@ class FuncRegistry(object):
     Raises:
       ValueError: if no function is registered for `token`.
     """
-    func = self._funcs[token]
+    func = self._funcs.get(token, None)
     if func is None:
       raise ValueError("callback %s is not found" % token)
     if isinstance(func, EagerFunc):
@@ -180,19 +184,6 @@ _py_funcs = FuncRegistry()
 pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
 
 
-class CleanupFunc(object):
-  """A helper class to remove a registered function from _py_funcs."""
-
-  def __init__(self, token):
-    self._token = token
-
-  def __del__(self):
-    if _py_funcs is not None:
-      # If _py_funcs is None, the program is most likely in shutdown, and the
-      # _py_funcs object has been destroyed already.
-      _py_funcs.remove(self._token)
-
-
 def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
   """See documentation for py_func and eager_py_func."""
 
@@ -216,17 +207,15 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
     # bound to that of the outer graph instead.
     graph = graph._outer_graph
 
-  cleanup = CleanupFunc(token)
-
   # TODO(zhifengc): Consider adding a Graph method to collect
   # `cleanup` objects in one of its member.
-  if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"):
-    graph._cleanup_py_funcs_used_in_graph = []
+  if not hasattr(graph, "_py_funcs_used_in_graph"):
+    graph._py_funcs_used_in_graph = []
 
-  # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph
-  # will be destroyed and their __del__ will remove the 'token' from
-  # the funcs registry.
-  graph._cleanup_py_funcs_used_in_graph.append(cleanup)
+  # Store a reference to the function in the graph to ensure it stays alive
+  # as long as the graph lives. When the graph is destroyed, the function
+  # is left to the garbage collector for destruction as well.
+  graph._py_funcs_used_in_graph.append(func)
   # pylint: enable=protected-access
 
   if eager: