Improve flatbuffer verification.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 20 Mar 2018 00:23:20 +0000 (17:23 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 00:27:51 +0000 (17:27 -0700)
PiperOrigin-RevId: 189668634

tensorflow/contrib/lite/toco/tflite/BUILD
tensorflow/contrib/lite/toco/tflite/import.cc
tensorflow/contrib/lite/toco/tflite/import_test.cc
tensorflow/contrib/lite/tools/verifier.cc
tensorflow/contrib/lite/tools/verifier.h
tensorflow/contrib/lite/tools/verifier_test.cc

index a2b8145..9d3e1da 100644 (file)
@@ -115,9 +115,11 @@ cc_library(
     deps = [
         ":operator",
         ":types",
+        "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite/schema:schema_fbs",
         "//tensorflow/contrib/lite/toco:model",
         "//tensorflow/contrib/lite/toco:tooling_util",
+        "//tensorflow/contrib/lite/tools:verifier",
         "@flatbuffers",
     ],
 )
index 867395e..c0e7ab2 100644 (file)
@@ -15,10 +15,12 @@ limitations under the License.
 #include "tensorflow/contrib/lite/toco/tflite/import.h"
 
 #include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/model.h"
 #include "tensorflow/contrib/lite/schema/schema_generated.h"
 #include "tensorflow/contrib/lite/toco/tflite/operator.h"
 #include "tensorflow/contrib/lite/toco/tflite/types.h"
 #include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/contrib/lite/tools/verifier.h"
 
 namespace toco {
 
@@ -171,10 +173,11 @@ bool Verify(const void* buf, size_t len) {
 
 std::unique_ptr<Model> Import(const ModelFlags& model_flags,
                               const string& input_file_contents) {
-  if (!Verify(input_file_contents.data(), input_file_contents.size())) {
+  ::tflite::AlwaysTrueResolver r;
+  if (!::tflite::Verify(input_file_contents.data(), input_file_contents.size(),
+                        r, ::tflite::DefaultErrorReporter())) {
     LOG(FATAL) << "Invalid flatbuffer.";
   }
-
   const ::tflite::Model* input_model =
       ::tflite::GetModel(input_file_contents.data());
 
index 937a291..edd22f7 100644 (file)
@@ -36,12 +36,13 @@ class ImportTest : public ::testing::Test {
     return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
                                  sizeof(T) * data.size());
   }
+
   Offset<Vector<Offset<::tflite::Buffer>>> BuildBuffers() {
     auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({}));
-    auto buf1 =
-        ::tflite::CreateBuffer(builder_, CreateDataVector<float>({1.0f, 2.0f}));
+    auto buf1 = ::tflite::CreateBuffer(
+        builder_, CreateDataVector<float>({1.0f, 2.0f, 3.0f, 4.0f}));
     auto buf2 =
-        ::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f}));
+        ::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f, 4.0f}));
     return builder_.CreateVector(
         std::vector<Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
   }
@@ -53,10 +54,10 @@ class ImportTest : public ::testing::Test {
         /*max=*/builder_.CreateVector<float>({0.2f}),
         /*scale=*/builder_.CreateVector<float>({0.3f}),
         /*zero_point=*/builder_.CreateVector<int64_t>({100ll}));
-    auto t1 = ::tflite::CreateTensor(builder_,
-                                     builder_.CreateVector<int>({1, 2, 3, 4}),
-                                     ::tflite::TensorType_FLOAT32, 1,
-                                     builder_.CreateString("tensor_one"), q);
+    auto t1 =
+        ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 2, 2}),
+                               ::tflite::TensorType_FLOAT32, 1,
+                               builder_.CreateString("tensor_one"), q);
     auto t2 =
         ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({2, 1}),
                                ::tflite::TensorType_FLOAT32, 2,
@@ -65,18 +66,26 @@ class ImportTest : public ::testing::Test {
         std::vector<Offset<::tflite::Tensor>>({t1, t2}));
   }
 
+  Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes(
+      std::initializer_list<::tflite::BuiltinOperator> op_codes) {
+    std::vector<Offset<::tflite::OperatorCode>> op_codes_vector;
+    for (auto op : op_codes) {
+      op_codes_vector.push_back(::tflite::CreateOperatorCode(builder_, op, 0));
+    }
+    return builder_.CreateVector(op_codes_vector);
+  }
+
   Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes() {
-    auto c1 = ::tflite::CreateOperatorCode(
-        builder_, ::tflite::BuiltinOperator_MAX_POOL_2D, 0);
-    auto c2 = ::tflite::CreateOperatorCode(
-        builder_, ::tflite::BuiltinOperator_CONV_2D, 0);
-    return builder_.CreateVector(
-        std::vector<Offset<::tflite::OperatorCode>>({c1, c2}));
+    return BuildOpCodes({::tflite::BuiltinOperator_MAX_POOL_2D,
+                         ::tflite::BuiltinOperator_CONV_2D});
   }
 
