[AutoTVM]Core functionality for Graph tuner (#2184)
authorYao Wang <kevinthesunwy@gmail.com>
Wed, 29 May 2019 23:36:05 +0000 (16:36 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Wed, 29 May 2019 23:36:05 +0000 (16:36 -0700)
* Add graph tuning

* Add tests

* Fix tests

* Fix pylint

* Small fix for docstring

* Minor fix

* Support fetching workload from relay expr

* Simplify benchmark layout transformation

* Add relay support

* Fix infer layout func name

* Refactor internal data representation

* Fix issues

* Add PBQP solver

* Fix layout transform check

* Add PBQPTuner test

* Fix lint

* Update tutorial

* Fix tutorial

* Fix lint

* Add relay test

* Remove nnvm since nnvm graph can be converted to relay function

* Modify benchmark layout wrt new layout_transform api

* Fix lint

* Update docstring for DP tuner

* Refactor traverse graph

* Support graph tuning for multiple target operators

* Fix fetching workloads

* Add x86 depthwise_conv2d infer_layout

* Fix x86 depthwise_conv2d autotvm

* Fix PBQP tuner

* Fix DP tuner

* Generate dummy layout transform record

* Update tutorial

* Modify layout records name

* Add ASF header

* Add ASF header for testing files

* Fix test

* Fix topi fetching

* Some refactors

* Fix lint

* Fix tutorial

* Rename test files

* Fix doc typo

* Add test case note link

18 files changed:
python/tvm/autotvm/graph_tuner/__init__.py [new file with mode: 0644]
python/tvm/autotvm/graph_tuner/_base.py [new file with mode: 0644]
python/tvm/autotvm/graph_tuner/base_graph_tuner.py [new file with mode: 0644]
python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py [new file with mode: 0644]
python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py [new file with mode: 0644]
python/tvm/autotvm/graph_tuner/pbqp_tuner.py [new file with mode: 0644]
python/tvm/autotvm/graph_tuner/utils/__init__.py [new file with mode: 0644]
python/tvm/autotvm/graph_tuner/utils/traverse_graph.py [new file with mode: 0644]
python/tvm/autotvm/graph_tuner/utils/utils.py [new file with mode: 0644]
python/tvm/autotvm/task/__init__.py
python/tvm/autotvm/task/topi_integration.py
tests/python/unittest/test_graph_tuner_core.py [new file with mode: 0644]
tests/python/unittest/test_graph_tuner_utils.py [new file with mode: 0644]
topi/python/topi/nn/conv2d.py
topi/python/topi/nn/depthwise_conv2d.py
topi/python/topi/x86/conv2d.py
topi/python/topi/x86/depthwise_conv2d.py
tutorials/autotvm/tune_relay_x86.py

diff --git a/python/tvm/autotvm/graph_tuner/__init__.py b/python/tvm/autotvm/graph_tuner/__init__.py
new file mode 100644 (file)
index 0000000..d590db0
--- /dev/null
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Autotvm graph tuner API."""
+from __future__ import absolute_import as _abs
+
+from . import _base
+from . import base_graph_tuner
+
+from .base_graph_tuner import BaseGraphTuner
+from .dynamic_programming_tuner import DPTuner
+from .pbqp_tuner import PBQPTuner
diff --git a/python/tvm/autotvm/graph_tuner/_base.py b/python/tvm/autotvm/graph_tuner/_base.py
new file mode 100644 (file)
index 0000000..83b9e06
--- /dev/null
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Helper functions and global data"""
+
+
+RULE_OUT_NODE_NAMES = ["Tuple", "TupleGetItem", "batch_flatten", "transpose", "reshape",
+                       "multibox_prior", "multibox_transform_loc", "where",
+                       "non_max_suppression", "strided_slice"]
+
+# We set a large time to represent an invalid layout-transformation.
+# This number is set to be 10e9 seconds to align with autotvm.
+INVALID_LAYOUT_TIME = 10e9
diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
new file mode 100644 (file)
index 0000000..0fbfc27
--- /dev/null
@@ -0,0 +1,522 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-instance-attributes,too-many-branches,too-many-nested-blocks,invalid-name,unused-argument,unused-variable,no-member,no-value-for-parameter
+"""Base class for graph tuner."""
+import logging
+from abc import abstractmethod
+
+import numpy as np
+import topi
+
+import tvm
+from tvm import autotvm, relay
+from tvm.autotvm.task import get_config
+from tvm.autotvm.task.topi_integration import deserialize_args, serialize_args
+from tvm.autotvm.record import encode, load_from_file
+from tvm.autotvm.measure import MeasureResult, MeasureInput
+
+from ... import target as _target
+from .utils import is_input_node, get_in_nodes, get_out_nodes, has_multiple_inputs, \
+    bind_inputs, expr2graph
+from ._base import INVALID_LAYOUT_TIME
+
+
+# Setup topi_op_name -> layout function
+# NOTE: To add more ops, change the following dictionary.
+OP2LAYOUT = {
+    "topi_nn_conv2d": topi.nn.conv2d_infer_layout,
+    "topi_nn_depthwise_conv2d_nchw": topi.nn.depthwise_conv2d_infer_layout,
+}
+
+
+@autotvm.template
+def layout_transform(*args):
+    """Autotvm layout transform template."""
+    args = deserialize_args(args)
+    cfg = get_config()
+    cfg.add_flop(-1)
+    data = args[0]
+    out = topi.layout_transform(*args)
+    sch = topi.generic.schedule_injective([out])
+    return sch, [data, out]
+
+
+class BaseGraphTuner(object):
+    """Class to search schedules considering both kernel execution time and
+    layout transformation time.
+
+    Before creating a Graph Executor instance, schedule candidates for all kernels in
+    graph should be provided through tensor-level tuning.
+    """
+    def __init__(self, graph, input_shapes, records, target_ops,
+                 target, max_sch_num=20, dtype="float32", verbose=True,
+                 log_file="graph_tuner.log", log_level=logging.DEBUG,
+                 name="graph_tuner"):
+        """Create a GlobalTuner instance. Local schedule searching for all nodes with
+        target_op in the input graph and layout transformation benchmark need to be
+        executed before initialization.
+
+        graph : tvm.relay.Expr.Function
+            Input graph
+
+        input_shapes : dict of str to tuple.
+            Input shapes of graph
+
+        records : str or iterator of (MeasureInput, MeasureResult)
+            Collection of kernel level tuning records.
+            If it is str, then it should be the filename of a records log file.
+                       Each row of this file is an encoded record pair.
+            Otherwise, it is an iterator.
+
+        target_ops : List of str
+            Target tuning operators.
+
+        target : str or tvm.target
+            Compilation target.
+
+        max_sch_num : int, optional
+            Maximum number of schedule candidates for each workload.
+
+        dtype : str, optional
+            Data type.
+
+        log_file : str, optional
+            graph tuner log file name
+
+        name : str, optional
+            Name of global tuner.
+        """
+        self._node_list = []
+        self._layout_transform_perf_records = {}
+        self._layout_transform_interlayer_cost = {}
+        self._input_shapes = input_shapes
+        self._target_ops = [op.__name__ for op in target_ops]
+
+        self._name = name
+        self._max_sch_num = max_sch_num
+        self._optimal_sch_dict = {}
+        self._records = records
+        self._dtype = dtype
+        if isinstance(target, str):
+            target = _target.create(target)
+        self._target = target
+        self._optimal_record_dict = {}
+
+        # Set up logger
+        self._verbose = verbose
+        self._logger = logging.getLogger(name + "_logger")
+        need_file_handler = need_console_handler = True
+        for handler in self._logger.handlers:
+            if handler.__class__.__name__ == 'FileHandler':
+                need_file_handler = False
+            if handler.__class__.__name__ == 'StreamHandler':
+                need_console_handler = False
+        self._log_level = log_level
+        self._log_file = log_file
+        self._formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
+        self._logger.setLevel(log_level)
+        if need_file_handler:
+            file_handler = logging.FileHandler(log_file)
+            file_handler.setFormatter(self._formatter)
+            self._logger.addHandler(file_handler)
+        if self._verbose and need_console_handler:
+            console_handler = logging.StreamHandler()
+            console_handler.setFormatter(self._formatter)
+            self._logger.addHandler(console_handler)
+            self._logger.setLevel(log_level)
+            self._logger.propagate = False
+
+        # Generate workload and schedule dictionaries.
+        if isinstance(graph, relay.expr.Function):
+            node_dict = {}
+            graph = bind_inputs(graph, input_shapes, dtype)
+            expr2graph(graph, self._target_ops, node_dict, self._node_list)
+        else:
+            raise RuntimeError("Unsupported graph type: %s" % str(type(graph)))
+
+        self._graph = graph
+        self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys())
+        self._out_nodes_dict = get_out_nodes(self._in_nodes_dict)
+        self._fetch_cfg()
+
+        # Setup infer_layout for elemwise-like nodes
+        # Note: graph tuner currently only supports tuning of single input and single output
+        # op as target op, such as conv2d, dense and conv2d_transpose. In this case, we can
+        # reuse infer_layout function from target ops for elemwise-like nodes. The behavior
+        # is to modify the first tensor shape of input workload to the output shape of
+        # elemwise-like node, and use infer_layout function from input op to generate layouts.
+        input_names = self._input_shapes.keys()
+        for idx in sorted(self._in_nodes_dict.keys()):
+            if has_multiple_inputs(self._node_list, idx, input_names):
+                node_entry = self._node_list[idx]
+                node_entry["topi_op"] = []
+                node_entry["workloads"] = []
+                for input_idx in self._in_nodes_dict[idx]:
+                    input_node = self._node_list[input_idx]
+                    if not is_input_node(input_node, input_names):
+                        input_topi_op = input_node["topi_op"][0]
+                        node_entry["topi_op"].append(input_topi_op)
+                        # Only replace the first input tensor
+                        input_workload = input_node["workloads"][0]
+                        first_tensor = input_workload[1]
+                        dtype = first_tensor[-1]
+                        new_shape = tuple([val.value for val in node_entry["types"][0].shape])
+                        actual_workload = (input_workload[0],) + \
+                                          ((new_shape + (dtype,)),) + input_workload[2:]
+                        node_entry["workloads"].append(actual_workload)
+                        if "record_candidates" not in node_entry:
+                            node_entry["record_candidates"] = input_node["record_candidates"]
+                    else:
+                        node_entry["topi_op"].append(None)
+                        node_entry["workloads"].append(None)
+
+
+    def _fetch_cfg(self):
+        """Read and pre-process input schedules."""
+        if isinstance(self._records, str):
+            records = load_from_file(self._records)
+        else:
+            records = self._records
+        cfg_dict = {}
+        for record in records:
+            in_measure, _ = record
+            workload = in_measure.task.workload
+            if workload not in cfg_dict:
+                cfg_dict[workload] = []
+            cfg_dict[workload].append(record)
+
+        cache_dict = {}
+        for key in self._in_nodes_dict:
+            node_entry = self._node_list[key]
+            if node_entry["op"] not in self._target_ops:
+                continue
+            workload = node_entry["workloads"][0]
+            if workload in cache_dict:
+                node_entry["record_candidates"] = cache_dict[workload]
+                continue
+            record_candidates = []
+            infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
+            layout_tracking_dict = {}
+            for record in cfg_dict[workload]:
+                in_measure, out_measure = record
+                workload = in_measure.task.workload
+                cfg = in_measure.config
+                # For multiple cfgs which produces the same in/out layouts,
+                # only the most efficient one is preserved.
+                with self._target:
+                    layouts = infer_layout_func(workload, cfg)
+                    if layouts in layout_tracking_dict:
+                        cost = out_measure.costs[0]
+                        current_best_cost = layout_tracking_dict[layouts][1].costs[0]
+                        if cost < current_best_cost:
+                            layout_tracking_dict[layouts] = record
+                    else:
+                        layout_tracking_dict[layouts] = record
+            sorted_records = sorted(layout_tracking_dict.values(),
+                                    key=lambda item: item[1].costs[0])
+            for i in range(min(self._max_sch_num, len(sorted_records))):
+                record_candidates.append(sorted_records[i])
+            node_entry["record_candidates"] = record_candidates
+            cache_dict[workload] = record_candidates
+
+    def _iterate_layout_transform(self, callback):
+        """Iterate all possible layout transformations and execute callback for each
+        iteration. callback function accepts 6 arguments: from_node_idx, to_node_idx,
+        from_sch_idx, to_sch_idx, args which represent the argument list of layout
+        transformation and is_valid showing whether this is a valid layout transformation.
+        """
+        input_names = self._input_shapes.keys()
+        for key, val in self._in_nodes_dict.items():
+            node_entry = self._node_list[key]
+            target_input_idx = -1
+            target_input_pos = -1
+            if has_multiple_inputs(self._node_list, key, input_names):
+                for i, item in enumerate(val):
+                    if not is_input_node(self._node_list[item], input_names):
+                        target_input_idx = item
+                        target_input_pos = i
+                        break
+
+            for i, item in enumerate(val):
+                i_idx = item
+                in_node_entry = self._node_list[i_idx]
+                if is_input_node(in_node_entry, input_names):
+                    continue
+
+                if node_entry["op"] in self._target_ops:
+                    o_idx = key
+                    o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
+                    o_wkl = node_entry["workloads"][0]
+                    i_topi_op = in_node_entry["topi_op"][0]
+                    i_wkl = in_node_entry["workloads"][0]
+                    pivot = 0
+                    while not i_wkl:
+                        pivot += 1
+                        i_topi_op = in_node_entry["topi_op"][pivot]
+                        i_wkl = in_node_entry["workloads"][pivot]
+                    i_infer_layout_func = OP2LAYOUT[i_topi_op]
+                else:
+                    o_idx = target_input_idx
+                    if i <= target_input_pos:
+                        continue
+                    o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
+                    o_wkl = node_entry["workloads"][target_input_pos]
+                    i_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][i]]
+                    i_wkl = node_entry["workloads"][i]
+
+
+                for m, i_record in enumerate(in_node_entry["record_candidates"]):
+                    for n, o_record in enumerate(node_entry["record_candidates"]):
+                        i_cfg, o_cfg = i_record[0].config, o_record[0].config
+                        with self._target:
+                            i_input_info, i_output_info = i_infer_layout_func(i_wkl, i_cfg)
+                            o_input_info, o_output_info = o_infer_layout_func(o_wkl, o_cfg)
+                        if len(i_input_info) > 1 or len(i_output_info) > 1 or \
+                                len(o_input_info) > 1 or len(o_output_info) > 1:
+                            raise RuntimeError("Graph tuner only supports target operator "
+                                               "with single input and single output. "
+                                               "Please check target_ops argument.")
+
+                        in_shape, in_layout = i_output_info[0]
+                        if node_entry["op"] in self._target_ops:
+                            _, out_layout = o_input_info[0]
+                        else:
+                            _, out_layout = o_output_info[0]
+                        data_placeholder = tvm.placeholder(in_shape, name="data",
+                                                           dtype=self._dtype)
+                        args = [data_placeholder, in_layout, out_layout]
+                        callback(i_idx, o_idx, m, n, args)
+
+
+    def _create_matrix_callback(self, from_node_idx, to_node_idx, from_sch_idx,
+                                to_sch_idx, args):
+        """Create dictionary containing matrix format of layout transformation
+        between nodes."""
+        sargs = serialize_args(args)
+        in_layout, out_layout = args[1], args[2]
+        ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(sargs)
+        idx_pair_key = (from_node_idx, to_node_idx)
+
+        if in_layout == out_layout:
+            layout_transform_time = 0
+        else:
+            layout_transform_time = \
+                self._layout_transform_perf_records[ltf_workload][1].costs[0]
+
+        if idx_pair_key not in self._layout_transform_interlayer_cost:
+            self._layout_transform_interlayer_cost[idx_pair_key] = []
+        if len(self._layout_transform_interlayer_cost[idx_pair_key]) <= from_sch_idx:
+            self._layout_transform_interlayer_cost[idx_pair_key].append([])
+        self._layout_transform_interlayer_cost[idx_pair_key][from_sch_idx]\
+            .append(layout_transform_time)
+
+    def benchmark_layout_transform(self, min_exec_num=100, timeout=10,
+                                   use_rpc=False, device_key=None, host="localhost",
+                                   port=9190, n_parallel=1, build_func='default',
+                                   layout_records=None, target_host=None, infer_layout=False):
+        """Benchmark all possible layout transformation in the graph,
+        given a set of schedule candidates for each workload of target operator.
+
+        Parameters
+        ----------
+        min_exec_num : int, optional
+            Minimum number of execution. Final execution time is the average of
+            all execution time.
+
+        timeout : int, optional
+            Time out for each execution.
+
+        use_rpc : boolean, optional
+            Whether to use rpc mode for benchmarking.
+
+        device_key : str, optional
+            Remote device key which can be queried by
+            python -m tvm.exec.query_rpc_tracker --host=0.0.0.0 --port=9190
+
+        host : str, optional
+            IP address used to create RPC tracker on host machine.
+
+        port : int, optional
+            Port number used to create RPC tracker on host machine.
+
+        n_parallel: int, optional
+            The number of measurement task that can run in parallel.
+            Set this according to the number of cpu cores (for compilation) and
+            the number of devices you have (for measuring generate code).
+
+        build_func: str or callable, optional
+            'default': call default builder. This works for normal target (llvm, cuda)
+
+            'ndk': use Android NDK to create shared library. Use this for android target.
+
+            callable: customized build function for other backends (e.g. VTA).
+                      See autotvm/measure/measure_methods.py::default_build_func for example.
+
+        layout_records : str or iterator of (MeasureInput, MeasureResult). optional
+            Collection of layout_transform benchmarking records.
+            If is str, then it should be the filename of a records log file.
+                   Each row of this file is an encoded record pair.
+            Otherwise, it is an iterator.
+
+            If this argument is set, graph tuner will first check whether layout_transform
+            workload already exists in records and skip benchmarking if possible.
+
+        target_host : str, optional
+            str or :any:`tvm.target.Target` optional
+            Host compilation target, if target is device.
+            When TVM compiles device specific program such as CUDA,
+            we also need host(CPU) side code to interact with the driver
+            setup the dimensions and parameters correctly.
+            target_host is used to specify the host side codegen target.
+            By default, llvm is used if it is enabled,
+            otherwise a stackvm intepreter is used.
+
+        infer_layout : bool, optional
+            Whether to infer layout transformation time if it doesn't exist in records, instead
+            of benchmarking on target device.
+
+            This might bring performance loss comparing to benchmarking layout transformation.
+        """
+        self._logger.info("Start to benchmark layout transformation...")
+        if layout_records is None and infer_layout:
+            raise RuntimeError("Requires some records to infer layout transformation time.")
+
+        if isinstance(layout_records, str):
+            layout_records = load_from_file(layout_records)
+            if not layout_records and infer_layout:
+                raise RuntimeError("Records must be non-empty to infer layout transformation time.")
+
+        if isinstance(layout_records, str):
+            layout_records = load_from_file(layout_records)
+        num_flops, total_time = 0, 0
+        if layout_records is not None:
+            for record in layout_records:
+                ltf_wkl = record[0].task.workload
+                self._layout_transform_perf_records[ltf_wkl] = record
+                input_shape = ltf_wkl[1][1]
+                flops = np.prod(input_shape)
+                num_flops += flops
+                total_time += record[1].costs[0]
+        avg_time = total_time / num_flops if num_flops > 0 else 0
+
+        args_list = []
+        def _fetch_args_callback(from_node_idx, to_node_idx, from_sch_idx,
+                                 to_sch_idx, args):
+            """Callback function to fetch layout transform args"""
+            _, in_layout, out_layout = args
+            if in_layout != out_layout:
+                args_list.append(args)
+
+        self._iterate_layout_transform(_fetch_args_callback)
+
+        def _log_to_list(record_list):
+            """Callback to log result to a list."""
+            def _callback(_, inputs, results):
+                """Callback implementation"""
+                record_list.append((inputs[0], results[0]))
+            return _callback
+
+        builder = autotvm.LocalBuilder(n_parallel=n_parallel, build_func=build_func)
+        runner = autotvm.LocalRunner(number=min_exec_num, repeat=1, timeout=timeout)
+        if use_rpc:
+            if device_key is None:
+                raise RuntimeError("device_key need to be set to use rpc tracker mode.")
+            runner = autotvm.measure.RPCRunner(device_key, host, port, n_parallel=n_parallel,
+                                               number=min_exec_num, repeat=1,
+                                               timeout=timeout)
+        measure_option = autotvm.measure_option(builder=builder, runner=runner)
+        for args in args_list:
+            args = serialize_args(args)
+            ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
+            if ltf_workload in  self._layout_transform_perf_records:
+                continue
+
+            if infer_layout:
+                input_shape = ltf_workload[1][1]
+                flops = 1
+                for i in input_shape:
+                    flops *= i
+                inferred_time = flops * avg_time
+                record_input = MeasureInput(target=self._target, task=None, config=None)
+                record_output = MeasureResult(costs=(inferred_time,), error_no=0,
+                                              all_cost=-1, timestamp=-1)
+                self._layout_transform_perf_records[ltf_workload] = (record_input, record_output)
+                continue
+
+            records = []
+            task = autotvm.task.create(layout_transform, args=args, target=self._target,
+                                       target_host=target_host)
+            task.workload = ltf_workload
+            tuner = autotvm.tuner.GridSearchTuner(task)
+            tuner.tune(n_trial=1, measure_option=measure_option,
+                       callbacks=[_log_to_list(records)])
+            if not isinstance(records[0][1].costs[0], float):
+                records[0] = (records[0][0], records[0][1]._replace(costs=(INVALID_LAYOUT_TIME,)))
+            self._layout_transform_perf_records[ltf_workload] = records[0]
+
+        self._iterate_layout_transform(self._create_matrix_callback)
+        self._logger.info("Benchmarking layout transformation successful.")
+
+    @property
+    def layout_transform_perf_records(self):
+        """Get layout transformation dictionary for input graph.
+
+        Returns
+        -------
+        layout_transform_perf_records : dict of tuple to (MeasureInput, MeasureResult)
+            Layout transformation dictionary for input graph.
+        """
+        return self._layout_transform_perf_records
+
+
+    def get_optimal_records(self):
+        """Convert optimal record dictionary to a list of records
+        with ascending order of node index in graph.
+
+        Returns
+        -------
+        sch_list : list of tuple
+            List of records with ascending order of node index in graph.
+        """
+        ordered_index_list = sorted(self._optimal_record_dict.keys())
+        ret = []
+        for index in ordered_index_list:
+            node_entry = self._node_list[index]
+            if node_entry["op"] not in self._target_ops:
+                continue
+            ret.append(node_entry["record_candidates"][self._optimal_record_dict[index]])
+        return ret
+
+    def write_opt_sch2record_file(self, record_file="graph_opt_schedule.log"):
+        """Write graph level optimal schedules into file.
+
+        Parameters
+        ----------
+        record_file : str, optional
+            Output schedule file.
+        """
+        with open(record_file, "a") as out_file:
+            records = self.get_optimal_records()
+            for record in records:
+                out_file.write(encode(record[0], record[1]) + "\n")
+        msg = "Writing optimal schedules to %s successfully." % record_file
+        self._logger.info(msg)
+
+    @abstractmethod
+    def run(self, **kwargs):
+        """Run graph tuning."""
+        pass
diff --git a/python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py b/python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py
new file mode 100644 (file)
index 0000000..4a512c2
--- /dev/null
@@ -0,0 +1,358 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-many-instance-attributes,too-many-branches,too-many-statements,too-many-arguments,too-many-locals,invalid-name
+"""Stage class for dynamic programming tuner"""
+import numpy as np
+
+from .utils import is_input_node
+
+
+class DPStage(object):
+    """Class to represent node in Markov decision process. A stage has states
+    to represent different schedules of the current node. Since in this problem
+    the action is the schedule selected for current node, action can be fully
+    represented by states. No extra attribute needs for action.
+
+    In most cases, instance of this class should be created through DPTuner.
+    """
+    def __init__(self, idx, input_shapes, node_list,
+                 counted_nodes_set, layout_transform_interlayer_cost,
+                 stage_dict, in_nodes_dict, out_nodes_dict,
+                 dep_dict, target_ops, dtype="float32"):
+        """Initialize a stage and create all states.
+
+        Parameters
+        ----------
+        idx : int
+            Index for current node.
+
+        input_shapes : dict of string to tuple of int
+            Input shapes for current graph.
+
+        node_list : list of dict
+            List of all nodes for current graph.
+
+        counted_nodes_set : set of int
+            Global set recording whether the execution time of a node has been counted.
+
+        layout_transform_interlayer_cost : dict of tuple to list
+            Dictionary maps node index pair to layout transformation time between them.
+
+        stage_dict : dict of int to Stage
+            Global dictionary for all stages mapping node index to stage.
+
+        in_nodes_dict : dict of int to list of int
+            Dictionary maps node index to corresponding input node index.
+
+        out_nodes_dict : dict of int to list of int
+            Dictionary maps node index to corresponding output node index.
+
+        dep_dict : dict of int to set of int
+            Dictionary maps node index to dependent node index.
+
+        target_ops : list of str
+            Target operators
+
+        dtype : str, optional
+            Data type.
+        """
+        self._global_input_shapes = input_shapes
+        self._global_input_names = input_shapes.keys()
+        self._global_node_list = node_list
+        self._global_counted_nodes_set = counted_nodes_set
+        self._global_layout_transform_interlayer_cost = layout_transform_interlayer_cost
+        self._global_stage_dict = stage_dict
+        self._global_in_nodes_dict = in_nodes_dict
+        self._global_out_nodes_dict = out_nodes_dict
+        self._global_dep_dict = dep_dict
+
+        self._idx = idx
+        self._node_entry = self._global_node_list[idx]
+        self._target_ops = target_ops
+        self._wkl = self._node_entry["workloads"][0]
+        self._record_list = self._node_entry["record_candidates"]
+        self._dep = []
+        self._dtype = dtype
+        self._states = None
+        self._full_states = None
+        self._full_states_idx = None
+        self._create_states()
+
+    def _create_states(self):
+        """Create states."""
+        node = self._global_node_list[self._idx]
+        if node["op"] in self._target_ops:
+            self._create_op_states()
+        else:
+            self._create_multi_inputs_states()
+
+    def _create_op_states(self):
+        """State creation routine for nodes with target_op."""
+        input_idx = -1
+        for index in self._global_in_nodes_dict[self._idx]:
+            input_idx = index
+            if not is_input_node(self._global_node_list[input_idx],
+                                 self._global_input_names):
+                break
+
+        if is_input_node(self._global_node_list[input_idx],
+                         self._global_input_names):
+            self._full_states = np.array([record[1].costs[0]
+                                          for record in self._record_list])
+            self._states = self._full_states
+        else:
+            input_node_entry = self._global_node_list[input_idx]
+            input_stage = self._global_stage_dict[input_idx]
+            input_dep = input_stage.dep
+            input_states = input_stage.states
+            input_flatten_states = input_states.flatten()
+            input_record_list = input_node_entry["record_candidates"]
+            num_schedules = len(self._record_list)
+            num_input_schedules = len(input_record_list)
+            num_input_states = input_flatten_states.shape[0]
+
+            full_states_shape = tuple([num_schedules, num_input_schedules] +
+                                      [len(self._global_node_list[dep_idx]["record_candidates"])
+                                       for dep_idx in input_dep])
+            self._full_states = np.zeros(full_states_shape).flatten().astype("float32")
+            self._full_states_idx = [self._idx, input_idx] + input_dep
+            dep_multiplier = 1
+            for i in range(2, len(full_states_shape)):
+                dep_multiplier *= full_states_shape[i]
+            input_node_time_counted = input_idx in self._global_counted_nodes_set
+
+            for i in range(num_schedules):
+                current_sch_time = float(self._record_list[i][1].costs[0])
+                for j in range(num_input_states):
+                    input_sch_idx = j // dep_multiplier
+                    layout_transform_time = \
+                        self._global_layout_transform_interlayer_cost \
+                            [(input_idx, self._idx)][input_sch_idx][i]
+
+                    if input_node_time_counted:
+                        total_time = current_sch_time + layout_transform_time
+                    else:
+                        total_time = \
+                            current_sch_time + layout_transform_time + input_flatten_states[j]
+                    current_state_idx = i * num_input_states + j
+                    self._full_states[current_state_idx] = total_time
+
+            if not input_node_time_counted:
+                self._global_counted_nodes_set.add(input_idx)
+            self._full_states = self._full_states.reshape(full_states_shape)
+
+            # If out degree of input node is 1, we can remove the dimension of input node,
+            # since the states of input node will not be needed any more. Otherwise, input
+            # node should become a dependency.
+            if len(self._global_out_nodes_dict[input_idx]) == 1:
+                self._states = np.amin(self._full_states, axis=1)
+                self._dep = list(input_dep)
+            else:
+                self._states = self._full_states
+                self._dep = [input_idx,] + input_dep
+
+        # Update global dependency dictionary.
+        # This is to monitor the dependency states to decide
+        # when a dependency can be eliminated, so that total
+        # number of states can be largely reduced.
+        for dep_idx in self._dep:
+            self._global_dep_dict[dep_idx].remove(self._idx)
+            for child in self._global_out_nodes_dict[self._idx]:
+                self._global_dep_dict[dep_idx].add(child)
+        if len(self._global_out_nodes_dict[self._idx]) > 1:
+            self._global_dep_dict[self._idx] = set()
+            for child in self._global_out_nodes_dict[self._idx]:
+                self._global_dep_dict[self._idx].add(child)
+
+    def _create_multi_inputs_states(self):
+        """State creation routine for multi_input operator
+
+        In tvm, layout transformation for an elemwise-like follow the rule which
+        all input operators transform their layouts to the leftmost input operator
+        layout. For example:
+                            elemwise-sum
+                            |    |    |
+                            |    |    |
+                           op0  op1  op2
+        In this block, the possible layout transformations are: op1 -> op0 and op2 -> op0.
+        In graph tuning, a 3-D array with shape (k0, k1, k2) can represent the layout
+        transformations between these three nodes. It is also possible some earlier states
+        belong to other nodes(We name them as dependency) are required for dynamic programming.
+        The final states array for this elemwise-sum can be with shape (e0, k0, k1, e1, k2).
+        To iterate through all states, we first align the shape of op0, op1 and op2 to be
+        (e0, k0, k1, e1, k2) by broadcasting the original states. We also record the axis of
+        each input node in the states array, together with the multiplier. For example,
+        the axis index for op0 is 1, and multiplier is k1 * e1 * k2. If current iterating index
+        in the flatten array is i, the index of op0 can be computed as:
+        i % (k0 * k1 * e1 * k2) // (k1 * e1 * k2).
+        """
+        full_input_node_list = list(self._global_in_nodes_dict[self._idx])
+        input_index_list = []
+        # Remove input and parameter nodes
+        for input_idx in full_input_node_list:
+            if not is_input_node(self._global_node_list[input_idx],
+                                 self._global_input_names):
+                input_index_list.append(input_idx)
+
+        # Generate new states
+        states_list, aligned_node_list = DPStage.align_states(input_index_list,
+                                                              self._global_stage_dict,
+                                                              self._global_node_list)
+        target_node_idx, target_major_axis, target_multiplier, target_states = states_list[0]
+        aligned_shape = target_states.shape
+        self._full_states = np.zeros(aligned_shape).astype("float32").flatten()
+        self._full_states_idx = list(aligned_node_list)
+        num_states = self._full_states.shape[0]
+        node_time_counted = [item[0] in self._global_counted_nodes_set for item in states_list]
+        target_states = target_states.flatten()
+        src_states_list = [states_list[i][3].flatten() for i in range(1, len(states_list))]
+
+        for i in range(num_states):
+            target_sch_idx = (i % (target_multiplier *
+                                   aligned_shape[target_major_axis])) // target_multiplier
+            if node_time_counted[0]:
+                new_state = 0
+            else:
+                new_state = target_states[i]
+
+            for j in range(1, len(states_list)):
+                src_states = src_states_list[j - 1]
+                src_node_idx, src_major_axis, src_multiplier, _ = states_list[j]
+                src_sch_idx = (i % (src_multiplier *
+                                    aligned_shape[src_major_axis])) // src_multiplier
+                layout_transform_time = \
+                    self._global_layout_transform_interlayer_cost\
+                        [(src_node_idx, target_node_idx)][src_sch_idx][target_sch_idx]
+
+                if node_time_counted[j]:
+                    new_state += layout_transform_time
+                else:
+                    new_state += layout_transform_time + src_states[i]
+                self._full_states[i] = new_state
+
+        for i, node_counted in enumerate(node_time_counted):
+            if not node_counted:
+                self._global_counted_nodes_set.add(states_list[i][0])
+        self._full_states = self._full_states.reshape(aligned_shape)
+
+        # Remove dependency to reduce states
+        reduced_states = np.array(self._full_states)
+        reduced_states_transpose = [states_list[0][1]]
+        reduced_states_dep_list = []
+        self._dep = []
+        for i in range(len(reduced_states.shape)):
+            if i != states_list[0][1]:
+                reduced_states_transpose.append(i)
+                reduced_states_dep_list.append(aligned_node_list[i])
+        reduced_states = np.transpose(reduced_states, reduced_states_transpose)
+        shift = 0
+        for i, dep in enumerate(reduced_states_dep_list):
+            if dep not in self._global_dep_dict or len(self._global_dep_dict[dep]) == 1:
+                self._global_dep_dict.pop(dep, None)
+                reduced_states = np.amin(reduced_states, axis=i+1-shift)
+                shift += 1
+            else:
+                self._dep.append(dep)
+        self._states = reduced_states
+
+        # Update dependency
+        for dep in self._dep:
+            self._global_dep_dict[dep].remove(self._idx)
+            for child in self._global_out_nodes_dict[self._idx]:
+                self._global_dep_dict[dep].add(child)
+        if len(self._global_out_nodes_dict[self._idx]) > 1:
+            self._global_dep_dict[self._idx] = set()
+            for child in self._global_out_nodes_dict[self._idx]:
+                self._global_dep_dict[self._idx].add(child)
+
+    @property
+    def dep(self):
+        """Get dependency list."""
+        return self._dep
+
+    @property
+    def states(self):
+        """Get states."""
+        return self._states
+
+    @property
+    def full_states(self):
+        """Get complete states."""
+        return self._full_states
+
+    @property
+    def full_states_idx(self):
+        """Get node index of complete states."""
+        return self._full_states_idx
+
+    @staticmethod
+    def align_states(input_index_list, stage_dict, node_list):
+        """Align all input node states shapes to be the same and transpose/reshape properly.
+
+        This is used in creating multi_input operator states.
+
+        Parameters
+        ----------
+        input_index_list : list of int
+            List of input node index.
+
+        stage_dict : dict of int to Stage
+            Global dictionary of node index to stage.
+
+        node_list : list of dict
+            List of all nodes for current graph.
+
+        Returns
+        -------
+        states_list : list of tuple
+            List of aligned states.
+
+        aligned_node_list : list in int
+            List of node index for aligned states.
+        """
+        aligned_node_list = list(input_index_list)
+        states_list = []
+        for input_idx in input_index_list:
+            input_node_stage = stage_dict[input_idx]
+            for dep_idx in input_node_stage.dep:
+                if dep_idx not in aligned_node_list:
+                    aligned_node_list.append(dep_idx)
+        aligned_shape = tuple([len(node_list[idx]["record_candidates"])
+                               for idx in aligned_node_list])
+        for input_idx in input_index_list:
+            input_node_stage = stage_dict[input_idx]
+            input_node_shape_idx_list = [input_idx] + input_node_stage.dep
+            transpose_idx_list = []
+            reshape_list = []
+            major_axis = -1
+            for i, idx in enumerate(aligned_node_list):
+                if input_idx == idx:
+                    major_axis = i
+                if idx in input_node_shape_idx_list:
+                    transpose_idx_list.append(idx)
+                    reshape_list.append(aligned_shape[i])
+                else:
+                    reshape_list.append(1)
+            transpose_list = [input_node_shape_idx_list.index(idx) for idx in transpose_idx_list]
+            input_node_states = np.transpose(input_node_stage.states, tuple(transpose_list))
+            input_node_states = np.reshape(input_node_states, tuple(reshape_list))
+            input_node_states = np.broadcast_to(input_node_states, aligned_shape)
+            multiplier = 1
+            for i in range(major_axis + 1, len(aligned_shape)):
+                multiplier *= aligned_shape[i]
+            states_list.append((input_idx, major_axis, multiplier, input_node_states))
+        return states_list, aligned_node_list
diff --git a/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py b/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py
new file mode 100644 (file)
index 0000000..11571f2
--- /dev/null
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-error,too-many-locals,too-many-statements,too-many-branches,unused-variable
+"""Dynamic programming tuner."""
+import sys
+import numpy as np
+
+from .base_graph_tuner import BaseGraphTuner
+from .dynamic_programming_stage import DPStage
+from .utils import has_multiple_inputs, is_input_node
+
+if sys.version_info[0] == 3:
+    import queue
+else:
+    import Queue as queue
+
+class DPTuner(BaseGraphTuner):
+    """Tuner which uses dynamic programming to solve MDP problem.
+
+    Note: currently dynamic programming is used to solve this MDP problem. However,
+    this problem is intrinsically non-polynomial. DP can't apply for more complicated
+    models, such as networks with many element-wise sum operators. In this case, switch
+    to heuristic algorithm such as PBQP tuner.
+    """
+    def __init__(self, *args, **kwargs):
+        """Create a dynamic programming tuner.
+        """
+        super(DPTuner, self).__init__(*args, **kwargs)
+        self._num_states = self._max_num_states = None
+        self._stage_dict = {}
+        self._dep_dict = {}
+        self._counted_nodes_set = set()
+
+        self._global_data_dict = {
+            "dtype": self._dtype,
+            "counted_nodes_set": self._counted_nodes_set,
+            "stage_dict": self._stage_dict,
+            "in_nodes_dict": self._in_nodes_dict,
+            "out_nodes_dict": self._out_nodes_dict,
+            "dep_dict": self._dep_dict,
+            "node_list": self._node_list,
+            "input_shapes": self._input_shapes,
+            "layout_transform_interlayer_cost": self._layout_transform_interlayer_cost
+        }
+
+    def _check_num_states(self, num_states):
+        """Track the number of states."""
+        self._num_states += num_states
+        if self._max_num_states is not None:
+            if self._num_states > self._max_num_states:
+                raise RuntimeError("Too many states detected while running dynamic "
+                                   "programming: got %d states but upper limit is %d." %
+                                   (self._num_states, self._max_num_states))
+
+    def _forward(self):
+        """Forward pass in DP to generate states for all stages.
+        """
+        self._logger.info("Start forward pass...")
+        for node_idx in sorted(self._in_nodes_dict.keys()):
+            stage = DPStage(idx=node_idx, target_ops=self._target_ops,
+                            **self._global_data_dict)
+            self._check_num_states(stage.full_states.size)
+            self._stage_dict[node_idx] = stage
+        self._logger.info("Finished forward pass.")
+
+    def _backward(self):
+        """Backward pass in DP to generate optimal solution.
+        """
+        self._logger.info("Start backward pass...")
+        input_names = self._input_shapes.keys()
+        optimal_record_dict = {}
+        # Pick optimal schedule for output nodes
+        output_idx_list = []
+        for key, val in self._out_nodes_dict.items():
+            if not val:
+                output_idx_list.append(key)
+        states_list, aligned_node_list = DPStage.align_states(output_idx_list, self._stage_dict,
+                                                              self._node_list)
+        num_states = states_list[0][3].size
+        self._check_num_states(num_states * len(output_idx_list))
+        aligned_node_shape = states_list[0][3].shape
+        min_time = 0
+        min_pos = -1
+        for states in states_list:
+            min_time += np.amax(states[3])
+        flatten_states_list = [current_states[3].flatten() for current_states in states_list]
+        for i in range(num_states):
+            current_time = 0
+            for j, current_states in enumerate(states_list):
+                current_time += flatten_states_list[j][i]
+            if min_time > current_time:
+                min_time = current_time
+                min_pos = i
+        for i, states in enumerate(states_list):
+            current_major_axis = states[1]
+            current_sch_idx = (min_pos % (states[2] *
+                                          aligned_node_shape[current_major_axis])) // states[2]
+            optimal_record_dict[aligned_node_list[i]] = current_sch_idx
+        # Pick optimal schedule for dependencies of output nodes
+        for i in range(len(states_list), len(aligned_node_list)):
+            multiplier = 1
+            for j in range(i + 1, len(aligned_node_list)):
+                multiplier *= aligned_node_shape[j]
+            optimal_record_dict[aligned_node_list[i]] = \
+                min_pos // multiplier % aligned_node_shape[i]
+
+        # Backward pass to get optimal schedules for other nodes
+        bfs_q = queue.Queue()
+        visited = set()
+        for out_idx in output_idx_list:
+            bfs_q.put(out_idx)
+        while not bfs_q.empty():
+            node_idx = bfs_q.get()
+            visited.add(node_idx)
+            if is_input_node(self._node_list[node_idx], input_names):
+                continue
+            optimal_sch_idx = optimal_record_dict[node_idx]
+            full_states = self._stage_dict[node_idx].full_states
+            if not has_multiple_inputs(self._node_list, node_idx, input_names):
+                input_idx = self._in_nodes_dict[node_idx][0]
+                if is_input_node(self._node_list[input_idx], input_names):
+                    continue
+                if input_idx not in visited:
+                    bfs_q.put(input_idx)
+                    if input_idx not in optimal_record_dict:
+                        dep_list = self._stage_dict[node_idx].dep
+                        dep_idx = tuple([optimal_record_dict[item] for item in dep_list])
+                        tmp = np.argmin(full_states, axis=1)
+                        optimal_input_sch_idx = tmp[(optimal_sch_idx,) + dep_idx]
+                        optimal_record_dict[input_idx] = optimal_input_sch_idx
+            else:
+                input_idx_list = self._in_nodes_dict[node_idx]
+                optimal_record_dict[input_idx_list[0]] = optimal_sch_idx
+                full_states_idx = self._stage_dict[node_idx].full_states_idx
+                tmp = full_states[optimal_sch_idx]
+                new_states_idx, new_states_pos = [], []
+                visited_states_idx, visited_states_pos = [], []
+                for i in range(1, len(full_states_idx)):
+                    if full_states_idx[i] in optimal_record_dict:
+                        visited_states_idx.append(full_states_idx[i])
+                        visited_states_pos.append(i - 1)
+                    else:
+                        new_states_idx.append(full_states_idx[i])
+                        new_states_pos.append(i - 1)
+                if visited_states_idx:
+                    tmp = np.transpose(tmp, tuple(visited_states_pos + new_states_pos))
+                    tmp = tmp[tuple([optimal_record_dict[idx] for idx in visited_states_idx])]
+                min_pos = np.argmin(tmp)
+                multiplier = 1
+                for i in range(len(new_states_idx)):
+                    multiplier *= full_states.shape[new_states_pos[i] + 1]
+                for pos, idx in zip(new_states_pos, new_states_idx):
+                    multiplier //= full_states.shape[pos + 1]
+                    optimal_record_dict[idx] = min_pos // multiplier
+                    min_pos %= multiplier
+                for input_idx in input_idx_list:
+                    if input_idx not in visited:
+                        bfs_q.put(input_idx)
+
+        self._optimal_record_dict = optimal_record_dict
+        for node_idx, _ in self._in_nodes_dict.items():
+            if self._node_list[node_idx]["op"] not in self._target_ops:
+                continue
+        self._logger.info("Finished backward pass...")
+
+    def run(self, **kwargs):
+        """Run dynamic programming solver.
+        """
+        max_num_states = None if "max_num_states" not in kwargs else kwargs["max_num_states"]
+        self._num_states = 0
+        self._max_num_states = max_num_states
+        self._logger.info("Start to run dynamic programming algorithm...")
+        self._forward()
+        self._backward()
+        self._logger.info("Finished DPExecutor run.")
diff --git a/python/tvm/autotvm/graph_tuner/pbqp_tuner.py b/python/tvm/autotvm/graph_tuner/pbqp_tuner.py
new file mode 100644 (file)
index 0000000..1d7089e
--- /dev/null
@@ -0,0 +1,288 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,too-many-locals
+"""Partitioned Boolean Quadratic Programming Tuner"""
+from ._base import INVALID_LAYOUT_TIME
+from .base_graph_tuner import BaseGraphTuner
+from .utils import is_input_node, has_multiple_inputs
+
+
+class PBQPTuner(BaseGraphTuner):
+    """An approximation method to deal with intractably
+    large size of graph tuning problem.
+
+    This graph coloring algorithm mainly comes from:
+
+    Lang Hames and Bernhard Scholz.
+    Nearly optimal register allocation with pbqp.JMLC 2006.
+    LNCS, vol.4228,pp. 346-361, 2016
+    """
+    def __init__(self, *args, **kwargs):
+        """Create a partitioned boolean quadratic programming tuner.
+        """
+        super(PBQPTuner, self).__init__(*args, **kwargs)
+
+        # Remove input nodes
+        input_names = self._input_shapes.keys()
+        for node_idx in self._out_nodes_dict:
+            if is_input_node(self._node_list[node_idx], input_names):
+                for out_node_idx in self._out_nodes_dict[node_idx]:
+                    self._in_nodes_dict[out_node_idx].remove(node_idx)
+
+        self._adj_dict = {}
+        for node_idx in self._in_nodes_dict:
+            self._adj_dict[node_idx] = list(self._in_nodes_dict[node_idx]) + \
+                                       list(self._out_nodes_dict[node_idx])
+
+        self._record_cost_dict = {}
+        for key in self._in_nodes_dict:
+            self._record_cost_dict[key] = []
+            for record in self._node_list[key]["record_candidates"]:
+                self._record_cost_dict[key].append(record[1].costs[0])
+
+        self._max_degree = -1
+        self._node_degree_dict = {}
+        for node_idx in self._in_nodes_dict:
+            node_degree = self._get_degree(node_idx)
+            self._node_degree_dict[node_idx] = node_degree
+            self._max_degree = max(self._max_degree, node_degree)
+
+        self._stack = []
+        self._buckets = [[] for _ in range(self._max_degree + 2)]
+        for node_idx in sorted(self._in_nodes_dict):
+            node_degree = self._get_degree(node_idx)
+            self._buckets[node_degree].append(node_idx)
+
+        self._is_optimal = True
+
+    def _get_degree(self, node_idx):
+        """Get node degree.
+        """
+        return len(self._adj_dict[node_idx])
+
+    def _reorder_adj_nodes(self, node_idx):
+        """Update buckets list with current adjacency list.
+        """
+        for adj_node in self._adj_dict[node_idx]:
+            current_degree = self._get_degree(adj_node)
+            prev_degree = self._node_degree_dict[adj_node]
+            if prev_degree != current_degree:
+                self._buckets[prev_degree].remove(adj_node)
+                self._buckets[current_degree].insert(0, adj_node)
+                self._node_degree_dict[adj_node] = current_degree
+
+    def _remove_node(self, node_idx):
+        """Remove node from graph. Update adjacency list accordingly.
+        """
+        node_degree = self._get_degree(node_idx)
+        self._buckets[node_degree].remove(node_idx)
+        for adj_node in self._adj_dict[node_idx]:
+            self._adj_dict[adj_node].remove(node_idx)
+
+    def _insert_edge(self, node_x, node_y, adj_cost_matrix):
+        """Insert an edge between two nodes.
+        """
+        self._layout_transform_interlayer_cost[(node_x, node_y)] = adj_cost_matrix
+        self._layout_transform_interlayer_cost[(node_y, node_x)] = []
+        for i in range(len(adj_cost_matrix[0])):
+            self._layout_transform_interlayer_cost[(node_y, node_x)].append([])
+            for cost_vec in adj_cost_matrix:
+                self._layout_transform_interlayer_cost[(node_y, node_x)][i] \
+                    .append(cost_vec[i])
+
+        self._adj_dict[node_x].append(node_y)
+        self._adj_dict[node_y].append(node_x)
+
+    def _backward_insert_node(self, node_idx):
+        """Reinsert node in backward pass.
+        """
+        for adj_node in self._adj_dict[node_idx]:
+            self._adj_dict[adj_node].append(node_idx)
+
+    def _RI_reduction(self, node_idx):
+        """Reduce nodes with degree 1.
+        """
+        adj_node = self._adj_dict[node_idx][0]
+        ltf_matrix = self._layout_transform_interlayer_cost[(adj_node, node_idx)]
+        for i, cost_vec in enumerate(ltf_matrix):
+            min_cost = INVALID_LAYOUT_TIME
+            for j, cost in enumerate(cost_vec):
+                min_cost = min(min_cost, cost + self._record_cost_dict[node_idx][j])
+            self._record_cost_dict[adj_node][i] += min_cost
+        self._remove_node(node_idx)
+        self._reorder_adj_nodes(node_idx)
+        self._stack.append(node_idx)
+
+    def _RII_reduction(self, node_idx):
+        """Reduce nodes with degree 2.
+        """
+        adj_node_x, adj_node_y = self._adj_dict[node_idx]
+        ltf_matrix_x = self._layout_transform_interlayer_cost[(adj_node_x, node_idx)]
+        ltf_matrix_y = self._layout_transform_interlayer_cost[(adj_node_y, node_idx)]
+        delta_matrix = [[] for _ in range(len(ltf_matrix_x))]
+        for i, cost_vec_x in enumerate(ltf_matrix_x):
+            for j, cost_vec_y in enumerate(ltf_matrix_y):
+                min_cost = INVALID_LAYOUT_TIME
+                for k in range(len(self._record_cost_dict[node_idx])):
+                    min_cost = min(min_cost, cost_vec_x[k] + cost_vec_y[k]
+                                   + self._record_cost_dict[node_idx][k])
+                delta_matrix[i].append(min_cost)
+
+        if adj_node_x == adj_node_y:
+            for i, delta_row in enumerate(delta_matrix):
+                self._record_cost_dict[adj_node_x][i] += delta_row[i]
+        elif adj_node_x in self._adj_dict[adj_node_y]:
+            for i, _ in enumerate(delta_matrix):
+                for j, delta in enumerate(delta_matrix[i]):
+                    self._layout_transform_interlayer_cost[(adj_node_x, adj_node_y)][i][j] \
+                        += delta
+                    self._layout_transform_interlayer_cost[(adj_node_y, adj_node_x)][j][i] \
+                        += delta
+        else:
+            self._insert_edge(adj_node_x, adj_node_y, delta_matrix)
+
+        self._remove_node(node_idx)
+        self._reorder_adj_nodes(node_idx)
+        self._stack.append(node_idx)
+
+    def _RN_reduction(self, node_idx):
+        """Reduce nodes with degree greater than 2.
+        """
+        min_cost = INVALID_LAYOUT_TIME
+        record_idx = -1
+
+        for i, record_cost in enumerate(self._record_cost_dict[node_idx]):
+            current_cost = record_cost
+            for adj_node in self._adj_dict[node_idx]:
+                ltf_matrix = self._layout_transform_interlayer_cost[(node_idx, adj_node)]
+                adj_record_cost = list(self._record_cost_dict[adj_node])
+                for j, ltf_cost in enumerate(ltf_matrix[i]):
+                    adj_record_cost[j] += ltf_cost
+                current_cost += min(adj_record_cost)
+            if current_cost < min_cost:
+                min_cost = current_cost
+                record_idx = i
+
+        if record_idx < 0:
+            raise RuntimeError("Can't find a soltuion for node %d when "
+                               "applying RN reduction" % node_idx)
+        self._optimal_record_dict[node_idx] = record_idx
+        self._is_optimal = False
+
+        for adj_node in self._adj_dict[node_idx]:
+            ltf_matrix = self._layout_transform_interlayer_cost[(node_idx, adj_node)]
+            for i, ltf_cost in enumerate(ltf_matrix[record_idx]):
+                self._record_cost_dict[adj_node][i] += ltf_cost
+
+        self._remove_node(node_idx)
+        self._reorder_adj_nodes(node_idx)
+        self._stack.append(node_idx)
+
+    def _forward(self):
+        """Forward pass in PBQP to reduce nodes.
+        """
+        while True:
+            if self._buckets[1]:
+                node_idx = self._buckets[1][0]
+                self._RI_reduction(node_idx)
+            elif self._max_degree >= 2 and self._buckets[2]:
+                node_idx = self._buckets[2][0]
+                self._RII_reduction(node_idx)
+            elif self._max_degree >= 3:
+                max_degree_node = -1
+                for i in range(self._max_degree, 2, -1):
+                    if self._buckets[i]:
+                        max_degree_node = self._buckets[i][0]
+                        self._RN_reduction(max_degree_node)
+                        break
+                if max_degree_node < 0:
+                    break
+            else:
+                break
+
+    def _backward(self):
+        """Backward pass in PBQP to generate optimal solution.
+        """
+        # Solve nodes left in the forward graph
+        for node_idx in self._buckets[0]:
+            record_costs = self._record_cost_dict[node_idx]
+            min_cost = min(record_costs)
+            self._optimal_record_dict[node_idx] = record_costs.index(min_cost)
+
+        # Solve nodes with one or two degrees
+        for node_idx in reversed(self._stack):
+            self._backward_insert_node(node_idx)
+            if node_idx not in self._optimal_record_dict:
+                record_costs = list(self._record_cost_dict[node_idx])
+                for adj_node in self._adj_dict[node_idx]:
+                    adj_optimal_idx = self._optimal_record_dict[adj_node]
+                    for i, _ in enumerate(record_costs):
+                        record_costs[i] += \
+                            self._layout_transform_interlayer_cost \
+                                [(node_idx, adj_node)][i][adj_optimal_idx]
+                min_cost = min(record_costs)
+                self._optimal_record_dict[node_idx] = record_costs.index(min_cost)
+
+    def run(self, **kwargs):
+        """Run partitioned boolean quadratic programming tuner.
+        """
+        self._logger.info("Start to run PBQP algorithm...")
+        # Define virtual record lists and layout transformaton matrices
+        # for multi-input nodes.
+        input_names = self._input_shapes.keys()
+        temp = {}
+        for key, val in self._in_nodes_dict.items():
+            target_input_idx = -1
+            target_input_pos = -1
+            if has_multiple_inputs(self._node_list, key, input_names):
+                for i, item in enumerate(val):
+                    if not is_input_node(self._node_list[item], input_names):
+                        target_input_idx = item
+                        target_input_pos = i
+                        break
+                temp[(target_input_idx, key)] = []
+                record_candidates = self._node_list[target_input_idx]["record_candidates"]
+                for j in range(len(record_candidates)):
+                    temp[(target_input_idx, key)].append([])
+                    for k in range(len(record_candidates)):
+                        temp[(target_input_idx, key)][j].append(0 if j == k
+                                                                else INVALID_LAYOUT_TIME)
+
+                for j in range(target_input_pos + 1, len(val)):
+                    input_idx = val[j]
+                    if is_input_node(self._node_list[input_idx], input_names):
+                        continue
+                    temp[(input_idx, key)] = \
+                        self._layout_transform_interlayer_cost[(input_idx, target_input_idx)]
+        self._layout_transform_interlayer_cost.update(temp)
+
+        # Create reverse layout transformation matrices
+        temp = {}
+        for idx_pair, ltf_matrix in self._layout_transform_interlayer_cost.items():
+            reverse_key = (idx_pair[1], idx_pair[0])
+            reverse_matrix = [[] for _ in range(len(ltf_matrix[0]))]
+            for i, _ in enumerate(ltf_matrix):
+                for j, ltf in enumerate(ltf_matrix[i]):
+                    reverse_matrix[j].append(ltf)
+            temp[reverse_key] = reverse_matrix
+        self._layout_transform_interlayer_cost.update(temp)
+
+        self._forward()
+        self._backward()
+        is_optimal = "optimal" if self._is_optimal else "sub-optimal"
+        msg = "Finished PBQPExecutor run. Got %s solution." % is_optimal
+        self._logger.info(msg)
diff --git a/python/tvm/autotvm/graph_tuner/utils/__init__.py b/python/tvm/autotvm/graph_tuner/utils/__init__.py
new file mode 100644 (file)
index 0000000..8b36e75
--- /dev/null
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import
+"""Graph tuner utility functions"""
+from __future__ import absolute_import
+
+from . import traverse_graph
+from . import utils
+
+from .traverse_graph import expr2graph, get_direct_ancestor, get_in_nodes, \
+    get_out_nodes
+from .utils import has_multiple_inputs, is_input_node, bind_inputs
diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
new file mode 100644 (file)
index 0000000..08f1017
--- /dev/null
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-many-locals,too-many-statements,too-many-branches,protected-access
+"""API for graph traversing."""
+import threading
+
+import topi
+
+from tvm import relay, autotvm
+from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
+from tvm.relay.ty import TupleType, TensorType
+from tvm.autotvm.task import TaskExtractEnv
+
+from .._base import RULE_OUT_NODE_NAMES
+from .utils import has_multiple_inputs, is_input_node
+
+
+# Setup relay op base name -> topi compute functions
+# NOTE: To add more ops, change the following dictionary.
+OP2COMPUTE = {
+    "conv2d" : [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw],
+}
+
+
+def expr2graph(expr, target_ops, node_dict, node_list):
+    """Convert relay expr to graph data structure
+    and fetch workloads of target operators.
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr.Function
+        Input relay function expression.
+
+    target_ops: List of str
+        List of target relay base op name
+
+    node_dict : dictionary from tvm.relay.Expr to int
+        Dictionary to record node index
+
+    node_list : list of dictionary
+        List of nodes which contains all expr in the input relay function.
+        Each node will be stored as a dictionary in the format of
+        {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type],
+         "name": str, "workloads": [tuple], "topi_op": [function]}
+    """
+    env = TaskExtractEnv.get(allow_duplicate=True)
+    topi_funcs = []
+    for op_name in target_ops:
+        if op_name not in OP2COMPUTE:
+            raise RuntimeError("Not supported relay op in graph tuner: %s"
+                               % op_name)
+        topi_funcs += OP2COMPUTE[op_name]
+    env.reset(topi_funcs)
+    _expr2graph_impl(expr, target_ops, node_dict, node_list)
+    task_pos = 0
+    for node_entry in node_list:
+        if node_entry["op"] in target_ops:
+            task_name, args = env.task_collection[task_pos]
+            task = autotvm.task.create(task_name, args,
+                                       target="llvm",
+                                       target_host=None,
+                                       template_key='direct')
+            node_entry["workloads"] = [task.workload]
+            node_entry["topi_op"] = [task_name]
+            task_pos += 1
+
+
+def _expr2graph_impl(expr, target_ops, node_dict, node_list):
+    """Implementation to convert relay expr to graph data structure
+    """
+    def _traverse_expr(node):
+        if node in node_dict:
+            return
+        node_index = len(node_list)
+        node_entry = {"node": node, "inputs": [], "types": [],
+                      "op": "null", "name": None}
+
+        if isinstance(node, Call):
+            op_name = node.op.name.split(".")[-1]
+            node_entry["op"] = op_name
+            for arg in node.args:
+                in_node_idx = node_dict[arg]
+                if isinstance(arg, (Tuple, TupleGetItem)):
+                    node_entry["inputs"] += node_list[in_node_idx]["inputs"]
+                else:
+                    node_entry["inputs"].append([in_node_idx, 0, 0])
+            infer_out = relay.ir_pass.infer_type(node)
+            out_type = infer_out._checked_type_
+            if isinstance(out_type, TensorType):
+                node_entry["types"].append(out_type)
+            elif isinstance(out_type, TupleType):
+                for tupe_type in out_type.fields:
+                    node_entry["types"].append(tupe_type)
+            else:
+                raise RuntimeError("Unsupported output type %s in operator %s"
+                                   % (type(out_type), op_name))
+
+            # Utilize tracing target to fetch workload with topo-order.
+            # Since we only need workload, dummy target can be used to
+            # create task.
+            if op_name in target_ops:
+                params = []
+                for i, input_idx in enumerate(node_entry["inputs"]):
+                    input_node_entry = node_list[input_idx[0]]
+                    input_type = input_node_entry["types"][input_idx[1]]
+                    if not isinstance(input_node_entry["node"], (Var, Call)):
+                        raise RuntimeError("Graph tuner can only tune target "
+                                           "operators with input node of type "
+                                           "relay.expr.Var or relay.expr.Call. Now "
+                                           "find a target op %s with input type %s"
+                                           % (op_name, str(type(input_node_entry["node"]))))
+                    free_var = relay.Var("var_%d" % i, input_type)
+                    params.append(free_var)
+                call = relay.Call(node.op, params, node.attrs)
+                func = relay.Function(params, call)
+                relay.backend.compile_engine.get().clear()
+                build_thread = threading.Thread(target=relay.build,
+                                                args=(func,
+                                                      "llvm -device=tracing",
+                                                      None,
+                                                      None))
+                build_thread.start()
+                build_thread.join()
+        elif isinstance(node, Var):
+            node_entry["name"] = node.name_hint
+            node_entry["types"] = [node.type_annotation]
+        elif isinstance(node, Function):
+            # Ignore root node since it equals to input function expression
+            if node != expr:
+                _expr2graph_impl(node, target_ops, node_dict, node_list)
+            return
+        elif isinstance(node, TupleGetItem):
+            node_entry["op"] = "TupleGetItem"
+            in_node_idx = node_dict[node.tuple_value]
+            node_entry["inputs"].append([in_node_idx, node.index, 0])
+        elif isinstance(node, Tuple):
+            node_entry["op"] = "Tuple"
+            for tuple_item in node:
+                in_node_idx = node_dict[tuple_item]
+                if isinstance(tuple_item, TupleGetItem):
+                    node_entry["inputs"] += node_list[in_node_idx]["inputs"]
+                elif isinstance(tuple_item, Tuple):
+                    raise RuntimeError("Graph tuner doesn't support nested tuple.")
+                else:
+                    node_entry["inputs"].append([in_node_idx, 0, 0])
+        elif isinstance(node, Constant):
+            pass
+        elif isinstance(node, relay.op.op.Op):
+            return
+        else:
+            raise RuntimeError("Not supported relay node type in graph tuning: %s"
+                               % str(type(node)))
+        node_dict[node] = node_index
+        node_list.append(node_entry)
+
+    relay.ir_pass.post_order_visit(expr, _traverse_expr)
+
+
+def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_names):
+    """Given a node_list in relay function and a node index, return the
+    closest ancestor which has op_name as operator name or is multi_input operator.
+
+    If node has multiple inputs, multiple ancestor nodes will be returned.
+
+    Parameters
+    ----------
+    node_list : list of dict of str to object
+        List of all nodes in a graph.
+
+    visited_dict : dict of int to int
+        Nodes and corresponding ancestors which have been visited.
+
+    target_ops: List of str
+        List of target relay base op name
+
+    node_idx : int
+        Input node index.
+
+    input_names : list of str
+        Names of graph input nodes.
+
+    Returns
+    -------
+    out : list of int
+        List of ancestor node index.
+    """
+    if node_idx in visited_dict:
+        return visited_dict[node_idx]
+    if is_input_node(node_list[node_idx], input_names):
+        return [node_idx]
+    node = node_list[node_idx]
+    # Rule out injective operators
+    is_rule_out = False
+    for item_idx in node["inputs"]:
+        item = node_list[item_idx[0]]
+        if item["op"] in RULE_OUT_NODE_NAMES:
+            is_rule_out = True
+            break
+    if is_rule_out:
+        visited_dict[node_idx] = []
+        return []
+
+    node_direct_ancestor = []
+    for item_idx in node["inputs"]:
+        item = node_list[item_idx[0]]
+        is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names)
+        if item["op"] in target_ops or is_multiple_inputs:
+            node_direct_ancestor.append(item_idx[0])
+        else:
+            tmp = get_direct_ancestor(node_list, visited_dict, target_ops,
+                                      item_idx[0], input_names)
+            for tmp_item in tmp:
+                node_direct_ancestor.append(tmp_item)
+    if not has_multiple_inputs(node_list, node_idx, input_names) and node_direct_ancestor:
+        node_direct_ancestor = [node_direct_ancestor[0]]
+    visited_dict[node_idx] = node_direct_ancestor
+    return node_direct_ancestor
+
+
+def get_in_nodes(node_list, target_ops, input_names):
+    """Create a dictionary mapping from op_name nodes or multi_input
+    nodes to closest input ancestors.
+
+    Parameters
+    ----------
+    node_list : list of dict of str to object
+        List of all nodes in a graph.
+
+    target_ops: List of str
+        List of target relay op
+
+    input_names : list of str
+        Names of graph input nodes.
+
+    Returns
+    -------
+    out : dict of int to list of int
+        Dictionary maps node index to closest input ancestors.
+    """
+
+    visited_dict = {}
+    in_node_dict = {}
+    for i, node in enumerate(node_list):
+        if node["op"] in RULE_OUT_NODE_NAMES:
+            continue
+        get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
+    for key, val in visited_dict.items():
+        node = node_list[key]
+        is_multiple_inputs = has_multiple_inputs(node_list, key, input_names)
+        if node["op"] in target_ops or is_multiple_inputs:
+            in_node_dict[key] = val
+
+    # Remove empty nodes
+    has_empty_node = True
+    out_node_dict = get_out_nodes(in_node_dict)
+    while has_empty_node:
+        empty_nodes = []
+        for key, val in in_node_dict.items():
+            if not val:
+                empty_nodes.append(key)
+        if empty_nodes:
+            has_empty_node = True
+            for node in empty_nodes:
+                del in_node_dict[node]
+                if node in out_node_dict:
+                    for out_node in out_node_dict[node]:
+                        in_node_dict[out_node].remove(node)
+        else:
+            has_empty_node = False
+
+    return in_node_dict
+
+
+def get_out_nodes(in_node_dict):
+    """Create output dictionary from input dictionary.
+
+    Parameters
+    ----------
+    in_node_dict : dict of int to list of int
+        Dictionary maps node index to closest input ancestors.
+        It can be created with get_in_nodes.
+
+    Returns
+    -------
+    out : dict of int to list of int
+        Dictionary maps node index to closest output nodes.
+    """
+    out_node_dict = {}
+    for key in in_node_dict:
+        out_node_dict[key] = []
+    for key, val in in_node_dict.items():
+        for item in val:
+            if item in out_node_dict:
+                out_node_dict[item].append(key)
+            else:
+                out_node_dict[item] = [key]
+
+    return out_node_dict
diff --git a/python/tvm/autotvm/graph_tuner/utils/utils.py b/python/tvm/autotvm/graph_tuner/utils/utils.py
new file mode 100644 (file)
index 0000000..6151734
--- /dev/null
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=eval-used,invalid-name,too-many-arguments
+"""Utility functions"""
+from tvm import relay
+
+
+def has_multiple_inputs(node_list, node_idx, input_names):
+    """Check whether a node has multiple input nodes
+    except variable nodes.
+
+    Parameters
+    ----------
+    node_list : list of dict of str to object
+        List of all nodes in a graph.
+
+    node_idx : int
+        Node index to be checked.
+
+    input_names : list of str
+        List of input names of graph.
+
+    Returns
+    -------
+    out : bool
+        Whether the specified node has multiple input nodes
+    """
+    num_inputs = 0
+    node = node_list[node_idx]
+    for in_idx in node["inputs"]:
+        in_idx = in_idx[0]
+        in_node = node_list[in_idx]
+        # Exclude parameter nodes
+        if in_node["op"] != "null" or is_input_node(in_node,
+                                                    input_names):
+            num_inputs += 1
+    return num_inputs > 1
+
+
+def is_input_node(node_entry, input_names):
+    """Whether a node is an input node.
+
+    Parameters
+    ----------
+    node_entry : dict
+        Node entry.
+
+    input_names : list of str
+        List of input names of graph.
+
+    Returns
+    -------
+    out : bool
+        whether node is a input node.
+    """
+    return "name" in node_entry and node_entry["name"] in input_names
+
+
+def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
+    """Bind input variables of a relay function expression
+    to new shapes and/or dtypes.
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr.Function
+        Input relay function expression.
+
+    input_shapes : dict of str to tuple of int, optional
+        Input shapes.
+
+    input_dtypes : str or dict of str to str, optional
+        Input dtypes.
+
+    Returns
+    -------
+    out : tvm.relay.Expr.Function
+        Bind relay function expression.
+    """
+    if input_shapes is None:
+        return expr
+    if isinstance(input_dtypes, str):
+        input_dtypes = {key : input_dtypes for key in input_shapes.keys()}
+
+    updated_input_dict = {}
+    for input_name in input_shapes.keys():
+        updated_input = relay.var(input_name, shape=input_shapes[input_name],
+                                  dtype=input_dtypes[input_name])
+        updated_input_dict[input_name] = updated_input
+
+    rebind_dict = {}
+    for var in expr.params:
+        if var.name_hint in updated_input_dict:
+            rebind_dict[var] = updated_input_dict[var.name_hint]
+    updated_expr = relay.expr.bind(expr, rebind_dict)
+
+    return relay.ir_pass.infer_type(updated_expr)
index ff50a4e..0a0e6e1 100644 (file)
@@ -28,6 +28,7 @@ from .code_hash import attach_code_hash, attach_code_hash_to_arg
 from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \
     FallbackContext, clear_fallback_cache, ApplyGraphBest
 
