from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
rpc as _rpc, target as _target
-from ...contrib import nvcc, ndk
+from ...contrib import nvcc, ndk, tar
from ..util import get_const_tuple
from ..env import AutotvmGlobalScope
build_func: callable or str
If is 'default', use default build function
If is 'ndk', use function for android ndk
- If is callable, use it as custom build function
+ If is callable, use it as custom build function, expect lib_format field.
"""
def __init__(self, timeout=10, n_parallel=None, build_func='default'):
super(LocalBuilder, self).__init__(timeout, n_parallel)
if isinstance(build_func, str):
if build_func == 'default':
- build_func = default_build_func
+ build_func = tar.tar
elif build_func == 'ndk':
- build_func = android_ndk_build_func
+ build_func = ndk.create_shared
else:
raise ValueError("Invalid build_func" + build_func)
- self.build_func = build_func
+ self.build_func = _wrap_build_func(build_func)
self.executor = LocalExecutor(timeout=timeout)
self.tmp_dir = tempfile.mkdtemp()
return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
-def default_build_func(measure_input, tmp_dir, **kwargs):
+def _wrap_build_func(build_func):
"""
- Default build func. This can work for cuda, opencl, llvm backend
+ Wrap build_func to a function that can be used in measure.
Parameters
----------
- measure_input: MeasureInput
- The input of measurement
- tmp_dir: str
- The path of temporary directory to export generated library
- """
- tic = time.time()
- try:
- filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64))
- func, arg_info = _build_func_common(measure_input, **kwargs)
- func.export_library(filename)
- except Exception as e: # pylint: disable=broad-except
- return BuildResult(None, None, e, time.time() - tic)
- return BuildResult(filename, arg_info, None, time.time() - tic)
-
-
-def android_ndk_build_func(measure_input, tmp_dir, **kwargs):
- """
- Build function for android device using ndk.
+ build_func : The compilation function
+ We expect fcompile to contain an attr "output_format"
- Parameters
- ----------
- measure_input: MeasureInput
- The input of measurement
- tmp_dir: str
- The path of temporary directory to export generated library
+ Returns
+ -------
+ wrapped_build_func : function
+ The wrapped build function
"""
- tic = time.time()
- try:
- filename = os.path.join(tmp_dir, "tmp_func_%0x.so" % getrandbits(64))
- func, arg_info = _build_func_common(measure_input, **kwargs)
- func.export_library(filename, ndk.create_shared)
- except Exception as e: # pylint: disable=broad-except
- return BuildResult(None, None, e, time.time() - tic)
- return BuildResult(filename, arg_info, None, time.time() - tic)
+ if not hasattr(build_func, "output_format"):
+ raise AttributeError("Expect build_func to have the attribute output_format.")
+ output_format = build_func.output_format
+
+ def _wrapped(measure_input, tmp_dir, **kwargs):
+ """
+ Wrapped build func.
+
+ Parameters
+ ----------
+ measure_input: MeasureInput
+ The input of measurement
+
+ tmp_dir: str
+ The path of temporary directory to export generated library
+ """
+ tic = time.time()
+ try:
+ filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
+ getrandbits(64), output_format))
+ # TODO(tvm-team) consider linline _build_func_common
+ func, arg_info = _build_func_common(measure_input, **kwargs)
+ func.export_library(filename, build_func)
+ except Exception as e: # pylint: disable=broad-except
+ return BuildResult(None, None, e, time.time() - tic)
+ return BuildResult(filename, arg_info, None, time.time() - tic)
+ return _wrapped
def run_through_rpc(measure_input, build_result,
cc : str, optional
The compile string.
"""
- if sys.platform == "darwin" or sys.platform.startswith('linux'):
+ if sys.platform == "darwin" or sys.platform.startswith("linux"):
_linux_shared(output, objects, options, cc)
elif sys.platform == "win32":
_windows_shared(output, objects, options)
raise ValueError("Unsupported platform")
+# assign so as default output format
+create_shared.output_format = "so" if sys.platform != "win32" else "dll"
+
+
+def cross_compiler(cc, options=None, output_format="so"):
+ """Create a cross compiler function.
+
+ Parameters
+ ----------
+ cc : str
+ The cross compiler name.
+
+ options : list, optional
+ List of additional optional string.
+
+ output_format : str, optional
+ Library output format.
+
+ Returns
+ -------
+ fcompile : function
+ A compilation function that can be passed to export_library.
+ """
+ def _fcompile(outputs, objects, opts=None):
+ opts = opts if opts else []
+ if options:
+ opts += options
+ _linux_shared(outputs, objects, opts, cc=cc)
+ _fcompile.output_format = output_format
+ return _fcompile
+
+
def _linux_shared(output, objects, options, cc="g++"):
cmd = [cc]
cmd += ["-shared", "-fPIC"]