-  Offset<Vector<Offset<::tflite::Operator>>> BuildOperators() {
-    auto is = builder_.CreateVector<int>({0});
-    auto os = builder_.CreateVector<int>({1});
+  Offset<Vector<Offset<::tflite::Operator>>> BuildOperators(
+      std::initializer_list<int> inputs, std::initializer_list<int> outputs) {
+    auto is = builder_.CreateVector<int>(inputs);
+    if (inputs.size() == 0) is = 0;
+    auto os = builder_.CreateVector<int>(outputs);
+    if (outputs.size() == 0) os = 0;
     auto op = ::tflite::CreateOperator(
         builder_, 0, is, os, ::tflite::BuiltinOptions_Conv2DOptions,
         ::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_VALID, 1, 1,
@@ -87,6 +96,10 @@ class ImportTest : public ::testing::Test {
     return builder_.CreateVector(std::vector<Offset<::tflite::Operator>>({op}));
   }
 
+  Offset<Vector<Offset<::tflite::Operator>>> BuildOperators() {
+    return BuildOperators({0}, {1});
+  }
+
   Offset<Vector<Offset<::tflite::SubGraph>>> BuildSubGraphs(
       Offset<Vector<Offset<::tflite::Tensor>>> tensors,
       Offset<Vector<Offset<::tflite::Operator>>> operators,
@@ -154,9 +167,9 @@ TEST_F(ImportTest, Tensors) {
   Array& a1 = model->GetArray("tensor_one");
   EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
   EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
-              ElementsAre(1.0f, 2.0f));
+              ElementsAre(1.0f, 2.0f, 3.0f, 4.0f));
   ASSERT_TRUE(a1.has_shape());
-  EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 3, 4));
+  EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 2));
 
   const auto& mm = a1.minmax;
   ASSERT_TRUE(mm.get());
@@ -169,6 +182,63 @@ TEST_F(ImportTest, Tensors) {
   EXPECT_EQ(100, q->zero_point);
 }
 
+TEST_F(ImportTest, NoBuffers) {
+  auto buffers = 0;
+  auto tensors = BuildTensors();
+  auto opcodes = BuildOpCodes();
+  auto operators = BuildOperators();
+  auto subgraphs = BuildSubGraphs(tensors, operators);
+  auto comment = builder_.CreateString("");
+  ::tflite::FinishModelBuffer(
+      builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
+                                      subgraphs, comment, buffers));
+  EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
+               "Missing 'buffers' section.");
+}
+
+TEST_F(ImportTest, NoInputs) {
+  auto buffers = BuildBuffers();
+  auto tensors = BuildTensors();
+  auto opcodes = BuildOpCodes();
+  auto operators = BuildOperators({}, {1});
+  auto subgraphs = BuildSubGraphs(tensors, operators);
+  auto comment = builder_.CreateString("");
+  ::tflite::FinishModelBuffer(
+      builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
+                                      subgraphs, comment, buffers));
+  EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
+               "Missing 'inputs' for operator.");
+}
+
+TEST_F(ImportTest, NoOutputs) {
+  auto buffers = BuildBuffers();
+  auto tensors = BuildTensors();
+  auto opcodes = BuildOpCodes();
+  auto operators = BuildOperators({0}, {});
+  auto subgraphs = BuildSubGraphs(tensors, operators);
+  auto comment = builder_.CreateString("");
+  ::tflite::FinishModelBuffer(
+      builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
+                                      subgraphs, comment, buffers));
+  EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
+               "Missing 'outputs' for operator.");
+}
+
+TEST_F(ImportTest, InvalidOpCode) {
+  auto buffers = BuildBuffers();
+  auto tensors = BuildTensors();
+  auto opcodes = BuildOpCodes({static_cast<::tflite::BuiltinOperator>(-1),
+                               ::tflite::BuiltinOperator_CONV_2D});
+  auto operators = BuildOperators();
+  auto subgraphs = BuildSubGraphs(tensors, operators);
+  auto comment = builder_.CreateString("");
+  ::tflite::FinishModelBuffer(
+      builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
+                                      subgraphs, comment, buffers));
+  EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
+               "Operator id '-1' is out of range.");
+}
+
 TEST_F(ImportTest, MultipleSubGraphs) {
   auto buffers = BuildBuffers();
   auto tensors = BuildTensors();
index 59c7420..8818a7d 100644 (file)
@@ -148,11 +148,52 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer,
   // TODO(yichengfan): verify quantized tensors.
 }
 
