Employ array flat sizes more directly in optimized_ops, some places in reference_ops.h.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 16 May 2018 13:30:19 +0000 (06:30 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 16 May 2018 13:33:30 +0000 (06:33 -0700)
PiperOrigin-RevId: 196819423

tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
tensorflow/contrib/lite/kernels/internal/types.h

index c92ed68..3b59f24 100644 (file)
@@ -67,7 +67,7 @@ using VectorMap = typename std::conditional<
 
 template <typename Scalar, int N>
 VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
-  const int size = RequiredBufferSizeForDims(dims);
+  const int size = FlatSize(dims);
   return VectorMap<Scalar>(data, size, 1);
 }
 
@@ -249,8 +249,8 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
                                              float output_activation_max) {
 #ifdef USE_NEON
   gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
-  const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3];
-  const int array_size = array_dims.sizes[3] * array_dims.strides[3];
+  const int bias_size = FlatSize(bias_dims);
+  const int array_size = FlatSize(array_dims);
   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
   float* array_ptr = array_data;
   float* array_end_ptr = array_ptr + array_size;
@@ -300,8 +300,8 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
   }
 #else  // not NEON
   gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
-  const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3];
-  const int array_size = array_dims.sizes[3] * array_dims.strides[3];
+  const int bias_size = FlatSize(bias_dims);
+  const int array_size = FlatSize(array_dims);
   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
   for (int array_offset = 0; array_offset < array_size;
        array_offset += bias_size) {
@@ -372,10 +372,8 @@ inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims,
   TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-  TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
-                       ArraySize(output_dims, 3),
-                   1);
-  const int input_size = input_dims.strides[3];
+  TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
+  const int input_size = FlatSizeSkipDim(input_dims, 3);
   const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
   // This special fast path for quantized LSTM cells does not try to support
   // odd sizes that we haven't encountered in any LSTM cell, that would
@@ -558,10 +556,8 @@ inline void GEMVForLstmCellWithSymmetricRange(
   TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-  TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
-                       ArraySize(output_dims, 3),
-                   1);
-  const int input_size = input_dims.strides[3];
+  TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
+  const int input_size = FlatSizeSkipDim(input_dims, 3);
   const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
   // This special fast path for quantized LSTM cells does not try to support
   // odd sizes that we haven't encountered in any LSTM cell, that would
@@ -894,10 +890,8 @@ inline void FullyConnectedAsGEMV(
   TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-  TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
-                       ArraySize(output_dims, 3),
-                   1);
-  const int input_size = input_dims.strides[3];
+  TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
+  const int input_size = FlatSizeSkipDim(input_dims, 3);
   const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
   static constexpr int kPeel = 4;
   for (int k = 0; k < input_size; k += 64) {
@@ -1078,8 +1072,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
   // array of which dimension is the batch dimension in it.
-  const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
-                      ArraySize(output_dims, 3);
+  const int batches = FlatSizeSkipDim(output_dims, 0);
 #ifdef USE_NEON
   const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
   if (batches == 1 && !(output_size % 4)) {
@@ -1135,8 +1128,7 @@ inline void FullyConnected(
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
   // array of which dimension is the batch dimension in it.
-  const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
-                      ArraySize(output_dims, 3);
+  const int batches = FlatSizeSkipDim(output_dims, 0);
   const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
   const int accum_depth = ArraySize(filter_dims, 0);
   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
@@ -1551,8 +1543,7 @@ inline void ExperimentalShuffledFullyConnected(
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
   // array of which dimension is the batch dimension in it.
-  const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
-                      ArraySize(output_dims, 3);
+  const int batches = FlatSizeSkipDim(output_dims, 0);
   const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
   const int accum_depth = ArraySize(weights_dims, 0);
   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
@@ -1988,15 +1979,11 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
   }
 
   const int gemm_input_rows = gemm_input_dims->sizes[0];
-  const int gemm_input_cols = gemm_input_dims->sizes[1] *
-                              gemm_input_dims->sizes[2] *
-                              gemm_input_dims->sizes[3];
+  const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0);
   const int filter_rows = filter_dims.sizes[3];
-  const int filter_cols =
-      filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+  const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
   const int output_rows = output_dims.sizes[0];
-  const int output_cols =
-      output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+  const int output_cols = FlatSizeSkipDim(output_dims, 0);
   TFLITE_DCHECK_EQ(output_rows, filter_rows);
   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
@@ -2150,14 +2137,11 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
                     Ac == FusedActivationFunctionType::kRelu1,
                 "");
   const int input_rows = input_dims.sizes[0];
-  const int input_cols =
-      input_dims.sizes[1] * input_dims.sizes[2] * input_dims.sizes[3];
+  const int input_cols = FlatSizeSkipDim(input_dims, 0);
   const int filter_rows = filter_dims.sizes[3];
-  const int filter_cols =
-      filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+  const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
   const int output_rows = output_dims.sizes[0];
-  const int output_cols =
-      output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+  const int output_cols = FlatSizeSkipDim(output_dims, 0);
   TFLITE_DCHECK_EQ(output_rows, filter_rows);
   TFLITE_DCHECK_EQ(output_cols, input_cols);
   TFLITE_DCHECK_EQ(filter_cols, input_rows);
@@ -2221,27 +2205,15 @@ void NonGlobalBatchNormalization(
     const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization");
   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height =
-      MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2,
-                        offset_dims, 2, output_dims, 2);
-  const int width =
-      MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1,
-                        offset_dims, 1, output_dims, 1);
-  const int depth =
-      MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
-                        offset_dims, 0, output_dims, 0);
+  const int inner_size = MatchingFlatSizeSkipDim(
+      input_dims, 3, mean_dims, multiplier_dims, offset_dims, output_dims);
 
   for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        for (int c = 0; c < depth; ++c) {
-          output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
-              (input_data[Offset(input_dims, c, x, y, b)] -
-               mean_data[Offset(mean_dims, c, x, y, 0)]) *
-                  multiplier_data[Offset(multiplier_dims, c, x, y, 0)] +
-              offset_data[Offset(offset_dims, c, x, y, 0)]);
-        }
-      }
+    for (int i = 0; i < inner_size; ++i) {
+      *output_data = ActivationFunction<Ac>(
+          (*input_data - mean_data[i]) * multiplier_data[i] + offset_data[i]);
+      ++output_data;
+      ++input_data;
     }
   }
 }
