],
)
+py_binary(
+ name = "tflite_convert",
+ srcs = ["tflite_convert.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":lite",
+ ],
+)
+
py_library(
name = "lite",
srcs = ["lite.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":convert",
"//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python:graph_util",
"//tensorflow/python:platform",
"//tensorflow/python/saved_model",
],
)
-
-# Transitive dependencies of this target will be included in the pip package.
-py_library(
- name = "tf_lite_py_pip",
- deps = [
- ":convert_saved_model",
- ],
-)
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.lite.python.convert import tensor_name
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework import ops
-from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import tag_constants
-
-
-def _write_and_flush_file(file_path, data_str):
- """Writes data to file path.
-
- Args:
- file_path: Full path of the file to store data in.
- data_str: Data represented as a string.
-
- Returns: None.
- """
- with gfile.Open(file_path, "wb") as data_file:
- data_file.write(data_str)
- data_file.flush()
def _log_tensor_details(tensor_info):
"""
tensors = []
if user_tensor_names:
- # Get the list of all of the tensors with and without the tensor index.
- all_tensor_names = [
- tensor.name for op in graph.get_operations() for tensor in op.outputs
- ]
- all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names]
-
# Sort the tensor names.
user_tensor_names = sorted(user_tensor_names)
- # Get the tensors associated with the tensor names.
- tensors = []
- invalid_tensors = []
- for name in user_tensor_names:
- if name not in all_tensor_names_only:
- invalid_tensors.append(name)
- else:
- idx = all_tensor_names_only.index(name)
- tensors.append(graph.get_tensor_by_name(all_tensor_names[idx]))
-
- # Throw ValueError if any user input names are not valid tensors.
- if invalid_tensors:
- raise ValueError("Invalid tensors '{}' were found.".format(
- ",".join(invalid_tensors)))
+ tensors = get_tensors_from_tensor_names(graph, user_tensor_names)
elif signature_def_tensor_names:
tensors = [
graph.get_tensor_by_name(name)
return tensors
+def get_tensors_from_tensor_names(graph, tensor_names):
+ """Gets the Tensors associated with the `tensor_names` in the provided graph.
+
+ Args:
+ graph: TensorFlow Graph.
+ tensor_names: List of strings that represent names of tensors in the graph.
+
+ Returns:
+ A list of Tensor objects in the same order the names are provided.
+
+ Raises:
+ ValueError:
+ tensor_names contains an invalid tensor name.
+ """
+ # Get the list of all of the tensors.
+ tensor_name_to_tensor = {
+ tensor_name(tensor): tensor for op in graph.get_operations()
+ for tensor in op.values()
+ }
+
+ # Get the tensors associated with tensor_names.
+ tensors = []
+ invalid_tensors = []
+ for name in tensor_names:
+ tensor = tensor_name_to_tensor.get(name)
+ if tensor is None:
+ invalid_tensors.append(name)
+ else:
+ tensors.append(tensor)
+
+ # Throw ValueError if any user input names are not valid tensors.
+ if invalid_tensors:
+ raise ValueError("Invalid tensors '{}' were found.".format(
+ ",".join(invalid_tensors)))
+ return tensors
+
+
+def set_tensor_shapes(tensors, shapes):
+ """Sets Tensor shape for each tensor if the shape is defined.
+
+ Args:
+ tensors: TensorFlow ops.Tensor.
+ shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ """
+ if shapes:
+ for tensor in tensors:
+ shape = shapes.get(tensor.name)
+ if shape is not None:
+ tensor.set_shape(shapes[tensor.name])
+
+
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key):
"""Converts a SavedModel to a frozen graph.
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input arrays
- from SignatureDef when none are provided. (default None)
- input_shapes: Map of strings representing input tensor names to list of
+ from SignatureDef when none are provided.
+ input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" : None}).
- (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
- arrays from SignatureDef when none are provided. (default None)
+ arrays from SignatureDef when none are provided.
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. All tags in the tag set must be present. (default "serve")
+ analyze. All tags in the tag set must be present.
signature_key: Key identifying SignatureDef containing inputs and outputs.
Returns:
signature_key is not in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
- Unable to load Session.
"""
- # Set default values for inputs if they are set to None.
- if signature_key is None:
- signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
- if tag_set is None:
- tag_set = set([tag_constants.SERVING])
-
# Read SignatureDef.
meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
signature_def = _get_signature_def(meta_graph, signature_key)
# TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
in_tensors = _get_tensors(graph, inputs, input_arrays)
out_tensors = _get_tensors(graph, outputs, output_arrays)
-
- # Gets fully defined tensor shape.
- for tensor in in_tensors:
- if (input_shapes and tensor.name in input_shapes and
- input_shapes[tensor.name] is not None):
- shape = input_shapes[tensor.name]
- else:
- shape = tensor.get_shape().as_list()
- tensor.set_shape(shape)
+ set_tensor_shapes(in_tensors, input_shapes)
output_names = [node.split(":")[0] for node in outputs]
frozen_graph_def = tf_graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), output_names)
return frozen_graph_def, in_tensors, out_tensors
- raise ValueError("Unable to load Session.")
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import training as train
+class TensorFunctionsTest(test_util.TensorFlowTestCase):
+
+ def testGetTensorsValid(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ tensors = convert_saved_model.get_tensors_from_tensor_names(
+ sess.graph, ["Placeholder"])
+ self.assertEqual("Placeholder:0", tensors[0].name)
+
+ def testGetTensorsInvalid(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ with self.assertRaises(ValueError) as error:
+ convert_saved_model.get_tensors_from_tensor_names(sess.graph,
+ ["invalid-input"])
+ self.assertEqual("Invalid tensors 'invalid-input' were found.",
+ str(error.exception))
+
+ def testSetTensorShapeValid(self):
+ tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ convert_saved_model.set_tensor_shapes([tensor],
+ {"Placeholder:0": [5, 3, 5]})
+ self.assertEqual([5, 3, 5], tensor.shape.as_list())
+
+ def testSetTensorShapeInvalid(self):
+ tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ convert_saved_model.set_tensor_shapes([tensor],
+ {"invalid-input": [5, 3, 5]})
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ def testSetTensorShapeEmpty(self):
+ tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ convert_saved_model.set_tensor_shapes([tensor], {})
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+
class FreezeSavedModelTest(test_util.TensorFlowTestCase):
def _createSimpleSavedModel(self, shape):
output_arrays=None,
tag_set=None,
signature_key=None):
+ if tag_set is None:
+ tag_set = set([tag_constants.SERVING])
+ if signature_key is None:
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model(
saved_model_dir=saved_model_dir,
input_arrays=input_arrays,
input_arrays=None,
input_shapes=None,
output_arrays=["Softmax"],
- tag_set=None,
+ tag_set=set([tag_constants.SERVING]),
signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
self.assertTrue(result)
from __future__ import division
from __future__ import print_function
+from google.protobuf import text_format as _text_format
+from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model
+from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names
+from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes
from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as tf_graph_util
+from tensorflow.python.framework.importer import import_graph_def
from tensorflow.python.ops.variables import global_variables_initializer
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
Attributes:
- inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
- (default FLOAT)
- output_format: Type of data to write (currently must be TFLITE or
- GRAPHVIZ_DOT). (default TFLITE)
+ inference_type: Target data type of arrays in the output file. Currently
+ must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
+ output_format: Output file format. Currently must be `{TFLITE,
+ GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: The mean and std deviation of training data for each
input tensor. Only needed if `inference_type` is `QUANTIZED_UINT8`.
- (default None)
+ Dict of strings representing input tensor names to a tuple of integers
+ representing the quantization stats (e.g., {"foo" : (0., 1.)}).
+ (default {})
drop_control_dependency: Boolean indicating whether to drop control
dependencies silently. This is due to TFLite not supporting control
dependencies. (default True)
Example usage:
- # Converting a frozen graph.
+ # Converting a GraphDef from session.
converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
+ # Converting a GraphDef from file.
+ converter = lite.TocoConverter.from_flatbuffer_file(
+ graph_def_file, input_arrays, output_arrays)
+ tflite_model = converter.convert()
+ open("converted_model.tflite", "wb").write(tflite_model)
+
# Converting a SavedModel.
converter = lite.TocoConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self._output_tensors = output_tensors
self.inference_type = constants.FLOAT
self.output_format = constants.TFLITE
- self.quantized_input_stats = None
+ self.quantized_input_stats = {}
self.drop_control_dependency = True
self.allow_custom_ops = False
@classmethod
- def from_session(cls,
- sess,
- input_tensors,
- output_tensors,
- freeze_variables=False):
+ def from_session(cls, sess, input_tensors, output_tensors):
"""Creates a TocoConverter class from a TensorFlow Session.
Args:
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
- freeze_variables: Boolean indicating whether the variables need to be
- converted into constants via the freeze_graph.py script.
- (default False)
Returns:
TocoConverter class.
"""
+ graph_def = _freeze_graph(sess, output_tensors)
+ return cls(graph_def, input_tensors, output_tensors)
+
+ @classmethod
+ def from_flatbuffer_file(cls,
+ graph_def_file,
+ input_arrays,
+ output_arrays,
+ input_shapes=None):
+ """Creates a TocoConverter class from a file containing a GraphDef.
+
+ Args:
+ graph_def_file: Full filepath of file containing TensorFlow GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" :
+ None}). (default None)
- # Get GraphDef.
- if freeze_variables:
+ Returns:
+ TocoConverter class.
+
+ Raises:
+ ValueError:
+ Unable to parse input file.
+ The graph is not frozen.
+ input_arrays or output_arrays contains an invalid tensor name.
+ """
+ with _session.Session() as sess:
sess.run(global_variables_initializer())
- output_arrays = [tensor_name(tensor) for tensor in output_tensors]
- graph_def = tf_graph_util.convert_variables_to_constants(
- sess, sess.graph_def, output_arrays)
- else:
- graph_def = sess.graph_def
- # Create TocoConverter class.
- return cls(graph_def, input_tensors, output_tensors)
+ # Read GraphDef from file.
+ graph_def = _graph_pb2.GraphDef()
+ with open(graph_def_file, "rb") as f:
+ file_content = f.read()
+ try:
+ graph_def.ParseFromString(file_content)
+ except (_text_format.ParseError, DecodeError):
+ try:
+ print("Ignore 'tcmalloc: large alloc' warnings.")
+ _text_format.Merge(file_content, graph_def)
+ except (_text_format.ParseError, DecodeError):
+ raise ValueError(
+ "Unable to parse input file '{}'.".format(graph_def_file))
+ sess.graph.as_default()
+ import_graph_def(graph_def, name="")
+
+ # Get input and output tensors.
+ input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays)
+ output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays)
+ set_tensor_shapes(input_tensors, input_shapes)
+
+ # Check if graph is frozen.
+ if not _is_frozen_graph(sess):
+ raise ValueError("Please freeze the graph using freeze_graph.py")
+
+ # Create TocoConverter class.
+ return cls(sess.graph_def, input_tensors, output_tensors)
@classmethod
- def from_saved_model(
- cls,
- saved_model_dir,
- input_arrays=None,
- input_shapes=None,
- output_arrays=None,
- tag_set=None,
- signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
+ def from_saved_model(cls,
+ saved_model_dir,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=None):
"""Creates a TocoConverter class from a SavedModel.
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input
arrays from SignatureDef when none are provided. (default None)
- input_shapes: Map of strings representing input tensor names to list of
- integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. All tags in the tag set must be present. (default "serve")
+ analyze. All tags in the tag set must be present. (default set("serve"))
signature_key: Key identifying SignatureDef containing inputs and outputs.
+ (default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
Returns:
TocoConverter class.
"""
if tag_set is None:
tag_set = set([tag_constants.SERVING])
+ if signature_key is None:
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key)
elif shape[0] is None:
self._set_batch_size(batch_size=1)
+ # Get quantization stats. Ensures there is one stat per name if the stats
+ # are specified.
+ if self.quantized_input_stats:
+ quantized_stats = []
+ invalid_stats = []
+ for tensor in self._input_tensors:
+ name = tensor_name(tensor)
+ if name in self.quantized_input_stats:
+ quantized_stats.append(self.quantized_input_stats[name])
+ else:
+ invalid_stats.append(name)
+
+ if invalid_stats:
+ raise ValueError("Quantization input stats are not available for input "
+ "tensors '{0}'.".format(",".join(invalid_stats)))
+ else:
+ quantized_stats = None
+
# Converts model.
result = toco_convert(
input_data=self._graph_def,
inference_type=self.inference_type,
input_format=constants.TENSORFLOW_GRAPHDEF,
output_format=self.output_format,
- quantized_input_stats=self.quantized_input_stats,
+ quantized_input_stats=quantized_stats,
drop_control_dependency=self.drop_control_dependency)
return result
shape = tensor.get_shape().as_list()
shape[0] = batch_size
tensor.set_shape(shape)
+
+
+def _is_frozen_graph(sess):
+ """Determines if the graph is frozen.
+
+ Determines if a graph has previously been frozen by checking for any
+ operations of type Variable*. If variables are found, the graph is not frozen.
+
+ Args:
+ sess: TensorFlow Session.
+
+ Returns:
+ Bool.
+ """
+ for op in sess.graph.get_operations():
+ if op.type.startswith("Variable"):
+ return False
+ return True
+
+
+def _freeze_graph(sess, output_tensors):
+ """Returns a frozen GraphDef.
+
+ Freezes a graph with Variables in it. Otherwise the existing GraphDef is
+ returned.
+
+ Args:
+ sess: TensorFlow Session.
+ output_tensors: List of output tensors (only .name is used from this).
+
+ Returns:
+ Frozen GraphDef.
+ """
+ if not _is_frozen_graph(sess):
+ sess.run(global_variables_initializer())
+ output_arrays = [tensor_name(tensor) for tensor in output_tensors]
+ return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def,
+ output_arrays)
+ else:
+ return sess.graph_def
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
+from tensorflow.python.training.training_util import write_graph
class FromSessionTest(test_util.TensorFlowTestCase):
self.assertEqual((0., 0.), output_details[0]['quantization'])
def testQuantization(self):
- in_tensor = array_ops.placeholder(
- shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
out_tensor = array_ops.fake_quant_with_min_max_args(
- in_tensor + in_tensor, min=0., max=1., name='output')
+ in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1, in_tensor_2], [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
- converter.quantized_input_stats = [(0., 1.)] # mean, std_dev
+ converter.quantized_input_stats = {
+ 'inputA': (0., 1.),
+ 'inputB': (0., 1.)
+ } # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
- self.assertEqual(1, len(input_details))
- self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('inputA', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((1., 0.),
input_details[0]['quantization']) # scale, zero_point
+ self.assertEqual('inputB', input_details[1]['name'])
+ self.assertEqual(np.uint8, input_details[1]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
+ self.assertEqual((1., 0.),
+ input_details[1]['quantization']) # scale, zero_point
+
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+ def testQuantizationInvalid(self):
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+ out_tensor = array_ops.fake_quant_with_min_max_args(
+ in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1, in_tensor_2], [out_tensor])
+ converter.inference_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
+ with self.assertRaises(ValueError) as error:
+ converter.convert()
+ self.assertEqual(
+ 'Quantization input stats are not available for input tensors '
+ '\'inputB\'.', str(error.exception))
+
def testBatchSizeInvalid(self):
in_tensor = array_ops.placeholder(
shape=[None, 16, 16, 3], dtype=dtypes.float32)
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(
- sess, [in_tensor], [out_tensor], freeze_variables=True)
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
self.assertTrue(graphviz_output)
+class FromFlatbufferFile(test_util.TensorFlowTestCase):
+
+ def testFloat(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_flatbuffer_file(
+ graph_def_file, ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testFloatWithShapesArray(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_flatbuffer_file(
+ graph_def_file, ['Placeholder'], ['add'],
+ input_shapes={'Placeholder': [1, 16, 16, 3]})
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+
+ def testFreezeGraph(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ var = variable_scope.get_variable(
+ 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + var
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Ensure the graph with variables cannot be converted.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'],
+ ['add'])
+ self.assertEqual('Please freeze the graph using freeze_graph.py',
+ str(error.exception))
+
+ def testPbtxt(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
+ write_graph(sess.graph_def, '', graph_def_file, True)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_flatbuffer_file(
+ graph_def_file, ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testInvalidFile(self):
+ graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
+ with gfile.Open(graph_def_file, 'wb') as temp_file:
+ temp_file.write('bad data')
+ temp_file.flush()
+
+ # Attempts to convert the invalid model.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'],
+ ['add'])
+ self.assertEqual(
+ 'Unable to parse input file \'{}\'.'.format(graph_def_file),
+ str(error.exception))
+
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
def _createSavedModel(self, shape):
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python command line interface for running TOCO."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
+from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
+from tensorflow.python.platform import app
+
+
+def _parse_array(values):
+ if values:
+ return values.split(",")
+
+
+def _parse_int_array(values):
+ if values:
+ return [int(val) for val in values.split(",")]
+
+
+def _parse_set(values):
+ if values:
+ return set(values.split(","))
+
+
+def _get_toco_converter(flags):
+ """Makes a TocoConverter object based on the flags provided.
+
+ Args:
+ flags: argparse.Namespace object containing TFLite flags.
+
+ Returns:
+ TocoConverter object.
+ """
+ # Parse input and output arrays.
+ input_arrays = _parse_array(flags.input_arrays)
+ input_shapes = None
+ if flags.input_shapes:
+ input_shapes_list = [
+ _parse_int_array(shape) for shape in flags.input_shapes.split(":")
+ ]
+ input_shapes = dict(zip(input_arrays, input_shapes_list))
+ output_arrays = _parse_array(flags.output_arrays)
+
+ converter_kwargs = {
+ "input_arrays": input_arrays,
+ "input_shapes": input_shapes,
+ "output_arrays": output_arrays
+ }
+
+ # Create TocoConverter.
+ if flags.graph_def_file:
+ converter_fn = lite.TocoConverter.from_flatbuffer_file
+ converter_kwargs["graph_def_file"] = flags.graph_def_file
+ elif flags.saved_model_dir:
+ converter_fn = lite.TocoConverter.from_saved_model
+ converter_kwargs["saved_model_dir"] = flags.saved_model_dir
+ converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
+ converter_kwargs["signature_key"] = flags.saved_model_signature_key
+
+ return converter_fn(**converter_kwargs)
+
+
+def _convert_model(flags):
+ """Calls function to convert the TensorFlow model into a TFLite model.
+
+ Args:
+ flags: argparse.Namespace object.
+ """
+ # Create converter.
+ converter = _get_toco_converter(flags)
+ if flags.inference_type:
+ converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type)
+ if flags.output_format:
+ converter.output_format = _toco_flags_pb2.FileFormat.Value(
+ flags.output_format)
+
+ if flags.mean_values and flags.std_dev_values:
+ input_arrays = _parse_array(flags.input_arrays)
+ std_dev_values = _parse_int_array(flags.std_dev_values)
+ mean_values = _parse_int_array(flags.mean_values)
+ quant_stats = zip(mean_values, std_dev_values)
+ converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
+
+ if flags.drop_control_dependency:
+ converter.drop_control_dependency = flags.drop_control_dependency
+ if flags.allow_custom_ops:
+ converter.allow_custom_ops = flags.allow_custom_ops
+
+ # Convert model.
+ output_data = converter.convert()
+ with open(flags.output_file, "wb") as f:
+ f.write(output_data)
+
+
+def _check_flags(flags, unparsed):
+ """Checks the parsed and unparsed flags to ensure they are valid.
+
+ Displays warnings for unparsed flags. Raises an error for parsed flags that
+ don't meet the required conditions.
+
+ Args:
+ flags: argparse.Namespace object containing TFLite flags.
+ unparsed: List of unparsed flags.
+
+ Raises:
+ ValueError: Invalid flags.
+ """
+ # Check unparsed flags for common mistakes based on previous TOCO.
+ if unparsed:
+ print("tflite_convert: warning: Unable to parse following flags "
+ "'{}'".format(",".join(unparsed)))
+ for flag in unparsed:
+ if "--input_file=" in flag:
+ print("tflite_convert: warning: Use --graph_def_file instead of "
+ "--input_file")
+ if "--std_values=" in flag:
+ print("tflite_convert: warning: Use --std_dev_values instead of "
+ "--std_values")
+
+ # Check that flags are valid.
+ if flags.graph_def_file and (not flags.input_arrays or
+ not flags.output_arrays):
+ raise ValueError("--input_arrays and --output_arrays are required with "
+ "--graph_def_file")
+
+ if flags.input_shapes:
+ if not flags.input_arrays:
+ raise ValueError("--input_shapes must be used with --input_arrays")
+ if flags.input_shapes.count(":") != flags.input_arrays.count(","):
+ raise ValueError("--input_shapes and --input_arrays must have the same "
+ "number of items")
+
+ if flags.std_dev_values or flags.mean_values:
+ if bool(flags.std_dev_values) != bool(flags.mean_values):
+ raise ValueError("--std_dev_values and --mean_values must be used "
+ "together")
+ if not flags.input_arrays:
+ raise ValueError("--std_dev_values and --mean_values must be used with "
+ "--input_arrays")
+ if (flags.std_dev_values.count(",") != flags.mean_values.count(",") or
+ flags.std_dev_values.count(",") != flags.input_arrays.count(",")):
+ raise ValueError("--std_dev_values, --mean_values, and --input_arrays "
+ "must have the same number of items")
+
+
+def run_main(_):
+ """Main in toco_convert.py."""
+ parser = argparse.ArgumentParser(
+ description=("Command line tool to run TensorFlow Lite Optimizing "
+ "Converter (TOCO)."))
+
+ # Output file flag.
+ parser.add_argument(
+ "--output_file",
+ type=str,
+ help="Full filepath of the output file.",
+ required=True)
+
+ # Input file flags.
+ input_file_group = parser.add_mutually_exclusive_group(required=True)
+ input_file_group.add_argument(
+ "--graph_def_file",
+ type=str,
+ help="Full filepath of file containing TensorFlow GraphDef.")
+ input_file_group.add_argument(
+ "--saved_model_dir",
+ type=str,
+ help="Full filepath of directory containing the SavedModel.")
+
+ # Model format flags.
+ parser.add_argument(
+ "--output_format",
+ type=str,
+ choices=["TFLITE", "GRAPHVIZ_DOT"],
+ help="Output file format.")
+ parser.add_argument(
+ "--inference_type",
+ type=str,
+ choices=["FLOAT", "QUANTIZED_UINT8"],
+ help="Target data type of arrays in the output file.")
+
+ # Input and output arrays flags.
+ parser.add_argument(
+ "--input_arrays",
+ type=str,
+ help="Names of the output arrays, comma-separated.")
+ parser.add_argument(
+ "--input_shapes",
+ type=str,
+ help="Shapes corresponding to --input_arrays, colon-separated.")
+ parser.add_argument(
+ "--output_arrays",
+ type=str,
+ help="Names of the output arrays, comma-separated.")
+
+ # SavedModel related flags.
+ parser.add_argument(
+ "--saved_model_tag_set",
+ type=str,
+ help=("Set of tags identifying the MetaGraphDef within the SavedModel "
+ "to analyze. All tags must be present. (default \"serve\")"))
+ parser.add_argument(
+ "--saved_model_signature_key",
+ type=str,
+ help=("Key identifying SignatureDef containing inputs and outputs. "
+ "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
+
+ # Quantization flags.
+ parser.add_argument(
+ "--std_dev_values",
+ type=str,
+ help=("Standard deviation of training data for each input tensor, "
+ "comma-separated. Used for quantization. (default None)"))
+ parser.add_argument(
+ "--mean_values",
+ type=str,
+ help=("Mean of training data for each input tensor, comma-separated. "
+ "Used for quantization. (default None)"))
+
+ # Graph manipulation flags.
+ parser.add_argument(
+ "--drop_control_dependency",
+ type=bool,
+ help=("Boolean indicating whether to drop control dependencies silently. "
+ "This is due to TensorFlow Lite not supporting control "
+ "dependencies. (default True)"))
+ parser.add_argument(
+ "--allow_custom_ops",
+ type=bool,
+ help=("Boolean indicating whether to allow custom operations. When false "
+ "any unknown operation is an error. When true, custom ops are "
+ "created for any op that is unknown. The developer will need to "
+ "provide these to the TensorFlow Lite runtime with a custom "
+ "resolver. (default False)"))
+
+ tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
+ try:
+ _check_flags(tflite_flags, unparsed)
+ except ValueError as e:
+ parser.print_usage()
+ file_name = os.path.basename(sys.argv[0])
+ sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
+ sys.exit(1)
+ _convert_model(tflite_flags)
+
+
+def main():
+ app.run(main=run_main, argv=sys.argv[:1])
+
+
+if __name__ == "__main__":
+ main()
* [High-level overview](#high-level-overview)
* [API](#api)
* [Basic examples](#basic)
- * [Exporting a GraphDef with constants](#basic-graphdef-const)
- * [Exporting a GraphDef with variables](#basic-graphdef-var)
+ * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess)
+ * [Exporting a GraphDef from file](#basic-graphdef-file)
* [Exporting a SavedModel](#basic-savedmodel)
* [Complex examples](#complex)
* [Exporting a quantized GraphDef](#complex-quant)
The following section shows examples of how to convert a basic float-point model
from each of the supported data formats into a TensorFlow Lite FlatBuffers.
-### Exporting a GraphDef with constants <a name="basic-graphdef-const"></a>
+### Exporting a GraphDef from tf.Session <a name="basic-graphdef-sess"></a>
-The following example shows how to convert a TensorFlow GraphDef with constants
-into a TensorFlow Lite FlatBuffer.
+The following example shows how to convert a TensorFlow GraphDef into a
+TensorFlow Lite FlatBuffer from a `tf.Session` object.
```python
import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
-val = img + const
+var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3))
+val = img + var
out = tf.identity(val, name="out")
with tf.Session() as sess:
open("converted_model.tflite", "wb").write(tflite_model)
```
-### Exporting a GraphDef with variables <a name="basic-graphdef-var"></a>
+### Exporting a GraphDef from file <a name="basic-graphdef-file"></a>
-If a model has variables, they need to be turned into constants through a
-process known as freezing. It can be accomplished by setting `freeze_variables`
-to `True` as shown in the example below.
+The following example shows how to convert a TensorFlow GraphDef into a
+TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and
+`.pbtxt` files are accepted.
+
+The example uses
+[Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz).
+The function only supports GraphDefs frozen via
+[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py).
```python
import tensorflow as tf
-img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3))
-val = img + var
-out = tf.identity(val, name="out")
+graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb"
+input_arrays = ["input"]
+output_arrays = ["MobilenetV1/Predictions/Softmax"]
-with tf.Session() as sess:
- converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out],
- freeze_variables=True)
- tflite_model = converter.convert()
- open("converted_model.tflite", "wb").write(tflite_model)
+converter = tf.contrib.lite.TocoConverter.from_flatbuffer_file(
+ graph_def_file, input_arrays, output_arrays)
+tflite_model = converter.convert()
+open("converted_model.tflite", "wb").write(tflite_model)
```
### Exporting a SavedModel <a name="basic-savedmodel"></a>
## Complex examples <a name="complex"></a>
For models where the default value of the attributes is not sufficient, the
-variables values should be set before calling `convert()`. In order to call any
-constants use `tf.contrib.lite.constants.<CONSTANT_NAME>` as seen below with
+attribute's values should be set before calling `convert()`. In order to call
+any constants use `tf.contrib.lite.constants.<CONSTANT_NAME>` as seen below with
`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TocoConverter)` in the Python
terminal for detailed documentation on the attributes.
with tf.Session() as sess:
converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
- converter.quantized_input_stats = [(0., 1.)] # mean, std_dev
+ converter.quantized_input_stats = {"img" : (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
```
],
)
-py_binary(
- name = "toco_wrapper",
- srcs = ["toco_wrapper.py"],
- srcs_version = "PY2AND3",
-)
-
tf_py_test(
name = "toco_from_protos_test",
srcs = ["toco_from_protos_test.py"],
+++ /dev/null
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Wrapper for runninmg toco binary embedded in pip site-package.
-
-NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/.
-It can only install Python "console-scripts." This will work as a console
-script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS).
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-
-def main():
- # Pip installs the binary in aux-bin off of main site-package install.
- # Just find it and exec, passing all arguments in the process.
- # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary.
- print("""TOCO from pip install is currently not working on command line.
-Please use the python TOCO API or use
-bazel run tensorflow/contrib/lite:toco -- <args> from a TensorFlow source dir.
-""")
- sys.exit(1)
- # TODO(aselle): Replace this when we find a way to run toco without
- # blowing up executable size.
- # binary = os.path.join(tf.__path__[0], 'aux-bin/toco')
- # os.execvp(binary, sys.argv)
"//conditions:default": COMMON_PIP_DEPS + [
":simple_console",
"//tensorflow/contrib/lite/python:interpreter_test_data",
- "//tensorflow/contrib/lite/python:tf_lite_py_pip",
- "//tensorflow/contrib/lite/toco:toco",
- "//tensorflow/contrib/lite/toco/python:toco_wrapper",
+ "//tensorflow/contrib/lite/python:tflite_convert",
"//tensorflow/contrib/lite/toco/python:toco_from_protos",
],
}) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([
fi
mkdir "${TMPDIR}/tensorflow/aux-bin"
# Install toco as a binary in aux-bin.
- # TODO(aselle): Re-enable this when we find a way to do it without doubling
- # the whl size (over the limit).
- # cp bazel-bin/tensorflow/contrib/lite/toco/toco ${TMPDIR}/tensorflow/aux-bin/
+ cp bazel-bin/tensorflow/contrib/lite/python/tflite_convert ${TMPDIR}/tensorflow/aux-bin/
fi
# protobuf pip package doesn't ship with header files. Copy the headers
CONSOLE_SCRIPTS = [
'freeze_graph = tensorflow.python.tools.freeze_graph:run_main',
'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main',
- 'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main',
+ 'tflite_convert = tensorflow.contrib.lite.python.tflite_convert:main',
+ 'toco = tensorflow.contrib.lite.python.tflite_convert:main',
'saved_model_cli = tensorflow.python.tools.saved_model_cli:main',
# We need to keep the TensorBoard command, even though the console script
# is now declared by the tensorboard pip package. If we remove the