remove BOM from files
[platform/upstream/armnn.git] / src / backends / backendsCommon / CpuTensorHandle.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <armnn/Exceptions.hpp>
6 #include <armnn/utility/IgnoreUnused.hpp>
7
8 #include <backendsCommon/CpuTensorHandle.hpp>
9
10 #include <cstring>
11
12 namespace armnn
13 {
14
15 TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo)
16 {
17     TensorShape shape(tensorInfo.GetShape());
18     auto size = GetDataTypeSize(tensorInfo.GetDataType());
19     auto runningSize = size;
20     std::vector<unsigned int> strides(shape.GetNumDimensions());
21     auto lastIdx = shape.GetNumDimensions()-1;
22     for (unsigned int i=0; i < lastIdx ; i++)
23     {
24         strides[lastIdx-i] = runningSize;
25         runningSize *= shape[lastIdx-i];
26     }
27     strides[0] = runningSize;
28     return TensorShape(shape.GetNumDimensions(), strides.data());
29 }
30
31 ConstCpuTensorHandle::ConstCpuTensorHandle(const TensorInfo& tensorInfo)
32 : m_TensorInfo(tensorInfo)
33 , m_Memory(nullptr)
34 {
35 }
36
37 template <>
38 const void* ConstCpuTensorHandle::GetConstTensor<void>() const
39 {
40     return m_Memory;
41 }
42
43 CpuTensorHandle::CpuTensorHandle(const TensorInfo& tensorInfo)
44 : ConstCpuTensorHandle(tensorInfo)
45 , m_MutableMemory(nullptr)
46 {
47 }
48
49 template <>
50 void* CpuTensorHandle::GetTensor<void>() const
51 {
52     return m_MutableMemory;
53 }
54
55 ScopedCpuTensorHandle::ScopedCpuTensorHandle(const TensorInfo& tensorInfo)
56 : CpuTensorHandle(tensorInfo)
57 {
58 }
59
60 ScopedCpuTensorHandle::ScopedCpuTensorHandle(const ConstTensor& tensor)
61 : ScopedCpuTensorHandle(tensor.GetInfo())
62 {
63     CopyFrom(tensor.GetMemoryArea(), tensor.GetNumBytes());
64 }
65
66 ScopedCpuTensorHandle::ScopedCpuTensorHandle(const ConstCpuTensorHandle& tensorHandle)
67 : ScopedCpuTensorHandle(tensorHandle.GetTensorInfo())
68 {
69     CopyFrom(tensorHandle.GetConstTensor<void>(), tensorHandle.GetTensorInfo().GetNumBytes());
70 }
71
72 ScopedCpuTensorHandle::ScopedCpuTensorHandle(const ScopedCpuTensorHandle& other)
73 : CpuTensorHandle(other.GetTensorInfo())
74 {
75     CopyFrom(other);
76 }
77
78 ScopedCpuTensorHandle& ScopedCpuTensorHandle::operator=(const ScopedCpuTensorHandle& other)
79 {
80     ::operator delete(GetTensor<void>());
81     SetMemory(nullptr);
82     CopyFrom(other);
83     return *this;
84 }
85
86 ScopedCpuTensorHandle::~ScopedCpuTensorHandle()
87 {
88     ::operator delete(GetTensor<void>());
89 }
90
91 void ScopedCpuTensorHandle::Allocate()
92 {
93     if (GetTensor<void>() == nullptr)
94     {
95         SetMemory(::operator new(GetTensorInfo().GetNumBytes()));
96     }
97     else
98     {
99         throw InvalidArgumentException("CpuTensorHandle::Allocate Trying to allocate a CpuTensorHandle"
100             "that already has allocated memory.");
101     }
102 }
103
104 void ScopedCpuTensorHandle::CopyOutTo(void* memory) const
105 {
106     memcpy(memory, GetTensor<void>(), GetTensorInfo().GetNumBytes());
107 }
108
109 void ScopedCpuTensorHandle::CopyInFrom(const void* memory)
110 {
111     memcpy(GetTensor<void>(), memory, GetTensorInfo().GetNumBytes());
112 }
113
114 void ScopedCpuTensorHandle::CopyFrom(const ScopedCpuTensorHandle& other)
115 {
116     CopyFrom(other.GetTensor<void>(), other.GetTensorInfo().GetNumBytes());
117 }
118
119 void ScopedCpuTensorHandle::CopyFrom(const void* srcMemory, unsigned int numBytes)
120 {
121     ARMNN_ASSERT(GetTensor<void>() == nullptr);
122     ARMNN_ASSERT(GetTensorInfo().GetNumBytes() == numBytes);
123
124     if (srcMemory)
125     {
126         Allocate();
127         memcpy(GetTensor<void>(), srcMemory, numBytes);
128     }
129 }
130
131 void PassthroughCpuTensorHandle::Allocate()
132 {
133     throw InvalidArgumentException("PassthroughCpuTensorHandle::Allocate() should never be called");
134 }
135
136 void ConstPassthroughCpuTensorHandle::Allocate()
137 {
138     throw InvalidArgumentException("ConstPassthroughCpuTensorHandle::Allocate() should never be called");
139 }
140
141 } // namespace armnn