Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / Concat.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 struct ConcatFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14     explicit ConcatFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
15                            unsigned int concatDim)
16     {
17         m_Prototext = R"(
18         node {
19           name: "graphInput0"
20           op: "Placeholder"
21           attr {
22             key: "dtype"
23             value {
24               type: DT_FLOAT
25             }
26           }
27           attr {
28             key: "shape"
29             value {
30               shape {
31               }
32             }
33           }
34         }
35         node {
36           name: "graphInput1"
37           op: "Placeholder"
38           attr {
39             key: "dtype"
40             value {
41               type: DT_FLOAT
42             }
43           }
44           attr {
45             key: "shape"
46             value {
47               shape {
48               }
49             }
50           }
51         }
52         node {
53           name: "concat/axis"
54           op: "Const"
55           attr {
56             key: "dtype"
57             value {
58               type: DT_INT32
59             }
60           }
61           attr {
62             key: "value"
63             value {
64               tensor {
65                 dtype: DT_INT32
66                 tensor_shape {
67                 }
68                 int_val: )";
69
70         m_Prototext += std::to_string(concatDim);
71
72         m_Prototext += R"(
73               }
74             }
75           }
76         }
77         node {
78           name: "concat"
79           op: "ConcatV2"
80           input: "graphInput0"
81           input: "graphInput1"
82           input: "concat/axis"
83           attr {
84             key: "N"
85             value {
86               i: 2
87             }
88           }
89           attr {
90             key: "T"
91             value {
92               type: DT_FLOAT
93             }
94           }
95           attr {
96             key: "Tidx"
97             value {
98               type: DT_FLOAT
99             }
100           }
101         }
102         )";
103
104         Setup({{"graphInput0", inputShape0 },
105                {"graphInput1", inputShape1 }}, {"concat"});
106     }
107 };
108
109 struct ConcatFixtureNCHW : ConcatFixture
110 {
111     ConcatFixtureNCHW() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 1 ) {}
112 };
113
114 struct ConcatFixtureNHWC : ConcatFixture
115 {
116     ConcatFixtureNHWC() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 3 ) {}
117 };
118
119 BOOST_FIXTURE_TEST_CASE(ParseConcatNCHW, ConcatFixtureNCHW)
120 {
121     RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
122                 {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
123                {{"concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 }}});
124 }
125
126 BOOST_FIXTURE_TEST_CASE(ParseConcatNHWC, ConcatFixtureNHWC)
127 {
128     RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
129                 {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
130                {{"concat", { 0.0, 1.0, 4.0, 5.0, 2.0, 3.0, 6.0, 7.0 }}});
131 }
132
133 struct ConcatFixtureDim1 : ConcatFixture
134 {
135     ConcatFixtureDim1() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 1) {}
136 };
137
138 struct ConcatFixtureDim3 : ConcatFixture
139 {
140     ConcatFixtureDim3() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 3) {}
141 };
142
143 BOOST_FIXTURE_TEST_CASE(ParseConcatDim1, ConcatFixtureDim1)
144 {
145     RunTest<4>({ { "graphInput0", {  0.0,  1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0, 10.0, 11.0,
146                                      12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0 } },
147                  { "graphInput1", {  50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
148                                      62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } },
149                { { "concat",      {  0.0,  1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0, 10.0, 11.0,
150                                      12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0,
151                                      50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
152                                      62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } });
153 }
154
155 BOOST_FIXTURE_TEST_CASE(ParseConcatDim3, ConcatFixtureDim3)
156 {
157     RunTest<4>({ { "graphInput0", {  0.0, 1.0, 2.0, 3.0,
158                                      4.0, 5.0, 6.0, 7.0,
159                                      8.0, 9.0, 10.0, 11.0,
160                                      12.0, 13.0, 14.0, 15.0,
161                                      16.0, 17.0, 18.0, 19.0,
162                                      20.0, 21.0, 22.0, 23.0 } },
163                  { "graphInput1", {  50.0, 51.0, 52.0, 53.0,
164                                      54.0, 55.0, 56.0, 57.0,
165                                      58.0, 59.0, 60.0, 61.0,
166                                      62.0, 63.0, 64.0, 65.0,
167                                      66.0, 67.0, 68.0, 69.0,
168                                      70.0, 71.0, 72.0, 73.0 } } },
169                { { "concat",      {  0.0,  1.0,  2.0,  3.0,
170                                      50.0, 51.0, 52.0, 53.0,
171                                      4.0,  5.0,  6.0,  7.0,
172                                      54.0, 55.0, 56.0, 57.0,
173                                      8.0,  9.0,  10.0, 11.0,
174                                      58.0, 59.0, 60.0, 61.0,
175                                      12.0, 13.0, 14.0, 15.0,
176                                      62.0, 63.0, 64.0, 65.0,
177                                      16.0, 17.0, 18.0, 19.0,
178                                      66.0, 67.0, 68.0, 69.0,
179                                      20.0, 21.0, 22.0, 23.0,
180                                      70.0, 71.0, 72.0, 73.0 } } });
181 }
182
183 BOOST_AUTO_TEST_SUITE_END()