Improve graph tuner dealing with Tuple (#3649)
authorYao Wang <kevinthesunwy@gmail.com>
Sun, 11 Aug 2019 16:36:13 +0000 (09:36 -0700)
committerYizhi Liu <liuyizhi@apache.org>
Sun, 11 Aug 2019 16:36:13 +0000 (00:36 +0800)
* Improve graph tuner dealing with Tuple

* Add test case

* Move some data out of _base.py

* Fix lint

python/tvm/autotvm/graph_tuner/_base.py
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
python/tvm/autotvm/graph_tuner/utils/utils.py
tests/python/unittest/test_graph_tuner_core.py

index 4002f67..e8d35ac 100644 (file)
 """Helper functions and global data"""
 
 
-# Operators dependent on original layouts.
-LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
-                   "multibox_prior", "multibox_transform_loc", "where",
-                   "non_max_suppression", "strided_slice"]
-
 # We set a large time to represent an invalid layout-transformation.
 # This number is set to be 10e9 seconds to align with autotvm.
 INVALID_LAYOUT_TIME = 10e9
index 68bc614..dca4148 100644 (file)
@@ -444,6 +444,7 @@ class BaseGraphTuner(object):
                                                timeout=timeout)
         measure_option = autotvm.measure_option(builder=builder, runner=runner)
         for args in args_list:
+            data, in_layout, out_layout = args
             args = serialize_args(args)
             ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args)
             if ltf_workload in  self._layout_transform_perf_records:
@@ -454,7 +455,18 @@ class BaseGraphTuner(object):
                 flops = 1
                 for i in input_shape:
                     flops *= i
-                inferred_time = flops * avg_time
+
+                # Rule out invalid layout transformations
+                out = topi.layout_transform(data, in_layout, out_layout)
+                out_flops = 1
+                for i in topi.util.get_const_tuple(out.shape):
+                    out_flops *= i
+
+                if flops != out_flops:
+                    inferred_time = INVALID_LAYOUT_TIME
+                else:
+                    inferred_time = flops * avg_time
+
                 record_input = MeasureInput(target=self._target, task=None, config=None)
                 record_output = MeasureResult(costs=(inferred_time,), error_no=0,
                                               all_cost=-1, timestamp=-1)
index a6eea6d..19c3193 100644 (file)
@@ -26,7 +26,7 @@ from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
 from tvm.relay.ty import TupleType, TensorType
 from tvm.autotvm.task import TaskExtractEnv
 
-from .utils import has_multiple_inputs, is_boundary_node
+from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
 
 
 # Setup relay op base name -> topi compute functions
@@ -252,7 +252,7 @@ def get_in_nodes(node_list, target_ops, input_names):
     visited_dict = {}
     in_node_dict = {}
     for i, node in enumerate(node_list):
-        if is_boundary_node(node, input_names):
+        if is_boundary_node(node, input_names) or is_skipped_node(node):
             continue
         get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
     for key, val in visited_dict.items():
@@ -282,10 +282,12 @@ def get_in_nodes(node_list, target_ops, input_names):
                         boundary_nodes.append(key)
         if boundary_nodes:
             for idx in boundary_nodes:
-                del in_node_dict[idx]
+                if idx in in_node_dict:
+                    del in_node_dict[idx]
         else:
             has_reduced_node = False
 
+
     # Remove empty nodes to ignore pre-computed sub-graph
     has_empty_node = True
     while has_empty_node:
index 2570d81..d73f2c3 100644 (file)
@@ -19,8 +19,6 @@
 from tvm import relay
 from tvm.relay import transform
 
-from .._base import LAYOUT_FIXED_OP
-
 
 def has_multiple_inputs(node_list, node_idx, input_names):
     """Check whether a node has multiple input nodes
@@ -72,11 +70,35 @@ def is_boundary_node(node_entry, input_names):
     out : bool
         whether node is a boundary node.
     """
-    out = node_entry["op"] in LAYOUT_FIXED_OP or \
+    # Operators dependent on original layouts.
+    _LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
+                        "multibox_prior", "multibox_transform_loc", "where",
+                        "non_max_suppression", "strided_slice"]
+
+    out = node_entry["op"] in _LAYOUT_FIXED_OP or \
           ("name" in node_entry and node_entry["name"] in input_names)
     return out
 
 
