[QNN] Refactor fixed point multiplication in requantize (#4073)
authorWuwei Lin <vincentl13x@gmail.com>
Tue, 8 Oct 2019 17:17:56 +0000 (13:17 -0400)
committerZhi <5145158+zhiics@users.noreply.github.com>
Tue, 8 Oct 2019 17:17:56 +0000 (10:17 -0700)
src/relay/pass/pattern_util.h
src/relay/qnn/op/requantize.cc
src/relay/qnn/util.cc [new file with mode: 0644]
src/relay/qnn/util.h

index 988b13c..525008b 100644 (file)
@@ -336,6 +336,14 @@ inline Expr ZerosLike(Expr e) {
   return CallNode::make(op, {e});
 }
 
+inline Expr Zeros(Array<IndexExpr> shape, DataType dtype) {
+  auto attrs = make_node<InitOpAttrs>();
+  attrs->shape = std::move(shape);
+  attrs->dtype = std::move(dtype);
+  static const Op& op = Op::Get("zeros");
+  return CallNode::make(op, {}, Attrs(attrs), {});
+}
+
 inline Expr OnesLike(Expr e) {
   static const Op& op = Op::Get("ones_like");
   return CallNode::make(op, {e});
index cf5f316..85d8dc3 100644 (file)
@@ -37,50 +37,7 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
 
 // Lowering of qnn.requantize op
 
-/*
- * \brief Convert FP32 representation into fixed point representation.
- * \param double_multplier The input FP32 number.
- * \return The pair of multiplier and shift for fixed point representation.
- * \note Converts a floating point number so that it can be represented by
- *       integers. The representation is
- *             float_number = (significand) * 2^(exponent)
- *
- *       The significand is a number between 0.5 and 1. This is represented by
- *       an integer number. For example, if it is int32, then the decimal point
- *       exists between bit 31 and 30 from LSB (or between first and second bit
- *       from the left).
- *
- *       Some examples are
- *           0.25 = (0.5) * 2^(-1)
- *           0.125 = (0.5) * 2^(-2)
- *
- *       Credit to TFLite reference implementation.
- */
-std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
-  int32_t significand, exponent;
-  if (double_multiplier == 0.) {
-    significand = 0;
-    exponent = 0;
-    return std::make_pair(significand, exponent);
-  }
 
-  // Get the significand and exponent.
-  double significand_d = std::frexp(double_multiplier, &exponent);
-
-  // Convert the double significand to int significand, i.e., convert into a
-  // integer where the decimal point is between bit 31 and 30. This is done by
-  // multiplying the double value with 2^31 and then casting to int.
-  significand_d = std::round(significand_d * (1ll << 31));
-  auto significand_int64 = static_cast<int64_t>(significand_d);
-  CHECK_LE(significand_int64, (1ll << 31));
-  if (significand_int64 == (1ll << 31)) {
-    significand_int64 /= 2;
-    ++exponent;
-  }
-  CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
-  significand = static_cast<int32_t>(significand_int64);
-  return std::make_pair(significand, exponent);
-}
 
 /*
  * \brief Lower requantize to a sequence of ops.
@@ -93,93 +50,41 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
  *       and shift. This is useful, if the target device does not support/have
  *       very expensive floating point computations.
  *
- *       Original compuation is scale_fp32 * quantized_tensor.  To convert into
- *       integer computation, the multiplication with fp32 scalar can be
- *       replaced by multiplication with an int value and then right shifting
- *       the result. This approximates the floating point computation with a
- *       fixed point computation.
- *
  *       The whole computation this can be broken down into following steps
  *       1) Calculate the integer multiplier and integer shift.
  *       2) Subtract the input integer zero point.
- *       3) Multiply the fixed point multiplier with quantized tensor.
- *       4) Round the result.
- *       5) Right shift the result.
- *       6) Add the output zero point.
- *       7) Cast to the out_dtype.
+ *       3) Perform fixed point multiplication.
+ *       4) Add the output zero point.
+ *       5) Cast to the out_dtype.
  */
 Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
                      const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
   double double_multiplier = param->input_scale / param->output_scale;
 
-  // Choose high precision datatype to be int64. This is for avoiding overflow
-  // in multiplication of two int32 values.
   DataType hp_dtype = Int(64);
 
-  // 1) Calculating the integer multiplier and integer shift
-  int32_t fixed_point_multiplier, shift;
-  std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
-  int left_shift = shift > 0 ? shift : 0;
-  int right_shift = shift > 0 ? 0 : -shift;
-
-  // 2) Subtract the input_zero_point
   auto tensor = Cast(input_tensor, hp_dtype);
+  // 1) Subtract the input_zero_point
   if (param->input_zero_point != 0) {
     auto input_zp = MakeConstantScalar(hp_dtype, param->input_zero_point);
     tensor = Subtract(tensor, input_zp);
   }
 
-  // If the input and output scales are same, we can skip the fixed point multiplication.
+  // 2) If the input and output scales are same, we can skip the fixed point multiplication.
   auto scaled_int64_t = tensor;
   if (param->input_scale != param->output_scale) {
-    // 3) Multiply the integer multiplier
-    if (left_shift != 0) {
-      tensor = Multiply(tensor, MakeConstantScalar(hp_dtype, 1 << left_shift));
-    }
-    // Perform the multiplication in higher precision.
-    // The scalar is a fixed point value of int32 where the decimal point is
-    // between bits 31 and 30. After multiplying with input_tensor, the result is
-    // in int64 where the decimal point is sitting between bits 31 and 30 (from
-    // the right, rightmost bit is bit 0). The computation is performed in higher
-    // precision to avoid overflow in multiplying two int32 values.
-    Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
-    auto multiplied_t = Multiply(tensor, scalar);
-
-    // 4) Find the rounding scalar. This depends on where the final decimal point
-    // sits. As we will be right shifting the multiplied_t, we need to first
-    // calculate the total_right_shift.
-    int total_right_shift = right_shift + 31;
-    int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
-
-    tensor = multiplied_t;
-    Expr round_scalar;
-    if (param->rounding == "UPWARD") {
-      round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
-    } else if (param->rounding == "TONEAREST") {
-      auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
-      auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
-      auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
-      auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
-
-      auto zero = MakeConstantScalar(hp_dtype, 0);
-      auto zero_t = Full(zero, input_shape, hp_dtype);
-      round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
-    }
-    // Add the rounding scalar.
-    tensor = Add(tensor, round_scalar);
-
-    // 5) Simply right shift the result to get the final output.
-    scaled_int64_t = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
+    scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape,
+                                       param->rounding);
   }
 
