2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <cker/operation/SoftMax.h>
19 #include "OperationUtil.h"
21 #include "interp/Registration.h"
22 #include "ir/operation/Softmax.h"
23 #include "misc/polymorphic_downcast.h"
32 void Softmax2D(const float *in, const int input_size, const int batch_size, const float beta,
35 assert(input_size > 0);
38 for (int b = 0; b < batch_size; b++)
40 // Find the max coeff.
41 float max_coeff = in[0];
42 for (int i = 1; i < input_size; i++)
44 if (in[i] > max_coeff)
48 // Compute the normalized sum of exps.
50 for (int i = 0; i < input_size; i++)
52 out[i] = std::exp((in[i] - max_coeff) * beta);
56 // Divide by the sum of exps.
57 float reciprocal_sum_exp = 1.f / exp_sum;
58 for (int i = 0; i < input_size; i++)
60 out[i] *= reciprocal_sum_exp;
63 // Advance in and out pointers for the next batch.
69 void prepareSoftMax(ExecEnv *env, const ir::Operation &node)
71 const auto in_index = node.getInputs().at(0);
72 const auto out_index = node.getOutputs().at(0);
74 const auto in_tensor = env->tensorAt(in_index);
75 UNUSED_RELEASE(in_tensor);
77 assert((in_tensor->num_dimensions() == 4) || (in_tensor->num_dimensions() == 2));
79 // Output shape should be same with input
80 // Output type is pre-defined in model
81 const auto output_shape = env->graph().operands().at(in_index).info().shape();
82 const auto output_type = env->graph().operands().at(out_index).info().typeInfo();
84 const auto output_info = ir::OperandInfo::createStaticInfo(output_shape, output_type);
85 env->allocateIfNeeded(out_index, output_info);
87 auto out_tensor = env->tensorAt(out_index);
88 UNUSED_RELEASE(out_tensor);
90 // Check output shape is same with input
91 assert(out_tensor->num_dimensions() == out_tensor->num_dimensions());
92 for (uint32_t i = 0; i < in_tensor->num_dimensions(); i++)
94 assert(in_tensor->dimension(i) == out_tensor->dimension(i));
98 void invoke(const ITensor *in_tensor, const ITensor *out_tensor,
99 const ir::operation::Softmax::Param ¶m)
101 const float *in_ptr = reinterpret_cast<const float *>(in_tensor->bufferRO());
102 float *out_ptr = reinterpret_cast<float *>(out_tensor->buffer());
104 float beta = param.beta;
106 if (in_tensor->num_dimensions() == 2)
108 uint32_t batch_size = in_tensor->dimension(0);
109 uint32_t input_size = in_tensor->dimension(1);
111 Softmax2D(in_ptr, input_size, batch_size, beta, out_ptr);
113 else if (in_tensor->num_dimensions() == 4)
115 const auto in_shape = convertShape(in_tensor->tensorInfo().shape());
116 const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
118 nnfw::cker::SoftmaxParams cker_param;
119 cker_param.beta = beta;
121 nnfw::cker::Softmax(cker_param, in_shape, in_ptr, out_shape, out_ptr);
125 throw std::runtime_error{"Unsuported input dimension: support 2D or 4D"};
129 void invokeSoftMax(const ExecEnv *env, const ir::Operation &node)
131 const auto &softmax_node = nnfw::misc::polymorphic_downcast<const ir::operation::Softmax &>(node);
133 const auto in_index = node.getInputs().at(0);
134 const auto out_index = node.getOutputs().at(0);
136 const auto in_tensor = env->tensorAt(in_index);
137 const auto out_tensor = env->tensorAt(out_index);
139 const auto in_data_type = in_tensor->data_type();
140 const auto out_data_type = out_tensor->data_type();
141 if ((in_data_type == ir::DataType::FLOAT32) && (out_data_type == ir::DataType::FLOAT32))
143 invoke(in_tensor, out_tensor, softmax_node.param());
147 throw std::runtime_error{"NYI: Support float32 only"};
153 OpKernel *getSoftmax()
155 static OpKernel kernel = {prepareSoftMax, invokeSoftMax};
159 } // namespace interp