Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / NeonTensorHandle.hpp
index 684a5e1..3818d2c 100644 (file)
@@ -7,11 +7,14 @@
 #include "OutputHandler.hpp"
 #include "ArmComputeTensorUtils.hpp"
 
+#include <arm_compute/runtime/MemoryGroup.h>
+#include <arm_compute/runtime/IMemoryGroup.h>
 #include <arm_compute/runtime/Tensor.h>
 #include <arm_compute/runtime/SubTensor.h>
 #include <arm_compute/core/TensorShape.h>
 #include <arm_compute/core/Coordinates.h>
 
+#include <boost/polymorphic_pointer_cast.hpp>
 
 namespace armnn
 {
@@ -22,6 +25,7 @@ public:
     virtual arm_compute::ITensor& GetTensor() = 0;
     virtual arm_compute::ITensor const& GetTensor() const = 0;
     virtual arm_compute::DataType GetDataType() const = 0;
+    virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
 };
 
 class NeonTensorHandle : public INeonTensorHandle
@@ -34,47 +38,100 @@ public:
 
     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
+
     virtual void Allocate() override
     {
         armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
     };
 
+    virtual void Manage() override
+    {
+        BOOST_ASSERT(m_MemoryGroup != nullptr);
+        m_MemoryGroup->manage(&m_Tensor);
+    }
+
     virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
 
+    virtual ITensorHandle* GetParent() const override { return nullptr; }
+
     virtual arm_compute::DataType GetDataType() const override
     {
         return m_Tensor.info()->data_type();
     }
 
+    virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
+    {
+        m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
+    }
+
+    virtual const void* Map(bool /* blocking = true */) const override
+    {
+        return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
+    }
+    virtual void Unmap() const override {}
+
+
+    TensorShape GetStrides() const override
+    {
+        return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
+    }
+
+    TensorShape GetShape() const override
+    {
+        return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
+    }
+
 private:
     arm_compute::Tensor m_Tensor;
+    std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
 };
 
 class NeonSubTensorHandle : public INeonTensorHandle
 {
 public:
-    NeonSubTensorHandle(arm_compute::ITensor& parent,
-        const arm_compute::TensorShape& shape,
-        const arm_compute::Coordinates& coords)
-     : m_Tensor(&parent, shape, coords)
+    NeonSubTensorHandle(INeonTensorHandle* parent,
+                        const arm_compute::TensorShape& shape,
+                        const arm_compute::Coordinates& coords)
+     : m_Tensor(&parent->GetTensor(), shape, coords)
     {
+        parentHandle = parent;
     }
 
     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
-    virtual void Allocate() override
-    {
-    };
+
+    virtual void Allocate() override {}
+    virtual void Manage() override {}
 
     virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
 
+    virtual ITensorHandle* GetParent() const override { return parentHandle; }
+
     virtual arm_compute::DataType GetDataType() const override
     {
         return m_Tensor.info()->data_type();
     }
 
+    virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
+
+    virtual const void* Map(bool /* blocking = true */) const override
+    {
+        return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
+    }
+    virtual void Unmap() const override {}
+
+    TensorShape GetStrides() const override
+    {
+        return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
+    }
+
+    TensorShape GetShape() const override
+    {
+        return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
+    }
 private:
-    arm_compute::SubTensor m_Tensor;   
+    arm_compute::SubTensor m_Tensor;
+    ITensorHandle* parentHandle = nullptr;
 };
 
 }