Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / DepthwiseConvolution2d.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9 #include <string>
10 #include <iostream>
11
12 BOOST_AUTO_TEST_SUITE(TensorflowParser)
13
14 struct DepthwiseConvolution2dFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
15 {
16     explicit DepthwiseConvolution2dFixture(const char* paddingType)
17     {
18         m_Prototext = "node { \n"
19                       "    name: \"graphInput\" \n"
20                       "    op: \"Placeholder\" \n"
21                       "    attr { \n"
22                       "      key: \"dtype\" \n"
23                       "      value { \n"
24                       "        type: DT_FLOAT \n"
25                       "      } \n"
26                       "    } \n"
27                       "    attr { \n"
28                       "      key: \"value\" \n"
29                       "      value { \n"
30                       "        tensor { \n"
31                       "          dtype: DT_FLOAT \n"
32                       "          tensor_shape { \n"
33                       "            dim { \n"
34                       "              size: 1 \n"
35                       "            } \n"
36                       "            dim { \n"
37                       "              size: 1 \n"
38                       "            } \n"
39                       "            dim { \n"
40                       "              size: 3 \n"
41                       "            } \n"
42                       "            dim { \n"
43                       "              size: 3 \n"
44                       "            } \n"
45                       "          } \n"
46                       "          tensor_content: \"\\000\\000\\200?\\000\\000\\000@\\000\\000@@\\000\\000\\200@"
47                       "\\000\\000\\240@\\000\\000\\300@\\000\\000\\340@\\000\\000\\000A\\000\\000\\020A\" \n"
48                       "        } \n"
49                       "      } \n"
50                       "    } \n"
51                       "  } \n"
52                       "  node { \n"
53                       "  name: \"Const_1\" \n"
54                       "  op: \"Const\" \n"
55                       "  attr { \n"
56                       "    key: \"dtype\" \n"
57                       "    value { \n"
58                       "      type: DT_FLOAT \n"
59                       "    } \n"
60                       "  } \n"
61                       "  attr { \n"
62                       "    key: \"value\" \n"
63                       "    value { \n"
64                       "      tensor { \n"
65                       "        dtype: DT_FLOAT \n"
66                       "        tensor_shape { \n"
67                       "          dim { \n"
68                       "            size: 1 \n"
69                       "          } \n"
70                       "          dim { \n"
71                       "            size: 3 \n"
72                       "          } \n"
73                       "          dim { \n"
74                       "            size: 3 \n"
75                       "          } \n"
76                       "          dim { \n"
77                       "            size: 3 \n"
78                       "          } \n"
79                       "        } \n"
80                       "        tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
81                       "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
82                       "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
83                       "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
84                       "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
85                       "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
86                       "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
87                       "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
88                       "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
89                       "      } \n"
90                       "    } \n"
91                       "  } \n"
92                       "} \n"
93                       "node { \n"
94                       "  name: \"potato\" \n"
95                       "  op: \"DepthwiseConv2dNative\" \n"
96                       "  input: \"graphInput\" \n"
97                       "  input: \"Const_1\" \n"
98                       "  attr { \n"
99                       "    key: \"T\" \n"
100                       "    value { \n"
101                       "      type: DT_FLOAT \n"
102                       "    } \n"
103                       "  } \n"
104                       "  attr { \n"
105                       "    key: \"data_format\" \n"
106                       "    value { \n"
107                       "      s: \"NHWC\" \n"
108                       "    } \n"
109                       "  } \n"
110                       "  attr { \n"
111                       "    key: \"padding\" \n"
112                       "    value { \n"
113                       "      s: \"";
114         m_Prototext.append(paddingType);
115         m_Prototext.append("\"\n"
116                       "    } \n"
117                       "  } \n"
118                       "  attr { \n"
119                       "    key: \"strides\" \n"
120                       "    value { \n"
121                       "      list { \n"
122                       "        i: 1 \n"
123                       "        i: 1 \n"
124                       "        i: 1 \n"
125                       "        i: 1 \n"
126                       "      } \n"
127                       "    } \n"
128                       "  } \n"
129                       "  attr { \n"
130                       "    key: \"use_cudnn_on_gpu\" \n"
131                       "    value { \n"
132                       "      b: false \n"
133                       "    } \n"
134                       "  } \n"
135                       "} \n");
136
137         SetupSingleInputSingleOutput({ 1, 1, 3, 3 }, "graphInput", "potato");
138     }
139 };
140
141 struct DepthwiseConvolution2dSameFixture : DepthwiseConvolution2dFixture
142 {
143     DepthwiseConvolution2dSameFixture() : DepthwiseConvolution2dFixture("SAME") { }
144 };
145
146 BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DSame, DepthwiseConvolution2dSameFixture)
147 {
148     RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
149                { 2.5f, 5.f,  2.5f, 3.5f, 7.f,  3.5f, 4.5f, 9.f,  4.5f,
150                  6.f,  12.f, 6.f,  7.5f, 15.f, 7.5f, 9.f,  18.f, 9.f,
151                  5.5f, 11.f, 5.5f, 6.5f, 13.f, 6.5f, 7.5f, 15.f, 7.5f});
152 }
153
154 struct DepthwiseConvolution2dValidFixture : DepthwiseConvolution2dFixture
155 {
156     DepthwiseConvolution2dValidFixture() : DepthwiseConvolution2dFixture("VALID") { }
157 };
158
159 BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DValid, DepthwiseConvolution2dValidFixture)
160 {
161     RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // input data
162                { 6.f,  12.f, 6.f,  7.5f, 15.f, 7.5f, 9.f,  18.f, 9.f });  // output expected data
163 }
164
165
166 BOOST_AUTO_TEST_SUITE_END()