Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / interp / operations / Softmax.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <cker/operation/SoftMax.h>
18
19 #include "OperationUtil.h"
20
21 #include "interp/Registration.h"
22 #include "ir/operation/Softmax.h"
23 #include "misc/polymorphic_downcast.h"
24
25 namespace onert
26 {
27 namespace interp
28 {
29 namespace
30 {
31
32 void prepareSoftMax(ExecEnv *env, const ir::Operation &node)
33 {
34   const auto in_index = node.getInputs().at(0);
35   const auto out_index = node.getOutputs().at(0);
36
37   const auto in_tensor = env->tensorAt(in_index);
38   UNUSED_RELEASE(in_tensor);
39
40   assert((in_tensor->num_dimensions() == 4) || (in_tensor->num_dimensions() == 2));
41
42   // Output shape should be same with input
43   // Output type is pre-defined in model
44   const auto output_shape = env->graph().operands().at(in_index).info().shape();
45   const auto output_type = env->graph().operands().at(out_index).info().typeInfo();
46
47   const auto output_info = ir::OperandInfo::createStaticInfo(output_shape, output_type);
48   env->allocateIfNeeded(out_index, output_info);
49
50   auto out_tensor = env->tensorAt(out_index);
51   UNUSED_RELEASE(out_tensor);
52
53   // Check output shape is same with input
54   assert(out_tensor->num_dimensions() == out_tensor->num_dimensions());
55   for (uint32_t i = 0; i < in_tensor->num_dimensions(); i++)
56   {
57     assert(in_tensor->dimension(i) == out_tensor->dimension(i));
58   }
59 }
60
61 void invoke(const ITensor *in_tensor, const ITensor *out_tensor,
62             const ir::operation::Softmax::Param &param)
63 {
64   const float *in_ptr = reinterpret_cast<const float *>(in_tensor->bufferRO());
65   float *out_ptr = reinterpret_cast<float *>(out_tensor->buffer());
66
67   float beta = param.beta;
68
69   if (in_tensor->num_dimensions() == 2)
70   {
71     uint32_t batch_size = in_tensor->dimension(0);
72     uint32_t input_size = in_tensor->dimension(1);
73
74     nnfw::cker::Softmax(in_ptr, input_size, batch_size, beta, out_ptr);
75   }
76   else if (in_tensor->num_dimensions() == 4)
77   {
78     const auto in_shape = convertShape(in_tensor->tensorInfo().shape());
79     const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
80
81     nnfw::cker::SoftmaxParams cker_param;
82     cker_param.beta = beta;
83
84     nnfw::cker::Softmax(cker_param, in_shape, in_ptr, out_shape, out_ptr);
85   }
86   else
87   {
88     throw std::runtime_error{"Unsuported input dimension: support 2D or 4D"};
89   }
90 }
91
92 void invokeSoftMax(const ExecEnv *env, const ir::Operation &node)
93 {
94   const auto &softmax_node = nnfw::misc::polymorphic_downcast<const ir::operation::Softmax &>(node);
95
96   const auto in_index = node.getInputs().at(0);
97   const auto out_index = node.getOutputs().at(0);
98
99   const auto in_tensor = env->tensorAt(in_index);
100   const auto out_tensor = env->tensorAt(out_index);
101
102   const auto in_data_type = in_tensor->data_type();
103   const auto out_data_type = out_tensor->data_type();
104   if ((in_data_type == ir::DataType::FLOAT32) && (out_data_type == ir::DataType::FLOAT32))
105   {
106     invoke(in_tensor, out_tensor, softmax_node.param());
107   }
108   else
109   {
110     throw std::runtime_error{"NYI: Support float32 only"};
111   }
112 }
113
114 } // namespace
115
116 OpKernel *getSoftmax()
117 {
118   static OpKernel kernel = {prepareSoftMax, invokeSoftMax};
119   return &kernel;
120 }
121
122 } // namespace interp
123 } // namespace onert