2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include "../Deserializer.hpp"
13 BOOST_AUTO_TEST_SUITE(Deserializer)
15 struct GatherFixture : public ParserFlatbuffersSerializeFixture
17 explicit GatherFixture(const std::string& inputShape,
18 const std::string& indicesShape,
19 const std::string& input1Content,
20 const std::string& outputShape,
21 const std::string& axis,
22 const std::string dataType,
23 const std::string constDataType)
31 layer_type: "InputLayer",
37 layerName: "InputLayer",
41 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
46 dimensions: )" + inputShape + R"(,
47 dataType: )" + dataType + R"(
52 layer_type: "ConstantLayer",
56 layerName: "ConstantLayer",
57 layerType: "Constant",
61 dimensions: )" + indicesShape + R"(,
68 dimensions: )" + indicesShape + R"(,
69 dataType: )" + dataType + R"(
71 data_type: )" + constDataType + R"(,
73 data: )" + input1Content + R"(,
77 layer_type: "GatherLayer",
81 layerName: "GatherLayer",
86 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
90 connection: {sourceLayerIndex:1, outputSlotIndex:0 }
95 dimensions: )" + outputShape + R"(,
96 dataType: )" + dataType + R"(
100 axis: )" + axis + R"(
104 layer_type: "OutputLayer",
110 layerName: "OutputLayer",
114 connection: {sourceLayerIndex:2, outputSlotIndex:0 },
119 dimensions: )" + outputShape + R"(,
120 dataType: )" + dataType + R"(
131 struct SimpleGatherFixtureFloat32 : GatherFixture
133 SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
134 "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {}
137 BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32)
139 RunTest<4, armnn::DataType::Float32>(0,
140 {{"InputLayer", { 1, 2, 3,
146 {{"OutputLayer", { 7, 8, 9,
160 BOOST_AUTO_TEST_SUITE_END()