Check if ops used in the model are supported by op resolver
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 14 Feb 2018 19:24:59 +0000 (11:24 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Feb 2018 19:28:07 +0000 (11:28 -0800)
PiperOrigin-RevId: 185716870

tensorflow/contrib/lite/tools/BUILD
tensorflow/contrib/lite/tools/verifier.cc
tensorflow/contrib/lite/tools/verifier.h
tensorflow/contrib/lite/tools/verifier_test.cc

index 6786b16..999ccf2 100644 (file)
@@ -112,6 +112,7 @@ cc_test(
     size = "small",
     srcs = ["verifier_test.cc"],
     deps = [
+        ":mutable_op_resolver",
         ":verifier",
         "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite:schema_fbs_version",
index 726e2aa..59c7420 100644 (file)
@@ -155,11 +155,11 @@ bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) {
   }
   for (const auto& subgraph : *model.subgraphs()) {
     if (!subgraph->tensors()) {
-      return true;
+      continue;
     }
     for (const auto& tensor : *subgraph->tensors()) {
       if (!tensor->buffer()) {
-        return true;
+        continue;
       }
       if (tensor->buffer() >= model.buffers()->size()) {
         ReportError(error_reporter, "Invalid tensor buffer index: %d",
@@ -187,9 +187,33 @@ bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) {
   return true;
 }
 
+bool VerifyOps(const Model& model, const OpResolver& resolver,
+               ErrorReporter* error_reporter) {
+  if (!model.operator_codes()) {
+    return true;
+  }
+  for (const auto& opcode : *model.operator_codes()) {
+    if (opcode->builtin_code() == BuiltinOperator_CUSTOM) {
+      if (!resolver.FindOp(opcode->custom_code()->c_str())) {
+        ReportError(error_reporter, "Unsupported custom op: %s",
+                    opcode->custom_code()->c_str());
+        return false;
+      }
+    } else {
+      if (!resolver.FindOp(opcode->builtin_code())) {
+        ReportError(error_reporter, "Unsupported builtin op: %s",
+                    EnumNameBuiltinOperator(opcode->builtin_code()));
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
 }  // namespace
 
-bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter) {
+bool Verify(const void* buf, size_t len, const OpResolver& resolver,
+            ErrorReporter* error_reporter) {
   const Model* model = VerifyFlatbufferAndGetModel(buf, len);
   if (model == nullptr) {
     ReportError(error_reporter, "Invalid flatbuffer format");
@@ -202,6 +226,9 @@ bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter) {
   if (!VerifyTensors(*model, error_reporter)) {
     return false;
   }
+  if (!VerifyOps(*model, resolver, error_reporter)) {
+    return false;
+  }
   return true;
 }
 }  // namespace tflite
index d2bf3c9..c2ee112 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include <stdio.h>
 
 #include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/model.h"
 
 namespace tflite {
 
@@ -26,7 +27,9 @@ namespace tflite {
 // Currently, it verifies:
 // * The file is following a legit flatbuffer schema.
 // * The model is in supported version.
-bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter);
+// * All ops used in the model are supported by OpResolver.
+bool Verify(const void* buf, size_t len, const OpResolver& resolver,
+            ErrorReporter* error_reporter);
 
 }  // namespace tflite
 
index 87f6854..b3e611f 100644 (file)
@@ -22,6 +22,7 @@ limitations under the License.
 #include "tensorflow/contrib/lite/error_reporter.h"
 #include "tensorflow/contrib/lite/schema/schema_generated.h"
 #include "tensorflow/contrib/lite/testing/util.h"
+#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
 #include "tensorflow/contrib/lite/tools/verifier.h"
 #include "tensorflow/contrib/lite/version.h"
 #include "tensorflow/core/framework/numeric_types.h"
@@ -40,6 +41,19 @@ class TfLiteFlatbufferModelBuilder {
         CreateBuffer(builder_, builder_.CreateVector(std::vector<uint8_t>{})));
   }
 
+  TfLiteFlatbufferModelBuilder(const std::vector<BuiltinOperator>& builtin_ops,
+                               const std::vector<string>& custom_ops) {
+    buffers_.push_back(
+        CreateBuffer(builder_, builder_.CreateVector(std::vector<uint8_t>{})));
+
+    for (const auto& iter : builtin_ops) {
+      resolver_.AddBuiltin(iter, &fake_op_);
+    }
+    for (const auto& iter : custom_ops) {
+      resolver_.AddCustom(iter.data(), &fake_op_);
+    }
+  }
+
   void AddTensor(const std::vector<int>& shape, tflite::TensorType type,
                  const std::vector<uint8_t>& buffer, const char* name) {
     int buffer_index = 0;
@@ -79,11 +93,13 @@ class TfLiteFlatbufferModelBuilder {
 
   bool Verify() {
     return tflite::Verify(builder_.GetBufferPointer(), builder_.GetSize(),
-                          DefaultErrorReporter());
+                          resolver_, DefaultErrorReporter());
   }
 
  private:
   FlatBufferBuilder builder_;
+  MutableOpResolver resolver_;
+  TfLiteRegistration fake_op_;
   std::vector<Offset<Operator>> operators_;
   std::vector<Offset<OperatorCode>> operator_codes_;
   std::vector<Offset<Tensor>> tensors_;
@@ -98,11 +114,11 @@ TEST(VerifyModel, TestEmptyModel) {
   ::tflite::FinishModelBuffer(builder, model);
 
   ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(),
-                     DefaultErrorReporter()));
+                     MutableOpResolver{}, DefaultErrorReporter()));
 }
 
 TEST(VerifyModel, TestSimpleModel) {
-  TfLiteFlatbufferModelBuilder builder;
+  TfLiteFlatbufferModelBuilder builder({}, {"test"});
   builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "test");
   builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4, 5, 6}, "input");
   builder.AddTensor(
@@ -116,7 +132,8 @@ TEST(VerifyModel, TestSimpleModel) {
 
 TEST(VerifyModel, TestCorruptedData) {
   std::string model = "123";
-  ASSERT_FALSE(Verify(model.data(), model.size(), /*error_reporter=*/nullptr));
+  ASSERT_FALSE(Verify(model.data(), model.size(), MutableOpResolver{},
+                      /*error_reporter=*/nullptr));
 }
 
 TEST(VerifyModel, TestUnsupportedVersion) {
@@ -125,7 +142,7 @@ TEST(VerifyModel, TestUnsupportedVersion) {
                            /*subgraphs=*/0, /*description=*/0, /*buffers=*/0);
   ::tflite::FinishModelBuffer(builder, model);
   ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(),
-                      DefaultErrorReporter()));
+                      MutableOpResolver{}, DefaultErrorReporter()));
 }
 
 TEST(VerifyModel, TestRandomModificationIsNotAllowed) {
@@ -140,7 +157,7 @@ TEST(VerifyModel, TestRandomModificationIsNotAllowed) {
   for (int i = 0; i < model_content.size(); i++) {
     model_content[i] = (model_content[i] + 137) % 255;
     EXPECT_FALSE(Verify(model_content.data(), model_content.size(),
-                        DefaultErrorReporter()))
+                        MutableOpResolver{}, DefaultErrorReporter()))
         << "Fail at position: " << i;
   }
 }
