return grc.codegen(mod["main"])
-def extract_from_program(func, params, ops, target, target_host=None):
+def extract_from_program(func, params, ops, target, target_host=None,
+ template_keys=None):
""" Extract tuning tasks from a relay program.
This function is the single program version of extract_from_multiple_program.
The compilation target
target_host: tvm.target.Target
The host compilation target
+ template_keys: dict of topi op to str
+ The tuning template keys map for schedules, default to None.
+ Example: {topi.nn.conv2d: 'direct'}
Returns
-------
task: Array of autotvm.task.Task
collected tasks
"""
- return extract_from_multiple_program([func], [params], ops, target, target_host)
+ return extract_from_multiple_program([func], [params], ops, target, target_host,
+ template_keys=template_keys)
-def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
+def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
+ template_keys=None):
""" Extract tuning tasks from multiple relay programs.
This function collects tuning tasks by building a list of programs
The compilation target
target_host: tvm.target.Target
The host compilation target
+ template_keys: dict of topi op to str
+ The tuning template keys map for schedules, default to None.
+ Example: {topi.nn.conv2d: 'direct'}
Returns
-------
logger.disabled = old_state
+ # convert *topi op to template key* map to *task name to template key* map
+ task_name_to_keys = {}
+ if template_keys is not None:
+ for op in template_keys.keys():
+ if op in env.topi_to_task:
+ task_name_to_keys[env.topi_to_task[op]] = template_keys[op]
+ else:
+ logger.warning("Invalid template key, fallback to direct")
+ task_name_to_keys[env.topi_to_task[op]] = 'direct'
+
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
try:
+ key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
tsk = create(task_name, args,
target=target, target_host=target_host,
- template_key='direct')
+ template_key=key)
tasks.append(tsk)
except topi.InvalidShapeError:
- print("[Warning] Invalid shape during AutoTVM task creation")
+ logger.warning("Invalid shape during AutoTVM task creation")
return tasks
ops=(relay.op.nn.conv2d,))
assert len(tasks) == 31
+def test_template_key_provided():
+ """test task extraction using non-'direct' template_key"""
+ target = 'llvm'
+
+ import topi
+ template_keys = {
+ # topi.nn.conv2d - is left blank to test fallback logic
+ topi.nn.dense: 'direct_nopack',
+ topi.nn.depthwise_conv2d_nchw: 'direct',
+ }
+
+ mod, params, _ = get_network('mobilenet', batch_size=1)
+ tasks = autotvm.task.extract_from_program(mod['main'], target=target,
+ params=params,
+ ops=(relay.op.nn.conv2d, relay.op.nn.dense),
+ template_keys=template_keys)
+ for task in tasks:
+ if 'dense' in task.name:
+ assert task.config_space.template_key == 'direct_nopack'
+ else:
+ assert task.config_space.template_key == 'direct'
+
+def test_template_key_empty():
+ """test task extraction using empty template_key"""
+ target = 'llvm'
+ mod, params, _ = get_network('mobilenet', batch_size=1)
+ tasks = autotvm.task.extract_from_program(mod['main'], target=target,
+ params=params,
+ ops=(relay.op.nn.conv2d, relay.op.nn.dense),
+ template_keys=None)
+ for task in tasks:
+ assert task.config_space.template_key == 'direct'
+
+def test_template_key_default():
+ """test task extraction without template_key"""
+ target = 'llvm'
+ mod, params, _ = get_network('mobilenet', batch_size=1)
+ tasks = autotvm.task.extract_from_program(mod['main'], target=target,
+ params=params,
+ ops=(relay.op.nn.conv2d, relay.op.nn.dense))
+ for task in tasks:
+ assert task.config_space.template_key == 'direct'
+
if __name__ == '__main__':
test_task_extraction()
+ test_template_key_provided()
+ test_template_key_empty()
+ test_template_key_default()