logger = logging.getLogger('autotvm')
-# TODO(moreau89) find a more elegant way to build for VTAs
-def _build(func,
+# TODO(moreau89) find a more elegant way to lower for VTAs
+def _lower(func,
target,
- target_host,
params):
- """ Helper to build VTA properly.
+ """ Helper to lower VTA properly.
"""
from tvm import relay
+ from tvm.relay.backend import graph_runtime_codegen
if hasattr(target, 'device_name') and target.device_name == "vta":
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
import vta
with vta.build_config():
- return relay.build(func, target, target_host, params)
+ mod, _ = relay.optimize(func, target, params)
+ grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
+ return grc.codegen(mod["main"])
# default case
- return relay.build(func, target, target_host, params)
+ mod, _ = relay.optimize(func, target, params)
+ grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
+ return grc.codegen(mod["main"])
+
def extract_from_program(func, params, ops, target, target_host=None):
""" Extract tuning tasks from a relay program.
relay.backend.compile_engine.get().clear()
# wrap build call in thread to avoid multiprocessing problems
mod = relay.Module.from_expr(func)
- build_thread = threading.Thread(target=_build,
- args=(mod, target, target_host, param))
+ build_thread = threading.Thread(target=_lower,
+ args=(mod, target, param))
build_thread.start()
build_thread.join()
from . import adt
from . import analysis
from . import transform
-from .build_module import build, create_executor
+from .build_module import build, create_executor, optimize
from .transform import build_config
from . import prelude
from . import parser
from __future__ import absolute_import
from tvm.ndarray import empty
-from tvm.relay import build_module
+from tvm.relay import _build_module
from tvm import target as _target
from tvm import expr as _expr
"""The compiler from Relay to the TVM runtime system."""
def __init__(self, mod, target):
- self._mod = build_module._GraphRuntimeCodegen()
+ self._mod = _build_module._GraphRuntimeCodegen()
self._init = self._mod["init"]
self._codegen = self._mod["codegen"]
self._get_graph_json = self._mod["get_graph_json"]
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
+ self._optimize = self.mod["optimize"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
return graph_json, mod, params
+ def optimize(self, func, target=None, params=None):
+ """
+ Parameters
+ ----------
+ func: relay.Function
+ The function to build.
+
+ target : str, :any:`tvm.target.Target`, or dict of str(i.e.
+ device/context name) to str/tvm.target.Target, optional
+ For heterogeneous compilation, it is a dictionary indicating context
+ to target mapping. For homogeneous compilation, it is a build target.
+
+ params : dict of str to NDArray
+ Input parameters to the graph that do not change
+ during inference time. Used for constant folding.
+
+ Returns
+ -------
+ mod : relay.Module
+ The optimized relay module.
+
+ params : dict
+ The parameters of the final graph.
+ """
+ target = _update_target(target)
+
+ # Setup the params.
+ if params:
+ self._set_params(params)
+ mod = self._optimize(func, target)
+ # Get artifacts
+ params = self.get_params()
+
+ return mod, params
+
+
def _set_params(self, params):
inputs = {}
for name, param in params.items():
return graph_json, mod, params
+def optimize(mod, target=None, params=None):
+ """Helper function that optimizes a Relay module.
+
+ Parameters
+ ----------
+ mod : relay.Module
+ The module to build. Using relay.Function is deprecated.
+
+ target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
+ name) to str/tvm.target.Target, optional
+ For heterogeneous compilation, it is a dictionary indicating context to
+ target mapping. For homogeneous compilation, it is a build target.
+
+ params : dict of str to NDArray
+ Input parameters to the graph that do not change
+ during inference time. Used for constant folding.
+
+ Returns
+ -------
+ mod : relay.Module
+ The optimized relay module.
+
+ params : dict
+ The parameters of the final graph.
+ """
+ if isinstance(mod, _Module):
+ func = mod["main"]
+ elif isinstance(mod, _expr.Function):
+ func = mod
+ warnings.warn(
+ "Please use input parameter mod (tvm.relay.module.Module) "
+ "instead of deprecated parameter func (tvm.relay.expr.Function)",
+ DeprecationWarning)
+ else:
+ raise ValueError("Type of input parameter mod must be tvm.relay.module.Module")
+
+ target = _update_target(target)
+
+ # If current dispatch context is fallback context (the default root context),
+ # then load pre-tuned parameters from TopHub
+ if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
+ tophub_context = autotvm.tophub.context(list(target.values()))
+ else:
+ tophub_context = autotvm.util.EmptyContext()
+
+ with tophub_context:
+ bld_mod = BuildModule()
+ mod, params = bld_mod.optimize(func, target, params)
+ return mod, params
+
+
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetLoweredFunc();
});
+ } else if (name == "optimize") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ CHECK_EQ(args.num_args, 2);
+ *rv = this->Optimize(args[0], args[1], this->params_);
+ });
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
}
/*!
- * \brief Optimize a Relay module.
+ * \brief Optimize a Relay Function.
*
- * \param relay_module The input Relay module where optmization will be
- * applied on.
+ * \param func The input Function where optmization will be applied on.
* \param targets The device type to `Target` mapping.
* \param params The param name to value mapping.
*
* \return relay::Module The updated Relay module after optimization.
*/
relay::Module Optimize(
- relay::Module relay_module,
+ Function func,
const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
+ if (params.size()) {
+ func = BindParamsByName(func, params);
+ }
+
+ // Perform Module->Module optimizations.
+ relay::Module relay_module = relay::ModuleNode::FromExpr(func);
+
Array<Pass> pass_seqs;
// Run all dialect legalization passes.
// Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module);
+ CHECK(relay_module.defined());
return relay_module;
}
void BuildRelay(
Function func,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
- if (params.size()) {
- func = BindParamsByName(func, params);
- }
-
- // Perform Module->Module optimizations.
- relay::Module relay_module = relay::ModuleNode::FromExpr(func);
- relay_module = Optimize(relay_module, targets_, params);
- CHECK(relay_module.defined());
+ // Optimize input Relay Function and returns Relay Module
+ relay::Module relay_module = Optimize(func, targets_, params);
// Get the updated function.
func = relay_module->Lookup("main");