From b25e6fe32cccd29ec4cb4014bbb45d62b75835b4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 May 2018 15:47:27 -0700 Subject: [PATCH] Implementation of the fully-connected TFLite Op using the symmetric quantization. PiperOrigin-RevId: 195010312 --- tensorflow/contrib/lite/kernels/BUILD | 2 + tensorflow/contrib/lite/kernels/fully_connected.cc | 117 ++++++++++++++++- .../contrib/lite/kernels/fully_connected_test.cc | 141 ++++++++++++++++++--- tensorflow/contrib/lite/kernels/test_util.h | 17 +++ 4 files changed, 255 insertions(+), 22 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 689f9bf..57b3136 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -31,6 +31,7 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:schema_fbs_version", "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "//tensorflow/contrib/lite/testing:util", "//tensorflow/core:tflite_portable_logging", "@com_google_googletest//:gtest", @@ -672,6 +673,7 @@ tf_cc_test( ":builtin_ops", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index 888e679..c5bf50d 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -55,19 +55,24 @@ struct OpData { // uint8_t these would be 0 and 255. int32_t output_activation_min; int32_t output_activation_max; + // The index of the temporary tensor where the quantized inputs are cached. + int input_quantized_index; }; constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; +constexpr int kScratchBufferTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { // This is a builtin op, so we don't use the contents in 'buffer', if any. // Instead, we allocate a new object to carry information from Prepare() to // Eval(). gemm_support::IncrementUsageCounter(context); - return new OpData; + auto* op_data = new OpData; + context->AddTensors(context, 1, &op_data->input_quantized_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { @@ -121,6 +126,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { &data->output_activation_max); } + // If we have to perform on-the-fly quantization (with quantized weights and + // float inputs) first we need to quantize the inputs. Allocate a temporary + // buffer to store the intermediate quantized values. + if (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8) { + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(1); + node->temporaries->data[0] = data->input_quantized_index; + + TfLiteTensor* input_quantized = + &context->tensors[node->temporaries->data[0]]; + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + + // TODO(raziel): add this logic to ResizeTensor. + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + } + // Resize output. TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); output_size_array->data[0] = batch_size; @@ -163,6 +189,74 @@ TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } +TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* input_quantized, + TfLiteTensor* output) { + // Check the types for this hybrid Op. + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + + int total_input_size = 1; + for (int i = 0; i < input->dims->size; i++) { + total_input_size *= input->dims->data[i]; + } + + const int input_size = filter->dims->data[1]; + const int batch_size = total_input_size / filter->dims->data[1]; + const int num_units = filter->dims->data[0]; + + // Output = bias if bias tensor exists. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // TODO(mirkov): change std::minmax_element with a vectorized call. + auto minmax_element = + std::minmax_element(input->data.f, input->data.f + total_input_size); + // Save matrix multiplication computation for all zero input. + if (*minmax_element.first == 0.0 && *minmax_element.second == 0.0) { + tensor_utils::ApplyActivationToVector(output->data.f, + batch_size * num_units, + params->activation, output->data.f); + return kTfLiteOk; + } + + // Quantize input from float to uint8 + quantization params (scaling factor). + float min, max; + float* scaling_factors = new float[batch_size]; + + // Quantize each batch independently. + for (int b = 0; b < batch_size; ++b) { + const int offset = b * input_size; + tensor_utils::SymmetricQuantizeFloats( + input->data.f + offset, input_size, + reinterpret_cast(input_quantized->data.uint8) + offset, &min, + &max, &scaling_factors[b]); + // Incorporate scaling of the filter. + scaling_factors[b] *= filter->params.scale; + } + + // Compute output += weight * quantized_input + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + reinterpret_cast(filter->data.uint8), num_units, input_size, + reinterpret_cast(input_quantized->data.uint8), scaling_factors, + batch_size, output->data.f, /*result_stride=*/1); + + // Apply activation function to floats. + tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units, + params->activation, output->data.f); + delete[] scaling_factors; + + return kTfLiteOk; +} + #define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \ if (params->activation == kTfLiteActNone) { \ macro_name(target_namespace, kNone); \ @@ -178,7 +272,8 @@ template TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, TfLiteTensor* input, TfLiteTensor* filter, - TfLiteTensor* bias, TfLiteTensor* output) { + TfLiteTensor* bias, TfLiteTensor* input_quantized, + TfLiteTensor* output) { gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); int32_t input_offset = -input->params.zero_point; @@ -195,9 +290,15 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, if (kernel_type == kReference) { TF_LITE_FULLY_CONNECTED(reference_ops); } else if (kernel_type == kPie) { - // TODO(ahentz): we don't have a quantized version of the PIE kernels, so - // we just defer to the MINI ones. - TF_LITE_FULLY_CONNECTED(optimized_ops); + if (input->type == kTfLiteFloat32) { + // Pie currently only supports quantized models and float inputs/outputs. + return EvalPieQuantized(context, node, params, data, input, filter, bias, + input_quantized, output); + } else { + // TODO(ahentz): we don't have a quantized version of the PIE kernels, so + // we just defer to the MINI ones. + TF_LITE_FULLY_CONNECTED(optimized_ops); + } } else { TF_LITE_FULLY_CONNECTED(optimized_ops); } @@ -245,13 +346,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - switch (input->type) { // Already know in/out types are same. + TfLiteTensor* input_quantized = &context->tensors[node->temporaries->data[0]]; + + switch (filter->type) { // Already know in/out types are same. case kTfLiteFloat32: return EvalFloat(context, node, params, data, input, filter, bias, output); case kTfLiteUInt8: return EvalQuantized(context, node, params, data, input, - filter, bias, output); + filter, bias, input_quantized, output); default: context->ReportError(context, "Type not currently supported."); return kTfLiteError; diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc index 8741300..05dd028 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/memory/memory.h" #include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" #include "tensorflow/contrib/lite/model.h" @@ -224,6 +225,60 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { } }; +// In the hybrid model the weights are quantized (to uint8). But the bias, +// input (and output) are expected to be in float precision. +class HybridFullyConnectedOpModel : public SingleOpModel { + public: + HybridFullyConnectedOpModel(int units, int batches, const TensorData& input, + const TensorData& weights, + const TensorData& output = {TensorType_FLOAT32}) + : batches_(batches), units_(units) { + int total_input_size = 1; + for (int i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = AddInput(weights); + + TensorData bias{TensorType_FLOAT32, {units_}}; + bias_ = AddInput(bias); + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + resolver_ = absl::make_unique( + BuiltinOperator_FULLY_CONNECTED, + ops::builtin::Register_FULLY_CONNECTED_PIE()); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + void SetWeights(std::initializer_list data) { + SymmetricQuantizeAndPopulate(weights_, data); + } + + void SetInput(std::initializer_list f) { PopulateTensor(input_, f); } + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + const auto kKernelMap = new std::map({ {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()}, {"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()}, @@ -231,18 +286,43 @@ const auto kKernelMap = new std::map({ {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()}, }); -class FullyConnectedOpTest : public SingleOpTest { +class FloatFullyConnectedOpTest : public SingleOpTest { protected: const std::map& GetKernelMap() override { return *kKernelMap; } }; +const auto kKernelMapNoPie = new std::map({ + {"Reference", ops::builtin::Register_FULLY_CONNECTED_REF()}, + {"NeonOptimized", ops::builtin::Register_FULLY_CONNECTED_NEON_OPT()}, + {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()}, +}); + +class QuantizedFullyConnectedOpTest : public SingleOpTest { + protected: + const std::map& GetKernelMap() override { + return *kKernelMapNoPie; + } +}; + +const auto kKernelMapPie = new std::map({ + {"Pie", ops::builtin::Register_FULLY_CONNECTED_PIE()}, +}); + +// Hybrid mode is used by the Pie quantized kernel. +class HybridFullyConnectedOpTest : public SingleOpTest { + protected: + const std::map& GetKernelMap() override { + return *kKernelMapPie; + } +}; + // TODO(ahentz): add more small tests like this one, focused on making sure the // calculations are correct. -TEST_P(FullyConnectedOpTest, SimpleTest) { - FloatFullyConnectedOpModel m(GetRegistration(), 3, 2, - {TensorType_FLOAT32, {2, 10}}); +TEST_P(FloatFullyConnectedOpTest, SimpleTest) { + FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}); m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 @@ -260,9 +340,9 @@ TEST_P(FullyConnectedOpTest, SimpleTest) { EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); } -TEST_P(FullyConnectedOpTest, SimpleTestQuantized) { +TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { QuantizedFullyConnectedOpModel m( - GetRegistration(), 3, 2, + GetRegistration(), /*units=*/3, /*batches*/ 2, /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, /*output=*/{TensorType_UINT8, {}, -127, 128}); @@ -288,13 +368,40 @@ TEST_P(FullyConnectedOpTest, SimpleTestQuantized) { EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); } -TEST(FullyConnectedOpTest, SimpleTest4DInput) { +TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) { + HybridFullyConnectedOpModel m( + /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}, + /*weights=*/{TensorType_UINT8, {3, 10}, -63.5, 64}); // PIE + + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 24, 25, 26, // + 58, 59, 60, // + }, + /*max_abs_error=*/1.3f))); +} + +TEST(FloatFullyConnectedOpTest, SimpleTest4DInput) { // Note that it is not required that the first dimension be the number of // batches. All we care is that the input can be evenly distributed in // batches. In this case, we need the input to have multiples of '2'. FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(), - /*units=*/3, - /*batches=*/2, + /*units=*/3, /*batches=*/2, /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}}); m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 @@ -316,9 +423,9 @@ TEST(FullyConnectedOpTest, SimpleTest4DInput) { })); } -TEST_P(FullyConnectedOpTest, SimpleTest4dInputQuantized) { +TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) { QuantizedFullyConnectedOpModel m( - GetRegistration(), 3, 2, + GetRegistration(), /*units=*/3, /*batches=*/2, /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64}, /*output=*/{TensorType_UINT8, {}, -127, 128}); @@ -345,14 +452,18 @@ TEST_P(FullyConnectedOpTest, SimpleTest4dInputQuantized) { } INSTANTIATE_TEST_CASE_P( - FullyConnectedOpTest, FullyConnectedOpTest, + FloatFullyConnectedOpTest, FloatFullyConnectedOpTest, ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); +INSTANTIATE_TEST_CASE_P( + QuantizedFullyConnectedOpTest, QuantizedFullyConnectedOpTest, + ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapNoPie))); + // TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard // to debug errors and doesn't necessarily test all the important details. -TEST_P(FullyConnectedOpTest, BlackBoxTest) { - FloatFullyConnectedOpModel m(GetRegistration(), 16, 2, - {TensorType_FLOAT32, {2, 8}}); +TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) { + FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/16, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 8}}); m.SetWeights( {0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636, -0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504, diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index a9064d5..6fb6fe2 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string_util.h" @@ -133,6 +134,22 @@ class SingleOpModel { PopulateTensor(index, 0, q.data(), q.data() + q.size()); } + void SymmetricQuantizeAndPopulate(int index, + std::initializer_list data) { + TfLiteTensor* t = interpreter_->tensor(index); + std::vector values(data); + const int length = values.size(); + std::vector q(length); + float min, max, scaling_factor; + tensor_utils::SymmetricQuantizeFloats(values.data(), length, q.data(), &min, + &max, &scaling_factor); + // Update quantization params. + t->params.scale = scaling_factor; + t->params.zero_point = 0; + PopulateTensor(index, /*offset=*/0, reinterpret_cast(q.data()), + reinterpret_cast(q.data() + q.size())); + } + const std::vector& GetShape(int id) { return tensor_data_.at(id).shape; } float GetScale(int id) { return tensor_data_.at(id).scale; } -- 2.7.4