[AutoTVM] Support range in index based tuners (#4870)
authorCody Yu <comaniac0422@gmail.com>
Sat, 15 Feb 2020 04:15:47 +0000 (20:15 -0800)
committerGitHub <noreply@github.com>
Sat, 15 Feb 2020 04:15:47 +0000 (20:15 -0800)
* Support range in index based tuners

* Address comments

* Remove __*state__

* trigger CI

python/tvm/autotvm/tuner/__init__.py
python/tvm/autotvm/tuner/gridsearch_tuner.py [deleted file]
python/tvm/autotvm/tuner/index_based_tuner.py [new file with mode: 0644]
tests/python/unittest/test_autotvm_common.py
tests/python/unittest/test_autotvm_index_tuner.py [new file with mode: 0644]
tests/python/unittest/test_autotvm_measure.py

index c5ad6bf..7ffe9a2 100644 (file)
@@ -25,6 +25,6 @@ from . import callback
 
 from .tuner import Tuner
 
-from .gridsearch_tuner import GridSearchTuner, RandomTuner
+from .index_based_tuner import GridSearchTuner, RandomTuner
 from .ga_tuner import GATuner
 from .xgboost_tuner import XGBTuner
diff --git a/python/tvm/autotvm/tuner/gridsearch_tuner.py b/python/tvm/autotvm/tuner/gridsearch_tuner.py
deleted file mode 100644 (file)
index 4e9a4a2..0000000
+++ /dev/null
@@ -1,85 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=abstract-method
-"""Grid search tuner and random tuner"""
-
-import numpy as np
-
-from .tuner import Tuner
-
-
-class GridSearchTuner(Tuner):
-    """Enumerate the search space in a grid search order"""
-    def __init__(self, task):
-        super(GridSearchTuner, self).__init__(task)
-        self.counter = 0
-
-    def next_batch(self, batch_size):
-        ret = []
-        for _ in range(batch_size):
-            if self.counter >= len(self.task.config_space):
-                continue
-            index = self.counter
-            ret.append(self.task.config_space.get(index))
-            self.counter = self.counter + 1
-        return ret
-
-    def has_next(self):
-        return self.counter < len(self.task.config_space)
-
-    def load_history(self, data_set):
-        pass
-
-    def __getstate__(self):
-        return {"counter": self.counter}
-
-    def __setstate__(self, state):
-        self.counter = state['counter']
-
-
-class RandomTuner(Tuner):
-    """Enumerate the search space in a random order"""
-    def __init__(self, task):
-        super(RandomTuner, self).__init__(task)
-        self.visited = set()
-
-    def next_batch(self, batch_size):
-        ret = []
-        counter = 0
-        while counter < batch_size:
-            if len(self.visited) >= len(self.task.config_space):
-                break
-            index = np.random.randint(len(self.task.config_space))
-            while index in self.visited:
-                index = np.random.randint(len(self.task.config_space))
-
-            ret.append(self.task.config_space.get(index))
-            self.visited.add(index)
-            counter += 1
-        return ret
-
-    def has_next(self):
-        return len(self.visited) < len(self.task.config_space)
-
-    def load_history(self, data_set):
-        pass
-
-    def __getstate__(self):
-        return {"visited": self.counter}
-
-    def __setstate__(self, state):
-        self.counter = state['visited']
diff --git a/python/tvm/autotvm/tuner/index_based_tuner.py b/python/tvm/autotvm/tuner/index_based_tuner.py
new file mode 100644 (file)
index 0000000..99fc9f2
--- /dev/null
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=abstract-method
+"""Grid search tuner and random tuner"""
+
+import numpy as np
+
+from .tuner import Tuner
+
+class IndexBaseTuner(Tuner):
+    """Base class for index based tuner
+    This type of tuner determine the next batch of configs based on config indices.
+
+    Parameters
+    ----------
+    task: autotvm.task.Task
+        The tuning task
+
+    range_idx: Optional[Tuple[int, int]]
+        A tuple of index range that this tuner can select from
+    """
+    def __init__(self, task, range_idx=None):
+        super(IndexBaseTuner, self).__init__(task)
+        assert range_idx is None or isinstance(range_idx, tuple), \
+            "range_idx must be None or (int, int)"
+
+        self.range_length = len(self.task.config_space)
+        self.index_offset = 0
+        if range_idx is not None:
+            assert range_idx[1] > range_idx[0], "Index range must be positive"
+            assert range_idx[0] >= 0, "Start index must be positive"
+            self.range_length = range_idx[1] - range_idx[0] + 1
+            self.index_offset = range_idx[0]
+        self.counter = 0
+
+    def has_next(self):
+        return self.counter < self.range_length
+
+    def load_history(self, data_set):
+        pass
+
+
+class GridSearchTuner(IndexBaseTuner):
+    """Enumerate the search space in a grid search order"""
+
+    def next_batch(self, batch_size):
+        ret = []
+        for _ in range(batch_size):
+            if self.counter >= self.range_length:
+                break
+            index = self.counter + self.index_offset
+            ret.append(self.task.config_space.get(index))
+            self.counter = self.counter + 1
+        return ret
+
+
+class RandomTuner(IndexBaseTuner):
+    """Enumerate the search space in a random order
+
+    Parameters
+    ----------
+    task: autotvm.task.Task
+        Tuning Task
+
+    range_idx: Optional[Tuple[int, int]]
+        A tuple of index range to random
+    """
+    def __init__(self, task, range_idx=None):
+        super(RandomTuner, self).__init__(task, range_idx)
+
+        # Use a dict to mimic a range(n) list without storing rand_state[i] = i entries so that
+        # we can generate non-repetitive random indices.
+        self.rand_state = {}
+        self.rand_max = self.range_length
+        self.visited = []
+
+    def next_batch(self, batch_size):
+        ret = []
+        for _ in range(batch_size):
+            if self.rand_max == 0:
+                break
+
+            # Random an indirect index.
+            index_ = np.random.randint(self.rand_max)
+            self.rand_max -= 1
+
+            # Use the indirect index to get a direct index.
+            index = self.rand_state.get(index_, index_) + self.index_offset
+            ret.append(self.task.config_space.get(index))
+            self.visited.append(index)
+
+            # Update the direct index map.
+            self.rand_state[index_] = self.rand_state.get(self.rand_max, self.rand_max)
+            self.rand_state.pop(self.rand_max, None)
+            self.counter += 1
+        return ret
index 7043e47..fac9f06 100644 (file)
 """Common utilities for testing autotvm"""
 import time
 
+import numpy as np
+
 import tvm
 from tvm import autotvm
 from tvm.autotvm import MeasureInput, MeasureResult
+from tvm.autotvm.measure.measure import Runner
+
+
+class DummyRunner(Runner):
+    def __init__(self):
+        super(DummyRunner, self).__init__(1, 1)
+
+    def run(self, measure_inputs, build_results):
+        return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
+                for _ in range(len(measure_inputs))]
+
+    def get_build_kwargs(self):
+        return {}
 
 @autotvm.template
 def matmul(N, L, M, dtype):
