IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / src / backends / backendsCommon / CpuTensorHandle.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "CpuTensorHandleFwd.hpp"
8 #include "CompatibleTypes.hpp"
9
10 #include <armnn/TypesUtils.hpp>
11
12 #include <backendsCommon/OutputHandler.hpp>
13
14 #include <algorithm>
15
16 namespace armnn
17 {
18
19 // Abstract tensor handles wrapping a CPU-readable region of memory, interpreting it as tensor data.
20 class ConstCpuTensorHandle : public ITensorHandle
21 {
22 public:
23     template <typename T>
24     const T* GetConstTensor() const
25     {
26         BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
27         return reinterpret_cast<const T*>(m_Memory);
28     }
29
30     const TensorInfo& GetTensorInfo() const
31     {
32         return m_TensorInfo;
33     }
34
35     virtual void Manage() override {}
36
37     virtual ITensorHandle* GetParent() const override { return nullptr; }
38
39     virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
40     virtual void Unmap() const override {}
41
42     TensorShape GetStrides() const override
43     {
44         TensorShape shape(m_TensorInfo.GetShape());
45         auto size = GetDataTypeSize(m_TensorInfo.GetDataType());
46         auto runningSize = size;
47         std::vector<unsigned int> strides(shape.GetNumDimensions());
48         auto lastIdx = shape.GetNumDimensions()-1;
49         for (unsigned int i=0; i < lastIdx ; i++)
50         {
51             strides[lastIdx-i] = runningSize;
52             runningSize *= shape[lastIdx-i];
53         }
54         strides[0] = runningSize;
55         return TensorShape(shape.GetNumDimensions(), strides.data());
56     }
57     TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
58
59 protected:
60     ConstCpuTensorHandle(const TensorInfo& tensorInfo);
61
62     void SetConstMemory(const void* mem) { m_Memory = mem; }
63
64 private:
65     // Only used for testing
66     void CopyOutTo(void *) const override {}
67     void CopyInFrom(const void*) override {}
68
69     ConstCpuTensorHandle(const ConstCpuTensorHandle& other) = delete;
70     ConstCpuTensorHandle& operator=(const ConstCpuTensorHandle& other) = delete;
71
72     TensorInfo m_TensorInfo;
73     const void* m_Memory;
74 };
75
76 template<>
77 const void* ConstCpuTensorHandle::GetConstTensor<void>() const;
78
79 // Abstract specialization of ConstCpuTensorHandle that allows write access to the same data.
80 class CpuTensorHandle : public ConstCpuTensorHandle
81 {
82 public:
83     template <typename T>
84     T* GetTensor() const
85     {
86         BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
87         return reinterpret_cast<T*>(m_MutableMemory);
88     }
89
90 protected:
91     CpuTensorHandle(const TensorInfo& tensorInfo);
92
93     void SetMemory(void* mem)
94     {
95         m_MutableMemory = mem;
96         SetConstMemory(m_MutableMemory);
97     }
98
99 private:
100
101     CpuTensorHandle(const CpuTensorHandle& other) = delete;
102     CpuTensorHandle& operator=(const CpuTensorHandle& other) = delete;
103     void* m_MutableMemory;
104 };
105
106 template <>
107 void* CpuTensorHandle::GetTensor<void>() const;
108
109 // A CpuTensorHandle that owns the wrapped memory region.
110 class ScopedCpuTensorHandle : public CpuTensorHandle
111 {
112 public:
113     explicit ScopedCpuTensorHandle(const TensorInfo& tensorInfo);
114
115     // Copies contents from Tensor.
116     explicit ScopedCpuTensorHandle(const ConstTensor& tensor);
117
118     // Copies contents from ConstCpuTensorHandle
119     explicit ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle);
120
121     ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other);
122     ScopedCpuTensorHandle& operator=(const ScopedCpuTensorHandle& other);
123     ~ScopedCpuTensorHandle();
124
125     virtual void Allocate() override;
126
127 private:
128     // Only used for testing
129     void CopyOutTo(void* memory) const override;
130     void CopyInFrom(const void* memory) override;
131
132     void CopyFrom(const ScopedCpuTensorHandle& other);
133     void CopyFrom(const void* srcMemory, unsigned int numBytes);
134 };
135
136 // A CpuTensorHandle that wraps an already allocated memory region.
137 //
138 // Clients must make sure the passed in memory region stays alive for the lifetime of
139 // the PassthroughCpuTensorHandle instance.
140 //
141 // Note there is no polymorphism to/from ConstPassthroughCpuTensorHandle.
142 class PassthroughCpuTensorHandle : public CpuTensorHandle
143 {
144 public:
145     PassthroughCpuTensorHandle(const TensorInfo& tensorInfo, void* mem)
146     :   CpuTensorHandle(tensorInfo)
147     {
148         SetMemory(mem);
149     }
150
151     virtual void Allocate() override;
152 };
153
154 // A ConstCpuTensorHandle that wraps an already allocated memory region.
155 //
156 // This allows users to pass in const memory to a network.
157 // Clients must make sure the passed in memory region stays alive for the lifetime of
158 // the PassthroughCpuTensorHandle instance.
159 //
160 // Note there is no polymorphism to/from PassthroughCpuTensorHandle.
161 class ConstPassthroughCpuTensorHandle : public ConstCpuTensorHandle
162 {
163 public:
164     ConstPassthroughCpuTensorHandle(const TensorInfo& tensorInfo, const void* mem)
165     :   ConstCpuTensorHandle(tensorInfo)
166     {
167         SetConstMemory(mem);
168     }
169
170     virtual void Allocate() override;
171 };
172
173
174 // Template specializations.
175
176 template <>
177 const void* ConstCpuTensorHandle::GetConstTensor() const;
178
179 template <>
180 void* CpuTensorHandle::GetTensor() const;
181
182 } // namespace armnn