IVGCVSW-3536 Add Axis parameter to reference Softmax implementation
authorFrancis Murtagh <francis.murtagh@arm.com>
Tue, 23 Jul 2019 08:50:50 +0000 (09:50 +0100)
committerFrancis Murtagh <francis.murtagh@arm.com>
Tue, 23 Jul 2019 08:50:56 +0000 (09:50 +0100)
 * Add Axis parameter to Softmax Descriptor
 * Add new reference implementation for Softmax using Axis parameter
 * Add unit tests to cover each Axis

Change-Id: Iafac2275d2212337456f2b1b56b0f76f77fb9543
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
include/armnn/Descriptors.hpp
src/backends/backendsCommon/test/LayerTests.cpp
src/backends/backendsCommon/test/LayerTests.hpp
src/backends/backendsCommon/test/SoftmaxTestImpl.hpp
src/backends/reference/test/RefLayerTests.cpp
src/backends/reference/workloads/RefSoftmaxWorkload.cpp
src/backends/reference/workloads/Softmax.cpp
src/backends/reference/workloads/Softmax.hpp

index 377f070..9630d86 100644 (file)
@@ -49,9 +49,15 @@ struct PermuteDescriptor
 /// A SoftmaxDescriptor for the SoftmaxLayer.
 struct SoftmaxDescriptor
 {
-    SoftmaxDescriptor() : m_Beta(1.0f) {}
+    SoftmaxDescriptor()
+    : m_Beta(1.0f)
+    , m_Axis(-1)
+    {}
+
     /// Exponentiation value.
-    float              m_Beta;
+    float m_Beta;
+    /// Scalar, defaulted to the last index (-1), specifying the dimension the activation will be performed on.
+    int m_Axis;
 };
 
 /// @brief An OriginsDescriptor for the ConcatLayer.
