Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / RefWorkloads / Activation.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include "Activation.hpp"
7
8 #include <boost/log/trivial.hpp>
9
10 #include <cmath>
11
12 namespace armnn
13 {
14
15 void Activation(const float* in,
16                float* out,
17                const TensorInfo& tensorInfo,
18                ActivationFunction function,
19                float a,
20                float b)
21 {
22     for (size_t i = 0; i<tensorInfo.GetNumElements(); i++)
23     {
24         float input = in[i];
25         float output;
26
27         // Compute the result of the activation function.
28         switch (function)
29         {
30             case ActivationFunction::Linear:
31             {
32                 output = a * input + b;
33                 break;
34             }
35             case ActivationFunction::Sigmoid:
36             {
37                 output = 1.f / (1.f + expf(-input));
38                 break;
39             }
40             case ActivationFunction::ReLu:
41             {
42                 output = std::max(0.f, input);
43                 break;
44             }
45             case ActivationFunction::BoundedReLu:
46             {
47                 output = std::min(a, std::max(b, input));
48                 break;
49             }
50             case ActivationFunction::SoftReLu:
51             {
52                 output = logf(1.0f + expf(input));
53                 break;
54             }
55             case ActivationFunction::LeakyReLu:
56             {
57                 output = input > 0.0f ? input : (input * a);
58                 break;
59             }
60             case ActivationFunction::Abs:
61             {
62                 output = input < 0 ? -input : input;
63                 break;
64             }
65             case ActivationFunction::Sqrt:
66             {
67                 output = sqrtf(input);
68                 break;
69             }
70             case ActivationFunction::Square:
71             {
72                 output = input * input;
73                 break;
74             }
75             case ActivationFunction::TanH:
76             {
77                 output = a * tanhf(b * input);
78                 break;
79             }
80             default:
81             {
82                 BOOST_LOG_TRIVIAL(error) << "Unsupported activation function";
83                 return;
84             }
85         }
86
87         out[i] = output;
88     }
89 }
90
91 } //namespace armnn