IVGCVSW-1843 : replacing trivial arithmetic helpers
authorDavid Beck <david.beck@arm.com>
Tue, 11 Sep 2018 15:37:14 +0000 (16:37 +0100)
committerMatthew Bentham <matthew.bentham@arm.com>
Mon, 1 Oct 2018 13:56:47 +0000 (14:56 +0100)
Change-Id: Iddf637694f1a3a7ef00f006a41b8044a35c7e73c

21 files changed:
Android.mk
CMakeLists.txt
src/armnn/backends/RefWorkloads.hpp
src/armnn/backends/RefWorkloads/Addition.cpp [deleted file]
src/armnn/backends/RefWorkloads/Addition.hpp [deleted file]
src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp [new file with mode: 0644]
src/armnn/backends/RefWorkloads/ArithmeticFunction.hpp [new file with mode: 0644]
src/armnn/backends/RefWorkloads/Division.cpp [deleted file]
src/armnn/backends/RefWorkloads/Division.hpp [deleted file]
src/armnn/backends/RefWorkloads/Multiplication.cpp [deleted file]
src/armnn/backends/RefWorkloads/Multiplication.hpp [deleted file]
src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp
src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefDivisionFloat32Workload.cpp
src/armnn/backends/RefWorkloads/RefDivisionUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefMultiplicationFloat32Workload.cpp
src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
src/armnn/backends/RefWorkloads/Subtraction.cpp [deleted file]
src/armnn/backends/RefWorkloads/Subtraction.hpp [deleted file]

index 9c2373678dd23cb12001f2a97d71c71c80e4bfba..9c4db74d1a5afdeab15fcf6405f54fe87cd3a75a 100644 (file)
@@ -128,16 +128,14 @@ LOCAL_SRC_FILES := \
         src/armnn/backends/RefWorkloads/RefSoftmaxFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/RefActivationFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp \
-        src/armnn/backends/RefWorkloads/Multiplication.cpp \
         src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp \
         src/armnn/backends/RefWorkloads/RefBaseConstantWorkload.cpp \
         src/armnn/backends/RefWorkloads/RefResizeBilinearFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/RefBatchNormalizationFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/Broadcast.cpp \
-        src/armnn/backends/RefWorkloads/Addition.cpp \
+        src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp \
         src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp \
         src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp \
-        src/armnn/backends/RefWorkloads/Subtraction.cpp \
         src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp \
         src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/RefFakeQuantizationFloat32Workload.cpp \
@@ -170,7 +168,6 @@ LOCAL_SRC_FILES := \
         src/armnn/backends/RefWorkloads/RefPermuteWorkload.cpp \
         src/armnn/backends/RefWorkloads/RefConvertFp16ToFp32Workload.cpp \
         src/armnn/backends/RefWorkloads/RefConvertFp32ToFp16Workload.cpp \
-        src/armnn/backends/RefWorkloads/Division.cpp \
         src/armnn/backends/RefWorkloads/RefDivisionFloat32Workload.cpp \
         src/armnn/backends/RefWorkloads/RefDivisionUint8Workload.cpp \
         src/armnn/backends/MemCopyWorkload.cpp \
