From 00ee6689bb838f45a393d4fbca11ad10018a382a Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 17 May 2018 09:28:35 -0700 Subject: [PATCH] Improvements to function._FuncGraph. * Adds 'inputs', 'outputs', and 'name' field to _FuncGraph. This allows _FuncGraph to encapsulate all the information needed to convert it to a FunctionDef. * Refactor logic for converting a Python callable to a _FuncGraph into a new method, func_graph_from_py_func(). These changes are in preparation for converting tf.cond to emit an If op. By exposing _FuncGraph functionality outside of _DefinedFunction, _FuncGraphs can be used to represent functions that are manipulated (e.g. to output intermediate tensors) before being converted to FunctionDef protos. PiperOrigin-RevId: 197003496 --- tensorflow/python/framework/function.py | 132 +++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 45 deletions(-) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 94c37d6..6882b44 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -258,12 +258,10 @@ class _DefinedFunction(object): # another reference to _definition.signature self._op_def = None - self._args = [] assert isinstance(input_types, (list, tuple)) - for i in range(len(input_types)): - argname = argnames[i] if i < len(argnames) else ("arg%d" % i) - argtype = input_types[i] - self._args.append((argname, argtype)) + self._arg_types = input_types + self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i) + for i in range(len(input_types))] @property def name(self): @@ -336,42 +334,11 @@ class _DefinedFunction(object): if self._definition is not None or self._c_func is not None: return - # Create the func_def object. - temp_graph = _FuncGraph(capture_by_value=self._capture_by_value) - with temp_graph.as_default(), ops.device(self._caller_device): - # List of placeholders for the function_def. - inputs = [] - for (argname, argtype) in self._args: - argholder = array_ops.placeholder(argtype, name=argname) - inputs.append(argholder) - # Call func and gather the output tensors. - with vs.variable_scope("", custom_getter=temp_graph.getvar): - outputs = self._func(*inputs) - - # There is no way of distinguishing between a function not returning - # anything and a function returning None in Python. - # We need to allow the former and ideally want to forbid the latter as - # it is most likely user error. - # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to - # allow users to explicitly mark the function as not returning anything. - # For now, we allow a single None return and interpret it as a function - # with no output. - if outputs is None: - outputs = [] - else: - # If func only returned one value, make it a tuple. - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - if any([_ is None for _ in outputs]): - raise ValueError("Function can not return None.") - # Ensures each output is a Tensor in the function graph. - outputs = [ops.convert_to_tensor(t) for t in outputs] - outputs = [ - temp_graph.capture(t) if t.graph is not temp_graph else t - for t in outputs - ] + temp_graph = func_graph_from_py_func( + self._func, self._arg_names, self._arg_types, self._func_name, + self._capture_by_value, self._caller_device) + self._extra_inputs = temp_graph.extra_inputs - inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access @@ -390,8 +357,8 @@ class _DefinedFunction(object): self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), - inputs, - outputs, + temp_graph.inputs, + temp_graph.outputs, out_names=self._out_names) for k in kwargs_attr: @@ -421,8 +388,8 @@ class _DefinedFunction(object): base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers - [t._as_tf_output() for t in inputs], - [t._as_tf_output() for t in outputs], + [t._as_tf_output() for t in temp_graph.inputs], + [t._as_tf_output() for t in temp_graph.outputs], output_names, None, # opts description) @@ -653,16 +620,33 @@ class _FuncGraph(ops.Graph): function argument and the caller passes in the captured tensor. """ - def __init__(self, capture_by_value, *args, **kwargs): + def __init__(self, name, capture_by_value, *args, **kwargs): super(_FuncGraph, self).__init__(*args, **kwargs) self._capture_by_value = capture_by_value self._building_function = True self._outer_graph = ops.get_default_graph() self._vscope = vs.get_variable_scope() self._old_custom_getter = self._vscope.custom_getter + + # The name of the function. + self.name = name + # Placeholder tensors representing the inputs to this function. The tensors + # are in this _FuncGraph. + self.inputs = [] + # Tensors that will be returned this function. The tensors are in this + # _FuncGraph. + self.outputs = [] + # Maps external tensor -> internal tensor (e.g. input placeholder). self._captured = {} + # The external tensors that have been captured as inputs and must be passed + # to this function (empty if capturing by value, otherwise these are the + # keys of _captured). self.extra_inputs = [] + # Input placeholders that been added for captured values (empty if capturing + # by value). self.extra_args = [] + # Captured variables. + # TODO(skyewm): is this needed? self.extra_vars = [] def getvar( @@ -742,6 +726,7 @@ class _FuncGraph(ops.Graph): else: ph._handle_data = tensor._handle_data # pylint: enable=protected-access + self.inputs.append(ph) self._captured[tensor] = ph self.extra_args.append(ph) if _is_guaranteed_const(tensor): @@ -780,6 +765,63 @@ class _FuncGraph(ops.Graph): return captured_op +def func_graph_from_py_func(func, arg_names, arg_types, name=None, + capture_by_value=False, device=None): + """Returns a _FuncGraph generated from `func`. + + Args: + func: A Python callable which constructs a TF function body. The arguments + must correspond to `arg_types`. Returns a value or list/tuple of values. + No returned value can be None. + arg_names: A sequence of strings for the function argument names. + arg_types: A sequence of the function's argument types. + name: The function name. If None, the name is derived from `func`. + capture_by_value: boolean. If True, captured values will be copied into the + function body. + device: device name or function. + + Returns: + A _FuncGraph. + + Raises: + ValueError: if func returns None. + """ + if not name: + name = _get_func_name(func) + func_graph = _FuncGraph(name, capture_by_value) + with func_graph.as_default(), ops.device(device): + # Create placeholders for the function arguments. + for (argname, argtype) in zip(arg_names, arg_types): + argholder = array_ops.placeholder(argtype, name=argname) + func_graph.inputs.append(argholder) + # Call func and gather the output tensors. + with vs.variable_scope("", custom_getter=func_graph.getvar): + outputs = func(*func_graph.inputs) + + # There is no way of distinguishing between a function not returning + # anything and a function returning None in Python. + # We need to allow the former and ideally want to forbid the latter as + # it is most likely user error. + # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to + # allow users to explicitly mark the function as not returning anything. + # For now, we allow a single None return and interpret it as a function + # with no output. + if outputs is None: + outputs = [] + else: + # If func only returned one value, make it a tuple. + if not isinstance(outputs, (list, tuple)): + outputs = (outputs,) + if any([_ is None for _ in outputs]): + raise ValueError("Function can not return None.") + # Ensures each output is a Tensor in the function graph. + outputs = [ops.convert_to_tensor(t) for t in outputs] + outputs = [func_graph.capture(t) if t.graph is not func_graph else t + for t in outputs] + func_graph.outputs = outputs + return func_graph + + def _is_guaranteed_const(tensor): """Determines whether `tensor` is guaranteed to be a constant. -- 2.7.4