IVGCVSW-2957 MergerLayer subtensor optimization now backend agnostic
authorDerek Lamberti <derek.lamberti@arm.com>
Mon, 15 Apr 2019 17:37:35 +0000 (18:37 +0100)
committerderek.lamberti <derek.lamberti@arm.com>
Tue, 16 Apr 2019 13:50:11 +0000 (13:50 +0000)
+ Update clframework pin
+ Cl and Neon Merger workloads updated to use MemoryLayout agnostic API
+ Workloads only use sub-tensor optimization if ALL input tensors are sub-tensors
+ Refactor LayerSupportCommon code to be a bit more succinct

Change-Id: Ib61ad4ccbd767e924dff07e61022e0cda4069828
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
13 files changed:
include/armnn/Tensor.hpp
scripts/get_compute_library.sh
src/armnn/LayerSupportCommon.hpp
src/armnn/Tensor.cpp
src/armnn/layers/MergerLayer.cpp
src/backends/cl/ClLayerSupport.cpp
src/backends/cl/ClWorkloadFactory.cpp
src/backends/cl/workloads/ClMergerWorkload.cpp
src/backends/cl/workloads/ClMergerWorkload.hpp
src/backends/neon/NeonLayerSupport.cpp
src/backends/neon/NeonWorkloadFactory.cpp
src/backends/neon/workloads/NeonMergerWorkload.cpp
src/backends/neon/workloads/NeonMergerWorkload.hpp

index 503c161..160ccca 100644 (file)
@@ -80,7 +80,11 @@ public:
     int32_t GetQuantizationOffset() const           { return m_Quantization.m_Offset; }
     void SetQuantizationScale(float scale)          { m_Quantization.m_Scale = scale; }
     void SetQuantizationOffset(int32_t offset)      { m_Quantization.m_Offset = offset; }
-    bool IsQuantized() const                        { return m_DataType == DataType::QuantisedAsymm8; }
+    bool IsQuantized() const                        { return m_DataType == DataType::QuantisedAsymm8 ||
+                                                             m_DataType == DataType::QuantisedSymm16; }
+
+    /// Check that the types are the same and, if quantize, that the quantization parameters are the same.
+    bool IsTypeSpaceMatch(const TensorInfo& other) const;
 
     unsigned int GetNumBytes() const;
 
index f3d1a8c..8a35bd3 100755 (executable)
@@ -10,7 +10,7 @@ CMD=$( basename $0 )
 # DEFAULT_CLFRAMEWORKREVISION="branches/arm_compute_19_02" # Release 19.02
 #
 # For pinning to a revision use this:
