[QUANTIZE] Add nn.batch_flatten as quantizable. (#5805)
authorBalint Cristian <cristian.balint@gmail.com>
Sun, 21 Jun 2020 22:36:51 +0000 (01:36 +0300)
committerGitHub <noreply@github.com>
Sun, 21 Jun 2020 22:36:51 +0000 (15:36 -0700)
* [ONNX] Skip ADD inside Gemm op when vector is zero

* [QUANTIZE] Add nn.batch_flatten as quantizable.

python/tvm/relay/quantize/_partition.py
src/relay/quantize/realize.cc
tests/python/relay/test_pass_auto_quantize.py

index a607f4e..b72f51c 100644 (file)
@@ -121,6 +121,21 @@ def add_partition_generic(ref_call, new_args, ctx):
 
     raise ValueError
 
+def mul_partition_generic(ref_call, new_args, ctx):
+    """Rewrite function for ewise mul for partition for generic devices"""
+    lhs_cond, lhs = partition_expr_check(new_args[0])
+    rhs_cond, rhs = partition_expr_check(new_args[1])
+
+    if lhs_cond:
+        # introduced by bn: multiply(out, scale)
+        return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
+
+    if not lhs_cond and not rhs_cond:
+        # trivial case
+        return None
+
+    raise ValueError
+
 
 # TODO(ziheng) enhance `register_partition_function` to dispatch
 # for target automatically
@@ -136,11 +151,5 @@ def add_partition_function(ref_call, new_args, ctx):
 
 @register_partition_function("multiply")
 def multiply_partition_function(ref_call, new_args, ctx):
-    """Rewrite function for ewise add for partition"""
-    lhs_cond, lhs = partition_expr_check(new_args[0])
-    rhs_cond, rhs = partition_expr_check(new_args[1])
-    if lhs_cond:
-        # introduced by bn: multiply(out, scale)
-        return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
-    assert (not lhs_cond) and (not rhs_cond)
-    return None
+    """Rewrite function for ewise multiply for partition"""
+    return mul_partition_generic(ref_call, new_args, ctx)
index 49d1e52..ddf945a 100644 (file)
@@ -272,7 +272,7 @@ Expr MulRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectR
     Expr dom_scale = FoldConstantOpt(mul);
     return QRealizeIntExpr(ret, dom_scale, dtype);
   }
-  CHECK(!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>());
+  CHECK(!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>());
   return Expr(nullptr);
 }
 
@@ -418,6 +418,9 @@ RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FQRealizeRewrite", Ident
 
 RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
 
+RELAY_REGISTER_OP("nn.batch_flatten")
+    .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
 RELAY_REGISTER_OP("annotation.stop_fusion")
     .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
 
index 35d33b1..f742797 100644 (file)
@@ -21,6 +21,7 @@ import tvm
 from tvm import te
 from tvm import relay
 from tvm.relay import testing
+from tvm.relay.expr import Call
 
 
 def quantize_and_build(out):
@@ -32,6 +33,7 @@ def quantize_and_build(out):
 
     relay.build(qmod, "llvm", params=params)
 
+    return qmod
 
 def test_mul_rewrite():
     """a test case where rhs of mul is not constant"""
@@ -49,6 +51,26 @@ def test_mul_rewrite():
 
     quantize_and_build(act * pool)
 
+def test_batch_flatten_rewrite():
+
+    data = relay.var("data", shape=(1, 16, 64, 64), dtype="float32")
+
+    out = relay.nn.conv2d(data, relay.var("weight"),
+                          kernel_size=(3, 3),
+                          padding=(1, 1),
+                          channels=16)
+
+    out = relay.nn.batch_flatten(out)
+
+    qmod = quantize_and_build(out)
+
+    def _check_batch_flatten(node):
+        if isinstance(node, Call):
+            if(node.op.name == "nn.batch_flatten"):
+               assert node.checked_type.dtype == "int8"
+
+    # check if batch_flatten is quantized
+    relay.analysis.post_order_visit(qmod["main"], _check_batch_flatten)
 
 def get_calibration_dataset(input_name):
     dataset = []
@@ -83,6 +105,7 @@ def test_calibrate_memory_bound():
 
 if __name__ == "__main__":
     test_mul_rewrite()
+    test_batch_flatten_rewrite()
     test_calibrate_target(False)
     test_calibrate_target(True)
     test_calibrate_memory_bound()