2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <cker/operation/SoftMax.h>
19 #include "OperationUtil.h"
21 #include "interp/Registration.h"
22 #include "ir/operation/Softmax.h"
23 #include "misc/polymorphic_downcast.h"
32 void prepareSoftMax(ExecEnv *env, const ir::Operation &node)
34 const auto in_index = node.getInputs().at(0);
35 const auto out_index = node.getOutputs().at(0);
37 const auto in_tensor = env->tensorAt(in_index);
38 UNUSED_RELEASE(in_tensor);
40 assert((in_tensor->num_dimensions() == 4) || (in_tensor->num_dimensions() == 2));
42 // Output shape should be same with input
43 // Output type is pre-defined in model
44 const auto output_shape = env->graph().operands().at(in_index).info().shape();
45 const auto output_type = env->graph().operands().at(out_index).info().typeInfo();
47 const auto output_info = ir::OperandInfo::createStaticInfo(output_shape, output_type);
48 env->allocateIfNeeded(out_index, output_info);
50 auto out_tensor = env->tensorAt(out_index);
51 UNUSED_RELEASE(out_tensor);
53 // Check output shape is same with input
54 assert(out_tensor->num_dimensions() == out_tensor->num_dimensions());
55 for (uint32_t i = 0; i < in_tensor->num_dimensions(); i++)
57 assert(in_tensor->dimension(i) == out_tensor->dimension(i));
61 void invoke(const ITensor *in_tensor, const ITensor *out_tensor,
62 const ir::operation::Softmax::Param ¶m)
64 const float *in_ptr = reinterpret_cast<const float *>(in_tensor->bufferRO());
65 float *out_ptr = reinterpret_cast<float *>(out_tensor->buffer());
67 float beta = param.beta;
69 if (in_tensor->num_dimensions() == 2)
71 uint32_t batch_size = in_tensor->dimension(0);
72 uint32_t input_size = in_tensor->dimension(1);
74 nnfw::cker::Softmax(in_ptr, input_size, batch_size, beta, out_ptr);
76 else if (in_tensor->num_dimensions() == 4)
78 const auto in_shape = convertShape(in_tensor->tensorInfo().shape());
79 const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
81 nnfw::cker::SoftmaxParams cker_param;
82 cker_param.beta = beta;
84 nnfw::cker::Softmax(cker_param, in_shape, in_ptr, out_shape, out_ptr);
88 throw std::runtime_error{"Unsuported input dimension: support 2D or 4D"};
92 void invokeSoftMax(const ExecEnv *env, const ir::Operation &node)
94 const auto &softmax_node = nnfw::misc::polymorphic_downcast<const ir::operation::Softmax &>(node);
96 const auto in_index = node.getInputs().at(0);
97 const auto out_index = node.getOutputs().at(0);
99 const auto in_tensor = env->tensorAt(in_index);
100 const auto out_tensor = env->tensorAt(out_index);
102 const auto in_data_type = in_tensor->data_type();
103 const auto out_data_type = out_tensor->data_type();
104 if ((in_data_type == ir::DataType::FLOAT32) && (out_data_type == ir::DataType::FLOAT32))
106 invoke(in_tensor, out_tensor, softmax_node.param());
110 throw std::runtime_error{"NYI: Support float32 only"};
116 OpKernel *getSoftmax()
118 static OpKernel kernel = {prepareSoftMax, invokeSoftMax};
122 } // namespace interp