2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <backendsCommon/OutputHandler.hpp>
8 #include <aclCommon/ArmComputeTensorHandle.hpp>
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
13 #include <arm_compute/runtime/CL/CLTensor.h>
14 #include <arm_compute/runtime/CL/CLSubTensor.h>
15 #include <arm_compute/runtime/CL/CLMemoryGroup.h>
16 #include <arm_compute/runtime/IMemoryGroup.h>
17 #include <arm_compute/core/TensorShape.h>
18 #include <arm_compute/core/Coordinates.h>
20 #include <boost/polymorphic_pointer_cast.hpp>
26 class IClTensorHandle : public IAclTensorHandle
29 virtual arm_compute::ICLTensor& GetTensor() = 0;
30 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
31 virtual arm_compute::DataType GetDataType() const = 0;
32 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
35 class ClTensorHandle : public IClTensorHandle
38 ClTensorHandle(const TensorInfo& tensorInfo)
40 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
43 ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
45 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
48 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
49 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
50 virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
52 virtual void Manage() override
54 assert(m_MemoryGroup != nullptr);
55 m_MemoryGroup->manage(&m_Tensor);
58 virtual const void* Map(bool blocking = true) const override
60 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
61 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
64 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
66 virtual ITensorHandle* GetParent() const override { return nullptr; }
68 virtual arm_compute::DataType GetDataType() const override
70 return m_Tensor.info()->data_type();
73 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
75 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::CLMemoryGroup>(memoryGroup);
78 TensorShape GetStrides() const override
80 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
83 TensorShape GetShape() const override
85 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
89 // Only used for testing
90 void CopyOutTo(void* memory) const override
92 const_cast<armnn::ClTensorHandle*>(this)->Map(true);
93 switch(this->GetDataType())
95 case arm_compute::DataType::F32:
96 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
97 static_cast<float*>(memory));
99 case arm_compute::DataType::U8:
100 case arm_compute::DataType::QASYMM8:
101 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
102 static_cast<uint8_t*>(memory));
104 case arm_compute::DataType::F16:
105 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
106 static_cast<armnn::Half*>(memory));
108 case arm_compute::DataType::S16:
109 case arm_compute::DataType::QSYMM16:
110 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
111 static_cast<int16_t*>(memory));
115 throw armnn::UnimplementedException();
118 const_cast<armnn::ClTensorHandle*>(this)->Unmap();
121 // Only used for testing
122 void CopyInFrom(const void* memory) override
125 switch(this->GetDataType())
127 case arm_compute::DataType::F32:
128 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
131 case arm_compute::DataType::U8:
132 case arm_compute::DataType::QASYMM8:
133 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
136 case arm_compute::DataType::F16:
137 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
140 case arm_compute::DataType::S16:
141 case arm_compute::DataType::QSYMM16:
142 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
147 throw armnn::UnimplementedException();
153 arm_compute::CLTensor m_Tensor;
154 std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup;
157 class ClSubTensorHandle : public IClTensorHandle
160 ClSubTensorHandle(IClTensorHandle* parent,
161 const arm_compute::TensorShape& shape,
162 const arm_compute::Coordinates& coords)
163 : m_Tensor(&parent->GetTensor(), shape, coords)
165 parentHandle = parent;
168 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
169 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
171 virtual void Allocate() override {}
172 virtual void Manage() override {}
174 virtual const void* Map(bool blocking = true) const override
176 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
177 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
179 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
181 virtual ITensorHandle* GetParent() const override { return parentHandle; }
183 virtual arm_compute::DataType GetDataType() const override
185 return m_Tensor.info()->data_type();
188 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
190 TensorShape GetStrides() const override
192 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
195 TensorShape GetShape() const override
197 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
201 // Only used for testing
202 void CopyOutTo(void* memory) const override
204 const_cast<ClSubTensorHandle*>(this)->Map(true);
205 switch(this->GetDataType())
207 case arm_compute::DataType::F32:
208 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
209 static_cast<float*>(memory));
211 case arm_compute::DataType::U8:
212 case arm_compute::DataType::QASYMM8:
213 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
214 static_cast<uint8_t*>(memory));
216 case arm_compute::DataType::F16:
217 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
218 static_cast<armnn::Half*>(memory));
220 case arm_compute::DataType::S16:
221 case arm_compute::DataType::QSYMM16:
222 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
223 static_cast<int16_t*>(memory));
227 throw armnn::UnimplementedException();
230 const_cast<ClSubTensorHandle*>(this)->Unmap();
233 // Only used for testing
234 void CopyInFrom(const void* memory) override
237 switch(this->GetDataType())
239 case arm_compute::DataType::F32:
240 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
243 case arm_compute::DataType::U8:
244 case arm_compute::DataType::QASYMM8:
245 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
248 case arm_compute::DataType::F16:
249 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
252 case arm_compute::DataType::S16:
253 case arm_compute::DataType::QSYMM16:
254 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
259 throw armnn::UnimplementedException();
265 mutable arm_compute::CLSubTensor m_Tensor;
266 ITensorHandle* parentHandle = nullptr;