From: Tianqi Chen Date: Sat, 11 Jan 2020 06:54:16 +0000 (-0800) Subject: [REFACTOR][IR] Allow Module to store BaseFunc. (#4678) X-Git-Tag: upstream/0.7.0~1396 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3d52a99c8bf6bab20a010932061c6335ee97fff0;p=platform%2Fupstream%2Ftvm.git [REFACTOR][IR] Allow Module to store BaseFunc. (#4678) Under the unified IR. We will allow a single IRModule to store different function variants, such as relay::Function, ExternFunc, and low-level function. This PR changes relay::Function -> BaseFunc in the module file to support multiple function variants. --- diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index dba4e4a..4fc31eb 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -62,7 +62,7 @@ struct Module; class ModuleNode : public RelayNode { public: /*! \brief A map from ids to all global functions. */ - tvm::Map functions; + tvm::Map functions; /*! \brief A map from global type vars to ADT type data. */ tvm::Map type_definitions; @@ -75,7 +75,7 @@ class ModuleNode : public RelayNode { v->Visit("global_type_var_map_", &global_type_var_map_); } - TVM_DLL static Module make(tvm::Map global_funcs, + TVM_DLL static Module make(tvm::Map global_funcs, tvm::Map global_type_defs, std::unordered_set imports = {}); @@ -86,7 +86,7 @@ class ModuleNode : public RelayNode { * \param update Controls whether you can replace a definition in the * environment. */ - TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false); + TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false); /*! * \brief Add a function to the global environment. @@ -95,7 +95,7 @@ class ModuleNode : public RelayNode { * * It does not do type inference as Add does. */ - TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func); + TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func); /*! * \brief Add a type-level definition to the global environment. @@ -124,7 +124,7 @@ class ModuleNode : public RelayNode { * \param var The name of the global function to update. * \param func The new function. */ - TVM_DLL void Update(const GlobalVar& var, const Function& func); + TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func); /*! * \brief Update a type definition in the global environment. @@ -184,14 +184,14 @@ class ModuleNode : public RelayNode { * \param var The global var to lookup. * \returns The function named by the variable argument. */ - TVM_DLL Function Lookup(const GlobalVar& var) const; + TVM_DLL BaseFunc Lookup(const GlobalVar& var) const; /*! * \brief Look up a global function by its string name * \param name The name of the function. * \returns The function named by the argument. */ - TVM_DLL Function Lookup(const std::string& name) const; + TVM_DLL BaseFunc Lookup(const std::string& name) const; /*! * \brief Look up a global type definition by its variable. @@ -256,7 +256,7 @@ class ModuleNode : public RelayNode { */ TVM_DLL static Module FromExpr( const Expr& expr, - const tvm::Map& global_funcs = {}, + const tvm::Map& global_funcs = {}, const tvm::Map& type_definitions = {}); static constexpr const char* _type_key = "relay.Module"; diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 839dabc..f64d556 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -463,7 +463,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Optimize input Relay Function and returns Relay Module relay::Module relay_module = Optimize(func, targets_, params); // Get the updated function. - func = relay_module->Lookup("main"); + func = Downcast(relay_module->Lookup("main")); // Generate code for the updated function. graph_codegen_ = std::unique_ptr(new GraphCodegen()); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index bb47685..b27c55b 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -612,7 +612,13 @@ class VMFunctionCompiler : ExprFunctor { CHECK(it != context_->global_map.end()); DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint << " with func_index=" << it->second; - auto func = context_->module->Lookup(global); + + // TODO(tvm-team): + // Think about mixed call into global that is not a relay::Function + // perhaps establish as an invariance(all functions in mod must be relay::Function) + auto func = Downcast(context_->module->Lookup(global)); + + if (IsClosure(func)) { auto arity = func->params.size(); Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); @@ -813,7 +819,10 @@ void VMCompiler::Lower(Module mod, CHECK_EQ(targets.size(), 1) << "Currently VM compiler doesn't support heterogeneous compilation"; if (params_.size()) { - auto f = BindParamsByName(mod->Lookup("main"), params_); + BaseFunc base_func = mod->Lookup("main"); + CHECK(base_func->IsInstance()) + << "VM compiler expects to compile relay::Function"; + auto f = BindParamsByName(Downcast(base_func), params_); auto gvar = mod->GetGlobalVar("main"); mod->Add(gvar, f); } @@ -837,13 +846,15 @@ void VMCompiler::Lower(Module mod, for (auto named_func : context_.module->functions) { auto gvar = named_func.first; - auto func = named_func.second; - VMFunctionCompiler func_compiler(&context_, targets_, target_host_); - auto vm_func = func_compiler.Compile(gvar, func); - - size_t func_index = context_.global_map.at(gvar); - CHECK(func_index < exec_->functions.size()); - exec_->functions[func_index] = vm_func; + if (auto* n = named_func.second.as()) { + auto func = GetRef(n); + VMFunctionCompiler func_compiler(&context_, targets_, target_host_); + auto vm_func = func_compiler.Compile(gvar, func); + + size_t func_index = context_.global_map.at(gvar); + CHECK(func_index < exec_->functions.size()); + exec_->functions[func_index] = vm_func; + } } #if USE_RELAY_DEBUG diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 25b0735..abd8d29 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -110,19 +110,23 @@ struct PrimitiveInliner : ExprMutator { auto gvar_funcs = module_->functions; for (auto pair : gvar_funcs) { auto global = pair.first; - auto func = pair.second; - DLOG(INFO) << "Before inlining primitives: " << global - << std::endl << AsText(func, false); - - func = FunctionNode::make(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, - func->attrs); - module_->Add(global, func, true); - - DLOG(INFO) << "After inlining primitives: " << global - << std::endl << AsText(func, false); + auto base_func = pair.second; + if (auto* n = base_func.as()) { + auto func = GetRef(n); + + DLOG(INFO) << "Before inlining primitives: " << global + << std::endl << AsText(func, false); + + func = FunctionNode::make(func->params, + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); + module_->Add(global, func, true); + + DLOG(INFO) << "After inlining primitives: " << global + << std::endl << AsText(func, false); + } } return module_; } diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 601af9e..34f03d2 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -188,13 +188,15 @@ class LambdaLifter : public ExprMutator { // There is an ordering bug here. auto glob_funcs = module_->functions; for (auto pair : glob_funcs) { - auto func = pair.second; - func = FunctionNode::make(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, - func->attrs); - module_->Add(pair.first, func, true); + if (auto* n = pair.second.as()) { + auto func = GetRef(n); + func = FunctionNode::make(func->params, + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); + module_->Add(pair.first, func, true); + } } return module_; } diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index bf1ebf3..0d2aebc 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -34,10 +34,9 @@ namespace relay { using tvm::NodePrinter; using namespace runtime; -Module ModuleNode::make(tvm::Map global_funcs, +Module ModuleNode::make(tvm::Map global_funcs, tvm::Map global_type_defs, - std::unordered_set imports - ) { + std::unordered_set imports) { auto n = make_object(); n->functions = std::move(global_funcs); n->type_definitions = std::move(global_type_defs); @@ -112,40 +111,54 @@ tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { return ret; } -void ModuleNode::Add(const GlobalVar& var, - const Function& f, - bool update) { - Function func = Downcast(DeDup(f)); +// helper function to run type check +relay::Function RunTypeCheck(const Module& mod, + const GlobalVar& var, + relay::Function f) { + auto func = Downcast(relay::DeDup(std::move(f))); // Type check the item before we add it to the module. - auto mod = GetRef(this); - auto fv = FreeVars(func); - auto ftv = FreeTypeVars(func, mod); + auto fv = relay::FreeVars(func); + auto ftv = relay::FreeTypeVars(func, mod); if (fv.size() != 0) { LOG(WARNING) - << "There are free variables: " - << fv - << " in function: " - << AsText(func, false) - << std::endl; + << "There are free variables: " + << fv + << " in function: " + << AsText(func, false) + << std::endl; } if (ftv.size() != 0) { LOG(WARNING) - << "There are free type variables: " - << ftv - << " in function: " - << AsText(func, false) - << std::endl; + << "There are free type variables: " + << ftv + << " in function: " + << AsText(func, false) + << std::endl; } func = - FunctionNode::make(concat(func->params, fv), - func->body, - func->ret_type, - concat(func->type_params, ftv), - func->attrs); + relay::FunctionNode::make(concat(func->params, fv), + func->body, + func->ret_type, + concat(func->type_params, ftv), + func->attrs); // Type check the item before we add it to the module. - Function checked_func = InferType(func, mod, var); + relay::Function checked_func = InferType(func, mod, var); + return checked_func; +} + +void ModuleNode::Add(const GlobalVar& var, + const BaseFunc& f, + bool update) { + BaseFunc checked_func = f; + if (auto* ptr = f.as()) { + checked_func = RunTypeCheck(GetRef(this), + var, + GetRef(ptr)); + } + auto type = checked_func->checked_type(); - CHECK(type.as() == nullptr); + CHECK(type.as() == nullptr); + if (functions.find(var) != functions.end()) { CHECK(update) << "Already have definition for " << var->name_hint; @@ -158,8 +171,7 @@ void ModuleNode::Add(const GlobalVar& var, } void ModuleNode::AddUnchecked(const GlobalVar& var, - const Function& func) { - auto mod = GetRef(this); + const BaseFunc& func) { this->functions.Set(var, func); auto it = global_var_map_.find(var->name_hint); @@ -185,15 +197,19 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& } } -void ModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) { +void ModuleNode::AddTypeDef(const GlobalTypeVar& var, + const TypeData& type, + bool update) { AddTypeDefUnchecked(var, type, update); // need to kind check at the end because the check can look up // a definition potentially - CHECK(KindCheck(type, GetRef(this)) == Kind::kTypeData) + CHECK(relay::KindCheck(type, GetRef(this)) == Kind::kTypeData) << "Invalid or malformed typedata given to module: " << type; } -void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) { +void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, + const TypeData& type, + bool update) { this->type_definitions.Set(var, type); if (!update) { // set global type var map @@ -204,11 +220,13 @@ void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& t RegisterConstructors(var, type); } -void ModuleNode::Update(const GlobalVar& var, const Function& func) { +void ModuleNode::Update(const GlobalVar& var, + const BaseFunc& func) { this->Add(var, func, true); } -void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) { +void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, + const TypeData& type) { this->AddTypeDef(var, type, true); } @@ -219,14 +237,14 @@ void ModuleNode::Remove(const GlobalVar& var) { gvar_node->data.erase(var->name_hint); } -Function ModuleNode::Lookup(const GlobalVar& var) const { +BaseFunc ModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); CHECK(it != functions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -Function ModuleNode::Lookup(const std::string& name) const { +BaseFunc ModuleNode::Lookup(const std::string& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } @@ -268,16 +286,17 @@ void ModuleNode::Update(const Module& mod) { } Module ModuleNode::FromExpr( - const Expr& expr, - const tvm::Map& global_funcs, + const RelayExpr& expr, + const tvm::Map& global_funcs, const tvm::Map& type_definitions) { auto mod = ModuleNode::make(global_funcs, type_definitions); - auto func_node = expr.as(); - Function func; - if (func_node) { - func = GetRef(func_node); + BaseFunc func; + if (auto* func_node = expr.as()) { + func = GetRef(func_node); } else { - func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {}); + func = relay::FunctionNode::make( + relay::FreeVars(expr), expr, Type(), + relay::FreeTypeVars(expr, mod), {}); } auto main_gv = GlobalVar("main"); mod->Add(main_gv, func); @@ -318,8 +337,8 @@ Module FromText(const std::string& source, const std::string& source_name) { TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_GLOBAL("relay._make.Module") -.set_body_typed( -[](tvm::Map funcs, tvm::Map types) { +.set_body_typed([](tvm::Map funcs, + tvm::Map types) { return ModuleNode::make(funcs, types, {}); }); @@ -330,17 +349,19 @@ TVM_REGISTER_GLOBAL("relay._module.Module_Add") ObjectRef val = args[2]; bool update = args[3]; CHECK(val->IsInstance()); - if (val->IsInstance()) { - mod->Add(var, Downcast(val), update); + + if (val->IsInstance()) { + mod->Add(var, Downcast(val), update); } else if (val->IsInstance()) { GlobalVar gv = Downcast(val); auto mod_copy = Module(make_object(*mod.operator->())); - mod_copy = transform::EtaExpand( - /* expand_constructor */ false, /* expand_global_var */ true)(mod_copy); + mod_copy = relay::transform::EtaExpand( + /* expand_constructor */ false, + /* expand_global_var */ true)(mod_copy); auto func = mod_copy->Lookup(gv->name_hint); - mod->Add(var, Downcast(func), update); + mod->Add(var, Downcast(func), update); } else { - auto func = FunctionNode::make({}, Downcast(val), Type(nullptr), {}); + auto func = FunctionNode::make({}, Downcast(val), Type(nullptr), {}); mod->Add(var, func, update); } *ret = mod; @@ -390,8 +411,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag") }); TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr") -.set_body_typed([](Expr e, - tvm::Map funcs, +.set_body_typed([](RelayExpr e, + tvm::Map funcs, tvm::Map type_defs) { return ModuleNode::FromExpr(e, funcs, type_defs); }); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index a9d788d..c88c3a0 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -486,7 +486,7 @@ class PrettyPrinter : return doc; } - Doc PrintFunc(const Doc& prefix, const Function& fn) { + Doc PrintFunc(const Doc& prefix, const relay::Function& fn) { Doc doc; doc << prefix; if (fn->type_params.size() > 0) { @@ -514,6 +514,17 @@ class PrettyPrinter : return doc; } + Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) { + if (auto* n = base_func.as()) { + return PrintFunc(prefix, GetRef(n)); + } else { + // def @xyz = meta['ExternalFunc'][id] + Doc doc; + doc << prefix << " = " << meta_.GetMetaNode(base_func); + return doc; + } + } + Doc PrintMod(const Module& mod) { Doc doc; int counter = 0; diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index b9973f3..5716da6 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -68,9 +68,12 @@ class EtaExpander : public ExprMutator { Module Expand() { for (GlobalVar global_var : mod_->GetGlobalVars()) { - const Function func = mod_->Lookup(global_var); - const Function new_func = Downcast(VisitExpr(func)); - mod_->Update(global_var, new_func); + const BaseFunc base_func = mod_->Lookup(global_var); + if (auto* n = base_func.as()) { + const Function new_func = Downcast( + VisitExpr(GetRef(n))); + mod_->Update(global_var, new_func); + } } return mod_; } @@ -120,21 +123,26 @@ class EtaExpander : public ExprMutator { if (!expand_global_var_) { return std::move(gvar); } - - const auto func = mod_->Lookup(gvar); - tvm::Array params; - tvm::Array args; - for (size_t i = 0; i < func->params.size(); ++i) { - auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation); - params.push_back(var); - args.push_back(var); - } + const auto base_func = mod_->Lookup(gvar); + if (auto *ptr = base_func.as()) { + // handle relay function, skip external functions. + auto func = GetRef(ptr); + tvm::Array params; + tvm::Array args; + for (size_t i = 0; i < func->params.size(); ++i) { + auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation); + params.push_back(var); + args.push_back(var); + } return FunctionNode::make( - args, - CallNode::make(gvar, params), - func->ret_type, - func->type_params); + args, + CallNode::make(gvar, params), + func->ret_type, + func->type_params); + } else { + return std::move(gvar); + } } private: diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index bce5879..352a1d7 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -217,7 +217,7 @@ class ConstantFolder : public ExprMutator { mod->Add(global, func); auto seq = transform::Sequential(passes); mod = seq(mod); - auto entry_func = mod->Lookup("main"); + auto entry_func = Downcast(mod->Lookup("main")); expr = expr.as() == nullptr ? entry_func->body : entry_func; return ObjectToExpr(executor_(expr)); } diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index cd86aaf..8380128 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -82,7 +82,12 @@ Type WithGradientType(const Type& t) { //! \brief if the expression is a GlobalVar, transform to it's expression. Expr DeGlobal(const Module& mod, const Expr& e) { if (const auto* x = e.as()) { - return mod->Lookup(GetRef(x))->body; + BaseFunc base_func = mod->Lookup(GetRef(x)); + if (auto* n = base_func.as()) { + return n->body; + } else { + return e; + } } else { return e; } diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index b066803..a2e8d06 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -676,12 +676,18 @@ class PartialEvaluator : public ExprFunctor PStatic VisitGlobalVar(const GlobalVar& gv) { CHECK(mod_.defined()); if (gv_map_.count(gv) == 0) { - Function func = mod_->Lookup(gv); - InitializeFuncId(func); - Func f = VisitFuncStatic(func, gv); - gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); - func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv))); - mod_->Update(gv, func); + BaseFunc base_func = mod_->Lookup(gv); + if (auto* n = base_func.as()) { + Function func = GetRef(n); + InitializeFuncId(func); + Func f = VisitFuncStatic(func, gv); + gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv))); + mod_->Update(gv, func); + return gv_map_.at(gv); + } else { + return NoStatic(gv); + } } return gv_map_.at(gv); } @@ -951,7 +957,7 @@ class PartialEvaluator : public ExprFunctor auto mod = ModuleNode::FromExpr(expr); auto seq = transform::Sequential(passes); mod = seq(mod); - auto entry_func = mod->Lookup("main"); + auto entry_func = Downcast(mod->Lookup("main")); auto fused_infered = expr.as() == nullptr ? entry_func->body : entry_func; return Reify(executor_(fused_infered), ll); diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index f5c65e5..7d27059 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -323,10 +323,14 @@ Module FunctionPassNode::operator()(const Module& mod, Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports()); std::vector > updates; for (const auto& it : updated_mod->functions) { - auto updated_func = SkipFunction(it.second) - ? it.second - : pass_func(it.second, updated_mod, pass_ctx); - updates.push_back({it.first, updated_func}); + // only picks up relay::Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + auto updated_func = SkipFunction(func) + ? func + : pass_func(func, updated_mod, pass_ctx); + updates.push_back({it.first, updated_func}); + } } for (const auto& pair : updates) { diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc index b3b44cc..5f80871 100644 --- a/src/relay/pass/quantize/realize.cc +++ b/src/relay/pass/quantize/realize.cc @@ -192,7 +192,7 @@ Expr QuantizeRealize(const Call& ref_call, Expr FoldConstantOpt(const Expr& expr) { auto mod = ModuleNode::FromExpr(expr); mod = transform::FoldConstant()(mod); - auto entry_func = mod->Lookup("main"); + auto entry_func = Downcast(mod->Lookup("main")); return expr.as() == nullptr ? entry_func->body : entry_func; } diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index 9e2516b..898e4e9 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -155,9 +155,17 @@ Function ToCPS(const Function& f, const Module& m, CPSMap* cm, VarMap* vm, const Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { auto gv = GetRef(op); if (cm->count(gv) == 0) { - auto cps_gv = GlobalVar(gv->name_hint + "_cps"); - cm->insert({gv, cps_gv}); - m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm)); + // only look unfold non-external calls. + BaseFunc base_func = m->Lookup(gv); + if (auto* n = base_func.as()) { + auto cps_gv = GlobalVar(gv->name_hint + "_cps"); + cm->insert({gv, cps_gv}); + m->Add(cps_gv, ToCPS(GetRef(n), m, cm)); + } else { + // return the original global var if it is + // an external call to non-relay function. + return GetRef(op); + } } return k(cm->at(gv)); } diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index 3914d96..22934a6 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -86,7 +86,7 @@ TEST(Relay, Sequential) { CHECK(mod.defined()); auto entry_func = mod->GetGlobalVar("main"); CHECK(entry_func.defined()); - relay::Function f = mod->Lookup("main"); + relay::Function f = Downcast(mod->Lookup("main")); CHECK(f.defined()); // Expected function