2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <cker/operation/BinaryArithmeticOps.h>
19 #include "OperationUtil.h"
21 #include "interp/Registration.h"
22 #include "ir/operation/BinaryArithmetic.h"
23 #include "misc/polymorphic_downcast.h"
24 #include "cker/Types.h"
40 void prepare(ExecEnv *env, const ir::Operation &node)
42 const auto &arithmetic_node =
43 nnfw::misc::polymorphic_downcast<const ir::operation::BinaryArithmetic &>(node);
45 const auto lhs_index = node.getInputs().at(arithmetic_node.LHS);
46 const auto rhs_index = node.getInputs().at(arithmetic_node.RHS);
47 const auto out_index = node.getOutputs().at(0);
49 const auto lhs_tensor = env->tensorAt(lhs_index);
50 const auto rhs_tensor = env->tensorAt(rhs_index);
52 // Check shape and type lhs is same with rhs
53 // TODO Util function to compare TensorInfo
54 if (lhs_tensor->data_type() != rhs_tensor->data_type())
56 throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Different input types"};
59 bool try_broadcast = (lhs_tensor->tensorInfo().shape() != rhs_tensor->tensorInfo().shape());
63 auto out_shape = calcBroadcastShape(lhs_tensor->tensorInfo().shape(),
64 rhs_tensor->tensorInfo().shape(), success);
67 throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Fail to brodcasting"};
71 ir::OperandInfo::createStaticInfo(out_shape, lhs_tensor->tensorInfo().typeInfo());
72 // We can handle already allocated (ex. model output)
73 env->allocateIfNeeded(out_index, output_info);
77 // Output's shape and type is same with input
78 auto output_info = lhs_tensor->tensorInfo();
79 // We can handle already allocated (ex. model output)
80 env->allocateIfNeeded(out_index, output_info);
83 auto out_tensor = env->tensorAt(out_index);
84 // Check shape and type lhs is same with output
85 // TODO Util function to compare TensorInfo
86 if (lhs_tensor->data_type() != out_tensor->data_type())
88 throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Invalid output type"};
92 inline void setActivationParams(float min, float max, nnfw::cker::BinaryArithmeticOpParam *params)
94 params->float_activation_min = min;
95 params->float_activation_max = max;
98 inline void setActivationParams(int32_t min, int32_t max,
99 nnfw::cker::BinaryArithmeticOpParam *params)
101 params->quantized_activation_min = min;
102 params->quantized_activation_max = max;
105 template <typename raw_type, OpType op_type>
106 void invoke(const ITensor *lhs_tensor, const ITensor *rhs_tensor, const ITensor *out_tensor,
107 const ir::operation::BinaryArithmetic::Param ¶m)
109 const auto lhs_buffer = lhs_tensor->bufferRO();
110 const auto rhs_buffer = rhs_tensor->bufferRO();
111 auto out_buffer = out_tensor->buffer();
113 nnfw::cker::BinaryArithmeticOpParam cker_param;
114 raw_type activation_min, activation_max;
115 calculateActivationRange(param.activation, &activation_min, &activation_max);
116 setActivationParams(activation_min, activation_max, &cker_param);
117 const raw_type *lhs_ptr = reinterpret_cast<const raw_type *>(lhs_buffer);
118 const raw_type *rhs_ptr = reinterpret_cast<const raw_type *>(rhs_buffer);
119 raw_type *out_ptr = reinterpret_cast<raw_type *>(out_buffer);
121 const auto cker_op_type =
122 (op_type == OpType::ADD)
123 ? nnfw::cker::BinaryArithmeticOpType::ADD
124 : ((op_type == OpType::SUB) ? nnfw::cker::BinaryArithmeticOpType::SUB
125 : nnfw::cker::BinaryArithmeticOpType::MUL);
127 const bool need_broadcast = nnfw::cker::ProcessBroadcastShapes(
128 convertShape(lhs_tensor->tensorInfo().shape()),
129 convertShape(rhs_tensor->tensorInfo().shape()), &cker_param);
133 const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
134 const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
135 const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
136 nnfw::cker::BroadcastBinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape,
137 rhs_ptr, out_shape, out_ptr);
141 const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
142 const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
143 const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
144 nnfw::cker::BinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape, rhs_ptr,
148 template <OpType op_type>
149 void invokeBinaryArithmetic(const ExecEnv *env, const ir::operation::BinaryArithmetic &node)
151 const auto lhs_index = node.getInputs().at(node.LHS);
152 const auto rhs_index = node.getInputs().at(node.RHS);
153 const auto out_index = node.getOutputs().at(0);
154 const auto lhs_tensor = env->tensorAt(lhs_index);
155 const auto rhs_tensor = env->tensorAt(rhs_index);
156 const auto out_tensor = env->tensorAt(out_index);
157 const auto data_type = lhs_tensor->data_type();
159 if (data_type == ir::DataType::INT32)
161 invoke<int32_t, op_type>(lhs_tensor, rhs_tensor, out_tensor, node.param());
163 else if (data_type == ir::DataType::FLOAT32)
165 invoke<float, op_type>(lhs_tensor, rhs_tensor, out_tensor, node.param());
169 throw std::runtime_error{"NYI: Unsupported data type"};
173 void invokeBinaryArithmeticOps(const ExecEnv *env, const ir::Operation &node)
175 const auto &arithmetic_node =
176 nnfw::misc::polymorphic_downcast<const ir::operation::BinaryArithmetic &>(node);
178 switch (arithmetic_node.param().arithmetic_type)
180 case ir::operation::BinaryArithmetic::ArithmeticType::ADD:
181 invokeBinaryArithmetic<OpType::ADD>(env, arithmetic_node);
183 case ir::operation::BinaryArithmetic::ArithmeticType::SUB:
184 invokeBinaryArithmetic<OpType::SUB>(env, arithmetic_node);
186 case ir::operation::BinaryArithmetic::ArithmeticType::MUL:
187 invokeBinaryArithmetic<OpType::MUL>(env, arithmetic_node);
190 throw std::runtime_error{"Interp(BinaryArithmetic): NYI unsupported operation " +
191 arithmetic_node.name()};
198 OpKernel *getBinaryArithmetic()
200 static OpKernel kernel = {prepare, invokeBinaryArithmeticOps};
204 } // namespace interp