int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
uint8* output_data, const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
- const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < height; ++y) {
- for (int x = 0; x < width; ++x) {
- for (int c = 0; c < depth; ++c) {
- const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)];
- const int32 input_val_centered =
- static_cast<int32>(input_val_u8) - input_zero_point;
- uint8 output_val;
- if (input_val_centered <= -input_range_radius) {
- output_val = 0;
- } else if (input_val_centered >= input_range_radius) {
- output_val = 255;
- } else {
- const int32 input_val_rescaled =
- MultiplyByQuantizedMultiplierGreaterThanOne(
- input_val_centered, input_multiplier, input_left_shift);
- using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
- using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
- const FixedPoint4 input_val_f4 =
- FixedPoint4::FromRaw(input_val_rescaled);
- const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
-
- using gemmlowp::RoundingDivideByPOT;
- int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
- // TODO(mjmatthews): properly wire through this zero offset
- output_val_s32 += 127;
- if (output_val_s32 == -1) {
- // May underflow since we cannot properly represent -1.0f
- output_val_s32 = 0;
- }
- TFLITE_DCHECK_GE(output_val_s32, 0);
- TFLITE_DCHECK_LE(output_val_s32, 255);
- output_val = static_cast<uint8>(output_val_s32);
- }
- output_data[Offset(output_dims, c, x, y, b)] = output_val;
- }
+ // Note that this is almost the exact same code as in Logistic().
+ gemmlowp::ScopedProfilingLabel label("Tanh");
+ /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3);
+ /* height */ MatchingArraySize(input_dims, 2, output_dims, 2);
+ /* width */ MatchingArraySize(input_dims, 1, output_dims, 1);
+ /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int size = RequiredBufferSizeForDims(input_dims);
+
+ int c = 0;
+ int32_t output_zero_point = 128;
+#ifdef USE_NEON
+ // Handle 16 values at a time
+ for (; c <= size - 16; c += 16) {
+ // Read input uint8 values, cast to int16 and subtract input_zero_point
+ uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
+ int16x8_t input_val_centered_0 =
+ vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
+ vdupq_n_s16(input_zero_point));
+ int16x8_t input_val_centered_1 =
+ vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
+ vdupq_n_s16(input_zero_point));
+
+ // Prepare the bit masks that we will use at the end to implement the logic
+ // that was expressed in the scalar code with branching:
+ // if (input_val_centered < -input_range_radius) {
+ // output_val = 0;
+ // } else if (input_val_centered > input_range_radius) {
+ // output_val = 255;
+ // } else {
+ // ...
+ uint16x8_t mask_rightclamp_0 =
+ vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
+ uint16x8_t mask_rightclamp_1 =
+ vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
+ uint16x8_t mask_leftclamp_0 =
+ vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
+ uint16x8_t mask_leftclamp_1 =
+ vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
+ uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
+ vshrn_n_u16(mask_rightclamp_1, 8));
+ uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
+ vshrn_n_u16(mask_leftclamp_1, 8));
+
+ // This performs what is expressed in the scalar code as
+ // const int32 input_val_rescaled =
+ // MultiplyByQuantizedMultiplierGreaterThanOne(
+ // input_val_centered, input_multiplier, input_left_shift);
+ int32x4_t input_val_rescaled_0 =
+ vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
+ vdupq_n_s32(input_left_shift));
+ int32x4_t input_val_rescaled_1 =
+ vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
+ vdupq_n_s32(input_left_shift));
+ int32x4_t input_val_rescaled_2 =
+ vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
+ vdupq_n_s32(input_left_shift));
+ int32x4_t input_val_rescaled_3 =
+ vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
+ vdupq_n_s32(input_left_shift));
+ input_val_rescaled_0 =
+ vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
+ input_val_rescaled_1 =
+ vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
+ input_val_rescaled_2 =
+ vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
+ input_val_rescaled_3 =
+ vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
+
+ // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t
+ using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
+ const FixedPoint4 input_val_f4_0 =
+ FixedPoint4::FromRaw(input_val_rescaled_0);
+ const FixedPoint4 input_val_f4_1 =
+ FixedPoint4::FromRaw(input_val_rescaled_1);
+ const FixedPoint4 input_val_f4_2 =
+ FixedPoint4::FromRaw(input_val_rescaled_2);
+ const FixedPoint4 input_val_f4_3 =
+ FixedPoint4::FromRaw(input_val_rescaled_3);
+ const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
+ const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
+ const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
+ const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
+
+ // Divide by 2^24 as in the scalar code
+ using gemmlowp::RoundingDivideByPOT;
+ int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24);
+ int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24);
+ int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24);
+ int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24);
+
+ // Add the output zero point
+ int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point);
+ output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32);
+ output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32);
+ output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32);
+ output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32);
+
+ // Cast output values to uint8, saturating
+ int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
+ vqmovn_s32(output_val_s32_1));
+ int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
+ vqmovn_s32(output_val_s32_3));
+ uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
+ vqmovun_s16(output_val_s16_1));
+
+ // Perform the bit-masking with the bit masks computed at the beginning,
+ // see the comment there.
+ output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
+ output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
+
+ // Store back to memory
+ vst1q_u8(output_data + c, output_val_u8);
+ }
+#endif
+ // Leftover loop: handle one value at a time with scalar code.
+ for (; c < size; ++c) {
+ const uint8 input_val_u8 = input_data[c];
+ const int32 input_val_centered =
+ static_cast<int32>(input_val_u8) - input_zero_point;
+ uint8 output_val;
+ if (input_val_centered < -input_range_radius) {
+ output_val = 0;
+ } else if (input_val_centered > input_range_radius) {
+ output_val = 255;
+ } else {
+ const int32 input_val_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_val_centered, input_multiplier, input_left_shift);
+ using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+ const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
+ const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
+ using gemmlowp::RoundingDivideByPOT;
+ int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
+ output_val_s32 += output_zero_point;
+ if (output_val_s32 == 256) {
+ output_val_s32 = 255;
}
+ TFLITE_DCHECK_GE(output_val_s32, 0);
+ TFLITE_DCHECK_LE(output_val_s32, 255);
+ output_val = static_cast<uint8>(output_val_s32);
}
+ output_data[c] = output_val;
}
}
-
inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
int32 zero_point, double scale, float* output_data,
const Dims<4>& output_dims) {