2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include "OperationUtil.h"
21 #include "interp/Registration.h"
23 #include "ir/operation/ElementwiseActivation.h"
25 #include <misc/polymorphic_downcast.h>
26 #include <cker/operation/Logistic.h>
27 #include <cker/operation/Tanh.h>
36 enum class ActivationType
43 void prepare(ExecEnv *env, const ir::Operation &node)
45 const auto input_index = node.getInputs().at(0);
46 const auto output_index = node.getOutputs().at(0);
48 const auto input_tensor = env->tensorAt(input_index);
50 const auto output_info = env->graph().operands().at(output_index).info();
51 if (output_info.total_size() == 0)
53 // Output's shape and type is same with input
54 auto input_info = input_tensor->tensorInfo();
55 // We can handle already allocated (ex. model output)
56 env->allocateIfNeeded(output_index, input_info);
60 env->allocateIfNeeded(output_index, output_info);
63 const auto output_tensor = env->tensorAt(output_index);
64 // Check shape and type lhs is same with output
65 // TODO Util function to compare TensorInfo
66 if (input_tensor->data_type() != output_tensor->data_type())
68 throw std::runtime_error{"Interp(ElementwiseActivation): Invalid output type"};
72 template <ActivationType act_type>
73 void evalFloat(const float *input_ptr, float *output_ptr, uint64_t num_elements, float alpha,
76 std::function<float(const float &)> fn = [](const float &) { return std::nanf(""); };
79 case ActivationType::ReLU:
80 fn = [alpha, beta](const float &in) { return std::min(std::max(beta, in), alpha); };
82 case ActivationType::Tanh:
83 fn = [](const float &in) { return std::tanh(in); };
86 throw std::runtime_error{"Interp(ElementwiseActivation): NYI - Unsupported activation"};
90 const float *input_end = input_ptr + num_elements;
91 for (; input_ptr < input_end; input_ptr++, output_ptr++)
93 *output_ptr = fn(*input_ptr);
97 template <ActivationType act_type> void invoke(const ExecEnv *env, const ir::Operation &node)
99 const auto input_index = node.getInputs().at(0);
100 const auto output_index = node.getOutputs().at(0);
102 // Check lhs shape is same with rhs (with broadcast)
103 const auto input_tensor = env->tensorAt(input_index);
104 const auto output_tensor = env->tensorAt(output_index);
106 const auto data_type = input_tensor->data_type();
107 if (data_type == ir::DataType::FLOAT32)
109 uint64_t elements = input_tensor->num_elements();
110 const float *input_start = reinterpret_cast<const float *>(input_tensor->bufferRO());
111 float *out = reinterpret_cast<float *>(output_tensor->buffer());
112 if (act_type == ActivationType::Logistic)
114 const auto cker_input_shape = convertShape(input_tensor->tensorInfo().shape());
115 const auto cker_output_shape = convertShape(output_tensor->tensorInfo().shape());
116 nnfw::cker::Logistic(cker_input_shape, input_start, cker_output_shape, out);
120 const auto &act_node =
121 nnfw::misc::polymorphic_downcast<const ir::operation::ElementwiseActivation &>(node);
122 evalFloat<act_type>(input_start, out, elements, act_node.param().alpha,
123 act_node.param().beta);
128 throw std::runtime_error{"Interp(" + node.name() + "): NYI - Support float only"};
132 void invokeElementwiseActivation(const ExecEnv *env, const ir::Operation &node)
134 const auto &act_node =
135 nnfw::misc::polymorphic_downcast<const ir::operation::ElementwiseActivation &>(node);
136 switch (act_node.param().op_type)
138 case ir::operation::ElementwiseActivation::Type::LOGISTIC:
139 invoke<ActivationType::Logistic>(env, node);
141 case ir::operation::ElementwiseActivation::Type::RELU:
142 invoke<ActivationType::ReLU>(env, node);
144 case ir::operation::ElementwiseActivation::Type::TANH:
145 invoke<ActivationType::Tanh>(env, node);
148 throw std::runtime_error("Interp(" + node.name() + "): NYI - Unsupported activation");
154 OpKernel *getElementwiseActivation()
156 static OpKernel kernel = {prepare, invokeElementwiseActivation};
160 } // namespace interp