@@ -188,7 +205,7 @@ TEST(VerifyModel, TensorBufferIsNotValid) {
 
   ::tflite::FinishModelBuffer(builder, model);
   ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(),
-                      DefaultErrorReporter()));
+                      MutableOpResolver{}, DefaultErrorReporter()));
 }
 
 TEST(VerifyModel, StringTensorHasInvalidNumString) {
@@ -229,6 +246,37 @@ TEST(VerifyModel, StringTensorIsLargerThanRequired) {
   ASSERT_FALSE(builder.Verify());
 }
 
+TEST(VerifyModel, AllOpsAreSupported) {
+  TfLiteFlatbufferModelBuilder builder({BuiltinOperator_ADD}, {"CustomOp"});
+  builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1");
+  builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2");
+  builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output");
+  builder.AddOperator({0, 1}, {2}, BuiltinOperator_ADD, nullptr);
+  builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "CustomOp");
+  builder.FinishModel({}, {});
+  ASSERT_FALSE(builder.Verify());
+}
+
+TEST(VerifyModel, UseUnsupportedBuiltinOps) {
+  TfLiteFlatbufferModelBuilder builder({BuiltinOperator_SUB}, {"CustomOp"});
+  builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1");
+  builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2");
+  builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output");
+  builder.AddOperator({0, 1}, {2}, BuiltinOperator_ADD, nullptr);
+  builder.FinishModel({}, {});
+  ASSERT_FALSE(builder.Verify());
+}
+
+TEST(VerifyModel, UseUnsupportedCustomOps) {
+  TfLiteFlatbufferModelBuilder builder({BuiltinOperator_ADD}, {"NewOp"});
+  builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1");
+  builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2");
+  builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output");
+  builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "Not supported");
+  builder.FinishModel({}, {});
+  ASSERT_FALSE(builder.Verify());
+}
+
 // TODO(yichengfan): make up malicious files to test with.
 
 }  // namespace tflite