6 #include <boost/test/unit_test.hpp> 18 explicit Convolution2dFixture(
const std::string& dataLayout,
const std::string& paddingType)
19 : Convolution2dFixture(dataLayout, paddingType, 1)
24 explicit Convolution2dFixture(
const std::string& dataLayout,
const std::string& paddingType,
25 int stride,
int dilation = 0)
27 std::string strideString (
" i: 1 \n" 29 if (dataLayout ==
"NHWC")
31 strideString.append(
" i: " + std::to_string(stride) +
" \n" 36 strideString.append(
" i: 1 \n" 37 " i: " + std::to_string(stride) +
" \n");
40 std::string dilationString = std::to_string(dilation);
42 " name: \"graphInput\" \n" 43 " op: \"Placeholder\" \n" 59 " name: \"Const_1\" \n" 86 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n" 92 " name: \"potato\" \n" 94 " input: \"graphInput\" \n" 95 " input: \"Const_1\" \n" 103 " key: \"data_format\" \n" 111 " key: \"padding\" \n" 119 " key: \"strides\" \n" 131 " key: \"dilations\" \n" 147 " key: \"use_cudnn_on_gpu\" \n" 155 BOOST_ASSERT_MSG(stride == 1 || stride == 2,
"Add support for strides other than 1 or 2.");
156 std::array<unsigned int, 4> dims;
157 if (dataLayout ==
"NHWC")
159 dims = { 1u, (stride == 2 ? 3u : 2u), 3u, 1u };
163 dims = { 1u, 1u, (stride == 2 ? 3u : 2u), 3u };
171 struct Convolution2dNhwcSameFixture : Convolution2dFixture
173 Convolution2dNhwcSameFixture() : Convolution2dFixture(
"NHWC",
"SAME", 1){}
177 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
180 struct Convolution2dNchwSameFixture : Convolution2dFixture
182 Convolution2dNchwSameFixture() : Convolution2dFixture(
"NCHW",
"SAME", 1){}
186 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
190 struct Convolution2dNhwcValidFixture : Convolution2dFixture
192 Convolution2dNhwcValidFixture() : Convolution2dFixture(
"NHWC",
"VALID", 1){}
196 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
199 struct Convolution2dNchwValidFixture : Convolution2dFixture
201 Convolution2dNchwValidFixture() : Convolution2dFixture(
"NCHW",
"VALID", 1){}
205 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
209 struct Convolution2dStride2NhwcSameFixture : Convolution2dFixture
211 Convolution2dStride2NhwcSameFixture() : Convolution2dFixture(
"NHWC",
"SAME", 2){}
215 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
218 struct Convolution2dStride2NchwSameFixture : Convolution2dFixture
220 Convolution2dStride2NchwSameFixture() : Convolution2dFixture(
"NCHW",
"SAME", 2){}
224 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
228 struct Convolution2dStride2NhwcValidFixture : Convolution2dFixture
230 Convolution2dStride2NhwcValidFixture() : Convolution2dFixture(
"NHWC",
"VALID", 2){}
234 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
237 struct Convolution2dStride2NchwValidFixture : Convolution2dFixture
239 Convolution2dStride2NchwValidFixture() : Convolution2dFixture(
"NCHW",
"VALID", 2){}
243 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
247 struct Convolution2dDilation1NhwcFixture : Convolution2dFixture
249 Convolution2dDilation1NhwcFixture() : Convolution2dFixture(
"NHWC",
"SAME", 1, 1){}
253 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
256 struct Convolution2dDilation1NchwFixture : Convolution2dFixture
258 Convolution2dDilation1NchwFixture() : Convolution2dFixture(
"NCHW",
"SAME", 1, 1){}
262 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
268 const char* prototext =
"" 270 " name: \"graphInput\"\n" 271 " op: \"Placeholder\"\n" 287 " name: \"Const_1\"\n" 314 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n" 320 " name: \"potato\"\n" 322 " input: \"graphInput\"\n" 323 " input: \"Const_1\"\n" 331 " key: \"data_format\"\n" 337 " key: \"padding\"\n" 343 " key: \"strides\"\n" 354 " key: \"dilations\"\n" 365 " key: \"use_cudnn_on_gpu\"\n" 372 std::map<std::string, armnn::TensorShape> inputShapes;
374 inputShapes[
"graphInput"] = tensorShape;
376 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, inputShapes, {
"potato" }),
armnn::ParseException);
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
std::unique_ptr< ITfParser, void(*)(ITfParser *parser)> ITfParserPtr
BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcSame, Convolution2dNhwcSameFixture)
BOOST_AUTO_TEST_SUITE_END()
static ITfParserPtr Create()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.
BOOST_AUTO_TEST_CASE(ParseConv2dDilation2)