IVGCVSW-3694 Add ArgMinMax implementation for Ref
[platform/upstream/armnn.git] / src / backends / reference / workloads / Softmax.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "Softmax.hpp"
7
8 #include <TensorUtils.hpp>
9
10 #include <cmath>
11 #include <vector>
12
13 namespace armnn
14 {
15
16 /// Computes the softmax function on some inputs, into outputs, with a shape given by tensorInfo.
17 void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta, int axis)
18 {
19     BOOST_ASSERT_MSG(axis < static_cast<int>(inputTensorInfo.GetNumDimensions()),
20                      "Required axis index greater than number of dimensions.");
21     BOOST_ASSERT_MSG(axis >= -static_cast<int>(inputTensorInfo.GetNumDimensions()),
22                      "Required axis index lower than negative of the number of dimensions");
23
24     unsigned int uAxis = axis < 0  ?
25                          inputTensorInfo.GetNumDimensions() - static_cast<unsigned int>(abs(axis))
26                          : static_cast<unsigned int>(axis);
27
28     const TensorShape& inputShape = inputTensorInfo.GetShape();
29     const unsigned int outerSize  = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
30     const unsigned int axisSize   = inputShape[uAxis];
31     const unsigned int innerSize  = armnnUtils::GetNumElementsBetween(inputShape,
32                                                                       uAxis + 1,
33                                                                       inputShape.GetNumDimensions());
34
35     for (unsigned int outer = 0; outer < outerSize; ++outer)
36     {
37         unsigned int inputBeginIdx  = outer * axisSize * innerSize;
38         unsigned int inputEndIdx    = inputBeginIdx + axisSize * innerSize;
39         unsigned int outputBeginIdx = outer * axisSize * innerSize;
40
41         for (unsigned int inner = 0; inner < innerSize; ++inner, ++inputBeginIdx, ++inputEndIdx, ++outputBeginIdx)
42         {
43             // Find max
44             float maxValue = std::numeric_limits<float>::lowest();
45             for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
46             {
47                 in[iter];
48                 maxValue = std::max(maxValue, in.Get());
49             }
50
51             // Compute sum
52             float sum = 0.0f;
53             for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
54             {
55                 in[iter];
56                 sum += std::exp((in.Get() - maxValue) * beta);
57             }
58
59             // Compute result
60             unsigned int outputIter = outputBeginIdx;
61             out[outputIter];
62             for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize, outputIter += innerSize)
63             {
64                 out[outputIter];
65                 in[iter];
66                 out.Set(std::exp((in.Get() - maxValue) * beta) / sum);
67             }
68         }
69     }
70 }
71
72 } //namespace armnn