[RELAY] Filter PlaceholderOp from schedule. (#2412)
authorSiva <sivar.b@huawei.com>
Tue, 15 Jan 2019 05:07:20 +0000 (10:37 +0530)
committerTianqi Chen <tqchen@users.noreply.github.com>
Tue, 15 Jan 2019 05:07:19 +0000 (21:07 -0800)
src/relay/backend/compile_engine.cc
tests/python/relay/test_backend_compile_engine.py

index 0dc9e458d7aa30ba1845cf7356f3dce165b820d5..73bae053cc828ce5d10f8870f6d65508358aeb0b 100644 (file)
@@ -83,11 +83,20 @@ class ScheduleGetter :
     cache_node->func_name = readable_name_stream_.str();
     CachedFunc cfunc(cache_node);
     CHECK(master_op_.defined());
+    // Fusion over tupled results may leave identity relationships
+    // between inputs and outputs, and those should not be scheduled.
+    // Hence schedule only non PlaceholderOp outputs.
+    tvm::Array<Tensor> tensor_outs;
+    for (const auto& tensor : cache_node->outputs) {
+      if (!tensor->op.as<PlaceholderOpNode>()) {
+        tensor_outs.push_back(tensor);
+      }
+    }
     Schedule schedule;
     // No need to register schedule for device copy op.
     if (master_attrs_.as<DeviceCopyAttrs>() == nullptr) {
       schedule =
-          fschedule[master_op_](master_attrs_, cache_node->outputs, target_);
+          fschedule[master_op_](master_attrs_, tensor_outs, target_);
       for (const auto& scalar : scalars_) {
         schedule[scalar].compute_inline();
       }
index 568d7849e7ee78adc82507dd5b5df0e619e2fbce..a3a3af3f94b89e3ac3da7ca2ad9048fbc2bf228d 100644 (file)
@@ -33,6 +33,16 @@ def test_compile_engine():
                 y.asnumpy(), x.asnumpy() * 3)
     engine.dump()
 
+def test_compile_placeholder_bypass():
+    engine = relay.backend.compile_engine.get()
+    x = relay.var("x", shape=(2, 3))
+    y = relay.var("y", shape=(2, 3))
+    z = relay.var("z", shape=(2, 3))
+    result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)])
+    func = relay.Function(relay.ir_pass.free_vars(result), result)
+    with relay.build_config(opt_level=0):
+       graph, lib, params = relay.build(func, 'llvm')
 
 if __name__ == "__main__":
     test_compile_engine()
+    test_compile_placeholder_bypass()