Generate example for basic lstm cell in tflite
authorZhixian Yan <zhixianyan@google.com>
Thu, 22 Feb 2018 20:26:22 +0000 (12:26 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Feb 2018 20:34:23 +0000 (12:34 -0800)
PiperOrigin-RevId: 186656247

tensorflow/contrib/lite/testing/BUILD
tensorflow/contrib/lite/testing/generate_examples.py
tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
tensorflow/contrib/lite/testing/parse_testdata.cc

index 1ccf7d4..b5960d6 100644 (file)
@@ -34,6 +34,7 @@ gen_zipped_test_files(
         "l2norm.zip",
         "local_response_norm.zip",
         "log_softmax.zip",
+        "lstm.zip",
         "max_pool.zip",
         "mean.zip",
         "mul.zip",
index 2cbac7c..2481add 100644 (file)
@@ -46,6 +46,7 @@ from google.protobuf import text_format
 # TODO(aselle): switch to TensorFlow's resource_loader
 from tensorflow.contrib.lite.testing import generate_examples_report as report_lib
 from tensorflow.python.framework import graph_util as tf_graph_util
+from tensorflow.python.ops import rnn
 
 parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
 parser.add_argument("output_path",
@@ -108,11 +109,23 @@ KNOWN_BUGS = {
 }
 
 
+class ExtraTocoOptions(object):
+  """Additonal toco options besides input, output, shape."""
+
+  def __init__(self):
+    # Whether to ignore control dependency nodes.
+    self.drop_control_dependency = False
+    # Allow custom ops in the toco conversion.
+    self.allow_custom_ops = False
+    # Rnn states that are used to support rnn / lstm cells.
+    self.rnn_states = None
+
+
 def toco_options(data_types,
                  input_arrays,
                  output_arrays,
                  shapes,
-                 drop_control_dependency):
+                 extra_toco_options=ExtraTocoOptions()):
   """Create TOCO options to process a model.
 
   Args:
@@ -120,8 +133,7 @@ def toco_options(data_types,
     input_arrays: names of the input tensors
     output_arrays: name of the output tensors
     shapes: shapes of the input tensors
-    drop_control_dependency: whether to ignore control dependency nodes.
-
+    extra_toco_options: additional toco options
   Returns:
     the options in a string.
   """
@@ -137,37 +149,15 @@ def toco_options(data_types,
        " --input_arrays=%s" % ",".join(input_arrays) +
        " --input_shapes=%s" % shape_str +
        " --output_arrays=%s" % ",".join(output_arrays))
-  if drop_control_dependency:
+  if extra_toco_options.drop_control_dependency:
     s += " --drop_control_dependency"
+  if extra_toco_options.allow_custom_ops:
+    s += " --allow_custom_ops"
+  if extra_toco_options.rnn_states:
+    s += (" --rnn_states='" + extra_toco_options.rnn_states + "'")
   return s
 
 
-def write_toco_options(filename,
-                       data_types,
-                       input_arrays,
-                       output_arrays,
-                       shapes,
-                       drop_control_dependency=False):
-  """Create TOCO options to process a model.
-
-  Args:
-    filename: Filename to write the options to.
-    data_types: input and inference types used by TOCO.
-    input_arrays: names of the input tensors
-    output_arrays: names of the output tensors
-    shapes: shapes of the input tensors
-    drop_control_dependency: whether to ignore control dependency nodes.
-  """
-  with open(filename, "w") as fp:
-    fp.write(
-        toco_options(
-            data_types=data_types,
-            input_arrays=input_arrays,
-            output_arrays=output_arrays,
-            shapes=shapes,
-            drop_control_dependency=drop_control_dependency))
-
-
 def write_examples(fp, examples):
   """Given a list `examples`, write a text format representation.
 
@@ -285,12 +275,14 @@ def make_control_dep_tests(zip_path):
     return [input_values], sess.run(
         outputs, feed_dict=dict(zip(inputs, [input_values])))
 
+  extra_toco_options = ExtraTocoOptions()
+  extra_toco_options.drop_control_dependency = True
   make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs,
-                    drop_control_dependency=True)
+                    extra_toco_options)
 
 
 def toco_convert(graph_def_str, input_tensors, output_tensors,
-                 drop_control_dependency=False):
+                 extra_toco_options):
   """Convert a model's graph def into a tflite model.
 
   NOTE: this currently shells out to the toco binary, but we would like
@@ -298,9 +290,9 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
 
   Args:
     graph_def_str: Graph def proto in serialized string format.
-    input_tensors: List of input tensor tuples `(name, shape, type)`
-    output_tensors: List of output tensors (names)
-    drop_control_dependency: whether to ignore control dependency nodes.
+    input_tensors: List of input tensor tuples `(name, shape, type)`.
+    output_tensors: List of output tensors (names).
+    extra_toco_options: Additional toco options.
 
   Returns:
     output tflite model, log_txt from conversion
@@ -312,7 +304,7 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
       input_arrays=[x[0] for x in input_tensors],
       shapes=[x[1] for x in input_tensors],
       output_arrays=output_tensors,
-      drop_control_dependency=drop_control_dependency)
+      extra_toco_options=extra_toco_options)
 
   with tempfile.NamedTemporaryFile() as graphdef_file, \
        tempfile.NamedTemporaryFile() as output_file, \