@@ -2256,24 +2228,17 @@ void GlobalBatchNormalization(const float* input_data,
                               const Dims<4>& offset_dims, float* output_data,
                               const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization");
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+  const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
   const int depth =
       MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
                         offset_dims, 0, output_dims, 0);
 
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        for (int c = 0; c < depth; ++c) {
-          output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
-              (input_data[Offset(input_dims, c, x, y, b)] -
-               mean_data[Offset(mean_dims, c, 0, 0, 0)]) *
-                  multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] +
-              offset_data[Offset(offset_dims, c, 0, 0, 0)]);
-        }
-      }
+  for (int i = 0; i < outer_size; ++i) {
+    for (int c = 0; c < depth; ++c) {
+      *output_data = ActivationFunction<Ac>(
+          (*input_data - mean_data[c]) * multiplier_data[c] + offset_data[c]);
+      ++output_data;
+      ++input_data;
     }
   }
 }
@@ -2290,44 +2255,26 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims,
 inline void Relu1(const float* input_data, const Dims<4>& input_dims,
                   float* output_data, const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
-  const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        for (int c = 0; c < depth; ++c) {
-          float val = input_data[Offset(input_dims, c, x, y, b)];
-          const float upper = 1;
-          const float lower = -1;
-          float clamped = val > upper ? upper : val < lower ? lower : val;
-          output_data[Offset(output_dims, c, x, y, b)] = clamped;
-        }
-      }
-    }
+  const int flat_size = MatchingFlatSize(input_dims, output_dims);
+  for (int i = 0; i < flat_size; ++i) {
+    const float val = input_data[i];
+    const float upper = 1;
+    const float lower = -1;
+    const float clamped = val > upper ? upper : val < lower ? lower : val;
+    output_data[i] = clamped;
   }
 }
 
 inline void Relu6(const float* input_data, const Dims<4>& input_dims,
                   float* output_data, const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
-  const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        for (int c = 0; c < depth; ++c) {
-          float val = input_data[Offset(input_dims, c, x, y, b)];
-          const float upper = 6;
-          const float lower = 0;
-          float clamped = val > upper ? upper : val < lower ? lower : val;
-          output_data[Offset(output_dims, c, x, y, b)] = clamped;
-        }
-      }
-    }
+  const int flat_size = MatchingFlatSize(input_dims, output_dims);
+  for (int i = 0; i < flat_size; ++i) {
+    const float val = input_data[i];
+    const float upper = 6;
+    const float lower = 0;
+    const float clamped = val > upper ? upper : val < lower ? lower : val;
+    output_data[i] = clamped;
   }
 }
 
