Supports quantized reduce_mean in TF Lite.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Mar 2018 19:29:32 +0000 (12:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 19:33:39 +0000 (12:33 -0700)
PiperOrigin-RevId: 190813997

tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
tensorflow/contrib/lite/kernels/mean.cc
tensorflow/contrib/lite/kernels/mean_test.cc
tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
tensorflow/contrib/lite/toco/graph_transformations/quantize.cc

index ce12fad..33d60af 100644 (file)
@@ -3183,19 +3183,20 @@ inline void Exp(const T* input_data, const size_t num_elements,
   }
 }
 
-template <typename T>
-inline void Mean(T* input_data, const int* input_dims, const int input_num_dims,
+template <typename T, typename U>
+inline bool Mean(T* input_data, const int* input_dims, const int input_num_dims,
                  T* output_data, const int* output_dims,
                  const int output_num_dims, const int* axis,
                  const int num_axis_dimensions, bool keep_dims, int* temp_index,
-                 int* resolved_axis) {
+                 int* resolved_axis, U* temp_sum) {
   // resets output data.
   size_t num_outputs = 1;
   for (int idx = 0; idx < output_num_dims; ++idx) {
     num_outputs *= static_cast<size_t>(output_dims[idx]);
   }
   for (size_t idx = 0; idx < num_outputs; ++idx) {
-    output_data[idx] = 0;
+    output_data[idx] = T();
+    temp_sum[idx] = U();
   }
   // resets temp index.
   for (int idx = 0; idx < input_num_dims; ++idx) {
@@ -3228,19 +3229,24 @@ inline void Mean(T* input_data, const int* input_dims, const int input_num_dims,
     size_t output_offset =
         ReducedOutputOffset(input_num_dims, input_dims, temp_index,
                             num_resolved_axis, resolved_axis);
-    output_data[output_offset] += input_data[input_offset];
+    temp_sum[output_offset] += static_cast<U>(input_data[input_offset]);
   }
   // takes average by num of elements added to get mean.
   size_t num_elements_in_axis = 1;
   for (int idx = 0; idx < num_resolved_axis; ++idx) {
-    num_elements_in_axis *= static_cast<size_t>(input_dims[resolved_axis[idx]]);
+    size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
+    if (current > (std::numeric_limits<U>::max() / num_elements_in_axis)) {
+      return false;
+    }
+    num_elements_in_axis *= current;
   }
   if (num_elements_in_axis > 0) {
     for (size_t idx = 0; idx < num_outputs; ++idx) {
-      output_data[idx] = static_cast<T>(static_cast<float>(output_data[idx]) /
-                                        num_elements_in_axis);
+      output_data[idx] =
+          static_cast<T>(temp_sum[idx] / static_cast<U>(num_elements_in_axis));
     }
   }
+  return true;
 }
 
 template <typename T>
index aff1958..047bdd1 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #include <vector>
 #include "tensorflow/contrib/lite/builtin_op_data.h"
 #include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -48,7 +49,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   // Creates two temp tensors to store index and axis for internal
   // implementation only.
   auto* scratch_tensor_index = new int;
-  context->AddTensors(context, 2, scratch_tensor_index);
+  context->AddTensors(context, 3, scratch_tensor_index);
   return scratch_tensor_index;
 }
 
@@ -64,6 +65,14 @@ TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context,
   return context->ResizeTensor(context, resolved_axis, axis_size);
 }
 
+// Resizes the temp tensor that stores temp sum of reduced elements.
+TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context,
+                           TfLiteTensor* temp_sum) {
+  TfLiteIntArray* size = TfLiteIntArrayCreate(1);
+  size->data[0] = static_cast<int>(NumElements(op_context->output));
+  return context->ResizeTensor(context, temp_sum, size);
+}
+
 // Resizes output array based on the input size and resolved axis.
 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
                                 MeanContext* op_context) {
@@ -135,7 +144,7 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
   // Creates a temp index to iterate through input data.
   int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
   TfLiteIntArrayFree(node->temporaries);
-  node->temporaries = TfLiteIntArrayCreate(2);
+  node->temporaries = TfLiteIntArrayCreate(3);
   node->temporaries->data[0] = *scratch_tensor_index;
   TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]];
   scratch_tensor->type = kTfLiteInt32;
@@ -149,6 +158,25 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
   node->temporaries->data[1] = *scratch_tensor_index + 1;
   TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]];
   resolved_axis->type = kTfLiteInt32;