-  // 6) Add the output zero point.
+  // 3) Add the output zero point.
   auto shifted_int64_t = scaled_int64_t;
   if (param->output_zero_point != 0) {
     auto output_zp = MakeConstantScalar(hp_dtype, param->output_zero_point);
     shifted_int64_t = Add(output_zp, scaled_int64_t);
   }
 
-  // 7) Clip to the out_dtype min/max.
+  // 4) Clip to the out_dtype min/max.
   auto q_min = GetQmin(out_dtype);
   auto q_max = GetQmax(out_dtype);
   auto clipped_t = Clip(shifted_int64_t, q_min, q_max);
diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc
new file mode 100644 (file)
index 0000000..d9e4506
--- /dev/null
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file src/relay/qnn/util.cc
+ * \brief Utility functions for QNN.
+ */
+
+#include "util.h"
+#include "../pass/pattern_util.h"
+
+namespace tvm {
+namespace relay {
+namespace qnn {
+
+/*
+ * \brief Convert FP32 representation into fixed point representation.
+ * \param double_multplier The input FP32 number.
+ * \return The pair of multiplier and shift for fixed point representation.
+ * \note Converts a floating point number so that it can be represented by
+ *       integers. The representation is
+ *             float_number = (significand) * 2^(exponent)
+ *
+ *       The significand is a number between 0.5 and 1. This is represented by
+ *       an integer number. For example, if it is int32, then the decimal point
+ *       exists between bit 31 and 30 from LSB (or between first and second bit
+ *       from the left).
+ *
+ *       Some examples are
+ *           0.25 = (0.5) * 2^(-1)
+ *           0.125 = (0.5) * 2^(-2)
+ *
+ *       Credit to TFLite reference implementation.
+ */
+std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
+    double double_multiplier) {
+  int32_t significand, exponent;
+  if (double_multiplier == 0.) {
+    significand = 0;
+    exponent = 0;
+    return std::make_pair(significand, exponent);
+  }
+
+  // Get the significand and exponent.
+  double significand_d = std::frexp(double_multiplier, &exponent);
+
+  // Convert the double significand to int significand, i.e., convert into a
+  // integer where the decimal point is between bit 31 and 30. This is done by
+  // multiplying the double value with 2^31 and then casting to int.
+  significand_d = std::round(significand_d * (1ll << 31));
+  auto significand_int64 = static_cast<int64_t>(significand_d);
+  CHECK_LE(significand_int64, (1ll << 31));
+  if (significand_int64 == (1ll << 31)) {
+    significand_int64 /= 2;
+    ++exponent;
+  }
+  CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
+  significand = static_cast<int32_t>(significand_int64);
+  return std::make_pair(significand, exponent);
+}
+
+Expr FixedPointMuliply(Expr tensor, double multiplier,
+                   const Array<IndexExpr>& input_shape, const std::string& rounding) {
+  // Choose high precision datatype to be int64. This is for avoiding overflow
+  // in multiplication of two int32 values.
+  DataType hp_dtype = Int(64);
+
+  // 1) Calculating the integer multiplier and integer shift
+  int32_t fixed_point_multiplier, shift;
+  std::tie(fixed_point_multiplier, shift) =
+      GetFixedPointMultiplierShift(multiplier);
+  int left_shift = shift > 0 ? shift : 0;
+  int right_shift = shift > 0 ? 0 : -shift;
+
+  // 2) Multiply the integer multiplier
+  if (left_shift != 0) {
+    tensor = LeftShift(tensor, MakeConstantScalar(hp_dtype, left_shift));
+  }
+
+  // 3) Perform the multiplication in higher precision.
+  // The scalar is a fixed point value of int32 where the decimal point is
+  // between bits 31 and 30. After multiplying with input_tensor, the result
+  // is in int64 where the decimal point is sitting between bits 31 and 30
+  // (from the right, rightmost bit is bit 0). The computation is performed in
+  // higher precision to avoid overflow in multiplying two int32 values.
+  Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
+  tensor = Multiply(tensor, scalar);
+
+  // 4) Find the rounding scalar. This depends on where the final decimal
+  // point sits. As we will be right shifting the multiplied_t, we need to
+  // first calculate the total_right_shift.
+  int total_right_shift = right_shift + 31;
+  int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
+
+  Expr round_scalar;
+  if (rounding == "UPWARD") {
+    round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
+  } else if (rounding == "TONEAREST") {
+    auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
+    auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
+    auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
+    auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
+
+    auto zero_t = Zeros(input_shape, hp_dtype);
+    round_scalar =
+        Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
+  }
+  // Add the rounding scalar.
+  tensor = Add(tensor, round_scalar);
+
+  // 5) Simply right shift the result to get the final output.
+  tensor =
+      RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
+
+  return tensor;
+}
+
+}  // namespace qnn
+}  // namespace relay
+}  // namespace tvm
index 9482331..c261837 100644 (file)
@@ -27,6 +27,7 @@
 
 #include <tvm/expr.h>
 #include <tvm/relay/expr.h>
