2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <backends/OutputHandler.hpp>
8 #include <backends/aclCommon/ArmComputeTensorUtils.hpp>
10 #include <arm_compute/runtime/CL/CLTensor.h>
11 #include <arm_compute/runtime/CL/CLSubTensor.h>
12 #include <arm_compute/runtime/CL/CLMemoryGroup.h>
13 #include <arm_compute/runtime/IMemoryGroup.h>
14 #include <arm_compute/core/TensorShape.h>
15 #include <arm_compute/core/Coordinates.h>
17 #include <boost/polymorphic_pointer_cast.hpp>
23 class IClTensorHandle : public ITensorHandle
26 virtual arm_compute::ICLTensor& GetTensor() = 0;
27 virtual arm_compute::ICLTensor const& GetTensor() const = 0;
28 virtual arm_compute::DataType GetDataType() const = 0;
29 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
32 class ClTensorHandle : public IClTensorHandle
35 ClTensorHandle(const TensorInfo& tensorInfo)
37 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
40 ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
42 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45 arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
46 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
47 virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
49 virtual void Manage() override
51 assert(m_MemoryGroup != nullptr);
52 m_MemoryGroup->manage(&m_Tensor);
55 virtual const void* Map(bool blocking = true) const override
57 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
58 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
60 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
62 virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; }
64 virtual ITensorHandle* GetParent() const override { return nullptr; }
66 virtual arm_compute::DataType GetDataType() const override
68 return m_Tensor.info()->data_type();
71 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
73 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::CLMemoryGroup>(memoryGroup);
76 TensorShape GetStrides() const override
78 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
81 TensorShape GetShape() const override
83 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
86 arm_compute::CLTensor m_Tensor;
87 std::shared_ptr<arm_compute::CLMemoryGroup> m_MemoryGroup;
90 class ClSubTensorHandle : public IClTensorHandle
93 ClSubTensorHandle(IClTensorHandle* parent,
94 const arm_compute::TensorShape& shape,
95 const arm_compute::Coordinates& coords)
96 : m_Tensor(&parent->GetTensor(), shape, coords)
98 parentHandle = parent;
101 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
102 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
104 virtual void Allocate() override {}
105 virtual void Manage() override {}
107 virtual const void* Map(bool blocking = true) const override
109 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
110 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
112 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
114 virtual ITensorHandle::Type GetType() const override { return ITensorHandle::CL; }
116 virtual ITensorHandle* GetParent() const override { return parentHandle; }
118 virtual arm_compute::DataType GetDataType() const override
120 return m_Tensor.info()->data_type();
123 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
125 TensorShape GetStrides() const override
127 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
130 TensorShape GetShape() const override
132 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
136 mutable arm_compute::CLSubTensor m_Tensor;
137 ITensorHandle* parentHandle = nullptr;