@@ -2336,24 +2283,19 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims,
                      float* output_data, const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("L2Normalization");
   static_assert(Ac == FusedActivationFunctionType::kNone, "");
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+  const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        float squared_l2_norm = 0;
-        for (int c = 0; c < depth; ++c) {
-          float val = input_data[Offset(input_dims, c, x, y, b)];
-          squared_l2_norm += val * val;
-        }
-        float inverse_l2_norm = 1.0f / std::sqrt(squared_l2_norm);
-        for (int c = 0; c < depth; ++c) {
-          output_data[Offset(output_dims, c, x, y, b)] =
-              input_data[Offset(input_dims, c, x, y, b)] * inverse_l2_norm;
-        }
-      }
+  for (int i = 0; i < outer_size; ++i) {
+    float squared_l2_norm = 0;
+    for (int c = 0; c < depth; ++c) {
+      const float val = input_data[depth * i + c];
+      squared_l2_norm += val * val;
+    }
+    const float l2_norm = std::sqrt(squared_l2_norm);
+    for (int c = 0; c < depth; ++c) {
+      *output_data = *input_data / l2_norm;
+      ++output_data;
+      ++input_data;
     }
   }
 }
@@ -2407,15 +2349,11 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
                             int32 input_zero_point, uint8* output_data,
                             const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
-  const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-  TFLITE_DCHECK_EQ(batches, 1);
-  TFLITE_DCHECK_EQ(height, 1);
-  TFLITE_DCHECK_EQ(width, 1);
+  const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+  const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
+  TFLITE_DCHECK_EQ(outer_size, 1);
   int32 square_l2_norm = 0;
   for (int i = 0; i < depth; i++) {
     int32 diff = input_data[i] - input_zero_point;
@@ -2441,20 +2379,12 @@ inline void Add(const float* input1_data, const Dims<4>& input1_dims,
                 float output_activation_min, float output_activation_max,
                 float* output_data, const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("Add");
-  /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
-                                              output_dims, 3);
-  /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
-                                             output_dims, 2);
-  /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
-                                            output_dims, 1);
-  /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
-                                            output_dims, 0);
   TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
 
   int i = 0;
-  const int size = input1_dims.sizes[3] * input1_dims.strides[3];
+  const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
 #ifdef USE_NEON
   const auto activation_min = vdupq_n_f32(output_activation_min);
   const auto activation_max = vdupq_n_f32(output_activation_max);
@@ -2658,9 +2588,7 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
     TFLITE_DCHECK_EQ(output_activation_max, 32767);
   }
 
-  const int flat_size = RequiredBufferSizeForDims(output_dims);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size);
+  const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
 
   TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0);
   TFLITE_DCHECK_GE(input1_shift, 0);
@@ -2696,10 +2624,10 @@ void Add(const int32* input1_data, const Dims<4>& input1_dims,
   auto output_map = MapAsVector(output_data, output_dims);
   if (AreSameDims(input1_dims, input2_dims)) {
     output_map.array() = input1_map.array() + input2_map.array();
-  } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
+  } else if (FlatSize(input2_dims) == 1) {
     auto scalar = input2_data[0];
     output_map.array() = input1_map.array() + scalar;
-  } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
+  } else if (FlatSize(input1_dims) == 1) {
     auto scalar = input1_data[0];
     output_map.array() = scalar + input2_map.array();
   } else {
@@ -2923,20 +2851,12 @@ inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
                 float output_activation_min, float output_activation_max,
                 float* output_data, const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("Mul");
-  /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
-                                              output_dims, 3);
-  /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
-                                             output_dims, 2);
-  /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
-                                            output_dims, 1);
-  /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
-                                            output_dims, 0);
   TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
 
   int i = 0;
-  const int size = input1_dims.sizes[3] * input1_dims.strides[3];
+  const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
 #ifdef USE_NEON
   const auto activation_min = vdupq_n_f32(output_activation_min);
   const auto activation_max = vdupq_n_f32(output_activation_max);
@@ -3011,10 +2931,10 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims,
   auto output_map = MapAsVector(output_data, output_dims);
   if (AreSameDims(input1_dims, input2_dims)) {
     output_map.array() = input1_map.array() * input2_map.array();
-  } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
+  } else if (FlatSize(input2_dims) == 1) {
     auto scalar = input2_data[0];
     output_map.array() = input1_map.array() * scalar;
-  } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
+  } else if (FlatSize(input1_dims) == 1) {
     auto scalar = input1_data[0];
     output_map.array() = scalar * input2_map.array();
   } else {
@@ -3030,9 +2950,7 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
   // This is a copy of the reference implementation. We do not currently have a
   // properly optimized version.
 
-  const int flat_size = RequiredBufferSizeForDims(output_dims);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size);
+  const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
 
   for (int i = 0; i < flat_size; i++) {
     // F0 uses 0 integer bits, range [-1, 1].
@@ -3054,9 +2972,7 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
   // properly optimized version.
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
 
-  const int flat_size = RequiredBufferSizeForDims(output_dims);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size);
+  const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
 
   for (int i = 0; i < flat_size; i++) {
     // F0 uses 0 integer bits, range [-1, 1].
@@ -3199,26 +3115,11 @@ inline void Div(const float* input1_data, const Dims<4>& input1_dims,
                 const float* input2_data, const Dims<4>& input2_dims,
                 float output_activation_min, float output_activation_max,
                 float* output_data, const Dims<4>& output_dims) {
-  const int batches =
-      MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
-  const int height =
-      MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
-  const int width =
-      MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
-  const int depth =
-      MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        for (int c = 0; c < depth; ++c) {
-          output_data[Offset(output_dims, c, x, y, b)] =
-              ActivationFunctionWithMinMax(
-                  input1_data[Offset(input1_dims, c, x, y, b)] /
-                      input2_data[Offset(input2_dims, c, x, y, b)],
-                  output_activation_min, output_activation_max);
-        }
-      }
-    }
+  const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+  for (int i = 0; i < flat_size; i++) {
+    output_data[i] = ActivationFunctionWithMinMax(
+        input1_data[i] / input2_data[i], output_activation_min,
+        output_activation_max);
   }
 }
 
