IVGCVSW-3663 Add utility function to expand tensor dimension
[platform/upstream/armnn.git] / src / armnnUtils / TensorUtils.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "TensorUtils.hpp"
7 #include <backendsCommon/ITensorHandle.hpp>
8
9 #include <boost/assert.hpp>
10 #include <boost/format.hpp>
11 #include <boost/numeric/conversion/cast.hpp>
12
13 namespace armnnUtils
14 {
15
16 armnn::TensorShape GetTensorShape(unsigned int numberOfBatches,
17                                   unsigned int numberOfChannels,
18                                   unsigned int height,
19                                   unsigned int width,
20                                   const armnn::DataLayout dataLayout)
21 {
22     switch (dataLayout)
23     {
24         case armnn::DataLayout::NCHW:
25             return armnn::TensorShape({numberOfBatches, numberOfChannels, height, width});
26         case armnn::DataLayout::NHWC:
27             return armnn::TensorShape({numberOfBatches, height, width, numberOfChannels});
28         default:
29             throw armnn::InvalidArgumentException("Unknown data layout ["
30                                                   + std::to_string(static_cast<int>(dataLayout)) +
31                                                   "]", CHECK_LOCATION());
32     }
33 }
34
35 armnn::TensorInfo GetTensorInfo(unsigned int numberOfBatches,
36                                 unsigned int numberOfChannels,
37                                 unsigned int height,
38                                 unsigned int width,
39                                 const armnn::DataLayout dataLayout,
40                                 const armnn::DataType dataType)
41 {
42     switch (dataLayout)
43     {
44         case armnn::DataLayout::NCHW:
45             return armnn::TensorInfo({numberOfBatches, numberOfChannels, height, width}, dataType);
46         case armnn::DataLayout::NHWC:
47             return armnn::TensorInfo({numberOfBatches, height, width, numberOfChannels}, dataType);
48         default:
49             throw armnn::InvalidArgumentException("Unknown data layout ["
50                                                   + std::to_string(static_cast<int>(dataLayout)) +
51                                                   "]", CHECK_LOCATION());
52     }
53 }
54
55 std::pair<float, float> FindMinMax(armnn::ITensorHandle* tensorHandle)
56 {
57     auto tensor_data = static_cast<const float *>(tensorHandle->Map(true));
58     auto tensor_size = tensorHandle->GetShape().GetNumElements();
59
60     // Set min/max initially to first value in tensor
61     float min = tensor_data[0];
62     float max = tensor_data[0];
63
64     // Loop over rest of tensor and update min/max if necessary
65     for (unsigned int val = 1; val < tensor_size; val++)
66     {
67         if (tensor_data[val] < min)
68         {
69             min = tensor_data[val];
70         }
71         else if (tensor_data[val] > max)
72         {
73             max = tensor_data[val];
74         }
75     }
76
77     tensorHandle->Unmap();
78
79     return std::make_pair(min, max);
80 }
81
82 armnn::TensorShape ExpandDims(const armnn::TensorShape& tensorShape, int axis)
83 {
84     unsigned int outputDim = tensorShape.GetNumDimensions() + 1;
85
86     if (axis < -boost::numeric_cast<int>(outputDim) || axis > boost::numeric_cast<int>(tensorShape.GetNumDimensions()))
87     {
88         throw armnn::InvalidArgumentException(
89             boost::str(boost::format("Invalid expansion axis %1% for %2%D input tensor. %3%") %
90                        axis %
91                        tensorShape.GetNumDimensions() %
92                        CHECK_LOCATION().AsString()));
93     }
94
95     if (axis < 0)
96     {
97         axis = boost::numeric_cast<int>(outputDim) + axis;
98     }
99
100     std::vector<unsigned int> outputShape;
101     for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); ++i)
102     {
103         outputShape.push_back(tensorShape[i]);
104     }
105     outputShape.insert(outputShape.begin() + axis, 1);
106
107     return armnn::TensorShape(outputDim, outputShape.data());
108 }
109
110 }