2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #include "kernels/NotEqual.h"
19 #include "kernels/Utils.h"
21 #include <tensorflow/lite/kernels/internal/reference/comparisons.h>
23 namespace luci_interpreter
29 NotEqual::NotEqual(const Tensor *x, const Tensor *y, Tensor *output) : Kernel({x, y}, {output}) {}
31 void NotEqual::configure()
33 LUCI_INTERPRETER_CHECK(x()->element_type() == y()->element_type());
34 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::BOOL);
36 if (x()->element_type() == DataType::U8)
38 quantizeMultiplierSmallerThanOneExp(x()->scale(), &_x_multiplier, &_x_shift);
39 quantizeMultiplierSmallerThanOneExp(y()->scale(), &_y_multiplier, &_y_shift);
41 // TODO: enable it only if kernel with dynamic shapes
42 output()->resize(calculateShapeForBroadcast(x()->shape(), y()->shape()));
45 void NotEqual::execute() const
47 switch (x()->element_type())
49 case DataType::FLOAT32:
53 evalInteger<int64_t>();
56 evalInteger<int32_t>();
62 assert(false && "Unsupported type.");
66 void NotEqual::evalFloat() const
68 const auto x_data = getTensorData<float>(x());
69 const auto y_data = getTensorData<float>(y());
70 auto output_data = getTensorData<bool>(output());
72 tflite::ComparisonParams op_params;
73 op_params.is_broadcast = x()->shape() != y()->shape();
75 if (op_params.is_broadcast)
77 tflite::reference_ops::Broadcast4DSlowNotEqual(op_params, getTensorShape(x()), x_data,
78 getTensorShape(y()), y_data,
79 getTensorShape(output()), output_data);
83 tflite::reference_ops::NotEqual(op_params, getTensorShape(x()), x_data, getTensorShape(y()),
84 y_data, getTensorShape(output()), output_data);
88 template <typename T> void NotEqual::evalInteger() const
90 const auto x_data = getTensorData<T>(x());
91 const auto y_data = getTensorData<T>(y());
92 auto output_data = getTensorData<bool>(output());
94 tflite::ComparisonParams op_params;
95 op_params.is_broadcast = x()->shape() != y()->shape();
97 if (op_params.is_broadcast)
99 tflite::reference_ops::Broadcast4DSlowNotEqualNoScaling(op_params, getTensorShape(x()), x_data,
100 getTensorShape(y()), y_data,
101 getTensorShape(output()), output_data);
105 tflite::reference_ops::NotEqualNoScaling(op_params, getTensorShape(x()), x_data,
106 getTensorShape(y()), y_data, getTensorShape(output()),
111 void NotEqual::evalQuantized() const
113 const auto x_data = getTensorData<uint8_t>(x());
114 const auto y_data = getTensorData<uint8_t>(y());
115 auto output_data = getTensorData<bool>(output());
117 tflite::ComparisonParams op_params;
118 op_params.left_shift = 8;
119 op_params.input1_offset = -x()->zero_point(); // Note the '-'
120 op_params.input1_shift = _x_shift;
121 op_params.input1_multiplier = _x_multiplier;
122 op_params.input2_offset = -y()->zero_point(); // Note the '-'
123 op_params.input2_shift = _y_shift;
124 op_params.input2_multiplier = _y_multiplier;
125 op_params.is_broadcast = x()->shape() != y()->shape();
127 if (op_params.is_broadcast)
129 tflite::reference_ops::Broadcast4DSlowNotEqualWithScaling(
130 op_params, getTensorShape(x()), x_data, getTensorShape(y()), y_data, getTensorShape(output()),
135 tflite::reference_ops::NotEqualWithScaling(op_params, getTensorShape(x()), x_data,
136 getTensorShape(y()), y_data,
137 getTensorShape(output()), output_data);
141 } // namespace kernels
142 } // namespace luci_interpreter