2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include "CpuTensorHandleFwd.hpp"
9 #include <armnn/TypesUtils.hpp>
11 #include <backendsCommon/OutputHandler.hpp>
18 // Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
19 class ConstCpuTensorHandle : public ITensorHandle
23 const T* GetConstTensor() const
25 BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
26 return reinterpret_cast<const T*>(m_Memory);
29 const TensorInfo& GetTensorInfo() const
34 virtual ITensorHandle::Type GetType() const override
36 return ITensorHandle::Cpu;
39 virtual void Manage() override {}
41 virtual ITensorHandle* GetParent() const override { return nullptr; }
43 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
44 virtual void Unmap() const override {}
46 TensorShape GetStrides() const override
48 TensorShape shape(m_TensorInfo.GetShape());
49 auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
50 auto runningSize = size;
51 std::vector<unsigned int> strides(shape.GetNumDimensions());
52 auto lastIdx = shape.GetNumDimensions()-1;
53 for (unsigned int i=0; i < lastIdx ; i++)
55 strides[lastIdx-i] = runningSize;
56 runningSize *= shape[lastIdx-i];
58 strides[0] = runningSize;
59 return TensorShape(shape.GetNumDimensions(), strides.data());
61 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
64 ConstCpuTensorHandle(const TensorInfo& tensorInfo);
66 void SetConstMemory(const void* mem) { m_Memory = mem; }
69 ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
70 ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
72 TensorInfo m_TensorInfo;
76 // Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
77 class CpuTensorHandle : public ConstCpuTensorHandle
83 BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
84 return reinterpret_cast<T*>(m_MutableMemory);
88 CpuTensorHandle(const TensorInfo& tensorInfo);
90 void SetMemory(void* mem)
92 m_MutableMemory = mem;
93 SetConstMemory(m_MutableMemory);
98 CpuTensorHandle(const CpuTensorHandle& other) = delete;
99 CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
100 void* m_MutableMemory;
103 // A CpuTensorHandle that owns the wrapped memory region.
104 class ScopedCpuTensorHandle : public CpuTensorHandle
107 explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
109 // Copies contents from Tensor.
110 explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
112 // Copies contents from ConstCpuTensorHandle
113 explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
115 ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
116 ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
117 ~ScopedCpuTensorHandle();
119 virtual void Allocate() override;
122 void CopyFrom(const ScopedCpuTensorHandle& other);
123 void CopyFrom(const void* srcMemory, unsigned int numBytes);
126 // A CpuTensorHandle that wraps an already allocated memory region.
128 // Clients must make sure the passed in memory region stays alive for the lifetime of
129 // the PassthroughCpuTensorHandle instance.
131 // Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
132 class PassthroughCpuTensorHandle : public CpuTensorHandle
135 PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
136 : CpuTensorHandle(tensorInfo)
141 virtual void Allocate() override;
144 // A ConstCpuTensorHandle that wraps an already allocated memory region.
146 // This allows users to pass in const memory to a network.
147 // Clients must make sure the passed in memory region stays alive for the lifetime of
148 // the PassthroughCpuTensorHandle instance.
150 // Note there is no polymorphism to/from PassthroughCpuTensorHandle.
151 class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
154 ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
155 : ConstCpuTensorHandle(tensorInfo)
160 virtual void Allocate() override;
164 // Template specializations.
167 const void* ConstCpuTensorHandle::GetConstTensor() const;
170 void* CpuTensorHandle::GetTensor() const;