6 #include <boost/test/unit_test.hpp> 8 #include "../Deserializer.hpp" 16 explicit StridedSliceFixture(
const std::string& inputShape,
17 const std::string& begin,
18 const std::string& end,
19 const std::string& stride,
20 const std::string& beginMask,
21 const std::string& endMask,
22 const std::string& shrinkAxisMask,
23 const std::string& ellipsisMask,
24 const std::string& newAxisMask,
25 const std::string& dataLayout,
26 const std::string& outputShape,
27 const std::string& dataType)
35 layer_type: "InputLayer", 41 layerName: "InputLayer", 45 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 50 dimensions: )" + inputShape + R"(, 51 dataType: )" + dataType + R"( 59 layer_type: "StridedSliceLayer", 63 layerName: "StridedSliceLayer", 64 layerType: "StridedSlice", 67 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 72 dimensions: )" + outputShape + R"(, 73 dataType: )" + dataType + R"( 78 begin: )" + begin + R"(, 80 stride: )" + stride + R"(, 81 beginMask: )" + beginMask + R"(, 82 endMask: )" + endMask + R"(, 83 shrinkAxisMask: )" + shrinkAxisMask + R"(, 84 ellipsisMask: )" + ellipsisMask + R"(, 85 newAxisMask: )" + newAxisMask + R"(, 86 dataLayout: )" + dataLayout + R"(, 91 layer_type: "OutputLayer", 97 layerName: "OutputLayer", 101 connection: {sourceLayerIndex:1, outputSlotIndex:0 }, 106 dimensions: )" + outputShape + R"(, 107 dataType: )" + dataType + R"( 121 struct SimpleStridedSliceFixture : StridedSliceFixture
123 SimpleStridedSliceFixture() : StridedSliceFixture(
"[ 3, 2, 3, 1 ]",
139 RunTest<4, armnn::DataType::Float32>(0,
141 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
142 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
143 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
146 1.0f, 1.0f, 5.0f, 5.0f
150 struct StridedSliceMaskFixture : StridedSliceFixture
152 StridedSliceMaskFixture() : StridedSliceFixture(
"[ 3, 2, 3, 1 ]",
168 RunTest<4, armnn::DataType::Float32>(0,
170 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
171 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
172 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
175 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
176 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
177 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
BOOST_FIXTURE_TEST_CASE(SimpleStridedSliceFloat32, SimpleStridedSliceFixture)
BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)