[TOPI] Fix reduction (#6250)
authorZhi <5145158+zhiics@users.noreply.github.com>
Wed, 12 Aug 2020 08:12:56 +0000 (01:12 -0700)
committerGitHub <noreply@github.com>
Wed, 12 Aug 2020 08:12:56 +0000 (17:12 +0900)
python/tvm/topi/cuda/reduction.py
tests/python/relay/test_pass_fuse_ops.py

index 38e3086..664ea44 100644 (file)
@@ -139,6 +139,8 @@ def schedule_reduce(outs):
             for tensor in input_tensors:
                 if tensor.op not in scheduled_ops:
                     traverse_before_reduce(tensor.op)
+        elif isinstance(operator, tvm.te.PlaceholderOp):
+            pass
         else:
             raise RuntimeError("Unsupported operator: %s" % operator.tag)
 
index f4369c1..1727429 100644 (file)
@@ -694,6 +694,34 @@ def test_fuse_gather_nd():
     assert tvm.ir.structural_equal(m["main"], after)
 
 
+def test_fuse_bcast_reduce_scalar():
+    """Test fusion case with broadcast and reduction involving scalar"""
+
+    def before():
+        x = relay.var("x", shape=(), dtype="int32")
+        less = relay.less(x, relay.const(10, dtype="int32"))
+        z = relay.min(less)
+        return relay.Function([x], z)
+
+    def expected():
+        p0 = relay.var("p0", shape=(), dtype="int32")
+        less = relay.less(p0, relay.const(10, dtype="int32"))
+        z0 = relay.min(less)
+        f0 = relay.Function([p0], z0)
+        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+
+        x = relay.var("x", shape=(), dtype="int32")
+        f = relay.Call(f0, [x])
+        return relay.Function([x], f)
+
+    orig = before()
+    m = fuse2(tvm.IRModule.from_expr(orig))
+    for tgt, _ in tvm.relay.testing.config.ctx_list():
+        relay.build(m, tgt)
+    after = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(m["main"], after)
+
+
 if __name__ == "__main__":
     test_fuse_simple()
     test_conv2d_fuse()
@@ -712,3 +740,4 @@ if __name__ == "__main__":
     test_fuse_max()
     test_fuse_take()
     test_fuse_gather_nd()
+    test_fuse_bcast_reduce_scalar()