2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
12 struct ConcatFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
14 explicit ConcatFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
15 unsigned int concatDim)
70 m_Prototext += std::to_string(concatDim);
104 Setup({{"graphInput0", inputShape0 },
105 {"graphInput1", inputShape1 }}, {"concat"});
109 struct ConcatFixtureNCHW : ConcatFixture
111 ConcatFixtureNCHW() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 1 ) {}
114 struct ConcatFixtureNHWC : ConcatFixture
116 ConcatFixtureNHWC() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 3 ) {}
119 BOOST_FIXTURE_TEST_CASE(ParseConcatNCHW, ConcatFixtureNCHW)
121 RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
122 {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
123 {{"concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 }}});
126 BOOST_FIXTURE_TEST_CASE(ParseConcatNHWC, ConcatFixtureNHWC)
128 RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
129 {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
130 {{"concat", { 0.0, 1.0, 4.0, 5.0, 2.0, 3.0, 6.0, 7.0 }}});
133 struct ConcatFixtureDim1 : ConcatFixture
135 ConcatFixtureDim1() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 1) {}
138 struct ConcatFixtureDim3 : ConcatFixture
140 ConcatFixtureDim3() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 3) {}
143 BOOST_FIXTURE_TEST_CASE(ParseConcatDim1, ConcatFixtureDim1)
145 RunTest<4>({ { "graphInput0", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
146 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0 } },
147 { "graphInput1", { 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
148 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } },
149 { { "concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
150 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0,
151 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
152 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } });
155 BOOST_FIXTURE_TEST_CASE(ParseConcatDim3, ConcatFixtureDim3)
157 RunTest<4>({ { "graphInput0", { 0.0, 1.0, 2.0, 3.0,
159 8.0, 9.0, 10.0, 11.0,
160 12.0, 13.0, 14.0, 15.0,
161 16.0, 17.0, 18.0, 19.0,
162 20.0, 21.0, 22.0, 23.0 } },
163 { "graphInput1", { 50.0, 51.0, 52.0, 53.0,
164 54.0, 55.0, 56.0, 57.0,
165 58.0, 59.0, 60.0, 61.0,
166 62.0, 63.0, 64.0, 65.0,
167 66.0, 67.0, 68.0, 69.0,
168 70.0, 71.0, 72.0, 73.0 } } },
169 { { "concat", { 0.0, 1.0, 2.0, 3.0,
170 50.0, 51.0, 52.0, 53.0,
172 54.0, 55.0, 56.0, 57.0,
173 8.0, 9.0, 10.0, 11.0,
174 58.0, 59.0, 60.0, 61.0,
175 12.0, 13.0, 14.0, 15.0,
176 62.0, 63.0, 64.0, 65.0,
177 16.0, 17.0, 18.0, 19.0,
178 66.0, 67.0, 68.0, 69.0,
179 20.0, 21.0, 22.0, 23.0,
180 70.0, 71.0, 72.0, 73.0 } } });
183 BOOST_AUTO_TEST_SUITE_END()