2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
6 #include <boost/test/unit_test.hpp>
7 #include "armnnOnnxParser/IOnnxParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
12 struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14 PoolingMainFixture(const std::string& dataType, const std::string& op)
19 producer_version: "2.5.1"
28 elem_type: )" + dataType + R"(
50 op_type: )" + op + R"(
101 struct MaxPoolValidFixture : PoolingMainFixture
103 MaxPoolValidFixture() : PoolingMainFixture("FLOAT", "\"MaxPool\"") {
108 struct MaxPoolInvalidFixture : PoolingMainFixture
110 MaxPoolInvalidFixture() : PoolingMainFixture("FLOAT16", "\"MaxPool\"") { }
113 BOOST_FIXTURE_TEST_CASE(ValidMaxPoolTest, MaxPoolValidFixture)
115 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}});
118 struct AvgPoolValidFixture : PoolingMainFixture
120 AvgPoolValidFixture() : PoolingMainFixture("FLOAT", "\"AveragePool\"") {
125 struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
127 PoolingWithPadFixture()
131 producer_name: "CNTK"
132 producer_version: "2.5.1"
163 op_type: "AveragePool"
185 name: "count_include_pad"
220 BOOST_FIXTURE_TEST_CASE(AveragePoolValid, AvgPoolValidFixture)
222 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}});
225 BOOST_FIXTURE_TEST_CASE(ValidAvgWithPadTest, PoolingWithPadFixture)
227 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}});
230 struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
236 producer_name: "CNTK"
237 producer_version: "2.5.1"
268 op_type: "GlobalAveragePool"
300 BOOST_FIXTURE_TEST_CASE(GlobalAvgTest, GlobalAvgFixture)
302 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}});
305 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeMaxPool, MaxPoolInvalidFixture)
307 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
310 BOOST_AUTO_TEST_SUITE_END()