From 64766c2cb1cd3ce41ee91a54909882036fffd412 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 18 Jun 2020 16:29:21 -0700 Subject: [PATCH] [AutoTVM] Suppress the warning messages when compile engine selects impls (#5821) --- python/tvm/autotvm/env.py | 1 + python/tvm/autotvm/task/dispatcher.py | 13 ++++++------- python/tvm/relay/backend/compile_engine.py | 18 ++++++++++++++++-- tests/python/integration/test_winograd_nnpack.py | 5 +++-- topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py | 3 ++- 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/python/tvm/autotvm/env.py b/python/tvm/autotvm/env.py index b1a5e31..18674d4 100644 --- a/python/tvm/autotvm/env.py +++ b/python/tvm/autotvm/env.py @@ -25,5 +25,6 @@ class AutotvmGlobalScope(object): self.cuda_target_arch = None self.in_tuning = False + self.silent = False GLOBAL_SCOPE = AutotvmGlobalScope() diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 97ee538..736b5f3 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -35,6 +35,7 @@ import logging import numpy as np from .space import FallbackConfigEntity +from .. import env as _env logger = logging.getLogger('autotvm') @@ -47,6 +48,8 @@ class DispatchContext(object): specific dispatch mechanism for templates. """ current = None + # a set to prevent print duplicated message + warning_messages = set() def __init__(self): self._old_ctx = DispatchContext.current @@ -295,21 +298,17 @@ class FallbackContext(DispatchContext): def __init__(self): super(FallbackContext, self).__init__() self.memory = {} - self.silent = False - - # a set to prevent print duplicated message - self.messages = set() def _query_inside(self, target, workload): key = (str(target), workload) if key in self.memory: return self.memory[key] - if not self.silent: + if not _env.GLOBAL_SCOPE.silent: msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\ "is used, which may bring great performance regression." % (target, workload) - if msg not in self.messages: - self.messages.add(msg) + if msg not in DispatchContext.warning_messages: + DispatchContext.warning_messages.add(msg) logger.warning(msg) cfg = FallbackConfigEntity() diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index eb5c2b3..8e6698e 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -30,7 +30,7 @@ from .. import ty as _ty from . import _backend logger = logging.getLogger('compile_engine') - +autotvm_logger = logging.getLogger('autotvm') @tvm._ffi.register_object("relay.LoweredOutput") class LoweredOutput(Object): @@ -190,24 +190,38 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) return best_plevel_impl, outs outputs = {} + workloads = {} best_autotvm_impl = None best_cfg = None dispatch_ctx = autotvm.task.DispatchContext.current + autotvm.GLOBAL_SCOPE.silent = True for impl in all_impls: outs = impl.compute(attrs, inputs, out_type) outputs[impl] = outs workload = autotvm.task.get_workload(outs) + workloads[impl] = workload if workload is None: + # Not an AutoTVM tunable implementation continue cfg = dispatch_ctx.query(target, workload) if cfg.is_fallback: - # It's a fallback config + # Skip fallback config continue if best_cfg is None or best_cfg.cost > cfg.cost: best_autotvm_impl = impl best_cfg = cfg + autotvm.GLOBAL_SCOPE.silent = False if best_autotvm_impl: + # The best autotvm implementation definitely doesn't use fallback config return best_autotvm_impl, outputs[best_autotvm_impl] + # Use the implementation with highest plevel + if workloads[best_plevel_impl] is not None: + msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\ + "is used, which may bring great performance regression." \ + % (target, workloads[best_plevel_impl]) + if msg not in autotvm.task.DispatchContext.warning_messages: + autotvm.task.DispatchContext.warning_messages.add(msg) + autotvm_logger.warning(msg) return best_plevel_impl, outputs[best_plevel_impl] diff --git a/tests/python/integration/test_winograd_nnpack.py b/tests/python/integration/test_winograd_nnpack.py index 7dad2ca..c974496 100644 --- a/tests/python/integration/test_winograd_nnpack.py +++ b/tests/python/integration/test_winograd_nnpack.py @@ -106,7 +106,7 @@ def test_conv2d_nchw(): skip("nnpack is not available") devices = ['llvm -device=arm_cpu'] - autotvm.DispatchContext.current.silent = True + autotvm.GLOBAL_SCOPE.silent = True with WinogradFallback(): # resnet 18 workloads verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, devices=devices) @@ -137,8 +137,9 @@ def test_conv2d_nchw(): # werid workloads verify_conv2d_nchw(1, 3, 3, 3, 3, 1, 1, devices=devices) verify_conv2d_nchw(1, 13, 71, 59, 3, 1, 1, devices=devices) + autotvm.GLOBAL_SCOPE.silent = False if __name__ == "__main__": import pytest - pytest.main() + pytest.main([__file__]) diff --git a/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py index 0fd4205..ef3faae 100644 --- a/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py +++ b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py @@ -61,7 +61,7 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte break ic_block = 8 - autotvm.DispatchContext.current.silent = True + autotvm.GLOBAL_SCOPE.silent = True A = te.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A', dtype='uint8') W = te.placeholder((num_filter//oc_block, in_channel//ic_block//groups, kernel, kernel, ic_block//4, oc_block, 4), name='W', dtype='int8') @@ -103,6 +103,7 @@ def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filte for device in ["llvm -mcpu=skylake-avx512"]: with autotvm.tophub.context(device): # load tophub pre-tuned parameters check_device(device) + autotvm.GLOBAL_SCOPE.silent = False @pytest.mark.skip def test_conv2d_NCHWc(): -- 2.7.4