],
)
+py_binary(
+ name = "convert_savedmodel",
+ srcs = ["convert_savedmodel.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":lite",
+ "//tensorflow/contrib/saved_model:saved_model_py",
+ "//tensorflow/python:graph_util",
+ "//tensorflow/python/tools:freeze_graph_lib",
+ ],
+)
+
+py_test(
+ name = "convert_savedmodel_test",
+ srcs = ["convert_savedmodel_test.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":convert_savedmodel",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//tensorflow/python/saved_model",
+ ],
+)
+
+# Transitive dependencies of this target will be included in the pip package.
+py_library(
+ name = "tf_lite_py_pip",
+ deps = [
+ ":convert_savedmodel",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""TensorFlow Lite flatbuffer generation from saved_models.
+
+Example:
+
+bazel run third_party/tensorflow/contrib/lite/python:convert_savedmodel -- \
+ --saved_model_dir=/tmp/test_saved_model/1519865537 \
+ --output_tflite=/tmp/test.lite
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.saved_model.python.saved_model import reader
+from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import graph_util as tf_graph_util
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import app
+from tensorflow.python.platform import flags
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import tag_constants
+
+flags.DEFINE_string("saved_model_dir", "", "Saved model directory to convert.")
+flags.DEFINE_string("output_tflite", None, "File path to write flatbuffer.")
+flags.DEFINE_string("output_arrays", None,
+ "List of output tensor names, the default value is None, "
+ "which means the conversion will keep all outputs.")
+flags.DEFINE_integer("batch_size", 1,
+ "If input tensor shape has None at first dimension, "
+ "e.g. (None,224,224,3), replace None with batch_size.")
+flags.DEFINE_string("tag_set", tag_constants.SERVING,
+ "Group of tag(s) of the MetaGraphDef in the saved_model, "
+ "in string format, separated by ','. For tag-set contains "
+ "multiple tags, all tags must be passed in.")
+flags.DEFINE_string("signature_key",
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
+ "This is signature key to extract inputs, outputs.")
+
+
+def log_tensor_details(tensor_info):
+ """Log tensor details: name, shape, and type."""
+ for key in tensor_info:
+ val = tensor_info[key]
+ dtype = types_pb2.DataType.Name(val.dtype)
+ if val.tensor_shape.unknown_rank:
+ shape = "unknown_rank"
+ else:
+ dims = [str(dim.size) for dim in val.tensor_shape.dim]
+ shape = "({})".format(", ".join(dims))
+
+ logging.info("Tensor's key in savedmodel's tensor_map: %s", key)
+ logging.info(" tensor name: %s, shape: %s, type: %s", val.name, shape,
+ dtype)
+
+
+def get_meta_graph_def(saved_model_dir, tag_set):
+ """Validate savedmodel and extract MetaGraphDef.
+
+ Args:
+ saved_model_dir: Savedmodel path to convert.
+ tag_set: Set of tag(s) of the MetaGraphDef to load.
+
+ Returns:
+ The meta_graph_def used for tflite conversion.
+
+ Raises:
+ ValueError: No valid MetaGraphDef for given tag_set.
+ """
+ saved_model = reader.read_saved_model(saved_model_dir)
+ tag_sets = []
+ result_meta_graph_def = None
+ for meta_graph_def in saved_model.meta_graphs:
+ meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags)
+ tag_sets.append(meta_graph_tag_set)
+ if meta_graph_tag_set == tag_set:
+ result_meta_graph_def = meta_graph_def
+ logging.info("The given SavedModel contains the following tags: %s", tag_sets)
+ if result_meta_graph_def is not None:
+ return result_meta_graph_def
+ else:
+ raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible "
+ "values are '{}'. ".format(tag_set, tag_sets))
+
+
+def get_signature_def(meta_graph, signature_key):
+ """Get the signature def from meta_graph with given signature_key.
+
+ Args:
+ meta_graph: meta_graph_def.
+ signature_key: signature_def in the meta_graph_def.
+
+ Returns:
+ The signature_def used for tflite conversion.
+
+ Raises:
+ ValueError: Given signature_key is not valid for this meta_graph.
+ """
+ signature_def_map = meta_graph.signature_def
+ signature_def_keys = set(signature_def_map.keys())
+ logging.info(
+ "The given SavedModel MetaGraphDef contains SignatureDefs with the "
+ "following keys: %s", signature_def_keys)
+ if signature_key not in signature_def_keys:
+ raise ValueError("No '{}' in the saved_model\'s SignatureDefs. Possible "
+ "values are '{}'. ".format(signature_key,
+ signature_def_keys))
+ signature_def = signature_def_utils.get_signature_def_by_key(
+ meta_graph, signature_key)
+ return signature_def
+
+
+def get_inputs_outputs(signature_def):
+ """Get inputs and outputs from signature def.
+
+ Args:
+ signature_def: signatuer def in the meta_graph_def for conversion.
+
+ Returns:
+ The inputs and outputs in the graph for conversion.
+ """
+ inputs_tensor_info = signature_def.inputs
+ outputs_tensor_info = signature_def.outputs
+ logging.info("input tensors info: ")
+ log_tensor_details(inputs_tensor_info)
+ logging.info("output tensors info: ")
+ log_tensor_details(outputs_tensor_info)
+
+ def gather_names(tensor_info):
+ return [tensor_info[key].name for key in tensor_info]
+
+ inputs = gather_names(inputs_tensor_info)
+ outputs = gather_names(outputs_tensor_info)
+ return inputs, outputs
+
+
+def convert(saved_model_dir,
+ output_tflite=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
+ batch_size=1):
+ """Convert a savedmodel to tflite flatbuffer.
+
+ Args:
+ saved_model_dir: Saved model directory to convert.
+ output_tflite: File path to write result flatbuffer.
+ output_arrays: List of output tensor names, the default value is None, which
+ means conversion keeps all output tensors. This is also used to filter
+ tensors that are from Op currently not supported in tflite, e.g., Argmax).
+ tag_set: This is the set of tags to get meta_graph_def in saved_model.
+ signature_key: This is the signature key to extract inputs, outputs.
+ batch_size: If input tensor shape has None at first dimension,
+ e.g. (None,224,224,3), replace None with batch_size.
+
+ Returns:
+ The converted data. For example if tflite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ ValueError: If tag_set does not indicate any meta_graph_def in saved_model,
+ or signature_key is not in relevant meta_graph_def,
+ or input shape has None beyond 1st dimension, e.g., (1,None, None, 3),
+ or given output_arrays are not valid causing empty outputs.
+ """
+ if tag_set is None:
+ tag_set = set([tag_constants.SERVING])
+
+ meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
+ signature_def = get_signature_def(meta_graph, signature_key)
+ inputs, outputs = get_inputs_outputs(signature_def)
+
+ graph = ops.Graph()
+ with session.Session(graph=graph) as sess:
+
+ loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
+
+ in_tensors = [graph.get_tensor_by_name(input_) for input_ in inputs]
+
+ # Users can use output_arrays to filter output tensors for conversion.
+ # If output_arrays is None, we keep all output tensors. In future, we may
+ # use tflite supported Op list and check whether op is custom Op to
+ # automatically filter output arrays.
+ # TODO(zhixianyan): Use tflite supported Op list to filter outputs.
+ if output_arrays is not None:
+ output_arrays = output_arrays.split(",")
+ out_tensors = [
+ graph.get_tensor_by_name(output)
+ for output in outputs
+ if output.split(":")[0] in output_arrays
+ ]
+ else:
+ out_tensors = [graph.get_tensor_by_name(output) for output in outputs]
+
+ output_names = [node.split(":")[0] for node in outputs]
+
+ if not out_tensors:
+ raise ValueError(
+ "No valid output tensors for '{}', possible values are '{}'".format(
+ output_arrays, output_names))
+
+ frozen_graph_def = tf_graph_util.convert_variables_to_constants(
+ sess, graph.as_graph_def(), output_names)
+
+ # Toco requires fully defined tensor shape, for input tensor with None in
+ # their shape, e.g., (None, 224, 224, 3), we need to replace first None with
+ # a given batch size. For shape with more None, e.g. (None, None, None, 3),
+ # still be able to replace and convert, but require further investigation.
+ # TODO(zhixianyan): Add supports for input tensor with more None in shape.
+ for i in range(len(in_tensors)):
+ shape = in_tensors[i].get_shape().as_list()
+ if shape[0] is None:
+ shape[0] = batch_size
+ if None in shape[1:]:
+ raise ValueError(
+ "Only support None shape at 1st dim as batch_size. But tensor "
+ "'{}' 's shape '{}' has None at other dimension. ".format(
+ inputs[i], shape))
+ in_tensors[i].set_shape(shape)
+
+ result = lite.toco_convert(frozen_graph_def, in_tensors, out_tensors)
+
+ if output_tflite is not None:
+ with gfile.Open(output_tflite, "wb") as f:
+ f.write(result)
+ logging.info("Successfully converted to: %s", output_tflite)
+
+ return result
+
+
+def main(_):
+ convert(
+ saved_model_dir=flags.FLAGS.saved_model_dir,
+ output_tflite=flags.FLAGS.output_tflite,
+ output_arrays=flags.FLAGS.output_arrays,
+ batch_size=flags.FLAGS.batch_size,
+ tag_set=set(flags.FLAGS.tag_set.split(",")),
+ signature_key=flags.FLAGS.signature_key)
+
+
+if __name__ == "__main__":
+ app.run(main)
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TF Lite SavedModel Conversion test cases.
+
+ - test on generated saved_models from simple graphs (sanity check)
+ - test mnist savedmodel generated on-the-fly
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from tensorflow.contrib.lite.python import convert_savedmodel
+from tensorflow.python import estimator
+from tensorflow.python import keras
+from tensorflow.python import layers
+from tensorflow.python import losses
+from tensorflow.python import nn
+from tensorflow.python import saved_model
+from tensorflow.python import train
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase):
+
+ def _createSimpleSavedModel(self, shape):
+ """Create a simple savedmodel on the fly."""
+ saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
+ with session.Session() as sess:
+ in_tensor = array_ops.placeholder(shape=shape, dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ inputs = {"x": in_tensor}
+ outputs = {"y": out_tensor}
+ saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+ return saved_model_dir
+
+ def testSimpleSavedModel(self):
+ """Test a simple savedmodel created on the fly."""
+ # Create a simple savedmodel
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ # Convert to tflite
+ result = convert_savedmodel.convert(saved_model_dir=saved_model_dir)
+ self.assertTrue(result)
+
+ def testSimpleSavedModelWithNoneBatchSizeInShape(self):
+ """Test a simple savedmodel, with None in input tensor's shape."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
+ result = convert_savedmodel.convert(saved_model_dir=saved_model_dir)
+ self.assertTrue(result)
+
+ def testSimpleSavedModelWithMoreNoneInShape(self):
+ """Test a simple savedmodel, fail as more None in input shape."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, None, 3])
+ # Convert to tflite: this should raise ValueError, as 3rd dim is None.
+ with self.assertRaises(ValueError):
+ convert_savedmodel.convert(saved_model_dir=saved_model_dir)
+
+ def testSimpleSavedModelWithWrongSignatureKey(self):
+ """Test a simple savedmodel, fail as given signature is invalid."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ # Convert to tflite: this should raise ValueError, as
+ # signature_key does not exit in the saved_model.
+ with self.assertRaises(ValueError):
+ convert_savedmodel.convert(
+ saved_model_dir=saved_model_dir, signature_key="wrong-key")
+
+ def testSimpleSavedModelWithWrongOutputArray(self):
+ """Test a simple savedmodel, fail as given output_arrays is invalid."""
+ # Create a simple savedmodel
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ # Convert to tflite: this should raise ValueError, as
+ # output_arrays is not valid for the saved_model.
+ with self.assertRaises(ValueError):
+ convert_savedmodel.convert(
+ saved_model_dir=saved_model_dir, output_arrays="wrong-output")
+
+ def testMultipleMetaGraphDef(self):
+ """Test saved model with multiple MetaGraphDef."""
+ saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd")
+ builder = saved_model.builder.SavedModelBuilder(saved_model_dir)
+ with session.Session(graph=ops.Graph()) as sess:
+ # MetaGraphDef 1
+ in_tensor = array_ops.placeholder(shape=[1, 28, 28], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sig_input_tensor = saved_model.utils.build_tensor_info(in_tensor)
+ sig_input_tensor_signature = {"x": sig_input_tensor}
+ sig_output_tensor = saved_model.utils.build_tensor_info(out_tensor)
+ sig_output_tensor_signature = {"y": sig_output_tensor}
+ predict_signature_def = (
+ saved_model.signature_def_utils.build_signature_def(
+ sig_input_tensor_signature, sig_output_tensor_signature,
+ saved_model.signature_constants.PREDICT_METHOD_NAME))
+ signature_def_map = {
+ saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ predict_signature_def
+ }
+ builder.add_meta_graph_and_variables(
+ sess,
+ tags=[saved_model.tag_constants.SERVING, "additional_test_tag"],
+ signature_def_map=signature_def_map)
+ # MetaGraphDef 2
+ builder.add_meta_graph(tags=["tflite"])
+ builder.save(True)
+
+ # Convert to tflite
+ convert_savedmodel.convert(
+ saved_model_dir=saved_model_dir,
+ tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"]))
+
+
+class Model(keras.Model):
+ """Model to recognize digits in the MNIST dataset.
+
+ Train and export savedmodel, used for testOnflyTrainMnistSavedModel
+
+ Network structure is equivalent to:
+ https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
+ and
+ https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
+
+ But written as a ops.keras.Model using the layers API.
+ """
+
+ def __init__(self, data_format):
+ """Creates a model for classifying a hand-written digit.
+
+ Args:
+ data_format: Either "channels_first" or "channels_last".
+ "channels_first" is typically faster on GPUs while "channels_last" is
+ typically faster on CPUs. See
+ https://www.tensorflow.org/performance/performance_guide#data_formats
+ """
+ super(Model, self).__init__()
+ self._input_shape = [-1, 28, 28, 1]
+
+ self.conv1 = layers.Conv2D(
+ 32, 5, padding="same", data_format=data_format, activation=nn.relu)
+ self.conv2 = layers.Conv2D(
+ 64, 5, padding="same", data_format=data_format, activation=nn.relu)
+ self.fc1 = layers.Dense(1024, activation=nn.relu)
+ self.fc2 = layers.Dense(10)
+ self.dropout = layers.Dropout(0.4)
+ self.max_pool2d = layers.MaxPooling2D(
+ (2, 2), (2, 2), padding="same", data_format=data_format)
+
+ def __call__(self, inputs, training):
+ """Add operations to classify a batch of input images.
+
+ Args:
+ inputs: A Tensor representing a batch of input images.
+ training: A boolean. Set to True to add operations required only when
+ training the classifier.
+
+ Returns:
+ A logits Tensor with shape [<batch_size>, 10].
+ """
+ y = array_ops.reshape(inputs, self._input_shape)
+ y = self.conv1(y)
+ y = self.max_pool2d(y)
+ y = self.conv2(y)
+ y = self.max_pool2d(y)
+ y = layers.flatten(y)
+ y = self.fc1(y)
+ y = self.dropout(y, training=training)
+ return self.fc2(y)
+
+
+def model_fn(features, labels, mode, params):
+ """The model_fn argument for creating an Estimator."""
+ model = Model(params["data_format"])
+ image = features
+ if isinstance(image, dict):
+ image = features["image"]
+
+ if mode == estimator.ModeKeys.PREDICT:
+ logits = model(image, training=False)
+ predictions = {
+ "classes": math_ops.argmax(logits, axis=1),
+ "probabilities": nn.softmax(logits),
+ }
+ return estimator.EstimatorSpec(
+ mode=estimator.ModeKeys.PREDICT,
+ predictions=predictions,
+ export_outputs={
+ "classify": estimator.export.PredictOutput(predictions)
+ })
+
+ elif mode == estimator.ModeKeys.TRAIN:
+ optimizer = train.AdamOptimizer(learning_rate=1e-4)
+
+ logits = model(image, training=True)
+ loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
+ return estimator.EstimatorSpec(
+ mode=estimator.ModeKeys.TRAIN,
+ loss=loss,
+ train_op=optimizer.minimize(loss, train.get_or_create_global_step()))
+
+ elif mode == estimator.ModeKeys.EVAL:
+ logits = model(image, training=False)
+ loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
+ return estimator.EstimatorSpec(
+ mode=estimator.ModeKeys.EVAL,
+ loss=loss,
+ eval_metric_ops={
+ "accuracy":
+ ops.metrics.accuracy(
+ labels=labels, predictions=math_ops.argmax(logits, axis=1)),
+ })
+
+
+def dummy_input_fn():
+ image = random_ops.random_uniform([100, 784])
+ labels = random_ops.random_uniform([100, 1], maxval=9, dtype=dtypes.int32)
+ return image, labels
+
+
+class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
+
+ def testTrainedMnistSavedModel(self):
+ """Test mnist savedmodel, trained with dummy data and small steps."""
+ # Build classifier
+ classifier = estimator.Estimator(
+ model_fn=model_fn,
+ params={
+ "data_format": "channels_last" # tflite format
+ })
+
+ # Train and pred for serving
+ classifier.train(input_fn=dummy_input_fn, steps=2)
+ image = array_ops.placeholder(dtypes.float32, [None, 28, 28])
+ pred_input_fn = estimator.export.build_raw_serving_input_receiver_fn({
+ "image": image,
+ })
+
+ # Export savedmodel
+ saved_model_dir = os.path.join(self.get_temp_dir(), "mnist_savedmodel")
+ classifier.export_savedmodel(saved_model_dir, pred_input_fn)
+
+ # Convert to tflite and test output
+ saved_model_name = os.listdir(saved_model_dir)[0]
+ saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name)
+ output_tflite = os.path.join(saved_model_dir,
+ saved_model_final_dir + ".lite")
+ # TODO(zhixianyan): no need to limit output_arrays to `Softmax'
+ # once b/74205001 fixed and argmax implemented in tflite.
+ result = convert_savedmodel.convert(
+ saved_model_dir=saved_model_final_dir,
+ output_arrays="Softmax",
+ output_tflite=output_tflite)
+
+ self.assertTrue(result)
+
+
+if __name__ == "__main__":
+ test.main()
"//tensorflow/contrib/keras:keras",
"//tensorflow/contrib/labeled_tensor:labeled_tensor_pip",
"//tensorflow/contrib/lite/python:interpreter_test_data",
+ "//tensorflow/contrib/lite/python:tf_lite_py_pip",
"//tensorflow/contrib/lite/toco:toco",
"//tensorflow/contrib/lite/toco/python:toco_wrapper",
"//tensorflow/contrib/lite/toco/python:toco_from_protos",