IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / cl / ClTensorHandle.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <backendsCommon/OutputHandler.hpp>
8 #include <aclCommon/ArmComputeTensorUtils.hpp>
9
10 #include <arm_compute/runtime/CL/CLTensor.h>
11 #include <arm_compute/runtime/CL/CLSubTensor.h>
12 #include <arm_compute/runtime/CL/CLMemoryGroup.h>
13 #include <arm_compute/runtime/IMemoryGroup.h>
14 #include <arm_compute/core/TensorShape.h>
15 #include <arm_compute/core/Coordinates.h>
16
17 #include <boost/polymorphic_pointer_cast.hpp>
18
19 namespace armnn
20 {
21
22
23 class IClTensorHandle : public ITensorHandle
24 {
25 public:
26     virtual arm_compute::ICLTensor& GetTensor() = 0;
27     virtual arm_compute::ICLTensor const& GetTensor() const = 0;
28     virtual arm_compute::DataType GetDataType() const = 0;
29     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
30 };
31
32 class ClTensorHandle : public IClTensorHandle
33 {
34 public:
35     ClTensorHandle(const TensorInfo& tensorInfo)
36     {
37         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
38     }
39
40     ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
41     {
42         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
43     }
44
45     arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
46     arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
47     virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
48
49     virtual void Manage() override
50     {
51         assert(m_MemoryGroup != nullptr);
52         m_MemoryGroup->manage(&m_Tensor);
53     }
54
55     virtual const void* Map(bool blocking = true) const override
56     {
57         const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
58         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
59     }
60     virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
61
62     virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; }
63
64     virtual ITensorHandle* GetParent() const override { return nullptr; }
65
66     virtual arm_compute::DataType GetDataType() const override
67     {
68         return m_Tensor.info()->data_type();
69     }
70
71     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
72     {
73         m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::CLMemoryGroup>(memoryGroup);
74     }
75
76     TensorShape GetStrides() const override
77     {
78         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
79     }
80
81     TensorShape GetShape() const override
82     {
83         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
84     }
85 private:
86     arm_compute::CLTensor m_Tensor;
87     std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup;
88 };
89
90 class ClSubTensorHandle : public IClTensorHandle
91 {
92 public:
93     ClSubTensorHandle(IClTensorHandle* parent,
94                       const arm_compute::TensorShape& shape,
95                       const arm_compute::Coordinates& coords)
96     : m_Tensor(&parent->GetTensor(), shape, coords)
97     {
98         parentHandle = parent;
99     }
100
101     arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
102     arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
103
104     virtual void Allocate() override {}
105     virtual void Manage() override {}
106
107     virtual const void* Map(bool blocking = true) const override
108     {
109         const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
110         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
111     }
112     virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
113
114     virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; }
115
116     virtual ITensorHandle* GetParent() const override { return parentHandle; }
117
118     virtual arm_compute::DataType GetDataType() const override
119     {
120         return m_Tensor.info()->data_type();
121     }
122
123     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
124
125     TensorShape GetStrides() const override
126     {
127         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
128     }
129
130     TensorShape GetShape() const override
131     {
132         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
133     }
134
135 private:
136     mutable arm_compute::CLSubTensor m_Tensor;
137     ITensorHandle* parentHandle = nullptr;
138
139 };
140
141 }