62c720937f30956cf99a59de2b4ca2ed17a5a240
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / cmsisnn / PALreference_ops.h
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_REFERENCE_OPS_H
19 #define LUCI_INTERPRETER_PAL_REFERENCE_OPS_H
20
21 #include <stdint.h>
22 #include <sys/types.h>
23
24 #include <algorithm>
25 #include <cmath>
26 #include <cstring>
27 #include <functional>
28 #include <limits>
29 #include <memory>
30 #include <type_traits>
31
32 #include "third_party/eigen3/Eigen/Core"
33 #include "fixedpoint/fixedpoint.h"
34 #include "ruy/profiler/instrumentation.h" // from @ruy
35 #include "tensorflow/lite/c/common.h"
36 #include "tensorflow/lite/kernels/internal/common.h"
37 #include "tensorflow/lite/kernels/internal/quantization_util.h"
38 #include "tensorflow/lite/kernels/internal/reference/add.h"
39 #include "tensorflow/lite/kernels/internal/reference/add_n.h"
40 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
41 #include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
42 #include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
43 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
44 #include "tensorflow/lite/kernels/internal/reference/cast.h"
45 #include "tensorflow/lite/kernels/internal/reference/ceil.h"
46 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
47 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
48 #include "tensorflow/lite/kernels/internal/reference/conv.h"
49 #include "tensorflow/lite/kernels/internal/reference/depth_to_space.h"
50 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
51 #include "tensorflow/lite/kernels/internal/reference/div.h"
52 #include "tensorflow/lite/kernels/internal/reference/elu.h"
53 #include "tensorflow/lite/kernels/internal/reference/exp.h"
54 #include "tensorflow/lite/kernels/internal/reference/fill.h"
55 #include "tensorflow/lite/kernels/internal/reference/floor.h"
56 #include "tensorflow/lite/kernels/internal/reference/floor_div.h"
57 #include "tensorflow/lite/kernels/internal/reference/floor_mod.h"
58 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
59 #include "tensorflow/lite/kernels/internal/reference/gather.h"
60 #include "tensorflow/lite/kernels/internal/reference/hard_swish.h"
61 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
62 #include "tensorflow/lite/kernels/internal/reference/leaky_relu.h"
63 #include "tensorflow/lite/kernels/internal/reference/log_softmax.h"
64 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
65 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
66 #include "tensorflow/lite/kernels/internal/reference/mul.h"
67 #include "tensorflow/lite/kernels/internal/reference/neg.h"
68 #include "tensorflow/lite/kernels/internal/reference/pad.h"
69 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
70 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
71 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
72 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
73 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
74 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
75 #include "tensorflow/lite/kernels/internal/reference/resize_bilinear.h"
76 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
77 #include "tensorflow/lite/kernels/internal/reference/round.h"
78 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
79 #include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
80 #include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
81 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
82 #include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
83 #include "tensorflow/lite/kernels/internal/reference/sub.h"
84 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
85 #include "tensorflow/lite/kernels/internal/reference/transpose.h"
86 #include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
87 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
88 #include "tensorflow/lite/kernels/internal/tensor.h"
89 #include "tensorflow/lite/kernels/internal/types.h"
90 namespace tflite
91 {
92
93 namespace reference_ops
94 {
95
96 template <typename T>
97 inline void Relu(const RuntimeShape &input_shape, const T *input_data,
98                  const RuntimeShape &output_shape, T *output_data)
99 {
100   const int flat_size = MatchingFlatSize(input_shape, output_shape);
101   for (int i = 0; i < flat_size; ++i)
102   {
103     const T val = input_data[i];
104     const T lower = 0;
105     const T clamped = val < lower ? lower : val;
106     output_data[i] = clamped;
107   }
108 }
109
110 template <typename T>
111 inline void Relu1(const RuntimeShape &input_shape, const T *input_data,
112                   const RuntimeShape &output_shape, T *output_data)
113 {
114   ruy::profiler::ScopeLabel label("Relu1 (not fused)");
115   const int flat_size = MatchingFlatSize(input_shape, output_shape);
116   for (int i = 0; i < flat_size; ++i)
117   {
118     const T val = input_data[i];
119     const T upper = 1;
120     const T lower = -1;
121     const T clamped = val > upper ? upper : val < lower ? lower : val;
122     output_data[i] = clamped;
123   }
124 }
125
126 inline void Relu6(const RuntimeShape &input_shape, const float *input_data,
127                   const RuntimeShape &output_shape, float *output_data)
128 {
129   ruy::profiler::ScopeLabel label("Relu6 (not fused)");
130   const int flat_size = MatchingFlatSize(input_shape, output_shape);
131   for (int i = 0; i < flat_size; ++i)
132   {
133     const float val = input_data[i];
134     const float upper = 6;
135     const float lower = 0;
136     const float clamped = val > upper ? upper : val < lower ? lower : val;
137     output_data[i] = clamped;
138   }
139 }
140
141 template <typename T>
142 inline void ReluX(const tflite::ReluParams &params, const RuntimeShape &input_shape,
143                   const T *input_data, const RuntimeShape &output_shape, T *output_data)
144 {
145   ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
146   const int flat_size = MatchingFlatSize(input_shape, output_shape);
147   for (int i = 0; i < flat_size; ++i)
148   {
149     const int32 val = static_cast<int32_t>(input_data[i]);
150     int32 clamped = params.output_offset + MultiplyByQuantizedMultiplier(val - params.input_offset,
151                                                                          params.output_multiplier,
152                                                                          params.output_shift);
153     clamped = std::max(params.quantized_activation_min, clamped);
154     clamped = std::min(params.quantized_activation_max, clamped);
155     output_data[i] = static_cast<T>(clamped);
156   }
157 }
158
159 template <typename T>
160 inline void ReluX(const tflite::ActivationParams &params, const RuntimeShape &input_shape,
161                   const T *input_data, const RuntimeShape &output_shape, T *output_data)
162 {
163   ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
164   const int flat_size = MatchingFlatSize(input_shape, output_shape);
165   const T max_value = params.quantized_activation_max;
166   const T min_value = params.quantized_activation_min;
167   for (int i = 0; i < flat_size; ++i)
168   {
169     const T val = input_data[i];
170     const T clamped = val > max_value ? max_value : val < min_value ? min_value : val;
171     output_data[i] = clamped;
172   }
173 }
174
175 // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
176 // dimensionality if the runtime code does a single loop over one dimension
177 // that handles broadcasting as the base case. The code generator would then
178 // generate max(D1, D2) nested for loops.
179 inline void BroadcastMulFivefold(const ArithmeticParams &unswitched_params,
180                                  const RuntimeShape &unswitched_input1_shape,
181                                  const uint8 *unswitched_input1_data,
182                                  const RuntimeShape &unswitched_input2_shape,
183                                  const uint8 *unswitched_input2_data,
184                                  const RuntimeShape &output_shape, uint8 *output_data)
185 {
186   ArithmeticParams switched_params = unswitched_params;
187   switched_params.input1_offset = unswitched_params.input2_offset;
188   switched_params.input2_offset = unswitched_params.input1_offset;
189
190   const bool use_unswitched = unswitched_params.broadcast_category ==
191                               tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
192
193   const ArithmeticParams &params = use_unswitched ? unswitched_params : switched_params;
194   const uint8 *input1_data = use_unswitched ? unswitched_input1_data : unswitched_input2_data;
195   const uint8 *input2_data = use_unswitched ? unswitched_input2_data : unswitched_input1_data;
196
197   // Fivefold nested loops. The second input resets its position for each
198   // iteration of the second loop. The first input resets its position at the
199   // beginning of the fourth loop. The innermost loop is an elementwise Mul of
200   // sections of the arrays.
201   uint8 *output_data_ptr = output_data;
202   const uint8 *input1_data_ptr = input1_data;
203   const uint8 *input2_data_reset = input2_data;
204   int y0 = params.broadcast_shape[0];
205   int y1 = params.broadcast_shape[1];
206   int y2 = params.broadcast_shape[2];
207   int y3 = params.broadcast_shape[3];
208   int y4 = params.broadcast_shape[4];
209   for (int i0 = 0; i0 < y0; ++i0)
210   {
211     const uint8 *input2_data_ptr;
212     for (int i1 = 0; i1 < y1; ++i1)
213     {
214       input2_data_ptr = input2_data_reset;
215       for (int i2 = 0; i2 < y2; ++i2)
216       {
217         for (int i3 = 0; i3 < y3; ++i3)
218         {
219           MulElementwise(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
220           input2_data_ptr += y4;
221           output_data_ptr += y4;
222         }
223         input1_data_ptr += y4;
224       }
225     }
226     input2_data_reset = input2_data_ptr;
227   }
228 }
229
230 inline void Mul(const ArithmeticParams &params, const RuntimeShape &input1_shape,
231                 const int16 *input1_data, const RuntimeShape &input2_shape,
232                 const int16 *input2_data, const RuntimeShape &output_shape, int16 *output_data)
233 {
234   ruy::profiler::ScopeLabel label("Mul/Int16");
235
236   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
237
238   for (int i = 0; i < flat_size; i++)
239   {
240     // F0 uses 0 integer bits, range [-1, 1].
241     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
242
243     F0 unclamped_result = F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
244     output_data[i] = unclamped_result.raw();
245   }
246 }
247
248 inline void Mul(const ArithmeticParams &params, const RuntimeShape &input1_shape,
249                 const int16 *input1_data, const RuntimeShape &input2_shape,
250                 const int16 *input2_data, const RuntimeShape &output_shape, uint8 *output_data)
251 {
252   ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
253   int32 output_offset = params.output_offset;
254   int32 output_activation_min = params.quantized_activation_min;
255   int32 output_activation_max = params.quantized_activation_max;
256   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
257
258   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
259
260   for (int i = 0; i < flat_size; i++)
261   {
262     // F0 uses 0 integer bits, range [-1, 1].
263     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
264
265     F0 unclamped_result = F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
266     int16 rescaled_result = gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
267     int16 clamped_result = std::min<int16>(output_activation_max - output_offset, rescaled_result);
268     clamped_result = std::max<int16>(output_activation_min - output_offset, clamped_result);
269     output_data[i] = output_offset + clamped_result;
270   }
271 }
272
273 inline void Sub16(const ArithmeticParams &params, const RuntimeShape &input1_shape,
274                   const int16_t *input1_data, const RuntimeShape &input2_shape,
275                   const int16_t *input2_data, const RuntimeShape &output_shape,
276                   int16_t *output_data)
277 {
278   ruy::profiler::ScopeLabel label("Sub/Int16");
279   const int input1_shift = params.input1_shift;
280   const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
281   const int16 output_activation_min = params.quantized_activation_min;
282   const int16 output_activation_max = params.quantized_activation_max;
283
284   TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
285   TFLITE_DCHECK_LE(input1_shift, 0);
286   TFLITE_DCHECK_LE(params.input2_shift, 0);
287   const int16 *not_shift_input = input1_shift == 0 ? input1_data : input2_data;
288   const int16 *shift_input = input1_shift == 0 ? input2_data : input1_data;
289   const int input_right_shift = input1_shift == 0 ? -params.input2_shift : -input1_shift;
290
291   if (input1_shift == 0)
292   {
293     // F0 uses 0 integer bits, range [-1, 1].
294     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
295     for (int i = 0; i < flat_size; ++i)
296     {
297       F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
298       F0 scaled_input =
299         F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
300       F0 result = SaturatingSub(input_ready_scaled, scaled_input);
301       const int16 raw_output = result.raw();
302       const int16 clamped_output =
303         std::min(output_activation_max, std::max(output_activation_min, raw_output));
304       output_data[i] = clamped_output;
305     }
306   }
307   else
308   {
309     // F0 uses 0 integer bits, range [-1, 1].
310     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
311     for (int i = 0; i < flat_size; ++i)
312     {
313       F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
314       F0 scaled_input =
315         F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
316       F0 result = SaturatingSub(scaled_input, input_ready_scaled);
317       const int16 raw_output = result.raw();
318       const int16 clamped_output =
319         std::min(output_activation_max, std::max(output_activation_min, raw_output));
320       output_data[i] = clamped_output;
321     }
322   }
323 }
324
325 template <typename Scalar>
326 void Pack(const PackParams &params, const RuntimeShape *const *input_shapes,
327           const Scalar *const *input_data, const RuntimeShape &output_shape, Scalar *output_data)
328 {
329   ruy::profiler::ScopeLabel label("Pack");
330   const int dimensions = output_shape.DimensionsCount();
331   int axis = params.axis;
332   int inputs_count = params.inputs_count;
333
334   int outer_size = 1;
335   for (int i = 0; i < axis; i++)
336   {
337     outer_size *= output_shape.Dims(i);
338   }
339   int copy_size = 1;
340   for (int i = params.axis + 1; i < dimensions; i++)
341   {
342     copy_size *= output_shape.Dims(i);
343   }
344   TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
345
346   for (int i = 0; i < inputs_count; ++i)
347   {
348     for (int k = 0; k < outer_size; k++)
349     {
350       const Scalar *input_ptr = input_data[i] + copy_size * k;
351       int loc = k * inputs_count * copy_size + i * copy_size;
352       memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
353     }
354   }
355 }
356
357 template <typename Scalar>
358 void Unpack(const UnpackParams &params, const RuntimeShape &input_shape, const Scalar *input_data,
359             const RuntimeShape &output_shape, Scalar *const *output_datas)
360 {
361   ruy::profiler::ScopeLabel label("Unpack");
362   const int dimensions = input_shape.DimensionsCount();
363   const int outputs_count = params.num_split;
364
365   int outer_size = 1;
366   int axis = params.axis;
367   if (axis < 0)
368   {
369     axis += dimensions;
370   }
371   TFLITE_DCHECK_GE(axis, 0);
372   TFLITE_DCHECK_LT(axis, dimensions);
373   for (int i = 0; i < axis; ++i)
374   {
375     outer_size *= input_shape.Dims(i);
376   }
377   int copy_size = 1;
378   for (int i = axis + 1; i < dimensions; ++i)
379   {
380     copy_size *= input_shape.Dims(i);
381   }
382   TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
383
384   for (int i = 0; i < outputs_count; ++i)
385   {
386     for (int k = 0; k < outer_size; k++)
387     {
388       Scalar *output_ptr = output_datas[i] + copy_size * k;
389       int loc = k * outputs_count * copy_size + i * copy_size;
390       memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
391     }
392   }
393 }
394
395 template <typename Scalar>
396 void PackWithScaling(const PackParams &params, const RuntimeShape *const *input_shapes,
397                      const uint8 *const *input_data, const RuntimeShape &output_shape,
398                      uint8 *output_data)
399 {
400   ruy::profiler::ScopeLabel label("PackWithScaling");
401   const int dimensions = output_shape.DimensionsCount();
402   int axis = params.axis;
403   const int32 *input_zeropoint = params.input_zeropoint;
404   const float *input_scale = params.input_scale;
405   int inputs_count = params.inputs_count;
406   const int32 output_zeropoint = params.output_zeropoint;
407   const float output_scale = params.output_scale;
408
409   int outer_size = 1;
410   for (int i = 0; i < axis; i++)
411   {
412     outer_size *= output_shape.Dims(i);
413   }
414   int copy_size = 1;
415   for (int i = axis + 1; i < dimensions; i++)
416   {
417     copy_size *= output_shape.Dims(i);
418   }
419   TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
420
421   Scalar *output_ptr = output_data;
422   const float inverse_output_scale = 1.f / output_scale;
423   for (int k = 0; k < outer_size; k++)
424   {
425     for (int i = 0; i < inputs_count; ++i)
426     {
427       if (input_zeropoint[i] == output_zeropoint && input_scale[i] == output_scale)
428       {
429         memcpy(output_ptr, input_data[i] + k * copy_size, copy_size * sizeof(Scalar));
430       }
431       else
432       {
433         assert(false);
434         const float scale = input_scale[i] * inverse_output_scale;
435         const float bias = -input_zeropoint[i] * scale;
436         auto input_ptr = input_data[i];
437         for (int j = 0; j < copy_size; ++j)
438         {
439           const int value =
440             static_cast<int32_t>(std::round(input_ptr[j] * scale + bias)) + output_zeropoint;
441           output_ptr[j] = static_cast<uint8_t>(std::max(std::min(255, value), 0));
442         }
443       }
444       output_ptr += copy_size;
445     }
446   }
447 }
448
449 template <typename Scalar>
450 void DepthConcatenation(const ConcatenationParams &params, const RuntimeShape *const *input_shapes,
451                         const Scalar *const *input_data, const RuntimeShape &output_shape,
452                         Scalar *output_data)
453 {
454   ruy::profiler::ScopeLabel label("DepthConcatenation");
455   auto params_copy = params;
456   params_copy.axis = 3;
457   Concatenation(params_copy, input_shapes, input_data, output_shape, output_data);
458 }
459
460 inline void LstmCell(const LstmCellParams &params, const RuntimeShape &unextended_input_shape,
461                      const float *input_data, const RuntimeShape &unextended_prev_activ_shape,
462                      const float *prev_activ_data, const RuntimeShape &weights_shape,
463                      const float *weights_data, const RuntimeShape &unextended_bias_shape,
464                      const float *bias_data, const RuntimeShape &unextended_prev_state_shape,
465                      const float *prev_state_data,
466                      const RuntimeShape &unextended_output_state_shape, float *output_state_data,
467                      const RuntimeShape &unextended_output_activ_shape, float *output_activ_data,
468                      const RuntimeShape &unextended_concat_temp_shape, float *concat_temp_data,
469                      const RuntimeShape &unextended_activ_temp_shape, float *activ_temp_data)
470 {
471   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
472   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
473   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
474   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
475   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
476   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
477   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
478   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
479   const RuntimeShape input_shape = RuntimeShape::ExtendedShape(4, unextended_input_shape);
480   const RuntimeShape prev_activ_shape = RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
481   const RuntimeShape bias_shape = RuntimeShape::ExtendedShape(4, unextended_bias_shape);
482   const RuntimeShape prev_state_shape = RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
483   const RuntimeShape output_state_shape =
484     RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
485   const RuntimeShape output_activ_shape =
486     RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
487   const RuntimeShape concat_temp_shape =
488     RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
489   const RuntimeShape activ_temp_shape = RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
490   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
491
492   const int weights_dim_count = weights_shape.DimensionsCount();
493   const int batches = MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
494                                   output_state_shape, 0, output_activ_shape, 0);
495   const int height = MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
496                                  output_state_shape, 1, output_activ_shape, 1);
497   const int width = MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
498                                 output_state_shape, 2, output_activ_shape, 2);
499   const int input_depth = input_shape.Dims(3);
500   const int prev_activ_depth = prev_activ_shape.Dims(3);
501   const int total_input_depth = prev_activ_depth + input_depth;
502   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1), total_input_depth);
503   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
504   const int intern_activ_depth = MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
505   TFLITE_DCHECK_EQ(weights_shape.FlatSize(), intern_activ_depth * total_input_depth);
506   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
507   const int output_depth = MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
508                                        3, output_activ_shape, 3);
509   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
510
511   // Concatenate prev_activ and input data together
512   std::vector<float const *> concat_input_arrays_data;
513   std::vector<RuntimeShape const *> concat_input_arrays_shapes;
514   concat_input_arrays_data.push_back(input_data);
515   concat_input_arrays_data.push_back(prev_activ_data);
516   concat_input_arrays_shapes.push_back(&input_shape);
517   concat_input_arrays_shapes.push_back(&prev_activ_shape);
518   tflite::ConcatenationParams concat_params;
519   concat_params.axis = 3;
520   concat_params.inputs_count = concat_input_arrays_data.size();
521   Concatenation(concat_params, &(concat_input_arrays_shapes[0]), &(concat_input_arrays_data[0]),
522                 concat_temp_shape, concat_temp_data);
523
524   // Fully connected
525   tflite::FullyConnectedParams fc_params;
526   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
527   fc_params.float_activation_max = std::numeric_limits<float>::max();
528   FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape, weights_data,
529                  bias_shape, bias_data, activ_temp_shape, activ_temp_data);
530
531   // Memory state update (the LSTM "guts")
532   for (int b = 0; b < batches; ++b)
533   {
534     for (int w = 0; w < width; ++w)
535     {
536       for (int h = 0; h < height; ++h)
537       {
538         for (int c = 0; c < output_depth; ++c)
539         {
540           const float input_gate =
541             1.f /
542             (1.f +
543              std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w, 0 * output_depth + c)]));
544           const float new_input =
545             std::tanh(activ_temp_data[Offset(activ_temp_shape, b, h, w, 1 * output_depth + c)]);
546           const float forget_gate =
547             1.f /
548             (1.f +
549              std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w, 2 * output_depth + c)]));
550           const float output_gate =
551             1.f /
552             (1.f +
553              std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w, 3 * output_depth + c)]));
554           const float new_state =
555             input_gate * new_input +
556             forget_gate * prev_state_data[Offset(prev_state_shape, b, h, w, c)];
557           output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
558           output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
559             output_gate * std::tanh(new_state);
560         }
561       }
562     }
563   }
564 }
565
566 // Quantized LSTM cell implementation.
567 // The quantization of the input, output arrays is as follows:
568 //  - The input activations are quantized as uint8 on the interval
569 //    [-1, 127/128].
570 //    The rationale for that is that is the natural interval for output
571 //    activations (see next point) and these need to be concatenated together.
572 //    We could accommodate different ranges by re-scaling, but we empirically
573 //    found that setting the input activations range to be [-1, 127/128] in the
574 //    first place, removing the need for re-scaling, greatly improves accuracy.
575 //  - The output activations are quantized as uint8 on the interval
576 //    [-1, 127/128].
577 //    The rationale for that is that the definition of a LSTM cell makes them
578 //    intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
579 //    makes for simpler, more accurate fixed-point arithmetic.
580 //  - The output-at-previous-timestep state array is obviously quantized as
581 //    the output activations.
582 //  - The internal LSTM memory (not the output-at-previous-timestep, the other
583 //    internal state array) is int16-quantized and may use any power-of-two,
584 //    symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
585 //    StateIntegerBits below, see the below discussion of that template
586 //    parameter ("The StateIntegerBits template parameter").
587 //  - The output of the internal fully-connected node is int16-quantized
588 //    on the interval [-8, 8 * 32767/32768], the rationale for which is
589 //    explained just below ("Why [-8, 8] for fully-connected output?").
590 //
591 //
592 // === The StateIntegerBits template parameter ===
593 //
594 // The StateIntegerBits template parameter controls the fixed-point format used
595 // to represent the internal memory of the LSTM cell (not the
596 // output-at-previous-timestep, the other internal state array). It's currently
597 // a template parameter so that the model can control that. The most typical
598 // value for StateIntegerBits is 4. Other plausible values are anywhere between
599 // 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
600 // and drop that template parameter. The reason why it can't be a runtime
601 // parameter is that this controls the fixed-point format used, i.e. we need to
602 // generate actually different code based on it. In particular, we generate code
603 // for a fixed-point tanh() implementation for that format, which internally
604 // uses a fixed-point exp() implementation, which internally uses a
605 // barrel-shifter with a number of steps that depends on StateIntegerBits.
606 // Another consequence of that is that a higher value of StateIntegerBits
607 // results in a more expensive implementation (more barrel shifter steps
608 // needed).
609 //
610 //
611 // === Why [-8, 8] for fully-connected output? ===
612 //
613 // This array is only fed to Logistic and Tanh functions, for which
614 // the quantized implementation will want to use fixed-point arithmetic,
615 // requiring a power-of-two representation interval. Thus, we should right
616 // away quantize this array to a power-of-two interval; otherwise,
617 // implementation will need to rescale that, losing any benefit that a tighter
618 // representation interval might otherwise yield, while introducing some
619 // numerical error and computational overhead.
620 //
621 // Now, Logistic and Tanh
622 // are nearly constant (nearly equal to their horizontal asymptotes)
623 // outside of a small bounded interval around 0:
624 //
625 //   Logistic(4) = 1 - 1.8e-2     Tanh(4) = 1 - 6.7e-4
626 //   Logistic(8) = 1 - 3.4e-4     Tanh(8) = 1 - 2.3e-7
627 //   Logistic(16) = 1 - 1.1e-7    Tanh(16) = 1 - 2.5e-14
628 //
629 // From this, we see that clamping to [-4, 4] would be too inaccurate
630 // (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
631 // while clamping to [-16, 16] would make no difference even in float32.
632 // However, for a fixed-point implementation in 16-bit integers, using 5
633 // integer bits to represent the [-16, 16] range would leave only 11
634 // fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
635 // representable values. Notice that is higher than the
636 // worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
637 // Using [-8, 8] thus seems like the better compromise overall, enjoying
638 // an increment of 2.4e-4 between representable values and a worst-case
639 // clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
640 // [-16, 16].
641 //
642 // Moreover, all other things being equal, it is nice to choose the narrower
643 // representation range, as that makes the implementation of fixed-point
644 // math functions a little cheaper (each integer bit requires an additional
645 // barrel-shifter atep in the implementation of exp(-x)). That is further
646 // reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
647 // sense for 32-bit float or 32-bit fixed-point quantization, but we are
648 // aiming for 16-bit fixed-point quantization of these internal nodes here.
649 //
650 template <int StateIntegerBits>
651 inline void
652 LstmCell(const LstmCellParams &params, const RuntimeShape &unextended_input_shape,
653          const uint8 *input_data_uint8, const RuntimeShape &unextended_prev_activ_shape,
654          const uint8 *prev_activ_data_uint8, const RuntimeShape &weights_shape,
655          const uint8 *weights_data_uint8, const RuntimeShape &unextended_bias_shape,
656          const int32 *bias_data_int32, const RuntimeShape &unextended_prev_state_shape,
657          const int16 *prev_state_data_int16, const RuntimeShape &unextended_output_state_shape,
658          int16 *output_state_data_int16, const RuntimeShape &unextended_output_activ_shape,
659          uint8 *output_activ_data_uint8, const RuntimeShape &unextended_concat_temp_shape,
660          uint8 *concat_temp_data_uint8, const RuntimeShape &unextended_activ_temp_shape,
661          int16 *activ_temp_data_int16, void *gemmlowp_context)
662 {
663   (void)gemmlowp_context; // only used in optimized code.
664   int32 weights_zero_point = params.weights_zero_point;
665   int32 accum_multiplier = params.accum_multiplier;
666   int accum_shift = params.accum_shift;
667   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
668   TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
669   TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
670   TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
671   TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
672   TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
673   TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
674   TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
675   const RuntimeShape input_shape = RuntimeShape::ExtendedShape(4, unextended_input_shape);
676   const RuntimeShape prev_activ_shape = RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
677   const RuntimeShape bias_shape = RuntimeShape::ExtendedShape(4, unextended_bias_shape);
678   const RuntimeShape prev_state_shape = RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
679   const RuntimeShape output_state_shape =
680     RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
681   const RuntimeShape output_activ_shape =
682     RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
683   const RuntimeShape concat_temp_shape =
684     RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
685   const RuntimeShape activ_temp_shape = RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
686   TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
687
688   // Gather dimensions information, and perform consistency checks.
689   const int weights_dim_count = weights_shape.DimensionsCount();
690   const int outer_size = MatchingFlatSizeSkipDim(input_shape, 3, prev_activ_shape, prev_state_shape,
691                                                  output_state_shape, output_activ_shape);
692   const int input_depth = input_shape.Dims(3);
693   const int prev_activ_depth = prev_activ_shape.Dims(3);
694   const int total_input_depth = prev_activ_depth + input_depth;
695   TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1), total_input_depth);
696   const int intern_activ_depth = MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
697   TFLITE_DCHECK_EQ(weights_shape.FlatSize(), intern_activ_depth * total_input_depth);
698   TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
699   TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
700   const int output_depth = MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
701                                        3, output_activ_shape, 3);
702   TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
703   const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
704   const int fc_output_depth =
705     MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
706   const int fc_accum_depth = total_input_depth;
707   TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
708
709   // Depth-concatenate prev_activ and input data together.
710   uint8 const *concat_input_arrays_data[2] = {input_data_uint8, prev_activ_data_uint8};
711   const RuntimeShape *concat_input_arrays_shapes[2] = {&input_shape, &prev_activ_shape};
712   tflite::ConcatenationParams concat_params;
713   concat_params.axis = 3;
714   concat_params.inputs_count = 2;
715   Concatenation(concat_params, concat_input_arrays_shapes, concat_input_arrays_data,
716                 concat_temp_shape, concat_temp_data_uint8);
717
718   // Implementation of the fully connected node inside the LSTM cell.
719   // The operands are 8-bit integers, the accumulators are internally 32bit
720   // integers, and the output is 16-bit fixed-point with 3 integer bits so
721   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
722   // is explained in the function comment above.
723   for (int b = 0; b < fc_batches; ++b)
724   {
725     for (int out_c = 0; out_c < fc_output_depth; ++out_c)
726     {
727       // Internal accumulation.
728       // Initialize accumulator with the bias-value.
729       int32 accum = bias_data_int32[out_c];
730       // Accumulation loop.
731       for (int d = 0; d < fc_accum_depth; ++d)
732       {
733         int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
734         int16 weights_val = weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
735         accum += input_val * weights_val;
736       }
737       // Down-scale the final int32 accumulator to the scale used by our
738       // (16-bit, using 3 integer bits) fixed-point format. The quantized
739       // multiplier and shift here have been pre-computed offline
740       // (e.g. by toco).
741       accum = MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
742       // Saturate, cast to int16, and store to the temporary activations array.
743       accum = std::max(-32768, std::min(32767, static_cast<int>(accum)));
744       activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
745     }
746   }
747
748   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
749   // and muls, all done in 16-bit fixed-point.
750   for (int b = 0; b < outer_size; ++b)
751   {
752     for (int c = 0; c < output_depth; ++c)
753     {
754       // Define the fixed-point data types that we will use here. All use
755       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
756       // They only differ by the number of integral vs. fractional bits,
757       // determining the range of values that they can represent.
758       //
759       // F0 uses 0 integer bits, range [-1, 1].
760       // This is the return type of math functions such as tanh, logistic,
761       // whose range is in [-1, 1].
762       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
763       // F3 uses 3 integer bits, range [-8, 8].
764       // This is the range of the previous fully-connected node's output,
765       // which is our input here.
766       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
767       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
768       // 2^StateIntegerBits]. It's used to represent the internal state, whose
769       // number of integer bits is currently dictated by the model. See comment
770       // on the StateIntegerBits template parameter above.
771       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
772       // Implementation of input gate, using fixed-point logistic function.
773       F3 input_gate_input =
774         F3::FromRaw(activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
775       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
776       // Implementation of input modulation gate, using fixed-point tanh
777       // function.
778       F3 input_modulation_gate_input =
779         F3::FromRaw(activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
780       F0 input_modulation_gate_output = gemmlowp::tanh(input_modulation_gate_input);
781       // Implementation of forget gate, using fixed-point logistic function.
782       F3 forget_gate_input =
783         F3::FromRaw(activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
784       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
785       // Implementation of output gate, using fixed-point logistic function.
786       F3 output_gate_input =
787         F3::FromRaw(activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
788       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
789       // Implementation of internal multiplication nodes, still in fixed-point.
790       F0 input_times_input_modulation = input_gate_output * input_modulation_gate_output;
791       FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
792       FS prev_state_times_forget_state = forget_gate_output * prev_state;
793       // Implementation of internal addition node, saturating.
794       FS new_state =
795         gemmlowp::SaturatingAdd(gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
796                                 prev_state_times_forget_state);
797       // Implementation of last internal Tanh node, still in fixed-point.
798       // Since a Tanh fixed-point implementation is specialized for a given
799       // number or integer bits, and each specialization can have a substantial
800       // code size, and we already used above a Tanh on an input with 3 integer
801       // bits, and per the table in the above function comment there is no
802       // significant accuracy to be lost by clamping to [-8, +8] for a
803       // 3-integer-bits representation, let us just do that. This helps people
804       // porting this to targets where code footprint must be minimized.
805       F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
806       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
807       // Store the new internal state back to memory, as 16-bit integers.
808       // Note: here we store the original value with StateIntegerBits, not
809       // the rescaled 3-integer-bits value fed to tanh.
810       output_state_data_int16[b * output_depth + c] = new_state.raw();
811       // Down-scale the output activations to 8-bit integers, saturating,
812       // and store back to memory.
813       int16 rescaled_output_activ = gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
814       int16 clamped_output_activ =
815         std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
816       output_activ_data_uint8[b * output_depth + c] = 128 + clamped_output_activ;
817     }
818   }
819 }
820
821 template <typename Scalar>
822 void Split(const SplitParams &params, const RuntimeShape &input_shape, const Scalar *input_data,
823            const RuntimeShape *const *output_shapes, Scalar *const *output_data)
824 {
825   ruy::profiler::ScopeLabel label("Split");
826   const int split_dimensions = input_shape.DimensionsCount();
827   int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
828   int outputs_count = params.num_split;
829   TFLITE_DCHECK_LT(axis, split_dimensions);
830
831   int64_t split_size = 0;
832   for (int i = 0; i < outputs_count; i++)
833   {
834     TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
835     for (int j = 0; j < split_dimensions; j++)
836     {
837       if (j != axis)
838       {
839         MatchingDim(*output_shapes[i], j, input_shape, j);
840       }
841     }
842     split_size += output_shapes[i]->Dims(axis);
843   }
844   TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
845   int64_t outer_size = 1;
846   for (int i = 0; i < axis; ++i)
847   {
848     outer_size *= input_shape.Dims(i);
849   }
850   // For all output arrays,
851   // FlatSize() = outer_size * Dims(axis) * base_inner_size;
852   int64_t base_inner_size = 1;
853   for (int i = axis + 1; i < split_dimensions; ++i)
854   {
855     base_inner_size *= input_shape.Dims(i);
856   }
857
858   const Scalar *input_ptr = input_data;
859   for (int k = 0; k < outer_size; k++)
860   {
861     for (int i = 0; i < outputs_count; ++i)
862     {
863       const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
864       memcpy(output_data[i] + k * copy_size, input_ptr, copy_size * sizeof(Scalar));
865       input_ptr += copy_size;
866     }
867   }
868 }
869
870 inline int NodeOffset(int b, int h, int w, int height, int width)
871 {
872   return (b * height + h) * width + w;
873 }
874
875 inline void LocalResponseNormalization(const tflite::LocalResponseNormalizationParams &op_params,
876                                        const RuntimeShape &input_shape, const float *input_data,
877                                        const RuntimeShape &output_shape, float *output_data)
878 {
879   const int trailing_dim = input_shape.DimensionsCount() - 1;
880   const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
881   const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
882
883   for (int i = 0; i < outer_size; ++i)
884   {
885     for (int c = 0; c < depth; ++c)
886     {
887       const int begin_input_c = std::max(0, static_cast<int>(c - op_params.range));
888       const int end_input_c = std::min(depth, static_cast<int>(c + op_params.range));
889       float accum = 0.f;
890       for (int input_c = begin_input_c; input_c < end_input_c; ++input_c)
891       {
892         const float input_val = input_data[i * depth + input_c];
893         accum += input_val * input_val;
894       }
895       const float multiplier = std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
896       output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
897     }
898   }
899 }
900
901 inline void Dequantize(const RuntimeShape &input_shape, const Eigen::half *input_data,
902                        const RuntimeShape &output_shape, float *output_data)
903 {
904   const int flat_size = MatchingFlatSize(input_shape, output_shape);
905   for (int i = 0; i < flat_size; i++)
906   {
907     output_data[i] = static_cast<float>(input_data[i]);
908   }
909 }
910
911 inline void FakeQuant(const tflite::FakeQuantParams &op_params, const RuntimeShape &input_shape,
912                       const float *input_data, const RuntimeShape &output_shape, float *output_data)
913 {
914   ruy::profiler::ScopeLabel label("FakeQuant");
915   float rmin = op_params.minmax.min;
916   float rmax = op_params.minmax.max;
917   int num_bits = op_params.num_bits;
918   // 0 should always be a representable value. Let's assume that the initial
919   // min,max range contains 0.
920   TFLITE_DCHECK_LE(rmin, 0.0f);
921   TFLITE_DCHECK_GE(rmax, 0.0f);
922   TFLITE_DCHECK_LT(rmin, rmax);
923
924   // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor.
925   int quant_min = 0;
926   int quant_max = (1 << num_bits) - 1;
927   float nudged_min, nudged_max, nudged_scale;
928   NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, &nudged_max, &nudged_scale);
929   const int flat_size = MatchingFlatSize(input_shape, output_shape);
930   FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data, output_data, flat_size);
931 }
932
933 // Common subroutine for both `GatherNd` and `GatherNdString`.
934 struct GatherNdHelperResult
935 {
936   int n_slices;
937   int slice_size;
938   int indices_nd;
939   std::vector<int> dims_to_count;
940 };
941
942 // Returns common values being used on both `GatherNd` and `GatherNdString`.
943 inline GatherNdHelperResult GatherNdHelper(const RuntimeShape &params_shape,
944                                            const RuntimeShape &indices_shape)
945 {
946   GatherNdHelperResult ret;
947   ret.n_slices = 1;
948   ret.slice_size = 1;
949   const int indices_dims = indices_shape.DimensionsCount();
950   ret.indices_nd = indices_shape.Dims(indices_dims - 1);
951   const int params_dims = params_shape.DimensionsCount();
952   for (int i = 0; i < indices_dims - 1; ++i)
953   {
954     ret.n_slices *= indices_shape.Dims(i);
955   }
956   for (int i = ret.indices_nd; i < params_dims; ++i)
957   {
958     ret.slice_size *= params_shape.Dims(i);
959   }
960
961   int remain_flat_size = params_shape.FlatSize();
962   ret.dims_to_count = std::vector<int>(ret.indices_nd, 0);
963   for (int i = 0; i < ret.indices_nd; ++i)
964   {
965     ret.dims_to_count[i] = remain_flat_size / params_shape.Dims(i);
966     remain_flat_size = ret.dims_to_count[i];
967   }
968
969   return ret;
970 }
971
972 template <typename ParamsT, typename IndicesT = int32>
973 inline void GatherNd(const RuntimeShape &params_shape, const ParamsT *params_data,
974                      const RuntimeShape &indices_shape, const IndicesT *indices_data,
975                      const RuntimeShape &output_shape, ParamsT *output_data)
976 {
977   ruy::profiler::ScopeLabel label("GatherNd");
978
979   const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
980   for (int i = 0; i < res.n_slices; ++i)
981   {
982     int from_pos = 0;
983     for (int j = 0; j < res.indices_nd; ++j)
984     {
985       from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
986     }
987     std::memcpy(output_data + i * res.slice_size, params_data + from_pos,
988                 sizeof(ParamsT) * res.slice_size);
989   }
990 }
991
992 #ifndef TF_LITE_STATIC_MEMORY
993 template <typename IndicesT = int32>
994 inline void GatherNdString(const RuntimeShape &params_shape, const TfLiteTensor *params_data,
995                            const RuntimeShape &indices_shape, const IndicesT *indices_data,
996                            const RuntimeShape &output_shape, TfLiteTensor *output_data)
997 {
998   ruy::profiler::ScopeLabel label("GatherNdString");
999
1000   const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
1001   DynamicBuffer buffer;
1002   for (int i = 0; i < res.n_slices; ++i)
1003   {
1004     int from_pos = 0;
1005     for (int j = 0; j < res.indices_nd; ++j)
1006     {
1007       from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
1008     }
1009     for (int j = 0; j < res.slice_size; ++j)
1010     {
1011       buffer.AddString(GetString(params_data, from_pos + j));
1012     }
1013   }
1014   buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
1015 }
1016 #endif
1017
1018 template <typename IndicesT, typename UpdatesT>
1019 inline void ScatterNd(const RuntimeShape &indices_shape, const IndicesT *indices_data,
1020                       const RuntimeShape &updates_shape, const UpdatesT *updates_data,
1021                       const RuntimeShape &output_shape, UpdatesT *output_data)
1022 {
1023   ruy::profiler::ScopeLabel label("ScatterNd");
1024
1025   int n_slices = 1;
1026   int slice_size = 1;
1027   const int outer_dims = indices_shape.DimensionsCount() - 1;
1028   const int indices_nd = indices_shape.Dims(outer_dims);
1029   const int updates_dims = updates_shape.DimensionsCount();
1030   for (int i = 0; i < outer_dims; ++i)
1031   {
1032     n_slices *= indices_shape.Dims(i);
1033   }
1034   for (int i = outer_dims; i < updates_dims; ++i)
1035   {
1036     slice_size *= updates_shape.Dims(i);
1037   }
1038
1039   int output_flat_size = output_shape.FlatSize();
1040   int remain_flat_size = output_flat_size;
1041   std::vector<int> dims_to_count(indices_nd, 0);
1042   for (int i = 0; i < indices_nd; ++i)
1043   {
1044     dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
1045     remain_flat_size = dims_to_count[i];
1046   }
1047
1048   memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
1049   for (int i = 0; i < n_slices; ++i)
1050   {
1051     int to_pos = 0;
1052     for (int j = 0; j < indices_nd; ++j)
1053     {
1054       IndicesT idx = indices_data[i * indices_nd + j];
1055       TFLITE_DCHECK(0 <= idx && idx < output_shape.Dims(j));
1056       to_pos += idx * dims_to_count[j];
1057     }
1058     for (int j = 0; j < slice_size; j++)
1059     {
1060       output_data[to_pos + j] += updates_data[i * slice_size + j];
1061     }
1062   }
1063 }
1064
1065 template <typename T>
1066 inline void Slice(const tflite::SliceParams &op_params, const RuntimeShape &input_shape,
1067                   const RuntimeShape &output_shape, SequentialTensorWriter<T> *writer)
1068 {
1069   const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
1070   TFLITE_DCHECK_LE(op_params.begin_count, 5);
1071   TFLITE_DCHECK_LE(op_params.size_count, 5);
1072   const int begin_count = op_params.begin_count;
1073   const int size_count = op_params.size_count;
1074   // We front-pad the begin and size vectors.
1075   std::array<int, 5> start;
1076   std::array<int, 5> stop;
1077   for (int i = 0; i < 5; ++i)
1078   {
1079     int padded_i = 5 - i;
1080     start[i] = begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
1081     stop[i] = (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
1082                 ? ext_shape.Dims(i)
1083                 : start[i] + op_params.size[size_count - padded_i];
1084   }
1085
1086   for (int i0 = start[0]; i0 < stop[0]; ++i0)
1087   {
1088     for (int i1 = start[1]; i1 < stop[1]; ++i1)
1089     {
1090       for (int i2 = start[2]; i2 < stop[2]; ++i2)
1091       {
1092         for (int i3 = start[3]; i3 < stop[3]; ++i3)
1093         {
1094           for (int i4 = start[4]; i4 < stop[4]; ++i4)
1095           {
1096             writer->Write(Offset(ext_shape, i0, i1, i2, i3, i4));
1097           }
1098         }
1099       }
1100     }
1101   }
1102 }
1103
1104 template <typename T>
1105 inline void Slice(const tflite::SliceParams &op_params, const RuntimeShape &input_shape,
1106                   const T *input_data, const RuntimeShape &output_shape, T *output_data)
1107 {
1108   SequentialTensorWriter<T> writer(input_data, output_data);
1109   return Slice(op_params, input_shape, output_shape, &writer);
1110 }
1111
1112 template <typename T>
1113 inline void Slice(const tflite::SliceParams &op_params, const RuntimeShape &input_shape,
1114                   const TfLiteTensor *input, const RuntimeShape &output_shape, TfLiteTensor *output)
1115 {
1116   SequentialTensorWriter<T> writer(input, output);
1117   return Slice(op_params, input_shape, output_shape, &writer);
1118 }
1119
1120 template <typename T>
1121 void Minimum(const RuntimeShape &input1_shape, const T *input1_data, const T *input2_data,
1122              const RuntimeShape &output_shape, T *output_data)
1123 {
1124   const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1125
1126   auto min_value = input2_data[0];
1127   for (int i = 0; i < flat_size; i++)
1128   {
1129     output_data[i] = input1_data[i] > min_value ? min_value : input1_data[i];
1130   }
1131 }
1132
1133 // Convenience version that allows, for example, generated-code calls to be
1134 // the same as other binary ops.
1135 template <typename T>
1136 inline void Minimum(const RuntimeShape &input1_shape, const T *input1_data, const RuntimeShape &,
1137                     const T *input2_data, const RuntimeShape &output_shape, T *output_data)
1138 {
1139   // Drop shape of second input: not needed.
1140   Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
1141 }
1142
1143 template <typename T>
1144 void Maximum(const RuntimeShape &input1_shape, const T *input1_data, const T *input2_data,
1145              const RuntimeShape &output_shape, T *output_data)
1146 {
1147   const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1148
1149   auto max_value = input2_data[0];
1150   for (int i = 0; i < flat_size; i++)
1151   {
1152     output_data[i] = input1_data[i] < max_value ? max_value : input1_data[i];
1153   }
1154 }
1155
1156 // Convenience version that allows, for example, generated-code calls to be
1157 // the same as other binary ops.
1158 template <typename T>
1159 inline void Maximum(const RuntimeShape &input1_shape, const T *input1_data, const RuntimeShape &,
1160                     const T *input2_data, const RuntimeShape &output_shape, T *output_data)
1161 {
1162   // Drop shape of second input: not needed.
1163   Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
1164 }
1165
1166 template <typename T1, typename T2, typename T3>
1167 void ArgMax(const RuntimeShape &input1_shape, const T1 *input1_data, const T3 *input2_data,
1168             const RuntimeShape &output_shape, T2 *output_data)
1169 {
1170   ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data, std::greater<T1>());
1171 }
1172
1173 // Convenience version that allows, for example, generated-code calls to be
1174 // the same as other binary ops.
1175 template <typename T1, typename T2, typename T3>
1176 inline void ArgMax(const RuntimeShape &input1_shape, const T1 *input1_data,
1177                    const RuntimeShape &input2_shape, const T3 *input2_data,
1178                    const RuntimeShape &output_shape, T2 *output_data)
1179 {
1180   // Drop shape of second input: not needed.
1181   ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
1182 }
1183
1184 template <typename D, typename T>
1185 void Select(const RuntimeShape &input_condition_shape, const D *input_condition_data,
1186             const RuntimeShape &input_x_shape, const T *input_x_data,
1187             const RuntimeShape &input_y_shape, const T *input_y_data,
1188             const RuntimeShape &output_shape, T *output_data)
1189 {
1190   int64_t flatsize;
1191   // Allow select operator executions on mixed scalar tensors and one element
1192   // tensors.
1193   if (input_condition_shape.FlatSize() == 1 && input_x_shape.FlatSize() == 1 &&
1194       input_y_shape.FlatSize() == 1 && output_shape.FlatSize() == 1)
1195   {
1196     flatsize = 1;
1197   }
1198   else
1199   {
1200     flatsize = MatchingFlatSize(input_condition_shape, input_x_shape, input_y_shape, output_shape);
1201   }
1202   for (int64_t i = 0; i < flatsize; ++i)
1203   {
1204     output_data[i] = input_condition_data[i] ? input_x_data[i] : input_y_data[i];
1205   }
1206 }
1207
1208 template <typename D, typename T>
1209 void RankOneSelect(const RuntimeShape &input_condition_shape, const D *input_condition_data,
1210                    const RuntimeShape &input_x_shape, const T *input_x_data,
1211                    const RuntimeShape &input_y_shape, const T *input_y_data,
1212                    const RuntimeShape &output_shape, T *output_data)
1213 {
1214   const int64_t outer_size = input_condition_shape.FlatSize();
1215   int64_t inner_size;
1216   if (input_condition_shape.DimensionsCount() == 0)
1217   {
1218     inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
1219   }
1220   else
1221   {
1222     TFLITE_DCHECK_EQ(MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0), outer_size);
1223     inner_size = MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
1224   }
1225
1226   int64_t offset = 0;
1227   for (int64_t i = 0; i < outer_size; i++)
1228   {
1229     const T *input_data = input_condition_data[i] ? input_x_data : input_y_data;
1230     memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
1231     offset += inner_size;
1232   }
1233 }
1234
1235 template <typename D, typename T>
1236 void BroadcastSelect4DSlow(const RuntimeShape &input_condition_shape, const D *input_condition_data,
1237                            const RuntimeShape &input_x_shape, const T *input_x_data,
1238                            const RuntimeShape &input_y_shape, const T *input_y_data,
1239                            const RuntimeShape &output_shape, T *output_data)
1240 {
1241   TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 4);
1242   TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 4);
1243   TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 4);
1244   TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
1245
1246   const RuntimeShape extended_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
1247
1248   NdArrayDesc<4> desc_condition;
1249   NdArrayDesc<4> desc_x;
1250   NdArrayDesc<4> desc_y;
1251   NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape, input_y_shape,
1252                                       &desc_condition, &desc_x, &desc_y);
1253
1254   // In Tensorflow, the dimensions are canonically named (batch_number, row,
1255   // col, channel), with extents (batches, height, width, depth), with the
1256   // trailing dimension changing most rapidly (channels has the smallest
1257   // stride, typically 1 element).
1258   //
1259   // In generated C code, we store arrays with the dimensions reversed. The
1260   // first dimension has smallest stride.
1261   //
1262   // We name our variables by their Tensorflow convention, but generate C code
1263   // nesting loops such that the innermost loop has the smallest stride for
1264   // the best cache behavior.
1265   for (int b = 0; b < extended_output_shape.Dims(0); ++b)
1266   {
1267     for (int y = 0; y < extended_output_shape.Dims(1); ++y)
1268     {
1269       for (int x = 0; x < extended_output_shape.Dims(2); ++x)
1270       {
1271         for (int c = 0; c < extended_output_shape.Dims(3); ++c)
1272         {
1273           const int condition_index = SubscriptToIndex(desc_condition, b, y, x, c);
1274           const int x_index = SubscriptToIndex(desc_x, b, y, x, c);
1275           const int y_index = SubscriptToIndex(desc_y, b, y, x, c);
1276           output_data[Offset(extended_output_shape, b, y, x, c)] =
1277             input_condition_data[condition_index] ? input_x_data[x_index] : input_y_data[y_index];
1278         }
1279       }
1280     }
1281   }
1282 }
1283
1284 template <typename D, typename T>
1285 void SelectTrueCoords(const RuntimeShape &input_condition_shape, const D *input_condition_data,
1286                       T *output_data)
1287 {
1288   const size_t size = input_condition_shape.FlatSize();
1289   if (size == 0)
1290   {
1291     // Dimension is zero, in which case we don't need to output.
1292     return;
1293   }
1294   const size_t cond_rank = input_condition_shape.DimensionsCount();
1295
1296   std::vector<int> dims_to_count(cond_rank, 0);
1297   int cur_flat_size = size;
1298   for (int i = 0; i < cond_rank; ++i)
1299   {
1300     dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
1301     cur_flat_size = dims_to_count[i];
1302   }
1303
1304   int output_index = 0;
1305   for (int i = 0; i < size; ++i)
1306   {
1307     if (input_condition_data[i])
1308     {
1309       // Insert the coordinate of the current item (row major) into output.
1310       int flat_index = i;
1311       for (int j = 0; j < cond_rank; ++j)
1312       {
1313         int coord_j = flat_index / dims_to_count[j];
1314         output_data[output_index * cond_rank + j] = coord_j;
1315         flat_index %= dims_to_count[j];
1316       }
1317       output_index++;
1318     }
1319   }
1320 }
1321
1322 // For easy implementation, the indices is always a vector of size-4 vectors.
1323 template <typename T, typename TI>
1324 inline void SparseToDense(const std::vector<std::vector<TI>> &indices, const T *values,
1325                           T default_value, bool value_is_scalar,
1326                           const RuntimeShape &unextended_output_shape, T *output_data)
1327 {
1328   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1329   const RuntimeShape output_shape = RuntimeShape::ExtendedShape(4, unextended_output_shape);
1330   const int value_count = indices.size();
1331
1332   // First fill the output_data with default value.
1333   const int num_elements = output_shape.FlatSize();
1334   for (int i = 0; i < num_elements; ++i)
1335   {
1336     output_data[i] = default_value;
1337   }
1338
1339   // Special handle for value is scalar case to avoid checking the boolean
1340   // condition within the loop every time.
1341   if (value_is_scalar)
1342   {
1343     for (int i = 0; i < value_count; ++i)
1344     {
1345       const std::vector<TI> &index = indices[i];
1346       TFLITE_DCHECK_EQ(index.size(), 4);
1347       const T value = *values; // just use the first value.
1348       output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] = value;
1349     }
1350     return;
1351   }
1352
1353   // Go through the values and indices to fill the sparse values.
1354   for (int i = 0; i < value_count; ++i)
1355   {
1356     const std::vector<TI> &index = indices[i];
1357     TFLITE_DCHECK_EQ(index.size(), 4);
1358     const T value = values[i];
1359     output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] = value;
1360   }
1361 }
1362
1363 template <typename T>
1364 inline void Pow(const RuntimeShape &input1_shape, const T *input1_data,
1365                 const RuntimeShape &input2_shape, const T *input2_data,
1366                 const RuntimeShape &output_shape, T *output_data)
1367 {
1368   const int flat_size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
1369   for (int i = 0; i < flat_size; ++i)
1370   {
1371     output_data[i] = std::pow(input1_data[i], input2_data[i]);
1372   }
1373 }
1374
1375 template <typename T>
1376 inline void BroadcastPow4DSlow(const RuntimeShape &unextended_input1_shape, const T *input1_data,
1377                                const RuntimeShape &unextended_input2_shape, const T *input2_data,
1378                                const RuntimeShape &unextended_output_shape, T *output_data)
1379 {
1380   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
1381   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
1382   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1383   const RuntimeShape output_shape = RuntimeShape::ExtendedShape(4, unextended_output_shape);
1384
1385   NdArrayDesc<4> desc1;
1386   NdArrayDesc<4> desc2;
1387   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1,
1388                                       &desc2);
1389
1390   for (int b = 0; b < output_shape.Dims(0); ++b)
1391   {
1392     for (int y = 0; y < output_shape.Dims(1); ++y)
1393     {
1394       for (int x = 0; x < output_shape.Dims(2); ++x)
1395       {
1396         for (int c = 0; c < output_shape.Dims(3); ++c)
1397         {
1398           auto out_idx = Offset(output_shape, b, y, x, c);
1399           auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
1400           auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
1401           auto in1_val = input1_data[in1_idx];
1402           auto in2_val = input2_data[in2_idx];
1403           output_data[out_idx] = std::pow(in1_val, in2_val);
1404         }
1405       }
1406     }
1407   }
1408 }
1409
1410 template <typename Scalar>
1411 void Reverse(int axis, const RuntimeShape &input_shape, const Scalar *input_data,
1412              const RuntimeShape &output_shape, Scalar *output_data)
1413 {
1414   ruy::profiler::ScopeLabel label("Reverse");
1415
1416   int outer_size = 1;
1417   for (int i = 0; i < axis; ++i)
1418   {
1419     outer_size *= input_shape.Dims(i);
1420   }
1421
1422   int copy_size = 1;
1423   for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i)
1424   {
1425     copy_size *= input_shape.Dims(i);
1426   }
1427
1428   const int dims_at_axis = input_shape.Dims(axis);
1429   for (int i = 0; i < outer_size; ++i)
1430   {
1431     for (int j = 0; j < dims_at_axis; ++j)
1432     {
1433       const int start_pos = (i * dims_at_axis + j) * copy_size;
1434       Scalar *output_ptr = output_data + start_pos;
1435       int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
1436       memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
1437     }
1438   }
1439 }
1440
1441 template <typename Scalar, typename TS>
1442 void ReverseSequence(const TS *seq_lengths, const int seq_dim, const int batch_dim,
1443                      const RuntimeShape &input_shape, const Scalar *input_data,
1444                      const RuntimeShape &output_shape, Scalar *output_data)
1445 {
1446   ruy::profiler::ScopeLabel label("ReverseSequence");
1447
1448   int outer_size = 1;
1449   int outer_dim = std::min(batch_dim, seq_dim);
1450   int medium_dim = std::max(batch_dim, seq_dim);
1451   for (int i = 0; i < outer_dim; ++i)
1452   {
1453     outer_size *= input_shape.Dims(i);
1454   }
1455
1456   int medium_size = 1;
1457   for (int i = outer_dim + 1; i < medium_dim; ++i)
1458   {
1459     medium_size *= input_shape.Dims(i);
1460   }
1461
1462   int copy_size = 1;
1463   for (int i = medium_dim + 1; i < input_shape.DimensionsCount(); ++i)
1464   {
1465     copy_size *= input_shape.Dims(i);
1466   }
1467
1468   const int dims_at_outer_dim = input_shape.Dims(outer_dim);
1469   const int dims_at_medium_dim = input_shape.Dims(medium_dim);
1470
1471   Scalar *output_ptr;
1472   if (batch_dim > seq_dim)
1473   {
1474     for (int i = 0; i < outer_size; ++i)
1475     {
1476       for (int j = 0; j < dims_at_outer_dim; ++j)
1477       {
1478         const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1479         for (int p = 0; p < medium_size; ++p)
1480         {
1481           for (int q = 0; q < dims_at_medium_dim; ++q)
1482           {
1483             const int in_pos = ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1484             const Scalar *in_ptr = input_data + in_pos;
1485             int sl = seq_lengths[q] - 1;
1486             if (j > sl)
1487             {
1488               output_ptr = output_data + in_pos;
1489             }
1490             else
1491             {
1492               const int out_pos_base = (i * dims_at_outer_dim + sl - j) * medium_size;
1493               const int out_pos = ((out_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1494               output_ptr = output_data + out_pos;
1495             }
1496             memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1497           }
1498         }
1499       }
1500     }
1501   }
1502   else if (batch_dim < seq_dim)
1503   {
1504     for (int i = 0; i < outer_size; ++i)
1505     {
1506       for (int j = 0; j < dims_at_outer_dim; ++j)
1507       {
1508         const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1509         int sl = seq_lengths[j] - 1;
1510         const int out_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1511         for (int p = 0; p < medium_size; ++p)
1512         {
1513           for (int q = 0; q < dims_at_medium_dim; ++q)
1514           {
1515             const int in_pos = ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1516             const Scalar *in_ptr = input_data + in_pos;
1517             if (q > sl)
1518             {
1519               output_ptr = output_data + in_pos;
1520             }
1521             else
1522             {
1523               const int out_pos = ((out_pos_base + p) * dims_at_medium_dim + sl - q) * copy_size;
1524               output_ptr = output_data + out_pos;
1525             }
1526             memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1527           }
1528         }
1529       }
1530     }
1531   }
1532 }
1533
1534 template <typename T>
1535 inline void SegmentSum(const RuntimeShape &input_shape, const T *input_data,
1536                        const RuntimeShape &segment_ids_shape, const int32_t *segment_ids_data,
1537                        const RuntimeShape &output_shape, T *output_data)
1538 {
1539   const int segment_flat_size = MatchingFlatSizeSkipDim(input_shape, 0, output_shape);
1540
1541   memset(output_data, 0, sizeof(T) * output_shape.FlatSize());
1542
1543   for (int i = 0; i < input_shape.Dims(0); i++)
1544   {
1545     int output_index = segment_ids_data[i];
1546     for (int j = 0; j < segment_flat_size; ++j)
1547     {
1548       output_data[output_index * segment_flat_size + j] += input_data[i * segment_flat_size + j];
1549     }
1550   }
1551 }
1552
1553 } // namespace reference_ops
1554 } // namespace tflite
1555
1556 #endif // LUCI_INTERPRETER_PAL_REFERENCE_OPS_H