AutoTVM: selecting tuning templates when extracting task (#4338)
author黎明灰烬 <i@jackwish.net>
Sat, 16 Nov 2019 00:53:01 +0000 (08:53 +0800)
committerWuwei Lin <wuwei@apache.org>
Sat, 16 Nov 2019 00:53:01 +0000 (19:53 -0500)
* AutoTVM: selecting tuning templates when extracting task

Make the procedure of trying new templates easier.

Test: tests/python/relay/test_autotvm_task_extraction.py

* Use dict to match key for topi ops

* fix lint issue

* be more pythonic :)

python/tvm/autotvm/task/relay_integration.py
tests/python/relay/test_autotvm_task_extraction.py

index 345da66..b65c5d4 100644 (file)
@@ -54,7 +54,8 @@ def _lower(func,
     return grc.codegen(mod["main"])
 
 
-def extract_from_program(func, params, ops, target, target_host=None):
+def extract_from_program(func, params, ops, target, target_host=None,
+                         template_keys=None):
     """ Extract tuning tasks from a relay program.
 
     This function is the single program version of extract_from_multiple_program.
@@ -71,16 +72,21 @@ def extract_from_program(func, params, ops, target, target_host=None):
         The compilation target
     target_host: tvm.target.Target
         The host compilation target
+    template_keys: dict of topi op to str
+        The tuning template keys map for schedules, default to None.
+        Example: {topi.nn.conv2d: 'direct'}
 
     Returns
     -------
     task: Array of autotvm.task.Task
         collected tasks
     """
-    return extract_from_multiple_program([func], [params], ops, target, target_host)
+    return extract_from_multiple_program([func], [params], ops, target, target_host,
+                                         template_keys=template_keys)
 
 
-def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
+def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
+                                  template_keys=None):
     """ Extract tuning tasks from multiple relay programs.
 
     This function collects tuning tasks by building a list of programs
@@ -98,6 +104,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
         The compilation target
     target_host: tvm.target.Target
         The host compilation target
+    template_keys: dict of topi op to str
+        The tuning template keys map for schedules, default to None.
+        Example: {topi.nn.conv2d: 'direct'}
 
     Returns
     -------
@@ -146,15 +155,26 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
 
         logger.disabled = old_state
 
+    # convert *topi op to template key* map to *task name to template key* map
+    task_name_to_keys = {}
+    if template_keys is not None:
+        for op in template_keys.keys():
+            if op in env.topi_to_task:
+                task_name_to_keys[env.topi_to_task[op]] = template_keys[op]
+            else:
+                logger.warning("Invalid template key, fallback to direct")
+                task_name_to_keys[env.topi_to_task[op]] = 'direct'
+
     # create tasks for target
     tasks = []
     for task_name, args in env.get_tasks():
         try:
+            key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
             tsk = create(task_name, args,
                          target=target, target_host=target_host,
-                         template_key='direct')
+                         template_key=key)
             tasks.append(tsk)
         except topi.InvalidShapeError:
-            print("[Warning] Invalid shape during AutoTVM task creation")
+            logger.warning("Invalid shape during AutoTVM task creation")
 
     return tasks
index 242096f..d29d743 100644 (file)
@@ -79,5 +79,51 @@ def test_task_extraction():
                                                        ops=(relay.op.nn.conv2d,))
     assert len(tasks) == 31
 
+def test_template_key_provided():
+    """test task extraction using non-'direct' template_key"""
+    target = 'llvm'
+
+    import topi
+    template_keys = {
+        # topi.nn.conv2d - is left blank to test fallback logic
+        topi.nn.dense: 'direct_nopack',
+        topi.nn.depthwise_conv2d_nchw: 'direct',
+    }
+
+    mod, params, _ = get_network('mobilenet', batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod['main'], target=target,
+                                              params=params,
+                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense),
+                                              template_keys=template_keys)
+    for task in tasks:
+        if 'dense' in task.name:
+            assert task.config_space.template_key == 'direct_nopack'
+        else:
+            assert task.config_space.template_key == 'direct'
+
+def test_template_key_empty():
+    """test task extraction using empty template_key"""
+    target = 'llvm'
+    mod, params, _ = get_network('mobilenet', batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod['main'], target=target,
+                                              params=params,
+                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense),
+                                              template_keys=None)
+    for task in tasks:
+        assert task.config_space.template_key == 'direct'
+
+def test_template_key_default():
+    """test task extraction without template_key"""
+    target = 'llvm'
+    mod, params, _ = get_network('mobilenet', batch_size=1)
+    tasks = autotvm.task.extract_from_program(mod['main'], target=target,
+                                              params=params,
+                                              ops=(relay.op.nn.conv2d, relay.op.nn.dense))
+    for task in tasks:
+        assert task.config_space.template_key == 'direct'
+
 if __name__ == '__main__':
     test_task_extraction()
+    test_template_key_provided()
+    test_template_key_empty()
+    test_template_key_default()