Add attributes to TFLite Python API.
authorNupur Garg <nupurgarg@google.com>
Thu, 31 May 2018 20:58:32 +0000 (13:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 21:01:07 +0000 (14:01 -0700)
PiperOrigin-RevId: 198774775

tensorflow/contrib/lite/python/convert.py
tensorflow/contrib/lite/python/lite.py
tensorflow/contrib/lite/python/lite_test.py
tensorflow/contrib/lite/python/tflite_convert.py

index c0926d2..0819475 100644 (file)
@@ -115,11 +115,15 @@ def toco_convert(input_data,
                  input_tensors,
                  output_tensors,
                  inference_type=lite_constants.FLOAT,
+                 inference_input_type=None,
                  input_format=lite_constants.TENSORFLOW_GRAPHDEF,
                  output_format=lite_constants.TFLITE,
                  quantized_input_stats=None,
+                 default_ranges_stats=None,
                  drop_control_dependency=True,
-                 allow_custom_ops=False):
+                 reorder_across_fake_quant=False,
+                 allow_custom_ops=False,
+                 change_concat_input_ranges=False):
   """Convert a model using TOCO from `input_format` to `output_format`.
 
   Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@@ -130,18 +134,41 @@ def toco_convert(input_data,
     input_tensors: List of input tensors. Type and shape are computed using
       `foo.get_shape()` and `foo.dtype`.
     output_tensors: List of output tensors (only .name is used from this).
-    inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
-    input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
-    output_format: Type of data to write (currently must be TFLITE or
-      GRAPHVIZ_DOT)
-    quantized_input_stats: For each member of input_tensors the mean and
-      std deviation of training data. Only needed if `inference_type` is
-      `QUANTIZED_UINT8`.
-    drop_control_dependency: Drops control dependencies silently. This is due
-      to tf lite not supporting control dependencies.
+    inference_type: Target data type of arrays in the output file. Currently
+      must be `{FLOAT, QUANTIZED_UINT8}`.  (default FLOAT)
+    inference_input_type: Target data type of input arrays. Allows for a
+      different type for input arrays in the case of quantization. Currently
+      must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
+    input_format: Type of data to read Currently must be
+      `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
+    output_format: Output file format. Currently must be `{TFLITE,
+      GRAPHVIZ_DOT}`. (default TFLITE)
+    quantized_input_stats: Dict of strings representing input tensor names
+      mapped to tuple of integers representing the mean and standard deviation
+      of the training data (e.g., {"foo" : (0., 1.)}). Only need if
+      `inference_type` is `QUANTIZED_UINT8`. (default None)
+    default_ranges_stats: Tuple of integers representing (min, max) range values
+      for all arrays without a specified range. Intended for experimenting with
+      quantization via "dummy quantization". (default None)
+    drop_control_dependency: Boolean indicating whether to drop control
+      dependencies silently. This is due to TFLite not supporting control
+      dependencies. (default True)
+    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
+      nodes in unexpected locations. Used when the location of the FakeQuant
+      nodes is preventing graph transformations necessary to convert the graph.
+      Results in a graph that differs from the quantized training graph,
+      potentially causing differing arithmetic behavior. (default False)
+    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
+      inputs and outputs of the concat operator for quantized models. Changes
+      the ranges of concat operator overlap when true. (default False)
+    allow_custom_ops: Boolean indicating whether to allow custom operations.
+      When false any unknown operation is an error. When true, custom ops are
+      created for any op that is unknown. The developer will need to provide
+      these to the TensorFlow Lite runtime with a custom resolver.
+      (default False)
 
   Returns:
-    The converted data. For example if tflite was the destination, then
+    The converted data. For example if TFLite was the destination, then
     this will be a tflite flatbuffer in a bytes array.
 
   Raises:
@@ -152,10 +179,18 @@ def toco_convert(input_data,
   toco = _toco_flags_pb2.TocoFlags()
   toco.input_format = input_format
   toco.output_format = output_format
-  toco.drop_control_dependency = drop_control_dependency
-  model = _model_flags_pb2.ModelFlags()
   toco.inference_type = inference_type
+  if inference_input_type:
+    toco.inference_input_type = inference_input_type
+  toco.drop_control_dependency = drop_control_dependency
+  toco.reorder_across_fake_quant = reorder_across_fake_quant
   toco.allow_custom_ops = allow_custom_ops
+  if default_ranges_stats:
+    toco.default_ranges_min = default_ranges_stats[0]
+    toco.default_ranges_max = default_ranges_stats[1]
+
+  model = _model_flags_pb2.ModelFlags()
+  model.change_concat_input_ranges = change_concat_input_ranges
   for idx, input_tensor in enumerate(input_tensors):
     if input_tensor.dtype == _dtypes.float32:
       tflite_input_type = lite_constants.FLOAT
@@ -163,6 +198,8 @@ def toco_convert(input_data,
       tflite_input_type = lite_constants.INT32
     elif input_tensor.dtype == _dtypes.int64:
       tflite_input_type = lite_constants.INT64
+    elif input_tensor.dtype == _dtypes.uint8:
+      tflite_input_type = lite_constants.QUANTIZED_UINT8
     # TODO(aselle): Insert strings when they are available
     else:
       raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
index 6510d74..d55d8a6 100644 (file)
@@ -64,17 +64,33 @@ class TocoConverter(object):
 
     inference_type: Target data type of arrays in the output file. Currently
       must be `{FLOAT, QUANTIZED_UINT8}`.  (default FLOAT)
+    inference_input_type: Target data type of input arrays. Allows for a
+      different type for input arrays in the case of quantization. Currently
+      must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
     output_format: Output file format. Currently must be `{TFLITE,
       GRAPHVIZ_DOT}`. (default TFLITE)
-    quantized_input_stats: The mean and std deviation of training data for each
-      input tensor. Only needed if `inference_type` is `QUANTIZED_UINT8`.
-      Dict of strings representing input tensor names to a tuple of integers
-      representing the quantization stats (e.g., {"foo" : (0., 1.)}).
-      (default {})
+    quantized_input_stats: Dict of strings representing input tensor names
+      mapped to tuple of integers representing the mean and standard deviation
+      of the training data (e.g., {"foo" : (0., 1.)}). Only need if
+      `inference_type` is `QUANTIZED_UINT8`. (default {})
+    default_ranges_stats: Tuple of integers representing (min, max) range values
+      for all arrays without a specified range. Intended for experimenting with
+      quantization via "dummy quantization". (default None)
     drop_control_dependency: Boolean indicating whether to drop control
       dependencies silently. This is due to TFLite not supporting control
       dependencies. (default True)
+    reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
+      nodes in unexpected locations. Used when the location of the FakeQuant
+      nodes is preventing graph transformations necessary to convert the graph.
+      Results in a graph that differs from the quantized training graph,
+      potentially causing differing arithmetic behavior. (default False)
+    change_concat_input_ranges: Boolean to change behavior of min/max ranges for
+      inputs and outputs of the concat operator for quantized models. Changes
+      the ranges of concat operator overlap when true. (default False)
     allow_custom_ops: Boolean indicating whether to allow custom operations.
+      When false any unknown operation is an error. When true, custom ops are
+      created for any op that is unknown. The developer will need to provide
+      these to the TensorFlow Lite runtime with a custom resolver.
       (default False)
 
   Example usage:
@@ -109,9 +125,13 @@ class TocoConverter(object):
     self._input_tensors = input_tensors
     self._output_tensors = output_tensors
     self.inference_type = constants.FLOAT
+    self.inference_input_type = None
     self.output_format = constants.TFLITE
     self.quantized_input_stats = {}
+    self.default_ranges_stats = None
     self.drop_control_dependency = True
+    self.reorder_across_fake_quant = False
+    self.change_concat_input_ranges = False
     self.allow_custom_ops = False
 
   @classmethod
@@ -270,10 +290,15 @@ class TocoConverter(object):
         input_tensors=self._input_tensors,
         output_tensors=self._output_tensors,
         inference_type=self.inference_type,
+        inference_input_type=self.inference_input_type,
         input_format=constants.TENSORFLOW_GRAPHDEF,
         output_format=self.output_format,
         quantized_input_stats=quantized_stats,
-        drop_control_dependency=self.drop_control_dependency)
+        default_ranges_stats=self.default_ranges_stats,
+        drop_control_dependency=self.drop_control_dependency,
+        reorder_across_fake_quant=self.reorder_across_fake_quant,
+        change_concat_input_ranges=self.change_concat_input_ranges,
+        allow_custom_ops=self.allow_custom_ops)
     return result
 
   def _set_batch_size(self, batch_size):
index 28386ec..1b0cdb9 100644 (file)
@@ -220,6 +220,67 @@ class FromSessionTest(test_util.TensorFlowTestCase):
     graphviz_output = converter.convert()
     self.assertTrue(graphviz_output)
 
+  def testInferenceInputType(self):
+    in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8)
+    out_tensor = in_tensor + in_tensor
+    sess = session.Session()
+
+    # Convert model and ensure model is not None.
+    converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+    converter.inference_input_type = lite_constants.QUANTIZED_UINT8
+    tflite_model = converter.convert()
+    self.assertTrue(tflite_model)
+
+    # Check values from converted model.
+    interpreter = Interpreter(model_content=tflite_model)
+    interpreter.allocate_tensors()
+
+    input_details = interpreter.get_input_details()
+    self.assertEqual(1, len(input_details))
+    self.assertEqual('Placeholder', input_details[0]['name'])
+    self.assertEqual(np.uint8, input_details[0]['dtype'])
+    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+    self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+    output_details = interpreter.get_output_details()
+    self.assertEqual(1, len(output_details))
+    self.assertEqual('add', output_details[0]['name'])
+    self.assertEqual(np.uint8, output_details[0]['dtype'])
+    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+    self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+  def testDefaultRangesStats(self):
+    in_tensor = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32)
+    out_tensor = in_tensor + in_tensor
+    sess = session.Session()
+
+    # Convert model and ensure model is not None.
+    converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+    converter.inference_type = lite_constants.QUANTIZED_UINT8
+    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
+    converter.default_ranges_stats = (0, 6)  # min, max
+    tflite_model = converter.convert()
+    self.assertTrue(tflite_model)
+
+    # Check values from converted model.
+    interpreter = Interpreter(model_content=tflite_model)
+    interpreter.allocate_tensors()
+
+    input_details = interpreter.get_input_details()
+    self.assertEqual(1, len(input_details))
+    self.assertEqual('Placeholder', input_details[0]['name'])
+    self.assertEqual(np.uint8, input_details[0]['dtype'])
+    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+    self.assertEqual((1., 0.), input_details[0]['quantization'])
+
+    output_details = interpreter.get_output_details()
+    self.assertEqual(1, len(output_details))
+    self.assertEqual('add', output_details[0]['name'])
+    self.assertEqual(np.uint8, output_details[0]['dtype'])
+    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+    self.assertTrue(output_details[0]['quantization'][0] > 0)  # scale
+
 
 class FromFlatbufferFile(test_util.TensorFlowTestCase):
 
index 79be5cd..38068be 100644 (file)
@@ -91,6 +91,9 @@ def _convert_model(flags):
   converter = _get_toco_converter(flags)
   if flags.inference_type:
     converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type)
+  if flags.inference_input_type:
+    converter.inference_input_type = _types_pb2.IODataType.Value(
+        flags.inference_input_type)
   if flags.output_format:
     converter.output_format = _toco_flags_pb2.FileFormat.Value(
         flags.output_format)
@@ -101,9 +104,16 @@ def _convert_model(flags):
     mean_values = _parse_int_array(flags.mean_values)
     quant_stats = zip(mean_values, std_dev_values)
     converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
+  if flags.default_ranges_min and flags.default_ranges_max:
+    converter.default_ranges_stats = (flags.default_ranges_min,
+                                      flags.default_ranges_max)
 
   if flags.drop_control_dependency:
     converter.drop_control_dependency = flags.drop_control_dependency
+  if flags.reorder_across_fake_quant:
+    converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
+  if flags.change_concat_input_ranges:
+    converter.change_concat_input_ranges = flags.change_concat_input_ranges
   if flags.allow_custom_ops:
     converter.allow_custom_ops = flags.allow_custom_ops
 
@@ -116,8 +126,8 @@ def _convert_model(flags):
 def _check_flags(flags, unparsed):
   """Checks the parsed and unparsed flags to ensure they are valid.
 
