Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / interp / operations / BinaryArithmeticOps.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <cker/operation/BinaryArithmeticOps.h>
18
19 #include "OperationUtil.h"
20
21 #include "interp/Registration.h"
22 #include "ir/operation/BinaryArithmetic.h"
23 #include "misc/polymorphic_downcast.h"
24 #include "cker/Types.h"
25
26 namespace onert
27 {
28 namespace interp
29 {
30 namespace
31 {
32
33 enum class OpType
34 {
35   ADD,
36   SUB,
37   MUL
38 };
39
40 void prepare(ExecEnv *env, const ir::Operation &node)
41 {
42   const auto &arithmetic_node =
43       nnfw::misc::polymorphic_downcast<const ir::operation::BinaryArithmetic &>(node);
44
45   const auto lhs_index = node.getInputs().at(arithmetic_node.LHS);
46   const auto rhs_index = node.getInputs().at(arithmetic_node.RHS);
47   const auto out_index = node.getOutputs().at(0);
48
49   const auto lhs_tensor = env->tensorAt(lhs_index);
50   const auto rhs_tensor = env->tensorAt(rhs_index);
51
52   // Check shape and type lhs is same with rhs
53   // TODO Util function to compare TensorInfo
54   if (lhs_tensor->data_type() != rhs_tensor->data_type())
55   {
56     throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Different input types"};
57   }
58
59   bool try_broadcast = (lhs_tensor->tensorInfo().shape() != rhs_tensor->tensorInfo().shape());
60   if (try_broadcast)
61   {
62     bool success = true;
63     auto out_shape = calcBroadcastShape(lhs_tensor->tensorInfo().shape(),
64                                         rhs_tensor->tensorInfo().shape(), success);
65     if (!success)
66     {
67       throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Fail to brodcasting"};
68     }
69
70     auto output_info =
71         ir::OperandInfo::createStaticInfo(out_shape, lhs_tensor->tensorInfo().typeInfo());
72     // We can handle already allocated (ex. model output)
73     env->allocateIfNeeded(out_index, output_info);
74   }
75   else
76   {
77     // Output's shape and type is same with input
78     auto output_info = lhs_tensor->tensorInfo();
79     // We can handle already allocated (ex. model output)
80     env->allocateIfNeeded(out_index, output_info);
81   }
82
83   auto out_tensor = env->tensorAt(out_index);
84   // Check shape and type lhs is same with output
85   // TODO Util function to compare TensorInfo
86   if (lhs_tensor->data_type() != out_tensor->data_type())
87   {
88     throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Invalid output type"};
89   }
90 }
91
92 inline void setActivationParams(float min, float max, nnfw::cker::BinaryArithmeticOpParam *params)
93 {
94   params->float_activation_min = min;
95   params->float_activation_max = max;
96 }
97
98 inline void setActivationParams(int32_t min, int32_t max,
99                                 nnfw::cker::BinaryArithmeticOpParam *params)
100 {
101   params->quantized_activation_min = min;
102   params->quantized_activation_max = max;
103 }
104
105 template <typename raw_type, OpType op_type>
106 void invoke(const ITensor *lhs_tensor, const ITensor *rhs_tensor, const ITensor *out_tensor,
107             const ir::operation::BinaryArithmetic::Param &param)
108 {
109   const auto lhs_buffer = lhs_tensor->bufferRO();
110   const auto rhs_buffer = rhs_tensor->bufferRO();
111   auto out_buffer = out_tensor->buffer();
112
113   nnfw::cker::BinaryArithmeticOpParam cker_param;
114   raw_type activation_min, activation_max;
115   calculateActivationRange(param.activation, &activation_min, &activation_max);
116   setActivationParams(activation_min, activation_max, &cker_param);
117   const raw_type *lhs_ptr = reinterpret_cast<const raw_type *>(lhs_buffer);
118   const raw_type *rhs_ptr = reinterpret_cast<const raw_type *>(rhs_buffer);
119   raw_type *out_ptr = reinterpret_cast<raw_type *>(out_buffer);
120
121   const auto cker_op_type =
122       (op_type == OpType::ADD)
123           ? nnfw::cker::BinaryArithmeticOpType::ADD
124           : ((op_type == OpType::SUB) ? nnfw::cker::BinaryArithmeticOpType::SUB
125                                       : nnfw::cker::BinaryArithmeticOpType::MUL);
126
127   const bool need_broadcast = nnfw::cker::ProcessBroadcastShapes(
128       convertShape(lhs_tensor->tensorInfo().shape()),
129       convertShape(rhs_tensor->tensorInfo().shape()), &cker_param);
130
131   if (need_broadcast)
132   {
133     const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
134     const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
135     const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
136     nnfw::cker::BroadcastBinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape,
137                                                           rhs_ptr, out_shape, out_ptr);
138     return;
139   }
140
141   const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
142   const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
143   const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
144   nnfw::cker::BinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape, rhs_ptr,
145                                                out_shape, out_ptr);
146 }
147
148 template <OpType op_type>
149 void invokeBinaryArithmetic(const ExecEnv *env, const ir::operation::BinaryArithmetic &node)
150 {
151   const auto lhs_index = node.getInputs().at(node.LHS);
152   const auto rhs_index = node.getInputs().at(node.RHS);
153   const auto out_index = node.getOutputs().at(0);
154   const auto lhs_tensor = env->tensorAt(lhs_index);
155   const auto rhs_tensor = env->tensorAt(rhs_index);
156   const auto out_tensor = env->tensorAt(out_index);
157   const auto data_type = lhs_tensor->data_type();
158
159   if (data_type == ir::DataType::INT32)
160   {
161     invoke<int32_t, op_type>(lhs_tensor, rhs_tensor, out_tensor, node.param());
162   }
163   else if (data_type == ir::DataType::FLOAT32)
164   {
165     invoke<float, op_type>(lhs_tensor, rhs_tensor, out_tensor, node.param());
166   }
167   else
168   {
169     throw std::runtime_error{"NYI: Unsupported data type"};
170   }
171 }
172
173 void invokeBinaryArithmeticOps(const ExecEnv *env, const ir::Operation &node)
174 {
175   const auto &arithmetic_node =
176       nnfw::misc::polymorphic_downcast<const ir::operation::BinaryArithmetic &>(node);
177
178   switch (arithmetic_node.param().arithmetic_type)
179   {
180     case ir::operation::BinaryArithmetic::ArithmeticType::ADD:
181       invokeBinaryArithmetic<OpType::ADD>(env, arithmetic_node);
182       break;
183     case ir::operation::BinaryArithmetic::ArithmeticType::SUB:
184       invokeBinaryArithmetic<OpType::SUB>(env, arithmetic_node);
185       break;
186     case ir::operation::BinaryArithmetic::ArithmeticType::MUL:
187       invokeBinaryArithmetic<OpType::MUL>(env, arithmetic_node);
188       break;
189     default:
190       throw std::runtime_error{"Interp(BinaryArithmetic): NYI unsupported operation " +
191                                arithmetic_node.name()};
192       break;
193   }
194 }
195
196 } // namespace
197
198 OpKernel *getBinaryArithmetic()
199 {
200   static OpKernel kernel = {prepare, invokeBinaryArithmeticOps};
201   return &kernel;
202 }
203
204 } // namespace interp
205 } // namespace onert