From 78446bda808364c8b2c0a87d122763aecc7caea2 Mon Sep 17 00:00:00 2001 From: voegtlel Date: Tue, 29 May 2018 18:15:10 +0200 Subject: [PATCH] Fixed memory leak with py_func (#18292) (#19085) * Fixing memory leak with py_func (#18292) * Fixed memory leak with py_func (#18292) --- tensorflow/python/kernel_tests/py_func_test.py | 31 +++++++++++++++++------ tensorflow/python/ops/script_ops.py | 35 +++++++++----------------- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index b9f44d7..c899945 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gc import re import numpy as np @@ -432,13 +433,29 @@ class PyFuncTest(test.TestCase): # ----- Tests shared by py_func and eager_py_func ----- def testCleanup(self): - for _ in xrange(1000): - g = ops.Graph() - with g.as_default(): - c = constant_op.constant([1.], dtypes.float32) - _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) - _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) - self.assertTrue(script_ops._py_funcs.size() < 100) + # Delete everything created by previous tests to avoid side effects. + ops.reset_default_graph() + gc.collect() + initial_size = script_ops._py_funcs.size() + # Encapsulate the graph generation, so locals can be deleted. + def make_graphs(): + for _ in xrange(1000): + g = ops.Graph() + with g.as_default(): + c = constant_op.constant([1.], dtypes.float32) + _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) + _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) + # These ops have a reference to 'c' which has a reference to the graph. + # Checks if the functions are being deleted though the graph is referenced from them. + # (see #18292) + _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) + _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) + + # Call garbage collector to enforce deletion. + make_graphs() + ops.reset_default_graph() + gc.collect() + self.assertEqual(initial_size, script_ops._py_funcs.size()) # ----- Tests for eager_py_func ----- @test_util.run_in_graph_and_eager_modes() diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index f87c5dc..16c7321 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -24,6 +24,7 @@ import threading # Used by py_util.cc to get tracebacks. import traceback # pylint: disable=unused-import +import weakref import numpy as np import six @@ -88,11 +89,14 @@ class FuncRegistry(object): def __init__(self): self._lock = threading.Lock() self._unique_id = 0 # GUARDED_BY(self._lock) - self._funcs = {} + # Only store weakrefs to the funtions. The strong reference is stored in + # the graph. + self._funcs = weakref.WeakValueDictionary() def insert(self, func): """Registers `func` and returns a unique token for this entry.""" token = self._next_unique_token() + # Store a weakref to the function self._funcs[token] = func return token @@ -145,7 +149,7 @@ class FuncRegistry(object): Raises: ValueError: if no function is registered for `token`. """ - func = self._funcs[token] + func = self._funcs.get(token, None) if func is None: raise ValueError("callback %s is not found" % token) if isinstance(func, EagerFunc): @@ -180,19 +184,6 @@ _py_funcs = FuncRegistry() pywrap_tensorflow.InitializePyTrampoline(_py_funcs) -class CleanupFunc(object): - """A helper class to remove a registered function from _py_funcs.""" - - def __init__(self, token): - self._token = token - - def __del__(self): - if _py_funcs is not None: - # If _py_funcs is None, the program is most likely in shutdown, and the - # _py_funcs object has been destroyed already. - _py_funcs.remove(self._token) - - def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): """See documentation for py_func and eager_py_func.""" @@ -216,17 +207,15 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): # bound to that of the outer graph instead. graph = graph._outer_graph - cleanup = CleanupFunc(token) - # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. - if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"): - graph._cleanup_py_funcs_used_in_graph = [] + if not hasattr(graph, "_py_funcs_used_in_graph"): + graph._py_funcs_used_in_graph = [] - # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph - # will be destroyed and their __del__ will remove the 'token' from - # the funcs registry. - graph._cleanup_py_funcs_used_in_graph.append(cleanup) + # Store a reference to the function in the graph to ensure it stays alive + # as long as the graph lives. When the graph is destroyed, the function + # is left to the garbage collector for destruction as well. + graph._py_funcs_used_in_graph.append(func) # pylint: enable=protected-access if eager: -- 2.7.4