Implementation of the fully-connected TFLite Op using the symmetric quantization.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 1 May 2018 22:47:27 +0000 (15:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 1 May 2018 22:53:12 +0000 (15:53 -0700)
PiperOrigin-RevId: 195010312

tensorflow/contrib/lite/kernels/BUILD
tensorflow/contrib/lite/kernels/fully_connected.cc
tensorflow/contrib/lite/kernels/fully_connected_test.cc
tensorflow/contrib/lite/kernels/test_util.h

index 689f9bf..57b3136 100644 (file)
@@ -31,6 +31,7 @@ cc_library(
         "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite:schema_fbs_version",
         "//tensorflow/contrib/lite:string_util",
+        "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
         "//tensorflow/contrib/lite/testing:util",
         "//tensorflow/core:tflite_portable_logging",
         "@com_google_googletest//:gtest",
@@ -672,6 +673,7 @@ tf_cc_test(
         ":builtin_ops",
         "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite/kernels:test_util",
+        "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
         "@com_google_absl//absl/memory",
         "@com_google_googletest//:gtest",
     ],
index 888e679..c5bf50d 100644 (file)
@@ -55,19 +55,24 @@ struct OpData {
   // uint8_t these would be 0 and 255.
   int32_t output_activation_min;
   int32_t output_activation_max;
+  // The index of the temporary tensor where the quantized inputs are cached.
+  int input_quantized_index;
 };
 
 constexpr int kInputTensor = 0;
 constexpr int kWeightsTensor = 1;
 constexpr int kBiasTensor = 2;
 constexpr int kOutputTensor = 0;
+constexpr int kScratchBufferTensor = 1;
 
 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   // This is a builtin op, so we don't use the contents in 'buffer', if any.
   // Instead, we allocate a new object to carry information from Prepare() to
   // Eval().
   gemm_support::IncrementUsageCounter(context);
-  return new OpData;
+  auto* op_data = new OpData;
+  context->AddTensors(context, 1, &op_data->input_quantized_index);
+  return op_data;
 }
 
 void Free(TfLiteContext* context, void* buffer) {
@@ -121,6 +126,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
                                   &data->output_activation_max);
   }
 
+  // If we have to perform on-the-fly quantization (with quantized weights and
+  // float inputs) first we need to quantize the inputs. Allocate a temporary
+  // buffer to store the intermediate quantized values.
+  if (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8) {
+    TfLiteIntArrayFree(node->temporaries);
+    node->temporaries = TfLiteIntArrayCreate(1);
+    node->temporaries->data[0] = data->input_quantized_index;
+
+    TfLiteTensor* input_quantized =
+        &context->tensors[node->temporaries->data[0]];
+    input_quantized->type = kTfLiteUInt8;
+    input_quantized->allocation_type = kTfLiteArenaRw;
+
+    // TODO(raziel): add this logic to ResizeTensor.
+    if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+      TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+      TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+                                                       input_quantized_size));
+    }
+  }
+
   // Resize output.
   TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
   output_size_array->data[0] = batch_size;
@@ -163,6 +189,74 @@ TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
   return kTfLiteOk;
 }
 
+TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node,
+                              TfLiteFullyConnectedParams* params, OpData* data,
+                              TfLiteTensor* input, TfLiteTensor* filter,
+                              TfLiteTensor* bias, TfLiteTensor* input_quantized,
+                              TfLiteTensor* output) {
+  // Check the types for this hybrid Op.
+  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+  TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteUInt8);
+  TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
+  TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
+
+  int total_input_size = 1;
+  for (int i = 0; i < input->dims->size; i++) {
+    total_input_size *= input->dims->data[i];
+  }
+
+  const int input_size = filter->dims->data[1];
+  const int batch_size = total_input_size / filter->dims->data[1];
+  const int num_units = filter->dims->data[0];
+
+  // Output = bias if bias tensor exists.
+  if (bias) {
+    tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
+                                          output->data.f);
+  } else {
+    tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
+  }
+
+  // TODO(mirkov): change std::minmax_element with a vectorized call.
+  auto minmax_element =
+      std::minmax_element(input->data.f, input->data.f + total_input_size);
+  // Save matrix multiplication computation for all zero input.
+  if (*minmax_element.first == 0.0 && *minmax_element.second == 0.0) {
+    tensor_utils::ApplyActivationToVector(output->data.f,
+                                          batch_size * num_units,
+                                          params->activation, output->data.f);
+    return kTfLiteOk;
+  }
+
+  // Quantize input from float to uint8 + quantization params (scaling factor).
+  float min, max;
+  float* scaling_factors = new float[batch_size];
+
+  // Quantize each batch independently.
+  for (int b = 0; b < batch_size; ++b) {
+    const int offset = b * input_size;
+    tensor_utils::SymmetricQuantizeFloats(
+        input->data.f + offset, input_size,
+        reinterpret_cast<int8_t*>(input_quantized->data.uint8) + offset, &min,
+        &max, &scaling_factors[b]);
+    // Incorporate scaling of the filter.
+    scaling_factors[b] *= filter->params.scale;
+  }
+
+  // Compute output += weight * quantized_input
+  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      reinterpret_cast<int8_t*>(filter->data.uint8), num_units, input_size,
+      reinterpret_cast<int8_t*>(input_quantized->data.uint8), scaling_factors,
+      batch_size, output->data.f, /*result_stride=*/1);
+
+  // Apply activation function to floats.
+  tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units,
+                                        params->activation, output->data.f);
+  delete[] scaling_factors;
+
+  return kTfLiteOk;
+}
+
 #define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \
   if (params->activation == kTfLiteActNone) {                        \
     macro_name(target_namespace, kNone);                             \
@@ -178,7 +272,8 @@ template <KernelType kernel_type>
 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
                            TfLiteFullyConnectedParams* params, OpData* data,
                            TfLiteTensor* input, TfLiteTensor* filter,
-                           TfLiteTensor* bias, TfLiteTensor* output) {
+                           TfLiteTensor* bias, TfLiteTensor* input_quantized,
+                           TfLiteTensor* output) {
   gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
 
   int32_t input_offset = -input->params.zero_point;
@@ -195,9 +290,15 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
   if (kernel_type == kReference) {
     TF_LITE_FULLY_CONNECTED(reference_ops);
   } else if (kernel_type == kPie) {
-    // TODO(ahentz): we don't have a quantized version of the PIE kernels, so
-    // we just defer to the MINI ones.
-    TF_LITE_FULLY_CONNECTED(optimized_ops);
+    if (input->type == kTfLiteFloat32) {
+      // Pie currently only supports quantized models and float inputs/outputs.
+      return EvalPieQuantized(context, node, params, data, input, filter, bias,
+                              input_quantized, output);
+    } else {
+      // TODO(ahentz): we don't have a quantized version of the PIE kernels, so
+      // we just defer to the MINI ones.
+      TF_LITE_FULLY_CONNECTED(optimized_ops);
+    }
   } else {
     TF_LITE_FULLY_CONNECTED(optimized_ops);
   }
@@ -245,13 +346,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
 
-  switch (input->type) {  // Already know in/out types are same.
+  TfLiteTensor* input_quantized = &context->tensors[node->temporaries->data[0]];
+
+  switch (filter->type) {  // Already know in/out types are same.
     case kTfLiteFloat32:
       return EvalFloat<kernel_type>(context, node, params, data, input, filter,
                                     bias, output);
     case kTfLiteUInt8:
       return EvalQuantized<kernel_type>(context, node, params, data, input,
-                                        filter, bias, output);
+                                        filter, bias, input_quantized, output);
     default:
       context->ReportError(context, "Type not currently supported.");
       return kTfLiteError;
index 8741300..05dd028 100644 (file)
@@ -21,6 +21,7 @@ limitations under the License.
 #include <gtest/gtest.h>
 #include "absl/memory/memory.h"
 #include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
 #include "tensorflow/contrib/lite/kernels/register.h"
 #include "tensorflow/contrib/lite/kernels/test_util.h"
 #include "tensorflow/contrib/lite/model.h"
@@ -224,6 +225,60 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
   }
 };
 