@@ -341,7 +333,8 @@ def make_zip_of_tests(zip_path,
                       test_parameters,
                       make_graph,
                       make_test_inputs,
-                      drop_control_dependency=False):
+                      extra_toco_options=ExtraTocoOptions(),
+                      use_frozen_graph=False):
   """Helper to make a zip file of a bunch of TensorFlow models.
 
   This does a cartestian product of the dictionary of test_parameters and
@@ -359,7 +352,9 @@ def make_zip_of_tests(zip_path,
       `[input1, input2, ...], [output1, output2, ...]`
     make_test_inputs: function taking `curr_params`, `session`, `input_tensors`,
       `output_tensors` and returns tuple `(input_values, output_values)`.
-    drop_control_dependency: whether to ignore control dependency nodes.
+    extra_toco_options: Additional toco options.
+    use_frozen_graph: Whether or not freeze graph before toco converter.
+
   Raises:
     RuntimeError: if there are toco errors that can't be ignored.
   """
@@ -419,21 +414,25 @@ def make_zip_of_tests(zip_path,
           return None, report
         report["toco"] = report_lib.FAILED
         report["tf"] = report_lib.SUCCESS
-
         # Convert graph to toco
+        input_tensors = [(input_tensor.name.split(":")[0],
+                          input_tensor.get_shape(), input_tensor.dtype)
+                         for input_tensor in inputs]
+        output_tensors = [normalize_output_name(out.name) for out in outputs]
+        graph_def = freeze_graph(
+            sess,
+            tf.global_variables() + inputs +
+            outputs) if use_frozen_graph else sess.graph_def
         tflite_model_binary, toco_log = toco_convert(
-            sess.graph_def.SerializeToString(),
-            [(input_tensor.name.split(":")[0], input_tensor.get_shape(),
-              input_tensor.dtype) for input_tensor in inputs],
-            [normalize_output_name(out.name) for out in outputs],
-            drop_control_dependency)
+            graph_def.SerializeToString(), input_tensors, output_tensors,
+            extra_toco_options)
         report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None
                           else report_lib.FAILED)
         report["toco_log"] = toco_log
 
         if FLAGS.save_graphdefs:
           archive.writestr(label + ".pb",
-                           text_format.MessageToString(sess.graph_def),
+                           text_format.MessageToString(graph_def),
                            zipfile.ZIP_DEFLATED)
 
         if tflite_model_binary:
@@ -1761,6 +1760,84 @@ def make_strided_slice_tests(zip_path):
   make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
 
 
