[Relay][Compilation] replace relay.build_module with C++ BuildModule (#3174)
authorZhi <5145158+zhiics@users.noreply.github.com>
Thu, 16 May 2019 00:28:18 +0000 (17:28 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 16 May 2019 00:28:18 +0000 (17:28 -0700)
13 files changed:
python/tvm/relay/__init__.py
python/tvm/relay/_build_module.py [new file with mode: 0644]
python/tvm/relay/backend/graph_runtime_codegen.py
python/tvm/relay/build_module.py
python/tvm/relay/quantize/quantize.py
src/codegen/build_module.cc
src/relay/backend/build_module.cc
src/relay/backend/graph_runtime_codegen.cc
tests/cpp/relay_build_module_test.cc
tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py
tests/python/relay/test_cpp_build_module.py
tests/python/relay/test_pass_annotation.py
tests/python/relay/test_pass_quantize.py

index 6201681..1f1e4a6 100644 (file)
@@ -25,7 +25,7 @@ from . import expr_functor
 from . import module
 from . import adt
 from . import ir_pass
-from .build_module import build, build_config, create_executor, optimize
+from .build_module import build, build_config, create_executor
 from . import prelude
 from . import parser
 from . import debug
diff --git a/python/tvm/relay/_build_module.py b/python/tvm/relay/_build_module.py
new file mode 100644 (file)
index 0000000..bdbcbef
--- /dev/null
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
+"""The interface for building Relay functions exposed from C++."""
+from tvm._ffi.function import _init_api
+
+_init_api("relay.build_module", __name__)
index ea1846b..cf31e9c 100644 (file)
@@ -36,12 +36,9 @@ contrib.graph_runtime or any other TVM runtime compatible systems.
 from __future__ import absolute_import
 
 from tvm.ndarray import empty
-from tvm._ffi.function import _init_api
-
 from tvm.relay import build_module
 from tvm import target as _target
-
-_init_api("tvm.relay.build_module")
+from tvm import expr as _expr
 
 class GraphRuntimeCodegen(object):
     """The compiler from Relay to the TVM runtime system."""
@@ -57,17 +54,14 @@ class GraphRuntimeCodegen(object):
         self._setup(mod, target)
 
     def _setup(self, mod, target):
-        tgts = []
+        tgts = {}
         if isinstance(target, dict):
-            for kv in target.items():
-                tgts.append(kv[0])
-                if isinstance(kv[1], (str, _target.Target)):
-                    tgts.append(str(kv[1]))
-                else:
+            for dev, tgt in target.items():
+                if not isinstance(tgt, (str, _target.Target)):
                     raise Exception("Unknown target type")
+                tgts[dev] = _target.create(tgt)
         elif isinstance(target, (str, _target.Target)):
-            tgts.append("0")
-            tgts.append(str(target))
+            tgts[_expr.IntImm("int32", 0)] = _target.create(target)
         self._init(mod, tgts)
 
     def codegen(self, func):
index c8b69e0..d0ad78f 100644 (file)
 Construct the necessary state for the TVM graph runtime
 from a Relay expression.
 """
-import warnings
+import numpy as np
 
 from tvm._ffi.runtime_ctypes import TVMContext
-from ..build_module import build as _tvm_build_module
+from tvm import expr as tvm_expr
 from .. import nd as _nd, target as _target, autotvm
 from ..contrib import graph_runtime as _graph_rt
+from . import _build_module
 from . import ir_pass
-from . import expr as _expr
 from . import ty as _ty
+from . import expr as _expr
 from .backend import interpreter as _interpreter
-from .backend import graph_runtime_codegen as _graph_gen
 from .backend.vm import VMExecutor
 
-# List of optimization pass and level when switch on
-OPT_PASS_LEVEL = {
-    "SimplifyInference": 0,
-    "OpFusion": 1,
-    "FoldConstant": 2,
-    "CombineParallelConv2D": 3,
-    "FoldScaleAxis": 3,
-    "AlterOpLayout": 3,
-    "CanonicalizeOps": 3,
-    "EliminateCommonSubexpr": 3,
-}
-
-
 class BuildConfig(object):
     """Configuration scope to set a build config option.
 
@@ -56,6 +43,7 @@ class BuildConfig(object):
     defaults = {
         "opt_level": 2,
         "add_pass": None,
+        "disable_pass": None,
         "fallback_device": None,
     }
 
@@ -85,23 +73,6 @@ class BuildConfig(object):
         assert self._old_scope
         BuildConfig.current = self._old_scope
 
-    def pass_enabled(self, pass_name):
-        """Get whether pass is enabled.
-
-        Parameters
-        ----------
-        pass_name : str
-            The optimization pass name
-
-        Returns
-        -------
-        enabled : bool
-            Whether pass is enabled.
-        """
-        if self.add_pass and pass_name in self.add_pass:
-            return True
-        return self.opt_level >= OPT_PASS_LEVEL[pass_name]
-
 
 BuildConfig.current = BuildConfig()
 
@@ -117,6 +88,9 @@ def build_config(**kwargs):
     add_pass: set of str
         Optimization pass to be added regardless of optimization level.
 
+    disable_pass: set of str
+        Optimization pass to be disabled during optimization.
+
     fallback_device : str or tvm.TVMContext
         The fallback device. It is also used as the default device for
         operators without specified device during heterogeneous execution.
@@ -129,108 +103,203 @@ def build_config(**kwargs):
     return BuildConfig(**kwargs)
 
 
-def _bind_params_by_name(func, params):
-    """Bind parameters of function by its name."""
-    name_dict = {}
-    for arg in func.params:
-        name = arg.name_hint
-        if name in name_dict:
-            name_dict[name] = None
-        else:
-            name_dict[name] = arg
-    bind_dict = {}
-    for k, v in params.items():
-        if k not in name_dict:
-            continue
-        arg = name_dict[k]
-        if arg is None:
-            raise ValueError("Multiple args in the function have name %s" % k)
-        bind_dict[arg] = _expr.const(v)
-    return _expr.bind(func, bind_dict)
-
-
-def optimize(func, target=None, params=None):
-    """Perform target invariant optimizations.
-
-    Parameters
-    ----------
-    func : tvm.relay.Function
-        The input to optimization.
+def _update_target(target):
+    target = target if target else _target.current_target()
+    if target is None:
+        raise ValueError("Target is not set in env or passed as argument.")
 
-    target : Optional[:any:`tvm.target.Target`, Dict[int, tvm.target.Target]]
-        The optimization target. For heterogeneous compilation, it is a
-        dictionary mapping device type to compilation target. For homogeneous
-        compilation, it is a build target.
+    tgts = {}
+    if isinstance(target, (str, _target.Target)):
+        dev_type = tvm_expr.IntImm("int32", _nd.context(str(target)).device_type)
+        tgts[dev_type] = _target.create(target)
+    elif isinstance(target, dict):
+        for dev, tgt in target.items():
+            dev_type = tvm_expr.IntImm("int32", _nd.context(dev).device_type)
+            tgts[dev_type] = _target.create(tgt)
+    else:
+        raise TypeError("target is expected to be str or " +
+                        "tvm.target.Target, but received " +
+                        "{}".format(type(target)))
+    return tgts
 
-    params : Optional[Dict[str, tvm.nd.NDArray]]
-        Input parameters to the graph that do not change
-        during inference time. used for constant folding.
 
-    Returns
-    -------
-    opt_func : tvm.relay.Function
-        The optimized version of the function.
+class BuildModule(object):
+    """Build a Relay function to run on TVM graph runtime. This class is used
+    to expose the `RelayBuildModule` APIs implemented in C++.
     """
-    cfg = BuildConfig.current
-
-    # bind expressions
-    if params:
-        func = _bind_params_by_name(func, params)
-
-    if cfg.pass_enabled("SimplifyInference"):
-        func = ir_pass.infer_type(func)
-        func = ir_pass.simplify_inference(func)
-
-    if cfg.pass_enabled("EliminateCommonSubexpr"):
-        def fskip(expr):
-            if isinstance(expr, _expr.Call) and expr.op.name == 'cast' and \
-               expr.attrs.dtype == 'int32':
-                return True
-            return False
-
-        func = ir_pass.infer_type(func)
-        func = ir_pass.eliminate_common_subexpr(func, fskip)
-
-    if cfg.pass_enabled("CombineParallelConv2D"):
-        func = ir_pass.infer_type(func)
-        func = ir_pass.combine_parallel_conv2d(func)
-
-    # The constant folding pass is necessary because FoldScaleAxis pass needs
-    # to check the constantness and positiveness of scales.
-    if cfg.pass_enabled("FoldConstant"):
-        func = ir_pass.fold_constant(func)
-
-    if cfg.pass_enabled("FoldScaleAxis"):
-        func = ir_pass.infer_type(func)
-        func = ir_pass.backward_fold_scale_axis(func)
-        func = ir_pass.infer_type(func)
-        func = ir_pass.forward_fold_scale_axis(func)
-        func = ir_pass.fold_constant(func)
-
-    if cfg.pass_enabled("CanonicalizeOps"):
-        func = ir_pass.infer_type(func)
-        func = ir_pass.canonicalize_ops(func)
-
-    # FIXME(zhiics) Skip AlterOpLayout pass for heterogeneous compilation for
-    # now. We probably need to pass target to this pass as well. Fix it in
-    # a followup PR.
-    if cfg.pass_enabled("AlterOpLayout"):
-        if isinstance(target, _target.Target):
-            func = ir_pass.infer_type(func)
-            with target:
-                func = ir_pass.alter_op_layout(func)
-        elif isinstance(target, dict):
-            warnings.warn("AlterOpLayout pass is not enabled for heterogeneous"
-                          " execution yet.")
-
-    if cfg.pass_enabled("FoldConstant"):
-        func = ir_pass.fold_constant(func)
-
-    return func
+    def __init__(self):
+        self.mod = _build_module._BuildModule()
+        self._get_graph_json = self.mod["get_graph_json"]
+        self._get_module = self.mod["get_module"]
+        self._build = self.mod["build"]
+        self._add_pass = self.mod["add_pass"]
+        self._disable_pass = self.mod["disable_pass"]
+        self._set_opt_level = self.mod["set_opt_level"]
+        self._set_fallback_device = self.mod["set_fallback_device"]
+        self._set_params_func = self.mod["set_params"]
+        self._get_params_func = self.mod["get_params"]
+
+    def build(self, func, target=None, target_host=None, params=None):
+        """
+        Parameters
+        ----------
+        func: relay.Function
+            The function to build.
+
+        target : str, :any:`tvm.target.Target`, or dict of str(i.e.
+        device/context name) to str/tvm.target.Target, optional
+            For heterogeneous compilation, it is a dictionary indicating context
+            to target mapping. For homogeneous compilation, it is a build target.
+
+        target_host : str or :any:`tvm.target.Target`, optional
+            Host compilation target, if target is device.
+            When TVM compiles device specific program such as CUDA,
+            we also need host(CPU) side code to interact with the driver
+            to setup the dimensions and parameters correctly.
+            target_host is used to specify the host side codegen target.
+            By default, llvm is used if it is enabled,
+            otherwise a stackvm intepreter is used.
+
+        params : dict of str to NDArray
+            Input parameters to the graph that do not change
+            during inference time. Used for constant folding.
+
+        Returns
+        -------
+        graph_json : str
+            The json string that can be accepted by graph runtime.
+
+        mod : tvm.Module
+            The module containing necessary libraries.
+
+        params : dict
+            The parameters of the final graph.
+        """
+        target = _update_target(target)
+
+        # Setup the build configurations passed in through `with build_config`.
+        self._setup_build_config(params)
+        # Build the function
+        self._build(func, target, target_host)
+        # Get artifacts
+        graph_json = self.get_json()
+        mod = self.get_module()
+        params = self.get_params()
+
+        return graph_json, mod, params
+
+    def _setup_build_config(self, params):
+        cfg = BuildConfig.current
+
+        # Set opt_level.
+        self.set_opt_level(cfg.opt_level)
+
+        # Set fallback device if it is available.
+        if cfg.fallback_device:
+            self.set_fallback_device(cfg.fallback_device)
+
+        # Add required passes.
+        if cfg.add_pass:
+            passes = set()
+            if isinstance(cfg.add_pass, (list, tuple, set)):
+                passes = set(cfg.add_pass)
+            else:
+                raise TypeError("add_pass must be list, tuple, or set, but " +
+                                "got {}".format(type(cfg.add_pass)))
+            for pass_name in passes:
+                self.add_pass(pass_name)
+
+        # Add disabled passes.
+        if cfg.disable_pass:
+            passes = set()
+            if isinstance(cfg.disable_pass, (list, tuple, set)):
+                passes = set(cfg.disable_pass)
+            else:
+                raise TypeError("disable_pass must be list, tuple, or set, " +
+                                "but got {}".format(type(cfg.disable_pass)))
+            for pass_name in passes:
+                self.disable_pass(pass_name)
+
+        if params:
+            self._set_params(params)
+
+    def _set_params(self, params):
+        inputs = {}
+        for name, param in params.items():
+            if isinstance(param, np.ndarray):
+                param = _nd.array(param)
+            inputs[name] = _expr.const(param)
+        self._set_params_func(inputs)
+
+    def add_pass(self, pass_name):
+        """Add a pass to the pass list.
+
+        Parameters
+        ----------
+        pass_name : str
+            The name of the pass that will be added to the list of passes used
+            for optimizations.
+        """
+        self._add_pass(pass_name)
+
+    def disable_pass(self, pass_name):
+        """Add a pass to the disabled pass list.
+
+        Parameters
+        ----------
+        pass_name : str
+            The name of a pass. This pass will be added to the list of passes
+            that are disabled during optimization.
+        """
+        self._disable_pass(pass_name)
+
+    def get_json(self):
+        """Return the json file of the built program."""
+        return self._get_graph_json()
+
+    def get_module(self):
+        """Return the built module."""
+        return self._get_module()
+
+    def get_params(self):
+        """Return the updated weights."""
+        params = self._get_params_func()
+        ret = {}
+        for key, value in params.items():
+            ret[key] = value.data
+        return ret
+
+    def set_opt_level(self, level):
+        """Set the optimization level.
+
+        Parameters
+        ----------
+        level : int
+            The optimization level for build.
+        """
+        self._set_opt_level(level)
+
+    def set_fallback_device(self, fallback_device):
+        """Set the fallback device for heterogeneous execution.
+
+        Parameters
+        ----------
+        fallback_device : str or tvm.TVMContext
+            The fallback device used for heterogeneous execution.
+        """
+        if isinstance(fallback_device, str):
+            fallback_device = _nd.context(fallback_device)
+        if not isinstance(fallback_device, TVMContext):
+            raise TypeError("fallback_device is expected to be str " +
+                            "TVMContext, or dict of device name to target, " +
+                            "but received: {}".format(type(fallback_device)))
+
+        self._set_fallback_device(fallback_device.device_type)
 
 
 def build(func, target=None, target_host=None, params=None):
-    """Build a function to run on TVM graph runtime.
+    """Helper function that builds a Relay function to run on TVM graph
+    runtime.
 
     Parameters
     ----------
@@ -266,146 +335,28 @@ def build(func, target=None, target_host=None, params=None):
     params : dict
         The parameters of the final graph.
     """
-    target = target if target else _target.current_target()
-    if target is None:
-        raise ValueError("Target is not set in env or passed as argument.")
+    target = _update_target(target)
 
-    if isinstance(target, dict):
-        target, fallback_device = _update_heterogeneous_inputs(target)
-    elif isinstance(target, (str, _target.Target)):
-        target = _target.create(target)
-    else:
-        raise ValueError("target must be the type of str, tvm.target.Target," +
-                         "or dict of device name to target")
+    if isinstance(target_host, (str, _target.Target)):
+        target_host = _target.create(target_host)
+    elif target_host:
+        raise ValueError("target host must be the type of str, " +
+                         "tvm.target.Target, or None")
 
     # If current dispatch context is fallback context (the default root context),
     # then load pre-tuned parameters from TopHub
     if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext):
-        if isinstance(target, dict):
-            tophub_context = autotvm.tophub.context(list(target.values()))
-        else:
-            tophub_context = autotvm.tophub.context(target)
+        tophub_context = autotvm.tophub.context(list(target.values()))
     else:
         tophub_context = autotvm.util.EmptyContext()
 
-    cfg = BuildConfig.current
-
     with tophub_context:
-        func = optimize(func, target, params)
-        # Annotate the ops for heterogeneous execution.
-        if isinstance(target, dict):
-            func, target = _run_device_annotation_passes(func, target,
-                                                         fallback_device)
-        # Fuse ops before running code gen
-        func = ir_pass.infer_type(func)
-        func = ir_pass.fuse_ops(func, cfg.opt_level)
-        # Graph code generation
-        func = ir_pass.infer_type(func)
-        graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
-        graph_json, lowered_funcs, params = graph_gen.codegen(func)
-        mod = _tvm_build_module(
-            lowered_funcs, target=target, target_host=target_host)
+        bld_mod = BuildModule()
+        graph_json, mod, params = bld_mod.build(func, target, target_host,
+                                                params)
     return graph_json, mod, params
 
 
-def _update_heterogeneous_inputs(target):
-    """Update the target and fallback device required for heterogeneous
-    compilation. CPU is used as the fallback device if it wasn't provided.
-    Meanwhile, a CPU device type and "llvm" pair will be added to the target
-    dictionary in this case.
-
-    Parameters
-    ----------
-    target : dict of str(i.e. device/context name) to str/tvm.target.Target.
-        A dict contains context to target pairs.
-
-    Returns
-    -------
-    device_target : dict of int to tvm.target.Target.
-        The updated device type to target dict.
-
-    fallback_device : int
-        The updated fallback device type.
-    """
-    if not isinstance(target, dict):
-        raise ValueError("target must be dict of device name to target for " +
-                         "heterogeneous execution, but received %s."
-                         % type(target))
-
-    fallback_device = BuildConfig.current.fallback_device
-    if fallback_device is None:
-        # cpu is used as the default fallback device when heterogeneous
-        # execution is needed, but no fallback device is provided.
-        fallback_device = _nd.cpu(0).device_type
-        target[fallback_device] = str(_target.create("llvm"))
-    elif isinstance(fallback_device, str):
-        fallback_device = _nd.context(fallback_device).device_type
-    elif isinstance(fallback_device, TVMContext):
-        fallback_device = fallback_device.device_type
-    else:
-        raise ValueError("fallback_device expects the type of str or " +
-                         "TVMContext, but received %s." % type(fallback_device))
-
-    device_target = {}
-    for dev, tgt in target.items():
-        device_target[_nd.context(dev).device_type] = _target.create(tgt)
-
-    if fallback_device not in device_target:
-        raise ValueError("%s is used as the default device, but the target" +
-                         "is not provided."
-                         % _nd.context(fallback_device).device_name)
-    return device_target, fallback_device
-
-
-def _run_device_annotation_passes(func, target, fallback_device):
-    """Execute the device annotation passes to update the input program and
-    target information.
-
-    Parameters
-    ----------
-    func: tvm.relay.Function
-        The function where annotation passes will be execute at.
-
-    target : Dict[int, tvm.target.Target]
-        A dict contains device type to target pairs.
-
-    fallback_device : int
-        The fallback device type.
-
-    Returns
-    -------
-    target : Dict[int, tvm.target.Target]
-        The updated device type to target dict.
-
-    func : tvm.relay.Function
-        The updated func.
-    """
-    func = ir_pass.infer_type(func)
-    func = ir_pass.rewrite_annotated_ops(func, fallback_device)
-    device_map = ir_pass.collect_device_info(func)
-    # The expression to device type map will be empty if all or none of
-    # the expressions in the `func` are annotated because this map is
-    # obtained by propagating the device information in the device copy
-    # operator. None of the above cases needs device copy operator.
-    if not device_map:
-        annotation_map = ir_pass.collect_device_annotation_ops(func)
-        # No annotation.
-        if not annotation_map:
-            target = {0: target[fallback_device]}
-        else:
-            dev_type = next(iter(annotation_map.values()))
-            # All annotated with the same device type.
-            if all(val == dev_type for val in annotation_map.values()):
-                target = {0: target[dev_type]}
-            else:
-                raise RuntimeError("Expressions in the function are "
-                                   "annotated with various device types,"
-                                   "but not device copy operators "
-                                   "found. Please check the "
-                                   "RewriteAnnotation pass.")
-    return func, target
-
-
 class GraphExecutor(_interpreter.Executor):
     """Wrapper around Executor interface.
 
index 607ee18..b84d3eb 100644 (file)
@@ -269,6 +269,77 @@ def realize(graph):
     return _quantize.realize(graph)
 
 
+def optimize(func, params=None):
+    """ Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
+    "CanonicalizeOps" optimization before quantization.
+
+    # TODO(zhiics) These passes are executed one by one so far. We need to
+    # move them to the pass manager.
+
+    Parameters
+    ---------
+    func: tvm.relay.Function
+        The original Relay function to be optimized.
+
+    params : dict of str to tvm.NDArray
+        Input parameters to the graph that do not change
+        during inference time. Used for constant folding.
+
+    Returns
+    -------
+    ret: tvm.relay.Function
+        The graph after quantization
+    """
+
+    opt_passes = ["SimplifyInference",
+                  "FoldScaleAxis",
+                  "FoldConstant",
+                  "CanonicalizeOps"]
+
+    cfg = _build.build_config(add_pass=opt_passes)
+
+    if params:
+        name_dict = {}
+        for arg in func.params:
+            name = arg.name_hint
+            if name in name_dict:
+                name_dict[name] = None
+            else:
+                name_dict[name] = arg
+        bind_dict = {}
+        for k, v in params.items():
+            if k not in name_dict:
+                continue
+            arg = name_dict[k]
+            if arg is None:
+                raise ValueError("Multiple args in the function have name %s" % k)
+            bind_dict[arg] = _expr.const(v)
+        func = _expr.bind(func, bind_dict)
+
+    if "SimplifyInference" in cfg.add_pass:
+        func = _ir_pass.infer_type(func)
+        func = _ir_pass.simplify_inference(func)
+
+    if "FoldConstant" in cfg.add_pass:
+        func = _ir_pass.fold_constant(func)
+
+    if "FoldScaleAxis" in cfg.add_pass:
+        func = _ir_pass.infer_type(func)
+        func = _ir_pass.backward_fold_scale_axis(func)
+        func = _ir_pass.infer_type(func)
+        func = _ir_pass.forward_fold_scale_axis(func)
+        func = _ir_pass.fold_constant(func)
+
+    if "CanonicalizeOps" in cfg.add_pass:
+        func = _ir_pass.infer_type(func)
+        func = _ir_pass.canonicalize_ops(func)
+
+    if "FoldConstant" in cfg.add_pass:
+        func = _ir_pass.fold_constant(func)
+
+    return func
+
+
 def quantize(graph, params=None, dataset=None):
     """ The quantization procedure. Before running the three main
     procedure of quantization, "annotate", "calibrate" and "realize"
@@ -292,12 +363,8 @@ def quantize(graph, params=None, dataset=None):
     ret: Function
         The graph after quantization
     """
-    opt_passes = ["SimplifyInference",
-                  "FoldScaleAxis",
-                  "FoldConstant",
-                  "CanonicalizeOps"]
-    with _build.build_config(add_pass=opt_passes):
-        graph = _build.optimize(graph, params=params)
+    # TODO(zhiics) Move this to the pass manager.
+    graph = optimize(graph, params)
 
     graph = annotate(graph)
     graph = calibrate(graph, dataset)
index 57e300f..9b30ced 100644 (file)
@@ -311,7 +311,7 @@ bool LLVMEnabled() {
 
 /*! \return The default host target for a given device target */
 Target DefaultTargetHost(Target target) {
-  if (target->device_type == kDLCPU) {
+  if (target.defined() && target->device_type == kDLCPU) {
     return target;
   } else {
     if (LLVMEnabled()) {
index 08a88d5..63ee2d5 100644 (file)
@@ -38,54 +38,31 @@ namespace tvm {
 namespace relay {
 namespace backend {
 
+using TargetsMap = Map<tvm::Integer, tvm::Target>;
+
 /*!
- * \brief Context name / index
- *        See: python/tvm/_ffi/runtime_ctypes.py
+ * \brief Context index to Target
  */
-struct ContextMap {
-  static const std::unordered_map<int, std::string> mask2str;
-  static const std::unordered_map<std::string, int> str2mask;
-  static std::string Mask2Str(int mask) {
+struct ContextTargetMap {
+  static const std::unordered_map<int, tvm::Target> mask2str;
+  static tvm::Target Mask2Str(int mask) {
     CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
     return mask2str.at(mask);
   }
-  static int Str2Mask(const std::string& str) {
-    CHECK_GT(str2mask.count(str), 0) << "Unknown context.";
-    return str2mask.at(str);
-  }
-};
-
-const std::unordered_map<int, std::string> ContextMap::mask2str = {
-  {1, "cpu"},
-  {2, "gpu"},
-  {4, "opencl"},
-  {5, "aocl"},
-  {6, "sdaccel"},
-  {7, "vulkan"},
-  {8, "metal"},
-  {9, "vpi"},
-  {10, "rocm"},
-  {11, "opengl"},
-  {12, "ext_dev"}
 };
 
-const std::unordered_map<std::string, int> ContextMap::str2mask = {
-  {"llvm", 1},
-  {"cpu", 1},
-  {"c", 1},
-  {"gpu", 2},
-  {"cuda", 2},
-  {"nvptx", 2},
-  {"cl", 4},
-  {"opencl", 4},
-  {"aocl", 5},
-  {"aocl_sw_emu", 5},
-  {"vulkan", 7},
-  {"metal", 8},
-  {"vpi", 9},
-  {"rocm", 10},
-  {"opengl", 11},
-  {"ext_dev", 12}
+const std::unordered_map<int, tvm::Target> ContextTargetMap::mask2str = {
+  {1, tvm::Target::create("llvm")},
+  {2, tvm::Target::create("cuda")},
+  {4, tvm::Target::create("opencl")},
+  {5, tvm::Target::create("aocl")},
+  {6, tvm::Target::create("sdaccel")},
+  {7, tvm::Target::create("vulkan")},
+  {8, tvm::Target::create("metal")},
+  {9, tvm::Target::create("vpi")},
+  {10, tvm::Target::create("rocm")},
+  {11, tvm::Target::create("opengl")},
+  {12, tvm::Target::create("ext_dev")}
 };
 
 /*!
@@ -137,7 +114,7 @@ struct BuildOutput {
  */
 struct RelayBuildConfig {
   int opt_level{2};
-  std::string fallback_device{"llvm"};
+  int fallback_device{static_cast<int>(kDLCPU)};
   std::unordered_set<std::string> enabled_pass;
   std::unordered_set<std::string> disabled_pass;
   OptPassLevel OPT_PASS_LEVEL;
@@ -164,14 +141,8 @@ struct GraphCodegen {
   }
   ~GraphCodegen() {}
 
-  void Init(runtime::Module* m,
-            Map<HalideIR::Expr, HalideIR::Expr> targets) {
-    Array<HalideIR::Expr> tgts;
-    for (auto kv : targets) {
-      tgts.push_back(kv.first);
-      tgts.push_back(kv.second);
-    }
-    CallFunc("init", m, tgts);
+  void Init(runtime::Module* m, TargetsMap targets) {
+    CallFunc("init", m, targets);
   }
 
   void Codegen(const Function& func) {
@@ -248,14 +219,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     } else if (name == "build") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         CHECK_EQ(args.num_args, 3);
-        Array<HalideIR::Expr> tmp = args[1];
-        std::unordered_map<std::string, std::string> targets;
-        for (size_t i = 0; i < tmp.size(); i += 2) {
-          auto k = tmp[i].as<ir::StringImm>()->value;
-          auto v = tmp[i + 1].as<ir::StringImm>()->value;
-          targets[k] = v;
-        }
-        this->Build(args[0], targets, args[2]);
+        this->Build(args[0], args[1], args[2]);
       });
     } else if (name == "list_params") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
@@ -273,7 +237,8 @@ class RelayBuildModule : public runtime::ModuleNode {
       });
     } else if (name == "set_fallback_device") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        std::string dev = args[0];
+        CHECK_EQ(args.num_args, 1);
+        int dev = args[0];
         this->SetFallBackDev(dev);
       });
     } else if (name == "add_pass") {
@@ -328,7 +293,7 @@ class RelayBuildModule : public runtime::ModuleNode {
    *
    * \param device name
    */
-  void SetFallBackDev(const std::string& dev) {
+  void SetFallBackDev(int dev) {
     cfg_.fallback_device = dev;
   }
   /*!
@@ -402,8 +367,8 @@ class RelayBuildModule : public runtime::ModuleNode {
    * \param target_host Host target device
    */
   void Build(Function func,
-             const std::unordered_map<std::string, std::string>& targets,
-             const std::string& target_host) {
+             const TargetsMap& targets,
+             const tvm::Target& target_host) {
     targets_ = targets;
     target_host_ = target_host;
     BuildRelay(func, cfg_, params_);
@@ -416,8 +381,9 @@ class RelayBuildModule : public runtime::ModuleNode {
    * \param params params dict
    * \return relay::Function
    */
-  relay::Function BindParamsByName(relay::Function func,
-                              const std::unordered_map<std::string, runtime::NDArray>& params) {
+  relay::Function BindParamsByName(
+      relay::Function func,
+      const std::unordered_map<std::string, runtime::NDArray>& params) {
     std::unordered_map<std::string, relay::Var> name_dict;
     std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
     for (auto arg : func->params) {
@@ -454,7 +420,7 @@ class RelayBuildModule : public runtime::ModuleNode {
    * \return relay::Function
    */
   relay::Function Optimize(relay::Function func,
-                           const std::unordered_map<std::string, std::string>& targets,
+                           const TargetsMap& targets,
                            const RelayBuildConfig& cfg,
                            const std::unordered_map<std::string, runtime::NDArray>& params) {
     if (params.size()) {
@@ -507,8 +473,7 @@ class RelayBuildModule : public runtime::ModuleNode {
         auto enter_pf = GetPackedFunc("_EnterTargetScope");
         auto exit_pf = GetPackedFunc("_ExitTargetScope");
         for (const auto& kv : targets) {
-          auto target = Target::create(kv.second);
-          (*enter_pf)(target);
+          (*enter_pf)(kv.second);
           func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
           (*exit_pf)();
         }
@@ -530,25 +495,19 @@ class RelayBuildModule : public runtime::ModuleNode {
    *
    * \param targets dictionary
    * \param cfg
-   * \return Map<HalideIR::Expr, HalideIR::Expr>
+   * \return Map<tvm::Integer, tvm::Target>
    */
-  Map<HalideIR::Expr, HalideIR::Expr> UpdateHeterogeneousInputs(
-    const std::unordered_map<std::string, std::string>& targets,
-    const RelayBuildConfig& cfg) {
-    Map<HalideIR::Expr, HalideIR::Expr> device_target;
-    std::unordered_map<int64_t, std::string> tmp_map;
-    auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device);
-
+  TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets,
+                                       const RelayBuildConfig& cfg) {
+    TargetsMap device_target = targets;
+    std::unordered_map<int64_t, tvm::Target> tmp_map;
     for (const auto& kv : targets) {
-      tmp_map[ContextMap::Str2Mask(kv.first)] = kv.second;
-    }
-    if (tmp_map.count(fallback_idx) == 0) {
-      tmp_map[fallback_idx] = cfg.fallback_device;
+      tmp_map[kv.first->value] = kv.second;
     }
-    for (const auto& kv : tmp_map) {
+    if (tmp_map.count(cfg.fallback_device) == 0) {
       device_target.Set(
-        ir::IntImm::make(HalideIR::Int(64), kv.first),
-        ir::StringImm::make(kv.second));
+          cfg.fallback_device,
+          ContextTargetMap::Mask2Str(cfg.fallback_device));
     }
     return device_target;
   }
@@ -561,25 +520,19 @@ class RelayBuildModule : public runtime::ModuleNode {
    * \param targets_map_ptr
    * \return Function
    */
-  Function RunDeviceAnnotationPass(
-      Function func,
-      const RelayBuildConfig& cfg,
-      Map<HalideIR::Expr, HalideIR::Expr>* targets_map_ptr) {
-    auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device);
+  Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg,
+                                   TargetsMap* targets_map_ptr) {
     func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-    func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx);
-    auto device_map = CallPackedFunc<Map<Expr, Integer> >("relay._ir_pass.CollectDeviceInfo",
-                                                       func,
-                                                       nullptr);
+    func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
+                          cfg.fallback_device);
+    auto device_map = CallPackedFunc<Map<Expr, Integer> >(
+        "relay._ir_pass.CollectDeviceInfo", func, nullptr);
     if (device_map.size() == 0) {
-      auto annotation_map =
-        CallPackedFunc<Map<Expr, Integer> >("relay._ir_pass.CollectDeviceAnnotationOps",
-                                            func,
-                                            nullptr);
+      auto annotation_map = CallPackedFunc<Map<Expr, Integer> >(
+          "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
       if (annotation_map.size() == 0) {
         targets_map_ptr->Set(
-          ir::IntImm::make(HalideIR::Int(64), 0),
-          ir::StringImm::make(cfg.fallback_device));
+            0, ContextTargetMap::Mask2Str(cfg.fallback_device));
       } else {
         int64_t dev_type = -1;
         for (auto kv : annotation_map) {
@@ -594,9 +547,7 @@ class RelayBuildModule : public runtime::ModuleNode {
             << "found. Please check the "
             << "RewriteAnnotation pass.";
         }
-        targets_map_ptr->Set(
-          ir::IntImm::make(HalideIR::Int(64), 0),
-          ir::StringImm::make(ContextMap::Mask2Str(dev_type)));
+        targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type));
       }
     }
     return func;
@@ -614,15 +565,11 @@ class RelayBuildModule : public runtime::ModuleNode {
                   const std::unordered_map<std::string, tvm::runtime::NDArray> &params) {
     // convert
     tvm_cfg_ = build_config();
-    Map<HalideIR::Expr, HalideIR::Expr> device_target;
+    TargetsMap device_target;
     if (targets_.size() > 1) {
       device_target = UpdateHeterogeneousInputs(targets_, cfg);
     } else {
-      for (auto &kv : targets_) {
-        device_target.Set(
-          ir::IntImm::make(HalideIR::Int(64), ContextMap::Str2Mask(kv.first)),
-          ir::StringImm::make(kv.second));
-      }
+      device_target = targets_;
     }
     func = Optimize(func, targets_, cfg, params);
     if (device_target.size() > 1) {
@@ -640,16 +587,15 @@ class RelayBuildModule : public runtime::ModuleNode {
     ret_.graph_json = graph_codegen_->GetJSON();
     ret_.params = graph_codegen_->GetParams();
 
-    auto target_host = Target::create(target_host_);
-    ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host, tvm_cfg_);
+    ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_);
   }
 
  protected:
   std::unique_ptr<GraphCodegen> graph_codegen_;
   /*! \brief target device */
-  std::unordered_map<std::string, std::string> targets_;
+  TargetsMap targets_;
   /*! \brief target host device */
-  std::string target_host_;
+  tvm::Target target_host_;
   /*! \brief frontend optimization configure */
   RelayBuildConfig cfg_;
   /*! \brief parameters */
index 415e0ec..b14448c 100644 (file)
@@ -51,7 +51,7 @@ using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
 using GraphNodePtr = std::shared_ptr<GraphNode>;
 using GraphInputNodePtr = std::shared_ptr<GraphInputNode>;
 using GraphOpNodePtr = std::shared_ptr<GraphOpNode>;
-using TargetsMap = std::unordered_map<std::string, Target>;
+using TargetsMap = std::unordered_map<int, Target>;
 
 /*! \brief Lowered outputs */
 struct LoweredOutput {
@@ -193,12 +193,10 @@ class GraphOpNode : public GraphNode {
 class GraphRuntimeCodegen
     : public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
  public:
-  GraphRuntimeCodegen(runtime::Module* mod,
-                      const std::unordered_map<std::string, std::string>& targets) : mod_(mod) {
+  GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
+      : mod_(mod) {
     compile_engine_ = CompileEngine::Global();
-    for (auto &kv : targets) {
-      targets_[kv.first] = Target::create(kv.second);
-    }
+    targets_ = targets;
   }
 
   LoweredOutput Codegen(relay::Function func) {
@@ -406,7 +404,7 @@ class GraphRuntimeCodegen
     auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
     auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
     auto &device_type = storage_device_map_[expr][1];
-    auto call_dev_type = device_type[0]->value;  //-> int to string
+    auto call_dev_type = device_type[0]->value;
     Target target;
     if (targets_.size() == 1) {
        // homogeneous execution.
@@ -415,22 +413,17 @@ class GraphRuntimeCodegen
        }
     } else {
       // heterogeneous execution.
-      const auto call_dev_key = std::to_string(call_dev_type);
       std::string call_dev_name;
       if (call_dev_type == 0) {
         call_dev_name = "llvm";
       } else {
         call_dev_name = runtime::DeviceName(call_dev_type);
       }
-      if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) {
+      if (targets_.count(call_dev_type) == 0) {
         LOG(FATAL) << "No target is provided for device "
                    << call_dev_name;
       }
-      if (targets_.count(call_dev_key)) {
-        target = targets_[call_dev_key];
-      } else {
-        target = targets_[call_dev_name];
-      }
+      target = targets_[call_dev_type];
     }
     CCacheKey key = (*pf0)(func, target);
     CachedFunc lowerd_func = (*pf1)(compile_engine_, key);
@@ -604,30 +597,21 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
   virtual PackedFunc GetFunction(const std::string& name,
                                  const std::shared_ptr<ModuleNode>& sptr_to_self) {
      if (name == "init") {
-      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
-        CHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
-                                   << "runtime::Module mod and Map<str, StringImm> targets";
-        void* mod = args[0];
-        auto& sptr = args[1].node_sptr();
-        auto* node = static_cast<const ArrayNode*>(sptr.get());
-        auto& tmp_targets = node->data;
-        std::unordered_map<std::string, std::string> targets;
-        for (size_t i = 0; i < tmp_targets.size(); i += 2) {
-          std::string key;
-          auto sk = Expr(tmp_targets[i]).as<ir::StringImm>();
-          auto ik = Expr(tmp_targets[i]).as<ir::IntImm>();
-          if (sk) {
-            key = sk->value;
-          }
-          if (ik) {
-            key = std::to_string(ik->value);
-          }
-          auto v = Expr(tmp_targets[i + 1]).as<ir::StringImm>();
-          targets[key] = v->value;
-        }
-        codegen_ = std::make_shared<GraphRuntimeCodegen>(
-          reinterpret_cast<runtime::Module*>(mod), targets);
-      });
+       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+         CHECK_EQ(args.num_args, 2)
+             << "The expected of arguments are: "
+             << "runtime::Module mod and Map<int, Target> targets";
+         void* mod = args[0];
+         Map<Integer, tvm::Target> tmp = args[1];
+         TargetsMap targets;
+         for (const auto& it : tmp) {
+           auto dev_type = it.first.as<ir::IntImm>();
+           CHECK(dev_type);
+           targets[dev_type->value] = it.second;
+         }
+         codegen_ = std::make_shared<GraphRuntimeCodegen>(
+             reinterpret_cast<runtime::Module*>(mod), targets);
+       });
     } else if (name == "codegen") {
       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
         Function func = args[0];
index 38481bf..a1ab299 100644 (file)
@@ -18,6 +18,7 @@
  */
 
 #include <gtest/gtest.h>
+#include <tvm/build_module.h>
 #include <tvm/tvm.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/type.h>
@@ -73,10 +74,10 @@ TEST(Relay, BuildModule) {
   auto build_f = build_mod.GetFunction("build", false);
   auto json_f = build_mod.GetFunction("get_graph_json", false);
   auto mod_f = build_mod.GetFunction("get_module", false);
-  Array<HalideIR::Expr> target_pair;
-  target_pair.push_back(ir::StringImm::make("cpu"));
-  target_pair.push_back(ir::StringImm::make("llvm"));
-  build_f(func, target_pair, "llvm");
+  Map<tvm::Integer, tvm::Target> targets;
+  Target llvm_tgt = Target::create("llvm");
+  targets.Set(0, llvm_tgt);
+  build_f(func, targets, llvm_tgt);
   std::string json = json_f();
   tvm::runtime::Module mod = mod_f();
   // run
index a038685..d3538bb 100644 (file)
@@ -74,13 +74,12 @@ def test_alter_layout_conv2d():
 
     for tgt in targets:
         with tvm.target.create(tgt) as target:
-            with relay.build_config(opt_level=-1, add_pass='AlterOpLayout'):
-               with autotvm.tophub.context(target):
-                   O = relay.optimize(N, target, params=None)
-                   O = relay.ir_pass.infer_type(O)
+            with autotvm.tophub.context(target):
+                O = relay.ir_pass.alter_op_layout(N)
+                O = relay.ir_pass.infer_type(O)
 
-                   # graph should differ
-                   assert not relay.ir_pass.alpha_equal(N, O)
+                # graph should differ
+                assert not relay.ir_pass.alpha_equal(N, O)
 
 if __name__ == "__main__":
     np.random.seed(42)
index b94f57d..affc6ce 100644 (file)
@@ -18,55 +18,10 @@ import numpy as np
 
 import tvm
 from tvm import relay
+from tvm.contrib.nvcc import have_fp16
 
-from tvm._ffi.function import _init_api
-_init_api("tvm.relay.build_module")
-
-class BuildModule(object):
-    def __init__(self):
-        self.mod = relay.build_module._BuildModule()
-        self._get_graph_json = self.mod["get_graph_json"]
-        self._get_module = self.mod["get_module"]
-        self._build = self.mod["build"]
-        self._set_opt_level = self.mod["set_opt_level"]
-        self._set_params_func = self.mod["set_params"]
-        self._get_params_func = self.mod["get_params"]
-
-  
-    def build(self, func, target, target_host, params):
-        tgts = []
-        for kv in target.items():
-            tgts.append(kv[0])
-            tgts.append(kv[1])
-        self._set_params(params)
-        self._build(func, tgts, target_host)
-
-    def get_json(self):
-        return self._get_graph_json()
-
-    def get_module(self):
-        return self._get_module()
-
-    def set_opt_level(self, level):
-        self._set_opt_level(level)
-
-    def _set_params(self, params):
-        inputs = {}
-        for name, param in params.items():
-            inputs[name] = relay.Constant(param)
-        self._set_params_func(inputs)
-
-    def get_params(self):
-        params = self._get_params_func()
-        ret = {}
-        for key, value in params.items():
-            ret[key] = value.data
-        return ret
-
-
-def test_build():
-    m_bld = BuildModule()
-    tgt_name = "llvm"
+
+def test_basic_build():
     tgt = "llvm"
     ctx = tvm.cpu()
     # func
@@ -86,21 +41,96 @@ def test_build():
     }
     # build
     targets = {
-        tgt: tgt
+        tvm.expr.IntImm("int32", ctx.device_type): tgt
     }
-    m_bld.set_opt_level(3)
-    m_bld.build(func, targets, "llvm", params=params)
-    g_json = m_bld.get_json()
-    mmod = m_bld.get_module()
-    params = m_bld.get_params()
-   
+    g_json, mmod, params = relay.build(func, targets, "llvm", params=params)
+
     # test
     rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
     rt.set_input("a", A)
     rt.load_params(relay.save_param_dict(params))
     rt.run()
     out = rt.get_output(0)
-   
-    np.testing.assert_allclose(out.asnumpy(),
-        np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5)
-  
+
+    np.testing.assert_allclose(out.asnumpy(), np.maximum(np.dot(A.asnumpy(),
+                                                                B.asnumpy().T),
+                                                         0) + C.asnumpy(),
+                               atol=1e-5, rtol=1e-5)
+
+
+def test_fp16_build():
+    dtype = "float16"
+
+    if not tvm.module.enabled("cuda") or not tvm.gpu(0).exist:
+        print("skip because cuda is not enabled.")
+        return
+
+    ctx = tvm.gpu(0)
+    if dtype == "float16" and not have_fp16(ctx.compute_version):
+        print("skip because gpu does not support fp16")
+        return
+
+    x = relay.var("x", dtype=dtype, shape=(4, 4))
+    y = relay.var("y", dtype=dtype, shape=(4, 4))
+    z = x + y
+    func = relay.Function([x, y], z)
+    X = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx)
+    Y = tvm.nd.array(np.random.uniform(-1, 1, (4, 4)).astype(dtype), ctx=ctx)
+    params = {
+        "x": X,
+        "y": Y,
+    }
+
+    # build
+    g_json, mmod, params = relay.build(func, "cuda", params=params)
+
+    # test
+    rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
+    rt.load_params(relay.save_param_dict(params))
+    rt.run()
+    out = rt.get_output(0)
+
+    np.testing.assert_allclose(out.asnumpy(), X.asnumpy() + Y.asnumpy(),
+                               atol=1e-5, rtol=1e-5)
+
+
+def test_fp16_conversion():
+    def check_conversion(tgt, ctx):
+        if not tvm.module.enabled(tgt):
+            print("skip because {} is not enabled.".format(tgt))
+            return
+        elif tgt == "cuda" and ctx.exist and not have_fp16(ctx.compute_version):
+            print("skip because gpu does not support fp16")
+            return
+
+        n = 10
+
+        for (src, dst) in [('float32', 'float16'), ('float16', 'float32')]:
+            x = relay.var("x", relay.TensorType((n,), src))
+            y = x.astype(dst)
+            func = relay.Function([x], y)
+
+            # init input
+            X = tvm.nd.array(n * np.random.randn(n).astype(src) - n / 2)
+
+            # build
+            with relay.build_config(opt_level=1):
+                g_json, mmod, params = relay.build(func, tgt)
+
+            # test
+            rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
+            rt.set_input("x", X)
+            rt.run()
+            out = rt.get_output(0)
+
+            np.testing.assert_allclose(out.asnumpy(), X.asnumpy().astype(dst),
+                                       atol=1e-5, rtol=1e-5)
+
+    for target, ctx in [('llvm', tvm.cpu()), ('cuda', tvm.gpu())]:
+        check_conversion(target, ctx)
+
+
+if __name__ == "__main__":
+    test_basic_build()
+    test_fp16_build()
+    test_fp16_conversion()
index 04081e0..9a77d2f 100644 (file)
@@ -411,7 +411,7 @@ def run_fusible_network(dev, tgt):
                      expected_index)
 
     def test_fallback_all_operators(device, tgt):
-        target = {device: tgt}
+        target = {device: tgt, "cpu": "llvm"}
         annotated_func = get_func()
         expected_func = get_func()
         check_annotated_graph(annotated_func, expected_func)
index a9a683c..1630efc 100644 (file)
@@ -47,54 +47,54 @@ def test_simulated_quantize():
     assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
 
 
-def test_quantize_pass():
-    def quantize_weight(arr):
-        maximum = np.amax(np.abs(arr.asnumpy()))
-        scale = 2**math.ceil(math.log(maximum, 2))
-        out = np.around(arr.asnumpy() / scale * 128).astype('int8')
-        out = np.clip(out, -127, 127)
-        return relay.const(out, 'int8')
-#
-    n, c, h, w = 1, 3, 224, 224
-    def make_graph(data):
-        weight = relay.var("conv_weight")
-        out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
-        out = relay.Function(relay.ir_pass.free_vars(out), out)
-        return out
-#
-    def make_qgraph(data, weight):
-        out = data * relay.const(32.0)
-        out = relay.round(out)
-        out = relay.clip(out, a_min=-127, a_max=127)
-        out = out.astype('int8')
-#
-        out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
-                              padding=(1, 1), channels=c, out_dtype='int32')
-        out = out.astype('float32')
-        out = relay.multiply(out, relay.const(0.00024414062))
-        out = relay.Function(relay.ir_pass.free_vars(out), out)
-        return out
-#
-    data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
-    graph = make_graph(data)
-    dataset, params = make_dataset(graph, 10)
-#
-    with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
-                     round_for_shift=False, store_lowbit_output=False):
-        qgraph0 = qtz.quantize(graph, params)
-        qgraph0 = relay.ir_pass.infer_type(qgraph0)
-#
-    conv_weight = quantize_weight(params['conv_weight'])
-    qgraph1 = make_qgraph(data, conv_weight)
-    qgraph1 = relay.ir_pass.infer_type(qgraph1)
-#
-    graph = relay.create_executor('graph')
-    res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
-    res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
-    tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
+def test_quantize_pass():
+    def quantize_weight(arr):
+        maximum = np.amax(np.abs(arr.asnumpy()))
+        scale = 2**math.ceil(math.log(maximum, 2))
+        out = np.around(arr.asnumpy() / scale * 128).astype('int8')
+        out = np.clip(out, -127, 127)
+        return relay.const(out, 'int8')
+
+    n, c, h, w = 1, 3, 224, 224
+    def make_graph(data):
+        weight = relay.var("conv_weight")
+        out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
+        out = relay.Function(relay.ir_pass.free_vars(out), out)
+        return out
+
+    def make_qgraph(data, weight):
+        out = data * relay.const(32.0)
+        out = relay.round(out)
+        out = relay.clip(out, a_min=-127, a_max=127)
+        out = out.astype('int8')
+
+        out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
+                              padding=(1, 1), channels=c, out_dtype='int32')
+        out = out.astype('float32')
+        out = relay.multiply(out, relay.const(0.00024414062))
+        out = relay.Function(relay.ir_pass.free_vars(out), out)
+        return out
+
+    data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
+    graph = make_graph(data)
+    dataset, params = make_dataset(graph, 10)
+
+    with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
+                     round_for_shift=False, store_lowbit_output=False):
+        qgraph0 = qtz.quantize(graph, params)
+        qgraph0 = relay.ir_pass.infer_type(qgraph0)
+
+    conv_weight = quantize_weight(params['conv_weight'])
+    qgraph1 = make_qgraph(data, conv_weight)
+    qgraph1 = relay.ir_pass.infer_type(qgraph1)
+
+    graph = relay.create_executor('graph')
+    res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
+    res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
+    tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
 
 
 if __name__ == "__main__":
     np.random.seed(42)
     test_simulated_quantize()
-    test_quantize_pass()
+    test_quantize_pass()