From 41335abb46f80ca644b5738550daef6136ba5476 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 19 Mar 2018 17:23:20 -0700 Subject: [PATCH] Improve flatbuffer verification. PiperOrigin-RevId: 189668634 --- tensorflow/contrib/lite/toco/tflite/BUILD | 2 + tensorflow/contrib/lite/toco/tflite/import.cc | 7 +- tensorflow/contrib/lite/toco/tflite/import_test.cc | 106 +++++++++++++++++---- tensorflow/contrib/lite/tools/verifier.cc | 71 ++++++++++++-- tensorflow/contrib/lite/tools/verifier.h | 15 +++ tensorflow/contrib/lite/tools/verifier_test.cc | 4 +- 6 files changed, 175 insertions(+), 30 deletions(-) diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index a2b8145..9d3e1da 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -115,9 +115,11 @@ cc_library( deps = [ ":operator", ":types", + "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/toco:model", "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/contrib/lite/tools:verifier", "@flatbuffers", ], ) diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index 867395e..c0e7ab2 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -15,10 +15,12 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/tflite/import.h" #include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/toco/tflite/operator.h" #include "tensorflow/contrib/lite/toco/tflite/types.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/contrib/lite/tools/verifier.h" namespace toco { @@ -171,10 +173,11 @@ bool Verify(const void* buf, size_t len) { std::unique_ptr Import(const ModelFlags& model_flags, const string& input_file_contents) { - if (!Verify(input_file_contents.data(), input_file_contents.size())) { + ::tflite::AlwaysTrueResolver r; + if (!::tflite::Verify(input_file_contents.data(), input_file_contents.size(), + r, ::tflite::DefaultErrorReporter())) { LOG(FATAL) << "Invalid flatbuffer."; } - const ::tflite::Model* input_model = ::tflite::GetModel(input_file_contents.data()); diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/contrib/lite/toco/tflite/import_test.cc index 937a291..edd22f7 100644 --- a/tensorflow/contrib/lite/toco/tflite/import_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/import_test.cc @@ -36,12 +36,13 @@ class ImportTest : public ::testing::Test { return builder_.CreateVector(reinterpret_cast(data.data()), sizeof(T) * data.size()); } + Offset>> BuildBuffers() { auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector({})); - auto buf1 = - ::tflite::CreateBuffer(builder_, CreateDataVector({1.0f, 2.0f})); + auto buf1 = ::tflite::CreateBuffer( + builder_, CreateDataVector({1.0f, 2.0f, 3.0f, 4.0f})); auto buf2 = - ::tflite::CreateBuffer(builder_, CreateDataVector({3.0f})); + ::tflite::CreateBuffer(builder_, CreateDataVector({3.0f, 4.0f})); return builder_.CreateVector( std::vector>({buf0, buf1, buf2})); } @@ -53,10 +54,10 @@ class ImportTest : public ::testing::Test { /*max=*/builder_.CreateVector({0.2f}), /*scale=*/builder_.CreateVector({0.3f}), /*zero_point=*/builder_.CreateVector({100ll})); - auto t1 = ::tflite::CreateTensor(builder_, - builder_.CreateVector({1, 2, 3, 4}), - ::tflite::TensorType_FLOAT32, 1, - builder_.CreateString("tensor_one"), q); + auto t1 = + ::tflite::CreateTensor(builder_, builder_.CreateVector({1, 2, 2}), + ::tflite::TensorType_FLOAT32, 1, + builder_.CreateString("tensor_one"), q); auto t2 = ::tflite::CreateTensor(builder_, builder_.CreateVector({2, 1}), ::tflite::TensorType_FLOAT32, 2, @@ -65,18 +66,26 @@ class ImportTest : public ::testing::Test { std::vector>({t1, t2})); } + Offset>> BuildOpCodes( + std::initializer_list<::tflite::BuiltinOperator> op_codes) { + std::vector> op_codes_vector; + for (auto op : op_codes) { + op_codes_vector.push_back(::tflite::CreateOperatorCode(builder_, op, 0)); + } + return builder_.CreateVector(op_codes_vector); + } + Offset>> BuildOpCodes() { - auto c1 = ::tflite::CreateOperatorCode( - builder_, ::tflite::BuiltinOperator_MAX_POOL_2D, 0); - auto c2 = ::tflite::CreateOperatorCode( - builder_, ::tflite::BuiltinOperator_CONV_2D, 0); - return builder_.CreateVector( - std::vector>({c1, c2})); + return BuildOpCodes({::tflite::BuiltinOperator_MAX_POOL_2D, + ::tflite::BuiltinOperator_CONV_2D}); } - Offset>> BuildOperators() { - auto is = builder_.CreateVector({0}); - auto os = builder_.CreateVector({1}); + Offset>> BuildOperators( + std::initializer_list inputs, std::initializer_list outputs) { + auto is = builder_.CreateVector(inputs); + if (inputs.size() == 0) is = 0; + auto os = builder_.CreateVector(outputs); + if (outputs.size() == 0) os = 0; auto op = ::tflite::CreateOperator( builder_, 0, is, os, ::tflite::BuiltinOptions_Conv2DOptions, ::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_VALID, 1, 1, @@ -87,6 +96,10 @@ class ImportTest : public ::testing::Test { return builder_.CreateVector(std::vector>({op})); } + Offset>> BuildOperators() { + return BuildOperators({0}, {1}); + } + Offset>> BuildSubGraphs( Offset>> tensors, Offset>> operators, @@ -154,9 +167,9 @@ TEST_F(ImportTest, Tensors) { Array& a1 = model->GetArray("tensor_one"); EXPECT_EQ(ArrayDataType::kFloat, a1.data_type); EXPECT_THAT(a1.GetBuffer().data, - ElementsAre(1.0f, 2.0f)); + ElementsAre(1.0f, 2.0f, 3.0f, 4.0f)); ASSERT_TRUE(a1.has_shape()); - EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 3, 4)); + EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 2)); const auto& mm = a1.minmax; ASSERT_TRUE(mm.get()); @@ -169,6 +182,63 @@ TEST_F(ImportTest, Tensors) { EXPECT_EQ(100, q->zero_point); } +TEST_F(ImportTest, NoBuffers) { + auto buffers = 0; + auto tensors = BuildTensors(); + auto opcodes = BuildOpCodes(); + auto operators = BuildOperators(); + auto subgraphs = BuildSubGraphs(tensors, operators); + auto comment = builder_.CreateString(""); + ::tflite::FinishModelBuffer( + builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs, comment, buffers)); + EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()), + "Missing 'buffers' section."); +} + +TEST_F(ImportTest, NoInputs) { + auto buffers = BuildBuffers(); + auto tensors = BuildTensors(); + auto opcodes = BuildOpCodes(); + auto operators = BuildOperators({}, {1}); + auto subgraphs = BuildSubGraphs(tensors, operators); + auto comment = builder_.CreateString(""); + ::tflite::FinishModelBuffer( + builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs, comment, buffers)); + EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()), + "Missing 'inputs' for operator."); +} + +TEST_F(ImportTest, NoOutputs) { + auto buffers = BuildBuffers(); + auto tensors = BuildTensors(); + auto opcodes = BuildOpCodes(); + auto operators = BuildOperators({0}, {}); + auto subgraphs = BuildSubGraphs(tensors, operators); + auto comment = builder_.CreateString(""); + ::tflite::FinishModelBuffer( + builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs, comment, buffers)); + EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()), + "Missing 'outputs' for operator."); +} + +TEST_F(ImportTest, InvalidOpCode) { + auto buffers = BuildBuffers(); + auto tensors = BuildTensors(); + auto opcodes = BuildOpCodes({static_cast<::tflite::BuiltinOperator>(-1), + ::tflite::BuiltinOperator_CONV_2D}); + auto operators = BuildOperators(); + auto subgraphs = BuildSubGraphs(tensors, operators); + auto comment = builder_.CreateString(""); + ::tflite::FinishModelBuffer( + builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs, comment, buffers)); + EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()), + "Operator id '-1' is out of range."); +} + TEST_F(ImportTest, MultipleSubGraphs) { auto buffers = BuildBuffers(); auto tensors = BuildTensors(); diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/contrib/lite/tools/verifier.cc index 59c7420..8818a7d 100644 --- a/tensorflow/contrib/lite/tools/verifier.cc +++ b/tensorflow/contrib/lite/tools/verifier.cc @@ -148,11 +148,52 @@ bool VerifyNumericTensorBuffer(const Tensor& tensor, const Buffer& buffer, // TODO(yichengfan): verify quantized tensors. } +using flatbuffers::Offset; +using flatbuffers::Vector; + +bool VerifyOperators(const Vector>& operators, + ErrorReporter* error_reporter) { + for (const auto& op : operators) { + if (!op->inputs()) { + ReportError(error_reporter, "Missing 'inputs' for operator."); + return false; + } + if (!op->outputs()) { + ReportError(error_reporter, "Missing 'outputs' for operator."); + return false; + } + } + return true; +} + +bool VerifySubGraphs(const Model& model, ErrorReporter* error_reporter) { + if (!model.subgraphs()) { + ReportError(error_reporter, "Missing 'subgraphs' section."); + return false; + } + for (const auto& subgraph : *model.subgraphs()) { + if (!subgraph->operators()) { + ReportError(error_reporter, "Missing 'operators' section in subgraph."); + return false; + } + + if (!VerifyOperators(*subgraph->operators(), error_reporter)) { + return false; + } + } + return true; +} + // Verifies tensors have valid properties and legit buffer if set. bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) { if (!model.subgraphs()) { return true; } + if (!model.buffers()) { + ReportError(error_reporter, "Missing 'buffers' section."); + return false; + } + for (const auto& subgraph : *model.subgraphs()) { if (!subgraph->tensors()) { continue; @@ -167,19 +208,23 @@ bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) { return false; } auto* buffer = model.buffers()->Get(tensor->buffer()); - if (!buffer || !buffer->data()) { + if (!buffer) { ReportError(error_reporter, "Tensor buffer %d not set", tensor->buffer()); return false; } - if (tensor->type() == TensorType_STRING) { - if (!VerifyStringTensorBuffer(*buffer, error_reporter)) { - return false; - } - } else { - if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) { - return false; + // Many transient tensors don't have data in the flatbuffer. Their + // buffers will be allocated by the interpreter at run-time. + if (buffer->data()) { + if (tensor->type() == TensorType_STRING) { + if (!VerifyStringTensorBuffer(*buffer, error_reporter)) { + return false; + } + } else { + if (!VerifyNumericTensorBuffer(*tensor, *buffer, error_reporter)) { + return false; + } } } } @@ -193,6 +238,13 @@ bool VerifyOps(const Model& model, const OpResolver& resolver, return true; } for (const auto& opcode : *model.operator_codes()) { + if (opcode->builtin_code() < BuiltinOperator_MIN || + opcode->builtin_code() > BuiltinOperator_MAX) { + ReportError(error_reporter, "Operator id '%d' is out of range.", + opcode->builtin_code()); + return false; + } + if (opcode->builtin_code() == BuiltinOperator_CUSTOM) { if (!resolver.FindOp(opcode->custom_code()->c_str())) { ReportError(error_reporter, "Unsupported custom op: %s", @@ -223,6 +275,9 @@ bool Verify(const void* buf, size_t len, const OpResolver& resolver, ReportError(error_reporter, "Invalid model version %d", model->version()); return false; } + if (!VerifySubGraphs(*model, error_reporter)) { + return false; + } if (!VerifyTensors(*model, error_reporter)) { return false; } diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/contrib/lite/tools/verifier.h index c2ee112..b7ce4e8 100644 --- a/tensorflow/contrib/lite/tools/verifier.h +++ b/tensorflow/contrib/lite/tools/verifier.h @@ -23,6 +23,21 @@ limitations under the License. namespace tflite { +class AlwaysTrueResolver : public OpResolver { + public: + AlwaysTrueResolver() {} + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override { + static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr, + nullptr}; + return &null_registration; + } + TfLiteRegistration* FindOp(const char* op) const override { + static TfLiteRegistration null_registration = {nullptr, nullptr, nullptr, + nullptr}; + return &null_registration; + } +}; + // Verifies the integrity of a Tensorflow Lite flatbuffer model file. // Currently, it verifies: // * The file is following a legit flatbuffer schema. diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc index b3e611f..03b93af 100644 --- a/tensorflow/contrib/lite/tools/verifier_test.cc +++ b/tensorflow/contrib/lite/tools/verifier_test.cc @@ -113,8 +113,8 @@ TEST(VerifyModel, TestEmptyModel) { /*description=*/0, /*buffers=*/0); ::tflite::FinishModelBuffer(builder, model); - ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(), - MutableOpResolver{}, DefaultErrorReporter())); + ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), + MutableOpResolver{}, DefaultErrorReporter())); } TEST(VerifyModel, TestSimpleModel) { -- 2.7.4