Run flatbuffer verifier before reading a TFLITE file into toco.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 19 Mar 2018 22:13:53 +0000 (15:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 19 Mar 2018 22:20:23 +0000 (15:20 -0700)
PiperOrigin-RevId: 189649236

tensorflow/contrib/lite/toco/tflite/import.cc
tensorflow/contrib/lite/toco/tflite/import_test.cc

index e16784f..867395e 100644 (file)
@@ -162,8 +162,19 @@ void ImportIOTensors(const ::tflite::Model& input_model,
   }
 }
 
+namespace {
+bool Verify(const void* buf, size_t len) {
+  ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
+  return ::tflite::VerifyModelBuffer(verifier);
+}
+}  // namespace
+
 std::unique_ptr<Model> Import(const ModelFlags& model_flags,
                               const string& input_file_contents) {
+  if (!Verify(input_file_contents.data(), input_file_contents.size())) {
+    LOG(FATAL) << "Invalid flatbuffer.";
+  }
+
   const ::tflite::Model* input_model =
       ::tflite::GetModel(input_file_contents.data());
 
index f25b170..937a291 100644 (file)
@@ -66,15 +66,43 @@ class ImportTest : public ::testing::Test {
   }
 
   Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes() {
-    auto c1 =
-        ::tflite::CreateOperatorCode(builder_, ::tflite::BuiltinOperator_CUSTOM,
-                                     builder_.CreateString("custom_op_one"));
+    auto c1 = ::tflite::CreateOperatorCode(
+        builder_, ::tflite::BuiltinOperator_MAX_POOL_2D, 0);
     auto c2 = ::tflite::CreateOperatorCode(
         builder_, ::tflite::BuiltinOperator_CONV_2D, 0);
     return builder_.CreateVector(
         std::vector<Offset<::tflite::OperatorCode>>({c1, c2}));
   }
 
+  Offset<Vector<Offset<::tflite::Operator>>> BuildOperators() {
+    auto is = builder_.CreateVector<int>({0});
+    auto os = builder_.CreateVector<int>({1});
+    auto op = ::tflite::CreateOperator(
+        builder_, 0, is, os, ::tflite::BuiltinOptions_Conv2DOptions,
+        ::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_VALID, 1, 1,
+                                      ::tflite::ActivationFunctionType_NONE)
+            .Union(),
+        /*custom_options=*/0, ::tflite::CustomOptionsFormat_FLEXBUFFERS);
+
+    return builder_.CreateVector(std::vector<Offset<::tflite::Operator>>({op}));
+  }
+
+  Offset<Vector<Offset<::tflite::SubGraph>>> BuildSubGraphs(
+      Offset<Vector<Offset<::tflite::Tensor>>> tensors,
+      Offset<Vector<Offset<::tflite::Operator>>> operators,
+      int num_sub_graphs = 1) {
+    std::vector<int32_t> inputs = {0};
+    std::vector<int32_t> outputs = {1};
+    std::vector<Offset<::tflite::SubGraph>> v;
+    for (int i = 0; i < num_sub_graphs; ++i) {
+      v.push_back(::tflite::CreateSubGraph(
+          builder_, tensors, builder_.CreateVector(inputs),
+          builder_.CreateVector(outputs), operators,
+          builder_.CreateString("subgraph")));
+    }
+    return builder_.CreateVector(v);
+  }
+
   // This is a very simplistic model. We are not interested in testing all the
   // details here, since tf.mini's testing framework will be exercising all the
   // conversions multiple times, and the conversion of operators is tested by
@@ -83,14 +111,13 @@ class ImportTest : public ::testing::Test {
     auto buffers = BuildBuffers();
     auto tensors = BuildTensors();
     auto opcodes = BuildOpCodes();
-
-    auto subgraph = ::tflite::CreateSubGraph(builder_, tensors, 0, 0, 0);
-    std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vector(
-        {subgraph});
-    auto subgraphs = builder_.CreateVector(subgraph_vector);
+    auto operators = BuildOperators();
+    auto subgraphs = BuildSubGraphs(tensors, operators);
     auto s = builder_.CreateString("");
-    builder_.Finish(::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
-                                          opcodes, subgraphs, s, buffers));
+
+    ::tflite::FinishModelBuffer(
+        builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
+                                        opcodes, subgraphs, s, buffers));
 
     input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
   }
@@ -99,7 +126,6 @@ class ImportTest : public ::testing::Test {
                   builder_.GetSize());
   }
   flatbuffers::FlatBufferBuilder builder_;
-  // const uint8_t* buffer_ = nullptr;
   const ::tflite::Model* input_model_ = nullptr;
 };
 
@@ -116,7 +142,7 @@ TEST_F(ImportTest, LoadOperatorsTable) {
 
   details::OperatorsTable operators;
   details::LoadOperatorsTable(*input_model_, &operators);
-  EXPECT_THAT(operators, ElementsAre("custom_op_one", "CONV_2D"));
+  EXPECT_THAT(operators, ElementsAre("MAX_POOL_2D", "CONV_2D"));
 }
 
 TEST_F(ImportTest, Tensors) {
@@ -143,13 +169,17 @@ TEST_F(ImportTest, Tensors) {
   EXPECT_EQ(100, q->zero_point);
 }
 
-TEST_F(ImportTest, NoSubGraphs) {
+TEST_F(ImportTest, MultipleSubGraphs) {
   auto buffers = BuildBuffers();
+  auto tensors = BuildTensors();
   auto opcodes = BuildOpCodes();
-  auto subgraphs = 0;  // no subgraphs in this model
+  auto operators = BuildOperators();
+  auto subgraphs = BuildSubGraphs(tensors, operators, 2);
   auto comment = builder_.CreateString("");
-  builder_.Finish(::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
-                                        opcodes, subgraphs, comment, buffers));
+  ::tflite::FinishModelBuffer(
+      builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
+                                      subgraphs, comment, buffers));
+
   input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
 
   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),