2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include <armnnDeserializer/IDeserializer.hpp>
12 BOOST_AUTO_TEST_SUITE(Deserializer)
14 struct GatherFixture : public ParserFlatbuffersSerializeFixture
16 explicit GatherFixture(const std::string& inputShape,
17 const std::string& indicesShape,
18 const std::string& input1Content,
19 const std::string& outputShape,
20 const std::string& axis,
21 const std::string dataType,
22 const std::string constDataType)
30 layer_type: "InputLayer",
36 layerName: "InputLayer",
40 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
45 dimensions: )" + inputShape + R"(,
46 dataType: )" + dataType + R"(
51 layer_type: "ConstantLayer",
55 layerName: "ConstantLayer",
56 layerType: "Constant",
60 dimensions: )" + indicesShape + R"(,
67 dimensions: )" + indicesShape + R"(,
68 dataType: )" + dataType + R"(
70 data_type: )" + constDataType + R"(,
72 data: )" + input1Content + R"(,
76 layer_type: "GatherLayer",
80 layerName: "GatherLayer",
85 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
89 connection: {sourceLayerIndex:1, outputSlotIndex:0 }
94 dimensions: )" + outputShape + R"(,
95 dataType: )" + dataType + R"(
103 layer_type: "OutputLayer",
109 layerName: "OutputLayer",
113 connection: {sourceLayerIndex:2, outputSlotIndex:0 },
118 dimensions: )" + outputShape + R"(,
119 dataType: )" + dataType + R"(
130 struct SimpleGatherFixtureFloat32 : GatherFixture
132 SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
133 "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {}
136 BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32)
138 RunTest<4, armnn::DataType::Float32>(0,
139 {{"InputLayer", { 1, 2, 3,
145 {{"OutputLayer", { 7, 8, 9,
159 BOOST_AUTO_TEST_SUITE_END()