@@ -3272,26 +3173,12 @@ inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
                 const float* input2_data, const Dims<4>& input2_dims,
                 float output_activation_min, float output_activation_max,
                 float* output_data, const Dims<4>& output_dims) {
-  const int batches =
-      MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
-  const int height =
-      MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
-  const int width =
-      MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
-  const int depth =
-      MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        for (int c = 0; c < depth; ++c) {
-          output_data[Offset(output_dims, c, x, y, b)] =
-              ActivationFunctionWithMinMax(
-                  input1_data[Offset(input1_dims, c, x, y, b)] -
-                      input2_data[Offset(input2_dims, c, x, y, b)],
-                  output_activation_min, output_activation_max);
-        }
-      }
-    }
+  gemmlowp::ScopedProfilingLabel label("Sub");
+  const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+  for (int i = 0; i < flat_size; ++i) {
+    output_data[i] = ActivationFunctionWithMinMax(
+        input1_data[i] - input2_data[i], output_activation_min,
+        output_activation_max);
   }
 }
 
@@ -3600,15 +3487,9 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
   gemmlowp::ScopedProfilingLabel label(
       "LstmCell/quantized (8bit external, 16bit internal)");
   // Gather dimensions information, and perform consistency checks.
-  const int batches =
-      MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3,
-                        output_state_dims, 3, output_activ_dims, 3);
-  const int height =
-      MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2,
-                        output_state_dims, 2, output_activ_dims, 2);
-  const int width =
-      MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1,
-                        output_state_dims, 1, output_activ_dims, 1);
+  const int outer_size =
+      MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims,
+                              output_state_dims, output_activ_dims);
   TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
   TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
   const int input_depth = ArraySize(input_dims, 0);
@@ -3624,9 +3505,7 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
       MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
                         output_state_dims, 0, output_activ_dims, 0);
   TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
-  const int fc_batches = ArraySize(activ_temp_dims, 1) *
-                         ArraySize(activ_temp_dims, 2) *
-                         ArraySize(activ_temp_dims, 3);
+  const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
   const int fc_output_depth =
       MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
   const int fc_accum_depth = ArraySize(weights_dims, 0);
@@ -3682,7 +3561,6 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
 
   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
   // and muls, all done in 16-bit fixed-point.
-  const int outer_size = batches * width * height;
   const int16* input_gate_input_ptr = activ_temp_data_int16;
   const int16* input_modulation_gate_input_ptr =
       activ_temp_data_int16 + output_depth;
@@ -3848,20 +3726,15 @@ void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
   gemmlowp::ScopedProfilingLabel label("TensorFlowSplit");
   TFLITE_DCHECK_GE(outputs_count, 1);
   for (int i = 0; i < outputs_count; i++) {
-    /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
-    /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
-    /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
+    MatchingFlatSizeSkipDim(*output_dims[i], 0, input_dims);
   }
-  const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3);
-  const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2);
-  const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1);
+  const int outer_size = FlatSizeSkipDim(input_dims, 0);
   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  // for now we dont have a model with a TensorFlowSplit
+  // For now we don't have a model with a TensorFlowSplit
   // with fused activation function.
   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
-  const int whb = width * height * batches;
   const Scalar* input_ptr = input_data;
-  for (int k = 0; k < whb; k++) {
+  for (int k = 0; k < outer_size; k++) {
     for (int i = 0; i < outputs_count; ++i) {
       memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr,
              output_dims[i]->sizes[0] * sizeof(Scalar));
@@ -4386,10 +4259,7 @@ inline void LocalResponseNormalization(const float* input_data,
                                        float* output_data,
                                        const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization");
-  /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3);
-  /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2);
-  /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1);
-  /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0);
+  MatchingFlatSize(input_dims, output_dims);
 
   const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
   auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
@@ -4432,10 +4302,7 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims,
                     float beta, float* output_data,
                     const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("Softmax");
