from __future__ import division
from __future__ import print_function
+import gc
import re
import numpy as np
# ----- Tests shared by py_func and eager_py_func -----
def testCleanup(self):
- for _ in xrange(1000):
- g = ops.Graph()
- with g.as_default():
- c = constant_op.constant([1.], dtypes.float32)
- _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
- _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
- self.assertTrue(script_ops._py_funcs.size() < 100)
+ # Delete everything created by previous tests to avoid side effects.
+ ops.reset_default_graph()
+ gc.collect()
+ initial_size = script_ops._py_funcs.size()
+ # Encapsulate the graph generation, so locals can be deleted.
+ def make_graphs():
+ for _ in xrange(1000):
+ g = ops.Graph()
+ with g.as_default():
+ c = constant_op.constant([1.], dtypes.float32)
+ _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
+ _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
+ # These ops have a reference to 'c' which has a reference to the graph.
+ # Checks if the functions are being deleted though the graph is referenced from them.
+ # (see #18292)
+ _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
+ _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
+
+ # Call garbage collector to enforce deletion.
+ make_graphs()
+ ops.reset_default_graph()
+ gc.collect()
+ self.assertEqual(initial_size, script_ops._py_funcs.size())
# ----- Tests for eager_py_func -----
@test_util.run_in_graph_and_eager_modes()
# Used by py_util.cc to get tracebacks.
import traceback # pylint: disable=unused-import
+import weakref
import numpy as np
import six
def __init__(self):
self._lock = threading.Lock()
self._unique_id = 0 # GUARDED_BY(self._lock)
- self._funcs = {}
+ # Only store weakrefs to the funtions. The strong reference is stored in
+ # the graph.
+ self._funcs = weakref.WeakValueDictionary()
def insert(self, func):
"""Registers `func` and returns a unique token for this entry."""
token = self._next_unique_token()
+ # Store a weakref to the function
self._funcs[token] = func
return token
Raises:
ValueError: if no function is registered for `token`.
"""
- func = self._funcs[token]
+ func = self._funcs.get(token, None)
if func is None:
raise ValueError("callback %s is not found" % token)
if isinstance(func, EagerFunc):
pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
-class CleanupFunc(object):
- """A helper class to remove a registered function from _py_funcs."""
-
- def __init__(self, token):
- self._token = token
-
- def __del__(self):
- if _py_funcs is not None:
- # If _py_funcs is None, the program is most likely in shutdown, and the
- # _py_funcs object has been destroyed already.
- _py_funcs.remove(self._token)
-
-
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
"""See documentation for py_func and eager_py_func."""
# bound to that of the outer graph instead.
graph = graph._outer_graph
- cleanup = CleanupFunc(token)
-
# TODO(zhifengc): Consider adding a Graph method to collect
# `cleanup` objects in one of its member.
- if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"):
- graph._cleanup_py_funcs_used_in_graph = []
+ if not hasattr(graph, "_py_funcs_used_in_graph"):
+ graph._py_funcs_used_in_graph = []
- # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph
- # will be destroyed and their __del__ will remove the 'token' from
- # the funcs registry.
- graph._cleanup_py_funcs_used_in_graph.append(cleanup)
+ # Store a reference to the function in the graph to ensure it stays alive
+ # as long as the graph lives. When the graph is destroyed, the function
+ # is left to the garbage collector for destruction as well.
+ graph._py_funcs_used_in_graph.append(func)
# pylint: enable=protected-access
if eager: