raise ValueError
+def mul_partition_generic(ref_call, new_args, ctx):
+ """Rewrite function for ewise mul for partition for generic devices"""
+ lhs_cond, lhs = partition_expr_check(new_args[0])
+ rhs_cond, rhs = partition_expr_check(new_args[1])
+
+ if lhs_cond:
+ # introduced by bn: multiply(out, scale)
+ return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
+
+ if not lhs_cond and not rhs_cond:
+ # trivial case
+ return None
+
+ raise ValueError
+
# TODO(ziheng) enhance `register_partition_function` to dispatch
# for target automatically
@register_partition_function("multiply")
def multiply_partition_function(ref_call, new_args, ctx):
- """Rewrite function for ewise add for partition"""
- lhs_cond, lhs = partition_expr_check(new_args[0])
- rhs_cond, rhs = partition_expr_check(new_args[1])
- if lhs_cond:
- # introduced by bn: multiply(out, scale)
- return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
- assert (not lhs_cond) and (not rhs_cond)
- return None
+ """Rewrite function for ewise multiply for partition"""
+ return mul_partition_generic(ref_call, new_args, ctx)
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExpr(ret, dom_scale, dtype);
}
- CHECK(!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>());
+ CHECK(!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+RELAY_REGISTER_OP("nn.batch_flatten")
+ .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
+
RELAY_REGISTER_OP("annotation.stop_fusion")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
from tvm import te
from tvm import relay
from tvm.relay import testing
+from tvm.relay.expr import Call
def quantize_and_build(out):
relay.build(qmod, "llvm", params=params)
+ return qmod
def test_mul_rewrite():
"""a test case where rhs of mul is not constant"""
quantize_and_build(act * pool)
+def test_batch_flatten_rewrite():
+
+ data = relay.var("data", shape=(1, 16, 64, 64), dtype="float32")
+
+ out = relay.nn.conv2d(data, relay.var("weight"),
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ channels=16)
+
+ out = relay.nn.batch_flatten(out)
+
+ qmod = quantize_and_build(out)
+
+ def _check_batch_flatten(node):
+ if isinstance(node, Call):
+ if(node.op.name == "nn.batch_flatten"):
+ assert node.checked_type.dtype == "int8"
+
+ # check if batch_flatten is quantized
+ relay.analysis.post_order_visit(qmod["main"], _check_batch_flatten)
def get_calibration_dataset(input_name):
dataset = []
if __name__ == "__main__":
test_mul_rewrite()
+ test_batch_flatten_rewrite()
test_calibrate_target(False)
test_calibrate_target(True)
test_calibrate_memory_bound()