Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / Pooling.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12
13 struct Pooling2dFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
14 {
15     explicit Pooling2dFixture(const char* poolingtype)
16     {
17         m_Prototext =  "node {\n"
18             "  name: \"Placeholder\"\n"
19             "  op: \"Placeholder\"\n"
20             "  attr {\n"
21             "    key: \"dtype\"\n"
22             "    value {\n"
23             "      type: DT_FLOAT\n"
24             "    }\n"
25             "  }\n"
26             "  attr {\n"
27             "    key: \"value\"\n"
28             "    value {\n"
29             "      tensor {\n"
30             "        dtype: DT_FLOAT\n"
31             "        tensor_shape {\n"
32             "        }\n"
33             "      }\n"
34             "    }\n"
35             "   }\n"
36             "  }\n"
37             "node {\n"
38             "  name: \"";
39         m_Prototext.append(poolingtype);
40         m_Prototext.append("\"\n"
41                                "  op: \"");
42         m_Prototext.append(poolingtype);
43         m_Prototext.append("\"\n"
44                                "  input: \"Placeholder\"\n"
45                                "  attr {\n"
46                                "    key: \"T\"\n"
47                                "    value {\n"
48                                "      type: DT_FLOAT\n"
49                                "    }\n"
50                                "  }\n"
51                                "  attr {\n"
52                                "    key: \"data_format\"\n"
53                                "    value {\n"
54                                "      s: \"NHWC\"\n"
55                                "    }\n"
56                                "  }\n"
57                                "  attr {\n"
58                                "    key: \"ksize\"\n"
59                                "    value {\n"
60                                "      list {\n"
61                                "        i: 1\n"
62                                "        i: 2\n"
63                                "        i: 2\n"
64                                "        i: 1\n"
65                                "      }\n"
66                                "    }\n"
67                                "  }\n"
68                                "  attr {\n"
69                                "    key: \"padding\"\n"
70                                "    value {\n"
71                                "      s: \"VALID\"\n"
72                                "    }\n"
73                                "  }\n"
74                                "  attr {\n"
75                                "    key: \"strides\"\n"
76                                "    value {\n"
77                                "      list {\n"
78                                "        i: 1\n"
79                                "        i: 1\n"
80                                "        i: 1\n"
81                                "        i: 1\n"
82                                "      }\n"
83                                "    }\n"
84                                "  }\n"
85                                "}\n");
86
87         SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype);
88     }
89 };
90
91
92 struct MaxPoolFixture : Pooling2dFixture
93 {
94     MaxPoolFixture() : Pooling2dFixture("MaxPool") {}
95 };
96 BOOST_FIXTURE_TEST_CASE(ParseMaxPool, MaxPoolFixture)
97 {
98     RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f});
99 }
100
101
102 struct AvgPoolFixture : Pooling2dFixture
103 {
104     AvgPoolFixture() : Pooling2dFixture("AvgPool") {}
105 };
106 BOOST_FIXTURE_TEST_CASE(ParseAvgPool, AvgPoolFixture)
107 {
108     RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f});
109 }
110
111
112 BOOST_AUTO_TEST_SUITE_END()