@@ -82,4 +97,3 @@ def get_sample_records(n):
         inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
         ress.append(MeasureResult((i+1,), 0, i, time.time()))
     return list(zip(inps, ress))
-
diff --git a/tests/python/unittest/test_autotvm_index_tuner.py b/tests/python/unittest/test_autotvm_index_tuner.py
new file mode 100644 (file)
index 0000000..c7fa2ea
--- /dev/null
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test index based tuners"""
+
+from test_autotvm_common import DummyRunner, get_sample_task
+from tvm import autotvm
+from tvm.autotvm.tuner import GridSearchTuner, RandomTuner
+
+
+def test_gridsearch_tuner():
+    """Test GridSearchTuner"""
+
+    task, _ = get_sample_task()
+    measure_option = autotvm.measure_option(builder=autotvm.LocalBuilder(), runner=DummyRunner())
+
+    # When no range index, range_length should be the length of config space
+    tuner = autotvm.tuner.GridSearchTuner(task)
+    assert tuner.range_length == len(task.config_space)
+    assert tuner.index_offset == 0
+
+    # With range index, range_length should be the length of the specified range
+    tuner = autotvm.tuner.GridSearchTuner(task, range_idx=(8, 15))
+    assert tuner.range_length == 8
+    assert tuner.index_offset == 8
+
+    # Tuner should only focus on the specified range
+    tuner.tune(n_trial=8, measure_option=measure_option)
+    assert tuner.counter == 8
+    assert not tuner.has_next()
+
+
+def test_random_tuner():
+    """Test RandomTuner"""
+
+    task, _ = get_sample_task()
+    measure_option = autotvm.measure_option(builder=autotvm.LocalBuilder(), runner=DummyRunner())
+
+    tuner = autotvm.tuner.RandomTuner(task, range_idx=(8, 15))
+    assert tuner.range_length == 8
+    assert tuner.index_offset == 8
+
+    # Tuner should only focus on the specified range and should visit all indices
+    tuner.tune(n_trial=8, measure_option=measure_option)
+    assert tuner.counter == 8
+    assert not tuner.has_next()
+    visited = set()
+    for idx in tuner.visited:
+        assert idx not in visited
+        assert 8 <= idx <= 15
+
+
+if __name__ == '__main__':
+    test_gridsearch_tuner()
+    test_random_tuner()
\ No newline at end of file
index 2900948..48a1d31 100644 (file)
@@ -21,24 +21,14 @@ import time
 import numpy as np
 
 import tvm
+from test_autotvm_common import DummyRunner, bad_matmul, get_sample_task
 from tvm import autotvm
-from test_autotvm_common import get_sample_task, bad_matmul
-from tvm.autotvm.measure.measure import Runner, MeasureResult, MeasureErrorNo
+from tvm.autotvm.measure.measure import MeasureErrorNo, MeasureResult
+
 
 def test_task_tuner_without_measurement():
     """test task and tuner without measurement"""
-    task, target = get_sample_task()
-
-    class DummyRunner(Runner):
-        def __init__(self):
-            super(DummyRunner, self).__init__(1, 1)
-
-        def run(self, measure_inputs, build_results):
-            return [MeasureResult((np.random.random(),), 0, 0.2, time.time())
-                    for _ in range(len(measure_inputs))]
-
-        def get_build_kwargs(self):
-            return {}
+    task, _ = get_sample_task()
 
     measure_option = autotvm.measure_option(
         builder=autotvm.LocalBuilder(),
@@ -64,7 +54,7 @@ def test_check_correctness():
     )
 
     def _callback_correct(tuner, measure_inputs, measure_results):
-        for inp, res in zip(measure_inputs, measure_results):
+        for _, res in zip(measure_inputs, measure_results):
             assert res.error_no == 0
 
     tuner = autotvm.tuner.RandomTuner(task)
@@ -77,7 +67,7 @@ def test_check_correctness():
     task = autotvm.task.create(bad_matmul, args=(n, n, n, 'float32'), target=target)
 
     def _callback_wrong(tuner, measure_inputs, measure_results):
-        for inp, res in zip(measure_inputs, measure_results):
+        for _, res in zip(measure_inputs, measure_results):
             assert res.error_no == MeasureErrorNo.WRONG_ANSWER
 
     tuner = autotvm.tuner.RandomTuner(task)
@@ -90,4 +80,3 @@ if __name__ == '__main__':
 
     test_task_tuner_without_measurement()
     test_check_correctness()
-