Update schedule_dataflow_rewrite.cc (#2934)
authorMr You <244510556@qq.com>
Mon, 1 Apr 2019 04:32:34 +0000 (12:32 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 1 Apr 2019 04:32:34 +0000 (21:32 -0700)
src/schedule/schedule_dataflow_rewrite.cc
tests/python/unittest/test_schedule_schedule_ops.py

index 774c623..a6ef2ac 100644 (file)
@@ -603,8 +603,8 @@ void InjectInline(ScheduleNode* sch) {
       if (!op.same_as(s->op)) {
         for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
           repl[s->op.output(idx)] = op.output(idx);
-          s->op = op;
         }
+        s->op = op;
       }
     } else {
       Operation op = s->op->ReplaceInputs(s->op, repl);
index c7cf1c1..44ee904 100644 (file)
@@ -459,6 +459,26 @@ def test_reduction_and_dummy_fuse_split():
     f(*args)
     assert np.all(args[0].asnumpy() == n)
 
+def test_schedule_compute_inline():
+    shape = [10, 1024]
+    A = tvm.placeholder(shape, name="A")
+    B = tvm.placeholder(shape, name="B")
+    C = tvm.compute(shape, lambda *index:A(*index)+ B(*index), name = "C")
+    def _compute(*index) :
+        return C(*index) , C(*index) * B(*index)
+    F,E = tvm.compute(shape, _compute, name = "F")
+
+    s = tvm.create_schedule([F.op, E.op])
+    AL = s.cache_read(A, "local", [C])
+    BL = s.cache_read(B, "local", [C,E])
+    CL = s.cache_write(C, "local")
+    FL, EL = s.cache_write([F, E], "local")
+    s[C].compute_inline()
+
+    s = s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+
 if __name__ == "__main__":
     test_loop_dep_reduce()
     test_loop_dep_reduce_cache_write()
@@ -483,3 +503,4 @@ if __name__ == "__main__":
     test_schedule_tensor_compute2()
     test_schedule_tensor_compute3()
     test_reduction_and_dummy_fuse_split()
+    test_schedule_compute_inline()