-  /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3);
-  /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2);
-  /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1);
-  /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0);
+  MatchingFlatSize(input_dims, output_dims);
 
   const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
   auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
@@ -4467,13 +4334,9 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
   using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
 
   gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+  const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
 
-  const int outer_size = batches * height * width;
-
   for (int b = 0; b < outer_size; ++b) {
     const uint8* input_data_ptr = input_data + b * depth;
     uint8* output_data_ptr = output_data + b * depth;
@@ -4665,35 +4528,30 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
                        float* output_data, const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("LogSoftmax");
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+  const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
 
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        // Find max element value which we'll use to ensure numerical stability
-        // taking advantage of the following equality:
-        // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
-        float max = std::numeric_limits<float>::lowest();
-        for (int c = 0; c < depth; ++c) {
-          max = std::max(max, input_data[Offset(input_dims, c, x, y, b)]);
-        }
+  for (int i = 0; i < outer_size; ++i) {
+    const float* block_input_data = input_data + i * depth;
+    float* block_output_data = output_data + i * depth;
+    // Find max element value which we'll use to ensure numerical stability
+    // taking advantage of the following equality:
+    // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
+    float max = std::numeric_limits<float>::lowest();
+    for (int c = 0; c < depth; ++c) {
+      max = std::max(max, block_input_data[c]);
+    }
 
-        // Compute sum.
-        float sum = 0.f;
-        for (int c = 0; c < depth; ++c) {
-          sum += std::exp(input_data[Offset(input_dims, c, x, y, b)] - max);
-        }
+    // Compute sum.
+    float sum = 0.f;
+    for (int c = 0; c < depth; ++c) {
+      sum += std::exp(block_input_data[c] - max);
+    }
 
-        // Compute result.
-        const float log_sum = std::log(sum);
-        for (int c = 0; c < depth; ++c) {
-          output_data[Offset(output_dims, c, x, y, b)] =
-              input_data[Offset(input_dims, c, x, y, b)] - max - log_sum;
-        }
-      }
+    // Compute result.
+    const float log_sum = std::log(sum);
+    for (int c = 0; c < depth; ++c) {
+      block_output_data[c] = block_input_data[c] - max - log_sum;
     }
   }
 }
@@ -4722,15 +4580,16 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
 
   for (int i = 0; i < outer_size; ++i) {
+    const uint8* block_input_data = input_data + i * depth;
+    uint8* block_output_data = output_data + i * depth;
     uint8 max_in_row = 0;
     for (int c = 0; c < depth; ++c) {
-      max_in_row = std::max(max_in_row, input_data[i * depth + c]);
+      max_in_row = std::max(max_in_row, block_input_data[c]);
     }
 
     FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
     for (int c = 0; c < depth; ++c) {
-      int32 input_diff =
-          static_cast<int32>(input_data[i * depth + c]) - max_in_row;
+      int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
       if (input_diff >= diff_min) {
         const int32 input_diff_rescaled =
             MultiplyByQuantizedMultiplierGreaterThanOne(
@@ -4764,8 +4623,7 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
                      reverse_scaling_right_shift));
 
     for (int c = 0; c < depth; ++c) {
-      int32 input_diff =
-          static_cast<int32>(input_data[i * depth + c]) - max_in_row;
+      int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
       if (input_diff > adjusted_diff_min) {
         const int32 input_diff_rescaled =
             MultiplyByQuantizedMultiplierGreaterThanOne(
@@ -4776,11 +4634,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
                 31 - kScaledDiffIntegerBits - kOutputIntegerBits) +
             255;
 
-        output_data[i * depth + c] = static_cast<uint8>(
+        block_output_data[c] = static_cast<uint8>(
             std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
       } else {
         // Set output to smallest value.
-        output_data[i * depth + c] = 0;
+        block_output_data[c] = 0;
       }
     }
   }
@@ -4800,11 +4658,7 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
                      int32 input_multiplier, int input_left_shift,
                      uint8* output_data, const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("Logistic/Uint8");
-  /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3);
-  /* height */ MatchingArraySize(input_dims, 2, output_dims, 2);
-  /* width */ MatchingArraySize(input_dims, 1, output_dims, 1);
-  /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0);
-  const int size = RequiredBufferSizeForDims(input_dims);
+  const int size = MatchingFlatSize(input_dims, output_dims);
 
   int c = 0;
 #ifdef USE_NEON
@@ -4939,8 +4793,7 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
                      int16* output_data, const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
-  const int flat_size = RequiredBufferSizeForDims(output_dims);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input_dims), flat_size);
+  const int flat_size = MatchingFlatSize(output_dims, input_dims);
 
   for (int i = 0; i < flat_size; i++) {
   }
