ea5e2417b3706e5868642b4ad563057e0acd6a25
[platform/core/ml/nnfw.git] / runtime / onert / core / src / interp / operations / UnaryActivations.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <cmath>
18
19 #include "OperationUtil.h"
20
21 #include "interp/Registration.h"
22
23 #include "ir/operation/ReLU.h"
24 #include "ir/operation/ReLU1.h"
25 #include "ir/operation/ReLU6.h"
26 #include "ir/operation/Tanh.h"
27
28 namespace onert
29 {
30 namespace interp
31 {
32 namespace
33 {
34
35 enum class ActivationType
36 {
37   ReLU,
38   ReLU1,
39   ReLU6,
40   Tanh
41 };
42
43 void prepare(ExecEnv *env, const ir::Operation &node)
44 {
45   const auto input_index = node.getInputs().at(0);
46   const auto output_index = node.getOutputs().at(0);
47
48   const auto input_tensor = env->tensorAt(input_index);
49
50   const auto output_info = env->graph().operands().at(output_index).info();
51   if (output_info.total_size() == 0)
52   {
53     // Output's shape and type is same with input
54     auto input_info = input_tensor->tensorInfo();
55     // We can handle already allocated (ex. model output)
56     env->allocateIfNeeded(output_index, input_info);
57   }
58   else
59   {
60     env->allocateIfNeeded(output_index, output_info);
61   }
62
63   const auto output_tensor = env->tensorAt(output_index);
64   // Check shape and type lhs is same with output
65   // TODO Util function to compare TensorInfo
66   if (input_tensor->data_type() != output_tensor->data_type())
67   {
68     throw std::runtime_error{"Interp(Activations): Invalid output type"};
69   }
70 }
71
72 template <ActivationType act_type>
73 void evalFloat(const float *input_ptr, float *output_ptr, uint64_t num_elements)
74 {
75   std::function<float(const float &)> fn = [](const float &) { return std::nanf(""); };
76   switch (act_type)
77   {
78     case ActivationType::ReLU:
79       fn = [](const float &in) { return std::max(0.f, in); };
80       break;
81     case ActivationType::ReLU1:
82       fn = [](const float &in) { return std::min(std::max(-1.f, in), 1.f); };
83       break;
84     case ActivationType::ReLU6:
85       fn = [](const float &in) { return std::min(std::max(0.f, in), 6.f); };
86       break;
87     case ActivationType::Tanh:
88       fn = [](const float &in) { return std::tanh(in); };
89       break;
90     default:
91       throw std::runtime_error{"Interp(Activations): NYI - Unsupported activation"};
92       break;
93   }
94
95   const float *input_end = input_ptr + num_elements;
96   for (; input_ptr < input_end; input_ptr++, output_ptr++)
97   {
98     *output_ptr = fn(*input_ptr);
99   }
100 }
101
102 template <ActivationType act_type> void invoke(const ExecEnv *env, const ir::Operation &node)
103 {
104   const auto input_index = node.getInputs().at(0);
105   const auto output_index = node.getOutputs().at(0);
106
107   // Check lhs shape is same with rhs (with broadcast)
108   const auto input_tensor = env->tensorAt(input_index);
109   const auto output_tensor = env->tensorAt(output_index);
110
111   const auto data_type = input_tensor->data_type();
112   if (data_type == ir::DataType::FLOAT32)
113   {
114     uint64_t elements = input_tensor->num_elements();
115     const float *input_start = reinterpret_cast<const float *>(input_tensor->bufferRO());
116     float *out = reinterpret_cast<float *>(output_tensor->buffer());
117
118     evalFloat<act_type>(input_start, out, elements);
119   }
120   else
121   {
122     throw std::runtime_error{"Interp(ReLU6): NYI - Support float only"};
123   }
124 }
125
126 } // namespace
127
128 OpKernel *getReLU()
129 {
130   static OpKernel kernel = {prepare, invoke<ActivationType::ReLU>};
131   return &kernel;
132 }
133
134 OpKernel *getReLU1()
135 {
136   static OpKernel kernel = {prepare, invoke<ActivationType::ReLU1>};
137   return &kernel;
138 }
139
140 OpKernel *getReLU6()
141 {
142   static OpKernel kernel = {prepare, invoke<ActivationType::ReLU6>};
143   return &kernel;
144 }
145
146 OpKernel *getTanh()
147 {
148   static OpKernel kernel = {prepare, invoke<ActivationType::Tanh>};
149   return &kernel;
150 }
151
152 } // namespace interp
153 } // namespace onert