after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zz, after)
+
+def test_fuse_take():
+ """Test fusion case involving concat and take"""
+
+ def before():
+ shape = (tvm.tir.const(10, "int64"),
+ tvm.tir.const(1, "int64"))
+ x = relay.var("x", shape=shape)
+ concat = relay.concatenate([x,x], axis=-1)
+ out = relay.op.take(concat, indices=relay.const([0], dtype="int64"))
+ return relay.Function(relay.analysis.free_vars(out), out)
+
+ def expected():
+ shape1 = (tvm.tir.const(10, "int64"),
+ tvm.tir.const(1, "int64"))
+ shape2 = (tvm.tir.const(1, "int64"),)
+ x = relay.var("x", shape=shape1)
+ p0 = relay.var("p0", shape=shape1)
+ p1 = relay.var("p1", shape=shape2,
+ dtype="int64")
+ c = relay.const([0], dtype="int64")
+ concat = relay.concatenate([p0,p0], axis=-1)
+ out = relay.op.take(concat, indices=p1)
+
+ f0 = relay.Function([p0, p1], out)
+ f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+
+ y = relay.Call(f0, [x, c])
+ return relay.Function([x], y)
+
+ orig = before()
+ m = fuse2(tvm.IRModule.from_expr(orig))
+ relay.build(m, 'llvm')
+ after = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(m["main"], after)
+
+
+def test_fuse_gather_nd():
+ """Test fusion case involving concat and gather_nd"""
+
+ def before():
+ shape = (tvm.tir.const(10, "int64"),
+ tvm.tir.const(1, "int64"))
+ x = relay.var("x", shape=shape)
+ concat = relay.concatenate([x,x], axis=-1)
+ out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64"))
+ return relay.Function(relay.analysis.free_vars(out), out)
+
+ def expected():
+ shape1 = (tvm.tir.const(10, "int64"),
+ tvm.tir.const(1, "int64"))
+ shape2 = (tvm.tir.const(2, "int64"),
+ tvm.tir.const(2, "int64"))
+ x = relay.var("x", shape=shape1)
+ p0 = relay.var("p0", shape=shape1)
+ p1 = relay.var("p1", shape=shape2, dtype="int64")
+ c = relay.const([[0,1],[1,0]], dtype="int64")
+ concat = relay.concatenate([p0,p0], axis=-1)
+ out = relay.gather_nd(concat, indices=p1)
+
+ f0 = relay.Function([p0, p1], out)
+ f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+
+ y = relay.Call(f0, [x, c])
+ return relay.Function([x], y)
+
+ orig = before()
+ m = fuse2(tvm.IRModule.from_expr(orig))
+ relay.build(m, 'llvm')
+ after = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(m["main"], after)
+
+
if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
test_immutable()
test_split()
test_fuse_max()
+ test_fuse_take()
+ test_fuse_gather_nd()