[AUTOTVM] Refactor measure build func (#2927)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sat, 30 Mar 2019 13:33:28 +0000 (09:33 -0400)
committerGitHub <noreply@github.com>
Sat, 30 Mar 2019 13:33:28 +0000 (09:33 -0400)
python/tvm/autotvm/measure/measure_methods.py
python/tvm/contrib/cc.py
python/tvm/contrib/tar.py
python/tvm/contrib/xcode.py

index f77a13b..7f65f2e 100644 (file)
@@ -19,7 +19,7 @@ import numpy as np
 
 from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
     rpc as _rpc, target as _target
-from ...contrib import nvcc, ndk
+from ...contrib import nvcc, ndk, tar
 
 from ..util import get_const_tuple
 from ..env import AutotvmGlobalScope
@@ -58,20 +58,20 @@ class LocalBuilder(Builder):
     build_func: callable or str
         If is 'default', use default build function
         If is 'ndk', use function for android ndk
-        If is callable, use it as custom build function
+        If is callable, use it as custom build function, expect lib_format field.
     """
     def __init__(self, timeout=10, n_parallel=None, build_func='default'):
         super(LocalBuilder, self).__init__(timeout, n_parallel)
 
         if isinstance(build_func, str):
             if build_func == 'default':
-                build_func = default_build_func
+                build_func = tar.tar
             elif build_func == 'ndk':
-                build_func = android_ndk_build_func
+                build_func = ndk.create_shared
             else:
                 raise ValueError("Invalid build_func" + build_func)
 
-        self.build_func = build_func
+        self.build_func = _wrap_build_func(build_func)
         self.executor = LocalExecutor(timeout=timeout)
         self.tmp_dir = tempfile.mkdtemp()
 
@@ -349,46 +349,47 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti
     return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
 
 
-def default_build_func(measure_input, tmp_dir, **kwargs):
+def _wrap_build_func(build_func):
     """
-    Default build func. This can work for cuda, opencl, llvm backend
+    Wrap build_func to a function that can be used in measure.
 
     Parameters
     ----------
-    measure_input: MeasureInput
-        The input of measurement
-    tmp_dir: str
-        The path of temporary directory to export generated library
-    """
-    tic = time.time()
-    try:
-        filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
-        func, arg_info = _build_func_common(measure_input, **kwargs)
-        func.export_library(filename)
-    except Exception as e:  # pylint: disable=broad-except
-        return BuildResult(None, None, e, time.time() - tic)
-    return BuildResult(filename, arg_info, None, time.time() - tic)
-
-
-def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
-    """
-    Build function for android device using ndk.
+    build_func : The compilation function
+        We expect fcompile to contain an attr "output_format"
 
-    Parameters
-    ----------
-    measure_input: MeasureInput
-        The input of measurement
-    tmp_dir: str
-        The path of temporary directory to export generated library
+    Returns
+    -------
+    wrapped_build_func : function
+        The wrapped build function
     """
-    tic = time.time()
-    try:
-        filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64))
-        func, arg_info = _build_func_common(measure_input, **kwargs)
-        func.export_library(filename, ndk.create_shared)
-    except Exception as e:  # pylint: disable=broad-except
-        return BuildResult(None, None, e, time.time() - tic)
-    return BuildResult(filename, arg_info, None, time.time() - tic)
+    if not hasattr(build_func, "output_format"):
+        raise AttributeError("Expect build_func to have the attribute output_format.")
+    output_format = build_func.output_format
+
+    def _wrapped(measure_input, tmp_dir, **kwargs):
+        """
+        Wrapped build func.
+
+        Parameters
+        ----------
+        measure_input: MeasureInput
+            The input of measurement
+
+        tmp_dir: str
+            The path of temporary directory to export generated library
+        """
+        tic = time.time()
+        try:
+            filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
+                getrandbits(64), output_format))
+            # TODO(tvm-team) consider linline _build_func_common
+            func, arg_info = _build_func_common(measure_input, **kwargs)
+            func.export_library(filename, build_func)
+        except Exception as e:  # pylint: disable=broad-except
+            return BuildResult(None, None, e, time.time() - tic)
+        return BuildResult(filename, arg_info, None, time.time() - tic)
+    return _wrapped
 
 
 def run_through_rpc(measure_input, build_result,
index ee84da8..09822e5 100644 (file)
@@ -29,7 +29,7 @@ def create_shared(output,
     cc : str, optional
         The compile string.
     """
-    if sys.platform == "darwin" or sys.platform.startswith('linux'):
+    if sys.platform == "darwin" or sys.platform.startswith("linux"):
         _linux_shared(output, objects, options, cc)
     elif sys.platform == "win32":
         _windows_shared(output, objects, options)
@@ -37,6 +37,38 @@ def create_shared(output,
         raise ValueError("Unsupported platform")
 
 
+# assign so as default output format
+create_shared.output_format = "so" if sys.platform != "win32" else "dll"
+
+
+def cross_compiler(cc, options=None, output_format="so"):
+    """Create a cross compiler function.
+
+    Parameters
+    ----------
+    cc :  str
+        The cross compiler name.
+
+    options : list, optional
+        List of additional optional string.
+
+    output_format : str, optional
+        Library output format.
+
+    Returns
+    -------
+    fcompile : function
+        A compilation function that can be passed to export_library.
+    """
+    def _fcompile(outputs, objects, opts=None):
+        opts = opts if opts else []
+        if options:
+            opts += options
+        _linux_shared(outputs, objects, opts, cc=cc)
+    _fcompile.output_format = output_format
+    return _fcompile
+
+
 def _linux_shared(output, objects, options, cc="g++"):
     cmd = [cc]
     cmd += ["-shared", "-fPIC"]
index 7e075d7..741a914 100644 (file)
@@ -42,6 +42,9 @@ def tar(output, files):
         msg += py_str(out)
         raise RuntimeError(msg)
 
+# assign output format
+tar.output_format = "tar"
+
 
 def untar(tar_file, directory):
     """Unpack all tar files into the directory
index a43dc9a..99f5938 100644 (file)
@@ -98,6 +98,9 @@ def create_dylib(output, objects, arch, sdk="macosx"):
         raise RuntimeError(msg)
 
 
+# assign so as default output format
+create_dylib.output_format = "dylib"
+
 def compile_metal(code, path_target=None, sdk="macosx"):
     """Compile metal with CLI tool from env.