304939ee8c00acaaaeb32eb32f9a59065fe47f05
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / NotEqual.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/NotEqual.h"
19 #include "kernels/Utils.h"
20
21 #include <tensorflow/lite/kernels/internal/reference/comparisons.h>
22
23 namespace luci_interpreter
24 {
25
26 namespace kernels
27 {
28
29 NotEqual::NotEqual(const Tensor *x, const Tensor *y, Tensor *output) : Kernel({x, y}, {output}) {}
30
31 void NotEqual::configure()
32 {
33   LUCI_INTERPRETER_CHECK(x()->element_type() == y()->element_type());
34   LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::BOOL);
35
36   if (x()->element_type() == DataType::U8)
37   {
38     quantizeMultiplierSmallerThanOneExp(x()->scale(), &_x_multiplier, &_x_shift);
39     quantizeMultiplierSmallerThanOneExp(y()->scale(), &_y_multiplier, &_y_shift);
40   }
41   // TODO: enable it only if kernel with dynamic shapes
42   output()->resize(calculateShapeForBroadcast(x()->shape(), y()->shape()));
43 }
44
45 void NotEqual::execute() const
46 {
47   switch (x()->element_type())
48   {
49     case DataType::FLOAT32:
50       evalFloat();
51       break;
52     case DataType::S64:
53       evalInteger<int64_t>();
54       break;
55     case DataType::S32:
56       evalInteger<int32_t>();
57       break;
58     case DataType::U8:
59       evalQuantized();
60       break;
61     default:
62       assert(false && "Unsupported type.");
63   }
64 }
65
66 void NotEqual::evalFloat() const
67 {
68   const auto x_data = getTensorData<float>(x());
69   const auto y_data = getTensorData<float>(y());
70   auto output_data = getTensorData<bool>(output());
71
72   tflite::ComparisonParams op_params;
73   op_params.is_broadcast = x()->shape() != y()->shape();
74
75   if (op_params.is_broadcast)
76   {
77     tflite::reference_ops::Broadcast4DSlowNotEqual(op_params, getTensorShape(x()), x_data,
78                                                    getTensorShape(y()), y_data,
79                                                    getTensorShape(output()), output_data);
80   }
81   else
82   {
83     tflite::reference_ops::NotEqual(op_params, getTensorShape(x()), x_data, getTensorShape(y()),
84                                     y_data, getTensorShape(output()), output_data);
85   }
86 }
87
88 template <typename T> void NotEqual::evalInteger() const
89 {
90   const auto x_data = getTensorData<T>(x());
91   const auto y_data = getTensorData<T>(y());
92   auto output_data = getTensorData<bool>(output());
93
94   tflite::ComparisonParams op_params;
95   op_params.is_broadcast = x()->shape() != y()->shape();
96
97   if (op_params.is_broadcast)
98   {
99     tflite::reference_ops::Broadcast4DSlowNotEqualNoScaling(op_params, getTensorShape(x()), x_data,
100                                                             getTensorShape(y()), y_data,
101                                                             getTensorShape(output()), output_data);
102   }
103   else
104   {
105     tflite::reference_ops::NotEqualNoScaling(op_params, getTensorShape(x()), x_data,
106                                              getTensorShape(y()), y_data, getTensorShape(output()),
107                                              output_data);
108   }
109 }
110
111 void NotEqual::evalQuantized() const
112 {
113   const auto x_data = getTensorData<uint8_t>(x());
114   const auto y_data = getTensorData<uint8_t>(y());
115   auto output_data = getTensorData<bool>(output());
116
117   tflite::ComparisonParams op_params;
118   op_params.left_shift = 8;
119   op_params.input1_offset = -x()->zero_point(); // Note the '-'
120   op_params.input1_shift = _x_shift;
121   op_params.input1_multiplier = _x_multiplier;
122   op_params.input2_offset = -y()->zero_point(); // Note the '-'
123   op_params.input2_shift = _y_shift;
124   op_params.input2_multiplier = _y_multiplier;
125   op_params.is_broadcast = x()->shape() != y()->shape();
126
127   if (op_params.is_broadcast)
128   {
129     tflite::reference_ops::Broadcast4DSlowNotEqualWithScaling(
130       op_params, getTensorShape(x()), x_data, getTensorShape(y()), y_data, getTensorShape(output()),
131       output_data);
132   }
133   else
134   {
135     tflite::reference_ops::NotEqualWithScaling(op_params, getTensorShape(x()), x_data,
136                                                getTensorShape(y()), y_data,
137                                                getTensorShape(output()), output_data);
138   }
139 }
140
141 } // namespace kernels
142 } // namespace luci_interpreter