2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
19 #include "OperationUtil.h"
21 #include "interp/Registration.h"
23 #include "ir/operation/ReLU.h"
24 #include "ir/operation/ReLU1.h"
25 #include "ir/operation/ReLU6.h"
26 #include "ir/operation/Tanh.h"
35 enum class ActivationType
43 void prepare(ExecEnv *env, const ir::Operation &node)
45 const auto input_index = node.getInputs().at(0);
46 const auto output_index = node.getOutputs().at(0);
48 const auto input_tensor = env->tensorAt(input_index);
50 const auto output_info = env->graph().operands().at(output_index).info();
51 if (output_info.total_size() == 0)
53 // Output's shape and type is same with input
54 auto input_info = input_tensor->tensorInfo();
55 // We can handle already allocated (ex. model output)
56 env->allocateIfNeeded(output_index, input_info);
60 env->allocateIfNeeded(output_index, output_info);
63 const auto output_tensor = env->tensorAt(output_index);
64 // Check shape and type lhs is same with output
65 // TODO Util function to compare TensorInfo
66 if (input_tensor->data_type() != output_tensor->data_type())
68 throw std::runtime_error{"Interp(Activations): Invalid output type"};
72 template <ActivationType act_type>
73 void evalFloat(const float *input_ptr, float *output_ptr, uint64_t num_elements)
75 std::function<float(const float &)> fn = [](const float &) { return std::nanf(""); };
78 case ActivationType::ReLU:
79 fn = [](const float &in) { return std::max(0.f, in); };
81 case ActivationType::ReLU1:
82 fn = [](const float &in) { return std::min(std::max(-1.f, in), 1.f); };
84 case ActivationType::ReLU6:
85 fn = [](const float &in) { return std::min(std::max(0.f, in), 6.f); };
87 case ActivationType::Tanh:
88 fn = [](const float &in) { return std::tanh(in); };
91 throw std::runtime_error{"Interp(Activations): NYI - Unsupported activation"};
95 const float *input_end = input_ptr + num_elements;
96 for (; input_ptr < input_end; input_ptr++, output_ptr++)
98 *output_ptr = fn(*input_ptr);
102 template <ActivationType act_type> void invoke(const ExecEnv *env, const ir::Operation &node)
104 const auto input_index = node.getInputs().at(0);
105 const auto output_index = node.getOutputs().at(0);
107 // Check lhs shape is same with rhs (with broadcast)
108 const auto input_tensor = env->tensorAt(input_index);
109 const auto output_tensor = env->tensorAt(output_index);
111 const auto data_type = input_tensor->data_type();
112 if (data_type == ir::DataType::FLOAT32)
114 uint64_t elements = input_tensor->num_elements();
115 const float *input_start = reinterpret_cast<const float *>(input_tensor->bufferRO());
116 float *out = reinterpret_cast<float *>(output_tensor->buffer());
118 evalFloat<act_type>(input_start, out, elements);
122 throw std::runtime_error{"Interp(ReLU6): NYI - Support float only"};
130 static OpKernel kernel = {prepare, invoke<ActivationType::ReLU>};
136 static OpKernel kernel = {prepare, invoke<ActivationType::ReLU1>};
142 static OpKernel kernel = {prepare, invoke<ActivationType::ReLU6>};
148 static OpKernel kernel = {prepare, invoke<ActivationType::Tanh>};
152 } // namespace interp