self.signature = function_def.signature
self.grad_func_name = None
self.python_grad_func = None
- self._c_func = fn
+ self._c_func = c_api_util.ScopedTFFunction(fn)
self._grad_func = None
if context.executing_eagerly():
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
- _register(f._c_func) # pylint: disable=protected-access
+ _register(f._c_func.func) # pylint: disable=protected-access
return GraphModeFunction(
fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
func_outputs, output_shapes, variables)
# Also, what about the gradient registry of these functions? Those need to be
# addressed as well.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
- function._register(f._c_func) # pylint: disable=protected-access
+ function._register(f._c_func.func) # pylint: disable=protected-access
initializer_function = function.GraphModeFunction(
initialization_name,
placeholder_inputs,
c_api.TF_DeleteImportGraphDefOptions(self.options)
+class ScopedTFImportGraphDefResults(object):
+ """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
+
+ def __init__(self, results):
+ self.results = results
+
+ def __del__(self):
+ # Note: when we're destructing the global context (i.e when the process is
+ # terminating) we can have already deleted other modules.
+ if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None:
+ c_api.TF_DeleteImportGraphDefResults(self.results)
+
+
+class ScopedTFFunction(object):
+ """Wrapper around TF_Function that handles deletion."""
+
+ def __init__(self, func):
+ self.func = func
+
+ def __del__(self):
+ # Note: when we're destructing the global context (i.e when the process is
+ # terminating) we can have already deleted other modules.
+ if c_api is not None and c_api.TF_DeleteFunction is not None:
+ c_api.TF_DeleteFunction(self.func)
+
+
@tf_contextlib.contextmanager
def tf_buffer(data=None):
"""Context manager that creates and deletes TF_Buffer.
self._create_definition_if_needed()
if self._c_func:
with c_api_util.tf_buffer() as buf:
- c_api.TF_FunctionToFunctionDef(self._c_func, buf)
+ c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
fdef = function_pb2.FunctionDef()
proto_data = c_api.TF_GetBuffer(buf)
fdef.ParseFromString(compat.as_bytes(proto_data))
if self._out_names else [])
description = self._func.__doc__ or None
# pylint: disable=protected-access
- self._c_func = c_api.TF_GraphToFunction_wrapper(
+ c_func = c_api.TF_GraphToFunction_wrapper(
temp_graph._c_graph,
base_func_name,
self._func_name is None, # append_hash_to_fn_name
output_names,
None, # opts
description)
+ self._c_func = c_api_util.ScopedTFFunction(c_func)
# pylint: enable=protected-access
self._set_c_attrs(kwargs_attr)
serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status.
- c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
+ c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
serialized)
def _create_hash_str(self, input_arg, output_arg, node_def):
# pylint: disable=protected-access
if ops._USE_C_API:
serialized = fdef.SerializeToString()
- result._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ result._c_func = c_api_util.ScopedTFFunction(c_func)
result._extra_inputs = []
else:
result._definition = fdef
try:
results = c_api.TF_GraphImportGraphDefWithResults(
graph._c_graph, serialized, options) # pylint: disable=protected-access
+ results = c_api_util.ScopedTFImportGraphDefResults(results)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
# they are likely to be due to a typo.
missing_unused_input_keys = (
c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
- results))
+ results.results))
if missing_unused_input_keys:
missing_unused_input_keys = [
compat.as_str(s) for s in missing_unused_input_keys
if return_elements is None:
return None
else:
- return _GatherReturnElements(return_elements, graph, results)
+ return _GatherReturnElements(return_elements, graph, results.results)
else:
g = graph
# as this will be unnecessary.
if not function._c_func:
serialized = function.definition.SerializeToString()
- function._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
- gradient = function._grad_func._c_func if function._grad_func else None
- c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient)
+ c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ function._c_func = c_api_util.ScopedTFFunction(c_func)
+ gradient = (function._grad_func._c_func.func if function._grad_func
+ else None)
+ c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
else:
# If there is already a function with the same name, raise an error
# if bodies are different. Else, do nothing. The C API version above