index d6e0e87..b40a3f5 100644 (file)
@@ -77,6 +77,36 @@ static std::vector<float> ConvInput3x8x16({
 // 2-channel bias used by a number of Conv2d tests.
 static std::vector<float> Bias2({0, 2});
 
+struct Simple3dSoftmaxOutputData
+{
+    const std::vector<float> outputData =
+            {
+                0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
+                0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f
+            };
+
+    const armnn::TensorShape inputShape{ 1, 8, 1 };
+
+    const std::vector<float> inputData =
+            {
+                    0.f, 1.f, 0.f, 0.f,
+                    .5f, 0.f, 0.f, 0.f,
+            };
+};
+
+struct Simple4dSoftmaxData
+{
+    const armnn::TensorShape inputShape{ 1, 8, 1, 1 };
+
+    const std::vector<float> outputData = { 0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
+                                            0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f };
+    const std::vector<float> inputData =
+            {
+                    0.f, 1.f, 0.f, 0.f,
+                    .5f, 0.f, 0.f, 0.f
+            };
+};
+
 // Helper function that returns either Bias2 or an empty vector depending on whether bias is enabled.
 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 boost::multi_array<T, 1> GetBias2(bool biasEnabled, float qScale)
@@ -1647,12 +1677,117 @@ LayerTestResult<float,2> SimpleSoftmaxTest(
     return SimpleSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta);
 }
 
+LayerTestResult<float,2> SimpleAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    return SimpleSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta, axis);
+}
+
 LayerTestResult<float,3> Simple3dSoftmaxTest(
         armnn::IWorkloadFactory& workloadFactory,
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta);
+    Simple3dSoftmaxOutputData data;
+    return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta,
+                                                             data.inputShape, data.outputData, data.inputData);
+}
+
+LayerTestResult<float,3> Simple3dAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    armnn::TensorShape inputShape;
+    std::vector<float> inputData;
+    std::vector<float> outputData;
+    switch (axis)
+    {
+    case -3:
+    case 0:
+        {
+            inputShape = {5, 2, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
+
+                            15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
+                            0.087144312427294f,
+
+                            0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    case -2:
+    case 1:
+        {
+            inputShape = {2, 5, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+
+                            17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f,
+
+                            0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f
+                    };
+        break;
+        }
+    case -1:
+    case 2:
+        {
+            inputShape = {2, 2, 5};
+
+            inputData =
+                    {
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    }
+
+    return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta,
+                                                             inputShape, outputData, inputData, axis);
 }
 
 LayerTestResult<float,4> Simple4dSoftmaxTest(
@@ -1660,7 +1795,167 @@ LayerTestResult<float,4> Simple4dSoftmaxTest(
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta);
+    Simple4dSoftmaxData data;
+    return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta, data.inputShape,
+                                                             data.outputData, data.inputData);
+}
+
+LayerTestResult<float,4> Simple4dAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    armnn::TensorShape inputShape;
+    std::vector<float> inputData;
+    std::vector<float> outputData;
+    switch (axis)
+    {
+    case -4:
+    case 0:
+        {
+            inputShape = {5, 2, 2, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 17.0f, -1.0f, 17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f,
+                            16.0f, -2.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f, 15.0f, -3.0f,
+                            15.0f, -3.0f, 15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 14.0f, -4.0f,
+                            14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f, 1.0f, -17.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.643914213228014f,
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.236882800924671f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.236882800924671f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.087144312427294f, 0.087144312427294f, 0.087144312427294f,
+                            0.087144312427294f,
+
+                            0.087144312427294f, 0.087144312427294f, 0.087144312427294f, 0.087144312427294f,
+                            0.032058600957022f,
+                            0.032058600957022f, 0.032058600957022f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f,
+                            7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f, 7.246299848982885e-08f
+                    };
+            break;
+        }
+    case -3:
+    case 1:
+        {
+            inputShape = {2, 5, 2, 2};
+
+            inputData =
+                    {
+                            17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
+                            15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f,
+                            17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
+                            15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f,
+
+
+                            0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
+                            0.236882800924671f,
+                            0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
+                            0.087144312427294f,
+                            0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
+                            0.032058600957022f,
+                            0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    case -2:
+    case 2:
+        {
+        inputShape = {2, 2, 5, 2};
+
+        inputData =
+                {
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
+                };
+
+        outputData =
+                {
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f,
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f,
+
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f,
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f
+                };
+        break;
+        }
+    case -1:
+    case 3:
+        {
+            inputShape = {2, 2, 2, 5};
+
+            inputData =
+                    {
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
+                            17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
+                    };
+
+            outputData =
+                    {
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f,
+                            0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                            7.246299848982885e-08f
+                    };
+            break;
+        }
+    }
+
+    return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, beta, inputShape,
+                                                             outputData, inputData, axis);
 }
 
 LayerTestResult<uint8_t,2> SimpleSoftmaxUint8Test(
@@ -1676,7 +1971,9 @@ LayerTestResult<uint8_t,3> Simple3dSoftmaxUint8Test(
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta);
+    Simple3dSoftmaxOutputData data;
+    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<uint8_t,4> Simple4dSoftmaxUint8Test(
@@ -1684,7 +1981,10 @@ LayerTestResult<uint8_t,4> Simple4dSoftmaxUint8Test(
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta);
+    Simple4dSoftmaxData data;
+
+    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<int16_t,2> SimpleSoftmaxUint16Test(
@@ -1700,7 +2000,9 @@ LayerTestResult<int16_t,3> Simple3dSoftmaxUint16Test(
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta);
+    Simple3dSoftmaxOutputData data;
+    return Simple3dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<int16_t,4> Simple4dSoftmaxUint16Test(
@@ -1708,7 +2010,10 @@ LayerTestResult<int16_t,4> Simple4dSoftmaxUint16Test(
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta)
 {
-    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta);
+    Simple4dSoftmaxData data;
+
+    return Simple4dSoftmaxTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager, beta,
+                                                                     data.inputShape, data.outputData, data.inputData);
 }
 
 LayerTestResult<float,4> CompareNormalizationTest(
index d99e3b4..913c3a6 100644 (file)
@@ -472,16 +472,34 @@ LayerTestResult<float, 2> SimpleSoftmaxTest(
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     float beta);
 
+LayerTestResult<float, 2> SimpleAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis);
+
 LayerTestResult<float, 3> Simple3dSoftmaxTest(
         armnn::IWorkloadFactory& workloadFactory,
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta);
 
+LayerTestResult<float, 3> Simple3dAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis);
+
 LayerTestResult<float, 4> Simple4dSoftmaxTest(
         armnn::IWorkloadFactory& workloadFactory,
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         float beta);
 
+LayerTestResult<float, 4> Simple4dAxisSoftmaxTest(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis);
+
 LayerTestResult<uint8_t, 2> SimpleSoftmaxUint8Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
index 8081950..983a53b 100644 (file)
@@ -25,7 +25,9 @@ LayerTestResult<T, n> SimpleSoftmaxBaseTestImpl(
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     float beta,
     const armnn::TensorShape& inputShape,
-    const std::vector<float>& outputData)
+    const std::vector<float>& outputData,
+    const std::vector<float>& inputData,
+    int axis = 1)
 {
     using std::exp;
 
@@ -47,16 +49,14 @@ LayerTestResult<T, n> SimpleSoftmaxBaseTestImpl(
 
     // Each row is independently softmax'd.
     auto input = MakeTensor<T, n>(inputTensorInfo, std::vector<T>(
-        QuantizedVector<T>(qScale, qOffset, {
-            0.f, 1.f, 0.f, 0.f,
-            .5f, 0.f, 0.f, 0.f,
-        })));
+        QuantizedVector<T>(qScale, qOffset, inputData)));
 
     std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
     std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
 
     armnn::SoftmaxQueueDescriptor data;
     data.m_Parameters.m_Beta = beta;
+    data.m_Parameters.m_Axis = axis;
 
     armnn::WorkloadInfo info;
     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
@@ -100,33 +100,98 @@ LayerTestResult<T, 2> SimpleSoftmaxTestImpl(
     const std::vector<float> outputData = { x0[0] / sum0, x0[1] / sum0, x0[2] / sum0, x0[3] / sum0,
                                             x1[0] / sum1, x1[1] / sum1, x1[2] / sum1, x1[3] / sum1 };
 
-    return SimpleSoftmaxBaseTestImpl<ArmnnType, 2>(workloadFactory, memoryManager, beta, inputShape, outputData);
+    const std::vector<float> inputData =
+            {
+                0.f, 1.f, 0.f, 0.f,
+                .5f, 0.f, 0.f, 0.f,
+            };
+
+    return SimpleSoftmaxBaseTestImpl<ArmnnType, 2>(workloadFactory, memoryManager, beta,
+                                                   inputShape, outputData, inputData);
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> SimpleSoftmaxTestImpl(
+        armnn::IWorkloadFactory& workloadFactory,
+        const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+        float beta,
+        int axis)
+{
+    armnn::TensorShape inputShape;
+    std::vector<float> inputData;
+    std::vector<float> outputData;
+    switch (axis)
+    {
+    case -2:
+    case 0:
+        {
+        inputShape = {5, 2};
+
+        inputData =
+                {
+                        17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
+                };
+
+        outputData =
+                {
+                        0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
+                        0.087144312427294f,
+                        0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
+                        7.246299848982885e-08f
+                };
+        break;
+        }
+    case -1:
+    case 1:
+        {
+        inputShape = {2, 5};
+
+        inputData =
+                {
+                        17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
+                };
+
+        outputData =
+                {
+                        0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                        7.246299848982885e-08f,
+                        0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
+                        7.246299848982885e-08f
+                };
+        break;
+        }
+    }
+    return SimpleSoftmaxBaseTestImpl<ArmnnType, 2>(workloadFactory, memoryManager, beta,
+                                                   inputShape, outputData, inputData, axis);
 }
 
 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 3> Simple3dSoftmaxTestImpl(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
-    float beta)
+    float beta,
+    const armnn::TensorShape& inputShape,
+    const std::vector<float>& outputData,
+    const std::vector<float>& inputData,
+    int axis = 1)
 {
-    const armnn::TensorShape inputShape{ 1, 8, 1 };
-    const std::vector<float> outputData = { 0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
-                                            0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f };
-
-    return SimpleSoftmaxBaseTestImpl<ArmnnType, 3>(workloadFactory, memoryManager, beta, inputShape, outputData);
+    return SimpleSoftmaxBaseTestImpl<ArmnnType, 3>(workloadFactory, memoryManager, beta,
+                                                   inputShape, outputData, inputData, axis);
 }
 
 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 4> Simple4dSoftmaxTestImpl(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
-    float beta)
+    float beta,
+    const armnn::TensorShape& inputShape,
+    const std::vector<float>& outputData,
+    const std::vector<float>& inputData,
+    int axis = 1)
 {
-    const armnn::TensorShape inputShape{ 1, 8, 1, 1 };
-    const std::vector<float> outputData = { 0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
-                                            0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f };
 
-    return SimpleSoftmaxBaseTestImpl<ArmnnType, 4>(workloadFactory, memoryManager, beta, inputShape, outputData);
+    return SimpleSoftmaxBaseTestImpl<ArmnnType, 4>(workloadFactory, memoryManager, beta,
+                                                   inputShape, outputData, inputData, axis);
 }
 
 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
