Adding support for standalone Tanh operator.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 7 Feb 2018 17:01:44 +0000 (09:01 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Feb 2018 17:05:53 +0000 (09:05 -0800)
PiperOrigin-RevId: 184845130

tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
tensorflow/contrib/lite/toco/graph_transformations/quantize.cc

index 5a8df67..cd52385 100644 (file)
@@ -3102,51 +3102,152 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
                  int32 input_zero_point, int32 input_range_radius,
                  int32 input_multiplier, int input_left_shift,
                  uint8* output_data, const Dims<4>& output_dims) {
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
-  const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
-  const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-  for (int b = 0; b < batches; ++b) {
-    for (int y = 0; y < height; ++y) {
-      for (int x = 0; x < width; ++x) {
-        for (int c = 0; c < depth; ++c) {
-          const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)];
-          const int32 input_val_centered =
-              static_cast<int32>(input_val_u8) - input_zero_point;
-          uint8 output_val;
-          if (input_val_centered <= -input_range_radius) {
-            output_val = 0;
-          } else if (input_val_centered >= input_range_radius) {
-            output_val = 255;
-          } else {
-            const int32 input_val_rescaled =
-                MultiplyByQuantizedMultiplierGreaterThanOne(
-                    input_val_centered, input_multiplier, input_left_shift);
-            using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
-            using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
-            const FixedPoint4 input_val_f4 =
-                FixedPoint4::FromRaw(input_val_rescaled);
-            const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
-
-            using gemmlowp::RoundingDivideByPOT;
-            int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
-            // TODO(mjmatthews): properly wire through this zero offset
-            output_val_s32 += 127;
-            if (output_val_s32 == -1) {
-              // May underflow since we cannot properly represent -1.0f
-              output_val_s32 = 0;
-            }
-            TFLITE_DCHECK_GE(output_val_s32, 0);
-            TFLITE_DCHECK_LE(output_val_s32, 255);
-            output_val = static_cast<uint8>(output_val_s32);
-          }
-          output_data[Offset(output_dims, c, x, y, b)] = output_val;
-        }
+  // Note that this is almost the exact same code as in Logistic().
+  gemmlowp::ScopedProfilingLabel label("Tanh");
+  /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3);
+  /* height */ MatchingArraySize(input_dims, 2, output_dims, 2);
+  /* width */ MatchingArraySize(input_dims, 1, output_dims, 1);
+  /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0);
+  const int size = RequiredBufferSizeForDims(input_dims);
+
+  int c = 0;
+  int32_t output_zero_point = 128;
+#ifdef USE_NEON
+  // Handle 16 values at a time
+  for (; c <= size - 16; c += 16) {
+    // Read input uint8 values, cast to int16 and subtract input_zero_point
+    uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
+    int16x8_t input_val_centered_0 =
+        vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
+                  vdupq_n_s16(input_zero_point));
+    int16x8_t input_val_centered_1 =
+        vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
+                  vdupq_n_s16(input_zero_point));
+
+    // Prepare the bit masks that we will use at the end to implement the logic
+    // that was expressed in the scalar code with branching:
+    //   if (input_val_centered < -input_range_radius) {
+    //     output_val = 0;
+    //   } else if (input_val_centered > input_range_radius) {
+    //     output_val = 255;
+    //   } else {
+    //     ...
+    uint16x8_t mask_rightclamp_0 =
+        vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
+    uint16x8_t mask_rightclamp_1 =
+        vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
+    uint16x8_t mask_leftclamp_0 =
+        vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
+    uint16x8_t mask_leftclamp_1 =
+        vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
+    uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
+                                             vshrn_n_u16(mask_rightclamp_1, 8));
+    uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
+                                            vshrn_n_u16(mask_leftclamp_1, 8));
+
+    // This performs what is expressed in the scalar code as
+    // const int32 input_val_rescaled =
+    //     MultiplyByQuantizedMultiplierGreaterThanOne(
+    //         input_val_centered, input_multiplier, input_left_shift);
+    int32x4_t input_val_rescaled_0 =
+        vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
+                  vdupq_n_s32(input_left_shift));
+    int32x4_t input_val_rescaled_1 =
+        vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
+                  vdupq_n_s32(input_left_shift));
+    int32x4_t input_val_rescaled_2 =
+        vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
+                  vdupq_n_s32(input_left_shift));
+    int32x4_t input_val_rescaled_3 =
+        vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
+                  vdupq_n_s32(input_left_shift));
+    input_val_rescaled_0 =
+        vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
+    input_val_rescaled_1 =
+        vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
+    input_val_rescaled_2 =
+        vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
+    input_val_rescaled_3 =
+        vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
+
+    // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t
+    using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
+    using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
+    const FixedPoint4 input_val_f4_0 =
+        FixedPoint4::FromRaw(input_val_rescaled_0);
+    const FixedPoint4 input_val_f4_1 =
+        FixedPoint4::FromRaw(input_val_rescaled_1);
+    const FixedPoint4 input_val_f4_2 =
+        FixedPoint4::FromRaw(input_val_rescaled_2);
+    const FixedPoint4 input_val_f4_3 =
+        FixedPoint4::FromRaw(input_val_rescaled_3);
+    const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
+    const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
+    const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
+    const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
+
+    // Divide by 2^24 as in the scalar code
+    using gemmlowp::RoundingDivideByPOT;
+    int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24);
+    int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24);
+    int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24);
+    int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24);
+
+    // Add the output zero point
+    int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point);
+    output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32);
+    output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32);
+    output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32);
+    output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32);
+
+    // Cast output values to uint8, saturating
+    int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
+                                              vqmovn_s32(output_val_s32_1));
+    int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
+                                              vqmovn_s32(output_val_s32_3));
+    uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
+                                           vqmovun_s16(output_val_s16_1));
+
+    // Perform the bit-masking with the bit masks computed at the beginning,
+    // see the comment there.
+    output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
+    output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
+
+    // Store back to memory
+    vst1q_u8(output_data + c, output_val_u8);
+  }
+#endif
+  // Leftover loop: handle one value at a time with scalar code.
+  for (; c < size; ++c) {
+    const uint8 input_val_u8 = input_data[c];
+    const int32 input_val_centered =
+        static_cast<int32>(input_val_u8) - input_zero_point;
+    uint8 output_val;
+    if (input_val_centered < -input_range_radius) {
+      output_val = 0;
+    } else if (input_val_centered > input_range_radius) {
+      output_val = 255;
+    } else {
+      const int32 input_val_rescaled =
+          MultiplyByQuantizedMultiplierGreaterThanOne(
+              input_val_centered, input_multiplier, input_left_shift);
+      using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
+      using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+      const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
+      const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
+      using gemmlowp::RoundingDivideByPOT;
+      int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
+      output_val_s32 += output_zero_point;
+      if (output_val_s32 == 256) {
+        output_val_s32 = 255;
       }
+      TFLITE_DCHECK_GE(output_val_s32, 0);
+      TFLITE_DCHECK_LE(output_val_s32, 255);
+      output_val = static_cast<uint8>(output_val_s32);
     }
+    output_data[c] = output_val;
   }
 }
