From ddf7190846b23d6f55b2bb295e470e338a292e7c Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Fri, 29 May 2020 07:52:03 -0700 Subject: [PATCH] [REFACTOR][RELAY] move fallback_device to config (#5690) --- include/tvm/ir/transform.h | 5 ----- python/tvm/ir/transform.py | 18 +----------------- python/tvm/relay/transform/transform.py | 10 ++-------- src/ir/transform.cc | 7 ++----- src/relay/backend/build_module.cc | 7 ++++++- src/relay/ir/transform.cc | 2 ++ tests/cpp/relay_transform_sequential.cc | 2 +- tests/python/relay/test_pass_annotation.py | 12 ++++++------ 8 files changed, 20 insertions(+), 43 deletions(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 6c8bad2..4c36c7c 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -92,9 +92,6 @@ class PassContextNode : public Object { /*! \brief The default optimization level. */ int opt_level{2}; - /*! \brief CPU is the default fallback device for heterogeneous execution. */ - int fallback_device{static_cast(kDLCPU)}; - /*! \brief The list of required passes. */ Array required_pass; /*! \brief The list of disabled passes. */ @@ -139,7 +136,6 @@ class PassContextNode : public Object { void VisitAttrs(AttrVisitor* v) { v->Visit("opt_level", &opt_level); - v->Visit("fallback_device", &fallback_device); v->Visit("required_pass", &required_pass); v->Visit("disabled_pass", &disabled_pass); v->Visit("config", &config); @@ -157,7 +153,6 @@ class PassContextNode : public Object { * * auto new_ctx = PassContext::Create(); * ctx->opt_level = 2; - * ctx->fallback_device = kDLCPU; * With scope(ctx); * // pass context in effect. * diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 5f49092..358ad19 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -21,9 +21,7 @@ import inspect import functools import tvm._ffi - import tvm.runtime -from tvm.runtime import ndarray as _nd from . import _ffi_transform_api @@ -61,10 +59,6 @@ class PassContext(tvm.runtime.Object): opt_level : Optional[int] The optimization level of this pass. - fallback_device : Optional[Union[int, str, TVMContext]] - The fallback device type. It is also used as the default device for - operators that are not annotated during heterogeneous execution. - required_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are required by a certain pass. @@ -76,19 +70,10 @@ class PassContext(tvm.runtime.Object): """ def __init__(self, opt_level=2, - fallback_device=_nd.cpu(), required_pass=None, disabled_pass=None, trace=None, config=None): - if isinstance(fallback_device, str): - fallback_device = _nd.context(fallback_device).device_type - elif isinstance(fallback_device, tvm.runtime.TVMContext): - fallback_device = fallback_device.device_type - if not isinstance(fallback_device, int): - raise TypeError("fallback_device is expected to be the type of " + - "int/str/TVMContext.") - required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): raise TypeError("required_pass is expected to be the type of " + @@ -101,8 +86,7 @@ class PassContext(tvm.runtime.Object): config = config if config else None self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level, - fallback_device, required, - disabled, trace, config) + required, disabled, trace, config) def __enter__(self): _ffi_transform_api.EnterPassContext(self) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 7222ff2..19ddb32 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -31,7 +31,6 @@ from . import _ffi_api def build_config(opt_level=2, - fallback_device=_nd.cpu(), required_pass=None, disabled_pass=None, trace=None): @@ -59,10 +58,6 @@ def build_config(opt_level=2, "FastMath": 4 } - fallback_device : int, str, or tvmContext, optional - The fallback device. It is also used as the default device for - operators without specified device during heterogeneous execution. - required_pass: set of str, optional Optimization passes that are required regardless of optimization level. @@ -77,9 +72,8 @@ def build_config(opt_level=2, pass_context: PassContext The pass context for optimizations. """ - return tvm.ir.transform.PassContext( - opt_level, fallback_device, required_pass, - disabled_pass, trace) + return tvm.ir.transform.PassContext(opt_level, required_pass, + disabled_pass, trace) @tvm._ffi.register_object("relay.FunctionPass") diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 1dbad1a..9eb327d 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -454,12 +454,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") - .set_body_typed([](int opt_level, int fallback_device, Array required, - Array disabled, TraceFunc trace_func, - Optional> config) { + .set_body_typed([](int opt_level, Array required, Array disabled, + TraceFunc trace_func, Optional> config) { auto pctx = PassContext::Create(); pctx->opt_level = opt_level; - pctx->fallback_device = fallback_device; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); @@ -477,7 +475,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Pass context information: " << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n"; - p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n"; p->stream << "\trequired passes: ["; for (const auto& it : node->required_pass) { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index fe53336..abce068 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -304,7 +304,12 @@ class RelayBuildModule : public runtime::ModuleNode { // Handle heterogeneous compilation. transform::PassContext pass_ctx = PassContext::Current(); if (targets_.size() > 1) { - relay_module = RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); + Optional opt_fallback_dev = + pass_ctx->GetConfig("relay.fallback_device_type", + IntImm(runtime::DataType::Int(32), static_cast(kDLCPU))); + auto fallback_dev = opt_fallback_dev.value(); + CHECK_GT(fallback_dev->value, 0U); + relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value); } // Fuse the operations if it is needed. diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 6942df2..184ee58 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -30,6 +30,8 @@ namespace tvm { namespace relay { namespace transform { +TVM_REGISTER_PASS_CONFIG_OPTION("relay.fallback_device_type", IntImm); + class FunctionPass; /*! diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index e01a6ea..60d3a5e 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -70,7 +70,7 @@ TEST(Relay, Sequential) { auto mod = IRModule::FromExpr(func); auto pass_ctx = relay::transform::PassContext::Create(); pass_ctx->opt_level = 3; - pass_ctx->fallback_device = 1; + pass_ctx->config.Set("relay.fallback_device_type", IntImm(DataType::Int(32), 1)); { tvm::With ctx_scope(pass_ctx); tvm::With tctx(tvm::Target::Create("llvm")); diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 582d46a..0ecb2b5 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -344,10 +344,10 @@ def run_fusible_network(dev, tgt): def test_runtime(target, device, func, fallback_device=None, expected_index=None): params = {"x": x_data, "y": y_data} - config = {"opt_level": 1} + config = {} if fallback_device: - config["fallback_device"] = fallback_device - with relay.build_config(**config): + config["relay.fallback_device_type"] = fallback_device.device_type + with tvm.transform.PassContext(opt_level=1, config=config): graph, lib, params = relay.build( func, target, @@ -538,9 +538,9 @@ def run_unpropagatable_graph(dev, tgt): expected_index = [2, 2, 2, 1, 1, 1, 2, 2] check_annotated_graph(annotated_func, expected_func) params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data} - config = {"opt_level": 0} - config["fallback_device"] = fallback_device - with relay.build_config(**config): + with tvm.transform.PassContext(opt_level=0, + config={"relay.fallback_device_type": + fallback_device.device_type}): graph, lib, params = relay.build(annotated_func, target, params=params) contexts = [tvm.cpu(0), tvm.context(dev)] graph_json = json.loads(graph) -- 2.7.4