From: Nupur Garg Date: Thu, 31 May 2018 20:58:32 +0000 (-0700) Subject: Add attributes to TFLite Python API. X-Git-Tag: upstream/v1.9.0_rc1~26^2~6^2~27 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d3b5b07e7810782c3760468312f9cace10b89073;p=platform%2Fupstream%2Ftensorflow.git Add attributes to TFLite Python API. PiperOrigin-RevId: 198774775 --- diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index c0926d2..0819475 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -115,11 +115,15 @@ def toco_convert(input_data, input_tensors, output_tensors, inference_type=lite_constants.FLOAT, + inference_input_type=None, input_format=lite_constants.TENSORFLOW_GRAPHDEF, output_format=lite_constants.TFLITE, quantized_input_stats=None, + default_ranges_stats=None, drop_control_dependency=True, - allow_custom_ops=False): + reorder_across_fake_quant=False, + allow_custom_ops=False, + change_concat_input_ranges=False): """Convert a model using TOCO from `input_format` to `output_format`. Typically this is to convert from TensorFlow GraphDef to TFLite, in which @@ -130,18 +134,41 @@ def toco_convert(input_data, input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. + inference_type: Target data type of arrays in the output file. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of input arrays. Allows for a + different type for input arrays in the case of quantization. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + input_format: Type of data to read Currently must be + `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) + quantized_input_stats: Dict of strings representing input tensor names + mapped to tuple of integers representing the mean and standard deviation + of the training data (e.g., {"foo" : (0., 1.)}). Only need if + `inference_type` is `QUANTIZED_UINT8`. (default None) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) + drop_control_dependency: Boolean indicating whether to drop control + dependencies silently. This is due to TFLite not supporting control + dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) + allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. + (default False) Returns: - The converted data. For example if tflite was the destination, then + The converted data. For example if TFLite was the destination, then this will be a tflite flatbuffer in a bytes array. Raises: @@ -152,10 +179,18 @@ def toco_convert(input_data, toco = _toco_flags_pb2.TocoFlags() toco.input_format = input_format toco.output_format = output_format - toco.drop_control_dependency = drop_control_dependency - model = _model_flags_pb2.ModelFlags() toco.inference_type = inference_type + if inference_input_type: + toco.inference_input_type = inference_input_type + toco.drop_control_dependency = drop_control_dependency + toco.reorder_across_fake_quant = reorder_across_fake_quant toco.allow_custom_ops = allow_custom_ops + if default_ranges_stats: + toco.default_ranges_min = default_ranges_stats[0] + toco.default_ranges_max = default_ranges_stats[1] + + model = _model_flags_pb2.ModelFlags() + model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): if input_tensor.dtype == _dtypes.float32: tflite_input_type = lite_constants.FLOAT @@ -163,6 +198,8 @@ def toco_convert(input_data, tflite_input_type = lite_constants.INT32 elif input_tensor.dtype == _dtypes.int64: tflite_input_type = lite_constants.INT64 + elif input_tensor.dtype == _dtypes.uint8: + tflite_input_type = lite_constants.QUANTIZED_UINT8 # TODO(aselle): Insert strings when they are available else: raise ValueError("Tensors %s not known type %r" % (input_tensor.name, diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 6510d74..d55d8a6 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -64,17 +64,33 @@ class TocoConverter(object): inference_type: Target data type of arrays in the output file. Currently must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of input arrays. Allows for a + different type for input arrays in the case of quantization. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) output_format: Output file format. Currently must be `{TFLITE, GRAPHVIZ_DOT}`. (default TFLITE) - quantized_input_stats: The mean and std deviation of training data for each - input tensor. Only needed if `inference_type` is `QUANTIZED_UINT8`. - Dict of strings representing input tensor names to a tuple of integers - representing the quantization stats (e.g., {"foo" : (0., 1.)}). - (default {}) + quantized_input_stats: Dict of strings representing input tensor names + mapped to tuple of integers representing the mean and standard deviation + of the training data (e.g., {"foo" : (0., 1.)}). Only need if + `inference_type` is `QUANTIZED_UINT8`. (default {}) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) drop_control_dependency: Boolean indicating whether to drop control dependencies silently. This is due to TFLite not supporting control dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. (default False) Example usage: @@ -109,9 +125,13 @@ class TocoConverter(object): self._input_tensors = input_tensors self._output_tensors = output_tensors self.inference_type = constants.FLOAT + self.inference_input_type = None self.output_format = constants.TFLITE self.quantized_input_stats = {} + self.default_ranges_stats = None self.drop_control_dependency = True + self.reorder_across_fake_quant = False + self.change_concat_input_ranges = False self.allow_custom_ops = False @classmethod @@ -270,10 +290,15 @@ class TocoConverter(object): input_tensors=self._input_tensors, output_tensors=self._output_tensors, inference_type=self.inference_type, + inference_input_type=self.inference_input_type, input_format=constants.TENSORFLOW_GRAPHDEF, output_format=self.output_format, quantized_input_stats=quantized_stats, - drop_control_dependency=self.drop_control_dependency) + default_ranges_stats=self.default_ranges_stats, + drop_control_dependency=self.drop_control_dependency, + reorder_across_fake_quant=self.reorder_across_fake_quant, + change_concat_input_ranges=self.change_concat_input_ranges, + allow_custom_ops=self.allow_custom_ops) return result def _set_batch_size(self, batch_size): diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 28386ec..1b0cdb9 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -220,6 +220,67 @@ class FromSessionTest(test_util.TensorFlowTestCase): graphviz_output = converter.convert() self.assertTrue(graphviz_output) + def testInferenceInputType(self): + in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_input_type = lite_constants.QUANTIZED_UINT8 + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + def testDefaultRangesStats(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev + converter.default_ranges_stats = (0, 6) # min, max + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + class FromFlatbufferFile(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index 79be5cd..38068be 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -91,6 +91,9 @@ def _convert_model(flags): converter = _get_toco_converter(flags) if flags.inference_type: converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type) + if flags.inference_input_type: + converter.inference_input_type = _types_pb2.IODataType.Value( + flags.inference_input_type) if flags.output_format: converter.output_format = _toco_flags_pb2.FileFormat.Value( flags.output_format) @@ -101,9 +104,16 @@ def _convert_model(flags): mean_values = _parse_int_array(flags.mean_values) quant_stats = zip(mean_values, std_dev_values) converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) + if flags.default_ranges_min and flags.default_ranges_max: + converter.default_ranges_stats = (flags.default_ranges_min, + flags.default_ranges_max) if flags.drop_control_dependency: converter.drop_control_dependency = flags.drop_control_dependency + if flags.reorder_across_fake_quant: + converter.reorder_across_fake_quant = flags.reorder_across_fake_quant + if flags.change_concat_input_ranges: + converter.change_concat_input_ranges = flags.change_concat_input_ranges if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops @@ -116,8 +126,8 @@ def _convert_model(flags): def _check_flags(flags, unparsed): """Checks the parsed and unparsed flags to ensure they are valid. - Displays warnings for unparsed flags. Raises an error for parsed flags that - don't meet the required conditions. + Raises an error if previously support unparsed flags are found. Raises an + error for parsed flags that don't meet the required conditions. Args: flags: argparse.Namespace object containing TFLite flags. @@ -126,17 +136,20 @@ def _check_flags(flags, unparsed): Raises: ValueError: Invalid flags. """ + # Check unparsed flags for common mistakes based on previous TOCO. + def _get_message_unparsed(flag, orig_flag, new_flag): + if flag.startswith(orig_flag): + return "\n Use {0} instead of {1}".format(new_flag, orig_flag) + return "" + if unparsed: - print("tflite_convert: warning: Unable to parse following flags " - "'{}'".format(",".join(unparsed))) + output = "" for flag in unparsed: - if "--input_file=" in flag: - print("tflite_convert: warning: Use --graph_def_file instead of " - "--input_file") - if "--std_values=" in flag: - print("tflite_convert: warning: Use --std_dev_values instead of " - "--std_values") + output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") + output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") + output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") + raise ValueError(output) # Check that flags are valid. if flags.graph_def_file and (not flags.input_arrays or @@ -163,6 +176,10 @@ def _check_flags(flags, unparsed): raise ValueError("--std_dev_values, --mean_values, and --input_arrays " "must have the same number of items") + if bool(flags.default_ranges_min) != bool(flags.default_ranges_max): + raise ValueError("--default_ranges_min and --default_ranges_max must be " + "used together") + def run_main(_): """Main in toco_convert.py.""" @@ -199,6 +216,12 @@ def run_main(_): type=str, choices=["FLOAT", "QUANTIZED_UINT8"], help="Target data type of arrays in the output file.") + parser.add_argument( + "--inference_input_type", + type=str, + choices=["FLOAT", "QUANTIZED_UINT8"], + help=("Target data type of input arrays. Allows for a different type for " + "input arrays in the case of quantization.")) # Input and output arrays flags. parser.add_argument( @@ -218,12 +241,13 @@ def run_main(_): parser.add_argument( "--saved_model_tag_set", type=str, - help=("Set of tags identifying the MetaGraphDef within the SavedModel " - "to analyze. All tags must be present. (default \"serve\")")) + help=("Comma-separated set of tags identifying the MetaGraphDef within " + "the SavedModel to analyze. All tags must be present. " + "(default \"serve\")")) parser.add_argument( "--saved_model_signature_key", type=str, - help=("Key identifying SignatureDef containing inputs and outputs. " + help=("Key identifying the SignatureDef containing inputs and outputs. " "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) # Quantization flags. @@ -237,14 +261,41 @@ def run_main(_): type=str, help=("Mean of training data for each input tensor, comma-separated. " "Used for quantization. (default None)")) + parser.add_argument( + "--default_ranges_min", + type=int, + help=("Default value for min bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) + parser.add_argument( + "--default_ranges_max", + type=int, + help=("Default value for max bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) # Graph manipulation flags. parser.add_argument( "--drop_control_dependency", type=bool, help=("Boolean indicating whether to drop control dependencies silently. " - "This is due to TensorFlow Lite not supporting control " - "dependencies. (default True)")) + "This is due to TensorFlow not supporting control dependencies. " + "(default True)")) + parser.add_argument( + "--reorder_across_fake_quant", + type=bool, + help=("Boolean indicating whether to reorder FakeQuant nodes in " + "unexpected locations. Used when the location of the FakeQuant " + "nodes is preventing graph transformations necessary to convert " + "the graph. Results in a graph that differs from the quantized " + "training graph, potentially causing differing arithmetic " + "behavior. (default False)")) + parser.add_argument( + "--change_concat_input_ranges", + type=bool, + help=("Boolean to change behavior of min/max ranges for inputs and " + "outputs of the concat operator for quantized models. Changes the " + "ranges of concat operator overlap when true. (default False)")) parser.add_argument( "--allow_custom_ops", type=bool,