Invalidate("Failed to allocate tensors");
return;
}
+ ResetLSTMStateTensors();
must_allocate_tensors_ = false;
}
}
return success;
}
+void TfLiteDriver::ResetLSTMStateTensors() {
+ // This is a workaround for initializing state tensors for LSTM.
+ // TODO(ycling): Refactoring and find a better way to initialize state
+ // tensors. Maybe write the reset instructions into the test data.
+ for (auto node_index : interpreter_->execution_plan()) {
+ const auto& node_and_reg = interpreter_->node_and_registration(node_index);
+ const auto& node = node_and_reg->first;
+ const auto& registration = node_and_reg->second;
+ if (registration.builtin_code == tflite::BuiltinOperator_LSTM &&
+ node.outputs->size >= 2) {
+ // The first 2 outputs of LSTM are state tensors.
+ for (int i = 0; i < 2; ++i) {
+ int node_index = node.outputs->data[i];
+ ResetTensor(node_index);
+ }
+ }
+ }
+}
+
} // namespace testing
} // namespace tflite