2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include "CpuTensorHandleFwd.hpp"
8 #include <armnn/TypesUtils.hpp>
9 #include <backends/OutputHandler.hpp>
16 // Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
17 class ConstCpuTensorHandle : public ITensorHandle
21 const T* GetConstTensor() const
23 BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
24 return reinterpret_cast<const T*>(m_Memory);
27 const TensorInfo& GetTensorInfo() const
32 virtual ITensorHandle::Type GetType() const override
34 return ITensorHandle::Cpu;
37 virtual void Manage() override {}
39 virtual ITensorHandle* GetParent() const override { return nullptr; }
41 virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
42 virtual void Unmap() const override {}
44 TensorShape GetStrides() const override
46 TensorShape shape(m_TensorInfo.GetShape());
47 auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
48 auto runningSize = size;
49 std::vector<unsigned int> strides(shape.GetNumDimensions());
50 auto lastIdx = shape.GetNumDimensions()-1;
51 for (unsigned int i=0; i < lastIdx ; i++)
53 strides[lastIdx-i] = runningSize;
54 runningSize *= shape[lastIdx-i];
56 strides[0] = runningSize;
57 return TensorShape(shape.GetNumDimensions(), strides.data());
59 TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
62 ConstCpuTensorHandle(const TensorInfo& tensorInfo);
64 void SetConstMemory(const void* mem) { m_Memory = mem; }
67 ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
68 ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
70 TensorInfo m_TensorInfo;
74 // Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
75 class CpuTensorHandle : public ConstCpuTensorHandle
81 BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
82 return reinterpret_cast<T*>(m_MutableMemory);
86 CpuTensorHandle(const TensorInfo& tensorInfo);
88 void SetMemory(void* mem)
90 m_MutableMemory = mem;
91 SetConstMemory(m_MutableMemory);
96 CpuTensorHandle(const CpuTensorHandle& other) = delete;
97 CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
98 void* m_MutableMemory;
101 // A CpuTensorHandle that owns the wrapped memory region.
102 class ScopedCpuTensorHandle : public CpuTensorHandle
105 explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
107 // Copies contents from Tensor.
108 explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
110 // Copies contents from ConstCpuTensorHandle
111 explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
113 ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
114 ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
115 ~ScopedCpuTensorHandle();
117 virtual void Allocate() override;
120 void CopyFrom(const ScopedCpuTensorHandle& other);
121 void CopyFrom(const void* srcMemory, unsigned int numBytes);
124 // A CpuTensorHandle that wraps an already allocated memory region.
126 // Clients must make sure the passed in memory region stays alive for the lifetime of
127 // the PassthroughCpuTensorHandle instance.
129 // Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
130 class PassthroughCpuTensorHandle : public CpuTensorHandle
133 PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
134 : CpuTensorHandle(tensorInfo)
139 virtual void Allocate() override;
142 // A ConstCpuTensorHandle that wraps an already allocated memory region.
144 // This allows users to pass in const memory to a network.
145 // Clients must make sure the passed in memory region stays alive for the lifetime of
146 // the PassthroughCpuTensorHandle instance.
148 // Note there is no polymorphism to/from PassthroughCpuTensorHandle.
149 class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
152 ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
153 : ConstCpuTensorHandle(tensorInfo)
158 virtual void Allocate() override;
162 // Template specializations.
165 const void* ConstCpuTensorHandle::GetConstTensor() const;
168 void* CpuTensorHandle::GetTensor() const;