index 777c3153e64247279756f42298e7c6ce8c207b9c..9c2685c96d886f3ca1b7b3b4ab8b2ab67525c0c5 100644 (file)
@@ -186,14 +186,12 @@ list(APPEND armnn_sources
     src/armnn/backends/RefWorkloads/Broadcast.cpp
     src/armnn/backends/RefWorkloads/RefMergerUint8Workload.cpp
     src/armnn/backends/RefWorkloads/RefConstantUint8Workload.hpp
-    src/armnn/backends/RefWorkloads/Addition.cpp
-    src/armnn/backends/RefWorkloads/Addition.hpp
+    src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp
+    src/armnn/backends/RefWorkloads/ArithmeticFunction.hpp
     src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.cpp
     src/armnn/backends/RefWorkloads/RefAdditionUint8Workload.hpp
     src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.cpp
     src/armnn/backends/RefWorkloads/RefAdditionFloat32Workload.hpp
-    src/armnn/backends/RefWorkloads/Subtraction.cpp
-    src/armnn/backends/RefWorkloads/Subtraction.hpp
     src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.cpp
     src/armnn/backends/RefWorkloads/RefSubtractionUint8Workload.hpp
     src/armnn/backends/RefWorkloads/RefSubtractionFloat32Workload.cpp
@@ -210,12 +208,8 @@ list(APPEND armnn_sources
     src/armnn/backends/RefWorkloads/RefActivationFloat32Workload.cpp
     src/armnn/backends/RefWorkloads/RefBatchNormalizationUint8Workload.cpp
     src/armnn/backends/RefWorkloads/RefResizeBilinearUint8Workload.hpp
-    src/armnn/backends/RefWorkloads/Multiplication.cpp
-    src/armnn/backends/RefWorkloads/Division.cpp
-    src/armnn/backends/RefWorkloads/Division.hpp
     src/armnn/backends/RefWorkloads/RefMultiplicationUint8Workload.cpp
     src/armnn/backends/RefWorkloads/RefL2NormalizationFloat32Workload.hpp
-    src/armnn/backends/RefWorkloads/Multiplication.hpp
     src/armnn/backends/RefWorkloads/RefActivationUint8Workload.hpp
     src/armnn/backends/RefWorkloads/RefBaseConstantWorkload.cpp
     src/armnn/backends/RefWorkloads/RefResizeBilinearFloat32Workload.cpp
index 910610c72eb093f6df0761804efb22b52c179bc0..e58d4accbb8a1a40da6484a882dde848718bbe10 100644 (file)
@@ -6,7 +6,7 @@
 #pragma once
 
 #include "backends/RefWorkloads/RefConstantUint8Workload.hpp"
-#include "backends/RefWorkloads/Addition.hpp"
+#include "backends/RefWorkloads/ArithmeticFunction.hpp"
 #include "backends/RefWorkloads/ConvImpl.hpp"
 #include "backends/RefWorkloads/RefMultiplicationUint8Workload.hpp"
 #include "backends/RefWorkloads/RefBaseConstantWorkload.hpp"
@@ -14,7 +14,6 @@
 #include "backends/RefWorkloads/RefSplitterUint8Workload.hpp"
 #include "backends/RefWorkloads/RefResizeBilinearUint8Workload.hpp"
 #include "backends/RefWorkloads/RefL2NormalizationFloat32Workload.hpp"
-#include "backends/RefWorkloads/Multiplication.hpp"
 #include "backends/RefWorkloads/RefActivationUint8Workload.hpp"
 #include "backends/RefWorkloads/RefPooling2dFloat32Workload.hpp"
 #include "backends/RefWorkloads/RefWorkloadUtils.hpp"
diff --git a/src/armnn/backends/RefWorkloads/Addition.cpp b/src/armnn/backends/RefWorkloads/Addition.cpp
deleted file mode 100644 (file)
index 33d5bd5..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Addition.hpp"
-#include "Broadcast.hpp"
-
-#include <functional>
-
-namespace
-{
-
-void ElementwiseAddition(unsigned int numElements, const float* inData0, const float* inData1, float* outData)
-{
-    for (unsigned int i = 0; i < numElements; ++i)
-    {
-        outData[i] = inData0[i] + inData1[i];
-    }
-}
-
-} // namespace
-
-namespace armnn
-{
-
-void Addition(const TensorShape& inShape0,
-              const TensorShape& inShape1,
-              const TensorShape& outShape,
-              const float* inData0,
-              const float* inData1,
-              float* outData)
-{
-    if (inShape0 == inShape1)
-    {
-        ElementwiseAddition(inShape0.GetNumElements(), inData0, inData1, outData);
-    }
-    else
-    {
-        BroadcastLoop(inShape0, inShape1, outShape).Unroll(std::plus<float>(), 0, inData0, inData1, outData);
-    }
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Addition.hpp b/src/armnn/backends/RefWorkloads/Addition.hpp
deleted file mode 100644 (file)
index dcbd499..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <armnn/Tensor.hpp>
-
-namespace armnn
-{
-
-void Addition(const TensorShape& inShape0,
-              const TensorShape& inShape1,
-              const TensorShape& outShape,
-              const float* inData0,
-              const float* inData1,
-              float* outData);
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp b/src/armnn/backends/RefWorkloads/ArithmeticFunction.cpp
new file mode 100644 (file)
index 0000000..fede138
--- /dev/null
@@ -0,0 +1,29 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ArithmeticFunction.hpp"
+#include "Broadcast.hpp"
+#include <functional>
+
+namespace armnn
+{
+
+template <typename Functor>
+ArithmeticFunction<Functor>::ArithmeticFunction(const TensorShape& inShape0,
+                                                const TensorShape& inShape1,
+                                                const TensorShape& outShape,
+                                                const float* inData0,
+                                                const float* inData1,
+                                                float* outData)
+{
+    BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData);
+}
+
+} //namespace armnn
+
+template struct armnn::ArithmeticFunction<std::plus<float>>;
+template struct armnn::ArithmeticFunction<std::minus<float>>;
+template struct armnn::ArithmeticFunction<std::multiplies<float>>;
+template struct armnn::ArithmeticFunction<std::divides<float>>;
diff --git a/src/armnn/backends/RefWorkloads/ArithmeticFunction.hpp b/src/armnn/backends/RefWorkloads/ArithmeticFunction.hpp
new file mode 100644 (file)
index 0000000..eafb644
--- /dev/null
@@ -0,0 +1,24 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Tensor.hpp>
+
+namespace armnn
+{
+
+template <typename Functor>
+struct ArithmeticFunction
+{
+    ArithmeticFunction(const TensorShape& inShape0,
+                       const TensorShape& inShape1,
+                       const TensorShape& outShape,
+                       const float* inData0,
+                       const float* inData1,
+                       float* outData);
+};
+
+} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Division.cpp b/src/armnn/backends/RefWorkloads/Division.cpp
deleted file mode 100644 (file)
index cc7f7c9..0000000
+++ /dev/null
@@ -1,89 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Division.hpp"
-#include "Broadcast.hpp"
-
-#include <functional>
-
-#include <cmath>
-
-namespace
-{
-
-void ElementwiseDivision(unsigned int numElements,
-                         const float* inData0,
-                         const float* inData1,
-                         float* outData)
-{
-    for (unsigned int i = 0; i < numElements; ++i)
-    {
-        if (inData1[i] != 0.0f)
-        {
-            outData[i] = inData0[i] / inData1[i];
-        }
-        else if (inData0[i] == 0.0f)
-        {
-            if (!std::signbit(inData1[i]))
-            {
-                outData[i]= NAN;
-            }
-            else
-            {
-                outData[i]= -NAN;
-            }
-        }
-        else if (inData0[i] < 0.0f)
-        {
-            if (!std::signbit(inData1[i]))
-            {
-                outData[i] = -INFINITY;
-            }
-            else
-            {
-                outData[i] = INFINITY;
-            }
-        }
-        else
-        {
-            if (!std::signbit(inData1[i]))
-            {
-                outData[i] = INFINITY;
-            }
-            else
-            {
-                outData[i] = -INFINITY;
-            }
-        }
-    }
-}
-
-} // namespace
-
-namespace armnn
-{
-
-void Division(const TensorShape& inShape0,
-              const TensorShape& inShape1,
-              const TensorShape& outShape,
-              const float* inData0,
-              const float* inData1,
-              float* outData)
-{
-    if (inShape0 == inShape1)
-    {
-        ElementwiseDivision(inShape0.GetNumElements(), inData0, inData1, outData);
-    }
-    else
-    {
-        BroadcastLoop(inShape0, inShape1, outShape).Unroll(std::divides<float>(),
-                                                           0,
-                                                           inData0,
-                                                           inData1,
-                                                           outData);
-    }
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Division.hpp b/src/armnn/backends/RefWorkloads/Division.hpp
deleted file mode 100644 (file)
index b83c77f..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <armnn/Tensor.hpp>
-
-namespace armnn
-{
-
-    void Division(const TensorShape& inShape0,
-                  const TensorShape& inShape1,
-                  const TensorShape& outShape,
-                  const float* inData0,
-                  const float* inData1,
-                  float* outData);
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Multiplication.cpp b/src/armnn/backends/RefWorkloads/Multiplication.cpp
deleted file mode 100644 (file)
index ae6446a..0000000
+++ /dev/null
@@ -1,52 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Multiplication.hpp"
-#include "Broadcast.hpp"
-
-#include <functional>
-
-namespace
-{
-
-void ElementwiseMultiplication(unsigned int numElements,
-                               const float* inData0,
-                               const float* inData1,
-                               float* outData)
-{
-    for (unsigned int i = 0; i < numElements; ++i)
-    {
-        outData[i] = inData0[i] * inData1[i];
-    }
-}
-
-} // namespace
-
-namespace armnn
-{
-
-void Multiplication(const TensorShape& inShape0,
-                    const TensorShape& inShape1,
-                    const TensorShape& outShape,
-                    const float* inData0,
-                    const float* inData1,
-                    float* outData)
-{
-    if (inShape0 == inShape1)
-    {
-        ElementwiseMultiplication(inShape0.GetNumElements(), inData0, inData1, outData);
-    }
-    else
-    {
-        BroadcastLoop(inShape0, inShape1, outShape).Unroll(
-            std::multiplies<float>(),
-            0,
-            inData0,
-            inData1,
-            outData);
-    }
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Multiplication.hpp b/src/armnn/backends/RefWorkloads/Multiplication.hpp
deleted file mode 100644 (file)
index 58ad7b4..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <armnn/Tensor.hpp>
-
-namespace armnn
-{
-
-void Multiplication(const TensorShape& inShape0,
-                    const TensorShape& inShape1,
-                    const TensorShape& outShape,
-                    const float* inData0,
-                    const float* inData1,
-                    float* outData);
-
-} //namespace armnn
index c2a5b5fcbd6fb283eb33c054ac5c58608affdb90..21c7533c0f1ed7bb05fe76402cadfeaa9161646b 100644 (file)
@@ -5,7 +5,7 @@
 
 #include "RefAdditionFloat32Workload.hpp"
 
-#include "Addition.hpp"
+#include "ArithmeticFunction.hpp"
 #include "RefWorkloadUtils.hpp"
 
 #include "Profiling.hpp"
@@ -25,7 +25,7 @@ void RefAdditionFloat32Workload::Execute() const
     const float* inData1 = GetInputTensorDataFloat(1, m_Data);
     float* outData = GetOutputTensorDataFloat(0, m_Data);
 
-    Addition(inShape0, inShape1, outShape, inData0, inData1, outData);
+    ArithmeticFunction<std::plus<float>>(inShape0, inShape1, outShape, inData0, inData1, outData);
 }
 
 } //namespace armnn
index 2999be92402b084977142fa64c677ddc04726262..116a5f14cba55506c8b4bb3db3c50cb333168175 100644 (file)
@@ -5,7 +5,7 @@
 
 #include "RefAdditionUint8Workload.hpp"
 
-#include "Addition.hpp"
+#include "ArithmeticFunction.hpp"
 #include "RefWorkloadUtils.hpp"
 
 #include "Profiling.hpp"
@@ -28,12 +28,12 @@ void RefAdditionUint8Workload::Execute() const
 
     std::vector<float> results(outputInfo.GetNumElements());
 
-    Addition(inputInfo0.GetShape(),
-             inputInfo1.GetShape(),
-             outputInfo.GetShape(),
-             dequant0.data(),
-             dequant1.data(),
-             results.data());
+    ArithmeticFunction<std::plus<float>>(inputInfo0.GetShape(),
+                                         inputInfo1.GetShape(),
+                                         outputInfo.GetShape(),
+                                         dequant0.data(),
+                                         dequant1.data(),
+                                         results.data());
 
     Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
 }
index 81f4645cbc2b5839b6b6c529d4e1e108dbe390c8..28c90610ded1dd52c374f66d52d421ed041481df 100644 (file)
@@ -5,7 +5,7 @@
 
 #include "RefDivisionFloat32Workload.hpp"
 
-#include "Division.hpp"
+#include "ArithmeticFunction.hpp"
 #include "RefWorkloadUtils.hpp"
 
 #include "Profiling.hpp"
@@ -25,7 +25,7 @@ void RefDivisionFloat32Workload::Execute() const
     const float* inputData0 = GetInputTensorDataFloat(0, m_Data);
     const float* inputData1 = GetInputTensorDataFloat(1, m_Data);
 
-    Division(inShape0, inShape1, outShape, inputData0, inputData1, outputData);
+    ArithmeticFunction<std::divides<float>>(inShape0, inShape1, outShape, inputData0, inputData1, outputData);
 }
 
 } //namespace armnn
index a6ed770c40c493789c8070a16824c077ac6f133c..d10d8741377dd8f9e72523cc85049afc69cbf8a4 100644 (file)
@@ -5,7 +5,7 @@
 
 #include "RefDivisionUint8Workload.hpp"
 
-#include "Division.hpp"
+#include "ArithmeticFunction.hpp"
 #include "RefWorkloadUtils.hpp"
 
 #include "Profiling.hpp"
@@ -27,9 +27,13 @@ void RefDivisionUint8Workload::Execute() const
     auto dequant1 = Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1);
 
     std::vector<float> results(outputInfo.GetNumElements());
-    Division(
-        inputInfo0.GetShape(), inputInfo1.GetShape(), outputInfo.GetShape(),
-        dequant0.data(), dequant1.data(),results.data());
+
+    ArithmeticFunction<std::divides<float>>(inputInfo0.GetShape(),
+                                            inputInfo1.GetShape(),
+                                            outputInfo.GetShape(),
+                                            dequant0.data(),
+                                            dequant1.data(),
+                                            results.data());
 
     Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
 }
index 022cca70e74d0dcb7ab6973665be9dad09b1ff0c..0b36f0ff00ae17648339ec90c374a6dbbdc503af 100644 (file)
@@ -5,7 +5,7 @@
 
 #include "RefMultiplicationFloat32Workload.hpp"
 
-#include "Multiplication.hpp"
+#include "ArithmeticFunction.hpp"
 #include "RefWorkloadUtils.hpp"
 
 #include "Profiling.hpp"
@@ -25,7 +25,7 @@ void RefMultiplicationFloat32Workload::Execute() const
     const float* inputData0 = GetInputTensorDataFloat(0, m_Data);
     const float* inputData1 = GetInputTensorDataFloat(1, m_Data);
 
-    Multiplication(inShape0, inShape1, outShape, inputData0, inputData1, outputData);
+    ArithmeticFunction<std::multiplies<float>>(inShape0, inShape1, outShape, inputData0, inputData1, outputData);
 }
 
 } //namespace armnn
