Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / NeonTensorHandle.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "OutputHandler.hpp"
8 #include "ArmComputeTensorUtils.hpp"
9
10 #include <arm_compute/runtime/MemoryGroup.h>
11 #include <arm_compute/runtime/IMemoryGroup.h>
12 #include <arm_compute/runtime/Tensor.h>
13 #include <arm_compute/runtime/SubTensor.h>
14 #include <arm_compute/core/TensorShape.h>
15 #include <arm_compute/core/Coordinates.h>
16
17 #include <boost/polymorphic_pointer_cast.hpp>
18
19 namespace armnn
20 {
21
22 class INeonTensorHandle : public ITensorHandle
23 {
24 public:
25     virtual arm_compute::ITensor& GetTensor() = 0;
26     virtual arm_compute::ITensor const& GetTensor() const = 0;
27     virtual arm_compute::DataType GetDataType() const = 0;
28     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
29 };
30
31 class NeonTensorHandle : public INeonTensorHandle
32 {
33 public:
34     NeonTensorHandle(const TensorInfo& tensorInfo)
35     {
36         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
37     }
38
39     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
40     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
41
42     virtual void Allocate() override
43     {
44         armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
45     };
46
47     virtual void Manage() override
48     {
49         BOOST_ASSERT(m_MemoryGroup != nullptr);
50         m_MemoryGroup->manage(&m_Tensor);
51     }
52
53     virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
54
55     virtual ITensorHandle* GetParent() const override { return nullptr; }
56
57     virtual arm_compute::DataType GetDataType() const override
58     {
59         return m_Tensor.info()->data_type();
60     }
61
62     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
63     {
64         m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
65     }
66
67     virtual const void* Map(bool /* blocking = true */) const override
68     {
69         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
70     }
71     virtual void Unmap() const override {}
72
73
74     TensorShape GetStrides() const override
75     {
76         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
77     }
78
79     TensorShape GetShape() const override
80     {
81         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
82     }
83
84 private:
85     arm_compute::Tensor m_Tensor;
86     std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
87 };
88
89 class NeonSubTensorHandle : public INeonTensorHandle
90 {
91 public:
92     NeonSubTensorHandle(INeonTensorHandle* parent,
93                         const arm_compute::TensorShape& shape,
94                         const arm_compute::Coordinates& coords)
95      : m_Tensor(&parent->GetTensor(), shape, coords)
96     {
97         parentHandle = parent;
98     }
99
100     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
101     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
102
103     virtual void Allocate() override {}
104     virtual void Manage() override {}
105
106     virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
107
108     virtual ITensorHandle* GetParent() const override { return parentHandle; }
109
110     virtual arm_compute::DataType GetDataType() const override
111     {
112         return m_Tensor.info()->data_type();
113     }
114
115     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
116
117     virtual const void* Map(bool /* blocking = true */) const override
118     {
119         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
120     }
121     virtual void Unmap() const override {}
122
123     TensorShape GetStrides() const override
124     {
125         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
126     }
127
128     TensorShape GetShape() const override
129     {
130         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
131     }
132 private:
133     arm_compute::SubTensor m_Tensor;
134     ITensorHandle* parentHandle = nullptr;
135 };
136
137 }