IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / neon / NeonTensorHandle.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <backendsCommon/OutputHandler.hpp>
8 #include <aclCommon/ArmComputeTensorUtils.hpp>
9
10 #include <arm_compute/runtime/MemoryGroup.h>
11 #include <arm_compute/runtime/IMemoryGroup.h>
12 #include <arm_compute/runtime/Tensor.h>
13 #include <arm_compute/runtime/SubTensor.h>
14 #include <arm_compute/core/TensorShape.h>
15 #include <arm_compute/core/Coordinates.h>
16
17 #include <boost/polymorphic_pointer_cast.hpp>
18
19 namespace armnn
20 {
21
22 class INeonTensorHandle : public ITensorHandle
23 {
24 public:
25     virtual arm_compute::ITensor& GetTensor() = 0;
26     virtual arm_compute::ITensor const& GetTensor() const = 0;
27     virtual arm_compute::DataType GetDataType() const = 0;
28     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
29 };
30
31 class NeonTensorHandle : public INeonTensorHandle
32 {
33 public:
34     NeonTensorHandle(const TensorInfo& tensorInfo)
35     {
36         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
37     }
38
39     NeonTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
40     {
41         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
42     }
43
44     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
45     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
46
47     virtual void Allocate() override
48     {
49         armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
50     };
51
52     virtual void Manage() override
53     {
54         BOOST_ASSERT(m_MemoryGroup != nullptr);
55         m_MemoryGroup->manage(&m_Tensor);
56     }
57
58     virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
59
60     virtual ITensorHandle* GetParent() const override { return nullptr; }
61
62     virtual arm_compute::DataType GetDataType() const override
63     {
64         return m_Tensor.info()->data_type();
65     }
66
67     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
68     {
69         m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
70     }
71
72     virtual const void* Map(bool /* blocking = true */) const override
73     {
74         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
75     }
76     virtual void Unmap() const override {}
77
78
79     TensorShape GetStrides() const override
80     {
81         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
82     }
83
84     TensorShape GetShape() const override
85     {
86         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
87     }
88
89 private:
90     arm_compute::Tensor m_Tensor;
91     std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
92 };
93
94 class NeonSubTensorHandle : public INeonTensorHandle
95 {
96 public:
97     NeonSubTensorHandle(INeonTensorHandle* parent,
98                         const arm_compute::TensorShape& shape,
99                         const arm_compute::Coordinates& coords)
100      : m_Tensor(&parent->GetTensor(), shape, coords)
101     {
102         parentHandle = parent;
103     }
104
105     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
106     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
107
108     virtual void Allocate() override {}
109     virtual void Manage() override {}
110
111     virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
112
113     virtual ITensorHandle* GetParent() const override { return parentHandle; }
114
115     virtual arm_compute::DataType GetDataType() const override
116     {
117         return m_Tensor.info()->data_type();
118     }
119
120     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
121
122     virtual const void* Map(bool /* blocking = true */) const override
123     {
124         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
125     }
126     virtual void Unmap() const override {}
127
128     TensorShape GetStrides() const override
129     {
130         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
131     }
132
133     TensorShape GetShape() const override
134     {
135         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
136     }
137 private:
138     arm_compute::SubTensor m_Tensor;
139     ITensorHandle* parentHandle = nullptr;
140 };
141
142 }