[QNN] Requantize - Optimize lowering for some corner cases. (#3864)
authorAnimesh Jain <anijain@umich.edu>
Mon, 2 Sep 2019 02:58:38 +0000 (20:58 -0600)
committerWuwei Lin <wuwei@apache.org>
Mon, 2 Sep 2019 02:58:38 +0000 (22:58 -0400)
src/relay/qnn/op/requantize.cc
tests/python/relay/test_qnn_requantize.py

index 448395a..cf5f316 100644 (file)
@@ -129,48 +129,55 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
     tensor = Subtract(tensor, input_zp);
   }
 
-  // 3) Multiply the integer multiplier
-  if (left_shift != 0) {
-    tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
-  }
-  // Perform the multiplication in higher precision.
-  // The scalar is a fixed point value of int32 where the decimal point is
-  // between bits 31 and 30. After multiplying with input_tensor, the result is
-  // in int64 where the decimal point is sitting between bits 31 and 30 (from
-  // the right, rightmost bit is bit 0). The computation is performed in higher
-  // precision to avoid overflow in multiplying two int32 values.
-  Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
-  auto multiplied_t = Multiply(tensor, scalar);
+  // If the input and output scales are same, we can skip the fixed point multiplication.
+  auto scaled_int64_t = tensor;
+  if (param->input_scale != param->output_scale) {
+    // 3) Multiply the integer multiplier
+    if (left_shift != 0) {
+      tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
+    }
+    // Perform the multiplication in higher precision.
+    // The scalar is a fixed point value of int32 where the decimal point is
+    // between bits 31 and 30. After multiplying with input_tensor, the result is
+    // in int64 where the decimal point is sitting between bits 31 and 30 (from
+    // the right, rightmost bit is bit 0). The computation is performed in higher
+    // precision to avoid overflow in multiplying two int32 values.
+    Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
+    auto multiplied_t = Multiply(tensor, scalar);
 
-  // 4) Find the rounding scalar. This depends on where the final decimal point
-  // sits. As we will be right shifting the multiplied_t, we need to first
-  // calculate the total_right_shift.
-  int total_right_shift = right_shift + 31;
-  int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
+    // 4) Find the rounding scalar. This depends on where the final decimal point
+    // sits. As we will be right shifting the multiplied_t, we need to first
+    // calculate the total_right_shift.
+    int total_right_shift = right_shift + 31;
+    int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
 
-  tensor = multiplied_t;
-  Expr round_scalar;
-  if (param->rounding == "UPWARD") {
-    round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
-  } else if (param->rounding == "TONEAREST") {
-    auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
-    auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
-    auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
-    auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
+    tensor = multiplied_t;
+    Expr round_scalar;
+    if (param->rounding == "UPWARD") {
+      round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
+    } else if (param->rounding == "TONEAREST") {
+      auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
+      auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
+      auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
+      auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
 
-    auto zero = MakeConstantScalar(hp_dtype, 0);
-    auto zero_t = Full(zero, input_shape, hp_dtype);
-    round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
-  }
-  // Add the rounding scalar.
-  tensor = Add(tensor, round_scalar);
+      auto zero = MakeConstantScalar(hp_dtype, 0);
+      auto zero_t = Full(zero, input_shape, hp_dtype);
+      round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
+    }
+    // Add the rounding scalar.
+    tensor = Add(tensor, round_scalar);
 
-  // 5) Simply right shift the result to get the final output.
-  auto scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
+    // 5) Simply right shift the result to get the final output.
+    scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
+  }
 
   // 6) Add the output zero point.
-  auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
-  auto shifted_int64_t = Add(output_zp, scaled_int64_t);
+  auto shifted_int64_t = scaled_int64_t;
+  if (param->output_zero_point != 0) {
+    auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
+    shifted_int64_t = Add(output_zp, scaled_int64_t);
+  }
 
   // 7) Clip to the out_dtype min/max.
   auto q_min = GetQmin(out_dtype);
index 2afa7d9..131500c 100644 (file)
@@ -64,6 +64,7 @@ def test_requantize():
                           input_scale=0.5,
                           output_scale=0.5,
                           rounding=rounding)
+            assert 'right_shift' not in mod.astext()
             verify(mod, (golden_data, golden_output))
 
     def downscale_test():