-from .topi_integration import register_topi_compute, register_topi_schedule
+from .topi_integration import register_topi_compute, register_topi_schedule, \
+    TaskExtractEnv
 from .nnvm_integration import extract_from_graph, extract_from_multiple_graph
 from .relay_integration import extract_from_program, extract_from_multiple_program
index 3c98376..ef0cb56 100644 (file)
@@ -74,7 +74,7 @@ class TaskExtractEnv:
     """Global environment for extracting tuning tasks from nnvm graph"""
     current = None
 
-    def __init__(self):
+    def __init__(self, allow_duplicate=False):
         import topi
 
         # topi compute -> autotvm task name
@@ -106,6 +106,7 @@ class TaskExtractEnv:
             topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
         }
 
+        self.allow_duplicate = allow_duplicate
         self._register_tracing()
         self._register_topi_task()
         self.task_collection = []
@@ -123,10 +124,9 @@ class TaskExtractEnv:
                     assert not kwargs, "Do not support extracting tuning tasks when" \
                                        "kwargs is used in TOPI function call." \
                                        "Please modify it to use only positional args."
-
                     if compute_func in self.wanted_topi_funcs:  # record this call
                         key = (self.topi_to_task[compute_func], serialize_args(args))
-                        if key not in self.task_collection:
+                        if self.allow_duplicate or key not in self.task_collection:
                             self.task_collection.append(key)
                     return compute_func.fdefault(*args)
             _local_scope(topi_compute)