index 8af42ea..5cb8042 100644 (file)
@@ -374,6 +374,30 @@ ARMNN_AUTO_TEST_CASE(SimpleSoftmaxUint16, SimpleSoftmaxUint16Test, 1.0f)
 ARMNN_AUTO_TEST_CASE(Simple3dSoftmaxUint16, Simple3dSoftmaxUint16Test, 1.0f)
 ARMNN_AUTO_TEST_CASE(Simple4dSoftmaxUint16, Simple4dSoftmaxUint16Test, 1.0f)
 
+ARMNN_AUTO_TEST_CASE(Simple2dAxis0Softmax, SimpleAxisSoftmaxTest, 1.0f, 0)
+ARMNN_AUTO_TEST_CASE(Simple2dAxis1Softmax, SimpleAxisSoftmaxTest, 1.0f, 1)
+
+ARMNN_AUTO_TEST_CASE(Simple2dAxis0NegSoftmax, SimpleAxisSoftmaxTest, 1.0f, -2)
+ARMNN_AUTO_TEST_CASE(Simple2dAxis1NegSoftmax, SimpleAxisSoftmaxTest, 1.0f, -1)
+
+ARMNN_AUTO_TEST_CASE(Simple3dAxis0Softmax, Simple3dAxisSoftmaxTest, 1.0f, 0)
+ARMNN_AUTO_TEST_CASE(Simple3dAxis1Softmax, Simple3dAxisSoftmaxTest, 1.0f, 1)
+ARMNN_AUTO_TEST_CASE(Simple3dAxis2Softmax, Simple3dAxisSoftmaxTest, 1.0f, 2)
+
+ARMNN_AUTO_TEST_CASE(Simple3dAxis0NegSoftmax, Simple3dAxisSoftmaxTest, 1.0f, -3)
+ARMNN_AUTO_TEST_CASE(Simple3dAxis1NegSoftmax, Simple3dAxisSoftmaxTest, 1.0f, -2)
+ARMNN_AUTO_TEST_CASE(Simple3dAxis2NegSoftmax, Simple3dAxisSoftmaxTest, 1.0f, -1)
+
+ARMNN_AUTO_TEST_CASE(Simple4dAxis0Softmax, Simple4dAxisSoftmaxTest, 1.0f, 0)
+ARMNN_AUTO_TEST_CASE(Simple4dAxis1Softmax, Simple4dAxisSoftmaxTest, 1.0f, 1)
+ARMNN_AUTO_TEST_CASE(Simple4dAxis2Softmax, Simple4dAxisSoftmaxTest, 1.0f, 2)
+ARMNN_AUTO_TEST_CASE(Simple4dAxis3Softmax, Simple4dAxisSoftmaxTest, 1.0f, 3)
+
+ARMNN_AUTO_TEST_CASE(Simple4dAxis0NegSoftmax, Simple4dAxisSoftmaxTest, 1.0f, -4)
+ARMNN_AUTO_TEST_CASE(Simple4dAxis1NegSoftmax, Simple4dAxisSoftmaxTest, 1.0f, -3)
+ARMNN_AUTO_TEST_CASE(Simple4dAxis2NegSoftmax, Simple4dAxisSoftmaxTest, 1.0f, -2)
+ARMNN_AUTO_TEST_CASE(Simple4dAxis3NegSoftmax, Simple4dAxisSoftmaxTest, 1.0f, -1)
+
 // Sigmoid Activation
 ARMNN_AUTO_TEST_CASE(SimpleSigmoid, SimpleSigmoidTest)
 ARMNN_AUTO_TEST_CASE(SimpleSigmoidUint8, SimpleSigmoidUint8Test)