+def is_skipped_node(node_entry):
+    """Whether a node is not counted.
+
+    Parameters
+    ----------
+    node_entry : dict
+        Node entry.
+
+    Returns
+    -------
+    out : bool
+        whether node is skipped.
+    """
+    # Operators not counted in graph tuner.
+    _SKIPPED_OP = ["Tuple"]
+
+    return node_entry["op"] in _SKIPPED_OP
+
+
 def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):
     """Bind input variables of a relay function expression
     to new shapes and/or dtypes.
index 30b037e..c26b4b8 100644 (file)
@@ -354,25 +354,107 @@ def test_many_sub_graphs():
     ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
     ltf_records.append((ms_input, ms_output))
 
-    ltf_keys = []
-    ltf_arg = [tvm.placeholder((1, 4, 8, 8, 4), dtype=dtype), "NCHW4c", "NCHW8c"]
-    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
-    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
-    ltf_keys.append(ltf_wkl)
-    ltf_arg = [tvm.placeholder((1, 1, 8, 8, 32), dtype=dtype), "NCHW32c", "NCHW4c"]
-    ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
-    ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
-    ltf_keys.append(ltf_wkl)
-    ltf_arg = [tvm.placeholder((1, 4, 8, 8, 8), dtype=dtype), "NCHW8c", "NCHW32c"]
+    executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
+    executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
+    executor.run()
+    out = [record[0].config for record in executor.get_optimal_records()]
+    expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
+                                % (str(expected_out), str(out))
+
+    executor = PBQPTuner(net, {"data": dshape}, records, target_ops, target)
+    executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
+    executor.run()
+    out = [record[0].config for record in executor.get_optimal_records()]
+    expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
+    assert expected_out == out, "Output mismatch: expecting %s but got %s" \
+                                % (str(expected_out), str(out))
+
+
+def test_tuple():
+    target = "llvm"
+    dtype = "float32"
+    dshape = (1, 5, 32, 32)
+    layout = "NCHW"
+    target_ops = [relay.nn.conv2d]
+
+    data = relay.var("data", shape=dshape, dtype=dtype)
+    w0 = relay.var("w0_weight")
+    conv0 = relay.nn.conv2d(data, w0, channels=2, kernel_size=(3, 3), padding=(1, 1))
+    w1 = relay.var("w1_weight")
+    conv1 = relay.nn.conv2d(data, w1, channels=3, kernel_size=(3, 3), padding=(1, 1))
+    out = relay.concatenate([conv0, conv1], axis=1)
+    net = relay.Function(relay.analysis.free_vars(out), out)
+    net, params = relay.testing.create_workload(net)
+
+    tasks = autotvm.task.extract_from_program(net["main"],
+                                              target=target,
+                                              params=params,
+                                              ops=(relay.op.nn.conv2d,))
+    wkl_list = [
+        create_workload((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
+        create_workload((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1), (1, 1), layout, layout, dtype, dtype),
+    ]
+    costs = [0.01, 0.012, 0.03, 0.04]
+    config_list = []
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [1, 5]],
+                      ["tile_oc", "sp", [1, 2]],
+                      ["tile_ow", "sp", [4, 8]],
+                      ["unroll_kw", "ot", True]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [1, 5]],
+                      ["tile_oc", "sp", [1, 3]],
+                      ["tile_ow", "sp", [2, 16]],
+                      ["unroll_kw", "ot", False]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [1, 5]],
+                      ["tile_oc", "sp", [2, 1]],
+                      ["tile_ow", "sp", [4, 8]],
+                      ["unroll_kw", "ot", True]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+    cfg_dict = {"i": -1,
+                "c": None,
+                "e": [["tile_ic", "sp", [1, 5]],
+                      ["tile_oc", "sp", [3, 1]],
+                      ["tile_ow", "sp", [2, 16]],
+                      ["unroll_kw", "ot", False]],
+                "t": ""}
+    config_list.append(ConfigEntity.from_json_dict(cfg_dict))
+
+    records = []
+
+    wkl_list = wkl_list + wkl_list
+    tasks = tasks + tasks
+    for wkl, cost, config, task in zip(wkl_list, costs, config_list, tasks):
+        task.workload = wkl
+        ms_input = MeasureInput(target=target, task=task, config=config)
+        ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
+        records.append((ms_input, ms_output))
+
+    ltf_records = []
+    ltf_arg = [tvm.placeholder((1, 64, 16, 16, 8), dtype=dtype), "NCHW8c", "NCHW512c"]
     ltf_arg = autotvm.task.topi_integration.serialize_args(ltf_arg)
     ltf_wkl = ('layout_transform',) + autotvm.task.args_to_workload(ltf_arg)
-    ltf_keys.append(ltf_wkl)
+    ltf_task = copy.deepcopy(tasks[0])
+    ltf_task.workload = ltf_wkl
+    ms_input = MeasureInput(target=target, task=ltf_task, config=None)
+    ms_output =  MeasureResult(costs=(1.91224744e-05,), error_no=0, all_cost=-1, timestamp=-1)
+    ltf_records.append((ms_input, ms_output))
 
     executor = DPTuner(net, {"data": dshape}, records, target_ops, target)
     executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
     executor.run()
     out = [record[0].config for record in executor.get_optimal_records()]
-    expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
+    expected_out = [records[2][0].config, records[1][0].config]
     assert expected_out == out, "Output mismatch: expecting %s but got %s" \
                                 % (str(expected_out), str(out))
 
@@ -380,7 +462,7 @@ def test_many_sub_graphs():
     executor.benchmark_layout_transform(layout_records=ltf_records, infer_layout=True)
     executor.run()
     out = [record[0].config for record in executor.get_optimal_records()]
-    expected_out = [records[3][0].config, records[1][0].config, records[2][0].config]
+    expected_out = [records[2][0].config, records[1][0].config]
     assert expected_out == out, "Output mismatch: expecting %s but got %s" \
                                 % (str(expected_out), str(out))
 
@@ -390,3 +472,4 @@ if __name__=="__main__":
     test_DPTuner_run()
     test_PBQPTuner_run()
     test_many_sub_graphs()
+    test_tuple()