[Relay][Quantize] Use fixed point mulplications (#4160)
authorWuwei Lin <vincentl13x@gmail.com>
Tue, 29 Oct 2019 04:51:29 +0000 (00:51 -0400)
committerTianqi Chen <tqchen@users.noreply.github.com>
Tue, 29 Oct 2019 04:51:29 +0000 (00:51 -0400)
python/tvm/relay/quantize/quantize.py
src/relay/pass/quantize/quantize.cc
src/relay/pass/quantize/quantize.h
src/relay/pass/quantize/realize.cc
src/relay/qnn/op/requantize.cc
src/relay/qnn/util.cc
src/relay/qnn/util.h

index adde205..7fa8a66 100644 (file)
@@ -83,6 +83,7 @@ class QConfig(NodeBase):
         "do_simulation": False,
         "round_for_shift": True,
         "debug_enabled_ops": None,
+        "rounding": "UPWARD"
     }
 
     # pylint: disable=no-member
@@ -160,6 +161,9 @@ def qconfig(**kwargs):
         is None, which means will try to call all operartors' annotate rewrite
         function.
 
+    rounding: "UPWARD" or "TONEAREST"
+        Rounding direction for fixed point multiplications.
+
     Returns
     -------
     config: QConfig
index 3d0e71e..d564d2e 100644 (file)
@@ -126,7 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
   p->stream << "do_simulation==" << op->do_simulation << ", ";
   p->stream << "round_for_shift==" << op->round_for_shift << ", ";
-  p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
+  p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", ";
+  p->stream << "rounding==" << op->rounding;
   p->stream << ")";
 });
 
index 412bce0..8a0282a 100644 (file)
@@ -75,6 +75,7 @@ class QConfigNode : public Node {
   bool do_simulation = false;
   bool round_for_shift = true;
   Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
+  std::string rounding = "UPWARD";
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("nbit_input", &nbit_input);
@@ -88,6 +89,7 @@ class QConfigNode : public Node {
     v->Visit("do_simulation", &do_simulation);
     v->Visit("round_for_shift", &round_for_shift);
     v->Visit("debug_enabled_ops", &debug_enabled_ops);
+    v->Visit("rounding", &rounding);
   }
 
   static constexpr const char* _type_key = "relay.quantize.QConfig";
index bdd0d73..a0e8ffc 100644 (file)
@@ -31,6 +31,7 @@
 #include <tvm/relay/attrs/annotation.h>
 #include "./quantize.h"
 #include "../pattern_util.h"
+#include "../../qnn/util.h"
 
 namespace tvm {
 namespace relay {
@@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
 
 
 /* calculate `data * s1 / s2`, use shift if possible */
-inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
+inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
+                      const Array<IndexExpr> &data_shape) {
+  const QConfig& cfg = QConfig::Current();
   // here we assume the dtype of data is dtype activation
   if (s1 == s2) return data;
 
@@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
   } else if (static_cast<int>(factor) == factor) {
     return Multiply(data, MakeConstantScalar(dtype, factor));
   } else {
-    data = Cast(data, Float(32));
-    data = Multiply(data, MakeConstantScalar(Float(32), factor));
-    return Cast(Round(data), dtype);
+    data = qnn::FixedPointMultiply(Cast(data, Int(64)), factor, data_shape, cfg->rounding);
+    return Cast(data, dtype);
   }
 }
 
@@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call,
       data = Clip(data, clip_min_imm, clip_max_imm);
       return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
     } else {
-      // float computation
-      data = Cast(data, Float(32));
-      Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
-      Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
-      return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
+      data = Cast(data, Int(64));
+      data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
+                                     ref_call->type_as<TensorTypeNode>()->shape,
+                                     cfg->rounding);
+      data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
+      return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
     }
   }
 
@@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
   Expr dom_scale = MakeConstantScalar(Float(32), s);
   for (size_t i = 0; i < ret.size(); ++i) {
     float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
-    ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
+    ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape));
   }
 
   *dtype_ptr = dtype;
index 4a424d1..a361969 100644 (file)
@@ -37,8 +37,6 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
 
 // Lowering of qnn.requantize op
 
-
-
 /*
  * \brief Lower requantize to a sequence of ops.
  * \param input_tensor The input tensor to requantize op.
@@ -73,8 +71,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
   // 2) If the input and output scales are same, we can skip the fixed point multiplication.
   auto scaled_int64_t = tensor;
   if (param->input_scale != param->output_scale) {
-    scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape,
-                                       param->rounding);
+    scaled_int64_t =
+        FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
   }
 
   // 3) Add the output zero point.
index d9e4506..f0ad8ab 100644 (file)
@@ -76,7 +76,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
   return std::make_pair(significand, exponent);
 }
 
-Expr FixedPointMuliply(Expr tensor, double multiplier,
+Expr FixedPointMultiply(Expr tensor, double multiplier,
                    const Array<IndexExpr>& input_shape, const std::string& rounding) {
   // Choose high precision datatype to be int64. This is for avoiding overflow
   // in multiplication of two int32 values.
@@ -121,6 +121,8 @@ Expr FixedPointMuliply(Expr tensor, double multiplier,
     auto zero_t = Zeros(input_shape, hp_dtype);
     round_scalar =
         Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
+  } else {
+    LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
   }
   // Add the rounding scalar.
   tensor = Add(tensor, round_scalar);
index f94860d..0c35737 100644 (file)
@@ -115,9 +115,9 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
  *       2) Round the result.
  *       3) Right shift the result
  */
-Expr FixedPointMuliply(Expr tensor, double multiplier,
-                       const Array<IndexExpr>& input_shape,
-                       const std::string& rounding);
+Expr FixedPointMultiply(Expr tensor, double multiplier,
+                        const Array<IndexExpr>& input_shape,
+                        const std::string& rounding);
 
 }  // namespace qnn
 }  // namespace relay