Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / backend / cpu / ops / SoftMaxLayer.cc
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SoftMaxLayer.h"
18
19 #include "OperationUtils.h"
20
21 #include <cker/operation/SoftMax.h>
22
23 namespace onert
24 {
25 namespace backend
26 {
27 namespace cpu
28 {
29 namespace ops
30 {
31
32 SoftMaxLayer::SoftMaxLayer() : _input(nullptr), _output(nullptr), _beta(0.0)
33 {
34   // DO NOTHING
35 }
36
37 void SoftMaxLayer::softmaxFloat32()
38 {
39   if (getNumberOfDimensions(_input) == 1)
40   {
41     uint32_t input_size = getNumberOfElements(_input);
42     nnfw::cker::Softmax(getBuffer<float>(_input), input_size, 1, _beta, getBuffer<float>(_output));
43   }
44   else if (getNumberOfDimensions(_input) == 2)
45   {
46     uint32_t batch_size = getSizeOfDimension(_input, 0);
47     if (batch_size == 0)
48       throw std::runtime_error("batch_size should not be 0");
49
50     uint32_t input_size = getNumberOfElements(_input) / batch_size;
51     nnfw::cker::Softmax(getBuffer<float>(_input), input_size, batch_size, _beta,
52                         getBuffer<float>(_output));
53   }
54   else if (getNumberOfDimensions(_input) == 4)
55   {
56     nnfw::cker::SoftmaxParams op_params;
57     op_params.beta = _beta;
58     nnfw::cker::Softmax(op_params, getShape(_input), getBuffer<float>(_input), getShape(_output),
59                         getBuffer<float>(_output));
60   }
61   else
62   {
63     nnfw::cker::SoftmaxParams op_params;
64     op_params.beta = _beta;
65     nnfw::cker::reference::Softmax(op_params, getShape(_input), getBuffer<float>(_input),
66                                    getShape(_output), getBuffer<float>(_output));
67   }
68 }
69
70 template <typename T> void SoftMaxLayer::softmaxQuant8()
71 {
72   nnfw::cker::SoftmaxParams op_params;
73   op_params.scale = _output->data_scale();
74   op_params.zero_point = _output->data_zero_point();
75   op_params.uint8_table1 = _uint8_table1;
76   op_params.uint8_table2 = _uint8_table2;
77   op_params.table = _table;
78
79 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
80   nnfw::cker::SoftmaxInt8LUT<T, T>(op_params, getShape(_input), getBuffer<T>(_input),
81                                    getShape(_output), getBuffer<T>(_output));
82 #else
83   nnfw::cker::Softmax<T, T>(op_params, getShape(_input), getBuffer<T>(_input), getShape(_output),
84                             getBuffer<T>(_output));
85 #endif
86 }
87
88 void SoftMaxLayer::configure(const IPortableTensor *input, const float beta,
89                              IPortableTensor *output)
90 {
91   _input = input;
92   _output = output;
93   _beta = beta;
94
95   if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM ||
96       _input->data_type() == OperandType::QUANT_INT8_ASYMM)
97   {
98 #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT
99     // Only apply when both input & output are uint8/int8 & build with clang
100     // on aarch64.
101     nnfw::cker::PopulateSoftmaxUInt8LookupTable(_uint8_table1, _uint8_table2, _input->data_scale(),
102                                                 _beta);
103 #else
104     nnfw::cker::PopulateSoftmaxLookupTable(_table, _input->data_scale(), _beta);
105 #endif
106   }
107 }
108
109 void SoftMaxLayer::run()
110 {
111   switch (_input->data_type())
112   {
113     case OperandType::FLOAT32:
114       softmaxFloat32();
115       break;
116     case OperandType::QUANT_UINT8_ASYMM:
117       softmaxQuant8<uint8_t>();
118       break;
119     case OperandType::QUANT_INT8_ASYMM:
120       softmaxQuant8<int8_t>();
121       break;
122     default:
123       throw std::runtime_error{"SoftMax: unsupported data type"};
124   }
125 }
126
127 } // namespace ops
128 } // namespace cpu
129 } // namespace backend
130 } // namespace onert