-DEFAULT_CLFRAMEWORKREVISION="a4bba9c594c4022c9f85192bb8fd3593ad1a8d3c" # COMPMID-1995: Fix 32-bit NEDepthwiseConvolution errors.
+DEFAULT_CLFRAMEWORKREVISION="9e4824c909b14dbaf7106e9527b0ffa22ef09bdc"
 
 usage() {
     echo "Usage: $CMD (Use the default clframework SHA)"
index 70b5f18..3e2a124 100644 (file)
 namespace armnn
 {
 
+template<typename T, typename V>
+void SetValueChecked(Optional<T&> optionalRef, V&& val)
+{
+    if (optionalRef)
+    {
+        optionalRef.value() = val;
+    }
+}
+
 template<typename Float16Func, typename Float32Func, typename Uint8Func, typename Int32Func, typename BooleanFunc,
          typename ... Params>
 bool IsSupportedForDataTypeGeneric(Optional<std::string&> reasonIfUnsupported,
@@ -55,80 +64,56 @@ bool FalseFunc(Optional<std::string&> reasonIfUnsupported, Params&&... params)
 template<typename ... Params>
 bool FalseFuncF16(Optional<std::string&> reasonIfUnsupported, Params&&... params)
 {
-    if (reasonIfUnsupported)
-    {
-        reasonIfUnsupported.value() = "Layer is not supported with float16 data type";
-    }
+    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float16 data type");
     return false;
 }
 
 template<typename ... Params>
 bool FalseFuncF32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
 {
-    if (reasonIfUnsupported)
-    {
-        reasonIfUnsupported.value() = "Layer is not supported with float32 data type";
-    }
+    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float32 data type");
     return false;
 }
 
 template<typename ... Params>
 bool FalseFuncU8(Optional<std::string&> reasonIfUnsupported, Params&&... params)
 {
-    if (reasonIfUnsupported)
-    {
-        reasonIfUnsupported.value() = "Layer is not supported with 8-bit data type";
-    }
+    SetValueChecked(reasonIfUnsupported, "Layer is not supported with 8-bit data type");
     return false;
 }
 
 template<typename ... Params>
 bool FalseFuncI32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
 {
-    if (reasonIfUnsupported)
-    {
-        reasonIfUnsupported.value() = "Layer is not supported with int32 data type";
-    }
+    SetValueChecked(reasonIfUnsupported, "Layer is not supported with int32 data type");
     return false;
 }
 
 template<typename ... Params>
 bool FalseInputFuncF32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
 {
-    if (reasonIfUnsupported)
-    {
-        reasonIfUnsupported.value() = "Layer is not supported with float32 data type input";
-    }
+    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float32 data type input");
     return false;
 }
 
 template<typename ... Params>
 bool FalseInputFuncF16(Optional<std::string&> reasonIfUnsupported, Params&&... params)
 {
-    if (reasonIfUnsupported)
-    {
-        reasonIfUnsupported.value() = "Layer is not supported with float16 data type input";
-    }
+    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float16 data type input");
     return false;
 }
 
 template<typename ... Params>
 bool FalseOutputFuncF32(Optional<std::string&> reasonIfUnsupported, Params&&... params)
 {
-    if (reasonIfUnsupported)
-    {
-        reasonIfUnsupported.value() = "Layer is not supported with float32 data type output";
-    }
+    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float32 data type output");
     return false;
 }
 
 template<typename ... Params>
 bool FalseOutputFuncF16(Optional<std::string&> reasonIfUnsupported, Params&&... params)
 {
-    if (reasonIfUnsupported)
-    {
-        reasonIfUnsupported.value() = "Layer is not supported with float16 data type output";
-    }
+    SetValueChecked(reasonIfUnsupported, "Layer is not supported with float16 data type output");
     return false;
 }
 
index da19e5b..614abc7 100644 (file)
@@ -188,6 +188,20 @@ unsigned int TensorInfo::GetNumBytes() const
     return GetDataTypeSize(m_DataType) * GetNumElements();
 }
 
+bool TensorInfo::IsTypeSpaceMatch(const TensorInfo& other) const
+{
+    bool match = true;
+
+    match &= m_DataType == other.m_DataType;
+
+    if (IsQuantized())
+    {
+        match &= GetQuantizationScale() == other.GetQuantizationScale() &&
+                 GetQuantizationOffset() == other.GetQuantizationOffset();
+    }
+    return match;
+}
+
 // ---
 // --- BaseTensor
 // ---
index f87f349..c674f64 100644 (file)
@@ -36,14 +36,12 @@ std::unique_ptr<IWorkload> MergerLayer::CreateWorkload(const Graph& graph, const
 
 void MergerLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory)
 {
-    //If sub tensors are supported than the merger
+    //If sub tensors are supported then the merger
     //just needs to make sure that the outputs of the prev layer
     //are made subtensors of the output of the merger layer.
     m_OutputHandlers[0].CreateTensorHandles(factory);
 
-    unsigned int innerAxis = m_Param.GetNumDimensions() - m_Param.GetConcatAxis();
-
-    if (factory.SupportsSubTensors() && innerAxis != 1)
+    if (factory.SupportsSubTensors())
     {
         std::queue<MergerLayer*> m_MergerLayers;
 
@@ -52,23 +50,65 @@ void MergerLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& fact
         {
             MergerLayer* currentLayer = m_MergerLayers.front();
             ITensorHandle* parentTensor = currentLayer->GetOutputHandler(0).GetData();
-
+            const TensorInfo& parentInfo = currentLayer->GetOutputHandler(0).GetTensorInfo();
             m_MergerLayers.pop();
 
             const unsigned int numInputSlots = currentLayer->GetNumInputSlots();
+
+            // First go through all the input slots and verify that we can sub-tensor all the inputs.
+            std::vector<std::unique_ptr<ITensorHandle>> subTensors(0);
+            subTensors.reserve(numInputSlots);
             for (unsigned int i = 0; i < numInputSlots; ++i)
             {
                 OutputSlot* slot = currentLayer->GetInputSlot(i).GetConnectedOutputSlot();
+                const TensorInfo& info = slot->GetTensorInfo();
+
+                auto CreateSubTensor = [&]()
+                {
+                    // Make sure quantization parameters are in the same space
+                    if (parentInfo.IsTypeSpaceMatch(info))
+                    {
+                        return factory.CreateSubTensorHandle(*parentTensor,
+                                                             info.GetShape(),
+                                                             currentLayer->m_Param.GetViewOrigin(i));
+                    }
+                    return std::unique_ptr<ITensorHandle>();
+                };
+
+                auto subTensor = CreateSubTensor();
+                if (!subTensor)
+                {
+                    break; //Failed to create a valid sub-tensor, so stop trying with the rest of the inputs.
+                }
+                else
+                {
+                    subTensors.push_back(std::move(subTensor)); // store the valid sub-tensor.
+                }
+            }
+
+            // Ensure that ALL inputs can be substituted with valid sub-tensors
+            if (subTensors.size() < numInputSlots)
+            {
+                continue; // Don't optimize this Merge layer with sub-tensors
+            }
+
+            // Substitute input tensors with sub-tensors by replacing the output tensors on the connected layers.
+            unsigned int i=0;
+            for (auto& subTensor : subTensors)
+            {
+                OutputSlot* slot = currentLayer->GetInputSlot(i).GetConnectedOutputSlot();
                 OutputHandler& outputHandler = slot->GetOutputHandler();
-                outputHandler.SetData(factory.CreateSubTensorHandle(*parentTensor,
-                                                                    outputHandler.GetTensorInfo().GetShape(),
-                                                                    currentLayer->m_Param.GetViewOrigin(i)));
+
+                BOOST_ASSERT_MSG(subTensor, "MergerLayer: Expected a valid sub-tensor for substitution.");
+                outputHandler.SetData(std::move(subTensor));
 
                 Layer& inputLayer = slot->GetOwningLayer();
                 if (inputLayer.GetType() == LayerType::Merger)
                 {
+                    // Continue with the substitution if the connected inputs are also merger layers
                     m_MergerLayers.push(boost::polymorphic_downcast<MergerLayer*>(&inputLayer));
                 }
+                ++i;
             }
         }
     }
index cfc0f11..a5c5f2b 100644 (file)
@@ -416,7 +416,14 @@ bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inpu
                                        const OriginsDescriptor& descriptor,
                                        Optional<std::string&> reasonIfUnsupported) const
 {
-    if(descriptor.GetNumDimensions() - descriptor.GetConcatAxis() == 1)
+    if (descriptor.GetNumDimensions() <= descriptor.GetConcatAxis())
+    {
+        SetValueChecked(reasonIfUnsupported, "Cl Merger: Concat axis > Number of dimensions.");
+        return false;
+    }
+
+    unsigned int concatInnerAxis = (descriptor.GetNumDimensions() - descriptor.GetConcatAxis()) - 1;
+    if(concatInnerAxis < 3) // Width, height, or channels
     {
         FORWARD_WORKLOAD_VALIDATE_FUNC(ClMergerWorkloadValidate,
                                        reasonIfUnsupported,
@@ -424,12 +431,24 @@ bool ClLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> inpu
                                        output,
                                        descriptor);
     }
-    else
+    else if (concatInnerAxis == 3)
+    {
+        // We rely on the sub-tensor optimization to handle the batch dimension for 4D tensors. If we can't use
+        // sub-tensors for this then we can't support it. Here is where we check that the sub-tensors will work.
+        for (auto& input : inputs)
+        {
+            if (input && !output.IsTypeSpaceMatch(*input)) // Cannot use sub-tensors if the types are not same space
+            {
+                SetValueChecked(reasonIfUnsupported, "Cl Merger: Types and quantization parameters must match.");
+                return false;
+            }
+        }
+        return true; // Sub-tensors support concat along batch
+    }
+    else // > 4 dimensions not supported.
     {
-        return IsSupportedForDataTypeCl(reasonIfUnsupported,
-                                        inputs[0]->GetDataType(),
-                                        &TrueFunc<>,
-                                        &TrueFunc<>);
+        SetValueChecked(reasonIfUnsupported, "Cl Merger: Maximum of 4 dimensions supported.");
+        return false;
     }
 }
 
index d41a7e5..e4097a1 100644 (file)
@@ -113,6 +113,12 @@ std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateSubTensorHandle(ITensorH
         coords.set(i, boost::numeric_cast<int>(subTensorOrigin[revertedIndex]));
     }
 
+    const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
+    if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
+    {
+        return nullptr;
+    }
+
     return std::make_unique<ClSubTensorHandle>(
         boost::polymorphic_downcast<IClTensorHandle*>(&parent), shape, coords);
 }
index e06d8c5..610acb9 100644 (file)
@@ -9,16 +9,25 @@
 #include <cl/ClTensorHandle.hpp>
 #include <cl/ClLayerSupport.hpp>
 
+#include <arm_compute/core/Types.h>
+
 #include <boost/polymorphic_pointer_cast.hpp>
 
 namespace armnn
 {
 using namespace armcomputetensorutils;
 
+namespace
+{
+size_t CalcAxis(const MergerDescriptor& desc)
+{
+    return (desc.GetNumDimensions() - desc.GetConcatAxis()) - 1;
+}
+} //namespace
+
 arm_compute::Status ClMergerWorkloadValidate(const std::vector<const TensorInfo*>& inputs,
                                              const TensorInfo& output,
                                              const MergerDescriptor& descriptor)
-
 {
     std::vector<arm_compute::TensorInfo> aclInputs;
     for (const TensorInfo* input : inputs)
@@ -27,59 +36,65 @@ arm_compute::Status ClMergerWorkloadValidate(const std::vector<const TensorInfo*
         aclInputs.emplace_back(aclInputInfo);
     }
     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
-    arm_compute::DataLayoutDimension aclAxis = arm_compute::DataLayoutDimension::WIDTH;
-
     std::vector<arm_compute::ITensorInfo*> aclInputPtrs;
     for (arm_compute::ITensorInfo& input : aclInputs)
     {
         aclInputPtrs.emplace_back(&input);
     }
 
+    size_t aclAxis = CalcAxis(descriptor);
     return arm_compute::CLConcatenateLayer::validate(aclInputPtrs, &aclOutputInfo, aclAxis);
-
 }
 
 ClMergerWorkload::ClMergerWorkload(const MergerQueueDescriptor& descriptor, const WorkloadInfo& info)
 : BaseWorkload<MergerQueueDescriptor>(descriptor, info)
 {
-    m_Execute = true;
+    bool allInputsAreSubtensors = true;
 
-    unsigned int innerAxisOrder = descriptor.m_Parameters.GetNumDimensions() - descriptor.m_Parameters.GetConcatAxis();
+    // Check that all inputs are sub-tensors
+    for (auto input : descriptor.m_Inputs)
+    {
+        if (!input->GetParent())
+        {
+            // Non sub-tensor input found so we need to execute the merger function
+            allInputsAreSubtensors = false;
+            break;
+        }
+    }
 
-    if (innerAxisOrder != 1)
+    if (allInputsAreSubtensors)
     {
-        m_Execute = false;
+        // Can skip configuring the merger function since it's not executed
         return;
     }
 
     std::vector<arm_compute::ICLTensor *> aclInputs;
-    arm_compute::DataLayout aclDataLayout = ConvertDataLayout(armnn::DataLayout::NCHW);
     for (auto input : m_Data.m_Inputs)
     {
         arm_compute::ICLTensor& aclInput  = boost::polymorphic_pointer_downcast<IClTensorHandle>(input)->GetTensor();
-        aclInput.info()->set_data_layout(aclDataLayout);
         aclInputs.emplace_back(&aclInput);
     }
     arm_compute::ICLTensor& output = boost::polymorphic_pointer_downcast<IClTensorHandle>(
                                                                          m_Data.m_Outputs[0])->GetTensor();
-    output.info()->set_data_layout(aclDataLayout);
-
-    arm_compute::DataLayoutDimension aclAxis = arm_compute::DataLayoutDimension::WIDTH;
 
-    m_Layer.configure(aclInputs, &output, aclAxis);
+    // Create the layer function
+    m_Layer.reset(new arm_compute::CLConcatenateLayer());
 
-    m_Layer.prepare();
+    // Configure input and output tensors
+    size_t aclAxis = CalcAxis(descriptor.m_Parameters);
+    m_Layer->configure(aclInputs, &output, aclAxis);
 
+    // Prepare
+    m_Layer->prepare();
 }
 
 void ClMergerWorkload::Execute() const
 {
-    if (m_Execute)
+    if (m_Layer)
     {
         ARMNN_SCOPED_PROFILING_EVENT_CL("ClMergerWorkload_Execute");
-        m_Layer.run();
+        m_Layer->run();
     }
-
 }
 
 } //namespace armnn
\ No newline at end of file
index 8189a1b..1c2f823 100644 (file)
@@ -24,8 +24,7 @@ public:
     void Execute() const override;
 
 private:
-    mutable arm_compute::CLConcatenateLayer m_Layer;
-    bool m_Execute;
+    mutable std::unique_ptr<arm_compute::CLConcatenateLayer> m_Layer;
 };
 
 } //namespace armnn
