From: Balint Cristian Date: Sun, 21 Jun 2020 22:36:51 +0000 (+0300) Subject: [QUANTIZE] Add nn.batch_flatten as quantizable. (#5805) X-Git-Tag: upstream/0.7.0~521 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7902d0fee63a8d8a132a6ca3369ebfecb3aee68d;p=platform%2Fupstream%2Ftvm.git [QUANTIZE] Add nn.batch_flatten as quantizable. (#5805) * [ONNX] Skip ADD inside Gemm op when vector is zero * [QUANTIZE] Add nn.batch_flatten as quantizable. --- diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index a607f4e..b72f51c 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -121,6 +121,21 @@ def add_partition_generic(ref_call, new_args, ctx): raise ValueError +def mul_partition_generic(ref_call, new_args, ctx): + """Rewrite function for ewise mul for partition for generic devices""" + lhs_cond, lhs = partition_expr_check(new_args[0]) + rhs_cond, rhs = partition_expr_check(new_args[1]) + + if lhs_cond: + # introduced by bn: multiply(out, scale) + return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) + + if not lhs_cond and not rhs_cond: + # trivial case + return None + + raise ValueError + # TODO(ziheng) enhance `register_partition_function` to dispatch # for target automatically @@ -136,11 +151,5 @@ def add_partition_function(ref_call, new_args, ctx): @register_partition_function("multiply") def multiply_partition_function(ref_call, new_args, ctx): - """Rewrite function for ewise add for partition""" - lhs_cond, lhs = partition_expr_check(new_args[0]) - rhs_cond, rhs = partition_expr_check(new_args[1]) - if lhs_cond: - # introduced by bn: multiply(out, scale) - return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) - assert (not lhs_cond) and (not rhs_cond) - return None + """Rewrite function for ewise multiply for partition""" + return mul_partition_generic(ref_call, new_args, ctx) diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 49d1e52..ddf945a 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -272,7 +272,7 @@ Expr MulRealize(const Call& ref_call, const Array& new_args, const ObjectR Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExpr(ret, dom_scale, dtype); } - CHECK(!new_args[0]->IsInstance() && !new_args[1]->IsInstance()); + CHECK(!new_args[0]->IsInstance() || !new_args[1]->IsInstance()); return Expr(nullptr); } @@ -418,6 +418,9 @@ RELAY_REGISTER_OP("nn.relu").set_attr("FQRealizeRewrite", Ident RELAY_REGISTER_OP("strided_slice").set_attr("FQRealizeRewrite", IdentityRealize); +RELAY_REGISTER_OP("nn.batch_flatten") + .set_attr("FQRealizeRewrite", IdentityRealize); + RELAY_REGISTER_OP("annotation.stop_fusion") .set_attr("FQRealizeRewrite", IdentityRealize); diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index 35d33b1..f742797 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -21,6 +21,7 @@ import tvm from tvm import te from tvm import relay from tvm.relay import testing +from tvm.relay.expr import Call def quantize_and_build(out): @@ -32,6 +33,7 @@ def quantize_and_build(out): relay.build(qmod, "llvm", params=params) + return qmod def test_mul_rewrite(): """a test case where rhs of mul is not constant""" @@ -49,6 +51,26 @@ def test_mul_rewrite(): quantize_and_build(act * pool) +def test_batch_flatten_rewrite(): + + data = relay.var("data", shape=(1, 16, 64, 64), dtype="float32") + + out = relay.nn.conv2d(data, relay.var("weight"), + kernel_size=(3, 3), + padding=(1, 1), + channels=16) + + out = relay.nn.batch_flatten(out) + + qmod = quantize_and_build(out) + + def _check_batch_flatten(node): + if isinstance(node, Call): + if(node.op.name == "nn.batch_flatten"): + assert node.checked_type.dtype == "int8" + + # check if batch_flatten is quantized + relay.analysis.post_order_visit(qmod["main"], _check_batch_flatten) def get_calibration_dataset(input_name): dataset = [] @@ -83,6 +105,7 @@ def test_calibrate_memory_bound(): if __name__ == "__main__": test_mul_rewrite() + test_batch_flatten_rewrite() test_calibrate_target(False) test_calibrate_target(True) test_calibrate_memory_bound()