Add utility for converting FunctionDef to GraphDef and _FuncGraph.
authorSaurabh Saxena <srbs@google.com>
Thu, 31 May 2018 23:06:15 +0000 (16:06 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 23:15:46 +0000 (16:15 -0700)
PiperOrigin-RevId: 198795625

tensorflow/python/BUILD
tensorflow/python/framework/function_def_to_graph.py [new file with mode: 0644]
tensorflow/python/framework/function_def_to_graph_test.py [new file with mode: 0644]

index b15c529..569403f 100644 (file)
@@ -718,6 +718,38 @@ py_library(
 )
 
 py_library(
+    name = "function_def_to_graph",
+    srcs = ["framework/function_def_to_graph.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":framework",
+        ":function",
+        ":op_def_registry",
+        ":tensor_shape",
+        ":versions",
+        "//tensorflow/core:protos_all_py",
+    ],
+)
+
+py_test(
+    name = "function_def_to_graph_test",
+    size = "small",
+    srcs = ["framework/function_def_to_graph_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["no_pip"],
+    deps = [
+        ":array_ops",
+        ":client_testlib",
+        ":dtypes",
+        ":framework_ops",
+        ":function_def_to_graph",
+        ":graph_to_function_def",
+        ":math_ops",
+        ":test_ops",
+    ],
+)
+
+py_library(
     name = "graph_util",
     srcs = [
         "framework/graph_util.py",
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
new file mode 100644 (file)
index 0000000..4fecc41
--- /dev/null
@@ -0,0 +1,189 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Utlity to convert FunctionDef to GraphDef and Graph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.core.framework import versions_pb2
+from tensorflow.python.framework import function
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import versions
+
+
+def function_def_to_graph(fdef, input_shapes=None):
+  """Converts a FunctionDef to a function._FuncGraph (sub-class Graph).
+
+  The returned _FuncGraph's `name`, `inputs` and `outputs` fields will be set.
+  The input tensors are represented as placeholders.
+
+  Note: `_FuncGraph.inputs` and `_FuncGraph._captured` are not set and may be
+  set by the caller.
+
+  Args:
+    fdef: FunctionDef.
+    input_shapes: Optional. A list of TensorShape objects of the shapes of
+      function inputs. If specified, its length must match length of
+      `fdef.signature.input_arg`. If a shape is None, the corresponding input
+      placeholder will have unknown shape.
+
+  Returns:
+    A _FuncGraph.
+  """
+  func_graph = function._FuncGraph(fdef.signature.name, capture_by_value=False)  # pylint: disable=protected-access
+  graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
+      fdef, input_shapes)
+
+  with func_graph.as_default():
+    # Add all function nodes to the graph.
+    importer.import_graph_def(graph_def, name="")
+
+    # Initialize fields specific to _FuncGraph.
+
+    # inputs
+    input_tensor_names = [
+        nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg
+    ]
+    func_graph.inputs = [
+        func_graph.get_tensor_by_name(name) for name in input_tensor_names
+    ]
+
+    # outputs
+    output_tensor_names = [
+        nested_to_flat_tensor_name[fdef.ret[arg.name]]
+        for arg in fdef.signature.output_arg
+    ]
+    func_graph.outputs = [
+        func_graph.get_tensor_by_name(name) for name in output_tensor_names
+    ]
+
+  return func_graph
+
+
+def function_def_to_graph_def(fdef, input_shapes=None):
+  """Convert a FunctionDef to a GraphDef.
+
+  Steps:
+  1. Creates placeholder nodes corresponding to inputs in
+     `FunctionDef.signature.input_arg`.
+  2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
+  3. Renames inputs of all nodes to use the convention of GraphDef instead of
+     FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
+     in FunctionDefs is different from GraphDefs.
+
+  Args:
+    fdef: FunctionDef.
+    input_shapes: Optional. A list of TensorShape objects of the shapes of
+      function inputs. If specified, its length must match length of
+      `fdef.signature.input_arg`. If a shape is None, the corresponding input
+      placeholder will have unknown shape.
+
+  Returns:
+    A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
+    from nested tensor names (in FunctionDef) to flattened names (in GraphDef).
+
+  Raises:
+    ValueError: If the length of input_shapes does not match the number of
+      input_args or if the FunctionDef is invalid.
+  """
+  graph_def = graph_pb2.GraphDef()
+  graph_def.versions.CopyFrom(
+      versions_pb2.VersionDef(
+          producer=versions.GRAPH_DEF_VERSION,
+          min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
+
+  if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
+    raise ValueError("Length of input_shapes must match the number of " +
+                     "input_args. len(input_shapes): {} len(input_arg): {}".
+                     format(len(input_shapes), len(fdef.signature.input_arg)))
+
+  # 1. Create placeholders for input nodes.
+  for i, arg_def in enumerate(fdef.signature.input_arg):
+    node_def = graph_def.node.add()
+    node_def.name = arg_def.name
+    node_def.op = "Placeholder"
+    node_def.attr["dtype"].type = arg_def.type
+    if input_shapes and input_shapes[i] is not None:
+      node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto())
+
+  # 2. Copy all body NodeDefs to the GraphDef.
+  graph_def.node.extend(fdef.node_def)
+
+  # 3. Perform the renaming.
+
+  # Build the tensor name mapping then flatten the tensor names.
+  # See comment on `FunctionDef.node_def` on how the tensor naming in
+  # FunctionDefs is different from GraphDefs.
+  nested_to_flat_tensor_name = {}
+
+  for arg_def in fdef.signature.input_arg:
+    nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
+
+  for node_def in fdef.node_def:
+    op_def = op_def_registry.get_registered_ops().get(node_def.op)
+    if not op_def:
+      # TODO(b/80470245): Support functions which refer other functions.
+      raise NotImplementedError(
+          "No op registered for {},".format(node_def.op) +
+          " it may be a function. function_def_to_graph_def " +
+          "currently does not support converting functions with " +
+          "references to other graph functions.")
+
+    for attr in op_def.attr:
+      if attr.type in ("func", "list(func)"):
+        # TODO(b/80470245): Support functions which refer other functions.
+        raise NotImplementedError("Unsupported attr {} ".format(attr.name) +
+                                  " with type {}".format(attr.type) +
+                                  " in op {}. ".format(op_def.name) +
+                                  "function_def_to_graph_def currently does " +
+                                  "not support converting functions with " +
+                                  "references to other graph functions.")
+
+    # Iterate over output_args in op_def to build the map.
+    # Index of the output tensor in the flattened list of *all* output
+    # tensors of the op.
+    flattened_index = 0
+    for arg_def in op_def.output_arg:
+      num_args = _get_num_args(arg_def, node_def)
+      for i in range(num_args):
+        # Map tensor names from "node_name:output_arg_name:index" to
+        # "node_name:flattened_index".
+        nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
+        flat_name = "{}:{}".format(node_def.name, flattened_index)
+        nested_to_flat_tensor_name[nested_name] = flat_name
+        flattened_index += 1
+
+  # Update inputs of all nodes in graph.
+  for node_def in graph_def.node:
+    for i in range(len(node_def.input)):
+      node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
+
+  return graph_def, nested_to_flat_tensor_name
+
+
+# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange.
+def _get_num_args(arg_def, node_def):
+  if arg_def.number_attr:
+    return node_def.attr[arg_def.number_attr].i
+  elif arg_def.type_list_attr:
+    return len(node_def.attr[arg_def.type_list_attr].list.type)
+  elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
+    return 1
+  else:
+    raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py
new file mode 100644 (file)
index 0000000..0f4e6ef
--- /dev/null
@@ -0,0 +1,184 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.python.framework.function_def_to_graph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function_def_to_graph
+from tensorflow.python.framework import graph_to_function_def
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class FunctionDefToGraphTest(test.TestCase):
+
+  def _build_function_def(self):
+    with ops.Graph().as_default() as g:
+      # Inputs
+      x = array_ops.placeholder(dtypes.float32, name="x")
+      y = array_ops.placeholder(dtypes.float32, name="y")
+
+      # Outputs
+      sum_squares = math_ops.add_n(
+          [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares")
+      sum_cubes = math_ops.add_n(
+          [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes")
+    fdef = graph_to_function_def.graph_to_function_def(
+        g,
+        g.get_operations(),
+        [x, y],  # Inputs
+        [sum_squares, sum_cubes])  # Outputs.
+    fdef.signature.name = "_whats_in_a_name"
+    return fdef
+
+  def testInputsAndOutputs(self):
+    fdef = self._build_function_def()
+    g = function_def_to_graph.function_def_to_graph(fdef)
+    self.assertEqual(g.name, "_whats_in_a_name")
+    with self.test_session(graph=g) as sess:
+      inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3})
+      self.assertSequenceEqual(inputs, [2.0, 3.0])
+      outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3})
+      self.assertSequenceEqual(outputs, [13.0, 35.0])
+
+  def testShapes(self):
+    fdef = self._build_function_def()
+
+    g = function_def_to_graph.function_def_to_graph(fdef)
+    self.assertIsNone(g.inputs[0].shape.dims)  # Unknown dims.
+    self.assertIsNone(g.inputs[1].shape.dims)  # Unknown dims.
+    self.assertIsNone(g.outputs[0].shape.dims)  # Unknown dims.
+    self.assertIsNone(g.outputs[1].shape.dims)  # Unknown dims.
+
+    g = function_def_to_graph.function_def_to_graph(
+        fdef, input_shapes=[tensor_shape.vector(5),
+                            tensor_shape.vector(5)])
+    self.assertSequenceEqual(g.inputs[0].shape.dims, [5])
+    self.assertSequenceEqual(g.inputs[1].shape.dims, [5])
+    self.assertSequenceEqual(g.outputs[0].shape.dims, [5])
+    self.assertSequenceEqual(g.outputs[1].shape.dims, [5])
+
+    g = function_def_to_graph.function_def_to_graph(
+        fdef, input_shapes=[None, tensor_shape.matrix(5, 7)])
+    print(g.as_graph_def())
+    self.assertIsNone(g.inputs[0].shape.dims)
+    self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7])
+    self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7])
+    self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7])
+
+    # Should raise a ValueError if the length of input_shapes does not match
+    # the number of input args in FunctionDef.signature.input_arg.
+    with self.assertRaises(ValueError):
+      g = function_def_to_graph.function_def_to_graph(
+          fdef, input_shapes=[tensor_shape.matrix(5, 7)])
+
+
+class FunctionDefToGraphDefTest(test.TestCase):
+
+  def _build_function_def(self):
+    with ops.Graph().as_default() as g:
+      # Inputs:    x    y    z
+      #            |\   |   /
+      #            | \  |  /
+      #            |  foo_1     list_output
+      #            |   / \       /       \
+      #            | d_1 e_1  a:1        a:0
+      #            |  \   |   /           |
+      #            |   \  |  /            |
+      #            |    foo_2             |
+      #            |     / \              |
+      # Outputs:   x   d_2 e_2           a:0
+
+      x = array_ops.placeholder(dtypes.float32, name="x")
+      y = array_ops.placeholder(dtypes.int32, name="y")
+      z = array_ops.placeholder(dtypes.int32, name="z")
+
+      d_1, e_1 = test_ops._op_def_lib.apply_op(
+          "Foo1", name="foo_1", a=x, b=y, c=z)
+
+      list_output0, list_output1 = test_ops.list_output(
+          T=[dtypes.int32, dtypes.int32], name="list_output")
+
+      d_2, e_2 = test_ops.foo1(a=d_1, b=e_1, c=list_output1, name="foo_2")
+
+    fdef = graph_to_function_def.graph_to_function_def(
+        g,
+        g.get_operations(),
+        [x, y, z],  # Inputs
+        [x, d_2, e_2, list_output0])  # Outputs.
+
+    # Assert that the FunctionDef was correctly built.
+    assert len(fdef.node_def) == 3  # 2 Foo1 nodes and 1 ListOutput node.
+    assert fdef.node_def[0].op == "Foo1"
+    assert fdef.node_def[0].input == ["x", "y", "z"]
+    assert fdef.node_def[1].op == "ListOutput"
+    assert not fdef.node_def[1].input
+    assert fdef.node_def[2].op == "Foo1"
+    assert fdef.node_def[2].input == [
+        "foo_1:d:0", "foo_1:e:0", "list_output:a:1"
+    ]
+    return fdef
+
+  def testTensorNames(self):
+    fdef = self._build_function_def()
+    g, tensor_name_map = function_def_to_graph.function_def_to_graph_def(fdef)
+
+    # Verify that inputs of body nodes are correctly renamed.
+    # foo_1
+    self.assertSequenceEqual(g.node[3].input, ["x:0", "y:0", "z:0"])
+    # foo_2
+    self.assertSequenceEqual(g.node[5].input,
+                             ["foo_1:0", "foo_1:1", "list_output:1"])
+
+    # Verify that the `tensor_name_map` has the correct mapping.
+    self.assertDictEqual(
+        tensor_name_map, {
+            "x": "x:0",
+            "y": "y:0",
+            "z": "z:0",
+            "foo_1:d:0": "foo_1:0",
+            "foo_1:e:0": "foo_1:1",
+            "list_output:a:0": "list_output:0",
+            "list_output:a:1": "list_output:1",
+            "foo_2:d:0": "foo_2:0",
+            "foo_2:e:0": "foo_2:1",
+        })
+
+  def testShapes(self):
+    fdef = self._build_function_def()
+    g, _ = function_def_to_graph.function_def_to_graph_def(
+        fdef,
+        input_shapes=[tensor_shape.scalar(),
+                      tensor_shape.vector(5), None])
+    self.assertEqual("shape" in g.node[0].attr, True)
+    self.assertSequenceEqual(
+        tensor_shape.TensorShape(g.node[0].attr["shape"].shape).as_list(), [])
+    self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False)
+    self.assertEqual("shape" in g.node[1].attr, True)
+    self.assertSequenceEqual(
+        tensor_shape.TensorShape(g.node[1].attr["shape"].shape).as_list(), [5])
+    self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False)
+    self.assertFalse("shape" in g.node[2].attr)
+
+
+if __name__ == "__main__":
+  test.main()