@@ -262,16 +262,25 @@ class TaskExtractEnv:
         return self.task_collection
 
     @staticmethod
-    def get():
+    def get(allow_duplicate=False):
         """Get the single instance of TaskExtractEnv
 
+        Parameters
+        ----------
+        allow_duplicate : boolean
+            Whether to fetch all workloads in the network,
+            even though some of them are the same. This is
+            useful for graph tuning.
+
         Returns
         -------
         env: TaskExtractEnv
             The single instance of TaskExtractEnv
         """
         if not TaskExtractEnv.current:
-            TaskExtractEnv.current = TaskExtractEnv()
+            TaskExtractEnv.current = TaskExtractEnv(allow_duplicate)
+        else:
+            TaskExtractEnv.current.allow_duplicate = allow_duplicate
         return TaskExtractEnv.current
 
 
diff --git a/tests/python/unittest/test_graph_tuner_core.py b/tests/python/unittest/test_graph_tuner_core.py
new file mode 100644 (file)
index 0000000..240da7f
--- /dev/null
@@ -0,0 +1,254 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# NOTE: We name this test file to start with test_graph_tuner
+# to make it execute after zero_rank tensor test cases. This
+# helps avoid topi arithmetic operator overloading issue:
+# https://github.com/dmlc/tvm/issues/3240.
+# TODO: restore the file name after this issue is resolved.
+import os
+import copy
+import numpy as np
+import tvm
+import tvm.relay.testing
+
+from tvm import autotvm
+from tvm import relay
+from tvm.autotvm.task import ConfigEntity
+from tvm.autotvm.measure import MeasureResult, MeasureInput
+from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
+from test_graph_tuner_utils import create_workload
+
+
+def _create_data(target, dshape, dtype, layout):
+    data = relay.var("data", shape=dshape, dtype=dtype)
+    w0 = relay.var("w0_weight")
+    conv0 = relay.nn.conv2d(data, w0, channels=16, kernel_size=(3, 3), padding=(1, 1))
+    w1 = relay.var("w1_weight")
+    conv1 = relay.nn.conv2d(conv0, w1, channels=32, kernel_size=(1, 1))
+    w2 = relay.var("w2_weight")
+    conv2 = relay.nn.conv2d(conv1, w2, channels=32, kernel_size=(3, 3), padding=(1, 1))
+    out = relay.add(conv1, conv2)
+    net = relay.Function(relay.ir_pass.free_vars(out), out)
+    net, params = relay.testing.create_workload(net)
+    tasks = autotvm.task.extract_from_program(net,
+                                              target=target,
+                                              params=params,
+                                              ops=(relay.op.nn.conv2d,))
+    wkl_list = [
+        create_workload((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
+        create_workload((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0), (1, 1), layout, layout, dtype, dtype),
+        create_workload((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
+    ]
+    costs = [0.04, 0.012, 0.03]
+    config_list = []
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [3, 1]],
+                      ["tile_oc", "sp", [4, 4]],
+                      ["tile_ow", "sp", [4, 2]],
+                      ["unroll_kw", "ot", True]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [2, 8]],
+                      ["tile_oc", "sp", [1, 32]],
+                      ["tile_oh", "ot", 1],
+                      ["tile_ow", "sp", [4, 2]]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [8, 4]],
+                      ["tile_oc", "sp", [4, 8]],
+                      ["tile_ow", "sp", [2, 4]],
+                      ["unroll_kw", "ot", False]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+
+    records = []
+    for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
+        task.workload = wkl
+        ms_input = MeasureInput(target=target, task=task, config=config)
+        ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
+        records.append((ms_input, ms_output))
+
+    ltf_records = []
+    ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
+    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
+    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
+    ltf_task = copy.deepcopy(tasks[0])
+    ltf_task.workload = ltf_wkl
+    ms_input = MeasureInput(target=target, task=ltf_task, config=None)
+    ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
+    ltf_records.append((ms_input, ms_output))
+
+    ltf_keys = []
+    ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"]
+    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
+    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
+    ltf_keys.append(ltf_wkl)
+    ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"]
+    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
+    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
+    ltf_keys.append(ltf_wkl)
+    ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"]
+    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
+    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
+    ltf_keys.append(ltf_wkl)
+
+    return net, records, ltf_records, ltf_keys, tasks
+
+
+def test_graph_tuner_layout_transform():
+    log_file = "%s/test_tuner.log" % (os.getcwd())
+    target = "llvm"
+    dshape = (1, 3, 8, 8)
+    dtype = "float32"
+    layout = "NCHW"
+    target_ops = [relay.nn.conv2d]
+
+    g, records, ltf_records, ltf_keys, _ = _create_data(target, dshape, dtype, layout)
+    executor = DPTuner(g, {"data": dshape}, records, target_ops, target=target, log_file=log_file)
+    executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
+    out = executor._layout_transform_perf_records
+
+    num_flops = 0
+    total_time = 0
+    for record in ltf_records:
+        ltf_wkl = record[0].task.workload
+        input_shape = ltf_wkl[1][1]
+        flops = np.prod(input_shape)
+        num_flops += flops
+        total_time += record[1].costs[0]
+    avg_time = total_time / num_flops
+
+    for ltf_workload in out:
+        input_shape = ltf_workload[1][1]
+        flops = 1
+        for i in input_shape:
+            flops *= i
+        expected_time = flops * avg_time
+        out_time = out[ltf_workload][1].costs[0]
+        assert expected_time == out_time, "Inferred layout transformation time mismatch for %s: " \
+                                          "expecting %f but got %f" % (str(ltf_workload), expected_time,
+                                                                       out_time)
+
+
+def test_DPTuner_run():
+    log_file = "%s/test_tuner.log" % (os.getcwd())
+    target = "llvm"
+    dtype = "float32"
+    layout = "NCHW"
+    dshape = (1, 3, 8, 8)
+    target_ops = [relay.nn.conv2d]
+
+    g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
+    costs = [0.02, 0.02, 0.045]
+    config_list = []
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [1, 3]],
+                      ["tile_oc", "sp", [2, 8]],
+                      ["tile_ow", "sp", [4, 2]],
+                      ["unroll_kw", "ot", True]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [4, 4]],
+                      ["tile_oc", "sp", [2, 16]],
+                      ["tile_oh", "ot", 1],
+                      ["tile_ow", "sp", [4, 2]]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [16, 2]],
+                      ["tile_oc", "sp", [8, 4]],
+                      ["tile_ow", "sp", [2, 4]],
+                      ["unroll_kw", "ot", False]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    for cost, config, task in zip(costs, config_list, tasks):
+        ms_input = MeasureInput(target=target, task=task, config=config)
+        ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
+        records.append((ms_input, ms_output))
+
+    executor = DPTuner(g, {"data": dshape}, records, target_ops, target, log_file=log_file)
+    executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
+    executor.run()
+    out = [record[0].config for record in executor.get_optimal_records()]
+    expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
+                                % (str(expected_out), str(out))
+    assert os.path.isfile(log_file), "No log file with name %s exists." % log_file
+
+
+def test_PBQPTuner_run():
+    target = "llvm"
+    dtype = "float32"
+    layout = "NCHW"
+    dshape = (1, 3, 8, 8)
+    target_ops = [relay.nn.conv2d]
+
+    g, records, ltf_records, ltf_keys, tasks = _create_data(target, dshape, dtype, layout)
+    costs = [0.02, 0.02, 0.045]
+    config_list = []
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [1, 3]],
+                      ["tile_oc", "sp", [2, 8]],
+                      ["tile_ow", "sp", [4, 2]],
+                      ["unroll_kw", "ot", True]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [4, 4]],
+                      ["tile_oc", "sp", [2, 16]],
+                      ["tile_oh", "ot", 1],
+                      ["tile_ow", "sp", [4, 2]]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [16, 2]],
+                      ["tile_oc", "sp", [8, 4]],
+                      ["tile_ow", "sp", [2, 4]],
+                      ["unroll_kw", "ot", False]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    for cost, config, task in zip(costs, config_list, tasks):
+        ms_input = MeasureInput(target=target, task=task, config=config)
+        ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
+        records.append((ms_input, ms_output))
+
+    executor = PBQPTuner(g, {"data": dshape}, records, target_ops, target)
+    executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
+    executor.run()
+    out = [record[0].config for record in executor.get_optimal_records()]
+    expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
+                           % (str(expected_out), str(out))
+
+
+if __name__=="__main__":
+    test_graph_tuner_layout_transform()
+    test_DPTuner_run()
+    test_PBQPTuner_run()
diff --git a/tests/python/unittest/test_graph_tuner_utils.py b/tests/python/unittest/test_graph_tuner_utils.py
new file mode 100644 (file)
index 0000000..0847166
--- /dev/null
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# NOTE: We name this test file to start with test_graph_tuner
+# to make it execute after zero_rank tensor test cases. This
+# helps avoid topi arithmetic operator overloading issue:
+# https://github.com/dmlc/tvm/issues/3240
+# TODO: restore the file name after this issue is resolved.
+import tvm
+
+from tvm import autotvm, relay
+from tvm.relay.testing import resnet
+from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
+    get_out_nodes, expr2graph, bind_inputs
+from tvm.relay.expr import Call, TupleGetItem, Tuple
+from topi.nn.conv2d import conv2d
+
+
+def create_workload(dshape, kshape, strides,
+                    padding, dilation, layout,
+                    out_layout, dtype, out_dtype):
+    data = tvm.placeholder(dshape, dtype=dtype)
+    kernel = tvm.placeholder(kshape, dtype=dtype)
+    return autotvm.task.args_to_workload([data, kernel, strides, padding, dilation, layout,
+                                          out_dtype], conv2d)
+
+
+def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
+    out = has_multiple_inputs(node_list, node_idx, input_names)
+    assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
+                                   % (node_list[node_idx]["op"], str(expected_result), str(out))
+
+
+def test_has_multiple_inputs():
+    data = relay.var("data")
+    out1 = data * relay.expr.const(3.0)
+    w0 = relay.var("w0")
+    out2 = relay.nn.conv2d(data, w0)
+    out = relay.add(out1, out2)
+    net = relay.Function(relay.ir_pass.free_vars(out), out)
+    net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1)})
+    target_ops = ["conv2d"]
+    node_list = []
+    node_dict = {}
+    expr2graph(net, target_ops, node_dict, node_list)
+    input_names = ["data"]
+    verify_has_multiple_inputs(node_list, 2, input_names, False)
+    verify_has_multiple_inputs(node_list, 4, input_names, False)
+    verify_has_multiple_inputs(node_list, 5, input_names, True)
+
+
+def test_expr2graph():
+    net, _ = resnet.get_workload(num_layers=50, batch_size=1)
+    node_dict = {}
+    node_list = []
+    target_ops = ["conv2d"]
+    op_name_list = []
+    def _count_node(node):
+        if not isinstance(node, relay.op.op.Op,):
+            return
+        if isinstance(node, Call):
+            op_name_list.append(node.op.name.split(".")[-1])
+        elif isinstance(node, TupleGetItem):
+            op_name_list.append("TupleGetItem")
+        elif isinstance(node, Tuple):
+            op_name_list.append("Tuple")
+        else:
+            op_name_list.append("null")
+    relay.ir_pass.post_order_visit(net, _count_node)
+
+    expr2graph(net, target_ops, node_dict, node_list)
+    for i, item in enumerate(zip(op_name_list, node_list)):
+        op_name, node = item
+        assert op_name == node["op"], "%dth Node operator mismatch: expecting %s but got %s" \
+                                      % (i, str(op_name), str(node["op"]))
+
+
+def test_get_direct_ancestor():
+    data = relay.var("data")
+    w0 = relay.var("w0")
+    out1 = relay.nn.conv2d(data, w0)
+    out2 = relay.add(out1, data * relay.expr.const(5.0))
+    out3 = out2 + relay.expr.const(2.5)
+    w1 = relay.var("w1")
+    out = relay.nn.conv2d(out3, w1)
+    net = relay.Function(relay.ir_pass.free_vars(out), out)
+    net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
+    target_ops = ["conv2d"]
+    node_list = []
+    node_dict = {}
+    expr2graph(net, target_ops, node_dict, node_list)
+    visited_dict = {}
+    input_names = ["data"]
+    out = get_direct_ancestor(node_list, visited_dict, target_ops, 5, input_names)
+    assert out == [2, 0], "Output mismatch: expecting [2, 0] but got %s." % str(out)
+
+
+def test_get_in_nodes():
+    data = relay.var("data")
+    w0 = relay.var("w0")
+    out1 = relay.nn.conv2d(data, w0)
+    out2 = relay.add(out1, data)
+    out3 = out2 + relay.expr.const(2.5)
+    w1 = relay.var("w1")
+    out = relay.nn.conv2d(out3, w1)
+    net = relay.Function(relay.ir_pass.free_vars(out), out)
+    net = bind_inputs(net, {"data": (1, 16, 224, 224), "w0": (16, 16, 1, 1), "w1": (16, 16, 1, 1)})
+    target_ops = ["conv2d"]
+    input_names = ["data"]
+    node_list = []
+    node_dict = {}
+    expr2graph(net, target_ops, node_dict, node_list)
+    out = get_in_nodes(node_list, target_ops, input_names)
+    expected_out = {7: [3], 3: [2, 0], 2: [0]}
+    diff_set = set(out) ^ set(expected_out)
+    if len(diff_set) != 0:
+        raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
+
+
+def test_get_out_nodes():
+    in_nodes_dict = {8: [4], 4: [3, 0], 3: [0]}
+    expected_out = {0: [3, 4], 3: [4], 4: [8], 8: []}
+    out = get_out_nodes(in_nodes_dict)
+    diff_set = set(out) ^ set(expected_out)
+    if len(diff_set) != 0:
+        raise RuntimeError("Output mismatch: expecting %s but got %s." % (str(expected_out), str(out)))
+
+
+
+if __name__ == "__main__":
+    test_has_multiple_inputs()
+    test_expr2graph()
+    test_get_direct_ancestor()
+    test_get_in_nodes()
+    test_get_out_nodes()
index 83e0274..57c1d20 100644 (file)
@@ -94,6 +94,26 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F):
     # not to change by default
     return None
 
