Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / Convolution2d.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9 #include <string>
10 #include <iostream>
11
12 BOOST_AUTO_TEST_SUITE(TensorflowParser)
13
14 struct Convolution2dFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
15 {
16     explicit Convolution2dFixture(const char* paddingType)
17     : Convolution2dFixture(paddingType, 1)
18     {}
19
20     // dilation: 0 - dilations attribute is not included;
21     // dilation: >0 - dilations attribute set to [1,v,v,1], where v is the value of the dilation arg
22     explicit Convolution2dFixture(const char* paddingType, int stride, int dilation = 0)
23     {
24         std::string strideString = std::to_string(stride);
25         std::string dilationString = std::to_string(dilation);
26         m_Prototext = "node { \n"
27             "    name: \"graphInput\" \n"
28             "    op: \"Placeholder\" \n"
29             "    attr { \n"
30             "      key: \"dtype\" \n"
31             "      value { \n"
32             "        type: DT_FLOAT \n"
33             "      } \n"
34             "    } \n"
35             "    attr { \n"
36             "      key: \"shape\" \n"
37             "      value { \n"
38             "        shape { \n"
39             "        } \n"
40             "      } \n"
41             "    } \n"
42             "  } \n"
43             "  node { \n"
44             "  name: \"Const_1\" \n"
45             "  op: \"Const\" \n"
46             "  attr { \n"
47             "    key: \"dtype\" \n"
48             "    value { \n"
49             "      type: DT_FLOAT \n"
50             "    } \n"
51             "  } \n"
52             "  attr { \n"
53             "    key: \"value\" \n"
54             "    value { \n"
55             "      tensor { \n"
56             "        dtype: DT_FLOAT \n"
57             "        tensor_shape { \n"
58             "          dim { \n"
59             "            size: 1 \n"
60             "          } \n"
61             "          dim { \n"
62             "            size: 3 \n"
63             "          } \n"
64             "          dim { \n"
65             "            size: 1 \n"
66             "          } \n"
67             "          dim { \n"
68             "            size: 1 \n"
69             "          } \n"
70             "        } \n"
71             "        tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
72             "      } \n"
73             "    } \n"
74             "  } \n"
75             "} \n"
76             "node { \n"
77             "  name: \"potato\" \n"
78             "  op: \"Conv2D\" \n"
79             "  input: \"graphInput\" \n"
80             "  input: \"Const_1\" \n"
81             "  attr { \n"
82             "    key: \"T\" \n"
83             "    value { \n"
84             "      type: DT_FLOAT \n"
85             "    } \n"
86             "  } \n"
87             "  attr { \n"
88             "    key: \"data_format\" \n"
89             "    value { \n"
90             "      s: \"NHWC\" \n"
91             "    } \n"
92             "  } \n"
93             "  attr { \n"
94             "    key: \"padding\" \n"
95             "    value { \n"
96             "      s: \"";
97         m_Prototext.append(paddingType);
98         m_Prototext.append("\"\n"
99                            "    } \n"
100                            "  } \n"
101                            "  attr { \n"
102                            "    key: \"strides\" \n"
103                            "    value { \n"
104                            "      list { \n"
105                            "        i: 1 \n"
106                            "        i: 1 \n"
107                            "        i: ");
108         m_Prototext.append(strideString);
109         m_Prototext.append(" \n"
110                            "        i: 1 \n"
111                            "      } \n"
112                            "    } \n"
113                            "  } \n");
114
115         if (dilation > 0)
116         {
117             m_Prototext.append("  attr { \n"
118                                "    key: \"dilations\" \n"
119                                "    value { \n"
120                                "      list { \n"
121                                "        i: 1 \n"
122                                "        i: ");
123             m_Prototext.append(dilationString);
124             m_Prototext.append(" \n"
125                                "        i: ");
126             m_Prototext.append(dilationString);
127             m_Prototext.append(" \n"
128                                "        i: 1 \n"
129                                "      } \n"
130                                "    } \n"
131                                "  } \n");
132         }
133         m_Prototext.append("  attr { \n"
134                            "    key: \"use_cudnn_on_gpu\" \n"
135                            "    value { \n"
136                            "      b: false \n"
137                            "    } \n"
138                            "  } \n"
139                            "} \n");
140
141         // Manual height computation based on stride parameter.
142         BOOST_ASSERT_MSG(stride == 1 || stride==2, "Add support for strides other than 1 or 2.");
143         unsigned int dims[] = {1,2,3,1};
144         if (stride == 2)
145         {
146             dims[1]=3;
147         }
148
149         SetupSingleInputSingleOutput(armnn::TensorShape(4, dims), "graphInput", "potato");
150     }
151 };
152
153
154 struct Convolution2dSameFixture : Convolution2dFixture
155 {
156     Convolution2dSameFixture() : Convolution2dFixture("SAME", 1){}
157 };
158 BOOST_FIXTURE_TEST_CASE(ParseConv2DSame, Convolution2dSameFixture)
159 {
160     RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
161 }
162
163 struct Convolution2dValidFixture : Convolution2dFixture
164 {
165     Convolution2dValidFixture() : Convolution2dFixture("VALID", 1){}
166 };
167 BOOST_FIXTURE_TEST_CASE(ParseConv2DValid, Convolution2dValidFixture)
168 {
169     RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
170 }
171
172
173 struct Convolution2dStride2SameFixture : Convolution2dFixture
174 {
175     Convolution2dStride2SameFixture() : Convolution2dFixture("SAME", 2){}
176 };
177 BOOST_FIXTURE_TEST_CASE(ParseConv2DStride2Same, Convolution2dStride2SameFixture)
178 {
179     RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
180 }
181
182
183 struct Convolution2dStride2ValidFixture : Convolution2dFixture
184 {
185     Convolution2dStride2ValidFixture() : Convolution2dFixture("VALID", 2){}
186 };
187 BOOST_FIXTURE_TEST_CASE(ParseConv2DStride2Valid, Convolution2dStride2ValidFixture)
188 {
189     RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
190 }
191
192
193 struct Convolution2dDilation1Fixture : Convolution2dFixture
194 {
195     Convolution2dDilation1Fixture() : Convolution2dFixture("SAME", 1, 1){}
196 };
197 BOOST_FIXTURE_TEST_CASE(ParseConv2DDilation1, Convolution2dDilation1Fixture)
198 {
199     RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
200 }
201
202 BOOST_AUTO_TEST_CASE(ParseConv2DDilation2)
203 {
204     const char* prototext = ""
205         "node {\n"
206         "  name: \"graphInput\"\n"
207         "  op: \"Placeholder\"\n"
208         "  attr {\n"
209         "    key: \"dtype\"\n"
210         "    value {\n"
211         "      type: DT_FLOAT\n"
212         "    }\n"
213         "  }\n"
214         "  attr {\n"
215         "    key: \"shape\"\n"
216         "    value {\n"
217         "      shape {\n"
218         "      }\n"
219         "    }\n"
220         "  }\n"
221         "}\n"
222         "node {\n"
223         "  name: \"Const_1\"\n"
224         "  op: \"Const\"\n"
225         "  attr {\n"
226         "    key: \"dtype\"\n"
227         "    value {\n"
228         "      type: DT_FLOAT\n"
229         "    }\n"
230         "  }\n"
231         "  attr {\n"
232         "    key: \"value\"\n"
233         "    value {\n"
234         "      tensor {\n"
235         "        dtype: DT_FLOAT\n"
236         "        tensor_shape {\n"
237         "          dim {\n"
238         "            size: 1\n"
239         "          }\n"
240         "          dim {\n"
241         "            size: 3\n"
242         "          }\n"
243         "          dim {\n"
244         "            size: 1\n"
245         "          }\n"
246         "          dim {\n"
247         "            size: 1\n"
248         "          }\n"
249         "        }\n"
250         "        tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n"
251         "      }\n"
252         "    }\n"
253         "  }\n"
254         "}\n"
255         "node {\n"
256         "  name: \"potato\"\n"
257         "  op: \"Conv2D\"\n"
258         "  input: \"graphInput\"\n"
259         "  input: \"Const_1\"\n"
260         "  attr {\n"
261         "    key: \"T\"\n"
262         "    value {\n"
263         "      type: DT_FLOAT\n"
264         "    }\n"
265         "  }\n"
266         "  attr {\n"
267         "    key: \"data_format\"\n"
268         "    value {\n"
269         "      s: \"NHWC\"\n"
270         "    }\n"
271         "  }\n"
272         "  attr {\n"
273         "    key: \"padding\"\n"
274         "    value {\n"
275         "      s: \"SAME\"\n"
276         "    }\n"
277         "  }\n"
278         "  attr {\n"
279         "    key: \"strides\"\n"
280         "    value {\n"
281         "      list {\n"
282         "        i: 1\n"
283         "        i: 1\n"
284         "        i: 1\n"
285         "        i: 1\n"
286         "      }\n"
287         "    }\n"
288         "  }\n"
289         "  attr {\n"
290         "    key: \"dilations\"\n"
291         "    value {\n"
292         "      list {\n"
293         "        i: 1\n"
294         "        i: 2\n"
295         "        i: 2\n"
296         "        i: 1\n"
297         "      }\n"
298         "    }\n"
299         "  }\n"
300         "  attr {\n"
301         "    key: \"use_cudnn_on_gpu\"\n"
302         "    value {\n"
303         "      b: false\n"
304         "    }\n"
305         "  }\n"
306         "}\n";
307
308     std::map<std::string, armnn::TensorShape> inputShapes;
309     armnn::TensorShape tensorShape = { 1, 3, 3, 1 };
310     inputShapes["graphInput"] = tensorShape;
311     armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
312     BOOST_CHECK_EXCEPTION(parser->CreateNetworkFromString(prototext, inputShapes, { "potato" }),
313                           armnn::ParseException,
314                           [] (armnn::ParseException const& ex)->bool
315                           {
316                                 return strcmp(ex.what(),
317                                               "ArmNN only supports Convolution layers with dilations [1,1,1,1]") == 0;
318                           });
319 }
320
321
322 BOOST_AUTO_TEST_SUITE_END()