for tensor in input_tensors:
if tensor.op not in scheduled_ops:
traverse_before_reduce(tensor.op)
+ elif isinstance(operator, tvm.te.PlaceholderOp):
+ pass
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
assert tvm.ir.structural_equal(m["main"], after)
+def test_fuse_bcast_reduce_scalar():
+ """Test fusion case with broadcast and reduction involving scalar"""
+
+ def before():
+ x = relay.var("x", shape=(), dtype="int32")
+ less = relay.less(x, relay.const(10, dtype="int32"))
+ z = relay.min(less)
+ return relay.Function([x], z)
+
+ def expected():
+ p0 = relay.var("p0", shape=(), dtype="int32")
+ less = relay.less(p0, relay.const(10, dtype="int32"))
+ z0 = relay.min(less)
+ f0 = relay.Function([p0], z0)
+ f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+
+ x = relay.var("x", shape=(), dtype="int32")
+ f = relay.Call(f0, [x])
+ return relay.Function([x], f)
+
+ orig = before()
+ m = fuse2(tvm.IRModule.from_expr(orig))
+ for tgt, _ in tvm.relay.testing.config.ctx_list():
+ relay.build(m, tgt)
+ after = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(m["main"], after)
+
+
if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_fuse_max()
test_fuse_take()
test_fuse_gather_nd()
+ test_fuse_bcast_reduce_scalar()