+using flatbuffers::Offset;
+using flatbuffers::Vector;
+
+bool VerifyOperators(const Vector<Offset<Operator>>& operators,
+                     ErrorReporter* error_reporter) {
+  for (const auto& op : operators) {
+    if (!op->inputs()) {
+      ReportError(error_reporter, "Missing 'inputs' for operator.");
+      return false;
+    }
+    if (!op->outputs()) {
+      ReportError(error_reporter, "Missing 'outputs' for operator.");
+      return false;
+    }
+  }
+  return true;
+}
+
+bool VerifySubGraphs(const Model& model, ErrorReporter* error_reporter) {
+  if (!model.subgraphs()) {
+    ReportError(error_reporter, "Missing 'subgraphs' section.");
+    return false;
+  }
+  for (const auto& subgraph : *model.subgraphs()) {
+    if (!subgraph->operators()) {
+      ReportError(error_reporter, "Missing 'operators' section in subgraph.");
+      return false;
+    }
+
+    if (!VerifyOperators(*subgraph->operators(), error_reporter)) {
+      return false;
+    }
+  }
+  return true;
+}
+
 // Verifies tensors have valid properties and legit buffer if set.
 bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) {
   if (!model.subgraphs()) {
     return true;
   }
+  if (!model.buffers()) {
+    ReportError(error_reporter, "Missing 'buffers' section.");
+    return false;
+  }
+
   for (const auto& subgraph : *model.subgraphs()) {
     if (!subgraph->tensors()) {
       continue;
@@ -167,19 +208,23 @@ bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) {
         return false;
       }
       auto* buffer = model.buffers()->Get(tensor->buffer());
-      if (!buffer || !buffer->data()) {
+      if (!buffer) {
         ReportError(error_reporter, "Tensor buffer %d not set",
                     tensor->buffer());
         return false;
       }
 
-      if (tensor->type() == TensorType_STRING) {
-        if (!VerifyStringTensorBuffer(*buffer, error_reporter)) {
-          return false;
-        }
-      } else {
-        if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) {
-          return false;
+      // Many transient tensors don't have data in the flatbuffer. Their
+      // buffers will be allocated by the interpreter at run-time.
+      if (buffer->data()) {
+        if (tensor->type() == TensorType_STRING) {
+          if (!VerifyStringTensorBuffer(*buffer, error_reporter)) {
+            return false;
+          }
+        } else {
+          if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) {
+            return false;
+          }
         }
       }
     }
@@ -193,6 +238,13 @@ bool VerifyOps(const Model& model, const OpResolver& resolver,
     return true;
   }
   for (const auto& opcode : *model.operator_codes()) {
+    if (opcode->builtin_code() < BuiltinOperator_MIN ||
+        opcode->builtin_code() > BuiltinOperator_MAX) {
+      ReportError(error_reporter, "Operator id '%d' is out of range.",
+                  opcode->builtin_code());
+      return false;
+    }
+
     if (opcode->builtin_code() == BuiltinOperator_CUSTOM) {
       if (!resolver.FindOp(opcode->custom_code()->c_str())) {
         ReportError(error_reporter, "Unsupported custom op: %s",
@@ -223,6 +275,9 @@ bool Verify(const void* buf, size_t len, const OpResolver& resolver,
     ReportError(error_reporter, "Invalid model version %d", model->version());
     return false;
   }
+  if (!VerifySubGraphs(*model, error_reporter)) {
+    return false;
+  }
   if (!VerifyTensors(*model, error_reporter)) {
     return false;
   }
index c2ee112..b7ce4e8 100644 (file)
@@ -23,6 +23,21 @@ limitations under the License.
 
 namespace tflite {
 
+class AlwaysTrueResolver : public OpResolver {
+ public:
+  AlwaysTrueResolver() {}
+  TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override {
+    static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr,
+                                                   nullptr};
+    return &null_registration;
+  }
+  TfLiteRegistration* FindOp(const char* op) const override {
+    static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr,
+                                                   nullptr};
+    return &null_registration;
+  }
+};
+
 // Verifies the integrity of a Tensorflow Lite flatbuffer model file.
 // Currently, it verifies:
 // * The file is following a legit flatbuffer schema.
index b3e611f..03b93af 100644 (file)
@@ -113,8 +113,8 @@ TEST(VerifyModel, TestEmptyModel) {
                            /*description=*/0, /*buffers=*/0);
   ::tflite::FinishModelBuffer(builder, model);
 
-  ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(),
-                     MutableOpResolver{}, DefaultErrorReporter()));
+  ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(),
+                      MutableOpResolver{}, DefaultErrorReporter()));
 }
 
 TEST(VerifyModel, TestSimpleModel) {