index 8e0a617bf5ec3d7d9eda0d86d2b00cc3ed051de7..b929a538087c984b05b15dad16a611629720b074 100644 (file)
@@ -5,7 +5,7 @@
 
 #include "RefMultiplicationUint8Workload.hpp"
 
-#include "Multiplication.hpp"
+#include "ArithmeticFunction.hpp"
 #include "RefWorkloadUtils.hpp"
 
 #include "Profiling.hpp"
@@ -27,9 +27,13 @@ void RefMultiplicationUint8Workload::Execute() const
     auto dequant1 = Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1);
 
     std::vector<float> results(outputInfo.GetNumElements());
-    Multiplication(
-        inputInfo0.GetShape(), inputInfo1.GetShape(), outputInfo.GetShape(),
-        dequant0.data(), dequant1.data(),results.data());
+
+    ArithmeticFunction<std::multiplies<float>>(inputInfo0.GetShape(),
+                                               inputInfo1.GetShape(),
+                                               outputInfo.GetShape(),
+                                               dequant0.data(),
+                                               dequant1.data(),
+                                               results.data());
 
    Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
 }
index 4440eedab7899eb4f906d25acc03653213007558..f1840c347b13db317b3a594f6210843ca6b0e3a0 100644 (file)
@@ -5,7 +5,7 @@
 
 #include "RefSubtractionFloat32Workload.hpp"
 