+// In the hybrid model the weights are quantized (to uint8). But the bias,
+// input (and output) are expected to be in float precision.
+class HybridFullyConnectedOpModel : public SingleOpModel {
+ public:
+  HybridFullyConnectedOpModel(int units, int batches, const TensorData& input,
+                              const TensorData& weights,
+                              const TensorData& output = {TensorType_FLOAT32})
+      : batches_(batches), units_(units) {
+    int total_input_size = 1;
+    for (int i = 0; i < input.shape.size(); ++i) {
+      total_input_size *= input.shape[i];
+    }
+    input_size_ = total_input_size / batches_;
+
+    input_ = AddInput(input);
+    weights_ = AddInput(weights);
+
+    TensorData bias{TensorType_FLOAT32, {units_}};
+    bias_ = AddInput(bias);
+
+    output_ = AddOutput(output);
+
+    SetBuiltinOp(
+        BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
+        CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
+            .Union());
+    resolver_ = absl::make_unique<SingleOpResolver>(
+        BuiltinOperator_FULLY_CONNECTED,
+        ops::builtin::Register_FULLY_CONNECTED_PIE());
+    BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
+  }
+  void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+  void SetWeights(std::initializer_list<float> data) {
+    SymmetricQuantizeAndPopulate(weights_, data);
+  }
+
+  void SetInput(std::initializer_list<float> f) { PopulateTensor(input_, f); }
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+  int input_size() { return input_size_; }
+  int num_units() { return units_; }
+  int num_batches() { return batches_; }
+
+ protected:
+  int input_;
+  int weights_;
+  int bias_;
+  int output_;
+
+  int batches_;
+  int units_;
+  int input_size_;
+};
+
 const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
     {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
     {"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()},
@@ -231,18 +286,43 @@ const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
     {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
 });
 
-class FullyConnectedOpTest : public SingleOpTest {
+class FloatFullyConnectedOpTest : public SingleOpTest {
  protected:
   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
     return *kKernelMap;
   }
 };
 
+const auto kKernelMapNoPie = new std::map<string, TfLiteRegistration*>({
+    {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()},
+    {"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()},
+    {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
+});
+
+class QuantizedFullyConnectedOpTest : public SingleOpTest {
+ protected:
+  const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+    return *kKernelMapNoPie;
+  }
+};
+
+const auto kKernelMapPie = new std::map<string, TfLiteRegistration*>({
+    {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()},
+});
+
+// Hybrid mode is used by the Pie quantized kernel.
+class HybridFullyConnectedOpTest : public SingleOpTest {
+ protected:
+  const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+    return *kKernelMapPie;
+  }
+};
+
 // TODO(ahentz): add more small tests like this one, focused on making sure the
 // calculations are correct.
-TEST_P(FullyConnectedOpTest, SimpleTest) {
-  FloatFullyConnectedOpModel m(GetRegistration(), 3, 2,
-                               {TensorType_FLOAT32, {2, 10}});
+TEST_P(FloatFullyConnectedOpTest, SimpleTest) {
+  FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/3, /*batches=*/2,
+                               /*input=*/{TensorType_FLOAT32, {2, 10}});
   m.SetWeights({
       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
@@ -260,9 +340,9 @@ TEST_P(FullyConnectedOpTest, SimpleTest) {
   EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
 }
 
-TEST_P(FullyConnectedOpTest, SimpleTestQuantized) {
+TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) {
   QuantizedFullyConnectedOpModel m(
-      GetRegistration(), 3, 2,
+      GetRegistration(), /*units=*/3, /*batches*/ 2,
       /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
       /*output=*/{TensorType_UINT8, {}, -127, 128});
 
@@ -288,13 +368,40 @@ TEST_P(FullyConnectedOpTest, SimpleTestQuantized) {
   EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
 }
 
-TEST(FullyConnectedOpTest, SimpleTest4DInput) {
+TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) {
+  HybridFullyConnectedOpModel m(
+      /*units=*/3, /*batches=*/2,
+      /*input=*/{TensorType_FLOAT32, {2, 10}},
+      /*weights=*/{TensorType_UINT8, {3, 10}, -63.5, 64});  // PIE
+
+  m.SetWeights({
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
+  });
+  m.SetBias({1, 2, 3});
+
+  m.SetInput({
+      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
+      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
+  });
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+                                 {
+                                     24, 25, 26,  //
+                                     58, 59, 60,  //
+                                 },
+                                 /*max_abs_error=*/1.3f)));
+}
+
+TEST(FloatFullyConnectedOpTest, SimpleTest4DInput) {
   // Note that it is not required that the first dimension be the number of
   // batches. All we care is that the input can be evenly distributed in
   // batches. In this case, we need the input to have multiples of '2'.
   FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
-                               /*units=*/3,
-                               /*batches=*/2,
+                               /*units=*/3, /*batches=*/2,
                                /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
   m.SetWeights({
       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
@@ -316,9 +423,9 @@ TEST(FullyConnectedOpTest, SimpleTest4DInput) {
                              }));
 }
 
-TEST_P(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
+TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) {
   QuantizedFullyConnectedOpModel m(
-      GetRegistration(), 3, 2,
+      GetRegistration(), /*units=*/3, /*batches=*/2,
       /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
       /*output=*/{TensorType_UINT8, {}, -127, 128});
 
@@ -345,14 +452,18 @@ TEST_P(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
 }
 
 INSTANTIATE_TEST_CASE_P(
-    FullyConnectedOpTest, FullyConnectedOpTest,
+    FloatFullyConnectedOpTest, FloatFullyConnectedOpTest,
     ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
 
+INSTANTIATE_TEST_CASE_P(
+    QuantizedFullyConnectedOpTest, QuantizedFullyConnectedOpTest,
+    ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie)));
+
 // TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
 // to debug errors and doesn't necessarily test all the important details.
-TEST_P(FullyConnectedOpTest, BlackBoxTest) {
-  FloatFullyConnectedOpModel m(GetRegistration(), 16, 2,
-                               {TensorType_FLOAT32, {2, 8}});
+TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) {
+  FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/16, /*batches=*/2,
+                               /*input=*/{TensorType_FLOAT32, {2, 8}});
   m.SetWeights(
       {0.091327,  0.103366,  -0.316505, -0.083120, 0.149366,  -0.196636,
        -0.123672, 0.062800,  0.063031,  0.191670,  -0.062001, -0.061504,
index a9064d5..6fb6fe2 100644 (file)
@@ -21,6 +21,7 @@ limitations under the License.
 #include <gtest/gtest.h>
 
 #include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
 #include "tensorflow/contrib/lite/kernels/register.h"
 #include "tensorflow/contrib/lite/model.h"
 #include "tensorflow/contrib/lite/string_util.h"
@@ -133,6 +134,22 @@ class SingleOpModel {
     PopulateTensor(index, 0, q.data(), q.data() + q.size());
   }
 
+  void SymmetricQuantizeAndPopulate(int index,
+                                    std::initializer_list<float> data) {
+    TfLiteTensor* t = interpreter_->tensor(index);
+    std::vector<float> values(data);
+    const int length = values.size();
+    std::vector<int8_t> q(length);
+    float min, max, scaling_factor;
+    tensor_utils::SymmetricQuantizeFloats(values.data(), length, q.data(), &min,
+                                          &max, &scaling_factor);
+    // Update quantization params.
+    t->params.scale = scaling_factor;
+    t->params.zero_point = 0;
+    PopulateTensor(index, /*offset=*/0, reinterpret_cast<uint8_t*>(q.data()),
+                   reinterpret_cast<uint8_t*>(q.data() + q.size()));
+  }
+
   const std::vector<int>& GetShape(int id) { return tensor_data_.at(id).shape; }
 
   float GetScale(int id) { return tensor_data_.at(id).scale; }