+  // Creates a temp tensor to store temp sums when calculating mean.
+  node->temporaries->data[2] = *scratch_tensor_index + 2;
+  TfLiteTensor* temp_sum = &context->tensors[node->temporaries->data[2]];
+  switch (op_context->input->type) {
+    case kTfLiteFloat32:
+      temp_sum->type = kTfLiteFloat32;
+      break;
+    case kTfLiteInt32:
+      temp_sum->type = kTfLiteInt64;
+      break;
+    case kTfLiteInt64:
+      temp_sum->type = kTfLiteInt64;
+      break;
+    case kTfLiteUInt8:
+      temp_sum->type = kTfLiteInt32;
+      break;
+    default:
+      return kTfLiteError;
+  }
   return kTfLiteOk;
 }
 
@@ -160,16 +188,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
 
   TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]];
+  TfLiteTensor* temp_sum = &context->tensors[node->temporaries->data[2]];
   // Leaves work to Eval if axis is not constant; else resizes output.
   if (!IsConstantTensor(op_context.axis)) {
     SetTensorToDynamic(op_context.output);
     SetTensorToDynamic(resolved_axis);
+    SetTensorToDynamic(temp_sum);
     return kTfLiteOk;
   }
   resolved_axis->allocation_type = kTfLiteArenaRw;
   TF_LITE_ENSURE_OK(context,
                     ResizeTempAxis(context, &op_context, resolved_axis));
-  return ResizeOutputTensor(context, &op_context);
+  TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+  temp_sum->allocation_type = kTfLiteArenaRw;
+  return ResizeTempSum(context, &op_context, temp_sum);
 }
 
 template <KernelType kernel_type>
@@ -178,14 +210,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   int num_axis = static_cast<int>(NumElements(op_context.axis));
   TfLiteTensor* temp_index = &context->tensors[node->temporaries->data[0]];
   TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]];
+  TfLiteTensor* temp_sum = &context->tensors[node->temporaries->data[2]];
   // Resize the output tensor if the output tensor is dynamic.
   if (IsDynamicTensor(op_context.output)) {
     TF_LITE_ENSURE_OK(context,
                       ResizeTempAxis(context, &op_context, resolved_axis));
     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+    TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
   }
 
-#define TF_LITE_MEAN(kernel_type, data_type)                        \
+#define TF_LITE_MEAN(kernel_type, data_type, temp_data_type)        \
   kernel_type::Mean<>(                                              \
       GetTensorData<data_type>(op_context.input),                   \
       op_context.input->dims->data, op_context.input->dims->size,   \
@@ -193,21 +227,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
       op_context.output->dims->data, op_context.output->dims->size, \
       GetTensorData<int>(op_context.axis), num_axis,                \
       op_context.params->keep_dims, GetTensorData<int>(temp_index), \
-      GetTensorData<int>(resolved_axis))
+      GetTensorData<int>(resolved_axis),                            \
+      GetTensorData<temp_data_type>(temp_sum))
 
   if (kernel_type == kReference) {
     switch (op_context.input->type) {
       case kTfLiteFloat32:
-        TF_LITE_MEAN(reference_ops, float);
+        TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, float, float));
         break;
       case kTfLiteInt32:
-        TF_LITE_MEAN(reference_ops, int);
-        break;
-      case kTfLiteUInt8:
-        TF_LITE_MEAN(reference_ops, uint8_t);
+        TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int, int64_t));
         break;
       case kTfLiteInt64:
-        TF_LITE_MEAN(reference_ops, int64_t);
+        TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int64_t, int64_t));
+        break;
+      case kTfLiteUInt8:
+        TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
+                          op_context.output->params.scale);
+        TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
+                          op_context.output->params.zero_point);
+        TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int));
         break;
       default:
         return kTfLiteError;
@@ -216,7 +255,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 #undef TF_LITE_MEAN
   return kTfLiteOk;
 }
-
 }  // namespace mean
 
 TfLiteRegistration* Register_MEAN_REF() {
index 2d6d4bc..79c9957 100644 (file)
@@ -37,8 +37,15 @@ class BaseMeanOpModel : public SingleOpModel {
     return ExtractVector<T>(output_);
   }
 
+  std::vector<float> GetDequantizedOutput() {
+    return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+                               GetScale(output_), GetZeroPoint(output_));
+  }
+
   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
 
+  int Input() { return input_; }
+
  protected:
   int input_;
   int axis_;
@@ -142,56 +149,64 @@ TEST(DynamicFloatMeanOpTest, Scale) {
   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
 }
 
