INVALID_LAYOUT_TIME = 10e9
MAX_OUTPUT_NODES = 16
+
+OPT_OUT_OP = ["layout_transform"]
bind_inputs, expr2graph
from ._base import INVALID_LAYOUT_TIME
+from ._base import OPT_OUT_OP
def get_infer_layout(task_name):
if task_name.startswith("conv2d"):
self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys())
self._out_nodes_dict = get_out_nodes(self._in_nodes_dict)
self._fetch_cfg()
+ self._opt_out_op = OPT_OUT_OP
# Setup infer_layout for elemwise-like nodes
# Note: graph tuner currently only supports tuning of single input and single output
# elemwise-like node, and use infer_layout function from input op to generate layouts.
input_names = self._input_shapes.keys()
for idx in sorted(self._in_nodes_dict.keys()):
- if has_multiple_inputs(self._node_list, idx, input_names):
+ if has_multiple_inputs(self._node_list, idx, input_names, self._opt_out_op):
node_entry = self._node_list[idx]
node_entry["topi_op"] = []
node_entry["workloads"] = []
node_entry = self._node_list[key]
target_input_idx = -1
target_input_pos = -1
- if has_multiple_inputs(self._node_list, key, input_names):
+ if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
for i, item in enumerate(val):
node = self._node_list[item]
if not is_boundary_node(node, input_names):
continue
optimal_sch_idx = optimal_record_dict[node_idx]
full_states = self._stage_dict[node_idx].full_states
- if not has_multiple_inputs(self._node_list, node_idx, input_names):
+ if not has_multiple_inputs(self._node_list, node_idx, input_names, self._opt_out_op):
input_idx = self._in_nodes_dict[node_idx][0]
input_node = self._node_list[input_idx]
if is_boundary_node(input_node, input_names):
for key, val in self._in_nodes_dict.items():
target_input_idx = -1
target_input_pos = -1
- if has_multiple_inputs(self._node_list, key, input_names):
+ if has_multiple_inputs(self._node_list, key, input_names, self._opt_out_op):
for i, item in enumerate(val):
node = self._node_list[item]
if not is_boundary_node(node, input_names):
from tvm.autotvm.task import TaskExtractEnv
from .utils import has_multiple_inputs, is_boundary_node, is_skipped_node
-
+from .._base import OPT_OUT_OP
def expr2graph(expr, target_ops, node_dict, node_list):
"""Convert relay expr to graph data structure
node_direct_ancestor = []
for item_idx in node["inputs"]:
item = node_list[item_idx[0]]
- is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names)
+ is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], \
+ input_names, OPT_OUT_OP)
if item["op"] in target_ops or is_multiple_inputs:
node_direct_ancestor.append(item_idx[0])
else:
get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names)
for key, val in visited_dict.items():
node = node_list[key]
- is_multiple_inputs = has_multiple_inputs(node_list, key, input_names)
+ is_multiple_inputs = has_multiple_inputs(node_list, key, \
+ input_names, OPT_OUT_OP)
if node["op"] in target_ops or is_multiple_inputs:
in_node_dict[key] = val
from tvm import relay
from tvm.relay import transform
-
-def has_multiple_inputs(node_list, node_idx, input_names):
+def has_multiple_inputs(node_list, node_idx, input_names, opt_out_op):
"""Check whether a node has multiple input nodes
except variable nodes.
in_idx = in_idx[0]
in_node = node_list[in_idx]
# Exclude parameter nodes
- if in_node["op"] is not None or \
+ if(in_node["op"] is not None and in_node["op"].name in opt_out_op):
+ increase = False
+ for t_idx in in_node["inputs"]:
+ increase = has_multiple_inputs(node_list, t_idx[0], \
+ input_names, opt_out_op)
+ if increase:
+ num_inputs += 1
+ elif in_node["op"] is not None or \
("name" in in_node and in_node["name"] in input_names):
num_inputs += 1
return num_inputs > 1
from tvm.relay.testing import resnet
from tvm.autotvm.graph_tuner.utils import has_multiple_inputs, get_direct_ancestor, get_in_nodes, \
get_out_nodes, expr2graph, bind_inputs
+from tvm.autotvm.graph_tuner._base import OPT_OUT_OP
from tvm.relay.expr import Call, TupleGetItem, Tuple, Var
def verify_has_multiple_inputs(node_list, node_idx, input_names, expected_result):
- out = has_multiple_inputs(node_list, node_idx, input_names)
+ out = has_multiple_inputs(node_list, node_idx, input_names, OPT_OUT_OP)
assert out == expected_result, "Output mismatch: expecting checking %s to be %s but got %s." \
% (node_list[node_idx]["op"], str(expected_result), str(out))