[TUTORIAL][ANSOR] Using the template-free auto-scheduler on CPU (#6488)
authorLianmin Zheng <lianminzheng@gmail.com>
Thu, 17 Sep 2020 02:14:09 +0000 (19:14 -0700)
committerGitHub <noreply@github.com>
Thu, 17 Sep 2020 02:14:09 +0000 (19:14 -0700)
* add tutorial

* add tutorial

* update

* Apply suggestions from code review

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
* address comments

* fix bugs

* add the exmple for resuming the search

* fix lint

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
docs/api/python/auto_scheduler.rst [new file with mode: 0644]
docs/api/python/autotvm.rst
docs/api/python/index.rst
docs/conf.py
python/tvm/auto_scheduler/__init__.py
python/tvm/auto_scheduler/auto_schedule.py
src/auto_scheduler/search_policy/sketch_policy_rules.cc
tutorials/auto_scheduler/README.txt [new file with mode: 0644]
tutorials/auto_scheduler/tune_matmul_x86.py [new file with mode: 0644]
tutorials/autotvm/README.txt

diff --git a/docs/api/python/auto_scheduler.rst b/docs/api/python/auto_scheduler.rst
new file mode 100644 (file)
index 0000000..85ff22f
--- /dev/null
@@ -0,0 +1,35 @@
+..  Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+..    http://www.apache.org/licenses/LICENSE-2.0
+
+..  Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+tvm.auto_scheduler
+------------------
+.. automodule:: tvm.auto_scheduler
+
+tvm.auto_scheduler.auto_schedule
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. automodule:: tvm.auto_scheduler.auto_schedule
+
+.. autoclass:: tvm.auto_scheduler.auto_schedule.SearchTask
+
+.. autoclass:: tvm.auto_scheduler.auto_schedule.TuningOptions
+
+.. autofunction:: tvm.auto_scheduler.auto_schedule.create_task
+
+.. autofunction:: tvm.auto_scheduler.auto_schedule.auto_schedule
+
+
+
index 9357d1b..5bde9ac 100644 (file)
@@ -18,7 +18,7 @@
 tvm.autotvm
 -----------
 .. automodule:: tvm.autotvm
-.. automodule:: tvm.autotvm.apply_history_best
+.. autofunction:: tvm.autotvm.apply_history_best
 
 tvm.autotvm.measure
 ~~~~~~~~~~~~~~~~~~~
index bc9ec5f..a617968 100644 (file)
@@ -40,6 +40,7 @@ Python API
    relay/dataflow_pattern
    relay/testing
    autotvm
+   auto_scheduler
    rpc
    micro
    contrib
index ca0bc9b..9322f5a 100644 (file)
@@ -193,6 +193,7 @@ subsection_order = ExplicitOrder(
         "../tutorials/language",
         "../tutorials/optimize",
         "../tutorials/autotvm",
+        "../tutorials/auto_scheduler",
         "../tutorials/dev",
         "../tutorials/topi",
         "../tutorials/deployment",
index 43e08a4..2b36287 100644 (file)
@@ -26,7 +26,7 @@ from . import workload_registry
 from . import feature
 
 # Shortcut
-from .auto_schedule import SearchTask, TuningOptions, HardwareParams, auto_schedule
+from .auto_schedule import SearchTask, TuningOptions, HardwareParams, create_task, auto_schedule
 from .compute_dag import ComputeDAG
 from .cost_model import RandomModel, XGBModel
 from .measure import (
index af257f5..eae8b25 100644 (file)
@@ -31,7 +31,10 @@ Candidate schedules are measured against the specific hardware target.
 import tvm._ffi
 from tvm.runtime import Object
 from .measure import LocalBuilder, LocalRunner
-from .search_policy import EmptyPolicy
+from .workload_registry import make_workload_key
+from .compute_dag import ComputeDAG
+from .cost_model import XGBModel
+from .search_policy import SketchPolicy
 from . import _ffi_api
 
 
@@ -89,26 +92,26 @@ class TuningOptions(Object):
     Parameters
     ----------
     num_measure_trials: int = 0
-      The number of measurement trials.
-      The search policy measures `num_measure_trials` schedules in total and returns the best one
-      among them.
-      With `num_measure_trials` == 0, the policy will do the schedule search but won't involve
-      measurement. This can be used to get a runnable schedule quickly without auto-tuning.
+        The number of measurement trials.
+        The search policy measures `num_measure_trials` schedules in total and returns the best one
+        among them.
+        With `num_measure_trials` == 0, the policy will do the schedule search but won't involve
+        measurement. This can be used to get a runnable schedule quickly without auto-tuning.
     early_stopping: Optional[int]
-      Stop the tuning early if getting no improvement after n measurements.
+        Stop the tuning early if getting no improvement after n measurements.
     num_measures_per_round: int = 64
-      The number of schedules to be measured at each search round.
-      The whole schedule search process will try a total number of `num_measure_trials` in several
-      rounds.
+        The number of schedules to be measured at each search round.
+        The whole schedule search process will try a total number of `num_measure_trials` in several
+        rounds.
     verbose: int = 1
-      Verbosity level. 0 for silent, 1 to output information during schedule search.
+        Verbosity level. 0 for silent, 1 to output information during schedule search.
     builder: Union[ProgramBuilder, str] = 'local'
-      ProgramBuilder which builds the program.
+        ProgramBuilder which builds the program.
     runner: Union[ProgramRunner, str] = 'local'
-      ProgramRunner which runs the program and measures time costs.
+        ProgramRunner which runs the program and measures time costs.
     measure_callbacks: Optional[List[MeasureCallback]]
-      Callback functions called after each measurement.
-      Candidates:
+        Callback functions called after each measurement.
+        Candidates:
         - auto_scheduler.RecordToFile
     """
 
@@ -156,16 +159,41 @@ class TuningOptions(Object):
         )
 
 
+def create_task(func, args, target, target_host=None, hardware_params=None):
+    """Create a search task
+
+    Parameters
+    ----------
+    func : Union[Function, str]
+        The function that returns the compute declaration Tensors.
+        Can be the a function or the function name.
+    args : Union[Tuple[Any, ...], List[Any]]
+        The args of the function.
+    target : tvm.target.Target
+        The target device of this search task.
+    target_host : Optional[tvm.target.Target]
+        The target host device of this search task.
+    hardware_params : Optional[HardwareParams]
+        Hardware parameters used in this search task.
+
+    Returns
+    -------
+        SearchTask: the created task
+    """
+    workload_key = make_workload_key(func, args)
+    dag = ComputeDAG(workload_key)
+    return SearchTask(dag, workload_key, target, target_host, hardware_params)
+
+
 def auto_schedule(task, search_policy=None, tuning_options=TuningOptions()):
-    """Do auto scheduling for a computation declaration.
+    """Run auto scheduling search for a task
 
     Parameters
     ----------
     task : SearchTask
         The SearchTask for the computation declaration.
     search_policy : Optional[SearchPolicy]
-        The search policy to be used for schedule search. Use EmptyPolicy as default, which always
-        returns an empty schedule.
+        The search policy to be used for schedule search.
     tuning_options : Optional[TuningOptions]
         Tuning and measurement options.
 
@@ -178,5 +206,9 @@ def auto_schedule(task, search_policy=None, tuning_options=TuningOptions()):
             "Invalid task: " + task + " . `auto_scheduler.auto_schedule` expects a SearchTask."
         )
 
-    sch, tensors = _ffi_api.AutoSchedule(search_policy or EmptyPolicy(task), tuning_options)
+    if search_policy is None:
+        cost_model = XGBModel()
+        search_policy = SketchPolicy(task, cost_model)
+
+    sch, tensors = _ffi_api.AutoSchedule(search_policy, tuning_options)
     return sch, tensors
index 7e7b447..dab6e4d 100644 (file)
@@ -593,7 +593,7 @@ PopulationGenerationRule::ResultKind MutateComputeLocationCommon(SketchPolicyNod
 
 PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy,
                                                                       State* state) const {
-  return MutateComputeLocationCommon(policy, state, false);
+  return MutateComputeLocationCommon(policy, state, true);
 }
 
 PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy,
@@ -1059,7 +1059,7 @@ PopulationGenerationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNo
 
 PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
                                                                   State* state) const {
-  return MutateComputeLocationCommon(policy, state, true);
+  return MutateComputeLocationCommon(policy, state, false);
 }
 
 PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy,
diff --git a/tutorials/auto_scheduler/README.txt b/tutorials/auto_scheduler/README.txt
new file mode 100644 (file)
index 0000000..7598667
--- /dev/null
@@ -0,0 +1,2 @@
+AutoScheduler : Template-free Auto Scheduling
+---------------------------------------------
diff --git a/tutorials/auto_scheduler/tune_matmul_x86.py b/tutorials/auto_scheduler/tune_matmul_x86.py
new file mode 100644 (file)
index 0000000..1a9af42
--- /dev/null
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Auto-scheduling matrix multiplication for CPU
+=============================================
+**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, \
+            `Chengfan Jia <https://github.com/jcf94/>`_
+
+Different from the existing :ref:`autotvm <tutorials-autotvm-sec>` which relies on 
+manual templates to define the search space, the auto-scheduler does not require any templates.
+The auto-scheduler is template-free, so users only need to write the computation declaration without
+any schedule commands or templates.
+The auto-scheduler can automatically generate a large
+search space and find a good schedule in the space.
+
+We use matrix multiplication as an example in this tutorial.
+"""
+
+import numpy as np
+import tvm
+from tvm import te, testing, auto_scheduler
+
+######################################################################
+# Define the computation
+# ^^^^^^^^^^^^^^^^^^^^^^
+# To begin with, we define the computation of a matmul with bias add.
+# The function should return the list of input/output tensors.
+# From these tensors, the auto-scheduler can get the whole computational graph.
+
+
+@auto_scheduler.register_workload
+def matmul_add(N, L, M, dtype):
+    A = te.placeholder((N, L), name="A", dtype=dtype)
+    B = te.placeholder((L, M), name="B", dtype=dtype)
+    C = te.placeholder((N, M), name="C", dtype=dtype)
+
+    k = te.reduce_axis((0, L), name="k")
+    matmul = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul")
+    out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")
+
+    return [A, B, C, out]
+
+
+######################################################################
+# Create the search task
+# ^^^^^^^^^^^^^^^^^^^^^^
+# We then create a search task with N=L=M=128 and dtype="float32"
+
+target = tvm.target.Target("llvm")
+task = auto_scheduler.create_task(matmul_add, (128, 128, 128, "float32"), target)
+
+# Inspect the computational graph
+print(task.compute_dag)
+
+######################################################################
+# Next, we set parameters for the auto-scheduler.
+#
+# * `num_measure_trials` is the number of measurement trials we can use during the search.
+#   We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a
+#   good value for the search to converge. You can do more trials according to your time budget.
+# * In addition, we use `RecordToFile` to dump measurement records into a file `matmul.json`.
+#   The measurement records can be used to query the history best, resume the search,
+#   and do more analyses later.
+# * see :any:`auto_schedule.TuningOptions`: for more parameters
+
+tune_option = auto_scheduler.TuningOptions(
+    num_measure_trials=10, measure_callbacks=[auto_scheduler.RecordToFile("matmul.json")]
+)
+
+######################################################################
+# Run the search
+# ^^^^^^^^^^^^^^
+# Now we get all inputs ready. Pretty simple, isn't it?
+# We can kick off the search and let the auto-scheduler do its magic.
+# After some measurement trials, it will return the best schedule it found.
+
+sch, args = auto_scheduler.auto_schedule(task, tuning_options=tune_option)
+
+######################################################################
+# We can lower the schedule to see the IR after auto-scheduling.
+# The auto-scheduler correctly performs optimizations including multi-level tiling,
+# parallelization, vectorization, unrolling and fusion.
+
+print(tvm.lower(sch, args, simple_mode=True))
+
+######################################################################
+# Check correctness
+# ^^^^^^^^^^^^^^^^^
+# We build the binary and check its correctness
+
+func = tvm.build(sch, args)
+a_np = np.random.uniform(size=(128, 128)).astype(np.float32)
+b_np = np.random.uniform(size=(128, 128)).astype(np.float32)
+c_np = np.random.uniform(size=(128, 128)).astype(np.float32)
+d_np = a_np.dot(b_np) + c_np
+
+d_tvm = tvm.nd.empty(d_np.shape)
+func(tvm.nd.array(a_np), tvm.nd.array(b_np), tvm.nd.array(c_np), d_tvm)
+
+tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-3)
+
+######################################################################
+# Using the record file
+# ^^^^^^^^^^^^^^^^^^^^^
+# During the search, all measuremnt records are dumpped into the record
+# file "matmul.json". The measurement records can be used to re-apply search results,
+# resume the search, and perform other analyses.
+
+######################################################################
+# Here is an example where we load the best schedule from a file,
+# print the equivalent python schedule API, and build the binary again.
+
+# Load the measuremnt record for the best schedule
+inp, res = auto_scheduler.load_best("matmul.json", task.workload_key)
+
+# Print equivalent python schedule API. This can be used for debugging and
+# learning the behavior of the auto-scheduler.
+print(task.compute_dag.print_python_code_from_state(inp.state))
+
+# Rebuild the binary. This shows how you can apply the best schedule from a
+# log file without reruning the search again.
+sch, args = task.compute_dag.apply_steps_from_state(inp.state)
+func = tvm.build(sch, args)
+
+######################################################################
+# A more complicated example is to resume the search.
+# In this case, we need to create the search policy and cost model by ourselves
+# and resume the status of search policy and cost model with the log file.
+# In the example below we resume the status and do more 5 trials.
+
+
+def resume_search(task, log_file):
+    cost_model = auto_scheduler.XGBModel()
+    cost_model.update_from_file(log_file)
+    search_policy = auto_scheduler.SketchPolicy(
+        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
+    )
+    tune_option = auto_scheduler.TuningOptions(
+        num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
+    )
+    sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options=tune_option)
+
+
+# resume_search(task, "matmul.json")
+
+######################################################################
+# .. note::
+#   We cannot run the line above because of the conflict between
+#   python's multiprocessing and tvm's thread pool.
+#   After running a tvm generated binary (L112), the python's multiprocessing
+#   library will hang forever.
+#   You have to make sure that you don't run any tvm generated binaries before
+#   calling ansor's search. To run the L156 above, you should comment out L112-114.
+#
+#   You should be careful about this problem in your applications.
+#   There are other workarounds for this problem.
+#   For example, you can start a new thread/process (with the builtin python library
+#   threading or multiprocessing) and run the tvm binaries in the new thread/process.
+#   This provides an isolation and avoids the conflict in the main thread/process.
index 38e3b33..a1d33ba 100644 (file)
@@ -1,4 +1,4 @@
 .. _tutorials-autotvm-sec:
 
-Auto tuning
------------
+AutoTVM : Template-based Auto Tuning
+------------------------------------