index b176667..0f6f837 100644 (file)
@@ -34,6 +34,7 @@ void RefSoftmaxWorkload::Execute() const
     Softmax(decoder,
             encoder,
             inputTensorInfo,
-            m_Data.m_Parameters.m_Beta);
+            m_Data.m_Parameters.m_Beta,
+            m_Data.m_Parameters.m_Axis);
 }
 } //namespace armnn
index 6cb219a..ec4fdb8 100644 (file)
 namespace armnn
 {
 
+unsigned int GetNumElementsBetween(const TensorShape& shape,
+                                   unsigned int firstAxisInclusive,
+                                   unsigned int lastAxisExclusive)
+{
+    BOOST_ASSERT(0 <= firstAxisInclusive);
+    BOOST_ASSERT(firstAxisInclusive <= lastAxisExclusive);
+    BOOST_ASSERT(lastAxisExclusive <= shape.GetNumDimensions());
+    unsigned int count = 1;
+    for (unsigned int i = firstAxisInclusive; i < lastAxisExclusive; i++)
+    {
+        count *= shape[i];
+    }
+    return count;
+}
+
 /// Computes the softmax function on some inputs, into outputs, with a shape given by tensorInfo.
-void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta)
+void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta, int axis)
 {
-    unsigned int numChannels = inputTensorInfo.GetShape()[1];
+    BOOST_ASSERT_MSG(axis < static_cast<int>(inputTensorInfo.GetNumDimensions()),
+                     "Required axis index greater than number of dimensions.");
+    BOOST_ASSERT_MSG(axis >= -static_cast<int>(inputTensorInfo.GetNumDimensions()),
+                     "Required axis index lower than negative of the number of dimensions");
+
+    unsigned int uAxis = axis < 0  ?
+                         inputTensorInfo.GetNumDimensions() - static_cast<unsigned int>(abs(axis))
+                         : static_cast<unsigned int>(axis);
 
-    for (unsigned int n = 0; n < inputTensorInfo.GetShape()[0]; n++)
+    const TensorShape& inputShape = inputTensorInfo.GetShape();
+    const unsigned int outerSize  = GetNumElementsBetween(inputShape, 0, uAxis);
+    const unsigned int axisSize   = inputShape[uAxis];
+    const unsigned int innerSize  = GetNumElementsBetween(inputShape, uAxis + 1, inputShape.GetNumDimensions());
+
+    for (unsigned int outer = 0; outer < outerSize; ++outer)
     {
-        // Find maximum channel.
-        in[n * numChannels];
-        float max = in.Get();
-        for (unsigned int c = 1; c < numChannels; c++)
+        unsigned int inputBeginIdx  = outer * axisSize * innerSize;
+        unsigned int inputEndIdx    = inputBeginIdx + axisSize * innerSize;
+        unsigned int outputBeginIdx = outer * axisSize * innerSize;
+
+        for (unsigned int inner = 0; inner < innerSize; ++inner, ++inputBeginIdx, ++inputEndIdx, ++outputBeginIdx)
         {
-            in[n * numChannels + c];
-            float val = in.Get();
-            if (val > max)
+            // Find max
+            float maxValue = std::numeric_limits<float>::lowest();
+            for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
             {
-                max = val;
+                in[iter];
+                maxValue = std::max(maxValue, in.Get());
             }
-        }
 
-        // Exponentiate all values and sum.
-        std::vector<float> exponentials(numChannels);
-        float              sum = 0.0f;
-        for (unsigned int c = 0; c < numChannels; c++)
-        {
-            in[n * numChannels + c];
-            float val = in.Get();
-            exponentials[c] = expf((val - max) * beta);
-            sum += exponentials[c];
-        }
+            // Compute sum
+            float sum = 0.0f;
+            for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
+            {
+                in[iter];
+                sum += std::exp((in.Get() - maxValue) * beta);
+            }
 
-        // Divide exponentials by sum to give outputs.
-        for (unsigned int c = 0; c < numChannels; c++)
-        {
-            out[n * numChannels + c];
-            out.Set(exponentials[c] / sum);
+            // Compute result
+            unsigned int outputIter = outputBeginIdx;
+            out[outputIter];
+            for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize, outputIter += innerSize)
+            {
+                out[outputIter];
+                in[iter];
+                out.Set(std::exp((in.Get() - maxValue) * beta) / sum);
+            }
         }
     }
 }
index 3876293..25c7449 100644 (file)
@@ -12,6 +12,6 @@ namespace armnn
 {
 
 /// Computes the softmax function on some inputs, into outputs, with a shape given by tensorInfo.
-void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta);
+void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta, int axis = -1);
 
 } //namespace armnn