Release 18.08
[platform/upstream/armnn.git] / src / armnnOnnxParser / test / Pooling.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnOnnxParser/IOnnxParser.hpp"
8 #include  "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12 struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
14     PoolingMainFixture(const std::string& dataType, const std::string& op)
15     {
16         m_Prototext = R"(
17                    ir_version: 3
18                    producer_name:  "CNTK"
19                    producer_version:  "2.5.1"
20                    domain:  "ai.cntk"
21                    model_version: 1
22                    graph {
23                      name:  "CNTKGraph"
24                      input {
25                         name: "Input"
26                         type {
27                           tensor_type {
28                             elem_type: )" + dataType + R"(
29                             shape {
30                               dim {
31                                 dim_value: 1
32                               }
33                               dim {
34                                 dim_value: 1
35                               }
36                               dim {
37                                 dim_value: 2
38                               }
39                               dim {
40                                 dim_value: 2
41                               }
42                             }
43                           }
44                         }
45                       }
46                      node {
47                          input: "Input"
48                          output: "Output"
49                          name: "Pooling"
50                          op_type: )" + op + R"(
51                          attribute {
52                            name: "kernel_shape"
53                            ints: 2
54                            ints: 2
55                            type: INTS
56                          }
57                          attribute {
58                            name: "strides"
59                            ints: 1
60                            ints: 1
61                            type: INTS
62                          }
63                          attribute {
64                            name: "pads"
65                            ints: 0
66                            ints: 0
67                            ints: 0
68                            ints: 0
69                            type: INTS
70                          }
71                       }
72                       output {
73                           name: "Output"
74                           type {
75                              tensor_type {
76                                elem_type: FLOAT
77                                shape {
78                                    dim {
79                                        dim_value: 1
80                                    }
81                                    dim {
82                                        dim_value: 1
83                                    }
84                                    dim {
85                                        dim_value: 1
86                                    }
87                                    dim {
88                                        dim_value: 1
89                                    }
90                                }
91                             }
92                         }
93                         }
94                     }
95                    opset_import {
96                       version: 7
97                     })";
98     }
99 };
100
101 struct MaxPoolValidFixture : PoolingMainFixture
102 {
103     MaxPoolValidFixture() : PoolingMainFixture("FLOAT", "\"MaxPool\"") {
104         Setup();
105     }
106 };
107
108 struct MaxPoolInvalidFixture : PoolingMainFixture
109 {
110     MaxPoolInvalidFixture() : PoolingMainFixture("FLOAT16", "\"MaxPool\"") { }
111 };
112
113 BOOST_FIXTURE_TEST_CASE(ValidMaxPoolTest, MaxPoolValidFixture)
114 {
115     RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}});
116 }
117
118 struct AvgPoolValidFixture : PoolingMainFixture
119 {
120     AvgPoolValidFixture() : PoolingMainFixture("FLOAT", "\"AveragePool\"") {
121         Setup();
122     }
123 };
124
125 struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
126 {
127     PoolingWithPadFixture()
128     {
129         m_Prototext = R"(
130                    ir_version: 3
131                    producer_name:  "CNTK"
132                    producer_version:  "2.5.1"
133                    domain:  "ai.cntk"
134                    model_version: 1
135                    graph {
136                      name:  "CNTKGraph"
137                      input {
138                         name: "Input"
139                         type {
140                           tensor_type {
141                             elem_type: FLOAT
142                             shape {
143                               dim {
144                                 dim_value: 1
145                               }
146                               dim {
147                                 dim_value: 1
148                               }
149                               dim {
150                                 dim_value: 2
151                               }
152                               dim {
153                                 dim_value: 2
154                               }
155                             }
156                           }
157                         }
158                       }
159                      node {
160                          input: "Input"
161                          output: "Output"
162                          name: "Pooling"
163                          op_type: "AveragePool"
164                          attribute {
165                            name: "kernel_shape"
166                            ints: 4
167                            ints: 4
168                            type: INTS
169                          }
170                          attribute {
171                            name: "strides"
172                            ints: 1
173                            ints: 1
174                            type: INTS
175                          }
176                          attribute {
177                            name: "pads"
178                            ints: 1
179                            ints: 1
180                            ints: 1
181                            ints: 1
182                            type: INTS
183                          }
184                          attribute {
185                            name: "count_include_pad"
186                            i: 1
187                            type: INT
188                          }
189                       }
190                       output {
191                           name: "Output"
192                           type {
193                              tensor_type {
194                                elem_type: FLOAT
195                                shape {
196                                    dim {
197                                        dim_value: 1
198                                    }
199                                    dim {
200                                        dim_value: 1
201                                    }
202                                    dim {
203                                        dim_value: 1
204                                    }
205                                    dim {
206                                        dim_value: 1
207                                    }
208                                }
209                             }
210                         }
211                         }
212                     }
213                    opset_import {
214                       version: 7
215                     })";
216         Setup();
217     }
218 };
219
220 BOOST_FIXTURE_TEST_CASE(AveragePoolValid, AvgPoolValidFixture)
221 {
222     RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}});
223 }
224
225 BOOST_FIXTURE_TEST_CASE(ValidAvgWithPadTest, PoolingWithPadFixture)
226 {
227     RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}});
228 }
229
230 struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
231 {
232     GlobalAvgFixture()
233     {
234         m_Prototext = R"(
235                    ir_version: 3
236                    producer_name:  "CNTK"
237                    producer_version:  "2.5.1"
238                    domain:  "ai.cntk"
239                    model_version: 1
240                    graph {
241                      name:  "CNTKGraph"
242                      input {
243                         name: "Input"
244                         type {
245                           tensor_type {
246                             elem_type: FLOAT
247                             shape {
248                               dim {
249                                 dim_value: 1
250                               }
251                               dim {
252                                 dim_value: 2
253                               }
254                               dim {
255                                 dim_value: 2
256                               }
257                               dim {
258                                 dim_value: 2
259                               }
260                             }
261                           }
262                         }
263                       }
264                      node {
265                          input: "Input"
266                          output: "Output"
267                          name: "Pooling"
268                          op_type: "GlobalAveragePool"
269                       }
270                       output {
271                           name: "Output"
272                           type {
273                              tensor_type {
274                                elem_type: FLOAT
275                                shape {
276                                    dim {
277                                        dim_value: 1
278                                    }
279                                    dim {
280                                        dim_value: 2
281                                    }
282                                    dim {
283                                        dim_value: 1
284                                    }
285                                    dim {
286                                        dim_value: 1
287                                    }
288                                }
289                             }
290                         }
291                         }
292                     }
293                    opset_import {
294                       version: 7
295                     })";
296         Setup();
297     }
298 };
299
300 BOOST_FIXTURE_TEST_CASE(GlobalAvgTest, GlobalAvgFixture)
301 {
302     RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}});
303 }
304
305 BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeMaxPool, MaxPoolInvalidFixture)
306 {
307    BOOST_CHECK_THROW(Setup(), armnn::ParseException);
308 }
309
310 BOOST_AUTO_TEST_SUITE_END()