IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / backendsCommon / CpuTensorHandle.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "CpuTensorHandleFwd.hpp"
8
9 #include <armnn/TypesUtils.hpp>
10
11 #include <backendsCommon/OutputHandler.hpp>
12
13 #include <algorithm>
14
15 namespace armnn
16 {
17
18 // Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
19 class ConstCpuTensorHandle : public ITensorHandle
20 {
21 public:
22     template <typename T>
23     const T* GetConstTensor() const
24     {
25         BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
26         return reinterpret_cast<const T*>(m_Memory);
27     }
28
29     const TensorInfo& GetTensorInfo() const
30     {
31         return m_TensorInfo;
32     }
33
34     virtual ITensorHandle::Type GetType() const override
35     {
36         return ITensorHandle::Cpu;
37     }
38
39     virtual void Manage() override {}
40
41     virtual ITensorHandle* GetParent() const override { return nullptr; }
42
43     virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
44     virtual void Unmap() const override {}
45
46     TensorShape GetStrides() const override
47     {
48         TensorShape shape(m_TensorInfo.GetShape());
49         auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
50         auto runningSize = size;
51         std::vector<unsigned int> strides(shape.GetNumDimensions());
52         auto lastIdx = shape.GetNumDimensions()-1;
53         for (unsigned int i=0; i < lastIdx ; i++)
54         {
55             strides[lastIdx-i] = runningSize;
56             runningSize *= shape[lastIdx-i];
57         }
58         strides[0] = runningSize;
59         return TensorShape(shape.GetNumDimensions(), strides.data());
60     }
61     TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
62
63 protected:
64     ConstCpuTensorHandle(const TensorInfo& tensorInfo);
65
66     void SetConstMemory(const void* mem) { m_Memory = mem; }
67
68 private:
69     ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
70     ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
71
72     TensorInfo m_TensorInfo;
73     const void* m_Memory;
74 };
75
76 // Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
77 class CpuTensorHandle : public ConstCpuTensorHandle
78 {
79 public:
80     template <typename T>
81     T* GetTensor() const
82     {
83         BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
84         return reinterpret_cast<T*>(m_MutableMemory);
85     }
86
87 protected:
88     CpuTensorHandle(const TensorInfo& tensorInfo);
89
90     void SetMemory(void* mem)
91     {
92         m_MutableMemory = mem;
93         SetConstMemory(m_MutableMemory);
94     }
95
96 private:
97
98     CpuTensorHandle(const CpuTensorHandle& other) = delete;
99     CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
100     void* m_MutableMemory;
101 };
102
103 // A CpuTensorHandle that owns the wrapped memory region.
104 class ScopedCpuTensorHandle : public CpuTensorHandle
105 {
106 public:
107     explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
108
109     // Copies contents from Tensor.
110     explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
111
112     // Copies contents from ConstCpuTensorHandle
113     explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
114
115     ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
116     ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
117     ~ScopedCpuTensorHandle();
118
119     virtual void Allocate() override;
120
121 private:
122     void CopyFrom(const ScopedCpuTensorHandle& other);
123     void CopyFrom(const void* srcMemory, unsigned int numBytes);
124 };
125
126 // A CpuTensorHandle that wraps an already allocated memory region.
127 //
128 // Clients must make sure the passed in memory region stays alive for the lifetime of
129 // the PassthroughCpuTensorHandle instance.
130 //
131 // Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
132 class PassthroughCpuTensorHandle : public CpuTensorHandle
133 {
134 public:
135     PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
136     :   CpuTensorHandle(tensorInfo)
137     {
138         SetMemory(mem);
139     }
140
141     virtual void Allocate() override;
142 };
143
144 // A ConstCpuTensorHandle that wraps an already allocated memory region.
145 //
146 // This allows users to pass in const memory to a network.
147 // Clients must make sure the passed in memory region stays alive for the lifetime of
148 // the PassthroughCpuTensorHandle instance.
149 //
150 // Note there is no polymorphism to/from PassthroughCpuTensorHandle.
151 class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
152 {
153 public:
154     ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
155     :   ConstCpuTensorHandle(tensorInfo)
156     {
157         SetConstMemory(mem);
158     }
159
160     virtual void Allocate() override;
161 };
162
163
164 // Template specializations.
165
166 template <>
167 const void* ConstCpuTensorHandle::GetConstTensor() const;
168
169 template <>
170 void* CpuTensorHandle::GetTensor() const;
171
172 } // namespace armnn