From: Haozheng Fan Date: Thu, 23 Jul 2020 21:04:10 +0000 (+0800) Subject: [RELAY][Fix] i64 indices (#5235) X-Git-Tag: upstream/0.7.0~365 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6dbc344b56294b54ec70dfae29fbb2cbc0591c64;p=platform%2Fupstream%2Ftvm.git [RELAY][Fix] i64 indices (#5235) * fix * resolve comments --- diff --git a/src/te/schedule/operation_inline.cc b/src/te/schedule/operation_inline.cc index aab30ed..7399af9 100644 --- a/src/te/schedule/operation_inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -63,7 +63,8 @@ class OperationInliner final : public StmtExprMutator { } else { Map vmap; for (size_t i = 0; i < args_.size(); ++i) { - vmap.Set(args_[i], op->indices[i]); + // cast indices to the type of the original indexing variable + vmap.Set(args_[i], cast(args_[i].dtype(), op->indices[i])); } expr = Substitute(Evaluate(expr), vmap).as()->value; } diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 6b7d297..f4369c1 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -621,6 +621,79 @@ def test_fuse_max(): after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(zz, after) + +def test_fuse_take(): + """Test fusion case involving concat and take""" + + def before(): + shape = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + x = relay.var("x", shape=shape) + concat = relay.concatenate([x,x], axis=-1) + out = relay.op.take(concat, indices=relay.const([0], dtype="int64")) + return relay.Function(relay.analysis.free_vars(out), out) + + def expected(): + shape1 = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + shape2 = (tvm.tir.const(1, "int64"),) + x = relay.var("x", shape=shape1) + p0 = relay.var("p0", shape=shape1) + p1 = relay.var("p1", shape=shape2, + dtype="int64") + c = relay.const([0], dtype="int64") + concat = relay.concatenate([p0,p0], axis=-1) + out = relay.op.take(concat, indices=p1) + + f0 = relay.Function([p0, p1], out) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + y = relay.Call(f0, [x, c]) + return relay.Function([x], y) + + orig = before() + m = fuse2(tvm.IRModule.from_expr(orig)) + relay.build(m, 'llvm') + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(m["main"], after) + + +def test_fuse_gather_nd(): + """Test fusion case involving concat and gather_nd""" + + def before(): + shape = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + x = relay.var("x", shape=shape) + concat = relay.concatenate([x,x], axis=-1) + out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64")) + return relay.Function(relay.analysis.free_vars(out), out) + + def expected(): + shape1 = (tvm.tir.const(10, "int64"), + tvm.tir.const(1, "int64")) + shape2 = (tvm.tir.const(2, "int64"), + tvm.tir.const(2, "int64")) + x = relay.var("x", shape=shape1) + p0 = relay.var("p0", shape=shape1) + p1 = relay.var("p1", shape=shape2, dtype="int64") + c = relay.const([[0,1],[1,0]], dtype="int64") + concat = relay.concatenate([p0,p0], axis=-1) + out = relay.gather_nd(concat, indices=p1) + + f0 = relay.Function([p0, p1], out) + f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + + y = relay.Call(f0, [x, c]) + return relay.Function([x], y) + + orig = before() + m = fuse2(tvm.IRModule.from_expr(orig)) + relay.build(m, 'llvm') + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(m["main"], after) + + if __name__ == "__main__": test_fuse_simple() test_conv2d_fuse() @@ -637,3 +710,5 @@ if __name__ == "__main__": test_immutable() test_split() test_fuse_max() + test_fuse_take() + test_fuse_gather_nd()