index 46a7e6f..898660c 100644 (file)
@@ -52,10 +52,7 @@ bool IsNeonBackendSupported(Optional<std::string&> reasonIfUnsupported)
 #if defined(ARMCOMPUTENEON_ENABLED)
     return true;
 #else
-    if (reasonIfUnsupported)
-    {
-        reasonIfUnsupported.value() = "The armnn library has been built without NEON support";
-    }
+    SetValueChecked(reasonIfUnsupported, "The armnn library has been built without NEON support");
     return false;
 #endif
 }
@@ -304,7 +301,14 @@ bool NeonLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> in
                                          const OriginsDescriptor& descriptor,
                                          Optional<std::string&> reasonIfUnsupported) const
 {
-    if(descriptor.GetNumDimensions() - descriptor.GetConcatAxis() == 1)
+    if (descriptor.GetNumDimensions() <= descriptor.GetConcatAxis())
+    {
+        SetValueChecked(reasonIfUnsupported, "Neon Merger: Concat axis > Number of dimensions.");
+        return false;
+    }
+
+    unsigned int concatInnerAxis = (descriptor.GetNumDimensions() - descriptor.GetConcatAxis()) - 1;
+    if(concatInnerAxis < 3) // Width, height, or channels
     {
         FORWARD_WORKLOAD_VALIDATE_FUNC(NeonMergerWorkloadValidate,
                                        reasonIfUnsupported,
@@ -312,13 +316,23 @@ bool NeonLayerSupport::IsMergerSupported(const std::vector<const TensorInfo*> in
                                        output,
                                        descriptor);
     }
-    else
-     {
-         return IsSupportedForDataTypeNeon(reasonIfUnsupported,
-                                           inputs[0]->GetDataType(),
-                                           &TrueFunc<>,
-                                           &TrueFunc<>);
-      }
+    else if (concatInnerAxis == 3)
+    {
+        for (auto& input : inputs)
+        {
+            if (input && !output.IsTypeSpaceMatch(*input)) // Cannot use sub-tensors if the types are not same space
+            {
+                SetValueChecked(reasonIfUnsupported, "Neon Merger: Types and quantization parameters must match.");
+                return false;
+            }
+        }
+        return true; // Sub-tensors support concat along batch
+    }
+    else // > 4 dimensions not supported.
+    {
+        SetValueChecked(reasonIfUnsupported, "Neon Merger: Maximum of 4 dimensions supported.");
+        return false;
+    }
 }
 
 bool NeonLayerSupport::IsMinimumSupported(const TensorInfo& input0,
index 101e59d..8db5f9a 100644 (file)
@@ -61,6 +61,12 @@ std::unique_ptr<ITensorHandle> NeonWorkloadFactory::CreateSubTensorHandle(ITenso
         coords.set(i, boost::numeric_cast<int>(subTensorOrigin[revertedIndex]));
     }
 
+    const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
+    if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
+    {
+        return nullptr;
+    }
+
     return std::make_unique<NeonSubTensorHandle>(
         boost::polymorphic_downcast<INeonTensorHandle*>(&parent), shape, coords);
 }
index be096b4..64d4d93 100644 (file)
 #include <backendsCommon/CpuTensorHandle.hpp>
 #include <neon/NeonTensorHandle.hpp>
 
-#include <arm_compute/runtime/NEON/functions/NEConcatenateLayer.h>
+
 
 namespace armnn
 {
 using namespace armcomputetensorutils;
 
+namespace
+{
+size_t CalcAxis(const armnn::MergerDescriptor& desc)
+{
+    return (desc.GetNumDimensions() - desc.GetConcatAxis()) - 1;
+}
+} //namespace
+
 arm_compute::Status NeonMergerWorkloadValidate(const std::vector<const TensorInfo*>& inputs,
                                                const TensorInfo& output,
                                                const MergerDescriptor& descriptor)
@@ -25,60 +33,66 @@ arm_compute::Status NeonMergerWorkloadValidate(const std::vector<const TensorInf
     std::vector<arm_compute::TensorInfo> aclInputs;
     for (const TensorInfo* input : inputs)
     {
-       arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(*input, armnn::DataLayout::NCHW);
-       aclInputs.emplace_back(aclInputInfo);
+        arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(*input, armnn::DataLayout::NCHW);
+        aclInputs.emplace_back(aclInputInfo);
     }
     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
-    arm_compute::DataLayoutDimension aclAxis = arm_compute::DataLayoutDimension::WIDTH;
-
     std::vector<arm_compute::ITensorInfo*> aclInputPtrs;
     for (arm_compute::ITensorInfo& input : aclInputs)
     {
         aclInputPtrs.emplace_back(&input);
     }
 
+    size_t aclAxis = CalcAxis(descriptor);
     return arm_compute::NEConcatenateLayer::validate(aclInputPtrs, &aclOutputInfo, aclAxis);
-
 }
 
 NeonMergerWorkload::NeonMergerWorkload(
 const MergerQueueDescriptor& descriptor, const WorkloadInfo& info)
         : BaseWorkload<MergerQueueDescriptor>(descriptor, info)
 {
-    m_Execute = true;
+    bool allInputsAreSubtensors = true;
 
-    unsigned int innerAxisOrder = descriptor.m_Parameters.GetNumDimensions() - descriptor.m_Parameters.GetConcatAxis();
+    // Check that all inputs are sub-tensors
+    for (auto input : descriptor.m_Inputs)
+    {
+        if (!input->GetParent())
+        {
+            // Non sub-tensor input found so we need to execute the merger function
+            allInputsAreSubtensors = false;
+            break;
+        }
+    }
 
-    if (innerAxisOrder != 1)
+    if (allInputsAreSubtensors)
     {
-        m_Execute = false;
+        // Can skip configuring the merger function since it's not executed
         return;
     }
 
     std::vector<arm_compute::ITensor *> aclInputs;
-    arm_compute::DataLayout aclDataLayout = ConvertDataLayout(armnn::DataLayout::NCHW);
     for (auto input : m_Data.m_Inputs)
     {
         arm_compute::ITensor& aclInput  = boost::polymorphic_pointer_downcast<INeonTensorHandle>(input)->GetTensor();
-        aclInput.info()->set_data_layout(aclDataLayout);
         aclInputs.emplace_back(&aclInput);
     }
     arm_compute::ITensor& output = boost::polymorphic_pointer_downcast<INeonTensorHandle>(
-                                                                       m_Data.m_Outputs[0])->GetTensor();
-    output.info()->set_data_layout(aclDataLayout);
+        m_Data.m_Outputs[0])->GetTensor();
 
-    arm_compute::DataLayoutDimension aclAxis = arm_compute::DataLayoutDimension::WIDTH;
+    // Create the layer function
+    m_Layer.reset(new arm_compute::NEConcatenateLayer());
 
-    auto layer = std::make_unique<arm_compute::NEConcatenateLayer>();
-    layer->configure(aclInputs, &output, aclAxis);
-    m_Layer.reset(layer.release());
+    // Configure input and output tensors
+    size_t aclAxis = CalcAxis(descriptor.m_Parameters);
+    m_Layer->configure(aclInputs, &output, aclAxis);
 
+    // Prepare
     m_Layer->prepare();
 }
 
 void NeonMergerWorkload::Execute() const
 {
-    if (m_Execute)
+    if (m_Layer)
     {
         ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonMergerWorkload_Execute");
         m_Layer->run();
index 3432c62..1dd9309 100644 (file)
@@ -9,7 +9,8 @@
 
 #include <arm_compute/core/Error.h>
 #include <arm_compute/runtime/IFunction.h>
-#
+#include <arm_compute/runtime/NEON/functions/NEConcatenateLayer.h>
+
 #include <memory>
 
 namespace armnn
@@ -27,9 +28,7 @@ public:
     void Execute() const override;
 
 private:
-    std::unique_ptr<arm_compute::IFunction> m_Layer;
-    bool m_Execute;
-
+    std::unique_ptr<arm_compute::NEConcatenateLayer> m_Layer;
 };
 
 } //namespace armnn