-
 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
                        int32 zero_point, double scale, float* output_data,
                        const Dims<4>& output_dims) {
index c5f09af..f18543f 100644 (file)
@@ -2279,6 +2279,7 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
                  int32 input_zero_point, int32 input_range_radius,
                  int32 input_multiplier, int input_left_shift,
                  uint8* output_data, const Dims<4>& output_dims) {
+  const int32 output_zero_point = 128;
   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
   const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
@@ -2307,11 +2308,9 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
 
             using gemmlowp::RoundingDivideByPOT;
             int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
-            // TODO(mjmatthews): properly wire through this zero offset
-            output_val_s32 += 127;
-            if (output_val_s32 == -1) {
-              // May underflow since we cannot properly represent -1.0f
-              output_val_s32 = 0;
+            output_val_s32 += output_zero_point;
+            if (output_val_s32 == 256) {
+              output_val_s32 = 255;
             }
             TFLITE_DCHECK_GE(output_val_s32, 0);
             TFLITE_DCHECK_LE(output_val_s32, 255);
index 3a674f8..d7f804e 100644 (file)
@@ -294,7 +294,7 @@ bool ChooseHardcodedQuantizationForOperatorOutput(
   if (op.type == OperatorType::kTanh) {
     // Tanh has the range: [-1, 1].
     *quantized_data_type = ArrayDataType::kUint8;
-    quantization_params->zero_point = 127;
+    quantization_params->zero_point = 128;
     quantization_params->scale = 1. / 128.;
     // 0 should be exactly representable, as values will typically be centered
     // around 0, with many values near 0.