[BuildModule] Fix AlterLayout Pass (#3155)
authorBing Xu <antinucleon@gmail.com>
Thu, 9 May 2019 04:06:33 +0000 (21:06 -0700)
committerJared Roesch <roeschinc@gmail.com>
Thu, 9 May 2019 04:06:33 +0000 (00:06 -0400)
src/relay/backend/build_module.cc

index b60a048..67ab750 100644 (file)
@@ -504,7 +504,14 @@ class RelayBuildModule : public runtime::ModuleNode {
     if (cfg.pass_enabled("AlterOpLayout")) {
       if (targets.size() == 1) {
         func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-        func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
+        auto enter_pf = GetPackedFunc("_EnterTargetScope");
+        auto exit_pf = GetPackedFunc("_ExitTargetScope");
+        for (const auto& kv : targets) {
+          auto target = Target::create(kv.second);
+          (*enter_pf)(target);
+          func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
+          (*exit_pf)();
+        }
       } else {
         LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
                   << " execution yet.";