2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include "CpuTensorHandleFwd.hpp"
8 #include "CompatibleTypes.hpp"
10 #include <armnn/TypesUtils.hpp>
12 #include <backendsCommon/OutputHandler.hpp>
19 // Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
20 class ConstCpuTensorHandle : public ITensorHandle
24 const T* GetConstTensor() const
26 BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
27 return reinterpret_cast<const T*>(m_Memory);
30 const TensorInfo& GetTensorInfo() const
35 virtual void Manage() override {}
37 virtual ITensorHandle* GetParent() const override { return nullptr; }
39 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
40 virtual void Unmap() const override {}
42 TensorShape GetStrides() const override
44 TensorShape shape(m_TensorInfo.GetShape());
45 auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
46 auto runningSize = size;
47 std::vector<unsigned int> strides(shape.GetNumDimensions());
48 auto lastIdx = shape.GetNumDimensions()-1;
49 for (unsigned int i=0; i < lastIdx ; i++)
51 strides[lastIdx-i] = runningSize;
52 runningSize *= shape[lastIdx-i];
54 strides[0] = runningSize;
55 return TensorShape(shape.GetNumDimensions(), strides.data());
57 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
60 ConstCpuTensorHandle(const TensorInfo& tensorInfo);
62 void SetConstMemory(const void* mem) { m_Memory = mem; }
65 // Only used for testing
66 void CopyOutTo(void *) const override {}
67 void CopyInFrom(const void*) override {}
69 ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
70 ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
72 TensorInfo m_TensorInfo;
77 const void* ConstCpuTensorHandle::GetConstTensor<void>() const;
79 // Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
80 class CpuTensorHandle : public ConstCpuTensorHandle
86 BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
87 return reinterpret_cast<T*>(m_MutableMemory);
91 CpuTensorHandle(const TensorInfo& tensorInfo);
93 void SetMemory(void* mem)
95 m_MutableMemory = mem;
96 SetConstMemory(m_MutableMemory);
101 CpuTensorHandle(const CpuTensorHandle& other) = delete;
102 CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
103 void* m_MutableMemory;
107 void* CpuTensorHandle::GetTensor<void>() const;
109 // A CpuTensorHandle that owns the wrapped memory region.
110 class ScopedCpuTensorHandle : public CpuTensorHandle
113 explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
115 // Copies contents from Tensor.
116 explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
118 // Copies contents from ConstCpuTensorHandle
119 explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
121 ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
122 ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
123 ~ScopedCpuTensorHandle();
125 virtual void Allocate() override;
128 // Only used for testing
129 void CopyOutTo(void* memory) const override;
130 void CopyInFrom(const void* memory) override;
132 void CopyFrom(const ScopedCpuTensorHandle& other);
133 void CopyFrom(const void* srcMemory, unsigned int numBytes);
136 // A CpuTensorHandle that wraps an already allocated memory region.
138 // Clients must make sure the passed in memory region stays alive for the lifetime of
139 // the PassthroughCpuTensorHandle instance.
141 // Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
142 class PassthroughCpuTensorHandle : public CpuTensorHandle
145 PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
146 : CpuTensorHandle(tensorInfo)
151 virtual void Allocate() override;
154 // A ConstCpuTensorHandle that wraps an already allocated memory region.
156 // This allows users to pass in const memory to a network.
157 // Clients must make sure the passed in memory region stays alive for the lifetime of
158 // the PassthroughCpuTensorHandle instance.
160 // Note there is no polymorphism to/from PassthroughCpuTensorHandle.
161 class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
164 ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
165 : ConstCpuTensorHandle(tensorInfo)
170 virtual void Allocate() override;
174 // Template specializations.
177 const void* ConstCpuTensorHandle::GetConstTensor() const;
180 void* CpuTensorHandle::GetTensor() const;