[REFACTOR][IR] Allow Module to store BaseFunc. (#4678)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sat, 11 Jan 2020 06:54:16 +0000 (22:54 -0800)
committerGitHub <noreply@github.com>
Sat, 11 Jan 2020 06:54:16 +0000 (22:54 -0800)
Under the unified IR. We will allow a single IRModule
to store different function variants, such as relay::Function,
ExternFunc, and low-level function.

This PR changes relay::Function -> BaseFunc in the module file
to support multiple function variants.

15 files changed:
include/tvm/relay/module.h
src/relay/backend/build_module.cc
src/relay/backend/vm/compiler.cc
src/relay/backend/vm/inline_primitives.cc
src/relay/backend/vm/lambda_lift.cc
src/relay/ir/module.cc
src/relay/ir/pretty_printer.cc
src/relay/pass/eta_expand.cc
src/relay/pass/fold_constant.cc
src/relay/pass/gradient.cc
src/relay/pass/partial_eval.cc
src/relay/pass/pass_manager.cc
src/relay/pass/quantize/realize.cc
src/relay/pass/to_cps.cc
tests/cpp/relay_transform_sequential.cc

index dba4e4a..4fc31eb 100644 (file)
@@ -62,7 +62,7 @@ struct Module;
 class ModuleNode : public RelayNode {
  public:
   /*! \brief A map from ids to all global functions. */
-  tvm::Map<GlobalVar, Function> functions;
+  tvm::Map<GlobalVar, BaseFunc> functions;
   /*! \brief A map from global type vars to ADT type data. */
   tvm::Map<GlobalTypeVar, TypeData> type_definitions;
 
@@ -75,7 +75,7 @@ class ModuleNode : public RelayNode {
     v->Visit("global_type_var_map_", &global_type_var_map_);
   }
 
-  TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
+  TVM_DLL static Module make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
                              tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
                              std::unordered_set<std::string> imports = {});
 
@@ -86,7 +86,7 @@ class ModuleNode : public RelayNode {
    * \param update Controls whether you can replace a definition in the
    * environment.
    */
-  TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);
+  TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false);
 
   /*!
    * \brief Add a function to the global environment.
@@ -95,7 +95,7 @@ class ModuleNode : public RelayNode {
    *
    * It does not do type inference as Add does.
    */
-  TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
+  TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func);
 
   /*!
    * \brief Add a type-level definition to the global environment.
@@ -124,7 +124,7 @@ class ModuleNode : public RelayNode {
    * \param var The name of the global function to update.
    * \param func The new function.
    */
-  TVM_DLL void Update(const GlobalVar& var, const Function& func);
+  TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func);
 
   /*!
    * \brief Update a type definition in the global environment.
@@ -184,14 +184,14 @@ class ModuleNode : public RelayNode {
    * \param var The global var to lookup.
    * \returns The function named by the variable argument.
    */
-  TVM_DLL Function Lookup(const GlobalVar& var) const;
+  TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;
 
   /*!
    * \brief Look up a global function by its string name
    * \param name The name of the function.
    * \returns The function named by the argument.
    */
