From 80cac4d951d78660b32ccfde9d91ff4050e5d477 Mon Sep 17 00:00:00 2001 From: Zhixian Yan Date: Thu, 22 Mar 2018 17:59:19 -0700 Subject: [PATCH] Convert saved_model to tflite flatbuffer. PiperOrigin-RevId: 190156133 --- tensorflow/contrib/lite/python/BUILD | 35 +++ .../contrib/lite/python/convert_savedmodel.py | 261 +++++++++++++++++++ .../contrib/lite/python/convert_savedmodel_test.py | 276 +++++++++++++++++++++ tensorflow/tools/pip_package/BUILD | 1 + 4 files changed, 573 insertions(+) create mode 100644 tensorflow/contrib/lite/python/convert_savedmodel.py create mode 100644 tensorflow/contrib/lite/python/convert_savedmodel_test.py diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index d6f3921..ce1a81d 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -84,6 +84,41 @@ py_test( ], ) +py_binary( + name = "convert_savedmodel", + srcs = ["convert_savedmodel.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite", + "//tensorflow/contrib/saved_model:saved_model_py", + "//tensorflow/python:graph_util", + "//tensorflow/python/tools:freeze_graph_lib", + ], +) + +py_test( + name = "convert_savedmodel_test", + srcs = ["convert_savedmodel_test.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":convert_savedmodel", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python/saved_model", + ], +) + +# Transitive dependencies of this target will be included in the pip package. +py_library( + name = "tf_lite_py_pip", + deps = [ + ":convert_savedmodel", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/python/convert_savedmodel.py b/tensorflow/contrib/lite/python/convert_savedmodel.py new file mode 100644 index 0000000..d39e1a1 --- /dev/null +++ b/tensorflow/contrib/lite/python/convert_savedmodel.py @@ -0,0 +1,261 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""TensorFlow Lite flatbuffer generation from saved_models. + +Example: + +bazel run third_party/tensorflow/contrib/lite/python:convert_savedmodel -- \ + --saved_model_dir=/tmp/test_saved_model/1519865537 \ + --output_tflite=/tmp/test.lite + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.saved_model.python.saved_model import reader +from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils +from tensorflow.core.framework import types_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import graph_util as tf_graph_util +from tensorflow.python.framework import ops +from tensorflow.python.platform import app +from tensorflow.python.platform import flags +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants + +flags.DEFINE_string("saved_model_dir", "", "Saved model directory to convert.") +flags.DEFINE_string("output_tflite", None, "File path to write flatbuffer.") +flags.DEFINE_string("output_arrays", None, + "List of output tensor names, the default value is None, " + "which means the conversion will keep all outputs.") +flags.DEFINE_integer("batch_size", 1, + "If input tensor shape has None at first dimension, " + "e.g. (None,224,224,3), replace None with batch_size.") +flags.DEFINE_string("tag_set", tag_constants.SERVING, + "Group of tag(s) of the MetaGraphDef in the saved_model, " + "in string format, separated by ','. For tag-set contains " + "multiple tags, all tags must be passed in.") +flags.DEFINE_string("signature_key", + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + "This is signature key to extract inputs, outputs.") + + +def log_tensor_details(tensor_info): + """Log tensor details: name, shape, and type.""" + for key in tensor_info: + val = tensor_info[key] + dtype = types_pb2.DataType.Name(val.dtype) + if val.tensor_shape.unknown_rank: + shape = "unknown_rank" + else: + dims = [str(dim.size) for dim in val.tensor_shape.dim] + shape = "({})".format(", ".join(dims)) + + logging.info("Tensor's key in savedmodel's tensor_map: %s", key) + logging.info(" tensor name: %s, shape: %s, type: %s", val.name, shape, + dtype) + + +def get_meta_graph_def(saved_model_dir, tag_set): + """Validate savedmodel and extract MetaGraphDef. + + Args: + saved_model_dir: Savedmodel path to convert. + tag_set: Set of tag(s) of the MetaGraphDef to load. + + Returns: + The meta_graph_def used for tflite conversion. + + Raises: + ValueError: No valid MetaGraphDef for given tag_set. + """ + saved_model = reader.read_saved_model(saved_model_dir) + tag_sets = [] + result_meta_graph_def = None + for meta_graph_def in saved_model.meta_graphs: + meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags) + tag_sets.append(meta_graph_tag_set) + if meta_graph_tag_set == tag_set: + result_meta_graph_def = meta_graph_def + logging.info("The given SavedModel contains the following tags: %s", tag_sets) + if result_meta_graph_def is not None: + return result_meta_graph_def + else: + raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible " + "values are '{}'. ".format(tag_set, tag_sets)) + + +def get_signature_def(meta_graph, signature_key): + """Get the signature def from meta_graph with given signature_key. + + Args: + meta_graph: meta_graph_def. + signature_key: signature_def in the meta_graph_def. + + Returns: + The signature_def used for tflite conversion. + + Raises: + ValueError: Given signature_key is not valid for this meta_graph. + """ + signature_def_map = meta_graph.signature_def + signature_def_keys = set(signature_def_map.keys()) + logging.info( + "The given SavedModel MetaGraphDef contains SignatureDefs with the " + "following keys: %s", signature_def_keys) + if signature_key not in signature_def_keys: + raise ValueError("No '{}' in the saved_model\'s SignatureDefs. Possible " + "values are '{}'. ".format(signature_key, + signature_def_keys)) + signature_def = signature_def_utils.get_signature_def_by_key( + meta_graph, signature_key) + return signature_def + + +def get_inputs_outputs(signature_def): + """Get inputs and outputs from signature def. + + Args: + signature_def: signatuer def in the meta_graph_def for conversion. + + Returns: + The inputs and outputs in the graph for conversion. + """ + inputs_tensor_info = signature_def.inputs + outputs_tensor_info = signature_def.outputs + logging.info("input tensors info: ") + log_tensor_details(inputs_tensor_info) + logging.info("output tensors info: ") + log_tensor_details(outputs_tensor_info) + + def gather_names(tensor_info): + return [tensor_info[key].name for key in tensor_info] + + inputs = gather_names(inputs_tensor_info) + outputs = gather_names(outputs_tensor_info) + return inputs, outputs + + +def convert(saved_model_dir, + output_tflite=None, + output_arrays=None, + tag_set=None, + signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + batch_size=1): + """Convert a savedmodel to tflite flatbuffer. + + Args: + saved_model_dir: Saved model directory to convert. + output_tflite: File path to write result flatbuffer. + output_arrays: List of output tensor names, the default value is None, which + means conversion keeps all output tensors. This is also used to filter + tensors that are from Op currently not supported in tflite, e.g., Argmax). + tag_set: This is the set of tags to get meta_graph_def in saved_model. + signature_key: This is the signature key to extract inputs, outputs. + batch_size: If input tensor shape has None at first dimension, + e.g. (None,224,224,3), replace None with batch_size. + + Returns: + The converted data. For example if tflite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + ValueError: If tag_set does not indicate any meta_graph_def in saved_model, + or signature_key is not in relevant meta_graph_def, + or input shape has None beyond 1st dimension, e.g., (1,None, None, 3), + or given output_arrays are not valid causing empty outputs. + """ + if tag_set is None: + tag_set = set([tag_constants.SERVING]) + + meta_graph = get_meta_graph_def(saved_model_dir, tag_set) + signature_def = get_signature_def(meta_graph, signature_key) + inputs, outputs = get_inputs_outputs(signature_def) + + graph = ops.Graph() + with session.Session(graph=graph) as sess: + + loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir) + + in_tensors = [graph.get_tensor_by_name(input_) for input_ in inputs] + + # Users can use output_arrays to filter output tensors for conversion. + # If output_arrays is None, we keep all output tensors. In future, we may + # use tflite supported Op list and check whether op is custom Op to + # automatically filter output arrays. + # TODO(zhixianyan): Use tflite supported Op list to filter outputs. + if output_arrays is not None: + output_arrays = output_arrays.split(",") + out_tensors = [ + graph.get_tensor_by_name(output) + for output in outputs + if output.split(":")[0] in output_arrays + ] + else: + out_tensors = [graph.get_tensor_by_name(output) for output in outputs] + + output_names = [node.split(":")[0] for node in outputs] + + if not out_tensors: + raise ValueError( + "No valid output tensors for '{}', possible values are '{}'".format( + output_arrays, output_names)) + + frozen_graph_def = tf_graph_util.convert_variables_to_constants( + sess, graph.as_graph_def(), output_names) + + # Toco requires fully defined tensor shape, for input tensor with None in + # their shape, e.g., (None, 224, 224, 3), we need to replace first None with + # a given batch size. For shape with more None, e.g. (None, None, None, 3), + # still be able to replace and convert, but require further investigation. + # TODO(zhixianyan): Add supports for input tensor with more None in shape. + for i in range(len(in_tensors)): + shape = in_tensors[i].get_shape().as_list() + if shape[0] is None: + shape[0] = batch_size + if None in shape[1:]: + raise ValueError( + "Only support None shape at 1st dim as batch_size. But tensor " + "'{}' 's shape '{}' has None at other dimension. ".format( + inputs[i], shape)) + in_tensors[i].set_shape(shape) + + result = lite.toco_convert(frozen_graph_def, in_tensors, out_tensors) + + if output_tflite is not None: + with gfile.Open(output_tflite, "wb") as f: + f.write(result) + logging.info("Successfully converted to: %s", output_tflite) + + return result + + +def main(_): + convert( + saved_model_dir=flags.FLAGS.saved_model_dir, + output_tflite=flags.FLAGS.output_tflite, + output_arrays=flags.FLAGS.output_arrays, + batch_size=flags.FLAGS.batch_size, + tag_set=set(flags.FLAGS.tag_set.split(",")), + signature_key=flags.FLAGS.signature_key) + + +if __name__ == "__main__": + app.run(main) diff --git a/tensorflow/contrib/lite/python/convert_savedmodel_test.py b/tensorflow/contrib/lite/python/convert_savedmodel_test.py new file mode 100644 index 0000000..70cff9e --- /dev/null +++ b/tensorflow/contrib/lite/python/convert_savedmodel_test.py @@ -0,0 +1,276 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TF Lite SavedModel Conversion test cases. + + - test on generated saved_models from simple graphs (sanity check) + - test mnist savedmodel generated on-the-fly + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from tensorflow.contrib.lite.python import convert_savedmodel +from tensorflow.python import estimator +from tensorflow.python import keras +from tensorflow.python import layers +from tensorflow.python import losses +from tensorflow.python import nn +from tensorflow.python import saved_model +from tensorflow.python import train +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): + + def _createSimpleSavedModel(self, shape): + """Create a simple savedmodel on the fly.""" + saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") + with session.Session() as sess: + in_tensor = array_ops.placeholder(shape=shape, dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + inputs = {"x": in_tensor} + outputs = {"y": out_tensor} + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return saved_model_dir + + def testSimpleSavedModel(self): + """Test a simple savedmodel created on the fly.""" + # Create a simple savedmodel + saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) + # Convert to tflite + result = convert_savedmodel.convert(saved_model_dir=saved_model_dir) + self.assertTrue(result) + + def testSimpleSavedModelWithNoneBatchSizeInShape(self): + """Test a simple savedmodel, with None in input tensor's shape.""" + saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) + result = convert_savedmodel.convert(saved_model_dir=saved_model_dir) + self.assertTrue(result) + + def testSimpleSavedModelWithMoreNoneInShape(self): + """Test a simple savedmodel, fail as more None in input shape.""" + saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, None, 3]) + # Convert to tflite: this should raise ValueError, as 3rd dim is None. + with self.assertRaises(ValueError): + convert_savedmodel.convert(saved_model_dir=saved_model_dir) + + def testSimpleSavedModelWithWrongSignatureKey(self): + """Test a simple savedmodel, fail as given signature is invalid.""" + saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) + # Convert to tflite: this should raise ValueError, as + # signature_key does not exit in the saved_model. + with self.assertRaises(ValueError): + convert_savedmodel.convert( + saved_model_dir=saved_model_dir, signature_key="wrong-key") + + def testSimpleSavedModelWithWrongOutputArray(self): + """Test a simple savedmodel, fail as given output_arrays is invalid.""" + # Create a simple savedmodel + saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) + # Convert to tflite: this should raise ValueError, as + # output_arrays is not valid for the saved_model. + with self.assertRaises(ValueError): + convert_savedmodel.convert( + saved_model_dir=saved_model_dir, output_arrays="wrong-output") + + def testMultipleMetaGraphDef(self): + """Test saved model with multiple MetaGraphDef.""" + saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd") + builder = saved_model.builder.SavedModelBuilder(saved_model_dir) + with session.Session(graph=ops.Graph()) as sess: + # MetaGraphDef 1 + in_tensor = array_ops.placeholder(shape=[1, 28, 28], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sig_input_tensor = saved_model.utils.build_tensor_info(in_tensor) + sig_input_tensor_signature = {"x": sig_input_tensor} + sig_output_tensor = saved_model.utils.build_tensor_info(out_tensor) + sig_output_tensor_signature = {"y": sig_output_tensor} + predict_signature_def = ( + saved_model.signature_def_utils.build_signature_def( + sig_input_tensor_signature, sig_output_tensor_signature, + saved_model.signature_constants.PREDICT_METHOD_NAME)) + signature_def_map = { + saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + predict_signature_def + } + builder.add_meta_graph_and_variables( + sess, + tags=[saved_model.tag_constants.SERVING, "additional_test_tag"], + signature_def_map=signature_def_map) + # MetaGraphDef 2 + builder.add_meta_graph(tags=["tflite"]) + builder.save(True) + + # Convert to tflite + convert_savedmodel.convert( + saved_model_dir=saved_model_dir, + tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"])) + + +class Model(keras.Model): + """Model to recognize digits in the MNIST dataset. + + Train and export savedmodel, used for testOnflyTrainMnistSavedModel + + Network structure is equivalent to: + https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py + and + https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py + + But written as a ops.keras.Model using the layers API. + """ + + def __init__(self, data_format): + """Creates a model for classifying a hand-written digit. + + Args: + data_format: Either "channels_first" or "channels_last". + "channels_first" is typically faster on GPUs while "channels_last" is + typically faster on CPUs. See + https://www.tensorflow.org/performance/performance_guide#data_formats + """ + super(Model, self).__init__() + self._input_shape = [-1, 28, 28, 1] + + self.conv1 = layers.Conv2D( + 32, 5, padding="same", data_format=data_format, activation=nn.relu) + self.conv2 = layers.Conv2D( + 64, 5, padding="same", data_format=data_format, activation=nn.relu) + self.fc1 = layers.Dense(1024, activation=nn.relu) + self.fc2 = layers.Dense(10) + self.dropout = layers.Dropout(0.4) + self.max_pool2d = layers.MaxPooling2D( + (2, 2), (2, 2), padding="same", data_format=data_format) + + def __call__(self, inputs, training): + """Add operations to classify a batch of input images. + + Args: + inputs: A Tensor representing a batch of input images. + training: A boolean. Set to True to add operations required only when + training the classifier. + + Returns: + A logits Tensor with shape [, 10]. + """ + y = array_ops.reshape(inputs, self._input_shape) + y = self.conv1(y) + y = self.max_pool2d(y) + y = self.conv2(y) + y = self.max_pool2d(y) + y = layers.flatten(y) + y = self.fc1(y) + y = self.dropout(y, training=training) + return self.fc2(y) + + +def model_fn(features, labels, mode, params): + """The model_fn argument for creating an Estimator.""" + model = Model(params["data_format"]) + image = features + if isinstance(image, dict): + image = features["image"] + + if mode == estimator.ModeKeys.PREDICT: + logits = model(image, training=False) + predictions = { + "classes": math_ops.argmax(logits, axis=1), + "probabilities": nn.softmax(logits), + } + return estimator.EstimatorSpec( + mode=estimator.ModeKeys.PREDICT, + predictions=predictions, + export_outputs={ + "classify": estimator.export.PredictOutput(predictions) + }) + + elif mode == estimator.ModeKeys.TRAIN: + optimizer = train.AdamOptimizer(learning_rate=1e-4) + + logits = model(image, training=True) + loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) + return estimator.EstimatorSpec( + mode=estimator.ModeKeys.TRAIN, + loss=loss, + train_op=optimizer.minimize(loss, train.get_or_create_global_step())) + + elif mode == estimator.ModeKeys.EVAL: + logits = model(image, training=False) + loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) + return estimator.EstimatorSpec( + mode=estimator.ModeKeys.EVAL, + loss=loss, + eval_metric_ops={ + "accuracy": + ops.metrics.accuracy( + labels=labels, predictions=math_ops.argmax(logits, axis=1)), + }) + + +def dummy_input_fn(): + image = random_ops.random_uniform([100, 784]) + labels = random_ops.random_uniform([100, 1], maxval=9, dtype=dtypes.int32) + return image, labels + + +class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase): + + def testTrainedMnistSavedModel(self): + """Test mnist savedmodel, trained with dummy data and small steps.""" + # Build classifier + classifier = estimator.Estimator( + model_fn=model_fn, + params={ + "data_format": "channels_last" # tflite format + }) + + # Train and pred for serving + classifier.train(input_fn=dummy_input_fn, steps=2) + image = array_ops.placeholder(dtypes.float32, [None, 28, 28]) + pred_input_fn = estimator.export.build_raw_serving_input_receiver_fn({ + "image": image, + }) + + # Export savedmodel + saved_model_dir = os.path.join(self.get_temp_dir(), "mnist_savedmodel") + classifier.export_savedmodel(saved_model_dir, pred_input_fn) + + # Convert to tflite and test output + saved_model_name = os.listdir(saved_model_dir)[0] + saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name) + output_tflite = os.path.join(saved_model_dir, + saved_model_final_dir + ".lite") + # TODO(zhixianyan): no need to limit output_arrays to `Softmax' + # once b/74205001 fixed and argmax implemented in tflite. + result = convert_savedmodel.convert( + saved_model_dir=saved_model_final_dir, + output_arrays="Softmax", + output_tflite=output_tflite) + + self.assertTrue(result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index d55a883..8a80d64 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -158,6 +158,7 @@ sh_binary( "//tensorflow/contrib/keras:keras", "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip", "//tensorflow/contrib/lite/python:interpreter_test_data", + "//tensorflow/contrib/lite/python:tf_lite_py_pip", "//tensorflow/contrib/lite/toco:toco", "//tensorflow/contrib/lite/toco/python:toco_wrapper", "//tensorflow/contrib/lite/toco/python:toco_from_protos", -- 2.7.4