@@ -5011,11 +4864,7 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
                  uint8* output_data, const Dims<4>& output_dims) {
   // Note that this is almost the exact same code as in Logistic().
   gemmlowp::ScopedProfilingLabel label("Tanh");
-  /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3);
-  /* height */ MatchingArraySize(input_dims, 2, output_dims, 2);
-  /* width */ MatchingArraySize(input_dims, 1, output_dims, 1);
-  /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0);
-  const int size = RequiredBufferSizeForDims(input_dims);
+  const int size = MatchingFlatSize(input_dims, output_dims);
 
   int c = 0;
   int32_t output_zero_point = 128;
@@ -5165,8 +5014,7 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
   TFLITE_DCHECK_GE(input_left_shift, 0);
   TFLITE_DCHECK_LE(input_left_shift, 1);
 
-  const int flat_size = RequiredBufferSizeForDims(output_dims);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input_dims), flat_size);
+  const int flat_size = MatchingFlatSize(output_dims, input_dims);
 
   int c = 0;
   const int16* input_data_ptr = input_data;
@@ -5261,20 +5109,11 @@ inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
                        int32 zero_point, double scale, float* output_data,
                        const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("Dequantize");
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
-  const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        for (int c = 0; c < depth; ++c) {
-          int32 val = input_data[Offset(input_dims, c, x, y, b)];
-          float result = static_cast<float>(scale * (val - zero_point));
-          output_data[Offset(output_dims, c, x, y, b)] = result;
-        }
-      }
-    }
+  const int flat_size = MatchingFlatSize(output_dims, input_dims);
+  for (int i = 0; i < flat_size; ++i) {
+    int32 val = input_data[i];
+    float result = static_cast<float>(scale * (val - zero_point));
+    output_data[i] = result;
   }
 }
 
@@ -5297,25 +5136,15 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
                          &nudged_max, &nudged_scale);
   const float inv_nudged_scale = 1.0f / nudged_scale;
 
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
-  const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        for (int c = 0; c < depth; ++c) {
-          const float src_val = input_data[Offset(input_dims, c, x, y, b)];
-          const float clamped =
-              std::min(nudged_max, std::max(nudged_min, src_val));
-          const float clamped_shifted = clamped - nudged_min;
-          const float dst_val =
-              TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale +
-              nudged_min;
-          output_data[Offset(output_dims, c, x, y, b)] = dst_val;
-        }
-      }
-    }
+  const int flat_size = MatchingFlatSize(output_dims, input_dims);
+  for (int i = 0; i < flat_size; ++i) {
+    const float src_val = input_data[i];
+    const float clamped = std::min(nudged_max, std::max(nudged_min, src_val));
+    const float clamped_shifted = clamped - nudged_min;
+    const float dst_val =
+        TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale +
+        nudged_min;
+    output_data[i] = dst_val;
   }
 }
 
@@ -6146,10 +5975,10 @@ void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
   auto output_map = MapAsVector(output_data, output_dims);
   if (AreSameDims(input1_dims, input2_dims)) {
     output_map.array() = input1_map.array() - input2_map.array();
-  } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
+  } else if (FlatSize(input1_dims) == 1) {
     auto scalar = input1_data[0];
     output_map.array() = scalar - input2_map.array();
-  } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
+  } else if (FlatSize(input2_dims) == 1) {
     auto scalar = input2_data[0];
     output_map.array() = input1_map.array() - scalar;
   } else {
@@ -6193,25 +6022,22 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
   // input dimensions here. We enforce the constraint that the last dimension
   // must always be 1.
   TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1);
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
+  const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
   const int depth = ArraySize(input_dims, 0);
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        auto max_value = input_data[Offset(input_dims, 0, x, y, b)];
-        int max_index = 0;
-        for (int d = 1; d < depth; ++d) {
-          const auto& curr_value = input_data[Offset(input_dims, d, x, y, b)];
-          if (curr_value > max_value) {
-            max_value = curr_value;
-            max_index = d;
-          }
-        }
-        output_data[Offset(output_dims, 0, x, y, b)] = max_index;
+  for (int i = 0; i < outer_size; ++i) {
+    auto max_value = *input_data;
+    ++input_data;
+    int max_index = 0;
+    for (int d = 1; d < depth; ++d) {
+      const auto& curr_value = *input_data;
+      if (curr_value > max_value) {
+        max_value = curr_value;
+        max_index = d;
       }
+      ++input_data;
     }
+    *output_data = max_index;
+    ++output_data;
   }
 }
 
index 0dacbb2..a56fc06 100644 (file)
@@ -893,13 +893,9 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
                             int32 input_zero_point, uint8* output_data,
                             const Dims<4>& output_dims) {
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-  TFLITE_DCHECK_EQ(batches, 1);
-  TFLITE_DCHECK_EQ(height, 1);
-  TFLITE_DCHECK_EQ(width, 1);
+  const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
+  TFLITE_DCHECK_EQ(outer_size, 1);
   int32 square_l2_norm = 0;
   for (int i = 0; i < depth; i++) {
     int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point;
@@ -1021,9 +1017,7 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
     TFLITE_DCHECK_EQ(output_activation_max, 32767);
   }
 
-  const int flat_size = RequiredBufferSizeForDims(output_dims);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size);
+  const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
 
   TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0);
   TFLITE_DCHECK_GE(input1_shift, 0);
