[REFACTOR][RELAY] move fallback_device to config (#5690)
authorZhi <5145158+zhiics@users.noreply.github.com>
Fri, 29 May 2020 14:52:03 +0000 (07:52 -0700)
committerGitHub <noreply@github.com>
Fri, 29 May 2020 14:52:03 +0000 (07:52 -0700)
include/tvm/ir/transform.h
python/tvm/ir/transform.py
python/tvm/relay/transform/transform.py
src/ir/transform.cc
src/relay/backend/build_module.cc
src/relay/ir/transform.cc
tests/cpp/relay_transform_sequential.cc
tests/python/relay/test_pass_annotation.py

index 6c8bad2..4c36c7c 100644 (file)
@@ -92,9 +92,6 @@ class PassContextNode : public Object {
   /*! \brief The default optimization level. */
   int opt_level{2};
 
-  /*! \brief CPU is the default fallback device for heterogeneous execution. */
-  int fallback_device{static_cast<int>(kDLCPU)};
-
   /*! \brief The list of required passes. */
   Array<String> required_pass;
   /*! \brief The list of disabled passes. */
@@ -139,7 +136,6 @@ class PassContextNode : public Object {
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("opt_level", &opt_level);
-    v->Visit("fallback_device", &fallback_device);
     v->Visit("required_pass", &required_pass);
     v->Visit("disabled_pass", &disabled_pass);
     v->Visit("config", &config);
@@ -157,7 +153,6 @@ class PassContextNode : public Object {
  *
  *  auto new_ctx = PassContext::Create();
  *  ctx->opt_level = 2;
- *  ctx->fallback_device = kDLCPU;
  *  With<PassContext> scope(ctx);
  *  // pass context in effect.
  *
index 5f49092..358ad19 100644 (file)
@@ -21,9 +21,7 @@ import inspect
 import functools
 
 import tvm._ffi
-
 import tvm.runtime
-from tvm.runtime import ndarray as _nd
 
 from . import _ffi_transform_api
 
@@ -61,10 +59,6 @@ class PassContext(tvm.runtime.Object):
     opt_level : Optional[int]
         The optimization level of this pass.
 
-    fallback_device : Optional[Union[int, str, TVMContext]]
-        The fallback device type. It is also used as the default device for
-        operators that are not annotated during heterogeneous execution.
-
     required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
         The list of passes that are required by a certain pass.
 
@@ -76,19 +70,10 @@ class PassContext(tvm.runtime.Object):
     """
     def __init__(self,
                  opt_level=2,
-                 fallback_device=_nd.cpu(),
                  required_pass=None,
                  disabled_pass=None,
                  trace=None,
                  config=None):
-        if isinstance(fallback_device, str):
-            fallback_device = _nd.context(fallback_device).device_type
-        elif isinstance(fallback_device, tvm.runtime.TVMContext):
-            fallback_device = fallback_device.device_type
-        if not isinstance(fallback_device, int):
-            raise TypeError("fallback_device is expected to be the type of " +
-                            "int/str/TVMContext.")
-
         required = list(required_pass) if required_pass else []
         if not isinstance(required, (list, tuple)):
             raise TypeError("required_pass is expected to be the type of " +
@@ -101,8 +86,7 @@ class PassContext(tvm.runtime.Object):
 
         config = config if config else None
         self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level,
-                                            fallback_device, required,
-                                            disabled, trace, config)
+                                            required, disabled, trace, config)
 
     def __enter__(self):
         _ffi_transform_api.EnterPassContext(self)
index 7222ff2..19ddb32 100644 (file)
@@ -31,7 +31,6 @@ from . import _ffi_api
 
 
 def build_config(opt_level=2,
-                 fallback_device=_nd.cpu(),
                  required_pass=None,
                  disabled_pass=None,
                  trace=None):
@@ -59,10 +58,6 @@ def build_config(opt_level=2,
                 "FastMath": 4
             }
 
-    fallback_device : int, str, or tvmContext, optional
-        The fallback device. It is also used as the default device for
-        operators without specified device during heterogeneous execution.
-
     required_pass: set of str, optional
         Optimization passes that are required regardless of optimization level.
 
@@ -77,9 +72,8 @@ def build_config(opt_level=2,
     pass_context: PassContext
         The pass context for optimizations.
     """
-    return tvm.ir.transform.PassContext(
-        opt_level, fallback_device, required_pass,
-        disabled_pass, trace)
+    return tvm.ir.transform.PassContext(opt_level, required_pass,
+                                        disabled_pass, trace)
 
 
 @tvm._ffi.register_object("relay.FunctionPass")
index 1dbad1a..9eb327d 100644 (file)
@@ -454,12 +454,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 TVM_REGISTER_NODE_TYPE(PassContextNode);
 
 TVM_REGISTER_GLOBAL("transform.PassContext")
-    .set_body_typed([](int opt_level, int fallback_device, Array<String> required,
-                       Array<String> disabled, TraceFunc trace_func,
-                       Optional<Map<std::string, ObjectRef>> config) {
+    .set_body_typed([](int opt_level, Array<String> required, Array<String> disabled,
+                       TraceFunc trace_func, Optional<Map<std::string, ObjectRef>> config) {
       auto pctx = PassContext::Create();
       pctx->opt_level = opt_level;
-      pctx->fallback_device = fallback_device;
 
       pctx->required_pass = std::move(required);
       pctx->disabled_pass = std::move(disabled);
@@ -477,7 +475,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "Pass context information: "
                 << "\n";
       p->stream << "\topt_level: " << node->opt_level << "\n";
-      p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n";
 
       p->stream << "\trequired passes: [";
       for (const auto& it : node->required_pass) {
index fe53336..abce068 100644 (file)
@@ -304,7 +304,12 @@ class RelayBuildModule : public runtime::ModuleNode {
     // Handle heterogeneous compilation.
     transform::PassContext pass_ctx = PassContext::Current();
     if (targets_.size() > 1) {
-      relay_module = RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device);
+      Optional<IntImm> opt_fallback_dev =
+          pass_ctx->GetConfig("relay.fallback_device_type",
+                              IntImm(runtime::DataType::Int(32), static_cast<int>(kDLCPU)));
+      auto fallback_dev = opt_fallback_dev.value();
+      CHECK_GT(fallback_dev->value, 0U);
+      relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value);
     }
 
     // Fuse the operations if it is needed.
index 6942df2..184ee58 100644 (file)
@@ -30,6 +30,8 @@ namespace tvm {
 namespace relay {
 namespace transform {
 
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.fallback_device_type", IntImm);
+
 class FunctionPass;
 
 /*!
index e01a6ea..60d3a5e 100644 (file)
@@ -70,7 +70,7 @@ TEST(Relay, Sequential) {
   auto mod = IRModule::FromExpr(func);
   auto pass_ctx = relay::transform::PassContext::Create();
   pass_ctx->opt_level = 3;
-  pass_ctx->fallback_device = 1;
+  pass_ctx->config.Set("relay.fallback_device_type", IntImm(DataType::Int(32), 1));
   {
     tvm::With<relay::transform::PassContext> ctx_scope(pass_ctx);
     tvm::With<tvm::Target> tctx(tvm::Target::Create("llvm"));
index 582d46a..0ecb2b5 100644 (file)
@@ -344,10 +344,10 @@ def run_fusible_network(dev, tgt):
     def test_runtime(target, device, func, fallback_device=None,
                      expected_index=None):
         params = {"x": x_data, "y": y_data}
-        config = {"opt_level": 1}
+        config = {}
         if fallback_device:
-            config["fallback_device"] = fallback_device
-        with relay.build_config(**config):
+            config["relay.fallback_device_type"] = fallback_device.device_type
+        with tvm.transform.PassContext(opt_level=1, config=config):
             graph, lib, params = relay.build(
                 func,
                 target,
@@ -538,9 +538,9 @@ def run_unpropagatable_graph(dev, tgt):
     expected_index = [2, 2, 2, 1, 1, 1, 2, 2]
     check_annotated_graph(annotated_func, expected_func)
     params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data}
-    config = {"opt_level": 0}
-    config["fallback_device"] = fallback_device
-    with relay.build_config(**config):
+    with tvm.transform.PassContext(opt_level=0,
+                                   config={"relay.fallback_device_type":
+                                           fallback_device.device_type}):
         graph, lib, params = relay.build(annotated_func, target, params=params)
         contexts = [tvm.cpu(0), tvm.context(dev)]
         graph_json = json.loads(graph)