IVGCVSW-4531 Fix for failing strided slice NNT/VTS tests on Android R
authorDavid Monahan <david.monahan@arm.com>
Fri, 13 Mar 2020 07:52:54 +0000 (07:52 +0000)
committerJim Flynn <jim.flynn@arm.com>
Fri, 13 Mar 2020 19:09:07 +0000 (19:09 +0000)
Signed-off-by: David Monahan <david.monahan@arm.com>
Change-Id: I7f6932f0d21b5678ab9822b0fc69f589bbbb07e5

src/armnn/layers/StridedSliceLayer.cpp
src/backends/backendsCommon/test/EndToEndTestImpl.hpp
src/backends/cl/test/ClEndToEndTests.cpp
src/backends/neon/test/NeonEndToEndTests.cpp
src/backends/reference/test/RefEndToEndTests.cpp

index b9c337188cfed08cef6b7d9e23fe26dfe0b60360..dd4f942086f5436e156ef171a5993f465dd6cbce 100644 (file)
@@ -52,15 +52,21 @@ std::vector<TensorShape> StridedSliceLayer::InferOutputShapes(
 
     for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++)
     {
+        int stride = m_Param.m_Stride[i];
+        int start = m_Param.GetStartForAxis(inputShape, i);
+        int stop = m_Param.GetStopForAxis(inputShape, i, start);
+
         if (m_Param.m_ShrinkAxisMask & (1 << i))
         {
+            // Don't take a slice from an axis being shrunk
+            if (m_Param.m_End[i] >= 2)
+            {
+                throw LayerValidationException(
+                    "StridedSlice: Attempting to take slice from an axis being shrunk");
+            }
             continue;
         }
 
-        int stride = m_Param.m_Stride[i];
-        int start = m_Param.GetStartForAxis(inputShape, i);
-        int stop = m_Param.GetStopForAxis(inputShape, i, start);
-
         int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
                                    ((start - stop) - stride - 1) / -stride;
 
index 358f4e3fc288e6885d53a2296be33b7ac8df8014..4221f626da6cfcf7afb6b2cad18876229b9670b4 100644 (file)
@@ -766,4 +766,40 @@ inline void ExportOutputWithSeveralOutputSlotConnectionsTest(std::vector<Backend
     BOOST_TEST(found != std::string::npos);
 }
 
+inline void StridedSliceInvalidSliceEndToEndTest(std::vector<BackendId> backends)
+{
+    using namespace armnn;
+
+    // Create runtime in which test will run
+    IRuntime::CreationOptions options;
+    IRuntimePtr runtime(armnn::IRuntime::Create(options));
+
+    // build up the structure of the network
+    INetworkPtr net(INetwork::Create());
+
+    IConnectableLayer* input = net->AddInputLayer(0);
+
+    // Configure a strided slice with a stride the same size as the input but with a ShrinkAxisMask on the first
+    // dim of the output to make it too small to hold the specified slice.
+    StridedSliceDescriptor descriptor;
+    descriptor.m_Begin          = {0, 0};
+    descriptor.m_End            = {2, 3};
+    descriptor.m_Stride         = {1, 1};
+    descriptor.m_BeginMask      = 0;
+    descriptor.m_EndMask        = 0;
+    descriptor.m_ShrinkAxisMask = 1;
+    IConnectableLayer* stridedSlice = net->AddStridedSliceLayer(descriptor);
+
+    IConnectableLayer* output0 = net->AddOutputLayer(0);
+
+    input->GetOutputSlot(0).Connect(stridedSlice->GetInputSlot(0));
+    stridedSlice->GetOutputSlot(0).Connect(output0->GetInputSlot(0));
+
+    input->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 2, 3 }, DataType::Float32));
+    stridedSlice->GetOutputSlot(0).SetTensorInfo(TensorInfo({ 3 }, DataType::Float32));
+
+    // Attempt to optimize the network and check that the correct exception is thrown
+    BOOST_CHECK_THROW(Optimize(*net, backends, runtime->GetDeviceSpec()), armnn::LayerValidationException);
+}
+
 } // anonymous namespace
index 10abcc9fe6d9eb954581d52c47963063daa71dbd..05f9538d6ff7fd92f223609d025b90fd2aaa5037 100644 (file)
@@ -126,6 +126,11 @@ BOOST_AUTO_TEST_CASE(DequantizeEndToEndOffsetTest)
     DequantizeEndToEndOffset<armnn::DataType::QAsymmU8>(defaultBackends);
 }
 
+BOOST_AUTO_TEST_CASE(ClStridedSliceInvalidSliceEndToEndTest)
+{
+    StridedSliceInvalidSliceEndToEndTest(defaultBackends);
+}
+
 BOOST_AUTO_TEST_CASE(ClGreaterSimpleEndToEndTest)
 {
     const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0,  1, 1, 1, 1,
index abded6491517c75edf4fce4b77914512b6bb8016..081b8af60acb3789ef23c9dcea38b4333ceefc81 100644 (file)
@@ -513,6 +513,11 @@ BOOST_AUTO_TEST_CASE(NeonArgMinAxis3TestQuantisedAsymm8)
     ArgMinAxis3EndToEnd<armnn::DataType::QAsymmU8>(defaultBackends);
 }
 
+BOOST_AUTO_TEST_CASE(NeonStridedSliceInvalidSliceEndToEndTest)
+{
+    StridedSliceInvalidSliceEndToEndTest(defaultBackends);
+}
+
 BOOST_AUTO_TEST_CASE(NeonDetectionPostProcessRegularNmsTest, * boost::unit_test::disabled())
 {
     std::vector<float> boxEncodings({
index 54a68810f6fd33e8e36c5fee7e6aab69acbde028..bdda12f392b87425c69b29a570b3ec3ca81d42e2 100644 (file)
@@ -1210,6 +1210,11 @@ BOOST_AUTO_TEST_CASE(RefExportOutputWithSeveralOutputSlotConnectionsTest)
     ExportOutputWithSeveralOutputSlotConnectionsTest(defaultBackends);
 }
 
+BOOST_AUTO_TEST_CASE(RefStridedSliceInvalidSliceEndToEndTest)
+{
+    StridedSliceInvalidSliceEndToEndTest(defaultBackends);
+}
+
 #endif
 
 BOOST_AUTO_TEST_SUITE_END()