-  Displays warnings for unparsed flags. Raises an error for parsed flags that
-  don't meet the required conditions.
+  Raises an error if previously support unparsed flags are found. Raises an
+  error for parsed flags that don't meet the required conditions.
 
   Args:
     flags: argparse.Namespace object containing TFLite flags.
@@ -126,17 +136,20 @@ def _check_flags(flags, unparsed):
   Raises:
     ValueError: Invalid flags.
   """
+
   # Check unparsed flags for common mistakes based on previous TOCO.
+  def _get_message_unparsed(flag, orig_flag, new_flag):
+    if flag.startswith(orig_flag):
+      return "\n  Use {0} instead of {1}".format(new_flag, orig_flag)
+    return ""
+
   if unparsed:
-    print("tflite_convert: warning: Unable to parse following flags "
-          "'{}'".format(",".join(unparsed)))
+    output = ""
     for flag in unparsed:
-      if "--input_file=" in flag:
-        print("tflite_convert: warning: Use --graph_def_file instead of "
-              "--input_file")
-      if "--std_values=" in flag:
-        print("tflite_convert: warning: Use --std_dev_values instead of "
-              "--std_values")
+      output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
+      output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
+      output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
+    raise ValueError(output)
 
   # Check that flags are valid.
   if flags.graph_def_file and (not flags.input_arrays or
@@ -163,6 +176,10 @@ def _check_flags(flags, unparsed):
       raise ValueError("--std_dev_values, --mean_values, and --input_arrays "
                        "must have the same number of items")
 
+  if bool(flags.default_ranges_min) != bool(flags.default_ranges_max):
+    raise ValueError("--default_ranges_min and --default_ranges_max must be "
+                     "used together")
+
 
 def run_main(_):
   """Main in toco_convert.py."""
@@ -199,6 +216,12 @@ def run_main(_):
       type=str,
       choices=["FLOAT", "QUANTIZED_UINT8"],
       help="Target data type of arrays in the output file.")
+  parser.add_argument(
+      "--inference_input_type",
+      type=str,
+      choices=["FLOAT", "QUANTIZED_UINT8"],
+      help=("Target data type of input arrays. Allows for a different type for "
+            "input arrays in the case of quantization."))
 
   # Input and output arrays flags.
   parser.add_argument(
@@ -218,12 +241,13 @@ def run_main(_):
   parser.add_argument(
       "--saved_model_tag_set",
       type=str,
-      help=("Set of tags identifying the MetaGraphDef within the SavedModel "
-            "to analyze. All tags must be present. (default \"serve\")"))
+      help=("Comma-separated set of tags identifying the MetaGraphDef within "
+            "the SavedModel to analyze. All tags must be present. "
+            "(default \"serve\")"))
   parser.add_argument(
       "--saved_model_signature_key",
       type=str,
-      help=("Key identifying SignatureDef containing inputs and outputs. "
+      help=("Key identifying the SignatureDef containing inputs and outputs. "
             "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
 
   # Quantization flags.
@@ -237,14 +261,41 @@ def run_main(_):
       type=str,
       help=("Mean of training data for each input tensor, comma-separated. "
             "Used for quantization. (default None)"))
+  parser.add_argument(
+      "--default_ranges_min",
+      type=int,
+      help=("Default value for min bound of min/max range values used for all "
+            "arrays without a specified range, Intended for experimenting with "
+            "quantization via \"dummy quantization\". (default None)"))
+  parser.add_argument(
+      "--default_ranges_max",
+      type=int,
+      help=("Default value for max bound of min/max range values used for all "
+            "arrays without a specified range, Intended for experimenting with "
+            "quantization via \"dummy quantization\". (default None)"))
 
   # Graph manipulation flags.
   parser.add_argument(
       "--drop_control_dependency",
       type=bool,
       help=("Boolean indicating whether to drop control dependencies silently. "
-            "This is due to TensorFlow Lite not supporting control "
-            "dependencies. (default True)"))
+            "This is due to TensorFlow not supporting control dependencies. "
+            "(default True)"))
+  parser.add_argument(
+      "--reorder_across_fake_quant",
+      type=bool,
+      help=("Boolean indicating whether to reorder FakeQuant nodes in "
+            "unexpected locations. Used when the location of the FakeQuant "
+            "nodes is preventing graph transformations necessary to convert "
+            "the graph. Results in a graph that differs from the quantized "
+            "training graph, potentially causing differing arithmetic "
+            "behavior. (default False)"))
+  parser.add_argument(
+      "--change_concat_input_ranges",
+      type=bool,
+      help=("Boolean to change behavior of min/max ranges for inputs and "
+            "outputs of the concat operator for quantized models. Changes the "
+            "ranges of concat operator overlap when true. (default False)"))
   parser.add_argument(
       "--allow_custom_ops",
       type=bool,