}
for (const auto& subgraph : *model.subgraphs()) {
if (!subgraph->tensors()) {
- return true;
+ continue;
}
for (const auto& tensor : *subgraph->tensors()) {
if (!tensor->buffer()) {
- return true;
+ continue;
}
if (tensor->buffer() >= model.buffers()->size()) {
ReportError(error_reporter, "Invalid tensor buffer index: %d",
return true;
}
+bool VerifyOps(const Model& model, const OpResolver& resolver,
+ ErrorReporter* error_reporter) {
+ if (!model.operator_codes()) {
+ return true;
+ }
+ for (const auto& opcode : *model.operator_codes()) {
+ if (opcode->builtin_code() == BuiltinOperator_CUSTOM) {
+ if (!resolver.FindOp(opcode->custom_code()->c_str())) {
+ ReportError(error_reporter, "Unsupported custom op: %s",
+ opcode->custom_code()->c_str());
+ return false;
+ }
+ } else {
+ if (!resolver.FindOp(opcode->builtin_code())) {
+ ReportError(error_reporter, "Unsupported builtin op: %s",
+ EnumNameBuiltinOperator(opcode->builtin_code()));
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
} // namespace
-bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter) {
+bool Verify(const void* buf, size_t len, const OpResolver& resolver,
+ ErrorReporter* error_reporter) {
const Model* model = VerifyFlatbufferAndGetModel(buf, len);
if (model == nullptr) {
ReportError(error_reporter, "Invalid flatbuffer format");
if (!VerifyTensors(*model, error_reporter)) {
return false;
}
+ if (!VerifyOps(*model, resolver, error_reporter)) {
+ return false;
+ }
return true;
}
} // namespace tflite
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/testing/util.h"
+#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
#include "tensorflow/contrib/lite/tools/verifier.h"
#include "tensorflow/contrib/lite/version.h"
#include "tensorflow/core/framework/numeric_types.h"
CreateBuffer(builder_, builder_.CreateVector(std::vector<uint8_t>{})));
}
+ TfLiteFlatbufferModelBuilder(const std::vector<BuiltinOperator>& builtin_ops,
+ const std::vector<string>& custom_ops) {
+ buffers_.push_back(
+ CreateBuffer(builder_, builder_.CreateVector(std::vector<uint8_t>{})));
+
+ for (const auto& iter : builtin_ops) {
+ resolver_.AddBuiltin(iter, &fake_op_);
+ }
+ for (const auto& iter : custom_ops) {
+ resolver_.AddCustom(iter.data(), &fake_op_);
+ }
+ }
+
void AddTensor(const std::vector<int>& shape, tflite::TensorType type,
const std::vector<uint8_t>& buffer, const char* name) {
int buffer_index = 0;
bool Verify() {
return tflite::Verify(builder_.GetBufferPointer(), builder_.GetSize(),
- DefaultErrorReporter());
+ resolver_, DefaultErrorReporter());
}
private:
FlatBufferBuilder builder_;
+ MutableOpResolver resolver_;
+ TfLiteRegistration fake_op_;
std::vector<Offset<Operator>> operators_;
std::vector<Offset<OperatorCode>> operator_codes_;
std::vector<Offset<Tensor>> tensors_;
::tflite::FinishModelBuffer(builder, model);
ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(),
- DefaultErrorReporter()));
+ MutableOpResolver{}, DefaultErrorReporter()));
}
TEST(VerifyModel, TestSimpleModel) {
- TfLiteFlatbufferModelBuilder builder;
+ TfLiteFlatbufferModelBuilder builder({}, {"test"});
builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "test");
builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4, 5, 6}, "input");
builder.AddTensor(
TEST(VerifyModel, TestCorruptedData) {
std::string model = "123";
- ASSERT_FALSE(Verify(model.data(), model.size(), /*error_reporter=*/nullptr));
+ ASSERT_FALSE(Verify(model.data(), model.size(), MutableOpResolver{},
+ /*error_reporter=*/nullptr));
}
TEST(VerifyModel, TestUnsupportedVersion) {
/*subgraphs=*/0, /*description=*/0, /*buffers=*/0);
::tflite::FinishModelBuffer(builder, model);
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(),
- DefaultErrorReporter()));
+ MutableOpResolver{}, DefaultErrorReporter()));
}
TEST(VerifyModel, TestRandomModificationIsNotAllowed) {
for (int i = 0; i < model_content.size(); i++) {
model_content[i] = (model_content[i] + 137) % 255;
EXPECT_FALSE(Verify(model_content.data(), model_content.size(),
- DefaultErrorReporter()))
+ MutableOpResolver{}, DefaultErrorReporter()))
<< "Fail at position: " << i;
}
}
::tflite::FinishModelBuffer(builder, model);
ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(),
- DefaultErrorReporter()));
+ MutableOpResolver{}, DefaultErrorReporter()));
}
TEST(VerifyModel, StringTensorHasInvalidNumString) {
ASSERT_FALSE(builder.Verify());
}
+TEST(VerifyModel, AllOpsAreSupported) {
+ TfLiteFlatbufferModelBuilder builder({BuiltinOperator_ADD}, {"CustomOp"});
+ builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1");
+ builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2");
+ builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output");
+ builder.AddOperator({0, 1}, {2}, BuiltinOperator_ADD, nullptr);
+ builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "CustomOp");
+ builder.FinishModel({}, {});
+ ASSERT_FALSE(builder.Verify());
+}
+
+TEST(VerifyModel, UseUnsupportedBuiltinOps) {
+ TfLiteFlatbufferModelBuilder builder({BuiltinOperator_SUB}, {"CustomOp"});
+ builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1");
+ builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2");
+ builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output");
+ builder.AddOperator({0, 1}, {2}, BuiltinOperator_ADD, nullptr);
+ builder.FinishModel({}, {});
+ ASSERT_FALSE(builder.Verify());
+}
+
+TEST(VerifyModel, UseUnsupportedCustomOps) {
+ TfLiteFlatbufferModelBuilder builder({BuiltinOperator_ADD}, {"NewOp"});
+ builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1");
+ builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2");
+ builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output");
+ builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "Not supported");
+ builder.FinishModel({}, {});
+ ASSERT_FALSE(builder.Verify());
+}
+
// TODO(yichengfan): make up malicious files to test with.
} // namespace tflite