Create an interface to create hints for future toco conversions.
authorAndrew Selle <aselle@google.com>
Tue, 30 Jan 2018 17:55:38 +0000 (09:55 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 30 Jan 2018 18:07:34 +0000 (10:07 -0800)
Specifically, tf.contrib.lite.OpHint can create "breadcrumb"
hints that describe encapsulation of multiple TensorFlow ops
that make up a TensorFlow lite builtin or custom op. These
can later be replaced with stub versions in a GraphDef or
SavedModel.

PiperOrigin-RevId: 183846742

tensorflow/contrib/framework/__init__.py
tensorflow/contrib/lite/python/BUILD
tensorflow/contrib/lite/python/lite.py
tensorflow/contrib/lite/python/lite_test.py
tensorflow/contrib/lite/python/op_hint.py [new file with mode: 0644]

index 673c517..503b868 100644 (file)
@@ -53,6 +53,7 @@ See the @{$python/contrib.framework} guide.
 @@assign_from_values_fn
 @@create_global_step
 @@filter_variables
+@@fuse_op
 @@get_global_step
 @@get_or_create_global_step
 @@get_local_variables
index 3d6a3ec..2d8c49b 100644 (file)
@@ -13,6 +13,7 @@ py_library(
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
+        ":op_hint",
         "//tensorflow/contrib/lite/toco:model_flags_proto_py",
         "//tensorflow/contrib/lite/toco:toco_flags_proto_py",
         "//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco",
@@ -20,6 +21,17 @@ py_library(
     ],
 )
 
+py_library(
+    name = "op_hint",
+    srcs = ["op_hint.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/contrib/framework:framework_py",
+        "//tensorflow/python:platform",
+    ],
+)
+
 py_test(
     name = "lite_test",
     srcs = ["lite_test.py"],
@@ -27,6 +39,7 @@ py_test(
     tags = ["no_oss"],
     deps = [
         ":lite",
+        ":op_hint",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:dtypes",
index 3c36977..5d2f216 100644 (file)
@@ -18,16 +18,21 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
 
 @@toco_convert
 @@toco_convert_protos
+@@OpHint
+@@convert_op_hints_to_stubs
 
 """
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
-
 import os
 import subprocess
 import tempfile
 
+# pylint: disable=unused-import
+from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs
+from tensorflow.contrib.lite.python.op_hint import OpHint
+# pylint: enable=unused-import
 from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
 from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
index 7d55f3f..b8b4510 100644 (file)
@@ -18,10 +18,14 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.python.op_hint import _tensor_name_base as _tensor_name_base
 from tensorflow.python.client import session
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_util
+from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
+from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
@@ -35,7 +39,8 @@ class LiteTest(test_util.TensorFlowTestCase):
     # Try running on valid graph
     result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
     self.assertTrue(result)
-    # TODO(aselle): remove tests that fail.
+    # TODO(aselle): remove tests that fail (we must get TOCO to not fatal
+    # all the time).
     # Try running on identity graph (known fail)
     # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
     #   result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
@@ -51,5 +56,116 @@ class LiteTest(test_util.TensorFlowTestCase):
                                quantized_input_stats=[(0., 1.)])
     self.assertTrue(result)
 
+
+class LiteTestOpHint(test_util.TensorFlowTestCase):
+  """Test the hint to stub functionality."""
+
+  def _getGraphOpTypes(self, graphdef, output_nodes):
+    """Returns used op types in `graphdef` reachable from `output_nodes`.
+
+    This is used to check that after the stub transformation the expected
+    nodes are there. Typically use this with self.assertCountEqual(...).
+
+    NOTE: this is not a exact test that the graph is the correct output, but
+      it balances compact expressibility of test with sanity checking.
+
+    Args:
+      graphdef: TensorFlow proto graphdef.
+      output_nodes: A list of output node names that we need to reach.
+
+    Returns:
+      A set of node types reachable from `output_nodes`.
+    """
+    name_to_input_name, name_to_node, _ = (
+        _extract_graph_summary(graphdef))
+    # Find all nodes that are needed by the outputs
+    used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name)
+    return set([name_to_node[node_name].op for node_name in used_node_names])
+
+  def _countIdentities(self, nodes):
+    """Count the number of "Identity" op types in the list of proto nodes.
+
+    Args:
+      nodes: NodeDefs of the graph.
+
+    Returns:
+      The number of nodes with op type "Identity" found.
+    """
+    return len([x for x in nodes if x.op == "Identity"])
+
+  def testSwishLiteHint(self):
+    """Makes a custom op swish and makes sure it gets converted as a unit."""
+    image = array_ops.constant([1., 2., 3., 4.])
+    swish_scale = array_ops.constant(1.0)
+
+    def _swish(input_tensor, scale):
+      custom = lite.OpHint("cool_activation")
+      input_tensor, scale = custom.add_inputs(input_tensor, scale)
+      output = math_ops.sigmoid(input_tensor) * input_tensor * scale
+      output, = custom.add_outputs(output)
+      return output
+    output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput")
+
+    with self.test_session() as sess:
+      # check if identities have been put into the graph (2 input, 1 output,
+      # and 1 final output).
+      self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
+
+      stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
+
+      self.assertCountEqual(
+          self._getGraphOpTypes(
+              stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
+          ["cool_activation", "Const", "Identity"])
+
+  def testScaleAndBiasAndIdentity(self):
+    """This tests a scaled add which has 3 inputs and 2 outputs."""
+    a = array_ops.constant(1.)
+    x = array_ops.constant([2., 3.])
+    b = array_ops.constant([4., 5.])
+
+    def _scaled_and_bias_and_identity(a, x, b):
+      custom = lite.OpHint("scale_and_bias_and_identity")
+      a, x, b = custom.add_inputs(a, x, b)
+      return custom.add_outputs(a * x + b, x)
+    output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b),
+                                name="ModelOutput")
+
+    with self.test_session() as sess:
+      # make sure one identity for each input (3) and output (2) => 3 + 2 = 5
+      # +1 for the final output
+      self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
+
+      stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
+
+      self.assertCountEqual(
+          self._getGraphOpTypes(
+              stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
+          ["scale_and_bias_and_identity", "Const", "Identity", "Pack"])
+
+  def testTwoFunctions(self):
+    """Tests if two functions are converted correctly."""
+    a = array_ops.constant([1.])
+    b = array_ops.constant([1.])
+    def _double_values(x):
+      custom = lite.OpHint("add_test")
+      x = custom.add_inputs(x)
+      output = math_ops.multiply(x, x)
+      output, = custom.add_outputs(output)
+      return output
+    output = array_ops.identity(
+        math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput")
+
+    with self.test_session() as sess:
+      # make sure one identity for each input (2) and output (2) => 2 + 2
+      # +1 for the final output
+      self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
+      stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
+      self.assertCountEqual(
+          self._getGraphOpTypes(
+              stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
+          ["add_test", "Const", "Identity", "Add"])
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py
new file mode 100644 (file)
index 0000000..7c587e3
--- /dev/null
@@ -0,0 +1,291 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Define tflite op hints (intrinsic operations).
+
+This essentially allows defining a TensorFlow API for tflite operations in
+Python with hints on how they are represented in TensorFlow Lite. This basically
+is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution
+graph and is useful for LSTMs and other complicated TensorFlow constructions
+that are difficult to pattern match in TOCO, but are represented by a single
+accelerated tflite op.
+
+Example:
+  def tflite_cool_activation(input):
+    # A cool activation function.
+    custom = tf.contrib.lite.OpHint("cool_activation")
+    input = custom.add_inputs(input)
+    output = tf.sigmoid(input) * input
+    custom.add_outputs(output)
+    return output
+
+  image = tf.placeholder(tf.float32, (1, 16, 16, 1))
+  output = tf.identity(tflite_cool_activation(image))
+
+  session = tf.Session()
+
+  graphdef_to_convert = tf.contrib.lite.convert_op_hints_to_stubs(session)
+  tflite_graph = tf.contrib.lite.toco_convert(graphdef_to_convert,
+                                              [image], [output])
+                                              [image], [output])
+  with open("/tmp/graph.fb", "wb") as fp:
+    fp.write(tflite_graph)
+
+How does it work?:
+
+OpHint is a helper that you use when defining a vanilla python function.
+It allows you to wrap arguments with tf.identities with some custom attributes.
+These attributes allow you to find the original block of ops that was created.
+For example, if you use cool_activation above you essentially get:
+
+a_input = tf.identity()
+result = tf.multiply(tf.sigmoid(a_input), a_input)
+output = tf.identity()
+
+a_input, output are identities that have parameters representing
+what argument they are, what the name of the function they should turn into
+in tf lite as well as a guid that uniquely identifies a particular invocation.
+
+Once you have built your whole tensorflow graph, you can run it and train it
+as usual, but after you have done that, you need to convert the graph into
+a form that replaces these subgraphs wrapped in identities to stub ops. These
+ops don't actually exist in the normal TensorFlow runtime, but will be
+understood by toco later.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections as _collections
+import itertools as _itertools
+import uuid as _uuid
+
+from tensorflow.contrib import framework as _framework
+from tensorflow.python.framework import ops as _ops
+from tensorflow.python.ops import array_ops as _array_ops
+from tensorflow.python.util.all_util import remove_undocumented
+
+
+class OpHint(object):
+  """A class that helps build tflite function invocations.
+
+  It allows you to take a bunch of TensorFlow ops and annotate the construction
+  such that toco knows how to convert it to tflite. This embeds a pseudo
+  function in a TensorFlow graph. This allows embedding high-level API usage
+  information in a lower level TensorFlow implementation so that an alternative
+  implementation can be substituted later.
+
+  Essentially, any "input" into this pseudo op is fed into an identity, and
+  attributes are added to that input before being used by the constituent ops
+  that make up the pseudo op. A similar process is done to any output that
+  is to be exported from the current op.
+
+  TODO(aselle): When TensorFlow functions functionality works for arbitrary
+  constructs, this mechanism can be retired and changed to use python defun's.
+  """
+
+  # Attr constants that are used for representation in the GraphDef
+  FUNCTION_NAME_ATTR = "_tflite_function_name"
+  FUNCTION_UUID_ATTR = "_tflite_function_uuid"
+  FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index"
+  FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index"
+
+  def __init__(self, function_name, **kwargs):
+    """Create a OpHint.
+
+    Args:
+      function_name: Name of the function (the custom op name in tflite)
+      **kwargs: Keyword arguments of any constant attributes for the function.
+    """
+    self._function_name = function_name
+    self._unique_function_id = _uuid.uuid1().hex  # TODO(aselle): Unique enough?
+    self._curr_input_index = 0
+    self._curr_output_index = 0
+    self._attrs_to_store_later = kwargs
+    self._stored_attrs = False
+
+  def _setattr(self, dest_op, name, value):
+    tensor_value = _ops.convert_to_tensor(value)
+    dest_op.op.node_def.attr[name].tensor.CopyFrom(
+        tensor_value.op.node_def.attr["value"].tensor)
+
+  def add_inputs(self, *args):
+    """Add a sequence of inputs to the function invocation.
+
+    Args:
+      *args: List of inputs to be converted (should be Tf.Tensor).
+    Returns:
+      Wrapped inputs (identity standins that have additional metadata). These
+      are also are also tf.Tensor's.
+    """
+
+    def augmented_identity(arg):
+      identity_op = _array_ops.identity(arg)
+      attr = identity_op.op.node_def.attr
+      attr[OpHint.FUNCTION_NAME_ATTR].s = self._function_name
+      attr[OpHint.FUNCTION_UUID_ATTR].s = self._unique_function_id
+      attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i = self._curr_input_index
+      self._curr_input_index += 1
+      return identity_op
+
+    return [augmented_identity(arg) for arg in args]
+
+  def add_outputs(self, *args):
+    """Add a sequence of outputs to the function invocation.
+
+    Args:
+      *args: List of outputs to be converted (should be tf.Tensor).
+    Returns:
+      Wrapped outputs (identity standins that have additional metadata). These
+      are also tf.Tensor's.
+    """
+
+    def augmented_identity(arg):
+      identity_op = _array_ops.identity(arg)
+      attr = identity_op.op.node_def.attr
+      attr[OpHint.FUNCTION_NAME_ATTR].s = self._function_name
+      attr[OpHint.FUNCTION_UUID_ATTR].s = self._unique_function_id
+      attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i = self._curr_output_index
+      self._curr_output_index += 1
+      return identity_op
+
+    wrapped_outputs = [augmented_identity(arg) for arg in args]
+
+    if not self._stored_attrs:
+      for key, value in self._attrs_to_store_later.iteritems():
+        self._setattr(wrapped_outputs[0], "_tflite_attr_" + key, value)
+      self._stored_attrs = True
+
+    return wrapped_outputs
+
+
+class _LiteFuncCall(object):
+  """Represent a TensorFlow Lite custom function.
+
+  This is uses to accumulate found hints in the graphdef into a single
+  conceptual unit.
+
+  Properties:
+    self.inputs: inputs to the op (hash from index # to argument)
+    self.outputs: outputs to the op (hash from index # to argument)
+    self.function_name: the tflite custom op name to use
+    self.uuid: a unique call id for this particular call  (i.e.
+      multiple function calls would have the same function_name but different
+      uuids.
+    self.params: A param name to key value for op constant data. I.e. for
+      axis on a reduction, strides on a convolution, etc.
+  """
+
+  def __init__(self):
+    self.inputs = {}
+    self.outputs = {}
+    self.function_name = None
+    self.uuid = None
+    self.params = {}
+
+  def __str__(self):
+    return "tflite function %s call %s\n\tinputs: %r\n\toutputs: %r" % (
+        self.function_name, self.uuid, self.inputs, self.outputs)
+
+
+def _find_all_hints_in_graph_def(session):
+  """Look at the current default graph and return a list of LiteFuncCall objs.
+
+  Args:
+    session: A TensorFlow session that contains the graph to convert.
+  Returns:
+    a list of `LifeFuncCall` objects in the form
+
+  """
+  func_calls = _collections.defaultdict(_LiteFuncCall)
+  seen_ops = set()
+
+  for op in session.graph.get_operations():
+    for operand in _itertools.chain(op.inputs, op.outputs):
+      if operand in seen_ops:
+        continue
+      seen_ops.add(operand)
+      attr = operand.op.node_def.attr
+      uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
+      if OpHint.FUNCTION_UUID_ATTR not in attr:
+        continue
+      call_def = func_calls[uuid]
+      call_def.uuid = uuid
+      if OpHint.FUNCTION_UUID_ATTR in attr:
+        call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
+        if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
+          call_def.inputs[attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i] = operand
+        if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
+          call_def.outputs[attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i] = operand
+
+      for a in attr:
+        if a.startswith("_tflite_attr_"):
+          # TODO(aselle): Remember the attribute tensors so we can put them
+          # in collapse.
+          call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
+
+  return func_calls
+
+
+def _tensor_name_base(full_tensor_name):
+  """Removes the device assignment code from a tensor.
+
+  e.g. _tensor_name_base("foo:3") => "foo"
+
+  Args:
+    full_tensor_name: A tensor name that is annotated with a device placement
+      (this is what tensor flow introspection gives).
+  Returns:
+    A name without any device assignment.
+  """
+  return full_tensor_name.name.split(":")[0]
+
+
+def convert_op_hints_to_stubs(session):
+  """Converts a graphdef with LiteOp hints into stub operations.
+
+  This is used to prepare for toco conversion of complex intrinsic usages.
+
+  Args:
+    session: A TensorFlow session that contains the graph to convert.
+  Returns:
+    A new graphdef with all ops contained in OpHints being replaced by
+    a single op call with the right parameters.
+  """
+  hints = _find_all_hints_in_graph_def(session)
+  current_graph_def = session.graph_def
+  for call in hints.values():
+    input_names = [None] * len(call.inputs)
+    output_names = [None] * len(call.outputs)
+    output_dtypes = [None] * len(call.outputs)
+    output_quantized = False
+    for input_index, tensor in call.inputs.items():
+      input_names[input_index] = _tensor_name_base(tensor)
+    for output_index, tensor in call.outputs.items():
+      output_names[output_index] = _tensor_name_base(tensor)
+      output_dtypes[output_index] = tensor.dtype.as_datatype_enum
+    # TODO(aselle): Support quantized flag properly
+    current_graph_def = _framework.fuse_op(
+        current_graph_def, input_names, output_names, output_dtypes,
+        output_quantized, call.uuid, call.function_name)
+    for node in current_graph_def.node:
+      if node.name == call.uuid:
+        for param, tensor in call.params.items():
+          node.attr[param].tensor.CopyFrom(tensor)
+  return current_graph_def
+
+
+_allowed_symbols = ["OpHint", "convert_op_hints_to_stubs"]
+remove_undocumented(__name__, _allowed_symbols)