+// for quantized Add, the error shouldn't exceed step
+float GetTolerance(int min, int max) { return (max - min) / 255.0; }
+
 TEST(ConstUint8MeanOpTest, NotKeepDims) {
-  std::initializer_list<uint8_t> data = {1,  2,  3,  4,  5,  6,  7,  8,
-                                         9,  10, 11, 12, 13, 14, 15, 16,
-                                         17, 18, 19, 20, 21, 22, 23, 24};
-  MeanOpConstModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {2}},
-                     {4}, {1, 0, -3, -3}, false);
-  m.SetInput(data);
+  float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+  std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+  MeanOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
+                     {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
+  m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
   m.Invoke();
-  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
-  EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({12, 13}));
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+  EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+                                            {0.4, 0.4}, kQuantizedTolerance)));
 }
 
 TEST(ConstUint8MeanOpTest, KeepDims) {
-  std::initializer_list<uint8_t> data = {1,  2,  3,  4,  5,  6,  7,  8,
-                                         9,  10, 11, 12, 13, 14, 15, 16,
-                                         17, 18, 19, 20, 21, 22, 23, 24};
-  MeanOpConstModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {3}},
-                     {2}, {0, 2}, true);
-  m.SetInput(data);
+  float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+  std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+  MeanOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
+                     {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
+  m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
   m.Invoke();
-  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
-  EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({10, 12, 14}));
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+  EXPECT_THAT(
+      m.GetDequantizedOutput(),
+      ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance)));
 }
 
 TEST(DynamicUint8MeanOpTest, NotKeepDims) {
-  std::initializer_list<uint8_t> data = {1,  2,  3,  4,  5,  6,  7,  8,
-                                         9,  10, 11, 12, 13, 14, 15, 16,
-                                         17, 18, 19, 20, 21, 22, 23, 24};
-  MeanOpDynamicModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {2}},
-                       {TensorType_INT32, {4}}, false);
-  std::initializer_list<int> axis = {1, 0, -3, -3};
+  float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
+  std::initializer_list<float> data = {1.3, -4.8, -3.6, 0.24};
+  MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
+                       {TensorType_UINT8, {2}, -5.0, 2.0},
+                       {TensorType_INT32, {1}}, false);
+  std::initializer_list<int> axis = {1};
   m.SetAxis(axis);
-  m.SetInput(data);
+  m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
   m.Invoke();
   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
-  EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({12, 13}));
+  EXPECT_THAT(
+      m.GetDequantizedOutput(),
+      ElementsAreArray(ArrayFloatNear({-1.75, -1.68}, kQuantizedTolerance)));
 }
 
 TEST(DynamicUint8MeanOpTest, KeepDims) {
-  std::initializer_list<uint8_t> data = {1,  2,  3,  4,  5,  6,  7,  8,
-                                         9,  10, 11, 12, 13, 14, 15, 16,
-                                         17, 18, 19, 20, 21, 22, 23, 24};
-  MeanOpDynamicModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {3}},
-                       {TensorType_INT32, {2}}, true);
-  std::initializer_list<int> axis = {0, 2};
+  float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+  std::initializer_list<float> data = {11.14, -0.14, 7.423, 0.879};
+  MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
+                       {TensorType_UINT8, {2}, -10.0, 12.0},
+                       {TensorType_INT32, {1}}, true);
+  std::initializer_list<int> axis = {0};
   m.SetAxis(axis);
-  m.SetInput(data);
+  m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
   m.Invoke();
-  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
-  EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({10, 12, 14}));
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+  EXPECT_THAT(
+      m.GetDequantizedOutput(),
+      ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance)));
 }
 
 }  // namespace
index 5cc82da..7c97ef0 100644 (file)
@@ -332,6 +332,7 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
     case OperatorType::kPad:
     case OperatorType::kGather:
     case OperatorType::kTranspose:
+    case OperatorType::kMean:
       changed = HardcodeMinMaxFromFirstInput(model, op);
       break;
 
index 9679ea0..9fcc95e 100644 (file)
@@ -52,7 +52,7 @@ bool SupportsQuantization(const Operator& op) {
          type == OperatorType::kStridedSlice ||
          type == OperatorType::kDepthToSpace ||
          type == OperatorType::kLstmCell || type == OperatorType::kGather ||
-         type == OperatorType::kTranspose;
+         type == OperatorType::kTranspose || type == OperatorType::kMean;
 }
 
 template <ArrayDataType A>