2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include "../Deserializer.hpp"
12 BOOST_AUTO_TEST_SUITE(Deserializer)
14 struct StridedSliceFixture : public ParserFlatbuffersSerializeFixture
16 explicit StridedSliceFixture(const std::string& inputShape,
17 const std::string& begin,
18 const std::string& end,
19 const std::string& stride,
20 const std::string& beginMask,
21 const std::string& endMask,
22 const std::string& shrinkAxisMask,
23 const std::string& ellipsisMask,
24 const std::string& newAxisMask,
25 const std::string& dataLayout,
26 const std::string& outputShape,
27 const std::string& dataType)
35 layer_type: "InputLayer",
41 layerName: "InputLayer",
45 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
50 dimensions: )" + inputShape + R"(,
51 dataType: )" + dataType + R"(
59 layer_type: "StridedSliceLayer",
63 layerName: "StridedSliceLayer",
64 layerType: "StridedSlice",
67 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
72 dimensions: )" + outputShape + R"(,
73 dataType: )" + dataType + R"(
78 begin: )" + begin + R"(,
80 stride: )" + stride + R"(,
81 beginMask: )" + beginMask + R"(,
82 endMask: )" + endMask + R"(,
83 shrinkAxisMask: )" + shrinkAxisMask + R"(,
84 ellipsisMask: )" + ellipsisMask + R"(,
85 newAxisMask: )" + newAxisMask + R"(,
86 dataLayout: )" + dataLayout + R"(,
91 layer_type: "OutputLayer",
97 layerName: "OutputLayer",
101 connection: {sourceLayerIndex:1, outputSlotIndex:0 },
106 dimensions: )" + outputShape + R"(,
107 dataType: )" + dataType + R"(
117 SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
121 struct SimpleStridedSliceFixture : StridedSliceFixture
123 SimpleStridedSliceFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
137 BOOST_FIXTURE_TEST_CASE(SimpleStridedSliceFloat32, SimpleStridedSliceFixture)
139 RunTest<4, armnn::DataType::Float32>(0,
141 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
142 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
143 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
146 1.0f, 1.0f, 5.0f, 5.0f
150 struct StridedSliceMaskFixture : StridedSliceFixture
152 StridedSliceMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
166 BOOST_FIXTURE_TEST_CASE(StridedSliceMaskFloat32, StridedSliceMaskFixture)
168 RunTest<4, armnn::DataType::Float32>(0,
170 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
171 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
172 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
175 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
176 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
177 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
181 BOOST_AUTO_TEST_SUITE_END()