#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
+#include <tvm/relay/qnn/transform.h>
#include <tvm/logging.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
// Run some optimizations first, this code should
// be moved to pass manager.
- context_.module = OptimizeModule(mod_ref);
+ context_.module = OptimizeModule(mod_ref, targets_);
// Populate the global map.
//
}
}
-Module VMCompiler::OptimizeModule(const Module& mod) {
- // TODO(@icemelon9): check number of targets and build config, add more optimization pass
- transform::Sequential seq({transform::SimplifyInference(),
- transform::InlinePrimitives(),
- // TODO(@wweic): FuseOps pass currently don't handle Let
- // For now, we put FuseOps before ToANormalForm to enable it
- transform::FuseOps(),
- transform::ToANormalForm(),
- transform::LambdaLift(),
- transform::InlinePrimitives()});
- auto pass_ctx = transform::PassContext::Create();
+Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
+ Array<Pass> pass_seqs;
+ // Run all dialect legalization passes.
+ pass_seqs.push_back(relay::qnn::transform::Legalize());
+
+ // Legalize pass is restricted to homogeneous execution for now.
+ if (targets.size() == 1) {
+ pass_seqs.push_back(transform::Legalize());
+ }
+
+ pass_seqs.push_back(transform::SimplifyInference());
+ PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
+ Expr expr = args[0];
+ if (expr.as<CallNode>()) {
+ auto call_node = expr.as<CallNode>();
+ auto op_node = call_node->op.as<OpNode>();
+ if (op_node->name == "cast") {
+ auto attrs = call_node->attrs.as<CastAttrs>();
+ if (attrs->dtype == Int(32)) {
+ *rv = true;
+ }
+ }
+ }
+ *rv = false;
+ });
+ pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
+ pass_seqs.push_back(transform::InlinePrimitives());
+
+ pass_seqs.push_back(transform::CombineParallelConv2D(3));
+ pass_seqs.push_back(transform::CombineParallelDense(3));
+ pass_seqs.push_back(transform::FoldConstant());
+ pass_seqs.push_back(transform::FoldScaleAxis());
+ pass_seqs.push_back(transform::CanonicalizeCast());
+ pass_seqs.push_back(transform::CanonicalizeOps());
+
+ // Alter layout transformation is only applied to homogeneous execution yet.
+ if (targets.size() == 1) {
+ pass_seqs.push_back(transform::AlterOpLayout());
+ }
+
+ pass_seqs.push_back(transform::FoldConstant());
+
+ pass_seqs.push_back(transform::FuseOps());
+ pass_seqs.push_back(transform::ToANormalForm());
+ pass_seqs.push_back(transform::LambdaLift());
+ pass_seqs.push_back(transform::InlinePrimitives());
+
+ transform::Sequential seq(pass_seqs);
+ transform::PassContext pass_ctx = PassContext::Current();
+ // TODO(wweic): Support heterogenous execution
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
+ if (targets.size() == 1) {
+ for (const auto& kv : targets) {
+ With<Target> tctx(kv.second);
+ return seq(mod);
+ }
+ }
return seq(mod);
}