Fix the LSTM test in TFLite.
authorYu-Cheng Ling <ycling@google.com>
Tue, 22 May 2018 23:31:32 +0000 (16:31 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 23:34:15 +0000 (16:34 -0700)
PiperOrigin-RevId: 197643581

tensorflow/contrib/lite/build_def.bzl
tensorflow/contrib/lite/testing/tflite_driver.cc
tensorflow/contrib/lite/testing/tflite_driver.h

index 9bfc0a0..c8820ab 100644 (file)
@@ -212,12 +212,13 @@ def generated_test_models():
         "global_batch_norm",
         "greater",
         "greater_equal",
-        "l2_pool",
         "l2norm",
+        "l2_pool",
         "less",
         "less_equal",
         "local_response_norm",
         "log_softmax",
+        "lstm",
         "max_pool",
         "maximum",
         "mean",
index 58fe5bd..1f07068 100644 (file)
@@ -143,6 +143,7 @@ void TfLiteDriver::AllocateTensors() {
       Invalidate("Failed to allocate tensors");
       return;
     }
+    ResetLSTMStateTensors();
     must_allocate_tensors_ = false;
   }
 }
@@ -281,5 +282,24 @@ bool TfLiteDriver::CheckResults() {
   return success;
 }
 
+void TfLiteDriver::ResetLSTMStateTensors() {
+  // This is a workaround for initializing state tensors for LSTM.
+  // TODO(ycling): Refactoring and find a better way to initialize state
+  // tensors. Maybe write the reset instructions into the test data.
+  for (auto node_index : interpreter_->execution_plan()) {
+    const auto& node_and_reg = interpreter_->node_and_registration(node_index);
+    const auto& node = node_and_reg->first;
+    const auto& registration = node_and_reg->second;
+    if (registration.builtin_code == tflite::BuiltinOperator_LSTM &&
+        node.outputs->size >= 2) {
+      // The first 2 outputs of LSTM are state tensors.
+      for (int i = 0; i < 2; ++i) {
+        int node_index = node.outputs->data[i];
+        ResetTensor(node_index);
+      }
+    }
+  }
+}
+
 }  // namespace testing
 }  // namespace tflite
index 02b7de1..5493ba3 100644 (file)
@@ -48,6 +48,8 @@ class TfLiteDriver : public TestRunner {
   string ReadOutput(int id) override { return "no-op"; }
 
  private:
+  void ResetLSTMStateTensors();
+
   class Expectation;
 
   bool use_nnapi_ = false;