-#include "Subtraction.hpp"
+#include "ArithmeticFunction.hpp"
 #include "RefWorkloadUtils.hpp"
 
 #include "Profiling.hpp"
@@ -25,7 +25,7 @@ void RefSubtractionFloat32Workload::Execute() const
     const float* inData1 = GetInputTensorDataFloat(1, m_Data);
     float* outData = GetOutputTensorDataFloat(0, m_Data);
 
-    Subtraction(inShape0, inShape1, outShape, inData0, inData1, outData);
+    ArithmeticFunction<std::minus<float>>(inShape0, inShape1, outShape, inData0, inData1, outData);
 }
 
 } //namespace armnn
index 8066762e485cdabb271ca08efa5a7421d7bc5564..1affbdd8b185bc7021398160e7b9c6fbe6a4de2f 100644 (file)
@@ -5,7 +5,7 @@
 
 #include "RefSubtractionUint8Workload.hpp"
 
-#include "Subtraction.hpp"
+#include "ArithmeticFunction.hpp"
 #include "RefWorkloadUtils.hpp"
 
 #include "Profiling.hpp"
@@ -28,12 +28,12 @@ void RefSubtractionUint8Workload::Execute() const
 
     std::vector<float> results(outputInfo.GetNumElements());
 
-    Subtraction(inputInfo0.GetShape(),
-                inputInfo1.GetShape(),
-                outputInfo.GetShape(),
-                dequant0.data(),
-                dequant1.data(),
-                results.data());
+    ArithmeticFunction<std::minus<float>>(inputInfo0.GetShape(),
+                                          inputInfo1.GetShape(),
+                                          outputInfo.GetShape(),
+                                          dequant0.data(),
+                                          dequant1.data(),
+                                          results.data());
 
     Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
 }
