[Relay][Quantization] Fix out-of-date realize (#3790)
authorWuwei Lin <wuwei@apache.org>
Fri, 16 Aug 2019 06:23:08 +0000 (23:23 -0700)
committerziheng <ziheng@apache.org>
Fri, 16 Aug 2019 06:23:08 +0000 (23:23 -0700)
src/relay/pass/quantize/realize.cc

index e4bc63a..7eae999 100644 (file)
@@ -110,7 +110,6 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
   } else if (static_cast<int>(factor) == factor) {
     return Multiply(data, MakeConstantScalar(dtype, factor));
   } else {
-    LOG(FATAL) << "fall back to float computation";
     data = Cast(data, Float(32));
     data = Multiply(data, MakeConstantScalar(Float(32), factor));
     return Cast(Round(data), dtype);
@@ -147,15 +146,21 @@ Expr QuantizeRealize(const Call& ref_call,
     }
 
     float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
-    CHECK_GT(shift_nbit, 0);
+    CHECK_NE(shift_nbit, 0);
     if (static_cast<int>(shift_nbit) == shift_nbit) {
-      // use right shift
-      if (cfg->round_for_shift) {
-        float round_bias = std::pow(2.0, shift_nbit - 1);
-        data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
+      if (shift_nbit > 0) {
+        // use right shift
+        if (cfg->round_for_shift) {
+          float round_bias = std::pow(2.0, shift_nbit - 1);
+          data = Add(data, MakeConstantScalar(cfg->dtype_activation,
+                                              static_cast<int>(round_bias)));
+        }
+        data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
+                                                   static_cast<int>(shift_nbit)));
+      } else {
+        data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
+                                                  static_cast<int>(shift_nbit)));
       }
-      data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
-                                                 static_cast<int>(shift_nbit)));
       data = Clip(data, clip_min_imm, clip_max_imm);
       return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
     } else {