From e8899285ea7a2d57938c64907c4dfd3f1a60688e Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 29 Oct 2019 00:51:29 -0400 Subject: [PATCH] [Relay][Quantize] Use fixed point mulplications (#4160) --- python/tvm/relay/quantize/quantize.py | 4 ++++ src/relay/pass/quantize/quantize.cc | 3 ++- src/relay/pass/quantize/quantize.h | 2 ++ src/relay/pass/quantize/realize.cc | 23 +++++++++++++---------- src/relay/qnn/op/requantize.cc | 6 ++---- src/relay/qnn/util.cc | 4 +++- src/relay/qnn/util.h | 6 +++--- 7 files changed, 29 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index adde205..7fa8a66 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -83,6 +83,7 @@ class QConfig(NodeBase): "do_simulation": False, "round_for_shift": True, "debug_enabled_ops": None, + "rounding": "UPWARD" } # pylint: disable=no-member @@ -160,6 +161,9 @@ def qconfig(**kwargs): is None, which means will try to call all operartors' annotate rewrite function. + rounding: "UPWARD" or "TONEAREST" + Rounding direction for fixed point multiplications. + Returns ------- config: QConfig diff --git a/src/relay/pass/quantize/quantize.cc b/src/relay/pass/quantize/quantize.cc index 3d0e71e..d564d2e 100644 --- a/src/relay/pass/quantize/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -126,7 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "do_simulation==" << op->do_simulation << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; - p->stream << "debug_enabled_ops==" << op->debug_enabled_ops; + p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", "; + p->stream << "rounding==" << op->rounding; p->stream << ")"; }); diff --git a/src/relay/pass/quantize/quantize.h b/src/relay/pass/quantize/quantize.h index 412bce0..8a0282a 100644 --- a/src/relay/pass/quantize/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -75,6 +75,7 @@ class QConfigNode : public Node { bool do_simulation = false; bool round_for_shift = true; Array debug_enabled_ops = Array(NodePtr(nullptr)); + std::string rounding = "UPWARD"; void VisitAttrs(AttrVisitor* v) { v->Visit("nbit_input", &nbit_input); @@ -88,6 +89,7 @@ class QConfigNode : public Node { v->Visit("do_simulation", &do_simulation); v->Visit("round_for_shift", &round_for_shift); v->Visit("debug_enabled_ops", &debug_enabled_ops); + v->Visit("rounding", &rounding); } static constexpr const char* _type_key = "relay.quantize.QConfig"; diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc index bdd0d73..a0e8ffc 100644 --- a/src/relay/pass/quantize/realize.cc +++ b/src/relay/pass/quantize/realize.cc @@ -31,6 +31,7 @@ #include #include "./quantize.h" #include "../pattern_util.h" +#include "../../qnn/util.h" namespace tvm { namespace relay { @@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array& args) { /* calculate `data * s1 / s2`, use shift if possible */ -inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { +inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, + const Array &data_shape) { + const QConfig& cfg = QConfig::Current(); // here we assume the dtype of data is dtype activation if (s1 == s2) return data; @@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { } else if (static_cast(factor) == factor) { return Multiply(data, MakeConstantScalar(dtype, factor)); } else { - data = Cast(data, Float(32)); - data = Multiply(data, MakeConstantScalar(Float(32), factor)); - return Cast(Round(data), dtype); + data = qnn::FixedPointMultiply(Cast(data, Int(64)), factor, data_shape, cfg->rounding); + return Cast(data, dtype); } } @@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call, data = Clip(data, clip_min_imm, clip_max_imm); return QRealizeIntExprNode::make(data, dom_scale, n->dtype); } else { - // float computation - data = Cast(data, Float(32)); - Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale)); - Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); + data = Cast(data, Int(64)); + data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm, + ref_call->type_as()->shape, + cfg->rounding); + data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype); + return QRealizeIntExprNode::make(data, dom_scale, n->dtype); } } @@ -355,7 +358,7 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args Expr dom_scale = MakeConstantScalar(Float(32), s); for (size_t i = 0; i < ret.size(); ++i) { float cur_s = GetScalarFromConstant(nptrs[i]->dom_scale); - ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype)); + ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as()->shape)); } *dtype_ptr = dtype; diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 4a424d1..a361969 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -37,8 +37,6 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs); // Lowering of qnn.requantize op - - /* * \brief Lower requantize to a sequence of ops. * \param input_tensor The input tensor to requantize op. @@ -73,8 +71,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param, // 2) If the input and output scales are same, we can skip the fixed point multiplication. auto scaled_int64_t = tensor; if (param->input_scale != param->output_scale) { - scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape, - param->rounding); + scaled_int64_t = + FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding); } // 3) Add the output zero point. diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index d9e4506..f0ad8ab 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -76,7 +76,7 @@ std::pair GetFixedPointMultiplierShift( return std::make_pair(significand, exponent); } -Expr FixedPointMuliply(Expr tensor, double multiplier, +Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& input_shape, const std::string& rounding) { // Choose high precision datatype to be int64. This is for avoiding overflow // in multiplication of two int32 values. @@ -121,6 +121,8 @@ Expr FixedPointMuliply(Expr tensor, double multiplier, auto zero_t = Zeros(input_shape, hp_dtype); round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + } else { + LOG(FATAL) << "Rounding mode " << rounding << " not supported."; } // Add the rounding scalar. tensor = Add(tensor, round_scalar); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index f94860d..0c35737 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -115,9 +115,9 @@ static inline int64_t get_const_int(const tvm::Expr& x) { * 2) Round the result. * 3) Right shift the result */ -Expr FixedPointMuliply(Expr tensor, double multiplier, - const Array& input_shape, - const std::string& rounding); +Expr FixedPointMultiply(Expr tensor, double multiplier, + const Array& input_shape, + const std::string& rounding); } // namespace qnn } // namespace relay -- 2.7.4