+#include <tvm/relay/qnn/attrs.h>
 #include <limits>
 #include <string>
 #include <utility>
@@ -92,6 +93,32 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
   return value_ptr[0];
 }
 
+/*
+ * \brief Fixed point multiplication between integer tensor with floating point
+ scalar.
+ * \param tensor The quantized input tensor of dtype int64.
+ * \param multiplier The scalar multiplier.
+ * \param input_shape Shape of the input tensor.
+ * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value
+ is midway between" "two representable values.
+ * \return The sequence of Relay ops for fixed point multiplication.
+
+ * \note Original compuation is scale_fp32 * quantized_tensor.  To convert into
+ *       integer computation, the multiplication with fp32 scalar can be
+ *       replaced by multiplication with an int value and then right shifting
+ *       the result. This approximates the floating point computation with a
+ *       fixed point computation.
+ *
+ *       Computation of fixed point multiplication is consist of following
+ steps:
+ *       1) Multiply the fixed point multiplier with quantized tensor.
+ *       2) Round the result.
+ *       3) Right shift the result
+ */
+Expr FixedPointMuliply(Expr tensor, double multiplier,
+                       const Array<IndexExpr>& input_shape,
+                       const std::string& rounding);
+
 }  // namespace qnn
 }  // namespace relay
 }  // namespace tvm