4e8cba82b4abf6081da1189251fba66bfc32abfd
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Logistic.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Builders.h"
19 #include "kernels/Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/reference/logistic.h>
22
23 namespace luci_interpreter
24 {
25 namespace
26 {
27
28 #ifndef DIS_FLOAT
29 void evalFloat(const circle::Tensor *input, const circle::Tensor *output, bool is_inplace,
30                BaseRuntimeGraph *runtime_graph)
31 {
32   const float *input_data = reinterpret_cast<const float *>(runtime_graph->getDataByTensor(input));
33   float *output_data = reinterpret_cast<float *>(runtime_graph->getDataByTensor(output));
34
35   if (is_inplace)
36   {
37     output_data = const_cast<float *>(input_data);
38   }
39
40   assert(input_data != nullptr);
41   assert(output_data != nullptr);
42
43   tflite::reference_ops::Logistic(kernels::getTensorShape(input), input_data,
44                                   kernels::getTensorShape(output), output_data);
45   if (is_inplace)
46   {
47     runtime_graph->makeInplaceOperation(input, output);
48   }
49 }
50 #endif // DIS_FLOAT
51
52 #ifndef DIS_QUANT
53 void evalQuantized(const circle::Tensor *input, const circle::Tensor *output, bool is_inplace,
54                    BaseRuntimeGraph *runtime_graph)
55 {
56   const int8_t *input_data =
57     reinterpret_cast<const int8_t *>(runtime_graph->getDataByTensor(input));
58   int8_t *output_data = reinterpret_cast<int8_t *>(runtime_graph->getDataByTensor(output));
59   if (is_inplace)
60     output_data = const_cast<int8_t *>(input_data);
61
62   tflite::reference_ops::Logistic(kernels::getTensorShape(input), input_data, Tensor::scale(input),
63                                   Tensor::zero_point(input), kernels::getTensorShape(output),
64                                   output_data, Tensor::scale(output), Tensor::zero_point(output));
65   if (is_inplace)
66   {
67     runtime_graph->makeInplaceOperation(input, output);
68   }
69 }
70 #endif // DIS_QUANT
71
72 } // namespace
73
74 void configure_kernel_CircleLogistic(const circle::Operator *cur_op,
75                                      BaseRuntimeGraph *runtime_graph)
76 {
77   const auto input_index = cur_op->inputs()->operator[](0);
78   const auto output_index = cur_op->outputs()->operator[](0);
79
80   assert(input_index != -1);
81   assert(output_index != -1);
82
83   const auto input = runtime_graph->getCircleTensorByIndex(input_index);
84   auto output = runtime_graph->getCircleTensorByIndex(output_index);
85
86   assert(input != nullptr);
87   assert(output != nullptr);
88
89   LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == Tensor::element_type(output));
90
91 #ifndef DIS_QUANT
92   if (Tensor::element_type(input) == DataType::U8)
93   {
94     LUCI_INTERPRETER_CHECK(Tensor::scale(output) == 1. / 256);
95   }
96 #endif // DIS_QUANT
97 }
98
99 void execute_kernel_CircleLogistic(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
100                                    bool is_inplace)
101 {
102   const auto input_index = cur_op->inputs()->operator[](0);
103   const auto output_index = cur_op->outputs()->operator[](0);
104
105   assert(input_index != -1);
106   assert(output_index != -1);
107
108   const auto input = runtime_graph->getCircleTensorByIndex(input_index);
109   auto output = runtime_graph->getCircleTensorByIndex(output_index);
110
111   assert(input != nullptr);
112   assert(output != nullptr);
113
114   switch (Tensor::element_type(input))
115   {
116 #ifndef DIS_FLOAT
117     case DataType::FLOAT32:
118       evalFloat(input, output, is_inplace, runtime_graph);
119       break;
120 #endif // DIS_FLOAT
121 #ifndef DIS_QUANT
122     case DataType::S8:
123       evalQuantized(input, output, is_inplace, runtime_graph);
124       break;
125 #endif // DIS_QUANT
126     default:
127       assert(false && "Unsupported type.");
128   }
129 }
130
131 } // namespace luci_interpreter