@@ -1399,9 +1393,7 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
                 int16* output_data, const Dims<4>& output_dims) {
   gemmlowp::ScopedProfilingLabel label("Mul/Int16");
 
-  const int flat_size = RequiredBufferSizeForDims(output_dims);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size);
+  const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
 
   for (int i = 0; i < flat_size; i++) {
     // F0 uses 0 integer bits, range [-1, 1].
@@ -1421,9 +1413,7 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
   gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8");
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
 
-  const int flat_size = RequiredBufferSizeForDims(output_dims);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input1_dims), flat_size);
-  TFLITE_DCHECK_EQ(RequiredBufferSizeForDims(input2_dims), flat_size);
+  const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
 
   for (int i = 0; i < flat_size; i++) {
     // F0 uses 0 integer bits, range [-1, 1].
@@ -3529,7 +3519,7 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
   // computing their influence on the output, rather than looping through the
   // output elements in the typical "gather" access pattern of a conv. We
   // therefore must initialize the output array to zero.
-  for (int i = 0; i < RequiredBufferSizeForDims(output_dims); i++) {
+  for (int i = 0; i < FlatSize(output_dims); i++) {
     output_data[i] = 0.0f;
   }
 
@@ -3592,15 +3582,9 @@ template <typename T, ComparisonFn<T> F>
 inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
                        const T* input2_data, const Dims<4>& input2_dims,
                        bool* output_data, const Dims<4>& output_dims) {
-  const int64_t batches =
-      MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
-  const int64_t height =
-      MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
-  const int64_t width =
-      MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
-  const int64_t depth =
-      MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
-  for (int64_t i = 0; i < batches * height * width * depth; ++i) {
+  const int64_t flatsize =
+      MatchingFlatSize(input1_dims, input2_dims, output_dims);
+  for (int64_t i = 0; i < flatsize; ++i) {
     output_data[i] = F(input1_data[i], input2_data[i]);
   }
 }
@@ -3613,15 +3597,9 @@ inline void Comparison(int left_shift, const T* input1_data,
                        int32 input2_offset, int32 input2_multiplier,
                        int input2_shift, bool* output_data,
                        const Dims<4>& output_dims) {
-  const int64_t batches =
-      MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
-  const int64_t height =
-      MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
-  const int64_t width =
-      MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
-  const int64_t depth =
-      MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
-  for (int64_t i = 0; i < batches * height * width * depth; ++i) {
+  const int64_t flatsize =
+      MatchingFlatSize(input1_dims, input2_dims, output_dims);
+  for (int64_t i = 0; i < flatsize; ++i) {
     const int32 input1_val = input1_offset + input1_data[i];
     const int32 input2_val = input2_offset + input2_data[i];
     const int32 shifted_input1_val = input1_val * (1 << left_shift);
@@ -3749,19 +3727,9 @@ inline void Select(const D* input_condition_data,
                    const Dims<4>& input_x_dims, const T* input_y_data,
                    const Dims<4>& input_y_dims, T* output_data,
                    const Dims<4>& output_dims) {
-  const int64_t batches =
-      MatchingArraySize(input_condition_dims, 3, input_x_dims, 3, input_y_dims,
-                        3, output_dims, 3);
-  const int64_t height =
-      MatchingArraySize(input_condition_dims, 2, input_x_dims, 2, input_y_dims,
-                        2, output_dims, 2);
-  const int64_t width = MatchingArraySize(input_condition_dims, 1, input_x_dims,
-                                          1, input_y_dims, 1, output_dims, 1);
-  const int64_t depth = MatchingArraySize(input_condition_dims, 0, input_x_dims,
-                                          0, input_y_dims, 0, output_dims, 0);
-
-  const int64_t num_elements = batches * height * width * depth;
-  for (int64_t i = 0; i < num_elements; ++i) {
+  const int64_t flatsize =
+      MatchingFlatSize(input_x_dims, input_y_dims, output_dims);
+  for (int64_t i = 0; i < flatsize; ++i) {
     output_data[i] =
         input_condition_data[i] ? input_x_data[i] : input_y_data[i];
   }
@@ -3773,25 +3741,16 @@ inline void RankOneSelect(const D* input_condition_data,
                           const T* input_x_data, const Dims<4>& input_x_dims,
                           const T* input_y_data, const Dims<4>& input_y_dims,
                           T* output_data, const Dims<4>& output_dims) {
-  const int64_t rank = ArraySize(input_condition_dims, 0);
-
-  const int64_t batches =
-      MatchingArraySize(input_x_dims, 3, input_y_dims, 3, output_dims, 3);
-  const int64_t height =
-      MatchingArraySize(input_x_dims, 2, input_y_dims, 2, output_dims, 2);
-  const int64_t width =
-      MatchingArraySize(input_x_dims, 1, input_y_dims, 1, output_dims, 1);
-  const int64_t depth =
-      MatchingArraySize(input_x_dims, 0, input_y_dims, 0, output_dims, 0);
-
-  TFLITE_DCHECK_EQ(rank, batches);
+  const int64_t rank = MatchingArraySize(input_condition_dims, 0, input_x_dims,
+                                         3, input_y_dims, 3, output_dims, 3);
+  const int64_t inner_size =
+      MatchingFlatSizeSkipDim(input_x_dims, 3, input_y_dims, output_dims);
 
   int64_t offset = 0;
-  int64_t size = depth * height * width;
   for (int64_t i = 0; i < rank; i++) {
     const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
-    memcpy(output_data + offset, input_data + offset, size * sizeof(T));
-    offset += size;
+    memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
+    offset += inner_size;
   }
 }
 
index 3290c36..43c6883 100644 (file)
@@ -132,11 +132,11 @@ int MatchingArraySize(const ArrayType1& array1, int index1,
 
 template <int N>
 inline int FlatSize(const Dims<N>& dims) {
-  int max_offset = 0;
-  for (int i = 0; i < N; i++) {
-    max_offset += (dims.sizes[i] - 1) * dims.strides[i];
+  int flat_size = 1;
+  for (int i = 0; i < N; ++i) {
+    flat_size *= dims.sizes[i];
   }
-  return max_offset + 1;
+  return flat_size;
 }
 
 // Deprecated. Prefer FlatSize.
@@ -148,7 +148,7 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
 // arrays.
 template <int N>
 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
-  for (int i = 0; i < N; i++) {
+  for (int i = 0; i < N; ++i) {
     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
   }
   return FlatSize(dims);
@@ -157,7 +157,7 @@ inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
 template <int N>
 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
                             const Dims<N>& check_dims_1) {
-  for (int i = 0; i < N; i++) {
+  for (int i = 0; i < N; ++i) {
     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
   }
   return MatchingFlatSize(dims, check_dims_1);
@@ -167,7 +167,7 @@ template <int N>
 inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
                             const Dims<N>& check_dims_1,
                             const Dims<N>& check_dims_2) {
-  for (int i = 0; i < N; i++) {
+  for (int i = 0; i < N; ++i) {
     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
   }
   return FlatSize(dims, check_dims_1, check_dims_2);
@@ -178,7 +178,7 @@ inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
                             const Dims<N>& check_dims_1,
                             const Dims<N>& check_dims_2,
                             const Dims<N>& check_dims_3) {
-  for (int i = 0; i < N; i++) {
+  for (int i = 0; i < N; ++i) {
     TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
   }
   return FlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
@@ -191,7 +191,7 @@ template <int N>
 inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
   TFLITE_DCHECK(skip_dim >= 0 && skip_dim < N);
   int flat_size = 1;
-  for (int i = 0; i < N; i++) {
+  for (int i = 0; i < N; ++i) {
     flat_size *= (i == skip_dim) ? 1 : dims.sizes[i];
   }
   return flat_size;
@@ -201,7 +201,7 @@ inline int FlatSizeSkipDim(const Dims<N>& dims, int skip_dim) {
 template <int N>
 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
                                    const Dims<N>& check_dims_0) {
-  for (int i = 0; i < N; i++) {
+  for (int i = 0; i < N; ++i) {
     if (i != skip_dim) {
       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
     }
@@ -213,7 +213,7 @@ template <int N>
 inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
                                    const Dims<N>& check_dims_0,
                                    const Dims<N>& check_dims_1) {
-  for (int i = 0; i < N; i++) {
+  for (int i = 0; i < N; ++i) {
     if (i != skip_dim) {
       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
     }
@@ -226,7 +226,7 @@ inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
                                    const Dims<N>& check_dims_0,
                                    const Dims<N>& check_dims_1,
                                    const Dims<N>& check_dims_2) {
-  for (int i = 0; i < N; i++) {
+  for (int i = 0; i < N; ++i) {
     if (i != skip_dim) {
       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
     }
@@ -240,7 +240,7 @@ inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
                                    const Dims<N>& check_dims_1,
                                    const Dims<N>& check_dims_2,
                                    const Dims<N>& check_dims_3) {
-  for (int i = 0; i < N; i++) {
+  for (int i = 0; i < N; ++i) {
     if (i != skip_dim) {
       TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
     }