44c955421ccbe0e89d25ac817e93c30d49345e61
[platform/core/ml/nnfw.git] / runtime / onert / core / src / interp / operations / BinaryArithmeticOps.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <cker/operation/BinaryArithmeticOps.h>
18
19 #include "OperationUtil.h"
20
21 #include "interp/Registration.h"
22 #include "ir/operation/Add.h"
23 #include "ir/operation/Sub.h"
24 #include "ir/operation/Mul.h"
25 #include "misc/polymorphic_downcast.h"
26 #include "cker/Types.h"
27
28 namespace onert
29 {
30 namespace interp
31 {
32 namespace
33 {
34
35 enum class OpType
36 {
37   ADD,
38   SUB,
39   MUL
40 };
41
42 template <typename node_type> void prepareAdd(ExecEnv *env, const ir::Operation &node)
43 {
44   const auto &add_node = nnfw::misc::polymorphic_downcast<const node_type &>(node);
45
46   const auto lhs_index = node.getInputs().at(add_node.LHS);
47   const auto rhs_index = node.getInputs().at(add_node.RHS);
48   const auto out_index = node.getOutputs().at(0);
49
50   const auto lhs_tensor = env->tensorAt(lhs_index);
51   const auto rhs_tensor = env->tensorAt(rhs_index);
52
53   // Check shape and type lhs is same with rhs
54   // TODO Util function to compare TensorInfo
55   if (lhs_tensor->data_type() != rhs_tensor->data_type())
56   {
57     throw std::runtime_error{"Interp(Add): Different input types"};
58   }
59
60   bool try_broadcast = (lhs_tensor->tensorInfo().shape() != rhs_tensor->tensorInfo().shape());
61   if (try_broadcast)
62   {
63     bool success = true;
64     auto out_shape = calcBroadcastShape(lhs_tensor->tensorInfo().shape(),
65                                         rhs_tensor->tensorInfo().shape(), success);
66     if (!success)
67     {
68       throw std::runtime_error{"Interp(Add): Fail to brodcasting"};
69     }
70
71     auto output_info =
72         ir::OperandInfo::createStaticInfo(out_shape, lhs_tensor->tensorInfo().typeInfo());
73     // We can handle already allocated (ex. model output)
74     env->allocateIfNeeded(out_index, output_info);
75   }
76   else
77   {
78     // Output's shape and type is same with input
79     auto output_info = lhs_tensor->tensorInfo();
80     // We can handle already allocated (ex. model output)
81     env->allocateIfNeeded(out_index, output_info);
82   }
83
84   auto out_tensor = env->tensorAt(out_index);
85   // Check shape and type lhs is same with output
86   // TODO Util function to compare TensorInfo
87   if (lhs_tensor->data_type() != out_tensor->data_type())
88   {
89     throw std::runtime_error{"Interp(Add): Invalid output type"};
90   }
91 }
92
93 inline void setActivationParams(float min, float max, nnfw::cker::BinaryArithmeticOpParam *params)
94 {
95   params->float_activation_min = min;
96   params->float_activation_max = max;
97 }
98
99 inline void setActivationParams(int32_t min, int32_t max,
100                                 nnfw::cker::BinaryArithmeticOpParam *params)
101 {
102   params->quantized_activation_min = min;
103   params->quantized_activation_max = max;
104 }
105
106 template <typename raw_type, typename param_type, OpType op_type>
107 void invoke(const ITensor *lhs_tensor, const ITensor *rhs_tensor, const ITensor *out_tensor,
108             const param_type &param)
109 {
110   const auto lhs_buffer = lhs_tensor->bufferRO();
111   const auto rhs_buffer = rhs_tensor->bufferRO();
112   auto out_buffer = out_tensor->buffer();
113
114   nnfw::cker::BinaryArithmeticOpParam cker_param;
115   raw_type activation_min, activation_max;
116   calculateActivationRange(param.activation, &activation_min, &activation_max);
117   setActivationParams(activation_min, activation_max, &cker_param);
118   const raw_type *lhs_ptr = reinterpret_cast<const raw_type *>(lhs_buffer);
119   const raw_type *rhs_ptr = reinterpret_cast<const raw_type *>(rhs_buffer);
120   raw_type *out_ptr = reinterpret_cast<raw_type *>(out_buffer);
121
122   const auto cker_op_type =
123       (op_type == OpType::ADD)
124           ? nnfw::cker::BinaryArithmeticOpType::ADD
125           : ((op_type == OpType::SUB) ? nnfw::cker::BinaryArithmeticOpType::SUB
126                                       : nnfw::cker::BinaryArithmeticOpType::MUL);
127
128   const bool need_broadcast = nnfw::cker::ProcessBroadcastShapes(
129       convertShape(lhs_tensor->tensorInfo().shape()),
130       convertShape(rhs_tensor->tensorInfo().shape()), &cker_param);
131
132   if (need_broadcast)
133   {
134     const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
135     const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
136     const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
137     nnfw::cker::BroadcastBinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape,
138                                                           rhs_ptr, out_shape, out_ptr);
139     return;
140   }
141
142   const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
143   const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
144   const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
145   nnfw::cker::BinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape, rhs_ptr,
146                                                out_shape, out_ptr);
147 }
148
149 template <typename node_type, typename param_type, OpType op_type>
150 void invokeAdd(const ExecEnv *env, const ir::Operation &node)
151 {
152   const auto &arithmetic_node = nnfw::misc::polymorphic_downcast<const node_type &>(node);
153
154   const auto lhs_index = node.getInputs().at(arithmetic_node.LHS);
155   const auto rhs_index = node.getInputs().at(arithmetic_node.RHS);
156   const auto out_index = node.getOutputs().at(0);
157   const auto lhs_tensor = env->tensorAt(lhs_index);
158   const auto rhs_tensor = env->tensorAt(rhs_index);
159   const auto out_tensor = env->tensorAt(out_index);
160   const auto data_type = lhs_tensor->data_type();
161
162   if (data_type == ir::DataType::INT32)
163   {
164     invoke<int32_t, param_type, op_type>(lhs_tensor, rhs_tensor, out_tensor,
165                                          arithmetic_node.param());
166   }
167   else if (data_type == ir::DataType::FLOAT32)
168   {
169     invoke<float, param_type, op_type>(lhs_tensor, rhs_tensor, out_tensor, arithmetic_node.param());
170   }
171   else
172   {
173     throw std::runtime_error{"NYI: Unsupported data type"};
174   }
175 }
176 } // namespace
177
178 OpKernel *getAdd()
179 {
180   static OpKernel kernel = {prepareAdd<ir::operation::Add>,
181                             invokeAdd<ir::operation::Add, ir::operation::Add::Param, OpType::ADD>};
182   return &kernel;
183 }
184
185 OpKernel *getSub()
186 {
187   static OpKernel kernel = {prepareAdd<ir::operation::Sub>,
188                             invokeAdd<ir::operation::Sub, ir::operation::Sub::Param, OpType::SUB>};
189   return &kernel;
190 }
191
192 OpKernel *getMul()
193 {
194   static OpKernel kernel = {prepareAdd<ir::operation::Mul>,
195                             invokeAdd<ir::operation::Mul, ir::operation::Mul::Param, OpType::MUL>};
196   return &kernel;
197 }
198
199 } // namespace interp
200 } // namespace onert