2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
11 using armnnTfLiteParser::TfLiteParser;
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
15 struct GetBufferFixture : public ParserFlatbuffersFixture
17 explicit GetBufferFixture()
22 "operator_codes": [ { "builtin_code": "CONV_2D" } ],
26 "shape": [ 1, 3, 3, 1 ],
29 "name": "inputTensor",
38 "shape": [ 1, 1, 1, 1 ],
41 "name": "outputTensor",
50 "shape": [ 1, 3, 3, 1 ],
53 "name": "filterTensor",
69 "builtin_options_type": "Conv2DOptions",
74 "fused_activation_function": "NONE"
76 "custom_options_format": "FLEXBUFFERS"
83 { "data": [ 2,1,0, 6,2,1, 4,1,2 ], },
91 void CheckBufferContents(const TfLiteParser::ModelPtr& model,
92 std::vector<int32_t> bufferValues, size_t bufferIndex)
94 for(long unsigned int i=0; i<bufferValues.size(); i++)
96 BOOST_CHECK_EQUAL(TfLiteParser::GetBuffer(model, bufferIndex)->data[i], bufferValues[i]);
101 BOOST_FIXTURE_TEST_CASE(GetBufferCheckContents, GetBufferFixture)
103 //Check contents of buffer are correct
104 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
105 std::vector<int32_t> bufferValues = {2,1,0,6,2,1,4,1,2};
106 CheckBufferContents(model, bufferValues, 2);
109 BOOST_FIXTURE_TEST_CASE(GetBufferCheckEmpty, GetBufferFixture)
111 //Check if test fixture buffers are empty or not
112 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
113 BOOST_CHECK(TfLiteParser::GetBuffer(model, 0)->data.empty());
114 BOOST_CHECK(TfLiteParser::GetBuffer(model, 1)->data.empty());
115 BOOST_CHECK(!TfLiteParser::GetBuffer(model, 2)->data.empty());
116 BOOST_CHECK(TfLiteParser::GetBuffer(model, 3)->data.empty());
119 BOOST_FIXTURE_TEST_CASE(GetBufferCheckParseException, GetBufferFixture)
121 //Check if armnn::ParseException thrown when invalid buffer index used
122 TfLiteParser::ModelPtr model = TfLiteParser::LoadModelFromBinary(m_GraphBinary.data(), m_GraphBinary.size());
123 BOOST_CHECK_THROW(TfLiteParser::GetBuffer(model, 4)->data.empty(), armnn::Exception);
126 BOOST_AUTO_TEST_SUITE_END()