diff --git a/src/armnn/backends/RefWorkloads/Subtraction.cpp b/src/armnn/backends/RefWorkloads/Subtraction.cpp
deleted file mode 100644 (file)
index f25c8ad..0000000
+++ /dev/null
@@ -1,44 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "Subtraction.hpp"
-#include "Broadcast.hpp"
-
-#include <functional>
-
-namespace
-{
-
-void ElementwiseSubtraction(unsigned int numElements, const float* inData0, const float* inData1, float* outData)
-{
-    for (unsigned int i = 0; i < numElements; ++i)
-    {
-        outData[i] = inData0[i] - inData1[i];
-    }
-}
-
-} // namespace
-
-namespace armnn
-{
-
-void Subtraction(const TensorShape& inShape0,
-                 const TensorShape& inShape1,
-                 const TensorShape& outShape,
-                 const float* inData0,
-                 const float* inData1,
-                 float* outData)
-{
-    if (inShape0 == inShape1)
-    {
-        ElementwiseSubtraction(inShape0.GetNumElements(), inData0, inData1, outData);
-    }
-    else
-    {
-        BroadcastLoop(inShape0, inShape1, outShape).Unroll(std::minus<float>(), 0, inData0, inData1, outData);
-    }
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/RefWorkloads/Subtraction.hpp b/src/armnn/backends/RefWorkloads/Subtraction.hpp
deleted file mode 100644 (file)
index 3956797..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include <armnn/Tensor.hpp>
-
-namespace armnn
-{
-
-void Subtraction(const TensorShape& inShape0,
-                 const TensorShape& inShape1,
-                 const TensorShape& outShape,
-                 const float* inData0,
-                 const float* inData1,
-                 float* outData);
-
-} //namespace armnn