+def make_lstm_tests(zip_path):
+  """Make a set of tests to do basic Lstm cell."""
+
+  test_parameters = [
+      {
+          "dtype": [tf.float32],
+          "num_batchs": [1],
+          "time_step_size": [1],
+          "input_vec_size": [3],
+          "num_cells": [4],
+      },
+  ]
+
+  def build_graph(parameters):
+    """Build a simple graph with BasicLSTMCell."""
+
+    num_batchs = parameters["num_batchs"]
+    time_step_size = parameters["time_step_size"]
+    input_vec_size = parameters["input_vec_size"]
+    num_cells = parameters["num_cells"]
+    inputs_after_split = []
+    for i in xrange(time_step_size):
+      one_timestamp_input = tf.placeholder(
+          dtype=parameters["dtype"],
+          name="split_{}".format(i),
+          shape=[num_batchs, input_vec_size])
+      inputs_after_split.append(one_timestamp_input)
+    # Currently lstm identifier has a few limitations: only supports
+    # forget_bias == 0, inner state activiation == tanh.
+    # TODO(zhixianyan): Add another test with forget_bias == 1.
+    # TODO(zhixianyan): Add another test with relu as activation.
+    lstm_cell = tf.contrib.rnn.BasicLSTMCell(
+        num_cells, forget_bias=0.0, state_is_tuple=True)
+    cell_outputs, _ = rnn.static_rnn(
+        lstm_cell, inputs_after_split, dtype=tf.float32)
+    out = cell_outputs[-1]
+    return inputs_after_split, [out]
+
+  def build_inputs(parameters, sess, inputs, outputs):
+    """Feed inputs, assign vairables, and freeze graph."""
+
+    with tf.variable_scope("", reuse=True):
+      kernel = tf.get_variable("rnn/basic_lstm_cell/kernel")
+      bias = tf.get_variable("rnn/basic_lstm_cell/bias")
+      kernel_values = create_tensor_data(
+          parameters["dtype"], [kernel.shape[0], kernel.shape[1]], -1, 1)
+      bias_values = create_tensor_data(parameters["dtype"], [bias.shape[0]], 0,
+                                       1)
+      sess.run(tf.group(kernel.assign(kernel_values), bias.assign(bias_values)))
+
+    num_batchs = parameters["num_batchs"]
+    time_step_size = parameters["time_step_size"]
+    input_vec_size = parameters["input_vec_size"]
+    input_values = []
+    for _ in xrange(time_step_size):
+      tensor_data = create_tensor_data(parameters["dtype"],
+                                       [num_batchs, input_vec_size], 0, 1)
+      input_values.append(tensor_data)
+    out = sess.run(outputs, feed_dict=dict(zip(inputs, input_values)))
+    return input_values, out
+
+  # TODO(zhixianyan): Automatically generate rnn_states for lstm cell.
+  extra_toco_options = ExtraTocoOptions()
+  extra_toco_options.rnn_states = (
+      "{state_array:rnn/BasicLSTMCellZeroState/zeros,"
+      "back_edge_source_array:rnn/basic_lstm_cell/Add_1,size:4},"
+      "{state_array:rnn/BasicLSTMCellZeroState/zeros_1,"
+      "back_edge_source_array:rnn/basic_lstm_cell/Mul_2,size:4}")
+
+  make_zip_of_tests(
+      zip_path,
+      test_parameters,
+      build_graph,
+      build_inputs,
+      extra_toco_options,
+      use_frozen_graph=True)
+
+
 def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
   """Given an input perform a sequence of TensorFlow ops to produce l2pool."""
   return tf.sqrt(tf.nn.avg_pool(
@@ -1850,6 +1927,7 @@ def main(unused_args):
         "strided_slice.zip": make_strided_slice_tests,
         "exp.zip": make_exp_tests,
         "log_softmax.zip": make_log_softmax_tests,
+        "lstm.zip": make_lstm_tests,
     }
     out = FLAGS.zip_to_output
     bin_path = FLAGS.toco
index 89a5841..976363f 100644 (file)
@@ -266,6 +266,7 @@ INSTANTIATE_TESTS(sub)
 INSTANTIATE_TESTS(split)
 INSTANTIATE_TESTS(div)
 INSTANTIATE_TESTS(transpose)
+INSTANTIATE_TESTS(lstm)
 INSTANTIATE_TESTS(mean)
 INSTANTIATE_TESTS(squeeze)
 INSTANTIATE_TESTS(strided_slice)
index c8f2e49..389688d 100644 (file)
@@ -192,27 +192,25 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
   int model_outputs = interpreter->outputs().size();
   TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size());
   for (size_t i = 0; i < interpreter->outputs().size(); i++) {
+    bool tensors_differ = false;
     int output_index = interpreter->outputs()[i];
     if (const float* data = interpreter->typed_tensor<float>(output_index)) {
       for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
         float computed = data[idx];
         float reference = example.outputs[0].flat_data[idx];
         float diff = std::abs(computed - reference);
-        bool error_is_large = false;
         // For very small numbers, try absolute error, otherwise go with
         // relative.
-        if (std::abs(reference) < kRelativeThreshold) {
-          error_is_large = (diff > kAbsoluteThreshold);
-        } else {
-          error_is_large = (diff > kRelativeThreshold * std::abs(reference));
-        }
-        if (error_is_large) {
+        bool local_tensors_differ =
+            std::abs(reference) < kRelativeThreshold
+                ? diff > kAbsoluteThreshold
+                : diff > kRelativeThreshold * std::abs(reference);
+        if (local_tensors_differ) {
           fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n",
                   i, idx, data[idx], reference);
-          return kTfLiteError;
+          tensors_differ = local_tensors_differ;
         }
       }
-      fprintf(stderr, "\n");
     } else if (const int32_t* data =
                    interpreter->typed_tensor<int32_t>(output_index)) {
       for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
@@ -221,10 +219,9 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
         if (std::abs(computed - reference) > 0) {
           fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %d\n",
                   i, idx, computed, reference);
-          return kTfLiteError;
+          tensors_differ = true;
         }
       }
-      fprintf(stderr, "\n");
     } else if (const int64_t* data =
                    interpreter->typed_tensor<int64_t>(output_index)) {
       for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
@@ -235,14 +232,15 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
                   "output[%zu][%zu] did not match %" PRId64
                   " vs reference %" PRId64 "\n",
                   i, idx, computed, reference);
-          return kTfLiteError;
+          tensors_differ = true;
         }
       }
-      fprintf(stderr, "\n");
     } else {
       fprintf(stderr, "output[%zu] was not float or int data\n", i);
       return kTfLiteError;
     }
+    fprintf(stderr, "\n");
+    if (tensors_differ) return kTfLiteError;
   }
   return kTfLiteOk;
 }