+@tvm.target.generic_func
+def conv2d_infer_layout(workload, cfg):
+    """Infer input/output shapes and layouts from a workload and cfg.
+
+    Parameters
+    ----------
+    workload : tuple
+        conv2d workload
+
+    cfg : tuple
+        tvm.autotvm config
+
+    Returns
+    -------
+    Output : [tuple of tuple and str, tuple of tuple and str]
+        Input shapes and layouts, and output shapes and layouts
+    """
+    raise ValueError("missing register for topi.nn.conv2d_infer_layout")
+
+
 
 def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
     """ Get the workload structure. """
index 460f4fe..e703bec 100644 (file)
@@ -336,3 +336,22 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
         5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
     """
     raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
+
+@tvm.target.generic_func
+def depthwise_conv2d_infer_layout(workload, cfg):
+    """Infer input/output shapes and layouts from a workload and cfg.
+
+    Parameters
+    ----------
+    workload : tuple
+        conv2d workload
+
+    cfg : tuple
+        tvm.autotvm config
+
+    Returns
+    -------
+    Output : [tuple of tuple and str, tuple of tuple and str]
+        Input shapes and layouts, and output shapes and layouts
+    """
+    raise ValueError("missing register for topi.nn.depthwise_conv2d_infer_layout")
index de18abd..d0894ad 100644 (file)
@@ -28,7 +28,7 @@ from .. import generic, tag
 from .. import nn
 from ..util import get_const_tuple
 from ..nn.conv2d import conv2d, conv2d_NCHWc, \
-    conv2d_alter_layout, _get_workload as _get_conv2d_workload
+    conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
 from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
 from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
 from ..nn.pad import pad
@@ -475,6 +475,21 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
         return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
 
 
+@conv2d_infer_layout.register("cpu")
+def _conv2d_infer_layout(workload, cfg):
+    _, data, kernel, strides, padding, dilation, layout, dtype = workload
+    batch_size, in_channel, in_height, in_width = data[:-1]
+    out_channel, _, k_height, k_width = kernel[:-1]
+    out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
+    out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
+    tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+    in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
+    in_layout = "NCHW%dc" % tile_ic
+    out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
+    out_layout = "NCHW%dc" % tile_oc
+    return ((in_shape, in_layout),), ((out_shape, out_layout),)
+
+
 @autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct')
 def _declaration_conv_NCHWc(cfg, data, kernel, strides,
                             padding, dilation, layout, out_layout, out_dtype):
index f570aaf..6ea11f2 100644 (file)
@@ -25,7 +25,8 @@ from .. import generic, tag
 from ..nn.pad import pad
 from ..util import get_const_tuple
 from ..nn.util import get_pad_tuple
-from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload
+from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload, \
+    depthwise_conv2d_infer_layout
 
 from .util import get_fp32_len
 
@@ -206,7 +207,7 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
     # change shape with the value in config
     ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
     new_data_shape = (batch, in_channel // ic_bn, height, width, ic_bn)
-    new_kernel_shape = (out_channel // oc_bn, kh, kw, oc_bn)
+    new_kernel_shape = (out_channel // oc_bn, 1, kh, kw, 1, oc_bn)
     new_data = tvm.placeholder(new_data_shape, data.dtype)
     new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
 
@@ -217,3 +218,18 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
                                     data_layout, out_layout, dtype)
     s = schedule_depthwise_conv2d_NCHWc(cfg, [C])
     return s, [new_data, new_kernel, C]
+
+@depthwise_conv2d_infer_layout.register("cpu")
+def _depthwise_conv2d_infer_layout(workload, cfg):
+    _, data, kernel, strides, padding, dilation, dtype = workload
+    batch_size, in_channel, in_height, in_width = data[:-1]
+    filter_channel, channel_multiplier, k_height, k_width = kernel[:-1]
+    out_channel = filter_channel * channel_multiplier
+    out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
+    out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
+    tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+    in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
+    in_layout = "NCHW%dc" % tile_ic
+    out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
+    out_layout = "NCHW%dc" % tile_oc
+    return ((in_shape, in_layout),), ((out_shape, out_layout),)
index f100a35..ad35c19 100644 (file)
@@ -30,6 +30,7 @@ from tvm import autotvm
 from tvm import relay
 from tvm.relay import testing
 from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
+from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
 import tvm.contrib.graph_runtime as runtime
 
 #################################################################
@@ -81,6 +82,7 @@ batch_size = 1
 dtype = "float32"
 model_name = "resnet-18"
 log_file = "%s.log" % model_name
+graph_opt_sch_file = "%s_graph_opt.log" % model_name
 
 # Set number of threads used for tuning based on the number of
 # physical CPU cores on your machine.
@@ -157,6 +159,16 @@ def tune_kernels(tasks,
                            autotvm.callback.progress_bar(n_trial, prefix=prefix),
                            autotvm.callback.log_to_file(log_filename)])
 
+# Use graph tuner to achieve graph level optimal schedules
+# Set use_DP=False if it takes too long to finish.
+def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
+    target_op = [relay.nn.conv2d]
+    Tuner = DPTuner if use_DP else PBQPTuner
+    executor = Tuner(graph, {"data": dshape}, records, target_op, target)
+    executor.benchmark_layout_transform(min_exec_num=2000)
+    executor.run()
+    executor.write_opt_sch2record_file(opt_sch_file)
+
 
 ########################################################################
 # Finally, we launch tuning jobs and evaluate the end-to-end performance.
@@ -171,9 +183,10 @@ def tune_and_evaluate(tuning_opt):
     # run tuning tasks
     print("Tuning...")
     tune_kernels(tasks, **tuning_opt)
+    tune_graph(net, data_shape, log_file, graph_opt_sch_file)
 
-    # compile kernels with history best records
-    with autotvm.apply_history_best(log_file):
+    # compile kernels with graph-level best records
+    with autotvm.apply_graph_best(graph_opt_sch_file):
         print("Compile...")
         with relay.build_config(opt_level=3):
             graph, lib, params = relay.build_module.build(