Enable TOCO pip command line binding.
authorNupur Garg <nupurgarg@google.com>
Thu, 31 May 2018 00:54:02 +0000 (17:54 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 00:56:47 +0000 (17:56 -0700)
PiperOrigin-RevId: 198649827

12 files changed:
tensorflow/contrib/lite/python/BUILD
tensorflow/contrib/lite/python/convert_saved_model.py
tensorflow/contrib/lite/python/convert_saved_model_test.py
tensorflow/contrib/lite/python/lite.py
tensorflow/contrib/lite/python/lite_test.py
tensorflow/contrib/lite/python/tflite_convert.py [new file with mode: 0644]
tensorflow/contrib/lite/toco/g3doc/python_api.md
tensorflow/contrib/lite/toco/python/BUILD
tensorflow/contrib/lite/toco/python/toco_wrapper.py [deleted file]
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/setup.py

index a40e512..7e6ff6c 100644 (file)
@@ -36,6 +36,16 @@ py_test(
     ],
 )
 
+py_binary(
+    name = "tflite_convert",
+    srcs = ["tflite_convert.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//visibility:public"],
+    deps = [
+        ":lite",
+    ],
+)
+
 py_library(
     name = "lite",
     srcs = ["lite.py"],
@@ -125,6 +135,7 @@ py_library(
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
+        ":convert",
         "//tensorflow/contrib/saved_model:saved_model_py",
         "//tensorflow/python:graph_util",
         "//tensorflow/python:platform",
@@ -164,11 +175,3 @@ py_test(
         "//tensorflow/python/saved_model",
     ],
 )
-
-# Transitive dependencies of this target will be included in the pip package.
-py_library(
-    name = "tf_lite_py_pip",
-    deps = [
-        ":convert_saved_model",
-    ],
-)
index 54fec9d..b952a72 100644 (file)
@@ -18,31 +18,15 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.contrib.lite.python.convert import tensor_name
 from tensorflow.contrib.saved_model.python.saved_model import reader
 from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
 from tensorflow.core.framework import types_pb2
 from tensorflow.python.client import session
 from tensorflow.python.framework import graph_util as tf_graph_util
 from tensorflow.python.framework import ops
-from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.saved_model import loader
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import tag_constants
-
-
-def _write_and_flush_file(file_path, data_str):
-  """Writes data to file path.
-
-  Args:
-    file_path: Full path of the file to store data in.
-    data_str: Data represented as a string.
-
-  Returns: None.
-  """
-  with gfile.Open(file_path, "wb") as data_file:
-    data_file.write(data_str)
-    data_file.flush()
 
 
 def _log_tensor_details(tensor_info):
@@ -167,29 +151,10 @@ def _get_tensors(graph, signature_def_tensor_names=None,
   """
   tensors = []
   if user_tensor_names:
-    # Get the list of all of the tensors with and without the tensor index.
-    all_tensor_names = [
-        tensor.name for op in graph.get_operations() for tensor in op.outputs
-    ]
-    all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names]
-
     # Sort the tensor names.
     user_tensor_names = sorted(user_tensor_names)
 
-    # Get the tensors associated with the tensor names.
-    tensors = []
-    invalid_tensors = []
-    for name in user_tensor_names:
-      if name not in all_tensor_names_only:
-        invalid_tensors.append(name)
-      else:
-        idx = all_tensor_names_only.index(name)
-        tensors.append(graph.get_tensor_by_name(all_tensor_names[idx]))
-
-    # Throw ValueError if any user input names are not valid tensors.
-    if invalid_tensors:
-      raise ValueError("Invalid tensors '{}' were found.".format(
-          ",".join(invalid_tensors)))
+    tensors = get_tensors_from_tensor_names(graph, user_tensor_names)
   elif signature_def_tensor_names:
     tensors = [
         graph.get_tensor_by_name(name)
@@ -204,6 +169,58 @@ def _get_tensors(graph, signature_def_tensor_names=None,
   return tensors
 
 
+def get_tensors_from_tensor_names(graph, tensor_names):
+  """Gets the Tensors associated with the `tensor_names` in the provided graph.
+
+  Args:
+    graph: TensorFlow Graph.
+    tensor_names: List of strings that represent names of tensors in the graph.
+
+  Returns:
+    A list of Tensor objects in the same order the names are provided.
+
+  Raises:
+    ValueError:
+      tensor_names contains an invalid tensor name.
+  """
+  # Get the list of all of the tensors.
+  tensor_name_to_tensor = {
+      tensor_name(tensor): tensor for op in graph.get_operations()
+      for tensor in op.values()
+  }
+
+  # Get the tensors associated with tensor_names.
+  tensors = []
+  invalid_tensors = []
+  for name in tensor_names:
+    tensor = tensor_name_to_tensor.get(name)
+    if tensor is None:
+      invalid_tensors.append(name)
+    else:
+      tensors.append(tensor)
+
+  # Throw ValueError if any user input names are not valid tensors.
+  if invalid_tensors:
+    raise ValueError("Invalid tensors '{}' were found.".format(
+        ",".join(invalid_tensors)))
+  return tensors
+
+
+def set_tensor_shapes(tensors, shapes):
+  """Sets Tensor shape for each tensor if the shape is defined.
+
+  Args:
+    tensors: TensorFlow ops.Tensor.
+    shapes: Dict of strings representing input tensor names to list of
+      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+  """
+  if shapes:
+    for tensor in tensors:
+      shape = shapes.get(tensor.name)
+      if shape is not None:
+        tensor.set_shape(shapes[tensor.name])
+
+
 def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
                        output_arrays, tag_set, signature_key):
   """Converts a SavedModel to a frozen graph.
@@ -211,15 +228,14 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
   Args:
     saved_model_dir: SavedModel directory to convert.
     input_arrays: List of input tensors to freeze graph with. Uses input arrays
-      from SignatureDef when none are provided. (default None)
-    input_shapes: Map of strings representing input tensor names to list of
+      from SignatureDef when none are provided.
+    input_shapes: Dict of strings representing input tensor names to list of
       integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
       Automatically determined when input shapes is None (e.g., {"foo" : None}).
-      (default None)
     output_arrays: List of output tensors to freeze graph with. Uses output
-      arrays from SignatureDef when none are provided. (default None)
+      arrays from SignatureDef when none are provided.
     tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
-      analyze. All tags in the tag set must be present. (default "serve")
+      analyze. All tags in the tag set must be present.
     signature_key: Key identifying SignatureDef containing inputs and outputs.
 
   Returns:
@@ -233,14 +249,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
       signature_key is not in the MetaGraphDef.
       input_shapes does not match the length of input_arrays.
       input_arrays or output_arrays are not valid.
-      Unable to load Session.
   """
-  # Set default values for inputs if they are set to None.
-  if signature_key is None:
-    signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
-  if tag_set is None:
-    tag_set = set([tag_constants.SERVING])
-
   # Read SignatureDef.
   meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
   signature_def = _get_signature_def(meta_graph, signature_key)
@@ -255,19 +264,10 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
     # TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
     in_tensors = _get_tensors(graph, inputs, input_arrays)
     out_tensors = _get_tensors(graph, outputs, output_arrays)
-
-    # Gets fully defined tensor shape.
-    for tensor in in_tensors:
-      if (input_shapes and tensor.name in input_shapes and
-          input_shapes[tensor.name] is not None):
-        shape = input_shapes[tensor.name]
-      else:
-        shape = tensor.get_shape().as_list()
-      tensor.set_shape(shape)
+    set_tensor_shapes(in_tensors, input_shapes)
 
     output_names = [node.split(":")[0] for node in outputs]
     frozen_graph_def = tf_graph_util.convert_variables_to_constants(
         sess, graph.as_graph_def(), output_names)
 
     return frozen_graph_def, in_tensors, out_tensors
-  raise ValueError("Unable to load Session.")
index f69381d..80e5dc6 100644 (file)
@@ -41,9 +41,58 @@ from tensorflow.python.ops.losses import losses
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import saved_model
 from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import tag_constants
 from tensorflow.python.training import training as train
 
 
+class TensorFunctionsTest(test_util.TensorFlowTestCase):
+
+  def testGetTensorsValid(self):
+    in_tensor = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32)
+    _ = in_tensor + in_tensor
+    sess = session.Session()
+
+    tensors = convert_saved_model.get_tensors_from_tensor_names(
+        sess.graph, ["Placeholder"])
+    self.assertEqual("Placeholder:0", tensors[0].name)
+
+  def testGetTensorsInvalid(self):
+    in_tensor = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32)
+    _ = in_tensor + in_tensor
+    sess = session.Session()
+
+    with self.assertRaises(ValueError) as error:
+      convert_saved_model.get_tensors_from_tensor_names(sess.graph,
+                                                        ["invalid-input"])
+    self.assertEqual("Invalid tensors 'invalid-input' were found.",
+                     str(error.exception))
+
+  def testSetTensorShapeValid(self):
+    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+    self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+    convert_saved_model.set_tensor_shapes([tensor],
+                                          {"Placeholder:0": [5, 3, 5]})
+    self.assertEqual([5, 3, 5], tensor.shape.as_list())
+
+  def testSetTensorShapeInvalid(self):
+    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+    self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+    convert_saved_model.set_tensor_shapes([tensor],
+                                          {"invalid-input": [5, 3, 5]})
+    self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+  def testSetTensorShapeEmpty(self):
+    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+    self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+    convert_saved_model.set_tensor_shapes([tensor], {})
+    self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+
 class FreezeSavedModelTest(test_util.TensorFlowTestCase):
 
   def _createSimpleSavedModel(self, shape):
@@ -93,6 +142,10 @@ class FreezeSavedModelTest(test_util.TensorFlowTestCase):
                          output_arrays=None,
                          tag_set=None,
                          signature_key=None):
+    if tag_set is None:
+      tag_set = set([tag_constants.SERVING])
+    if signature_key is None:
+      signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
     graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model(
         saved_model_dir=saved_model_dir,
         input_arrays=input_arrays,
@@ -390,7 +443,7 @@ class FreezeSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
         input_arrays=None,
         input_shapes=None,
         output_arrays=["Softmax"],
-        tag_set=None,
+        tag_set=set([tag_constants.SERVING]),
         signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
 
     self.assertTrue(result)
index f7f2d40..6510d74 100644 (file)
@@ -33,15 +33,22 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from google.protobuf import text_format as _text_format
+from google.protobuf.message import DecodeError
 from tensorflow.contrib.lite.python import lite_constants as constants
 from tensorflow.contrib.lite.python.convert import tensor_name
 from tensorflow.contrib.lite.python.convert import toco_convert
 from tensorflow.contrib.lite.python.convert import toco_convert_protos  # pylint: disable=unused-import
 from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model
+from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names
+from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes
 from tensorflow.contrib.lite.python.interpreter import Interpreter  # pylint: disable=unused-import
 from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs  # pylint: disable=unused-import
 from tensorflow.contrib.lite.python.op_hint import OpHint  # pylint: disable=unused-import
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.python.client import session as _session
 from tensorflow.python.framework import graph_util as tf_graph_util
+from tensorflow.python.framework.importer import import_graph_def
 from tensorflow.python.ops.variables import global_variables_initializer
 from tensorflow.python.saved_model import signature_constants
 from tensorflow.python.saved_model import tag_constants
@@ -55,13 +62,15 @@ class TocoConverter(object):
 
   Attributes:
 
-    inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
-      (default FLOAT)
-    output_format: Type of data to write (currently must be TFLITE or
-      GRAPHVIZ_DOT). (default TFLITE)
+    inference_type: Target data type of arrays in the output file. Currently
+      must be `{FLOAT, QUANTIZED_UINT8}`.  (default FLOAT)
+    output_format: Output file format. Currently must be `{TFLITE,
+      GRAPHVIZ_DOT}`. (default TFLITE)
     quantized_input_stats: The mean and std deviation of training data for each
       input tensor. Only needed if `inference_type` is `QUANTIZED_UINT8`.
-      (default None)
+      Dict of strings representing input tensor names to a tuple of integers
+      representing the quantization stats (e.g., {"foo" : (0., 1.)}).
+      (default {})
     drop_control_dependency: Boolean indicating whether to drop control
       dependencies silently. This is due to TFLite not supporting control
       dependencies. (default True)
@@ -70,11 +79,17 @@ class TocoConverter(object):
 
   Example usage:
 
-    # Converting a frozen graph.
+    # Converting a GraphDef from session.
     converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
     tflite_model = converter.convert()
     open("converted_model.tflite", "wb").write(tflite_model)
 
+    # Converting a GraphDef from file.
+    converter = lite.TocoConverter.from_flatbuffer_file(
+      graph_def_file, input_arrays, output_arrays)
+    tflite_model = converter.convert()
+    open("converted_model.tflite", "wb").write(tflite_model)
+
     # Converting a SavedModel.
     converter = lite.TocoConverter.from_saved_model(saved_model_dir)
     tflite_model = converter.convert()
@@ -95,16 +110,12 @@ class TocoConverter(object):
     self._output_tensors = output_tensors
     self.inference_type = constants.FLOAT
     self.output_format = constants.TFLITE
-    self.quantized_input_stats = None
+    self.quantized_input_stats = {}
     self.drop_control_dependency = True
     self.allow_custom_ops = False
 
   @classmethod
-  def from_session(cls,
-                   sess,
-                   input_tensors,
-                   output_tensors,
-                   freeze_variables=False):
+  def from_session(cls, sess, input_tensors, output_tensors):
     """Creates a TocoConverter class from a TensorFlow Session.
 
     Args:
@@ -112,56 +123,102 @@ class TocoConverter(object):
       input_tensors: List of input tensors. Type and shape are computed using
         `foo.get_shape()` and `foo.dtype`.
       output_tensors: List of output tensors (only .name is used from this).
-      freeze_variables: Boolean indicating whether the variables need to be
-        converted into constants via the freeze_graph.py script.
-        (default False)
 
     Returns:
       TocoConverter class.
     """
+    graph_def = _freeze_graph(sess, output_tensors)
+    return cls(graph_def, input_tensors, output_tensors)
+
+  @classmethod
+  def from_flatbuffer_file(cls,
+                           graph_def_file,
+                           input_arrays,
+                           output_arrays,
+                           input_shapes=None):
+    """Creates a TocoConverter class from a file containing a GraphDef.
+
+    Args:
+      graph_def_file: Full filepath of file containing TensorFlow GraphDef.
+      input_arrays: List of input tensors to freeze graph with.
+      output_arrays: List of output tensors to freeze graph with.
+      input_shapes: Dict of strings representing input tensor names to list of
+        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+        Automatically determined when input shapes is None (e.g., {"foo" :
+        None}). (default None)
 
-    # Get GraphDef.
-    if freeze_variables:
+    Returns:
+      TocoConverter class.
+
+    Raises:
+      ValueError:
+        Unable to parse input file.
+        The graph is not frozen.
+        input_arrays or output_arrays contains an invalid tensor name.
+    """
+    with _session.Session() as sess:
       sess.run(global_variables_initializer())
-      output_arrays = [tensor_name(tensor) for tensor in output_tensors]
-      graph_def = tf_graph_util.convert_variables_to_constants(
-          sess, sess.graph_def, output_arrays)
-    else:
-      graph_def = sess.graph_def
 
-    # Create TocoConverter class.
-    return cls(graph_def, input_tensors, output_tensors)
+      # Read GraphDef from file.
+      graph_def = _graph_pb2.GraphDef()
+      with open(graph_def_file, "rb") as f:
+        file_content = f.read()
+      try:
+        graph_def.ParseFromString(file_content)
+      except (_text_format.ParseError, DecodeError):
+        try:
+          print("Ignore 'tcmalloc: large alloc' warnings.")
+          _text_format.Merge(file_content, graph_def)
+        except (_text_format.ParseError, DecodeError):
+          raise ValueError(
+              "Unable to parse input file '{}'.".format(graph_def_file))
+      sess.graph.as_default()
+      import_graph_def(graph_def, name="")
+
+      # Get input and output tensors.
+      input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays)
+      output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays)
+      set_tensor_shapes(input_tensors, input_shapes)
+
+      # Check if graph is frozen.
+      if not _is_frozen_graph(sess):
+        raise ValueError("Please freeze the graph using freeze_graph.py")
+
+      # Create TocoConverter class.
+      return cls(sess.graph_def, input_tensors, output_tensors)
 
   @classmethod
-  def from_saved_model(
-      cls,
-      saved_model_dir,
-      input_arrays=None,
-      input_shapes=None,
-      output_arrays=None,
-      tag_set=None,
-      signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
+  def from_saved_model(cls,
+                       saved_model_dir,
+                       input_arrays=None,
+                       input_shapes=None,
+                       output_arrays=None,
+                       tag_set=None,
+                       signature_key=None):
     """Creates a TocoConverter class from a SavedModel.
 
     Args:
       saved_model_dir: SavedModel directory to convert.
       input_arrays: List of input tensors to freeze graph with. Uses input
         arrays from SignatureDef when none are provided. (default None)
-      input_shapes: Map of strings representing input tensor names to list of
-        integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+      input_shapes: Dict of strings representing input tensor names to list of
+        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
         Automatically determined when input shapes is None (e.g., {"foo" :
         None}). (default None)
       output_arrays: List of output tensors to freeze graph with. Uses output
         arrays from SignatureDef when none are provided. (default None)
       tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
-        analyze. All tags in the tag set must be present. (default "serve")
+        analyze. All tags in the tag set must be present. (default set("serve"))
       signature_key: Key identifying SignatureDef containing inputs and outputs.
+        (default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
 
     Returns:
       TocoConverter class.
     """
     if tag_set is None:
       tag_set = set([tag_constants.SERVING])
+    if signature_key is None:
+      signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
 
     result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
                                 output_arrays, tag_set, signature_key)
@@ -189,6 +246,24 @@ class TocoConverter(object):
       elif shape[0] is None:
         self._set_batch_size(batch_size=1)
 
+    # Get quantization stats. Ensures there is one stat per name if the stats
+    # are specified.
+    if self.quantized_input_stats:
+      quantized_stats = []
+      invalid_stats = []
+      for tensor in self._input_tensors:
+        name = tensor_name(tensor)
+        if name in self.quantized_input_stats:
+          quantized_stats.append(self.quantized_input_stats[name])
+        else:
+          invalid_stats.append(name)
+
+      if invalid_stats:
+        raise ValueError("Quantization input stats are not available for input "
+                         "tensors '{0}'.".format(",".join(invalid_stats)))
+    else:
+      quantized_stats = None
+
     # Converts model.
     result = toco_convert(
         input_data=self._graph_def,
@@ -197,7 +272,7 @@ class TocoConverter(object):
         inference_type=self.inference_type,
         input_format=constants.TENSORFLOW_GRAPHDEF,
         output_format=self.output_format,
-        quantized_input_stats=self.quantized_input_stats,
+        quantized_input_stats=quantized_stats,
         drop_control_dependency=self.drop_control_dependency)
     return result
 
@@ -212,3 +287,43 @@ class TocoConverter(object):
       shape = tensor.get_shape().as_list()
       shape[0] = batch_size
       tensor.set_shape(shape)
+
+
+def _is_frozen_graph(sess):
+  """Determines if the graph is frozen.
+
+  Determines if a graph has previously been frozen by checking for any
+  operations of type Variable*. If variables are found, the graph is not frozen.
+
+  Args:
+    sess: TensorFlow Session.
+
+  Returns:
+    Bool.
+  """
+  for op in sess.graph.get_operations():
+    if op.type.startswith("Variable"):
+      return False
+  return True
+
+
+def _freeze_graph(sess, output_tensors):
+  """Returns a frozen GraphDef.
+
+  Freezes a graph with Variables in it. Otherwise the existing GraphDef is
+  returned.
+
+  Args:
+    sess: TensorFlow Session.
+    output_tensors: List of output tensors (only .name is used from this).
+
+  Returns:
+    Frozen GraphDef.
+  """
+  if not _is_frozen_graph(sess):
+    sess.run(global_variables_initializer())
+    output_arrays = [tensor_name(tensor) for tensor in output_tensors]
+    return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def,
+                                                        output_arrays)
+  else:
+    return sess.graph_def
index 2f3105f..28386ec 100644 (file)
@@ -29,8 +29,10 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import saved_model
+from tensorflow.python.training.training_util import write_graph
 
 
 class FromSessionTest(test_util.TensorFlowTestCase):
@@ -65,16 +67,22 @@ class FromSessionTest(test_util.TensorFlowTestCase):
     self.assertEqual((0., 0.), output_details[0]['quantization'])
 
   def testQuantization(self):
-    in_tensor = array_ops.placeholder(
-        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
+    in_tensor_1 = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+    in_tensor_2 = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
     out_tensor = array_ops.fake_quant_with_min_max_args(
-        in_tensor + in_tensor, min=0., max=1., name='output')
+        in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
     sess = session.Session()
 
     # Convert model and ensure model is not None.
-    converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+    converter = lite.TocoConverter.from_session(
+        sess, [in_tensor_1, in_tensor_2], [out_tensor])
     converter.inference_type = lite_constants.QUANTIZED_UINT8
-    converter.quantized_input_stats = [(0., 1.)]  # mean, std_dev
+    converter.quantized_input_stats = {
+        'inputA': (0., 1.),
+        'inputB': (0., 1.)
+    }  # mean, std_dev
     tflite_model = converter.convert()
     self.assertTrue(tflite_model)
 
@@ -83,13 +91,19 @@ class FromSessionTest(test_util.TensorFlowTestCase):
     interpreter.allocate_tensors()
 
     input_details = interpreter.get_input_details()
-    self.assertEqual(1, len(input_details))
-    self.assertEqual('input', input_details[0]['name'])
+    self.assertEqual(2, len(input_details))
+    self.assertEqual('inputA', input_details[0]['name'])
     self.assertEqual(np.uint8, input_details[0]['dtype'])
     self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
     self.assertEqual((1., 0.),
                      input_details[0]['quantization'])  # scale, zero_point
 
+    self.assertEqual('inputB', input_details[1]['name'])
+    self.assertEqual(np.uint8, input_details[1]['dtype'])
+    self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
+    self.assertEqual((1., 0.),
+                     input_details[1]['quantization'])  # scale, zero_point
+
     output_details = interpreter.get_output_details()
     self.assertEqual(1, len(output_details))
     self.assertEqual('output', output_details[0]['name'])
@@ -97,6 +111,26 @@ class FromSessionTest(test_util.TensorFlowTestCase):
     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
     self.assertTrue(output_details[0]['quantization'][0] > 0)  # scale
 
+  def testQuantizationInvalid(self):
+    in_tensor_1 = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+    in_tensor_2 = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+    out_tensor = array_ops.fake_quant_with_min_max_args(
+        in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
+    sess = session.Session()
+
+    # Convert model and ensure model is not None.
+    converter = lite.TocoConverter.from_session(
+        sess, [in_tensor_1, in_tensor_2], [out_tensor])
+    converter.inference_type = lite_constants.QUANTIZED_UINT8
+    converter.quantized_input_stats = {'inputA': (0., 1.)}  # mean, std_dev
+    with self.assertRaises(ValueError) as error:
+      converter.convert()
+    self.assertEqual(
+        'Quantization input stats are not available for input tensors '
+        '\'inputB\'.', str(error.exception))
+
   def testBatchSizeInvalid(self):
     in_tensor = array_ops.placeholder(
         shape=[None, 16, 16, 3], dtype=dtypes.float32)
@@ -152,8 +186,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
     sess = session.Session()
 
     # Convert model and ensure model is not None.
-    converter = lite.TocoConverter.from_session(
-        sess, [in_tensor], [out_tensor], freeze_variables=True)
+    converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
     tflite_model = converter.convert()
     self.assertTrue(tflite_model)
 
@@ -188,6 +221,135 @@ class FromSessionTest(test_util.TensorFlowTestCase):
     self.assertTrue(graphviz_output)
 
 
+class FromFlatbufferFile(test_util.TensorFlowTestCase):
+
+  def testFloat(self):
+    in_tensor = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32)
+    _ = in_tensor + in_tensor
+    sess = session.Session()
+
+    # Write graph to file.
+    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+    write_graph(sess.graph_def, '', graph_def_file, False)
+
+    # Convert model and ensure model is not None.
+    converter = lite.TocoConverter.from_flatbuffer_file(
+        graph_def_file, ['Placeholder'], ['add'])
+    tflite_model = converter.convert()
+    self.assertTrue(tflite_model)
+
+    # Check values from converted model.
+    interpreter = Interpreter(model_content=tflite_model)
+    interpreter.allocate_tensors()
+
+    input_details = interpreter.get_input_details()
+    self.assertEqual(1, len(input_details))
+    self.assertEqual('Placeholder', input_details[0]['name'])
+    self.assertEqual(np.float32, input_details[0]['dtype'])
+    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+    self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+    output_details = interpreter.get_output_details()
+    self.assertEqual(1, len(output_details))
+    self.assertEqual('add', output_details[0]['name'])
+    self.assertEqual(np.float32, output_details[0]['dtype'])
+    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+    self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+  def testFloatWithShapesArray(self):
+    in_tensor = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32)
+    _ = in_tensor + in_tensor
+    sess = session.Session()
+
+    # Write graph to file.
+    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+    write_graph(sess.graph_def, '', graph_def_file, False)
+
+    # Convert model and ensure model is not None.
+    converter = lite.TocoConverter.from_flatbuffer_file(
+        graph_def_file, ['Placeholder'], ['add'],
+        input_shapes={'Placeholder': [1, 16, 16, 3]})
+    tflite_model = converter.convert()
+    self.assertTrue(tflite_model)
+
+    # Check values from converted model.
+    interpreter = Interpreter(model_content=tflite_model)
+    interpreter.allocate_tensors()
+
+    input_details = interpreter.get_input_details()
+    self.assertEqual(1, len(input_details))
+    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+
+  def testFreezeGraph(self):
+    in_tensor = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32)
+    var = variable_scope.get_variable(
+        'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
+    _ = in_tensor + var
+    sess = session.Session()
+
+    # Write graph to file.
+    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+    write_graph(sess.graph_def, '', graph_def_file, False)
+
+    # Ensure the graph with variables cannot be converted.
+    with self.assertRaises(ValueError) as error:
+      lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'],
+                                              ['add'])
+    self.assertEqual('Please freeze the graph using freeze_graph.py',
+                     str(error.exception))
+
+  def testPbtxt(self):
+    in_tensor = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32)
+    _ = in_tensor + in_tensor
+    sess = session.Session()
+
+    # Write graph to file.
+    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
+    write_graph(sess.graph_def, '', graph_def_file, True)
+
+    # Convert model and ensure model is not None.
+    converter = lite.TocoConverter.from_flatbuffer_file(
+        graph_def_file, ['Placeholder'], ['add'])
+    tflite_model = converter.convert()
+    self.assertTrue(tflite_model)
+
+    # Check values from converted model.
+    interpreter = Interpreter(model_content=tflite_model)
+    interpreter.allocate_tensors()
+
+    input_details = interpreter.get_input_details()
+    self.assertEqual(1, len(input_details))
+    self.assertEqual('Placeholder', input_details[0]['name'])
+    self.assertEqual(np.float32, input_details[0]['dtype'])
+    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+    self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+    output_details = interpreter.get_output_details()
+    self.assertEqual(1, len(output_details))
+    self.assertEqual('add', output_details[0]['name'])
+    self.assertEqual(np.float32, output_details[0]['dtype'])
+    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+    self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+  def testInvalidFile(self):
+    graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
+    with gfile.Open(graph_def_file, 'wb') as temp_file:
+      temp_file.write('bad data')
+      temp_file.flush()
+
+    # Attempts to convert the invalid model.
+    with self.assertRaises(ValueError) as error:
+      lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'],
+                                              ['add'])
+    self.assertEqual(
+        'Unable to parse input file \'{}\'.'.format(graph_def_file),
+        str(error.exception))
+
+
 class FromSavedModelTest(test_util.TensorFlowTestCase):
 
   def _createSavedModel(self, shape):
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
new file mode 100644 (file)
index 0000000..79be5cd
--- /dev/null
@@ -0,0 +1,273 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python command line interface for running TOCO."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
+from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
+from tensorflow.python.platform import app
+
+
+def _parse_array(values):
+  if values:
+    return values.split(",")
+
+
+def _parse_int_array(values):
+  if values:
+    return [int(val) for val in values.split(",")]
+
+
+def _parse_set(values):
+  if values:
+    return set(values.split(","))
+
+
+def _get_toco_converter(flags):
+  """Makes a TocoConverter object based on the flags provided.
+
+  Args:
+    flags: argparse.Namespace object containing TFLite flags.
+
+  Returns:
+    TocoConverter object.
+  """
+  # Parse input and output arrays.
+  input_arrays = _parse_array(flags.input_arrays)
+  input_shapes = None
+  if flags.input_shapes:
+    input_shapes_list = [
+        _parse_int_array(shape) for shape in flags.input_shapes.split(":")
+    ]
+    input_shapes = dict(zip(input_arrays, input_shapes_list))
+  output_arrays = _parse_array(flags.output_arrays)
+
+  converter_kwargs = {
+      "input_arrays": input_arrays,
+      "input_shapes": input_shapes,
+      "output_arrays": output_arrays
+  }
+
+  # Create TocoConverter.
+  if flags.graph_def_file:
+    converter_fn = lite.TocoConverter.from_flatbuffer_file
+    converter_kwargs["graph_def_file"] = flags.graph_def_file
+  elif flags.saved_model_dir:
+    converter_fn = lite.TocoConverter.from_saved_model
+    converter_kwargs["saved_model_dir"] = flags.saved_model_dir
+    converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
+    converter_kwargs["signature_key"] = flags.saved_model_signature_key
+
+  return converter_fn(**converter_kwargs)
+
+
+def _convert_model(flags):
+  """Calls function to convert the TensorFlow model into a TFLite model.
+
+  Args:
+    flags: argparse.Namespace object.
+  """
+  # Create converter.
+  converter = _get_toco_converter(flags)
+  if flags.inference_type:
+    converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type)
+  if flags.output_format:
+    converter.output_format = _toco_flags_pb2.FileFormat.Value(
+        flags.output_format)
+
+  if flags.mean_values and flags.std_dev_values:
+    input_arrays = _parse_array(flags.input_arrays)
+    std_dev_values = _parse_int_array(flags.std_dev_values)
+    mean_values = _parse_int_array(flags.mean_values)
+    quant_stats = zip(mean_values, std_dev_values)
+    converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
+
+  if flags.drop_control_dependency:
+    converter.drop_control_dependency = flags.drop_control_dependency
+  if flags.allow_custom_ops:
+    converter.allow_custom_ops = flags.allow_custom_ops
+
+  # Convert model.
+  output_data = converter.convert()
+  with open(flags.output_file, "wb") as f:
+    f.write(output_data)
+
+
+def _check_flags(flags, unparsed):
+  """Checks the parsed and unparsed flags to ensure they are valid.
+
+  Displays warnings for unparsed flags. Raises an error for parsed flags that
+  don't meet the required conditions.
+
+  Args:
+    flags: argparse.Namespace object containing TFLite flags.
+    unparsed: List of unparsed flags.
+
+  Raises:
+    ValueError: Invalid flags.
+  """
+  # Check unparsed flags for common mistakes based on previous TOCO.
+  if unparsed:
+    print("tflite_convert: warning: Unable to parse following flags "
+          "'{}'".format(",".join(unparsed)))
+    for flag in unparsed:
+      if "--input_file=" in flag:
+        print("tflite_convert: warning: Use --graph_def_file instead of "
+              "--input_file")
+      if "--std_values=" in flag:
+        print("tflite_convert: warning: Use --std_dev_values instead of "
+              "--std_values")
+
+  # Check that flags are valid.
+  if flags.graph_def_file and (not flags.input_arrays or
+                               not flags.output_arrays):
+    raise ValueError("--input_arrays and --output_arrays are required with "
+                     "--graph_def_file")
+
+  if flags.input_shapes:
+    if not flags.input_arrays:
+      raise ValueError("--input_shapes must be used with --input_arrays")
+    if flags.input_shapes.count(":") != flags.input_arrays.count(","):
+      raise ValueError("--input_shapes and --input_arrays must have the same "
+                       "number of items")
+
+  if flags.std_dev_values or flags.mean_values:
+    if bool(flags.std_dev_values) != bool(flags.mean_values):
+      raise ValueError("--std_dev_values and --mean_values must be used "
+                       "together")
+    if not flags.input_arrays:
+      raise ValueError("--std_dev_values and --mean_values must be used with "
+                       "--input_arrays")
+    if (flags.std_dev_values.count(",") != flags.mean_values.count(",") or
+        flags.std_dev_values.count(",") != flags.input_arrays.count(",")):
+      raise ValueError("--std_dev_values, --mean_values, and --input_arrays "
+                       "must have the same number of items")
+
+
+def run_main(_):
+  """Main in toco_convert.py."""
+  parser = argparse.ArgumentParser(
+      description=("Command line tool to run TensorFlow Lite Optimizing "
+                   "Converter (TOCO)."))
+
+  # Output file flag.
+  parser.add_argument(
+      "--output_file",
+      type=str,
+      help="Full filepath of the output file.",
+      required=True)
+
+  # Input file flags.
+  input_file_group = parser.add_mutually_exclusive_group(required=True)
+  input_file_group.add_argument(
+      "--graph_def_file",
+      type=str,
+      help="Full filepath of file containing TensorFlow GraphDef.")
+  input_file_group.add_argument(
+      "--saved_model_dir",
+      type=str,
+      help="Full filepath of directory containing the SavedModel.")
+
+  # Model format flags.
+  parser.add_argument(
+      "--output_format",
+      type=str,
+      choices=["TFLITE", "GRAPHVIZ_DOT"],
+      help="Output file format.")
+  parser.add_argument(
+      "--inference_type",
+      type=str,
+      choices=["FLOAT", "QUANTIZED_UINT8"],
+      help="Target data type of arrays in the output file.")
+
+  # Input and output arrays flags.
+  parser.add_argument(
+      "--input_arrays",
+      type=str,
+      help="Names of the output arrays, comma-separated.")
+  parser.add_argument(
+      "--input_shapes",
+      type=str,
+      help="Shapes corresponding to --input_arrays, colon-separated.")
+  parser.add_argument(
+      "--output_arrays",
+      type=str,
+      help="Names of the output arrays, comma-separated.")
+
+  # SavedModel related flags.
+  parser.add_argument(
+      "--saved_model_tag_set",
+      type=str,
+      help=("Set of tags identifying the MetaGraphDef within the SavedModel "
+            "to analyze. All tags must be present. (default \"serve\")"))
+  parser.add_argument(
+      "--saved_model_signature_key",
+      type=str,
+      help=("Key identifying SignatureDef containing inputs and outputs. "
+            "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
+
+  # Quantization flags.
+  parser.add_argument(
+      "--std_dev_values",
+      type=str,
+      help=("Standard deviation of training data for each input tensor, "
+            "comma-separated. Used for quantization. (default None)"))
+  parser.add_argument(
+      "--mean_values",
+      type=str,
+      help=("Mean of training data for each input tensor, comma-separated. "
+            "Used for quantization. (default None)"))
+
+  # Graph manipulation flags.
+  parser.add_argument(
+      "--drop_control_dependency",
+      type=bool,
+      help=("Boolean indicating whether to drop control dependencies silently. "
+            "This is due to TensorFlow Lite not supporting control "
+            "dependencies. (default True)"))
+  parser.add_argument(
+      "--allow_custom_ops",
+      type=bool,
+      help=("Boolean indicating whether to allow custom operations. When false "
+            "any unknown operation is an error. When true, custom ops are "
+            "created for any op that is unknown. The developer will need to "
+            "provide these to the TensorFlow Lite runtime with a custom "
+            "resolver. (default False)"))
+
+  tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
+  try:
+    _check_flags(tflite_flags, unparsed)
+  except ValueError as e:
+    parser.print_usage()
+    file_name = os.path.basename(sys.argv[0])
+    sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
+    sys.exit(1)
+  _convert_model(tflite_flags)
+
+
+def main():
+  app.run(main=run_main, argv=sys.argv[:1])
+
+
+if __name__ == "__main__":
+  main()
index 29a83bd..e5f6a0b 100644 (file)
@@ -12,8 +12,8 @@ Table of contents:
 *   [High-level overview](#high-level-overview)
 *   [API](#api)
 *   [Basic examples](#basic)
-    *   [Exporting a GraphDef with constants](#basic-graphdef-const)
-    *   [Exporting a GraphDef with variables](#basic-graphdef-var)
+    *   [Exporting a GraphDef from tf.Session](#basic-graphdef-sess)
+    *   [Exporting a GraphDef from file](#basic-graphdef-file)
     *   [Exporting a SavedModel](#basic-savedmodel)
 *   [Complex examples](#complex)
     *   [Exporting a quantized GraphDef](#complex-quant)
@@ -50,17 +50,17 @@ possible.
 The following section shows examples of how to convert a basic float-point model
 from each of the supported data formats into a TensorFlow Lite FlatBuffers.
 
-### Exporting a GraphDef with constants <a name="basic-graphdef-const"></a>
+### Exporting a GraphDef from tf.Session <a name="basic-graphdef-sess"></a>
 
-The following example shows how to convert a TensorFlow GraphDef with constants
-into a TensorFlow Lite FlatBuffer.
+The following example shows how to convert a TensorFlow GraphDef into a
+TensorFlow Lite FlatBuffer from a `tf.Session` object.
 
 ```python
 import tensorflow as tf
 
 img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
-val = img + const
+var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3))
+val = img + var
 out = tf.identity(val, name="out")
 
 with tf.Session() as sess:
@@ -69,25 +69,28 @@ with tf.Session() as sess:
   open("converted_model.tflite", "wb").write(tflite_model)
 ```
 
-### Exporting a GraphDef with variables <a name="basic-graphdef-var"></a>
+### Exporting a GraphDef from file <a name="basic-graphdef-file"></a>
 
-If a model has variables, they need to be turned into constants through a
-process known as freezing. It can be accomplished by setting `freeze_variables`
-to `True` as shown in the example below.
+The following example shows how to convert a TensorFlow GraphDef into a
+TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and
+`.pbtxt` files are accepted.
+
+The example uses
+[Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz).
+The function only supports GraphDefs frozen via
+[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py).
 
 ```python
 import tensorflow as tf
 
-img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3))
-val = img + var
-out = tf.identity(val, name="out")
+graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb"
+input_arrays = ["input"]
+output_arrays = ["MobilenetV1/Predictions/Softmax"]
 
-with tf.Session() as sess:
-  converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out],
-                                                        freeze_variables=True)
-  tflite_model = converter.convert()
-  open("converted_model.tflite", "wb").write(tflite_model)
+converter = tf.contrib.lite.TocoConverter.from_flatbuffer_file(
+  graph_def_file, input_arrays, output_arrays)
+tflite_model = converter.convert()
+open("converted_model.tflite", "wb").write(tflite_model)
 ```
 
 ### Exporting a SavedModel <a name="basic-savedmodel"></a>
@@ -111,8 +114,8 @@ available by running `help(tf.contrib.lite.TocoConverter)`.
 ## Complex examples <a name="complex"></a>
 
 For models where the default value of the attributes is not sufficient, the
-variables values should be set before calling `convert()`. In order to call any
-constants use `tf.contrib.lite.constants.<CONSTANT_NAME>` as seen below with
+attribute's values should be set before calling `convert()`. In order to call
+any constants use `tf.contrib.lite.constants.<CONSTANT_NAME>` as seen below with
 `QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TocoConverter)` in the Python
 terminal for detailed documentation on the attributes.
 
@@ -135,7 +138,7 @@ out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output")
 with tf.Session() as sess:
   converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
   converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
-  converter.quantized_input_stats = [(0., 1.)]  # mean, std_dev
+  converter.quantized_input_stats = {"img" : (0., 1.)}  # mean, std_dev
   tflite_model = converter.convert()
   open("converted_model.tflite", "wb").write(tflite_model)
 ```
index 8cac568..a954f1d 100644 (file)
@@ -41,12 +41,6 @@ py_binary(
     ],
 )
 
-py_binary(
-    name = "toco_wrapper",
-    srcs = ["toco_wrapper.py"],
-    srcs_version = "PY2AND3",
-)
-
 tf_py_test(
     name = "toco_from_protos_test",
     srcs = ["toco_from_protos_test.py"],
diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/lite/toco/python/toco_wrapper.py
deleted file mode 100644 (file)
index 6d6b500..0000000
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Wrapper for runninmg toco binary embedded in pip site-package.
-
-NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/.
-It can only install Python "console-scripts." This will work as a console
-script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS).
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-
-def main():
-  # Pip installs the binary in aux-bin off of main site-package install.
-  # Just find it and exec, passing all arguments in the process.
-  # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary.
-  print("""TOCO from pip install is currently not working on command line.
-Please use the python TOCO API or use
-bazel run tensorflow/contrib/lite:toco -- <args> from a TensorFlow source dir.
-""")
-  sys.exit(1)
-  # TODO(aselle): Replace this when we find a way to run toco without
-  # blowing up executable size.
-  # binary = os.path.join(tf.__path__[0], 'aux-bin/toco')
-  # os.execvp(binary, sys.argv)
index 677ea65..e113565 100644 (file)
@@ -173,9 +173,7 @@ sh_binary(
         "//conditions:default": COMMON_PIP_DEPS + [
             ":simple_console",
             "//tensorflow/contrib/lite/python:interpreter_test_data",
-            "//tensorflow/contrib/lite/python:tf_lite_py_pip",
-            "//tensorflow/contrib/lite/toco:toco",
-            "//tensorflow/contrib/lite/toco/python:toco_wrapper",
+            "//tensorflow/contrib/lite/python:tflite_convert",
             "//tensorflow/contrib/lite/toco/python:toco_from_protos",
         ],
     }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([
index 1a83c6e..0c4065b 100755 (executable)
@@ -148,9 +148,7 @@ function main() {
     fi
     mkdir "${TMPDIR}/tensorflow/aux-bin"
     # Install toco as a binary in aux-bin.
-    # TODO(aselle): Re-enable this when we find a way to do it without doubling
-    # the whl size (over the limit).
-    # cp bazel-bin/tensorflow/contrib/lite/toco/toco ${TMPDIR}/tensorflow/aux-bin/
+    cp bazel-bin/tensorflow/contrib/lite/python/tflite_convert ${TMPDIR}/tensorflow/aux-bin/
   fi
 
   # protobuf pip package doesn't ship with header files. Copy the headers
index 70e6662..d25a9e7 100644 (file)
@@ -95,7 +95,8 @@ if sys.version_info < (3, 4):
 CONSOLE_SCRIPTS = [
     'freeze_graph = tensorflow.python.tools.freeze_graph:run_main',
     'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main',
-    'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main',
+    'tflite_convert = tensorflow.contrib.lite.python.tflite_convert:main',
+    'toco = tensorflow.contrib.lite.python.tflite_convert:main',
     'saved_model_cli = tensorflow.python.tools.saved_model_cli:main',
     # We need to keep the TensorBoard command, even though the console script
     # is now declared by the tensorboard pip package. If we remove the