Implements linear no-offset (aka symmetric) quantizer.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 27 Apr 2018 00:56:08 +0000 (17:56 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 27 Apr 2018 00:58:50 +0000 (17:58 -0700)
PiperOrigin-RevId: 194482547

tensorflow/contrib/lite/kernels/internal/BUILD
tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
tensorflow/contrib/lite/kernels/internal/tensor_utils.h
tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc

index dce14cd..c5539af 100644 (file)
@@ -289,6 +289,7 @@ cc_library(
         "reference/portable_tensor_utils.h",
     ],
     deps = [
+        ":round",
         "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite/kernels:activation_functor",
         "//tensorflow/contrib/lite/kernels:op_macros",
@@ -310,6 +311,7 @@ cc_library(
     deps = [
         ":cpu_check",
         ":portable_tensor_utils",
+        ":round",
         ":types",
         "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite/kernels:activation_functor",
index 780401e..47dfcbe 100644 (file)
@@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include <stdlib.h>
 #include <string.h>
 
 #include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/kernels/internal/common.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/internal/common.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
 
 #ifdef USE_NEON
 
@@ -248,6 +249,83 @@ void NeonClipVector(const float* vector, int v_size, float abs_limit,
   }
 }
 
+void NeonSymmetricQuantizeFloats(const float* values, const int size,
+                                 int8_t* quantized_values, float* min,
+                                 float* max, float* scaling_factor) {
+  // TODO(raziel): vectorize min/max calculation.
+  auto minmax = std::minmax_element(values, values + size);
+  *min = *minmax.first;
+  *max = *minmax.second;
+  const int kScale = 127;
+  const float range = std::max(std::abs(*min), std::abs(*max));
+  if (range == 0) {
+    memset(quantized_values, 0, size * sizeof(int8_t));
+    *scaling_factor = 1;
+    return;
+  }
+  *scaling_factor = kScale / range;
+
+  const int postamble_start =
+      size - (size & (2 * kFloatWeightsPerNeonLane - 1));
+
+  // Vectorized constants.
+  const float32x4_t q_factor_f32x4 = vmovq_n_f32(*scaling_factor);
+  const float32x4_t point5_f32x4 = vmovq_n_f32(0.5);
+  const float32x4_t zero_f32x4 = vmovq_n_f32(0.0);
+  const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
+  const int32x4_t neg_scale_i32x4 = vmovq_n_s32(-kScale);
+
+  for (int i = 0; i < postamble_start; i += 2 * kFloatWeightsPerNeonLane) {
+    // Implements the vectorized version of the following:
+    // const int32 quantized_value = static_cast<int32>(
+    //    std::round(*scaling_factor * values[i]));
+    // Since the vectorized round intrinsics (vrndqa_f32) is not supported
+    // on all Neon flavors, we use the following method for rounding: if (x
+    // < 0) (int)(x - 0.5) if (x >= 0) (int)(x + 0.5)
+    float32x4_t value0_f32x4 = vld1q_f32(&values[i]);
+    float32x4_t value1_f32x4 = vld1q_f32(&values[i + kFloatWeightsPerNeonLane]);
+    float32x4_t mul0_f32x4 = vmulq_f32(value0_f32x4, q_factor_f32x4);
+    float32x4_t mul1_f32x4 = vmulq_f32(value1_f32x4, q_factor_f32x4);
+
+    int32x4_t cmp_with_zero0_ui32x4 =
+        (int32x4_t)vcltq_f32(mul0_f32x4, zero_f32x4);  // NOLINT
+    int32x4_t cmp_with_zero1_ui32x4 =
+        (int32x4_t)vcltq_f32(mul1_f32x4, zero_f32x4);  // NOLINT
+
+    float32x4_t cmp_with_zero0_f32x4 = vcvtq_f32_s32(cmp_with_zero0_ui32x4);
+    float32x4_t cmp_with_zero1_f32x4 = vcvtq_f32_s32(cmp_with_zero1_ui32x4);
+    cmp_with_zero0_f32x4 = vaddq_f32(cmp_with_zero0_f32x4, point5_f32x4);
+    cmp_with_zero1_f32x4 = vaddq_f32(cmp_with_zero1_f32x4, point5_f32x4);
+
+    mul0_f32x4 = vaddq_f32(mul0_f32x4, cmp_with_zero0_f32x4);
+    mul1_f32x4 = vaddq_f32(mul1_f32x4, cmp_with_zero1_f32x4);
+
+    int32x4_t f2i0_i32x4 = vcvtq_s32_f32(mul0_f32x4);
+    int32x4_t f2i1_i32x4 = vcvtq_s32_f32(mul1_f32x4);
+
+    // Implements the vectorized version of the folowing block:
+    //  quantized_values[i] = std::min(kScale, std::max(-kScale,
+    //  quantized_value));
+    int32x4_t max0_i32x4 = vmaxq_s32(f2i0_i32x4, neg_scale_i32x4);
+    int32x4_t max1_i32x4 = vmaxq_s32(f2i1_i32x4, neg_scale_i32x4);
+    int32x4_t min0_i32x4 = vminq_s32(max0_i32x4, scale_i32x4);
+    int32x4_t min1_i32x4 = vminq_s32(max1_i32x4, scale_i32x4);
+
+    int16x4_t min0_16x4 = vmovn_s32(min0_i32x4);
+    int16x4_t min1_16x4 = vmovn_s32(min1_i32x4);
+
+    int16x8_t min_16x8 = vcombine_s16(min0_16x4, min1_16x4);
+    int8x8_t min_s8x8 = vqmovn_s16(min_16x8);
+    vst1_s8(&quantized_values[i], min_s8x8);
+  }
+
+  for (int i = postamble_start; i < size; ++i) {
+    const int32 quantized_value =
+        static_cast<int32>(TfLiteRound(*scaling_factor * values[i]));
+    quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
+  }
+}
+
 float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
                                  int v_size) {
   // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
index b7e317d..3b6f4bd 100644 (file)
@@ -97,6 +97,13 @@ void ClipVector(const float* vector, int v_size, float abs_limit,
   NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result);
 }
 
+void SymmetricQuantizeFloats(const float* values, const int size,
+                             int8_t* quantized_values, float* min, float* max,
+                             float* scaling_factor) {
+  NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values, min,
+                   max, scaling_factor);
+}
+
 void VectorShiftLeft(float* vector, int v_size, float shift_value) {
   NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value);
 }