-  TVM_DLL Function Lookup(const std::string& name) const;
+  TVM_DLL BaseFunc Lookup(const std::string& name) const;
 
   /*!
    * \brief Look up a global type definition by its variable.
@@ -256,7 +256,7 @@ class ModuleNode : public RelayNode {
    */
   TVM_DLL static Module FromExpr(
     const Expr& expr,
-    const tvm::Map<GlobalVar, Function>& global_funcs = {},
+    const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
     const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
 
   static constexpr const char* _type_key = "relay.Module";
index 839dabc..f64d556 100644 (file)
@@ -463,7 +463,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     // Optimize input Relay Function and returns Relay Module
     relay::Module relay_module = Optimize(func, targets_, params);
     // Get the updated function.
-    func = relay_module->Lookup("main");
+    func = Downcast<Function>(relay_module->Lookup("main"));
 
     // Generate code for the updated function.
     graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
index bb47685..b27c55b 100644 (file)
@@ -612,7 +612,13 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
       CHECK(it != context_->global_map.end());
       DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
                       << " with func_index=" << it->second;
-      auto func = context_->module->Lookup(global);
+
+      // TODO(tvm-team):
+      // Think about mixed call into global that is not a relay::Function
+      // perhaps establish as an invariance(all functions in mod must be relay::Function)
+      auto func = Downcast<Function>(context_->module->Lookup(global));
+
+
       if (IsClosure(func)) {
         auto arity = func->params.size();
         Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
@@ -813,7 +819,10 @@ void VMCompiler::Lower(Module mod,
   CHECK_EQ(targets.size(), 1)
     << "Currently VM compiler doesn't support heterogeneous compilation";
   if (params_.size()) {
-    auto f = BindParamsByName(mod->Lookup("main"), params_);
+    BaseFunc base_func = mod->Lookup("main");
+    CHECK(base_func->IsInstance<FunctionNode>())
+        << "VM compiler expects to compile relay::Function";
+    auto f = BindParamsByName(Downcast<Function>(base_func), params_);
     auto gvar = mod->GetGlobalVar("main");
     mod->Add(gvar, f);
   }
@@ -837,13 +846,15 @@ void VMCompiler::Lower(Module mod,
 
   for (auto named_func : context_.module->functions) {
     auto gvar = named_func.first;
-    auto func = named_func.second;
-    VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
-    auto vm_func = func_compiler.Compile(gvar, func);
-
-    size_t func_index = context_.global_map.at(gvar);
-    CHECK(func_index < exec_->functions.size());
-    exec_->functions[func_index] = vm_func;
+    if (auto* n = named_func.second.as<FunctionNode>()) {
+      auto func = GetRef<Function>(n);
+      VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
+      auto vm_func = func_compiler.Compile(gvar, func);
+
+      size_t func_index = context_.global_map.at(gvar);
+      CHECK(func_index < exec_->functions.size());
+      exec_->functions[func_index] = vm_func;
+    }
   }
 
 #if USE_RELAY_DEBUG
index 25b0735..abd8d29 100644 (file)
@@ -110,19 +110,23 @@ struct PrimitiveInliner : ExprMutator {
     auto gvar_funcs = module_->functions;
     for (auto pair : gvar_funcs) {
       auto global = pair.first;
-      auto func = pair.second;
-      DLOG(INFO) << "Before inlining primitives: " << global
-                 << std::endl << AsText(func, false);
-
-      func = FunctionNode::make(func->params,
-                                VisitExpr(func->body),
-                                func->ret_type,
-                                func->type_params,
-                                func->attrs);
-      module_->Add(global, func, true);
-
-      DLOG(INFO) << "After inlining primitives: " << global
-                 << std::endl << AsText(func, false);
+      auto base_func = pair.second;
+      if (auto* n = base_func.as<FunctionNode>()) {
+        auto func = GetRef<Function>(n);
+
+        DLOG(INFO) << "Before inlining primitives: " << global
+                   << std::endl << AsText(func, false);
+
+        func = FunctionNode::make(func->params,
+                                  VisitExpr(func->body),
+                                  func->ret_type,
+                                  func->type_params,
+                                  func->attrs);
+        module_->Add(global, func, true);
+
+        DLOG(INFO) << "After inlining primitives: " << global
+                   << std::endl << AsText(func, false);
+      }
     }
     return module_;
   }
index 601af9e..34f03d2 100644 (file)
@@ -188,13 +188,15 @@ class LambdaLifter : public ExprMutator {
     // There is an ordering bug here.
     auto glob_funcs = module_->functions;
     for (auto pair : glob_funcs) {
-      auto func = pair.second;
-      func = FunctionNode::make(func->params,
-                                VisitExpr(func->body),
-                                func->ret_type,
-                                func->type_params,
-                                func->attrs);
-      module_->Add(pair.first, func, true);
+      if (auto* n = pair.second.as<FunctionNode>()) {
+        auto func = GetRef<Function>(n);
+        func = FunctionNode::make(func->params,
+                                  VisitExpr(func->body),
+                                  func->ret_type,
+                                  func->type_params,
+                                  func->attrs);
+        module_->Add(pair.first, func, true);
+      }
     }
     return module_;
   }
index bf1ebf3..0d2aebc 100644 (file)
@@ -34,10 +34,9 @@ namespace relay {
 using tvm::NodePrinter;
 using namespace runtime;
 
-Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
+Module ModuleNode::make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
                         tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
-                        std::unordered_set<std::string> imports
-                        ) {
+                        std::unordered_set<std::string> imports) {
   auto n = make_object<ModuleNode>();
   n->functions = std::move(global_funcs);
   n->type_definitions = std::move(global_type_defs);
@@ -112,40 +111,54 @@ tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
   return ret;
 }
 
-void ModuleNode::Add(const GlobalVar& var,
-                     const Function& f,
-                     bool update) {
-  Function func = Downcast<Function>(DeDup(f));
+// helper function to run type check
+relay::Function RunTypeCheck(const Module& mod,
+                             const GlobalVar& var,
+                             relay::Function f) {
+  auto func = Downcast<relay::Function>(relay::DeDup(std::move(f)));
   // Type check the item before we add it to the module.
-  auto mod = GetRef<Module>(this);
-  auto fv = FreeVars(func);
-  auto ftv = FreeTypeVars(func, mod);
+  auto fv = relay::FreeVars(func);
+  auto ftv = relay::FreeTypeVars(func, mod);
   if (fv.size() != 0) {
     LOG(WARNING)
-      << "There are free variables: "
-      << fv
-      << " in function: "
-      << AsText(func, false)
-      << std::endl;
+        << "There are free variables: "
+        << fv
+        << " in function: "
+        << AsText(func, false)
+        << std::endl;
   }
   if (ftv.size() != 0) {
     LOG(WARNING)
-      << "There are free type variables: "
-      << ftv
-      << " in function: "
-      << AsText(func, false)
-      << std::endl;
+        << "There are free type variables: "
+        << ftv
+        << " in function: "
+        << AsText(func, false)
+        << std::endl;
   }
   func =
-    FunctionNode::make(concat(func->params, fv),
-                       func->body,
-                       func->ret_type,
-                       concat(func->type_params, ftv),
-                       func->attrs);
+      relay::FunctionNode::make(concat(func->params, fv),
+                                func->body,
+                                func->ret_type,
+                                concat(func->type_params, ftv),
+                                func->attrs);
   // Type check the item before we add it to the module.
-  Function checked_func = InferType(func, mod, var);
+  relay::Function checked_func = InferType(func, mod, var);
+  return checked_func;
+}
+
+void ModuleNode::Add(const GlobalVar& var,
+                     const BaseFunc& f,
+                     bool update) {
+  BaseFunc checked_func = f;
+  if (auto* ptr = f.as<relay::FunctionNode>()) {
+    checked_func = RunTypeCheck(GetRef<Module>(this),
+                                var,
+                                GetRef<relay::Function>(ptr));
+  }
+
   auto type = checked_func->checked_type();
-  CHECK(type.as<IncompleteTypeNode>() == nullptr);
+  CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);
+
   if (functions.find(var) != functions.end()) {
     CHECK(update)
         << "Already have definition for " << var->name_hint;
@@ -158,8 +171,7 @@ void ModuleNode::Add(const GlobalVar& var,
 }
 
 void ModuleNode::AddUnchecked(const GlobalVar& var,
-                              const Function& func) {
-  auto mod = GetRef<Module>(this);
+                              const BaseFunc& func) {
   this->functions.Set(var, func);
 
   auto it = global_var_map_.find(var->name_hint);
@@ -185,15 +197,19 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
   }
 }
 
-void ModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
+void ModuleNode::AddTypeDef(const GlobalTypeVar& var,
+                            const TypeData& type,
+                            bool update) {
   AddTypeDefUnchecked(var, type, update);
   // need to kind check at the end because the check can look up
   // a definition potentially
-  CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
+  CHECK(relay::KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
     << "Invalid or malformed typedata given to module: " << type;
 }
 
-void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) {
+void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var,
+                                     const TypeData& type,
+                                     bool update) {
   this->type_definitions.Set(var, type);
   if (!update) {
     // set global type var map
@@ -204,11 +220,13 @@ void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& t
   RegisterConstructors(var, type);
 }
 
-void ModuleNode::Update(const GlobalVar& var, const Function& func) {
+void ModuleNode::Update(const GlobalVar& var,
+                        const BaseFunc& func) {
   this->Add(var, func, true);
 }
 
-void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
+void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var,
+                               const TypeData& type) {
   this->AddTypeDef(var, type, true);
 }
 
@@ -219,14 +237,14 @@ void ModuleNode::Remove(const GlobalVar& var) {
   gvar_node->data.erase(var->name_hint);
 }
 
-Function ModuleNode::Lookup(const GlobalVar& var) const {
+BaseFunc ModuleNode::Lookup(const GlobalVar& var) const {
   auto it = functions.find(var);
   CHECK(it != functions.end())
       << "There is no definition of " << var->name_hint;
   return (*it).second;
 }
 
-Function ModuleNode::Lookup(const std::string& name) const {
+BaseFunc ModuleNode::Lookup(const std::string& name) const {
   GlobalVar id = this->GetGlobalVar(name);
   return this->Lookup(id);
 }
@@ -268,16 +286,17 @@ void ModuleNode::Update(const Module& mod) {
 }
 
 Module ModuleNode::FromExpr(
-  const Expr& expr,
-  const tvm::Map<GlobalVar, Function>& global_funcs,
+  const RelayExpr& expr,
+  const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
   const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
   auto mod = ModuleNode::make(global_funcs, type_definitions);
-  auto func_node = expr.as<FunctionNode>();
-  Function func;
-  if (func_node) {
-    func = GetRef<Function>(func_node);
+  BaseFunc func;
+  if (auto* func_node = expr.as<relay::FunctionNode>()) {
+    func = GetRef<relay::Function>(func_node);
   } else {
-    func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
+    func = relay::FunctionNode::make(
+        relay::FreeVars(expr), expr, Type(),
+        relay::FreeTypeVars(expr, mod), {});
   }
   auto main_gv = GlobalVar("main");
   mod->Add(main_gv, func);
@@ -318,8 +337,8 @@ Module FromText(const std::string& source, const std::string& source_name) {
 TVM_REGISTER_NODE_TYPE(ModuleNode);
 
 TVM_REGISTER_GLOBAL("relay._make.Module")
-.set_body_typed(
-[](tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalTypeVar, TypeData> types) {
+.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
+                   tvm::Map<GlobalTypeVar, TypeData> types) {
   return ModuleNode::make(funcs, types, {});
 });
 
@@ -330,17 +349,19 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add")
   ObjectRef val = args[2];
   bool update = args[3];
   CHECK(val->IsInstance<ExprNode>());
-  if (val->IsInstance<FunctionNode>()) {
-    mod->Add(var, Downcast<Function>(val), update);
+
+  if (val->IsInstance<relay::FunctionNode>()) {
+    mod->Add(var, Downcast<relay::Function>(val), update);
   } else if (val->IsInstance<GlobalVarNode>()) {
     GlobalVar gv = Downcast<GlobalVar>(val);
     auto mod_copy = Module(make_object<ModuleNode>(*mod.operator->()));
-    mod_copy = transform::EtaExpand(
-      /* expand_constructor */ false, /* expand_global_var */ true)(mod_copy);
+    mod_copy = relay::transform::EtaExpand(
+        /* expand_constructor */ false,
+        /* expand_global_var */ true)(mod_copy);
     auto func = mod_copy->Lookup(gv->name_hint);
-    mod->Add(var, Downcast<Function>(func), update);
+    mod->Add(var, Downcast<relay::Function>(func), update);
   } else {
-    auto func = FunctionNode::make({}, Downcast<Expr>(val), Type(nullptr), {});
+    auto func = FunctionNode::make({}, Downcast<relay::Expr>(val), Type(nullptr), {});
     mod->Add(var, func, update);
   }
   *ret = mod;
@@ -390,8 +411,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag")
   });
 
 TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr")
-.set_body_typed([](Expr e,
-                   tvm::Map<GlobalVar, Function> funcs,
+.set_body_typed([](RelayExpr e,
+                   tvm::Map<GlobalVar, BaseFunc> funcs,
                    tvm::Map<GlobalTypeVar, TypeData> type_defs) {
   return ModuleNode::FromExpr(e, funcs, type_defs);
 });
index a9d788d..c88c3a0 100644 (file)
@@ -486,7 +486,7 @@ class PrettyPrinter :
     return doc;
   }
 
-  Doc PrintFunc(const Doc& prefix, const Function& fn) {
+  Doc PrintFunc(const Doc& prefix, const relay::Function& fn) {
     Doc doc;
     doc << prefix;
     if (fn->type_params.size() > 0) {
@@ -514,6 +514,17 @@ class PrettyPrinter :
     return doc;
   }
 
+  Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
+    if (auto* n = base_func.as<relay::FunctionNode>()) {
+      return PrintFunc(prefix, GetRef<relay::Function>(n));
+    } else {
+      // def @xyz = meta['ExternalFunc'][id]
+      Doc doc;
+      doc << prefix << " = " <<  meta_.GetMetaNode(base_func);
+      return doc;
+    }
+  }
+
   Doc PrintMod(const Module& mod) {
     Doc doc;
     int counter = 0;
index b9973f3..5716da6 100644 (file)
@@ -68,9 +68,12 @@ class EtaExpander : public ExprMutator {
 
   Module Expand() {
     for (GlobalVar global_var : mod_->GetGlobalVars()) {
-      const Function func = mod_->Lookup(global_var);
-      const Function new_func = Downcast<Function>(VisitExpr(func));
-      mod_->Update(global_var, new_func);
+      const BaseFunc base_func = mod_->Lookup(global_var);
+      if (auto* n = base_func.as<FunctionNode>()) {
+        const Function new_func = Downcast<Function>(
+            VisitExpr(GetRef<Function>(n)));
+        mod_->Update(global_var, new_func);
+      }
     }
     return mod_;
   }
@@ -120,21 +123,26 @@ class EtaExpander : public ExprMutator {
     if (!expand_global_var_) {
       return std::move(gvar);
     }
-
-    const auto func = mod_->Lookup(gvar);
-    tvm::Array<Expr> params;
-    tvm::Array<Var> args;
-    for (size_t i = 0; i < func->params.size(); ++i) {
-      auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
-      params.push_back(var);
-      args.push_back(var);
-    }
+    const auto base_func = mod_->Lookup(gvar);
+    if (auto *ptr = base_func.as<FunctionNode>()) {
+      // handle relay function, skip external functions.
+      auto func = GetRef<Function>(ptr);
+      tvm::Array<Expr> params;
+      tvm::Array<Var> args;
+      for (size_t i = 0; i < func->params.size(); ++i) {
+        auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
+        params.push_back(var);
+        args.push_back(var);
+      }
 
     return FunctionNode::make(
-      args,
-      CallNode::make(gvar, params),
-      func->ret_type,
-      func->type_params);
+        args,
+        CallNode::make(gvar, params),
+        func->ret_type,
+        func->type_params);
+    } else {
+      return std::move(gvar);
+    }
   }
 
  private:
index bce5879..352a1d7 100644 (file)
@@ -217,7 +217,7 @@ class ConstantFolder : public ExprMutator {
     mod->Add(global, func);
     auto seq = transform::Sequential(passes);
     mod = seq(mod);
-    auto entry_func = mod->Lookup("main");
+    auto entry_func = Downcast<Function>(mod->Lookup("main"));
     expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
     return ObjectToExpr(executor_(expr));
   }
index cd86aaf..8380128 100644 (file)
@@ -82,7 +82,12 @@ Type WithGradientType(const Type& t) {
 //! \brief if the expression is a GlobalVar, transform to it's expression.
 Expr DeGlobal(const Module& mod, const Expr& e) {
   if (const auto* x = e.as<GlobalVarNode>()) {
-    return mod->Lookup(GetRef<GlobalVar>(x))->body;
+    BaseFunc base_func = mod->Lookup(GetRef<GlobalVar>(x));
+    if (auto* n = base_func.as<FunctionNode>()) {
+      return n->body;
+    } else {
+      return e;
+    }
   } else {
     return e;
   }
index b066803..a2e8d06 100644 (file)
@@ -676,12 +676,18 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
   PStatic VisitGlobalVar(const GlobalVar& gv) {
     CHECK(mod_.defined());
     if (gv_map_.count(gv) == 0) {
-      Function func = mod_->Lookup(gv);
-      InitializeFuncId(func);
-      Func f = VisitFuncStatic(func, gv);
-      gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
-      func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv)));
-      mod_->Update(gv, func);
+      BaseFunc base_func = mod_->Lookup(gv);
+      if (auto* n = base_func.as<FunctionNode>()) {
+        Function func = GetRef<Function>(n);
+        InitializeFuncId(func);
+        Func f = VisitFuncStatic(func, gv);
+        gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
+        func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv)));
+        mod_->Update(gv, func);
+        return gv_map_.at(gv);
+      } else {
+        return NoStatic(gv);
+      }
     }
     return gv_map_.at(gv);
   }
@@ -951,7 +957,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
     auto mod = ModuleNode::FromExpr(expr);
     auto seq = transform::Sequential(passes);
     mod = seq(mod);
-    auto entry_func = mod->Lookup("main");
+    auto entry_func = Downcast<Function>(mod->Lookup("main"));
     auto fused_infered =
         expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
     return Reify(executor_(fused_infered), ll);
index f5c65e5..7d27059 100644 (file)
@@ -323,10 +323,14 @@ Module FunctionPassNode::operator()(const Module& mod,
   Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports());
   std::vector<std::pair<GlobalVar, Function> > updates;
   for (const auto& it : updated_mod->functions) {
-    auto updated_func = SkipFunction(it.second)
-                            ? it.second
-                            : pass_func(it.second, updated_mod, pass_ctx);
-    updates.push_back({it.first, updated_func});
+    // only picks up relay::Function
+    if (auto* n = it.second.as<FunctionNode>()) {
+      Function func = GetRef<Function>(n);
+      auto updated_func = SkipFunction(func)
+                          ? func
+                          : pass_func(func, updated_mod, pass_ctx);
+      updates.push_back({it.first, updated_func});
+    }
   }
 
   for (const auto& pair : updates) {
index b3b44cc..5f80871 100644 (file)
@@ -192,7 +192,7 @@ Expr QuantizeRealize(const Call& ref_call,
 Expr FoldConstantOpt(const Expr& expr) {
   auto mod = ModuleNode::FromExpr(expr);
   mod = transform::FoldConstant()(mod);
-  auto entry_func = mod->Lookup("main");
+  auto entry_func = Downcast<Function>(mod->Lookup("main"));
   return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
 }
 
index 9e2516b..898e4e9 100644 (file)
@@ -155,9 +155,17 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const
     Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final {
       auto gv = GetRef<GlobalVar>(op);
       if (cm->count(gv) == 0) {
-        auto cps_gv = GlobalVar(gv->name_hint + "_cps");
-        cm->insert({gv, cps_gv});
-        m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm));
+        // only look unfold non-external calls.
+        BaseFunc base_func = m->Lookup(gv);
+        if (auto* n = base_func.as<FunctionNode>()) {
+          auto cps_gv = GlobalVar(gv->name_hint + "_cps");
+          cm->insert({gv, cps_gv});
+          m->Add(cps_gv, ToCPS(GetRef<Function>(n), m, cm));
+        } else {
+          // return the original global var if it is
+          // an external call to non-relay function.
+          return GetRef<GlobalVar>(op);
+        }
       }
       return k(cm->at(gv));
     }
index 3914d96..22934a6 100644 (file)
@@ -86,7 +86,7 @@ TEST(Relay, Sequential) {
   CHECK(mod.defined());
   auto entry_func = mod->GetGlobalVar("main");
   CHECK(entry_func.defined());
-  relay::Function f = mod->Lookup("main");
+  relay::Function f = Downcast<relay::Function>(mod->Lookup("main"));
   CHECK(f.defined());
 
   // Expected function