[Graph tuner]Add opt out operator for has_multiple_inputs for graph tuner (#5000)
authorzhen-jia <53954057+zhen-jia@users.noreply.github.com>
Fri, 13 Mar 2020 17:07:06 +0000 (10:07 -0700)
committerGitHub <noreply@github.com>
Fri, 13 Mar 2020 17:07:06 +0000 (10:07 -0700)
* consider layout_transform in has_multiple_inputs

* refactor code

* remove debug info

* remove subclass assignment

* refactoring a little bit

* remove default value

* remove trailing whitespace

* modify test for has_multiple_inputs

Co-authored-by: Ubuntu <ubuntu@ip-172-31-40-194.us-west-2.compute.internal>
python/tvm/autotvm/graph_tuner/_base.py
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py
python/tvm/autotvm/graph_tuner/pbqp_tuner.py
python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
python/tvm/autotvm/graph_tuner/utils/utils.py
tests/python/unittest/test_graph_tuner_utils.py

index e8d35ac..ae220bb 100644 (file)
@@ -23,3 +23,5 @@
 INVALID_LAYOUT_TIME = 10e9
 
 MAX_OUTPUT_NODES = 16
+
+OPT_OUT_OP = ["layout_transform"]
index bb9c52d..f1a0756 100644 (file)
@@ -34,6 +34,7 @@ from .utils import is_boundary_node, get_in_nodes, get_out_nodes, has_multiple_i
     bind_inputs, expr2graph
 from ._base import INVALID_LAYOUT_TIME
 
+from ._base import OPT_OUT_OP
 
 def get_infer_layout(task_name):
     if task_name.startswith("conv2d"):
@@ -153,6 +154,7 @@ class BaseGraphTuner(object):
         self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys())
         self._out_nodes_dict = get_out_nodes(self._in_nodes_dict)
         self._fetch_cfg()
+        self._opt_out_op = OPT_OUT_OP
 
         # Setup infer_layout for elemwise-like nodes
         # Note: graph tuner currently only supports tuning of single input and single output
@@ -162,7 +164,7 @@ class BaseGraphTuner(object):
         # elemwise-like node, and use infer_layout function from input op to generate layouts.
         input_names = self._input_shapes.keys()
         for idx in sorted(self._in_nodes_dict.keys()):
-            if has_multiple_inputs(self._node_list, idx, input_names):
+            if has_multiple_inputs(self._node_list, idx, input_names, self._opt_out_op):
                 node_entry = self._node_list[idx]
                 node_entry["topi_op"] = []
                 node_entry["workloads"] = []
@@ -246,7 +248,7 @@ class BaseGraphTuner(object):
             node_entry = self._node_list[key]
             target_input_idx = -1
             target_input_pos = -1
-            if has_multiple_inputs(self._node_list, key, input_names):
+            if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
                 for i, item in enumerate(val):
                     node = self._node_list[item]
                     if not is_boundary_node(node, input_names):
index e3e4d11..b9d40c8 100644 (file)
@@ -144,7 +144,7 @@ class DPTuner(BaseGraphTuner):
                 continue
             optimal_sch_idx = optimal_record_dict[node_idx]
             full_states = self._stage_dict[node_idx].full_states
-            if not has_multiple_inputs(self._node_list, node_idx, input_names):
+            if not has_multiple_inputs(self._node_list, node_idx, input_names, self._opt_out_op):
                 input_idx = self._in_nodes_dict[node_idx][0]
                 input_node = self._node_list[input_idx]
                 if is_boundary_node(input_node, input_names):
index 36090f4..d58694c 100644 (file)
@@ -249,7 +249,7 @@ class PBQPTuner(BaseGraphTuner):
         for key, val in self._in_nodes_dict.items():
             target_input_idx = -1
             target_input_pos = -1
-            if has_multiple_inputs(self._node_list, key, input_names):
+            if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
                 for i, item in enumerate(val):
                     node = self._node_list[item]
                     if not is_boundary_node(node, input_names):
index 17450ca..f1dd404 100644 (file)
@@ -26,7 +26,7 @@ from tvm.relay.ty import TupleType, TensorType
 from tvm.autotvm.task import TaskExtractEnv
 
 from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
-
+from .._base import OPT_OUT_OP
 
 def expr2graph(expr, target_ops, node_dict, node_list):
     """Convert relay expr to graph data structure
@@ -204,7 +204,8 @@ def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_nam
     node_direct_ancestor = []
     for item_idx in node["inputs"]:
         item = node_list[item_idx[0]]
-        is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names)
+        is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], \
+                input_names, OPT_OUT_OP)
         if item["op"] in target_ops or is_multiple_inputs:
             node_direct_ancestor.append(item_idx[0])
         else:
@@ -245,7 +246,8 @@ def get_in_nodes(node_list, target_ops, input_names):
         get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
     for key, val in visited_dict.items():
         node = node_list[key]
-        is_multiple_inputs = has_multiple_inputs(node_list, key, input_names)
+        is_multiple_inputs = has_multiple_inputs(node_list, key, \
+                input_names, OPT_OUT_OP)
         if node["op"] in target_ops or is_multiple_inputs:
             in_node_dict[key] = val
 
index 2486d0c..70e95c9 100644 (file)
@@ -20,8 +20,7 @@ import tvm
 from tvm import relay
 from tvm.relay import transform
 
-
-def has_multiple_inputs(node_list, node_idx, input_names):
+def has_multiple_inputs(node_list, node_idx, input_names, opt_out_op):
     """Check whether a node has multiple input nodes
     except variable nodes.
 
@@ -47,7 +46,14 @@ def has_multiple_inputs(node_list, node_idx, input_names):
         in_idx = in_idx[0]
         in_node = node_list[in_idx]
         # Exclude parameter nodes
-        if in_node["op"] is not None or \
+        if(in_node["op"] is not None and in_node["op"].name in opt_out_op):
+            increase = False
+            for t_idx in in_node["inputs"]:
+                increase = has_multiple_inputs(node_list, t_idx[0], \
+                        input_names, opt_out_op)
+            if increase:
+                num_inputs += 1
+        elif in_node["op"] is not None or \
                 ("name" in in_node and in_node["name"] in input_names):
             num_inputs += 1
     return num_inputs > 1
index f620acc..bd0ebe0 100644 (file)
@@ -27,11 +27,12 @@ from tvm import autotvm, relay
 from tvm.relay.testing import resnet
 from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
     get_out_nodes, expr2graph, bind_inputs
+from tvm.autotvm.graph_tuner._base import OPT_OUT_OP
 from tvm.relay.expr import Call, TupleGetItem, Tuple, Var
 
 
 def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
-    out = has_multiple_inputs(node_list, node_idx, input_names)
+    out = has_multiple_inputs(node_list, node_idx, input_names, OPT_OUT_OP)
     assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
                                    % (node_list[node_idx]["op"], str(expected_result), str(out))