index ff15f3e..1922047 100644 (file)
@@ -117,6 +117,14 @@ void PortableZeroVector(float* vector, int v_size);
 // Limit a float input f between +abs_limit and -abs_limit.
 float PortableClip(float f, float abs_limit);
 
+// Symmetric quantizer.
+void PortableSymmetricQuantizeFloats(const float* values, const int size,
+                                     int8_t* quantized_values, float* min,
+                                     float* max, float* scaling_factor);
+void NeonSymmetricQuantizeFloats(const float* values, const int size,
+                                 int8_t* quantized_values, float* min,
+                                 float* max, float* scaling_factor);
+
 // Shift left a vector in place with v_size size.
 void PortableVectorShiftLeft(float* vector, int v_size, float shift_value);
 void NeonVectorShiftLeft(float* vector, int v_size, float shift_value);
index c5b0bcc..5e7586e 100644 (file)
@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include <stdlib.h>
 #include <string.h>
 
 #include "tensorflow/contrib/lite/builtin_op_data.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
 
 namespace tflite {
@@ -27,6 +29,28 @@ float PortableClip(float f, float abs_limit) {
   return result;
 }
 
+void PortableSymmetricQuantizeFloats(const float* values, const int size,
+                                     int8_t* quantized_values, float* min,
+                                     float* max, float* scaling_factor) {
+  auto minmax = std::minmax_element(values, values + size);
+  *min = *minmax.first;
+  *max = *minmax.second;
+  const int kScale = 127;
+  const float range = std::max(std::abs(*min), std::abs(*max));
+  if (range == 0) {
+    memset(quantized_values, 0, size * sizeof(int8_t));
+    *scaling_factor = 1;
+    return;
+  }
+  *scaling_factor = kScale / range;
+  for (int i = 0; i < size; ++i) {
+    const int32_t quantized_value =
+        static_cast<int32_t>(TfLiteRound(*scaling_factor * values[i]));
+    // Clamp: just in case some odd numeric offset.
+    quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
+  }
+}
+
 void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
                                                  int m_rows, int m_cols,
                                                  const float* vector,
index c05c21b..478cda8 100644 (file)
@@ -25,6 +25,10 @@ namespace tensor_utils {
 // Limit a float input f between +abs_limit and -abs_limit.
 float PortableClip(float f, float abs_limit);
 
+void PortableSymmetricQuantizeFloats(const float* values, const int size,
+                                     int8_t* quantized_values, float* min,
+                                     float* max, float* scaling_factor);
+
 // Multiply a matrix by a batch vector, and store results in a batch-size
 // vector.
 void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix,
@@ -103,6 +107,13 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
 
 float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
 
+void SymmetricQuantizeFloats(const float* values, const int size,
+                             int8_t* quantized_values, float* min, float* max,
+                             float* scaling_factor) {
+  return PortableSymmetricQuantizeFloats(values, size, quantized_values, min,
+                                         max, scaling_factor);
+}
+
 void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
                                          int m_cols, const float* vector,
                                          int n_batch, float* result,
index 40d1449..997dc44 100644 (file)
@@ -23,6 +23,14 @@ namespace tensor_utils {
 // Limit a float input f between +abs_limit and -abs_limit.
 float Clip(float f, float abs_limit);
 
+// Quantizes a buffer of floating point values using a symmetric quantization
+// (i.e. linear quantization without an offset) to 8-bit signed integers.
+// It also outputs the range (min, max) of the floating point buffer, and the
+// scaling factor used to quantize the values.
+void SymmetricQuantizeFloats(const float* values, const int size,
+                             int8_t* quantized_values, float* min, float* max,
+                             float* scaling_factor);
+
 // Multiply a matrix by a batch vector, and store results in a batch-size
 // vector using a stride value provided in result_stride. 'result_stride' shows
 // how the number of elements between consecutive result values. For example
index 588f1a4..22b0167 100644 (file)
@@ -32,6 +32,55 @@ TEST(uKernels, ClipTest) {
                   {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0})));
 }
 
