2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
12 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14 struct Convolution2dFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
16 explicit Convolution2dFixture(const char* paddingType)
17 : Convolution2dFixture(paddingType, 1)
20 // dilation: 0 - dilations attribute is not included;
21 // dilation: >0 - dilations attribute set to [1,v,v,1], where v is the value of the dilation arg
22 explicit Convolution2dFixture(const char* paddingType, int stride, int dilation = 0)
24 std::string strideString = std::to_string(stride);
25 std::string dilationString = std::to_string(dilation);
26 m_Prototext = "node { \n"
27 " name: \"graphInput\" \n"
28 " op: \"Placeholder\" \n"
44 " name: \"Const_1\" \n"
71 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
77 " name: \"potato\" \n"
79 " input: \"graphInput\" \n"
80 " input: \"Const_1\" \n"
88 " key: \"data_format\" \n"
94 " key: \"padding\" \n"
97 m_Prototext.append(paddingType);
98 m_Prototext.append("\"\n"
102 " key: \"strides\" \n"
108 m_Prototext.append(strideString);
109 m_Prototext.append(" \n"
117 m_Prototext.append(" attr { \n"
118 " key: \"dilations\" \n"
123 m_Prototext.append(dilationString);
124 m_Prototext.append(" \n"
126 m_Prototext.append(dilationString);
127 m_Prototext.append(" \n"
133 m_Prototext.append(" attr { \n"
134 " key: \"use_cudnn_on_gpu\" \n"
141 // Manual height computation based on stride parameter.
142 BOOST_ASSERT_MSG(stride == 1 || stride==2, "Add support for strides other than 1 or 2.");
143 unsigned int dims[] = {1,2,3,1};
149 SetupSingleInputSingleOutput(armnn::TensorShape(4, dims), "graphInput", "potato");
154 struct Convolution2dSameFixture : Convolution2dFixture
156 Convolution2dSameFixture() : Convolution2dFixture("SAME", 1){}
158 BOOST_FIXTURE_TEST_CASE(ParseConv2DSame, Convolution2dSameFixture)
160 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
163 struct Convolution2dValidFixture : Convolution2dFixture
165 Convolution2dValidFixture() : Convolution2dFixture("VALID", 1){}
167 BOOST_FIXTURE_TEST_CASE(ParseConv2DValid, Convolution2dValidFixture)
169 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
173 struct Convolution2dStride2SameFixture : Convolution2dFixture
175 Convolution2dStride2SameFixture() : Convolution2dFixture("SAME", 2){}
177 BOOST_FIXTURE_TEST_CASE(ParseConv2DStride2Same, Convolution2dStride2SameFixture)
179 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
183 struct Convolution2dStride2ValidFixture : Convolution2dFixture
185 Convolution2dStride2ValidFixture() : Convolution2dFixture("VALID", 2){}
187 BOOST_FIXTURE_TEST_CASE(ParseConv2DStride2Valid, Convolution2dStride2ValidFixture)
189 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
193 struct Convolution2dDilation1Fixture : Convolution2dFixture
195 Convolution2dDilation1Fixture() : Convolution2dFixture("SAME", 1, 1){}
197 BOOST_FIXTURE_TEST_CASE(ParseConv2DDilation1, Convolution2dDilation1Fixture)
199 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
202 BOOST_AUTO_TEST_CASE(ParseConv2DDilation2)
204 const char* prototext = ""
206 " name: \"graphInput\"\n"
207 " op: \"Placeholder\"\n"
223 " name: \"Const_1\"\n"
250 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n"
256 " name: \"potato\"\n"
258 " input: \"graphInput\"\n"
259 " input: \"Const_1\"\n"
267 " key: \"data_format\"\n"
273 " key: \"padding\"\n"
279 " key: \"strides\"\n"
290 " key: \"dilations\"\n"
301 " key: \"use_cudnn_on_gpu\"\n"
308 std::map<std::string, armnn::TensorShape> inputShapes;
309 armnn::TensorShape tensorShape = { 1, 3, 3, 1 };
310 inputShapes["graphInput"] = tensorShape;
311 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
312 BOOST_CHECK_EXCEPTION(parser->CreateNetworkFromString(prototext, inputShapes, { "potato" }),
313 armnn::ParseException,
314 [] (armnn::ParseException const& ex)->bool
316 return strcmp(ex.what(),
317 "ArmNN only supports Convolution layers with dilations [1,1,1,1]") == 0;
322 BOOST_AUTO_TEST_SUITE_END()