/*! \brief The default optimization level. */
int opt_level{2};
- /*! \brief CPU is the default fallback device for heterogeneous execution. */
- int fallback_device{static_cast<int>(kDLCPU)};
-
/*! \brief The list of required passes. */
Array<String> required_pass;
/*! \brief The list of disabled passes. */
void VisitAttrs(AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
- v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
v->Visit("config", &config);
*
* auto new_ctx = PassContext::Create();
* ctx->opt_level = 2;
- * ctx->fallback_device = kDLCPU;
* With<PassContext> scope(ctx);
* // pass context in effect.
*
import functools
import tvm._ffi
-
import tvm.runtime
-from tvm.runtime import ndarray as _nd
from . import _ffi_transform_api
opt_level : Optional[int]
The optimization level of this pass.
- fallback_device : Optional[Union[int, str, TVMContext]]
- The fallback device type. It is also used as the default device for
- operators that are not annotated during heterogeneous execution.
-
required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
The list of passes that are required by a certain pass.
"""
def __init__(self,
opt_level=2,
- fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None,
trace=None,
config=None):
- if isinstance(fallback_device, str):
- fallback_device = _nd.context(fallback_device).device_type
- elif isinstance(fallback_device, tvm.runtime.TVMContext):
- fallback_device = fallback_device.device_type
- if not isinstance(fallback_device, int):
- raise TypeError("fallback_device is expected to be the type of " +
- "int/str/TVMContext.")
-
required = list(required_pass) if required_pass else []
if not isinstance(required, (list, tuple)):
raise TypeError("required_pass is expected to be the type of " +
config = config if config else None
self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level,
- fallback_device, required,
- disabled, trace, config)
+ required, disabled, trace, config)
def __enter__(self):
_ffi_transform_api.EnterPassContext(self)
def build_config(opt_level=2,
- fallback_device=_nd.cpu(),
required_pass=None,
disabled_pass=None,
trace=None):
"FastMath": 4
}
- fallback_device : int, str, or tvmContext, optional
- The fallback device. It is also used as the default device for
- operators without specified device during heterogeneous execution.
-
required_pass: set of str, optional
Optimization passes that are required regardless of optimization level.
pass_context: PassContext
The pass context for optimizations.
"""
- return tvm.ir.transform.PassContext(
- opt_level, fallback_device, required_pass,
- disabled_pass, trace)
+ return tvm.ir.transform.PassContext(opt_level, required_pass,
+ disabled_pass, trace)
@tvm._ffi.register_object("relay.FunctionPass")
TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_GLOBAL("transform.PassContext")
- .set_body_typed([](int opt_level, int fallback_device, Array<String> required,
- Array<String> disabled, TraceFunc trace_func,
- Optional<Map<std::string, ObjectRef>> config) {
+ .set_body_typed([](int opt_level, Array<String> required, Array<String> disabled,
+ TraceFunc trace_func, Optional<Map<std::string, ObjectRef>> config) {
auto pctx = PassContext::Create();
pctx->opt_level = opt_level;
- pctx->fallback_device = fallback_device;
pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
p->stream << "Pass context information: "
<< "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
- p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n";
p->stream << "\trequired passes: [";
for (const auto& it : node->required_pass) {
// Handle heterogeneous compilation.
transform::PassContext pass_ctx = PassContext::Current();
if (targets_.size() > 1) {
- relay_module = RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device);
+ Optional<IntImm> opt_fallback_dev =
+ pass_ctx->GetConfig("relay.fallback_device_type",
+ IntImm(runtime::DataType::Int(32), static_cast<int>(kDLCPU)));
+ auto fallback_dev = opt_fallback_dev.value();
+ CHECK_GT(fallback_dev->value, 0U);
+ relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value);
}
// Fuse the operations if it is needed.
namespace relay {
namespace transform {
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.fallback_device_type", IntImm);
+
class FunctionPass;
/*!
auto mod = IRModule::FromExpr(func);
auto pass_ctx = relay::transform::PassContext::Create();
pass_ctx->opt_level = 3;
- pass_ctx->fallback_device = 1;
+ pass_ctx->config.Set("relay.fallback_device_type", IntImm(DataType::Int(32), 1));
{
tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
tvm::With<tvm::Target> tctx(tvm::Target::Create("llvm"));
def test_runtime(target, device, func, fallback_device=None,
expected_index=None):
params = {"x": x_data, "y": y_data}
- config = {"opt_level": 1}
+ config = {}
if fallback_device:
- config["fallback_device"] = fallback_device
- with relay.build_config(**config):
+ config["relay.fallback_device_type"] = fallback_device.device_type
+ with tvm.transform.PassContext(opt_level=1, config=config):
graph, lib, params = relay.build(
func,
target,
expected_index = [2, 2, 2, 1, 1, 1, 2, 2]
check_annotated_graph(annotated_func, expected_func)
params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data}
- config = {"opt_level": 0}
- config["fallback_device"] = fallback_device
- with relay.build_config(**config):
+ with tvm.transform.PassContext(opt_level=0,
+ config={"relay.fallback_device_type":
+ fallback_device.device_type}):
graph, lib, params = relay.build(annotated_func, target, params=params)
contexts = [tvm.cpu(0), tvm.context(dev)]
graph_json = json.loads(graph)