2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <backends/OutputHandler.hpp>
8 #include <backends/aclCommon/ArmComputeTensorUtils.hpp>
10 #include <arm_compute/runtime/MemoryGroup.h>
11 #include <arm_compute/runtime/IMemoryGroup.h>
12 #include <arm_compute/runtime/Tensor.h>
13 #include <arm_compute/runtime/SubTensor.h>
14 #include <arm_compute/core/TensorShape.h>
15 #include <arm_compute/core/Coordinates.h>
17 #include <boost/polymorphic_pointer_cast.hpp>
22 class INeonTensorHandle : public ITensorHandle
25 virtual arm_compute::ITensor& GetTensor() = 0;
26 virtual arm_compute::ITensor const& GetTensor() const = 0;
27 virtual arm_compute::DataType GetDataType() const = 0;
28 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
31 class NeonTensorHandle : public INeonTensorHandle
34 NeonTensorHandle(const TensorInfo& tensorInfo)
36 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
39 NeonTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
41 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
44 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
45 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
47 virtual void Allocate() override
49 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
52 virtual void Manage() override
54 BOOST_ASSERT(m_MemoryGroup != nullptr);
55 m_MemoryGroup->manage(&m_Tensor);
58 virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
60 virtual ITensorHandle* GetParent() const override { return nullptr; }
62 virtual arm_compute::DataType GetDataType() const override
64 return m_Tensor.info()->data_type();
67 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
69 m_MemoryGroup = boost::polymorphic_pointer_downcast<arm_compute::MemoryGroup>(memoryGroup);
72 virtual const void* Map(bool /* blocking = true */) const override
74 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
76 virtual void Unmap() const override {}
79 TensorShape GetStrides() const override
81 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
84 TensorShape GetShape() const override
86 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
90 arm_compute::Tensor m_Tensor;
91 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
94 class NeonSubTensorHandle : public INeonTensorHandle
97 NeonSubTensorHandle(INeonTensorHandle* parent,
98 const arm_compute::TensorShape& shape,
99 const arm_compute::Coordinates& coords)
100 : m_Tensor(&parent->GetTensor(), shape, coords)
102 parentHandle = parent;
105 arm_compute::ITensor& GetTensor() override { return m_Tensor; }
106 arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
108 virtual void Allocate() override {}
109 virtual void Manage() override {}
111 virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
113 virtual ITensorHandle* GetParent() const override { return parentHandle; }
115 virtual arm_compute::DataType GetDataType() const override
117 return m_Tensor.info()->data_type();
120 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
122 virtual const void* Map(bool /* blocking = true */) const override
124 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
126 virtual void Unmap() const override {}
128 TensorShape GetStrides() const override
130 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
133 TensorShape GetShape() const override
135 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
138 arm_compute::SubTensor m_Tensor;
139 ITensorHandle* parentHandle = nullptr;