class ModuleNode : public RelayNode {
public:
/*! \brief A map from ids to all global functions. */
- tvm::Map<GlobalVar, Function> functions;
+ tvm::Map<GlobalVar, BaseFunc> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;
v->Visit("global_type_var_map_", &global_type_var_map_);
}
- TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
+ TVM_DLL static Module make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {});
* \param update Controls whether you can replace a definition in the
* environment.
*/
- TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);
+ TVM_DLL void Add(const GlobalVar& var, const BaseFunc& func, bool update = false);
/*!
* \brief Add a function to the global environment.
*
* It does not do type inference as Add does.
*/
- TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
+ TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func);
/*!
* \brief Add a type-level definition to the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
- TVM_DLL void Update(const GlobalVar& var, const Function& func);
+ TVM_DLL void Update(const GlobalVar& var, const BaseFunc& func);
/*!
* \brief Update a type definition in the global environment.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
- TVM_DLL Function Lookup(const GlobalVar& var) const;
+ TVM_DLL BaseFunc Lookup(const GlobalVar& var) const;
/*!
* \brief Look up a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
- TVM_DLL Function Lookup(const std::string& name) const;
+ TVM_DLL BaseFunc Lookup(const std::string& name) const;
/*!
* \brief Look up a global type definition by its variable.
*/
TVM_DLL static Module FromExpr(
const Expr& expr,
- const tvm::Map<GlobalVar, Function>& global_funcs = {},
+ const tvm::Map<GlobalVar, BaseFunc>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
static constexpr const char* _type_key = "relay.Module";
// Optimize input Relay Function and returns Relay Module
relay::Module relay_module = Optimize(func, targets_, params);
// Get the updated function.
- func = relay_module->Lookup("main");
+ func = Downcast<Function>(relay_module->Lookup("main"));
// Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
CHECK(it != context_->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second;
- auto func = context_->module->Lookup(global);
+
+ // TODO(tvm-team):
+ // Think about mixed call into global that is not a relay::Function
+ // perhaps establish as an invariance(all functions in mod must be relay::Function)
+ auto func = Downcast<Function>(context_->module->Lookup(global));
+
+
if (IsClosure(func)) {
auto arity = func->params.size();
Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
- auto f = BindParamsByName(mod->Lookup("main"), params_);
+ BaseFunc base_func = mod->Lookup("main");
+ CHECK(base_func->IsInstance<FunctionNode>())
+ << "VM compiler expects to compile relay::Function";
+ auto f = BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
- auto func = named_func.second;
- VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
- auto vm_func = func_compiler.Compile(gvar, func);
-
- size_t func_index = context_.global_map.at(gvar);
- CHECK(func_index < exec_->functions.size());
- exec_->functions[func_index] = vm_func;
+ if (auto* n = named_func.second.as<FunctionNode>()) {
+ auto func = GetRef<Function>(n);
+ VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
+ auto vm_func = func_compiler.Compile(gvar, func);
+
+ size_t func_index = context_.global_map.at(gvar);
+ CHECK(func_index < exec_->functions.size());
+ exec_->functions[func_index] = vm_func;
+ }
}
#if USE_RELAY_DEBUG
auto gvar_funcs = module_->functions;
for (auto pair : gvar_funcs) {
auto global = pair.first;
- auto func = pair.second;
- DLOG(INFO) << "Before inlining primitives: " << global
- << std::endl << AsText(func, false);
-
- func = FunctionNode::make(func->params,
- VisitExpr(func->body),
- func->ret_type,
- func->type_params,
- func->attrs);
- module_->Add(global, func, true);
-
- DLOG(INFO) << "After inlining primitives: " << global
- << std::endl << AsText(func, false);
+ auto base_func = pair.second;
+ if (auto* n = base_func.as<FunctionNode>()) {
+ auto func = GetRef<Function>(n);
+
+ DLOG(INFO) << "Before inlining primitives: " << global
+ << std::endl << AsText(func, false);
+
+ func = FunctionNode::make(func->params,
+ VisitExpr(func->body),
+ func->ret_type,
+ func->type_params,
+ func->attrs);
+ module_->Add(global, func, true);
+
+ DLOG(INFO) << "After inlining primitives: " << global
+ << std::endl << AsText(func, false);
+ }
}
return module_;
}
// There is an ordering bug here.
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
- auto func = pair.second;
- func = FunctionNode::make(func->params,
- VisitExpr(func->body),
- func->ret_type,
- func->type_params,
- func->attrs);
- module_->Add(pair.first, func, true);
+ if (auto* n = pair.second.as<FunctionNode>()) {
+ auto func = GetRef<Function>(n);
+ func = FunctionNode::make(func->params,
+ VisitExpr(func->body),
+ func->ret_type,
+ func->type_params,
+ func->attrs);
+ module_->Add(pair.first, func, true);
+ }
}
return module_;
}
using tvm::NodePrinter;
using namespace runtime;
-Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
+Module ModuleNode::make(tvm::Map<GlobalVar, BaseFunc> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
- std::unordered_set<std::string> imports
- ) {
+ std::unordered_set<std::string> imports) {
auto n = make_object<ModuleNode>();
n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs);
return ret;
}
-void ModuleNode::Add(const GlobalVar& var,
- const Function& f,
- bool update) {
- Function func = Downcast<Function>(DeDup(f));
+// helper function to run type check
+relay::Function RunTypeCheck(const Module& mod,
+ const GlobalVar& var,
+ relay::Function f) {
+ auto func = Downcast<relay::Function>(relay::DeDup(std::move(f)));
// Type check the item before we add it to the module.
- auto mod = GetRef<Module>(this);
- auto fv = FreeVars(func);
- auto ftv = FreeTypeVars(func, mod);
+ auto fv = relay::FreeVars(func);
+ auto ftv = relay::FreeTypeVars(func, mod);
if (fv.size() != 0) {
LOG(WARNING)
- << "There are free variables: "
- << fv
- << " in function: "
- << AsText(func, false)
- << std::endl;
+ << "There are free variables: "
+ << fv
+ << " in function: "
+ << AsText(func, false)
+ << std::endl;
}
if (ftv.size() != 0) {
LOG(WARNING)
- << "There are free type variables: "
- << ftv
- << " in function: "
- << AsText(func, false)
- << std::endl;
+ << "There are free type variables: "
+ << ftv
+ << " in function: "
+ << AsText(func, false)
+ << std::endl;
}
func =
- FunctionNode::make(concat(func->params, fv),
- func->body,
- func->ret_type,
- concat(func->type_params, ftv),
- func->attrs);
+ relay::FunctionNode::make(concat(func->params, fv),
+ func->body,
+ func->ret_type,
+ concat(func->type_params, ftv),
+ func->attrs);
// Type check the item before we add it to the module.
- Function checked_func = InferType(func, mod, var);
+ relay::Function checked_func = InferType(func, mod, var);
+ return checked_func;
+}
+
+void ModuleNode::Add(const GlobalVar& var,
+ const BaseFunc& f,
+ bool update) {
+ BaseFunc checked_func = f;
+ if (auto* ptr = f.as<relay::FunctionNode>()) {
+ checked_func = RunTypeCheck(GetRef<Module>(this),
+ var,
+ GetRef<relay::Function>(ptr));
+ }
+
auto type = checked_func->checked_type();
- CHECK(type.as<IncompleteTypeNode>() == nullptr);
+ CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);
+
if (functions.find(var) != functions.end()) {
CHECK(update)
<< "Already have definition for " << var->name_hint;
}
void ModuleNode::AddUnchecked(const GlobalVar& var,
- const Function& func) {
- auto mod = GetRef<Module>(this);
+ const BaseFunc& func) {
this->functions.Set(var, func);
auto it = global_var_map_.find(var->name_hint);
}
}
-void ModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
+void ModuleNode::AddTypeDef(const GlobalTypeVar& var,
+ const TypeData& type,
+ bool update) {
AddTypeDefUnchecked(var, type, update);
// need to kind check at the end because the check can look up
// a definition potentially
- CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
+ CHECK(relay::KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type;
}
-void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) {
+void ModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var,
+ const TypeData& type,
+ bool update) {
this->type_definitions.Set(var, type);
if (!update) {
// set global type var map
RegisterConstructors(var, type);
}
-void ModuleNode::Update(const GlobalVar& var, const Function& func) {
+void ModuleNode::Update(const GlobalVar& var,
+ const BaseFunc& func) {
this->Add(var, func, true);
}
-void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
+void ModuleNode::UpdateTypeDef(const GlobalTypeVar& var,
+ const TypeData& type) {
this->AddTypeDef(var, type, true);
}
gvar_node->data.erase(var->name_hint);
}
-Function ModuleNode::Lookup(const GlobalVar& var) const {
+BaseFunc ModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var);
CHECK(it != functions.end())
<< "There is no definition of " << var->name_hint;
return (*it).second;
}
-Function ModuleNode::Lookup(const std::string& name) const {
+BaseFunc ModuleNode::Lookup(const std::string& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}
}
Module ModuleNode::FromExpr(
- const Expr& expr,
- const tvm::Map<GlobalVar, Function>& global_funcs,
+ const RelayExpr& expr,
+ const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = ModuleNode::make(global_funcs, type_definitions);
- auto func_node = expr.as<FunctionNode>();
- Function func;
- if (func_node) {
- func = GetRef<Function>(func_node);
+ BaseFunc func;
+ if (auto* func_node = expr.as<relay::FunctionNode>()) {
+ func = GetRef<relay::Function>(func_node);
} else {
- func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
+ func = relay::FunctionNode::make(
+ relay::FreeVars(expr), expr, Type(),
+ relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar("main");
mod->Add(main_gv, func);
TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_GLOBAL("relay._make.Module")
-.set_body_typed(
-[](tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalTypeVar, TypeData> types) {
+.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
+ tvm::Map<GlobalTypeVar, TypeData> types) {
return ModuleNode::make(funcs, types, {});
});
ObjectRef val = args[2];
bool update = args[3];
CHECK(val->IsInstance<ExprNode>());
- if (val->IsInstance<FunctionNode>()) {
- mod->Add(var, Downcast<Function>(val), update);
+
+ if (val->IsInstance<relay::FunctionNode>()) {
+ mod->Add(var, Downcast<relay::Function>(val), update);
} else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = Module(make_object<ModuleNode>(*mod.operator->()));
- mod_copy = transform::EtaExpand(
- /* expand_constructor */ false, /* expand_global_var */ true)(mod_copy);
+ mod_copy = relay::transform::EtaExpand(
+ /* expand_constructor */ false,
+ /* expand_global_var */ true)(mod_copy);
auto func = mod_copy->Lookup(gv->name_hint);
- mod->Add(var, Downcast<Function>(func), update);
+ mod->Add(var, Downcast<relay::Function>(func), update);
} else {
- auto func = FunctionNode::make({}, Downcast<Expr>(val), Type(nullptr), {});
+ auto func = FunctionNode::make({}, Downcast<relay::Expr>(val), Type(nullptr), {});
mod->Add(var, func, update);
}
*ret = mod;
});
TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr")
-.set_body_typed([](Expr e,
- tvm::Map<GlobalVar, Function> funcs,
+.set_body_typed([](RelayExpr e,
+ tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return ModuleNode::FromExpr(e, funcs, type_defs);
});
return doc;
}
- Doc PrintFunc(const Doc& prefix, const Function& fn) {
+ Doc PrintFunc(const Doc& prefix, const relay::Function& fn) {
Doc doc;
doc << prefix;
if (fn->type_params.size() > 0) {
return doc;
}
+ Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
+ if (auto* n = base_func.as<relay::FunctionNode>()) {
+ return PrintFunc(prefix, GetRef<relay::Function>(n));
+ } else {
+ // def @xyz = meta['ExternalFunc'][id]
+ Doc doc;
+ doc << prefix << " = " << meta_.GetMetaNode(base_func);
+ return doc;
+ }
+ }
+
Doc PrintMod(const Module& mod) {
Doc doc;
int counter = 0;
Module Expand() {
for (GlobalVar global_var : mod_->GetGlobalVars()) {
- const Function func = mod_->Lookup(global_var);
- const Function new_func = Downcast<Function>(VisitExpr(func));
- mod_->Update(global_var, new_func);
+ const BaseFunc base_func = mod_->Lookup(global_var);
+ if (auto* n = base_func.as<FunctionNode>()) {
+ const Function new_func = Downcast<Function>(
+ VisitExpr(GetRef<Function>(n)));
+ mod_->Update(global_var, new_func);
+ }
}
return mod_;
}
if (!expand_global_var_) {
return std::move(gvar);
}
-
- const auto func = mod_->Lookup(gvar);
- tvm::Array<Expr> params;
- tvm::Array<Var> args;
- for (size_t i = 0; i < func->params.size(); ++i) {
- auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
- params.push_back(var);
- args.push_back(var);
- }
+ const auto base_func = mod_->Lookup(gvar);
+ if (auto *ptr = base_func.as<FunctionNode>()) {
+ // handle relay function, skip external functions.
+ auto func = GetRef<Function>(ptr);
+ tvm::Array<Expr> params;
+ tvm::Array<Var> args;
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
+ params.push_back(var);
+ args.push_back(var);
+ }
return FunctionNode::make(
- args,
- CallNode::make(gvar, params),
- func->ret_type,
- func->type_params);
+ args,
+ CallNode::make(gvar, params),
+ func->ret_type,
+ func->type_params);
+ } else {
+ return std::move(gvar);
+ }
}
private:
mod->Add(global, func);
auto seq = transform::Sequential(passes);
mod = seq(mod);
- auto entry_func = mod->Lookup("main");
+ auto entry_func = Downcast<Function>(mod->Lookup("main"));
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return ObjectToExpr(executor_(expr));
}
//! \brief if the expression is a GlobalVar, transform to it's expression.
Expr DeGlobal(const Module& mod, const Expr& e) {
if (const auto* x = e.as<GlobalVarNode>()) {
- return mod->Lookup(GetRef<GlobalVar>(x))->body;
+ BaseFunc base_func = mod->Lookup(GetRef<GlobalVar>(x));
+ if (auto* n = base_func.as<FunctionNode>()) {
+ return n->body;
+ } else {
+ return e;
+ }
} else {
return e;
}
PStatic VisitGlobalVar(const GlobalVar& gv) {
CHECK(mod_.defined());
if (gv_map_.count(gv) == 0) {
- Function func = mod_->Lookup(gv);
- InitializeFuncId(func);
- Func f = VisitFuncStatic(func, gv);
- gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
- func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv)));
- mod_->Update(gv, func);
+ BaseFunc base_func = mod_->Lookup(gv);
+ if (auto* n = base_func.as<FunctionNode>()) {
+ Function func = GetRef<Function>(n);
+ InitializeFuncId(func);
+ Func f = VisitFuncStatic(func, gv);
+ gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
+ func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv)));
+ mod_->Update(gv, func);
+ return gv_map_.at(gv);
+ } else {
+ return NoStatic(gv);
+ }
}
return gv_map_.at(gv);
}
auto mod = ModuleNode::FromExpr(expr);
auto seq = transform::Sequential(passes);
mod = seq(mod);
- auto entry_func = mod->Lookup("main");
+ auto entry_func = Downcast<Function>(mod->Lookup("main"));
auto fused_infered =
expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
return Reify(executor_(fused_infered), ll);
Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) {
- auto updated_func = SkipFunction(it.second)
- ? it.second
- : pass_func(it.second, updated_mod, pass_ctx);
- updates.push_back({it.first, updated_func});
+ // only picks up relay::Function
+ if (auto* n = it.second.as<FunctionNode>()) {
+ Function func = GetRef<Function>(n);
+ auto updated_func = SkipFunction(func)
+ ? func
+ : pass_func(func, updated_mod, pass_ctx);
+ updates.push_back({it.first, updated_func});
+ }
}
for (const auto& pair : updates) {
Expr FoldConstantOpt(const Expr& expr) {
auto mod = ModuleNode::FromExpr(expr);
mod = transform::FoldConstant()(mod);
- auto entry_func = mod->Lookup("main");
+ auto entry_func = Downcast<Function>(mod->Lookup("main"));
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
}
Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final {
auto gv = GetRef<GlobalVar>(op);
if (cm->count(gv) == 0) {
- auto cps_gv = GlobalVar(gv->name_hint + "_cps");
- cm->insert({gv, cps_gv});
- m->Add(cps_gv, ToCPS(m->Lookup(gv), m, cm));
+ // only look unfold non-external calls.
+ BaseFunc base_func = m->Lookup(gv);
+ if (auto* n = base_func.as<FunctionNode>()) {
+ auto cps_gv = GlobalVar(gv->name_hint + "_cps");
+ cm->insert({gv, cps_gv});
+ m->Add(cps_gv, ToCPS(GetRef<Function>(n), m, cm));
+ } else {
+ // return the original global var if it is
+ // an external call to non-relay function.
+ return GetRef<GlobalVar>(op);
+ }
}
return k(cm->at(gv));
}
CHECK(mod.defined());
auto entry_func = mod->GetGlobalVar("main");
CHECK(entry_func.defined());
- relay::Function f = mod->Lookup("main");
+ relay::Function f = Downcast<relay::Function>(mod->Lookup("main"));
CHECK(f.defined());
// Expected function