From 81118023e193f6ec81db8f9d56494e118b5f31a1 Mon Sep 17 00:00:00 2001 From: Jon Soifer Date: Thu, 3 Oct 2019 11:55:17 -0700 Subject: [PATCH] [Relay][TopHub] Add switch to disable TopHub download (#4015) --- python/tvm/autotvm/tophub.py | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index aa49bea..c290063 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -18,7 +18,8 @@ TopHub: Tensor Operator Hub To get the best performance, we typically need auto-tuning for the specific devices. TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets. -TVM will download these parameters for you when you call nnvm.compiler.build_module . +TVM will download these parameters for you when you call +nnvm.compiler.build_module or relay.build. """ # pylint: disable=invalid-name @@ -30,6 +31,16 @@ from .task import ApplyHistoryBest from .. import target as _target from ..contrib.download import download from .record import load_from_file +from .util import EmptyContext + +# environment variable to read TopHub location +AUTOTVM_TOPHUB_LOC_VAR = "TOPHUB_LOCATION" + +# default location of TopHub +AUTOTVM_TOPHUB_DEFAULT_LOC = "https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub" + +# value of AUTOTVM_TOPHUB_LOC_VAR to specify to not read from TopHub +AUTOTVM_TOPHUB_NONE_LOC = "NONE" # root path to store TopHub files AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub") @@ -61,6 +72,9 @@ def _alias(name): } return table.get(name, name) +def _get_tophub_location(): + location = os.getenv(AUTOTVM_TOPHUB_LOC_VAR, None) + return AUTOTVM_TOPHUB_DEFAULT_LOC if location is None else location def context(target, extra_files=None): """Return the dispatch context with pre-tuned parameters. @@ -75,6 +89,10 @@ def context(target, extra_files=None): extra_files: list of str, optional Extra log files to load """ + tophub_location = _get_tophub_location() + if tophub_location == AUTOTVM_TOPHUB_NONE_LOC: + return EmptyContext() + best_context = ApplyHistoryBest([]) targets = target if isinstance(target, (list, tuple)) else [target] @@ -94,7 +112,7 @@ def context(target, extra_files=None): for name in possible_names: name = _alias(name) if name in all_packages: - if not check_backend(name): + if not check_backend(tophub_location, name): continue filename = "%s_%s.log" % (name, PACKAGE_VERSION[name]) @@ -108,7 +126,7 @@ def context(target, extra_files=None): return best_context -def check_backend(backend): +def check_backend(tophub_location, backend): """Check whether have pre-tuned parameters of the certain target. If not, will download it. @@ -135,18 +153,21 @@ def check_backend(backend): else: import urllib2 try: - download_package(package_name) + download_package(tophub_location, package_name) return True except urllib2.URLError as e: logging.warning("Failed to download tophub package for %s: %s", backend, e) return False -def download_package(package_name): +def download_package(tophub_location, package_name): """Download pre-tuned parameters of operators for a backend Parameters ---------- + tophub_location: str + The location to download TopHub parameters from + package_name: str The name of package """ @@ -160,9 +181,9 @@ def download_package(package_name): if not os.path.isdir(path): os.mkdir(path) - logger.info("Download pre-tuned parameters package %s", package_name) - download("https://raw.githubusercontent.com/uwsampl/tvm-distro/master/tophub/%s" - % package_name, os.path.join(rootpath, package_name), True, verbose=0) + download_url = "{0}/{1}".format(tophub_location, package_name) + logger.info("Download pre-tuned parameters package from %s", download_url) + download(download_url, os.path.join(rootpath, package_name), True, verbose=0) # global cache for load_reference_log -- 2.7.4