2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
5 #include "armnn/Descriptors.hpp"
6 #include "armnn/Logging.hpp"
8 #include <armnn/utility/Assert.hpp>
9 #include <armnn/utility/NumericCast.hpp>
15 #include <fmt/format.h>
20 PermutationVector::PermutationVector(const ValueType *dimMappings, const SizeType numDimMappings)
24 if (numDimMappings > MaxNumOfTensorDimensions)
26 throw InvalidArgumentException(
27 fmt::format("The number of mappings ({0}) cannot be greater "
28 "than the maximum number of dimensions supported ({1})",
30 MaxNumOfTensorDimensions));
33 if ((dimMappings == nullptr) && (numDimMappings != 0))
35 throw InvalidArgumentException("Dimension mappings must not be NULL if the number of mappings is positive");
38 for (SizeType i = 0; i < numDimMappings; ++i)
40 const ValueType dstIndex = dimMappings[i];
41 if (dstIndex >= numDimMappings)
43 throw InvalidArgumentException(
44 fmt::format("Dimension mapping at index {0} is invalid: "
45 "{1} is outside of the valid range [0,{2}]",
48 (numDimMappings - 1)));
52 // Validation: Detect duplicates
54 std::array<bool, MaxNumOfTensorDimensions> observedDims;
55 observedDims.fill(false);
57 for (SizeType i = 0; i < numDimMappings; ++i)
59 const ValueType dstIndex = dimMappings[i];
60 if (observedDims[dstIndex])
62 throw InvalidArgumentException("Invalid dimension mappings: Two or more source dimensions are mapped "
63 "to the same output dimension");
65 observedDims[dstIndex] = true;
70 for (SizeType i = 0; i < numDimMappings; ++i)
72 m_DimMappings[i] = dimMappings[i];
74 m_NumDimMappings = numDimMappings;
77 PermutationVector::PermutationVector(std::initializer_list<ValueType> dimMappings)
78 : PermutationVector(dimMappings.begin(), armnn::numeric_cast<SizeType>(dimMappings.size()))
82 OriginsDescriptor::OriginsDescriptor()
86 , m_ViewOrigins(nullptr)
89 OriginsDescriptor::OriginsDescriptor(uint32_t numViews, uint32_t numDimensions /*= 4*/)
91 , m_NumViews(numViews)
92 , m_NumDimensions(numDimensions)
93 , m_ViewOrigins(numViews && numDimensions > 0 ? new uint32_t *[numViews]() : nullptr)
95 for (uint32_t i = 0; m_NumDimensions > 0 && i < m_NumViews; ++i)
97 m_ViewOrigins[i] = new uint32_t[m_NumDimensions]();
101 OriginsDescriptor::OriginsDescriptor(const OriginsDescriptor& other)
102 : m_ConcatAxis(other.m_ConcatAxis)
103 , m_NumViews(other.m_NumViews)
104 , m_NumDimensions(other.m_NumDimensions)
105 , m_ViewOrigins(other.m_NumViews && other.m_NumDimensions > 0 ? new uint32_t *[other.m_NumViews]() : nullptr)
107 for (uint32_t i = 0; m_NumDimensions > 0 && i < m_NumViews; ++i)
109 m_ViewOrigins[i] = new uint32_t[m_NumDimensions]();
110 memcpy(m_ViewOrigins[i], other.m_ViewOrigins[i], m_NumDimensions * sizeof(uint32_t));
114 OriginsDescriptor::OriginsDescriptor(OriginsDescriptor&& other)
115 : OriginsDescriptor()
120 OriginsDescriptor::~OriginsDescriptor()
122 for (uint32_t i = 0; m_NumDimensions > 0 && i < m_NumViews; ++i)
124 delete[] m_ViewOrigins[i];
126 delete[] m_ViewOrigins;
129 OriginsDescriptor& OriginsDescriptor::operator=(OriginsDescriptor rhs)
135 bool OriginsDescriptor::operator==(const OriginsDescriptor& rhs) const
137 if (GetNumViews() != rhs.GetNumViews() ||
138 GetNumDimensions() != rhs.GetNumDimensions() ||
139 GetConcatAxis() != rhs.GetConcatAxis())
144 for (unsigned int i = 0u; i < GetNumViews(); ++i)
146 for (unsigned int j = 0u; j < GetNumDimensions(); ++j)
148 if (GetViewOrigin(i)[j] != rhs.GetViewOrigin(i)[j])
158 void OriginsDescriptor::SetConcatAxis(unsigned int concatAxis)
160 m_ConcatAxis = concatAxis;
162 unsigned int OriginsDescriptor::GetConcatAxis() const
167 Status OriginsDescriptor::SetViewOriginCoord(uint32_t view, uint32_t coord, uint32_t value)
169 if (view >= m_NumViews)
171 ARMNN_LOG(error) << "OriginsDescriptor::SetViewOriginCoord: view argument:" << view <<
173 return Status::Failure;
175 if (coord >= m_NumDimensions)
177 ARMNN_LOG(error) << "OriginsDescriptor::SetViewOriginCoord: coord argument:" << coord <<
179 return Status::Failure;
182 m_ViewOrigins[view][coord] = value;
183 return Status::Success;
187 uint32_t OriginsDescriptor::GetNumViews() const
192 uint32_t OriginsDescriptor::GetNumDimensions() const
194 return m_NumDimensions;
197 const uint32_t* OriginsDescriptor::GetViewOrigin(uint32_t idx) const
199 return m_ViewOrigins ? m_ViewOrigins[idx] : nullptr;
203 // Reorders the viewOrigins in accordance with the indices presented in newOrdering array.
204 void OriginsDescriptor::ReorderOrigins(unsigned int* newOrdering, unsigned int numNewOrdering)
206 ARMNN_ASSERT_MSG(m_NumViews == numNewOrdering, "number of views must match number of "
207 "elements in the new ordering array");
208 std::vector<uint32_t*> viewOrigins(&m_ViewOrigins[0], &m_ViewOrigins[m_NumViews]);
210 for (unsigned int i = 0; i < numNewOrdering; ++i)
212 m_ViewOrigins[i] = viewOrigins[newOrdering[i]];
216 ViewsDescriptor::ViewsDescriptor()
218 , m_ViewSizes(nullptr)
221 ViewsDescriptor::ViewsDescriptor(uint32_t numViews, uint32_t numDimensions /*= 4*/)
222 : m_Origins(numViews, numDimensions)
223 , m_ViewSizes(numViews > 0 && numDimensions > 0 ?
224 new uint32_t *[numViews]() : nullptr)
228 for (uint32_t i = 0; GetNumDimensions() > 0 && i < GetNumViews(); ++i)
230 m_ViewSizes[i] = new uint32_t[GetNumDimensions()]();
235 ViewsDescriptor::ViewsDescriptor(const ViewsDescriptor& other)
236 : m_Origins(other.m_Origins)
237 , m_ViewSizes(other.GetNumViews() > 0 && other.GetNumDimensions() > 0 ?
238 new uint32_t *[other.GetNumViews()]() : nullptr)
242 for (uint32_t i = 0; GetNumDimensions() > 0 && i < GetNumViews(); ++i)
244 m_ViewSizes[i] = new uint32_t[GetNumDimensions()]();
245 memcpy(m_ViewSizes[i], other.m_ViewSizes[i], GetNumDimensions() * sizeof(uint32_t));
250 ViewsDescriptor::ViewsDescriptor(ViewsDescriptor&& other)
256 ViewsDescriptor::~ViewsDescriptor()
260 for (uint32_t i = 0; GetNumDimensions() > 0 && i < GetNumViews(); ++i)
262 delete[] m_ViewSizes[i];
264 delete[] m_ViewSizes;
268 ViewsDescriptor& ViewsDescriptor::operator=(ViewsDescriptor rhs)
274 bool ViewsDescriptor::operator==(const ViewsDescriptor& rhs) const
276 if (GetNumViews() != rhs.GetNumViews() || GetNumDimensions() != rhs.GetNumDimensions())
281 for (unsigned int i = 0u; i < GetNumViews(); ++i)
283 for (unsigned int j = 0u; j < GetNumDimensions(); ++j)
285 if (GetViewOrigin(i)[j] != rhs.GetViewOrigin(i)[j] || GetViewSizes(i)[j] != rhs.GetViewSizes(i)[j])
295 uint32_t ViewsDescriptor::GetNumViews() const
297 return m_Origins.GetNumViews();
300 uint32_t ViewsDescriptor::GetNumDimensions() const
302 return m_Origins.GetNumDimensions();
305 const uint32_t* ViewsDescriptor::GetViewOrigin(uint32_t idx) const
307 return m_Origins.GetViewOrigin(idx);
310 Status ViewsDescriptor::SetViewOriginCoord(uint32_t view, uint32_t coord, uint32_t value)
312 return m_Origins.SetViewOriginCoord(view, coord, value);
315 Status ViewsDescriptor::SetViewSize(uint32_t view, uint32_t coord, uint32_t value)
319 ARMNN_LOG(error) << "ViewsDescriptor::SetViewSize: invalid view sizes";
320 return Status::Failure;
323 if (view >= GetNumViews())
325 ARMNN_LOG(error) << "ViewsDescriptor::SetViewSize: view argument:" << view <<
327 return Status::Failure;
329 if (coord >= GetNumDimensions())
331 ARMNN_LOG(error) << "ViewsDescriptor::SetViewSize: coord argument:" << coord <<
333 return Status::Failure;
336 m_ViewSizes[view][coord] = value;
337 return Status::Success;
340 const uint32_t* ViewsDescriptor::GetViewSizes(uint32_t idx) const
342 return m_ViewSizes ? m_ViewSizes[idx] : nullptr;
345 const OriginsDescriptor& ViewsDescriptor::GetOrigins() const
350 void swap(OriginsDescriptor& first, OriginsDescriptor& second)
353 swap(first.m_NumViews, second.m_NumViews);
354 swap(first.m_NumDimensions, second.m_NumDimensions);
355 swap(first.m_ViewOrigins, second.m_ViewOrigins);
356 swap(first.m_ConcatAxis, second.m_ConcatAxis);
359 void swap(ViewsDescriptor& first, ViewsDescriptor& second)
362 swap(first.m_Origins, second.m_Origins);
363 swap(first.m_ViewSizes, second.m_ViewSizes);
366 int StridedSliceDescriptor::GetStartForAxis(const TensorShape& inputShape,
367 unsigned int axis) const
369 int start = m_Begin[axis];
371 if (m_BeginMask & (1 << axis))
373 if (m_Stride[axis] > 0)
375 start = std::numeric_limits<int>::min();
379 start = std::numeric_limits<int>::max();
383 const int axisSize = armnn::numeric_cast<int>(inputShape[axis]);
389 return std::max(0, std::min(start, axisSize - 1));
393 int StridedSliceDescriptor::GetStopForAxis(const TensorShape& inputShape,
395 int startForAxis) const
398 if (m_ShrinkAxisMask & (1 << axis))
400 return startForAxis + 1;
403 int stop = m_End[axis];
405 if (m_EndMask & (1 << axis))
407 if (m_Stride[axis] > 0)
409 stop = std::numeric_limits<int>::max();
413 stop = std::numeric_limits<int>::min();
417 const int axisSize = armnn::numeric_cast<int>(inputShape[axis]);
423 return m_Stride[axis] > 0 ? std::max(0, std::min(stop, axisSize)) :
424 std::max(-1, std::min(stop, axisSize - 1));