Fix TF_ImportGraphDefResults and TF_Function leaks in Python API.
authorSkye Wanderman-Milne <skyewm@google.com>
Thu, 5 Apr 2018 21:22:54 +0000 (14:22 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 21:25:18 +0000 (14:25 -0700)
PiperOrigin-RevId: 191797853

tensorflow/python/eager/function.py
tensorflow/python/eager/graph_callable.py
tensorflow/python/framework/c_api_util.py
tensorflow/python/framework/function.py
tensorflow/python/framework/importer.py
tensorflow/python/framework/ops.py

index 711eddc..61859d6 100644 (file)
@@ -294,7 +294,7 @@ class _EagerDefinedFunction(object):
     self.signature = function_def.signature
     self.grad_func_name = None
     self.python_grad_func = None
-    self._c_func = fn
+    self._c_func = c_api_util.ScopedTFFunction(fn)
     self._grad_func = None
 
 
@@ -661,7 +661,7 @@ def _defun_internal(name, func, args, kwds):
   if context.executing_eagerly():
     for f in tmp_graph._functions.values():  # pylint: disable=protected-access
       # TODO(ashankar): What about the gradient registry?
-      _register(f._c_func)  # pylint: disable=protected-access
+      _register(f._c_func.func)  # pylint: disable=protected-access
   return GraphModeFunction(
       fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
       func_outputs, output_shapes, variables)
index ee5d87f..d40ea98 100644 (file)
@@ -325,7 +325,7 @@ def _graph_callable_internal(func, shape_and_dtypes):
   # Also, what about the gradient registry of these functions? Those need to be
   # addressed as well.
   for f in tmp_graph._functions.values():  # pylint: disable=protected-access
-    function._register(f._c_func)  # pylint: disable=protected-access
+    function._register(f._c_func.func)  # pylint: disable=protected-access
   initializer_function = function.GraphModeFunction(
       initialization_name,
       placeholder_inputs,
index 4356a53..7bbe318 100644 (file)
@@ -63,6 +63,32 @@ class ScopedTFImportGraphDefOptions(object):
       c_api.TF_DeleteImportGraphDefOptions(self.options)
 
 
+class ScopedTFImportGraphDefResults(object):
+  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
+
+  def __init__(self, results):
+    self.results = results
+
+  def __del__(self):
+    # Note: when we're destructing the global context (i.e when the process is
+    # terminating) we can have already deleted other modules.
+    if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None:
+      c_api.TF_DeleteImportGraphDefResults(self.results)
+
+
+class ScopedTFFunction(object):
+  """Wrapper around TF_Function that handles deletion."""
+
+  def __init__(self, func):
+    self.func = func
+
+  def __del__(self):
+    # Note: when we're destructing the global context (i.e when the process is
+    # terminating) we can have already deleted other modules.
+    if c_api is not None and c_api.TF_DeleteFunction is not None:
+      c_api.TF_DeleteFunction(self.func)
+
+
 @tf_contextlib.contextmanager
 def tf_buffer(data=None):
   """Context manager that creates and deletes TF_Buffer.
index c5caf9e..9570f00 100644 (file)
@@ -274,7 +274,7 @@ class _DefinedFunction(object):
     self._create_definition_if_needed()
     if self._c_func:
       with c_api_util.tf_buffer() as buf:
-        c_api.TF_FunctionToFunctionDef(self._c_func, buf)
+        c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
         fdef = function_pb2.FunctionDef()
         proto_data = c_api.TF_GetBuffer(buf)
         fdef.ParseFromString(compat.as_bytes(proto_data))
@@ -397,7 +397,7 @@ class _DefinedFunction(object):
                       if self._out_names else [])
       description = self._func.__doc__ or None
       # pylint: disable=protected-access
-      self._c_func = c_api.TF_GraphToFunction_wrapper(
+      c_func = c_api.TF_GraphToFunction_wrapper(
           temp_graph._c_graph,
           base_func_name,
           self._func_name is None,  # append_hash_to_fn_name
@@ -407,6 +407,7 @@ class _DefinedFunction(object):
           output_names,
           None,  # opts
           description)
+      self._c_func = c_api_util.ScopedTFFunction(c_func)
       # pylint: enable=protected-access
       self._set_c_attrs(kwargs_attr)
 
@@ -429,7 +430,7 @@ class _DefinedFunction(object):
       serialized = attr_value.SerializeToString()
       # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
       # It might be worth creating a convenient way to re-use the same status.
-      c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
+      c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
                                          serialized)
 
   def _create_hash_str(self, input_arg, output_arg, node_def):
@@ -825,7 +826,8 @@ def _from_definition(fdef, grad_func=None):
   # pylint: disable=protected-access
   if ops._USE_C_API:
     serialized = fdef.SerializeToString()
-    result._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+    c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+    result._c_func = c_api_util.ScopedTFFunction(c_func)
     result._extra_inputs = []
   else:
     result._definition = fdef
index 23f529b..8beb74d 100644 (file)
@@ -487,6 +487,7 @@ def import_graph_def(graph_def,
         try:
           results = c_api.TF_GraphImportGraphDefWithResults(
               graph._c_graph, serialized, options)  # pylint: disable=protected-access
+          results = c_api_util.ScopedTFImportGraphDefResults(results)
         except errors.InvalidArgumentError as e:
           # Convert to ValueError for backwards compatibility.
           raise ValueError(str(e))
@@ -515,7 +516,7 @@ def import_graph_def(graph_def,
     # they are likely to be due to a typo.
     missing_unused_input_keys = (
         c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
-            results))
+            results.results))
     if missing_unused_input_keys:
       missing_unused_input_keys = [
           compat.as_str(s) for s in missing_unused_input_keys
@@ -527,7 +528,7 @@ def import_graph_def(graph_def,
     if return_elements is None:
       return None
     else:
-      return _GatherReturnElements(return_elements, graph, results)
+      return _GatherReturnElements(return_elements, graph, results.results)
 
   else:
     g = graph
index 2d55f98..84366e2 100644 (file)
@@ -3216,9 +3216,11 @@ class Graph(object):
       # as this will be unnecessary.
       if not function._c_func:
         serialized = function.definition.SerializeToString()
-        function._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
-      gradient = function._grad_func._c_func if function._grad_func else None
-      c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient)
+        c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+        function._c_func = c_api_util.ScopedTFFunction(c_func)
+      gradient = (function._grad_func._c_func.func if function._grad_func
+                  else None)
+      c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
     else:
       # If there is already a function with the same name, raise an error
       # if bodies are different. Else, do nothing. The C API version above