+TEST(uKernels, SymmetricQuantizeFloatsTest) {
+  constexpr int kVectorSize = 9;
+  static float input[kVectorSize] = {-640, -635.0, -630, 10.0,  2.0,
+                                     -5.0, -10.0,  0.0,  1000.0};
+
+  int8 output[kVectorSize];
+  float min, max, scaling_factor;
+  SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
+                          &scaling_factor);
+
+  EXPECT_EQ(min, -640);
+  EXPECT_EQ(max, 1000);
+  EXPECT_NEAR(scaling_factor, 0.127, 1e-6);  // EQ won't work due to fpoint.
+  EXPECT_THAT(output,
+              testing::ElementsAreArray({-81, -81, -80, 1, 0, -1, -1, 0, 127}));
+}
+
+TEST(uKernels, SymmetricQuantizeFloatsAllZerosTest) {
+  constexpr int kVectorSize = 9;
+  static float input[kVectorSize] = {0, 0, 0, 0, 0, 0, 0, 0, 0};
+
+  int8 output[kVectorSize];
+  float min, max, scaling_factor;
+  SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
+                          &scaling_factor);
+
+  EXPECT_EQ(min, 0);
+  EXPECT_EQ(max, 0);
+  EXPECT_EQ(scaling_factor, 1);
+  EXPECT_THAT(output, testing::ElementsAreArray({0, 0, 0, 0, 0, 0, 0, 0, 0}));
+}
+
+TEST(uKernels, SymmetricQuantizeFloatsAllAlmostZeroTest) {
+  constexpr int kVectorSize = 9;
+  static float input[kVectorSize] = {-1e-5, 3e-5, -7e-6, -9e-5, 1e-6,
+                                     4e-5,  9e-6, 2e-4,  0};
+
+  int8 output[kVectorSize];
+  float min, max, scaling_factor;
+  SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max,
+                          &scaling_factor);
+
+  EXPECT_NEAR(min, -9e-05, 1e-6);
+  EXPECT_NEAR(max, 0.0002, 1e-6);
+  EXPECT_EQ(scaling_factor, 635000);
+  EXPECT_THAT(output,
+              testing::ElementsAreArray({-6, 19, -4, -57, 1, 25, 6, 127, 0}));
+}
+
 TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) {
   constexpr int kRow = 3;
   constexpr int kCol = 4;