[Relay][VM] Add more passes to VMCompiler (#4058)
authorWei Chen <ipondering.weic@gmail.com>
Sat, 5 Oct 2019 23:08:53 +0000 (16:08 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Sat, 5 Oct 2019 23:08:53 +0000 (16:08 -0700)
* [Relay][VM] Add more passes to VMCompiler

* Check build config

* Add todo

src/relay/backend/vm/compiler.cc
src/relay/backend/vm/compiler.h

index 8c60fe6..49079fb 100644 (file)
@@ -27,6 +27,7 @@
 #include <tvm/relay/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/interpreter.h>
+#include <tvm/relay/qnn/transform.h>
 #include <tvm/logging.h>
 #include <tvm/relay/transform.h>
 #include <tvm/runtime/vm.h>
@@ -803,7 +804,7 @@ void VMCompiler::Compile(const Module& mod_ref,
 
   // Run some optimizations first, this code should
   // be moved to pass manager.
-  context_.module = OptimizeModule(mod_ref);
+  context_.module = OptimizeModule(mod_ref, targets_);
 
   // Populate the global map.
   //
@@ -844,18 +845,63 @@ void VMCompiler::Compile(const Module& mod_ref,
   }
 }
 
-Module VMCompiler::OptimizeModule(const Module& mod) {
-  // TODO(@icemelon9): check number of targets and build config, add more optimization pass
-  transform::Sequential seq({transform::SimplifyInference(),
-                             transform::InlinePrimitives(),
-                             // TODO(@wweic): FuseOps pass currently don't handle Let
-                             // For now, we put FuseOps before ToANormalForm to enable it
-                             transform::FuseOps(),
-                             transform::ToANormalForm(),
-                             transform::LambdaLift(),
-                             transform::InlinePrimitives()});
-  auto pass_ctx = transform::PassContext::Create();
+Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
+  Array<Pass> pass_seqs;
+  // Run all dialect legalization passes.
+  pass_seqs.push_back(relay::qnn::transform::Legalize());
+
+  // Legalize pass is restricted to homogeneous execution for now.
+  if (targets.size() == 1) {
+    pass_seqs.push_back(transform::Legalize());
+  }
+
+  pass_seqs.push_back(transform::SimplifyInference());
+  PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
+    Expr expr = args[0];
+    if (expr.as<CallNode>()) {
+      auto call_node = expr.as<CallNode>();
+      auto op_node = call_node->op.as<OpNode>();
+      if (op_node->name == "cast") {
+        auto attrs = call_node->attrs.as<CastAttrs>();
+        if (attrs->dtype == Int(32)) {
+          *rv = true;
+        }
+      }
+    }
+    *rv = false;
+  });
+  pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
+  pass_seqs.push_back(transform::InlinePrimitives());
+
+  pass_seqs.push_back(transform::CombineParallelConv2D(3));
+  pass_seqs.push_back(transform::CombineParallelDense(3));
+  pass_seqs.push_back(transform::FoldConstant());
+  pass_seqs.push_back(transform::FoldScaleAxis());
+  pass_seqs.push_back(transform::CanonicalizeCast());
+  pass_seqs.push_back(transform::CanonicalizeOps());
+
+  // Alter layout transformation is only applied to homogeneous execution yet.
+  if (targets.size() == 1) {
+    pass_seqs.push_back(transform::AlterOpLayout());
+  }
+
+  pass_seqs.push_back(transform::FoldConstant());
+
+  pass_seqs.push_back(transform::FuseOps());
+  pass_seqs.push_back(transform::ToANormalForm());
+  pass_seqs.push_back(transform::LambdaLift());
+  pass_seqs.push_back(transform::InlinePrimitives());
+
+  transform::Sequential seq(pass_seqs);
+  transform::PassContext pass_ctx = PassContext::Current();
+  // TODO(wweic): Support heterogenous execution
   tvm::With<relay::transform::PassContext> ctx(pass_ctx);
+  if (targets.size() == 1) {
+    for (const auto& kv : targets) {
+      With<Target> tctx(kv.second);
+      return seq(mod);
+    }
+  }
   return seq(mod);
 }
 
index bfe19ac..14a5035 100644 (file)
@@ -105,7 +105,7 @@ class VMCompiler : public runtime::ModuleNode {
                const tvm::Target& target_host);
 
  protected:
-  Module OptimizeModule(const Module& mod);
+  Module OptimizeModule(const Module& mod, const TargetsMap& targets);
 
   void PopulateGlobalMap();