[AutoTVM] Suppress the warning messages when compile engine selects impls (#5821)
authorHaichen Shen <shenhaichen@gmail.com>
Thu, 18 Jun 2020 23:29:21 +0000 (16:29 -0700)
committerGitHub <noreply@github.com>
Thu, 18 Jun 2020 23:29:21 +0000 (16:29 -0700)
python/tvm/autotvm/env.py
python/tvm/autotvm/task/dispatcher.py
python/tvm/relay/backend/compile_engine.py
tests/python/integration/test_winograd_nnpack.py
topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py

index b1a5e31..18674d4 100644 (file)
@@ -25,5 +25,6 @@ class AutotvmGlobalScope(object):
 
         self.cuda_target_arch = None
         self.in_tuning = False
+        self.silent = False
 
 GLOBAL_SCOPE = AutotvmGlobalScope()
index 97ee538..736b5f3 100644 (file)
@@ -35,6 +35,7 @@ import logging
 import numpy as np
 
 from .space import FallbackConfigEntity
+from .. import env as _env
 
 logger = logging.getLogger('autotvm')
 
@@ -47,6 +48,8 @@ class DispatchContext(object):
     specific dispatch mechanism for templates.
     """
     current = None
+    # a set to prevent print duplicated message
+    warning_messages = set()
 
     def __init__(self):
         self._old_ctx = DispatchContext.current
@@ -295,21 +298,17 @@ class FallbackContext(DispatchContext):
     def __init__(self):
         super(FallbackContext, self).__init__()
         self.memory = {}
-        self.silent = False
-
-        # a set to prevent print duplicated message
-        self.messages = set()
 
     def _query_inside(self, target, workload):
         key = (str(target), workload)
         if key in self.memory:
             return self.memory[key]
 
-        if not self.silent:
+        if not _env.GLOBAL_SCOPE.silent:
             msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\
                   "is used, which may bring great performance regression." % (target, workload)
-            if msg not in self.messages:
-                self.messages.add(msg)
+            if msg not in DispatchContext.warning_messages:
+                DispatchContext.warning_messages.add(msg)
                 logger.warning(msg)
         cfg = FallbackConfigEntity()
 
index eb5c2b3..8e6698e 100644 (file)
@@ -30,7 +30,7 @@ from .. import ty as _ty
 from . import _backend
 
 logger = logging.getLogger('compile_engine')
-
+autotvm_logger = logging.getLogger('autotvm')
 
 @tvm._ffi.register_object("relay.LoweredOutput")
 class LoweredOutput(Object):
@@ -190,24 +190,38 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
         return best_plevel_impl, outs
 
     outputs = {}
+    workloads = {}
     best_autotvm_impl = None
     best_cfg = None
     dispatch_ctx = autotvm.task.DispatchContext.current
+    autotvm.GLOBAL_SCOPE.silent = True
     for impl in all_impls:
         outs = impl.compute(attrs, inputs, out_type)
         outputs[impl] = outs
         workload = autotvm.task.get_workload(outs)
+        workloads[impl] = workload
         if workload is None:
+            # Not an AutoTVM tunable implementation
             continue
         cfg = dispatch_ctx.query(target, workload)
         if cfg.is_fallback:
-            # It's a fallback config
+            # Skip fallback config
             continue
         if best_cfg is None or best_cfg.cost > cfg.cost:
             best_autotvm_impl = impl
             best_cfg = cfg
+    autotvm.GLOBAL_SCOPE.silent = False
     if best_autotvm_impl:
+        # The best autotvm implementation definitely doesn't use fallback config
         return best_autotvm_impl, outputs[best_autotvm_impl]
+    # Use the implementation with highest plevel
+    if workloads[best_plevel_impl] is not None:
+        msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\
+              "is used, which may bring great performance regression." \
+              % (target, workloads[best_plevel_impl])
+        if msg not in autotvm.task.DispatchContext.warning_messages:
+            autotvm.task.DispatchContext.warning_messages.add(msg)
+            autotvm_logger.warning(msg)
     return best_plevel_impl, outputs[best_plevel_impl]
 
 
index 7dad2ca..c974496 100644 (file)
@@ -106,7 +106,7 @@ def test_conv2d_nchw():
         skip("nnpack is not available")
 
     devices = ['llvm -device=arm_cpu']
-    autotvm.DispatchContext.current.silent = True
+    autotvm.GLOBAL_SCOPE.silent = True
     with WinogradFallback():
         # resnet 18 workloads
         verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, devices=devices)
@@ -137,8 +137,9 @@ def test_conv2d_nchw():
         # werid workloads
         verify_conv2d_nchw(1, 3, 3, 3, 3, 1, 1, devices=devices)
         verify_conv2d_nchw(1, 13, 71, 59, 3, 1, 1, devices=devices)
+    autotvm.GLOBAL_SCOPE.silent = False
 
 
 if __name__ == "__main__":
     import pytest
-    pytest.main()
+    pytest.main([__file__])
index 0fd4205..ef3faae 100644 (file)
@@ -61,7 +61,7 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte
             break
 
     ic_block = 8
-    autotvm.DispatchContext.current.silent = True
+    autotvm.GLOBAL_SCOPE.silent = True
     A = te.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A', dtype='uint8')
     W = te.placeholder((num_filter//oc_block, in_channel//ic_block//groups, kernel, kernel, ic_block//4, oc_block, 4), name='W', dtype='int8')
 
@@ -103,6 +103,7 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte
     for device in ["llvm -mcpu=skylake-avx512"]:
         with autotvm.tophub.context(device):  # load tophub pre-tuned parameters
             check_device(device)
+    autotvm.GLOBAL_SCOPE.silent = False
 
 @pytest.mark.skip
 def test_conv2d_NCHWc():