2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_COMPARISONS_H
19 #define LUCI_INTERPRETER_PAL_COMPARISONS_H
22 #include "ProcessBroadcastShapes.h"
25 namespace luci_interpreter_pal
30 struct BroadcastComparison4DSlowCommon
32 const luci_interpreter::RuntimeShape output_shape;
37 inline BroadcastComparison4DSlowCommon
38 BroadcastComparison4DSlowPreprocess(const luci_interpreter::RuntimeShape &unextended_input1_shape,
39 const luci_interpreter::RuntimeShape &unextended_input2_shape,
40 const luci_interpreter::RuntimeShape &unextended_output_shape)
44 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1,
46 return {luci_interpreter::RuntimeShape::extendedShape(4, unextended_output_shape), desc1, desc2};
51 template <typename T> inline bool LessFn(T lhs, T rhs) { return lhs < rhs; }
52 template <typename T> inline bool LessEqualFn(T lhs, T rhs) { return lhs <= rhs; }
53 template <typename T> inline bool EqualFn(T lhs, T rhs) { return lhs == rhs; }
54 template <typename T> inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; }
55 template <typename T> inline bool GreaterEqualFn(T lhs, T rhs) { return lhs >= rhs; }
56 template <typename T> inline bool NotEqualFn(T lhs, T rhs) { return lhs != rhs; }
59 inline void ComparisonNoScaling(const int64_t flat_size, const T *input1_data, const T *input2_data,
60 bool *output_data, bool F(T, T))
62 for (int64_t i = 0; i < flat_size; ++i)
64 output_data[i] = F(input1_data[i], input2_data[i]);
69 inline void BroadcastComparison4DSlowWithScaling(
70 const ComparisonParams &op_params, const luci_interpreter::RuntimeShape &unextended_input1_shape,
71 const T *input1_data, const luci_interpreter::RuntimeShape &unextended_input2_shape,
72 const T *input2_data, const luci_interpreter::RuntimeShape &unextended_output_shape,
73 bool *output_data, bool F(T, T))
75 const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
76 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
78 int left_shift = op_params.left_shift;
79 int32_t input1_offset = op_params.input1_offset;
80 int32_t input1_multiplier = op_params.input1_multiplier;
81 int input1_shift = op_params.input1_shift;
82 int32_t input2_offset = op_params.input2_offset;
83 int32_t input2_multiplier = op_params.input2_multiplier;
84 int input2_shift = op_params.input2_shift;
86 for (int b = 0; b < dims.output_shape.dims(0); ++b)
88 for (int y = 0; y < dims.output_shape.dims(1); ++y)
90 for (int x = 0; x < dims.output_shape.dims(2); ++x)
92 for (int c = 0; c < dims.output_shape.dims(3); ++c)
94 const int32_t input1_val =
95 input1_offset + input1_data[subscriptToIndex(dims.desc1, b, y, x, c)];
96 const int32_t input2_val =
97 input2_offset + input2_data[subscriptToIndex(dims.desc2, b, y, x, c)];
98 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
99 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
100 const int32_t scaled_input1_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
101 shifted_input1_val, input1_multiplier, input1_shift);
102 const int32_t scaled_input2_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
103 shifted_input2_val, input2_multiplier, input2_shift);
105 const int output_data_offset =
106 ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
107 dims.output_shape.dims(3) +
109 output_data[output_data_offset] = F(scaled_input1_val, scaled_input2_val);
116 template <typename T>
117 inline void ComparisonWithScaling(const ComparisonParams &op_params, const int64_t flat_size,
118 const T *input1_data, const T *input2_data, bool *output_data,
121 int left_shift = op_params.left_shift;
122 int32_t input1_offset = op_params.input1_offset;
123 int32_t input1_multiplier = op_params.input1_multiplier;
124 int input1_shift = op_params.input1_shift;
125 int32_t input2_offset = op_params.input2_offset;
126 int32_t input2_multiplier = op_params.input2_multiplier;
127 int input2_shift = op_params.input2_shift;
129 for (int64_t i = 0; i < flat_size; ++i)
131 const int32_t input1_val = input1_offset + input1_data[i];
132 const int32_t input2_val = input2_offset + input2_data[i];
133 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
134 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
135 const int32_t scaled_input1_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
136 shifted_input1_val, input1_multiplier, input1_shift);
137 const int32_t scaled_input2_val = multiplyByQuantizedMultiplierSmallerThanOneExp(
138 shifted_input2_val, input2_multiplier, input2_shift);
139 output_data[i] = F(scaled_input1_val, scaled_input2_val);
143 template <typename T>
144 inline void BroadcastComparison4DSlowNoScaling(
145 const ComparisonParams &op_params, const luci_interpreter::RuntimeShape &unextended_input1_shape,
146 const T *input1_data, const luci_interpreter::RuntimeShape &unextended_input2_shape,
147 const T *input2_data, const luci_interpreter::RuntimeShape &unextended_output_shape,
148 bool *output_data, bool F(T, T))
150 const BroadcastComparison4DSlowCommon dims = BroadcastComparison4DSlowPreprocess(
151 unextended_input1_shape, unextended_input2_shape, unextended_output_shape);
153 for (int b = 0; b < dims.output_shape.dims(0); ++b)
155 for (int y = 0; y < dims.output_shape.dims(1); ++y)
157 for (int x = 0; x < dims.output_shape.dims(2); ++x)
159 for (int c = 0; c < dims.output_shape.dims(3); ++c)
161 const int output_data_offset =
162 ((b * dims.output_shape.dims(1) + y) * dims.output_shape.dims(2) + x) *
163 dims.output_shape.dims(3) +
165 output_data[output_data_offset] =
166 F(input1_data[subscriptToIndex(dims.desc1, b, y, x, c)],
167 input2_data[subscriptToIndex(dims.desc2, b, y, x, c)]);
174 } // namespace luci_interpreter_pal
176 #endif // LUCI_INTERPRETER_PAL_COMPARISONS_H