d08b79f9a6faa95fda7f612958d2fe7a0d91c3a7
[platform/upstream/armnn.git] / src / backends / cl / ClTensorHandle.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <backendsCommon/OutputHandler.hpp>
8 #include <aclCommon/ArmComputeTensorHandle.hpp>
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10
11 #include <Half.hpp>
12
13 #include <arm_compute/runtime/CL/CLTensor.h>
14 #include <arm_compute/runtime/CL/CLSubTensor.h>
15 #include <arm_compute/runtime/CL/CLMemoryGroup.h>
16 #include <arm_compute/runtime/IMemoryGroup.h>
17 #include <arm_compute/core/TensorShape.h>
18 #include <arm_compute/core/Coordinates.h>
19
20 #include <boost/polymorphic_pointer_cast.hpp>
21
22 namespace armnn
23 {
24
25
26 class IClTensorHandle : public IAclTensorHandle
27 {
28 public:
29     virtual arm_compute::ICLTensor& GetTensor() = 0;
30     virtual arm_compute::ICLTensor const& GetTensor() const = 0;
31     virtual arm_compute::DataType GetDataType() const = 0;
32     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
33 };
34
35 class ClTensorHandle : public IClTensorHandle
36 {
37 public:
38     ClTensorHandle(const TensorInfo& tensorInfo)
39     {
40         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
41     }
42
43     ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
44     {
45         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
46     }
47
48     arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
49     arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
50     virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
51
52     virtual void Manage() override
53     {
54         assert(m_MemoryGroup != nullptr);
55         m_MemoryGroup->manage(&m_Tensor);
56     }
57
58     virtual const void* Map(bool blocking = true) const override
59     {
60         const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
61         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
62     }
63
64     virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
65
66     virtual ITensorHandle* GetParent() const override { return nullptr; }
67
68     virtual arm_compute::DataType GetDataType() const override
69     {
70         return m_Tensor.info()->data_type();
71     }
72
73     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
74     {
75         m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::CLMemoryGroup>(memoryGroup);
76     }
77
78     TensorShape GetStrides() const override
79     {
80         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
81     }
82
83     TensorShape GetShape() const override
84     {
85         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
86     }
87
88 private:
89     // Only used for testing
90     void CopyOutTo(void* memory) const override
91     {
92         const_cast<armnn::ClTensorHandle*>(this)->Map(true);
93         switch(this->GetDataType())
94         {
95             case arm_compute::DataType::F32:
96                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
97                                                                  static_cast<float*>(memory));
98                 break;
99             case arm_compute::DataType::U8:
100             case arm_compute::DataType::QASYMM8:
101                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
102                                                                  static_cast<uint8_t*>(memory));
103                 break;
104             case arm_compute::DataType::F16:
105                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
106                                                                  static_cast<armnn::Half*>(memory));
107                 break;
108             case arm_compute::DataType::S16:
109             case arm_compute::DataType::QSYMM16:
110                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
111                                                                  static_cast<int16_t*>(memory));
112                 break;
113             default:
114             {
115                 throw armnn::UnimplementedException();
116             }
117         }
118         const_cast<armnn::ClTensorHandle*>(this)->Unmap();
119     }
120
121     // Only used for testing
122     void CopyInFrom(const void* memory) override
123     {
124         this->Map(true);
125         switch(this->GetDataType())
126         {
127             case arm_compute::DataType::F32:
128                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
129                                                                  this->GetTensor());
130                 break;
131             case arm_compute::DataType::U8:
132             case arm_compute::DataType::QASYMM8:
133                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
134                                                                  this->GetTensor());
135                 break;
136             case arm_compute::DataType::F16:
137                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
138                                                                  this->GetTensor());
139                 break;
140             case arm_compute::DataType::S16:
141             case arm_compute::DataType::QSYMM16:
142                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
143                                                                  this->GetTensor());
144                 break;
145             default:
146             {
147                 throw armnn::UnimplementedException();
148             }
149         }
150         this->Unmap();
151     }
152
153     arm_compute::CLTensor m_Tensor;
154     std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup;
155 };
156
157 class ClSubTensorHandle : public IClTensorHandle
158 {
159 public:
160     ClSubTensorHandle(IClTensorHandle* parent,
161                       const arm_compute::TensorShape& shape,
162                       const arm_compute::Coordinates& coords)
163     : m_Tensor(&parent->GetTensor(), shape, coords)
164     {
165         parentHandle = parent;
166     }
167
168     arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
169     arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
170
171     virtual void Allocate() override {}
172     virtual void Manage() override {}
173
174     virtual const void* Map(bool blocking = true) const override
175     {
176         const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
177         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
178     }
179     virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
180
181     virtual ITensorHandle* GetParent() const override { return parentHandle; }
182
183     virtual arm_compute::DataType GetDataType() const override
184     {
185         return m_Tensor.info()->data_type();
186     }
187
188     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
189
190     TensorShape GetStrides() const override
191     {
192         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
193     }
194
195     TensorShape GetShape() const override
196     {
197         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
198     }
199
200 private:
201     // Only used for testing
202     void CopyOutTo(void* memory) const override
203     {
204         const_cast<ClSubTensorHandle*>(this)->Map(true);
205         switch(this->GetDataType())
206         {
207             case arm_compute::DataType::F32:
208                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
209                                                                  static_cast<float*>(memory));
210                 break;
211             case arm_compute::DataType::U8:
212             case arm_compute::DataType::QASYMM8:
213                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
214                                                                  static_cast<uint8_t*>(memory));
215                 break;
216             case arm_compute::DataType::F16:
217                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
218                                                                  static_cast<armnn::Half*>(memory));
219                 break;
220             case arm_compute::DataType::S16:
221             case arm_compute::DataType::QSYMM16:
222                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
223                                                                  static_cast<int16_t*>(memory));
224                 break;
225             default:
226             {
227                 throw armnn::UnimplementedException();
228             }
229         }
230         const_cast<ClSubTensorHandle*>(this)->Unmap();
231     }
232
233     // Only used for testing
234     void CopyInFrom(const void* memory) override
235     {
236         this->Map(true);
237         switch(this->GetDataType())
238         {
239             case arm_compute::DataType::F32:
240                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
241                                                                  this->GetTensor());
242                 break;
243             case arm_compute::DataType::U8:
244             case arm_compute::DataType::QASYMM8:
245                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
246                                                                  this->GetTensor());
247                 break;
248             case arm_compute::DataType::F16:
249                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
250                                                                  this->GetTensor());
251                 break;
252             case arm_compute::DataType::S16:
253             case arm_compute::DataType::QSYMM16:
254                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
255                                                                  this->GetTensor());
256                 break;
257             default:
258             {
259                 throw armnn::UnimplementedException();
260             }
261         }
262         this->Unmap();
263     }
264
265     mutable arm_compute::CLSubTensor m_Tensor;
266     ITensorHandle* parentHandle = nullptr;
267 };
268
269 } // namespace armnn