2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <cker/operation/BinaryArithmeticOps.h>
19 #include "OperationUtil.h"
21 #include "interp/Registration.h"
22 #include "ir/operation/Add.h"
23 #include "ir/operation/Sub.h"
24 #include "ir/operation/Mul.h"
25 #include "misc/polymorphic_downcast.h"
26 #include "cker/Types.h"
42 template <typename node_type> void prepareAdd(ExecEnv *env, const ir::Operation &node)
44 const auto &add_node = nnfw::misc::polymorphic_downcast<const node_type &>(node);
46 const auto lhs_index = node.getInputs().at(add_node.LHS);
47 const auto rhs_index = node.getInputs().at(add_node.RHS);
48 const auto out_index = node.getOutputs().at(0);
50 const auto lhs_tensor = env->tensorAt(lhs_index);
51 const auto rhs_tensor = env->tensorAt(rhs_index);
53 // Check shape and type lhs is same with rhs
54 // TODO Util function to compare TensorInfo
55 if (lhs_tensor->data_type() != rhs_tensor->data_type())
57 throw std::runtime_error{"Interp(Add): Different input types"};
60 bool try_broadcast = (lhs_tensor->tensorInfo().shape() != rhs_tensor->tensorInfo().shape());
64 auto out_shape = calcBroadcastShape(lhs_tensor->tensorInfo().shape(),
65 rhs_tensor->tensorInfo().shape(), success);
68 throw std::runtime_error{"Interp(Add): Fail to brodcasting"};
72 ir::OperandInfo::createStaticInfo(out_shape, lhs_tensor->tensorInfo().typeInfo());
73 // We can handle already allocated (ex. model output)
74 env->allocateIfNeeded(out_index, output_info);
78 // Output's shape and type is same with input
79 auto output_info = lhs_tensor->tensorInfo();
80 // We can handle already allocated (ex. model output)
81 env->allocateIfNeeded(out_index, output_info);
84 auto out_tensor = env->tensorAt(out_index);
85 // Check shape and type lhs is same with output
86 // TODO Util function to compare TensorInfo
87 if (lhs_tensor->data_type() != out_tensor->data_type())
89 throw std::runtime_error{"Interp(Add): Invalid output type"};
93 inline void setActivationParams(float min, float max, nnfw::cker::BinaryArithmeticOpParam *params)
95 params->float_activation_min = min;
96 params->float_activation_max = max;
99 inline void setActivationParams(int32_t min, int32_t max,
100 nnfw::cker::BinaryArithmeticOpParam *params)
102 params->quantized_activation_min = min;
103 params->quantized_activation_max = max;
106 template <typename raw_type, typename param_type, OpType op_type>
107 void invoke(const ITensor *lhs_tensor, const ITensor *rhs_tensor, const ITensor *out_tensor,
108 const param_type ¶m)
110 const auto lhs_buffer = lhs_tensor->bufferRO();
111 const auto rhs_buffer = rhs_tensor->bufferRO();
112 auto out_buffer = out_tensor->buffer();
114 nnfw::cker::BinaryArithmeticOpParam cker_param;
115 raw_type activation_min, activation_max;
116 calculateActivationRange(param.activation, &activation_min, &activation_max);
117 setActivationParams(activation_min, activation_max, &cker_param);
118 const raw_type *lhs_ptr = reinterpret_cast<const raw_type *>(lhs_buffer);
119 const raw_type *rhs_ptr = reinterpret_cast<const raw_type *>(rhs_buffer);
120 raw_type *out_ptr = reinterpret_cast<raw_type *>(out_buffer);
122 const auto cker_op_type =
123 (op_type == OpType::ADD)
124 ? nnfw::cker::BinaryArithmeticOpType::ADD
125 : ((op_type == OpType::SUB) ? nnfw::cker::BinaryArithmeticOpType::SUB
126 : nnfw::cker::BinaryArithmeticOpType::MUL);
128 const bool need_broadcast = nnfw::cker::ProcessBroadcastShapes(
129 convertShape(lhs_tensor->tensorInfo().shape()),
130 convertShape(rhs_tensor->tensorInfo().shape()), &cker_param);
134 const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
135 const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
136 const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
137 nnfw::cker::BroadcastBinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape,
138 rhs_ptr, out_shape, out_ptr);
142 const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
143 const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
144 const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
145 nnfw::cker::BinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape, rhs_ptr,
149 template <typename node_type, typename param_type, OpType op_type>
150 void invokeAdd(const ExecEnv *env, const ir::Operation &node)
152 const auto &arithmetic_node = nnfw::misc::polymorphic_downcast<const node_type &>(node);
154 const auto lhs_index = node.getInputs().at(arithmetic_node.LHS);
155 const auto rhs_index = node.getInputs().at(arithmetic_node.RHS);
156 const auto out_index = node.getOutputs().at(0);
157 const auto lhs_tensor = env->tensorAt(lhs_index);
158 const auto rhs_tensor = env->tensorAt(rhs_index);
159 const auto out_tensor = env->tensorAt(out_index);
160 const auto data_type = lhs_tensor->data_type();
162 if (data_type == ir::DataType::INT32)
164 invoke<int32_t, param_type, op_type>(lhs_tensor, rhs_tensor, out_tensor,
165 arithmetic_node.param());
167 else if (data_type == ir::DataType::FLOAT32)
169 invoke<float, param_type, op_type>(lhs_tensor, rhs_tensor, out_tensor, arithmetic_node.param());
173 throw std::runtime_error{"NYI: Unsupported data type"};
180 static OpKernel kernel = {prepareAdd<ir::operation::Add>,
181 invokeAdd<ir::operation::Add, ir::operation::Add::Param, OpType::ADD>};
187 static OpKernel kernel = {prepareAdd<ir::operation::Sub>,
188 invokeAdd<ir::operation::Sub, ir::operation::Sub::Param, OpType::SUB>};
194 static OpKernel kernel = {prepareAdd<ir::operation::Mul>,
195 invokeAdd<ir::operation::Mul, ir::operation::Mul::Param, OpType::MUL>};
199 } // namespace interp