2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
5 #include "armnn/Tensor.hpp"
6 #include "armnn/Utils.hpp"
7 #include "armnn/Exceptions.hpp"
8 #include "armnn/TypesUtils.hpp"
10 #include <boost/assert.hpp>
11 #include <boost/log/trivial.hpp>
12 #include <boost/numeric/conversion/cast.hpp>
21 TensorShape::TensorShape()
26 TensorShape::TensorShape(const unsigned int numDimensions, const unsigned int* const dimensionSizes)
27 : m_NumDimensions(numDimensions)
29 if (numDimensions < 1)
31 throw InvalidArgumentException("Tensor numDimensions must be greater than 0");
34 if (numDimensions > MaxNumOfTensorDimensions)
36 throw InvalidArgumentException("Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions");
39 if (dimensionSizes == nullptr)
41 throw InvalidArgumentException("Tensor dimensionSizes must not be NULL");
44 std::copy(dimensionSizes, dimensionSizes + numDimensions, m_Dimensions.begin());
47 TensorShape::TensorShape(std::initializer_list<unsigned int> dimensionSizeList)
48 : TensorShape(boost::numeric_cast<unsigned int>(dimensionSizeList.size()), dimensionSizeList.begin())
52 TensorShape::TensorShape(const TensorShape& other)
53 : m_NumDimensions(other.m_NumDimensions)
55 std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
58 TensorShape& TensorShape::operator =(const TensorShape& other)
60 m_NumDimensions = other.m_NumDimensions;
61 std::copy(other.m_Dimensions.cbegin(), other.m_Dimensions.cbegin() + other.m_NumDimensions, m_Dimensions.begin());
65 bool TensorShape::operator==(const TensorShape& other) const
67 return ((m_NumDimensions == other.m_NumDimensions) &&
68 std::equal(m_Dimensions.cbegin(), m_Dimensions.cbegin() + m_NumDimensions, other.m_Dimensions.cbegin()));
71 bool TensorShape::operator!=(const TensorShape& other) const
73 return !(*this == other);
76 unsigned int TensorShape::GetNumElements() const
78 if (m_NumDimensions == 0)
83 unsigned int count = 1;
84 for (unsigned int i = 0; i < m_NumDimensions; i++)
86 count *= m_Dimensions[i];
96 TensorInfo::TensorInfo()
97 : m_DataType(DataType::Float32)
101 TensorInfo::TensorInfo(const TensorShape& shape, DataType dataType,
102 float quantizationScale, int32_t quantizationOffset)
104 , m_DataType(dataType)
106 m_Quantization.m_Scale = quantizationScale;
107 m_Quantization.m_Offset = quantizationOffset;
110 TensorInfo::TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType,
111 float quantizationScale, int32_t quantizationOffset)
112 : m_Shape(numDimensions, dimensionSizes)
113 , m_DataType(dataType)
115 m_Quantization.m_Scale = quantizationScale;
116 m_Quantization.m_Offset = quantizationOffset;
119 TensorInfo::TensorInfo(const TensorInfo& other)
120 : m_Shape(other.m_Shape)
121 , m_DataType(other.m_DataType)
122 , m_Quantization(other.m_Quantization)
126 TensorInfo& TensorInfo::operator=(const TensorInfo& other)
128 m_Shape = other.m_Shape;
129 m_DataType = other.m_DataType;
130 m_Quantization = other.m_Quantization;
134 bool TensorInfo::operator==(const TensorInfo& other) const
136 return ((m_Shape == other.m_Shape) &&
137 (m_DataType == other.m_DataType) &&
138 (m_Quantization == other.m_Quantization));
141 bool TensorInfo::operator!=(const TensorInfo& other) const
143 return !(*this == other);
146 unsigned int TensorInfo::GetNumBytes() const
148 return GetDataTypeSize(m_DataType) * GetNumElements();
155 template<typename MemoryType>
156 BaseTensor<MemoryType>::BaseTensor()
157 : m_MemoryArea(nullptr)
161 template<typename MemoryType>
162 BaseTensor<MemoryType>::BaseTensor(const TensorInfo& info, MemoryType memoryArea)
163 : m_MemoryArea(memoryArea)
168 template<typename MemoryType>
169 BaseTensor<MemoryType>::BaseTensor(const BaseTensor<MemoryType>& other)
170 : m_MemoryArea(other.m_MemoryArea)
171 , m_Info(other.GetInfo())
175 template<typename MemoryType>
176 BaseTensor<MemoryType>& BaseTensor<MemoryType>::operator =(const BaseTensor<MemoryType>& other)
178 m_Info = other.m_Info;
179 m_MemoryArea = other.m_MemoryArea;
183 // Explicit instantiations.
184 template class BaseTensor<const void*>;
185 template class BaseTensor<void*>;