class BuildModule(object):
- """Build a Relay function to run on TVM graph runtime. This class is used
+ """Build an IR module to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
"""
def __init__(self):
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
- def build(self, func, target=None, target_host=None, params=None):
+ def build(self, mod, target=None, target_host=None, params=None):
"""
Parameters
----------
- func: relay.Function
- The function to build.
+ mod : :py:class:`~tvm.IRModule`
+ The IRModule to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
# Setup the params.
if params:
self._set_params(params)
- # Build the function
- self._build(func, target, target_host)
+ # Build the IR module
+ self._build(mod, target, target_host)
# Get artifacts
graph_json = self.get_json()
mod = self.get_module()
return graph_json, mod, params
- def optimize(self, func, target=None, params=None):
+ def optimize(self, mod, target=None, params=None):
"""
Parameters
----------
- func: relay.Function
- The function to build.
+ mod : :py:class:`~tvm.IRModule`
+ The IR module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
Returns
-------
- mod : tvm.IRModule
+ mod : :py:class:`~tvm.IRModule`
The optimized relay module.
params : dict
# Setup the params.
if params:
self._set_params(params)
- mod = self._optimize(func, target)
+ mod = self._optimize(mod, target)
# Get artifacts
params = self.get_params()
Parameters
----------
- mod : tvm.IRModule
- The module to build. Using relay.Function is deprecated.
+ mod : :py:class:`~tvm.IRModule`
+ The IR module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional
params : dict
The parameters of the final graph.
"""
- if isinstance(mod, IRModule):
- func = mod["main"]
- elif isinstance(mod, _expr.Function):
- func = mod
+ if not isinstance(mod, (IRModule, _expr.Function)):
+ raise ValueError("Type of input parameter mod must be tvm.IRModule")
+
+ if isinstance(mod, _expr.Function):
+ mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
- "instead of deprecated parameter func (tvm.relay.expr.Function)",
+ "instead of deprecated parameter mod (tvm.relay.expr.Function)",
DeprecationWarning)
- else:
- raise ValueError("Type of input parameter mod must be tvm.IRModule")
target = _update_target(target)
with tophub_context:
bld_mod = BuildModule()
- graph_json, mod, params = bld_mod.build(func, target, target_host, params)
+ graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
return graph_json, mod, params
Parameters
----------
- mod : tvm.IRModule
+ mod : :py:class:`~tvm.IRModule`
The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
Returns
-------
- mod : tvm.IRModule
+ mod : :py:class:`~tvm.IRModule`
The optimized relay module.
params : dict
The parameters of the final graph.
"""
- if isinstance(mod, IRModule):
- func = mod["main"]
- elif isinstance(mod, _expr.Function):
- func = mod
+ if not isinstance(mod, (IRModule, _expr.Function)):
+ raise ValueError("Type of input parameter mod must be tvm.IRModule")
+
+ if isinstance(mod, _expr.Function):
+ mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning)
- else:
- raise ValueError("Type of input parameter mod must be tvm.IRModule")
target = _update_target(target)
with tophub_context:
bld_mod = BuildModule()
- mod, params = bld_mod.optimize(func, target, params)
+ mod, params = bld_mod.optimize(mod, target, params)
return mod, params
}
/*!
- * \brief Build relay function for graph runtime
+ * \brief Build relay IRModule for graph runtime
*
- * \param func Relay Function
+ * \param mod Relay IRModule
* \param target Target device
* \param target_host Host target device
*/
- void Build(Function func,
+ void Build(IRModule mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
targets_ = targets;
target_host_ = target_host;
- BuildRelay(func, params_);
+ BuildRelay(mod, params_);
}
protected:
/*!
- * \brief Optimize a Relay Function.
+ * \brief Optimize a Relay IRModule.
*
- * \param func The input Function where optmization will be applied on.
+ * \param relay_module The input IRModule where optmization will be applied on.
* \param targets The device type to `Target` mapping.
* \param params The param name to value mapping.
*
- * \return relay::Module The updated Relay module after optimization.
+ * \return relay::IRModule The updated Relay IR module after optimization.
*/
IRModule Optimize(
- Function func,
+ IRModule relay_module,
const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
if (params.size()) {
- func = BindParamsByName(func, params);
+ CHECK(relay_module->ContainGlobalVar("main"))
+ << "Missing the main entry function";
+ GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
+ Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
+ auto new_main = BindParamsByName(main_func, params);
+ relay_module->Update(main_glb_var, new_main);
}
- // Perform Module->Module optimizations.
- IRModule relay_module = IRModule::FromExpr(func);
-
Array<Pass> pass_seqs;
+ Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
+ pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
}
/*!
- * \brief Compile a Relay function to runtime module.
+ * \brief Compile a Relay IR module to runtime module.
*
- * \param func The Relay function.
+ * \param relay_module The Relay IR module.
* \param params The parameters.
*/
void BuildRelay(
- Function func,
+ IRModule relay_module,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
- // Optimize input Relay Function and returns Relay Module
- IRModule relay_module = Optimize(func, targets_, params);
+ // Relay IRModule -> IRModule optimizations.
+ relay_module = Optimize(relay_module, targets_, params);
// Get the updated function.
- func = Downcast<Function>(relay_module->Lookup("main"));
+ auto func = Downcast<Function>(relay_module->Lookup("main"));
// Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());