[Ansor][AutoTVM v2.0] Phase 2: Basic GPU Sketch Search Policy (#6269)
authorChenfan <chengfan.jcf@alibaba-inc.com>
Mon, 24 Aug 2020 10:46:31 +0000 (18:46 +0800)
committerGitHub <noreply@github.com>
Mon, 24 Aug 2020 10:46:31 +0000 (18:46 +0800)
* Add PreloadMeasuredStates & Split search_policy.py

* Add GPU sketch rule

* Update

* Bug fix for log record

* Lint fix

* Update tutorial

* Update

* UT fix

* Remove tutorial

* Update

* Update

* Update UT

* Lint fix

* Update

* Update

18 files changed:
include/tvm/auto_scheduler/search_policy.h
python/tvm/auto_scheduler/__init__.py
python/tvm/auto_scheduler/auto_schedule.py
python/tvm/auto_scheduler/loop_state.py
python/tvm/auto_scheduler/search_policy.py [new file with mode: 0644]
src/auto_scheduler/auto_schedule.cc
src/auto_scheduler/search_policy/search_policy.cc
src/auto_scheduler/search_policy/sketch_policy.cc
src/auto_scheduler/search_policy/sketch_policy.h
src/auto_scheduler/search_policy/sketch_policy_rules.cc
src/auto_scheduler/search_policy/sketch_policy_rules.h
src/auto_scheduler/search_policy/utils.cc
src/auto_scheduler/search_policy/utils.h
src/auto_scheduler/search_task.cc
src/auto_scheduler/transform_step.cc
src/auto_scheduler/utils.h
tests/python/unittest/test_auto_scheduler_search_policy.py
tests/python/unittest/test_auto_scheduler_sketch_generation.py

index 33a58aa..176b10c 100644 (file)
@@ -100,6 +100,35 @@ class SearchCallback : public ObjectRef {
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode);
 };
 
+/*! \brief Preload measured states from a log file.
+ * This can resume the state of the search policy */
+class PreloadMeasuredStatesNode : public SearchCallbackNode {
+ public:
+  /*! \brief The name of the record log file. */
+  String filename;
+
+  void Callback(SearchPolicyNode* policy) final;
+
+  static constexpr const char* _type_key = "auto_scheduler.PreloadMeasuredStates";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode);
+};
+
+/*!
+ * \brief Managed reference to PreloadMeasuredStatesNode.
+ * \sa PreloadMeasuredStatesNode
+ */
+class PreloadMeasuredStates : public SearchCallback {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param filename The name of the record log file.
+   */
+  explicit PreloadMeasuredStates(String filename);
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadMeasuredStates, SearchCallback,
+                                        PreloadMeasuredStatesNode);
+};
+
 /*! \brief Attribute keys of ops used for SearchPolicy. */
 struct SearchPolicyKey {
   /*! \brief Always apply unroll to the inner most iterator of the specificed iterators. */
@@ -142,6 +171,12 @@ class SearchPolicyNode : public Object {
                        ProgramMeasurer measurer) = 0;
 
   /*!
+   * \brief Preload measured states from a log file to resume the state of the search policy.
+   * \param log_file The name of the record log file.
+   */
+  void PreloadMeasuredStates(const String& log_file);
+
+  /*!
    * \brief Call SearchCallback with the current SearchPolicyNode
    * \param callbacks SearchCallback to be called.
    */
index 8d262b9..9ad526c 100644 (file)
@@ -27,11 +27,12 @@ from . import feature
 
 # Shortcut
 from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \
-    auto_schedule, EmptyPolicy, SketchPolicy
+    auto_schedule
 from .compute_dag import ComputeDAG
 from .cost_model import RandomModel, XGBModel
 from .measure import MeasureInput, MeasureResult, LocalBuilder, LocalRunner, RPCRunner, \
     LocalRPCMeasureContext
 from .measure_record import RecordToFile, RecordReader, load_best, \
     load_records, save_records
+from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
 from .workload_registry import register_workload, make_workload_key
index eb5a3fb..2942025 100644 (file)
@@ -28,12 +28,10 @@ uses cost model based evolutionary search to select schedules with the best perf
 Candidate schedules are measured against the specific hardware target.
 """
 
-import random
-
 import tvm._ffi
 from tvm.runtime import Object
 from .measure import LocalBuilder, LocalRunner
-from .cost_model import RandomModel
+from .search_policy import EmptyPolicy
 from . import _ffi_api
 
 
@@ -82,124 +80,6 @@ class SearchTask(Object):
                                             hardware_params)
 
 
-@tvm._ffi.register_object("auto_scheduler.SearchPolicy")
-class SearchPolicy(Object):
-    """ The base class of search policies. """
-
-
-@tvm._ffi.register_object("auto_scheduler.EmptyPolicy")
-class EmptyPolicy(SearchPolicy):
-    """ This is an example empty search policy which will always generate
-    the init state of ComputeDAG.
-
-    Parameters
-    ----------
-    task : SearchTask
-        The SearchTask for the computation declaration.
-    init_search_callbacks : Optional[List[SearchCallback]]
-        Callback functions called before the search process.
-    """
-    def __init__(self, task, init_search_callbacks=None):
-        self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy, task, init_search_callbacks)
-
-
-@tvm._ffi.register_object("auto_scheduler.SketchPolicy")
-class SketchPolicy(SearchPolicy):
-    """  The search policy that searches in a hierarchical search space defined by sketches.
-    The policy randomly samples programs from the space defined by sketches
-    and use evolutionary search to fine-tune them.
-
-    Parameters
-    ----------
-    task : SearchTask
-        The SearchTask for the computation declaration.
-    schedule_cost_model : CostModel = RandomModel()
-        The cost model to estimate the complete schedules.
-    params : Optional[Dict[str, Any]]
-        Parameters of the search policy.
-        See `src/auto_scheduler/search_policy/sketch_search_policy.h` for the definitions.
-        See `DEFAULT_PARAMS` below to find the default values.
-    seed : Optional[int]
-        Random seed.
-    verbose : int = 1
-        Verbosity level. 0 for silent, 1 to output information during schedule search.
-    init_search_callbacks : Optional[List[SearchCallback]]
-        Callback functions called before the search process, usually used to do extra
-        initializations.
-        Possible callbacks:
-            - auto_scheduler.PreloadMeasuredStates
-            - auto_scheduler.PreloadCustomSketchRule
-            TODO(jcf94): Add these search callback implementations.
-    """
-
-    DEFAULT_PARAMS = {
-        "eps_greedy": 0.05,
-
-        'evolutionary_search_population': 2048,
-        "evolutionary_search_use_measured_ratio": 0.2,
-
-        'cpu_multi_level_tiling_structure': 'SSRSRS',
-        'gpu_multi_level_tiling_structure': 'SSSRRSRS',
-
-        'max_innermost_split_factor': 16,
-        'max_vectorize_size': 16,
-
-        'disable_change_compute_location': 0,
-    }
-
-    def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=None, verbose=1,
-                 init_search_callbacks=None):
-        if params is None:
-            params = SketchPolicy.DEFAULT_PARAMS
-        else:
-            for key, value in SketchPolicy.DEFAULT_PARAMS.items():
-                if key not in params:
-                    params[key] = value
-
-        self.__init_handle_by_constructor__(
-            _ffi_api.SketchPolicy, task, schedule_cost_model, params,
-            seed or random.randint(1, 1 << 30), verbose, init_search_callbacks)
-
-    def generate_sketches(self, print_for_debug=False):
-        """ Generate the sketches.
-        This python interface is mainly used for debugging and testing.
-        The actual search is all doen in c++.
-
-        Parameters
-        ----------
-        print_for_debug : bool = False
-            Whether print out the sketches for debug.
-
-        Returns
-        -------
-        sketches : List[State]
-            The generated sketches of this search task.
-        """
-        sketches = _ffi_api.SketchPolicyGenerateSketches(self)
-        if print_for_debug:
-            for i, s in enumerate(sketches):
-                print("=" * 20 + " %d " % i + "=" * 20)
-                print(s)
-        return sketches
-
-    def sample_initial_population(self, pop_size):
-        """Sample initial population.
-        This python interface is mainly used for debugging and testing.
-        The actual search is all doen in c++.
-
-        Parameters
-        ----------
-        pop_size : int
-            The size of sampled population
-
-        Returns
-        -------
-        states: List[State]
-            The sampled states
-        """
-        states = _ffi_api.SketchPolicySampleInitialPopulation(self, pop_size)
-        return states
-
 @tvm._ffi.register_object("auto_scheduler.TuningOptions")
 class TuningOptions(Object):
     """ This controls the options of performance tuning.
index da3a4bf..e26c20f 100644 (file)
@@ -58,6 +58,13 @@ class Iterator(Object):
 class Stage(Object):
     """ A stage in the compute declaration. Similar to tvm.te.schedule.Stage. """
 
+    # Static trans table for compute_at location
+    # This is used to transform the compute_at location to C++ enum
+    COMPUTE_AT_TRANS_TABLE = {
+        "root": 0,
+        "inlined": 1,
+        "iter": 2
+    }
 
 @tvm._ffi.register_object("auto_scheduler.State")
 class StateObject(Object):
@@ -85,7 +92,7 @@ class State:
     This is a wrapper class of StateObject to deal with copy-on-write property
     """
 
-    # Static trans table for thread bind
+    # Static trans table for thread bind and annotation
     # This is used to transform the annotation name to C++ enum
     ANNOTATION_TRANS_TABLE = {
         "none": 0,
diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py
new file mode 100644 (file)
index 0000000..278822e
--- /dev/null
@@ -0,0 +1,180 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+The search policies for TVM Auto-scheduler.
+
+This contains the strategies to generate a schedule automatically. We provide an EmptyPolicy
+which always returns an unchanged initial state, and a more advanced SketchPolicy which can
+deal with various ops/subgraphs on different target devices.
+
+Reference:
+L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
+Programs for Deep Learning." arXiv preprint arXiv:2006.06762 (2020).
+"""
+
+import random
+
+import tvm._ffi
+from tvm.runtime import Object
+from .cost_model import RandomModel
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("auto_scheduler.SearchCallback")
+class SearchCallback(Object):
+    """Callback function before or after search process"""
+
+
+@tvm._ffi.register_object("auto_scheduler.PreloadMeasuredStates")
+class PreloadMeasuredStates(SearchCallback):
+    """ A SearchCallback to load measured states from the log file for a search policy.
+
+    This can resume the state of the search policy:
+        - Making sure an already measured state in former searches will never be measured again.
+        - The history states can be used to speed up the search process(e.g. SketchPolicy uses
+          history states as starting point to perform Evolutionary Search).
+
+    Parameters
+    ----------
+    filename : str
+        The name of the record file.
+    """
+    def __init__(self, filename="auto_scheduler_tuning.json"):
+        self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
+
+
+@tvm._ffi.register_object("auto_scheduler.SearchPolicy")
+class SearchPolicy(Object):
+    """ The base class of search policies. """
+
+
+@tvm._ffi.register_object("auto_scheduler.EmptyPolicy")
+class EmptyPolicy(SearchPolicy):
+    """ This is an example empty search policy which will always generate
+    the init state of ComputeDAG.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The SearchTask for the computation declaration.
+    init_search_callbacks : Optional[List[SearchCallback]]
+        Callback functions called before the search process.
+    """
+    def __init__(self, task, init_search_callbacks=None):
+        self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy, task, init_search_callbacks)
+
+
+@tvm._ffi.register_object("auto_scheduler.SketchPolicy")
+class SketchPolicy(SearchPolicy):
+    """  The search policy that searches in a hierarchical search space defined by sketches.
+    The policy randomly samples programs from the space defined by sketches and use evolutionary
+    search to fine-tune them.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The SearchTask for the computation declaration.
+    schedule_cost_model : CostModel = RandomModel()
+        The cost model to estimate the complete schedules.
+    params : Optional[Dict[str, Any]]
+        Parameters of the search policy.
+        See `src/auto_scheduler/search_policy/sketch_search_policy.h` for the definitions.
+        See `DEFAULT_PARAMS` below to find the default values.
+    seed : Optional[int]
+        Random seed.
+    verbose : int = 1
+        Verbosity level. 0 for silent, 1 to output information during schedule search.
+    init_search_callbacks : Optional[List[SearchCallback]]
+        Callback functions called before the search process, usually used to do extra
+        initializations.
+        Possible callbacks:
+            - auto_scheduler.PreloadMeasuredStates
+            - auto_scheduler.PreloadCustomSketchRule
+            TODO(jcf94): Add these search callback implementations.
+    """
+
+    DEFAULT_PARAMS = {
+        "eps_greedy": 0.05,
+        "retry_search_one_round_on_empty": 10,
+
+        'evolutionary_search_population': 2048,
+        "evolutionary_search_use_measured_ratio": 0.2,
+
+        'cpu_multi_level_tiling_structure': 'SSRSRS',
+        'gpu_multi_level_tiling_structure': 'SSSRRSRS',
+        # Notice: the default thread bind policy of GPU assumes the tiling structure to have at
+        # least 3 spatial tiling levels in outermost
+
+        'max_innermost_split_factor': 16,
+        'max_vectorize_size': 16,
+
+        'disable_change_compute_location': 0,
+    }
+
+    def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=None, verbose=1,
+                 init_search_callbacks=None):
+        if params is None:
+            params = SketchPolicy.DEFAULT_PARAMS
+        else:
+            for key, value in SketchPolicy.DEFAULT_PARAMS.items():
+                if key not in params:
+                    params[key] = value
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.SketchPolicy, task, schedule_cost_model, params,
+            seed or random.randint(1, 1 << 30), verbose, init_search_callbacks)
+
+    def generate_sketches(self, print_for_debug=False):
+        """ Generate the sketches.
+        This python interface is mainly used for debugging and testing.
+        The actual search is all done in c++.
+
+        Parameters
+        ----------
+        print_for_debug : bool = False
+            Whether print out the sketches for debug.
+
+        Returns
+        -------
+        sketches : List[State]
+            The generated sketches of this search task.
+        """
+        sketches = _ffi_api.SketchPolicyGenerateSketches(self)
+        if print_for_debug:
+            for i, s in enumerate(sketches):
+                print("=" * 20 + " %d " % i + "=" * 20)
+                print(s)
+        return sketches
+
+    def sample_initial_population(self, pop_size):
+        """Sample initial population.
+        This python interface is mainly used for debugging and testing.
+        The actual search is all done in c++.
+
+        Parameters
+        ----------
+        pop_size : int
+            The size of sampled population
+
+        Returns
+        -------
+        states: List[State]
+            The sampled states
+        """
+        states = _ffi_api.SketchPolicySampleInitialPopulation(self, pop_size)
+        return states
index 4867959..dd6b705 100644 (file)
@@ -27,6 +27,8 @@
 #include <tvm/auto_scheduler/auto_schedule.h>
 #include <tvm/runtime/registry.h>
 
+#include "utils.h"
+
 namespace tvm {
 namespace auto_scheduler {
 
@@ -56,7 +58,16 @@ std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchPolicy search_poli
   State state =
       search_policy->Search(tuning_options->num_measure_trials, tuning_options->early_stopping,
                             tuning_options->num_measures_per_round, measurer);
-  return search_policy->search_task->compute_dag.ApplySteps(state->transform_steps);
+  if (state.defined()) {
+    return search_policy->search_task->compute_dag.ApplySteps(state->transform_steps);
+  } else {
+    StdCout(tuning_options->verbose)
+        << "No valid state found in this search round. Check if it has traversed all of the "
+        << "search space." << std::endl;
+    // Return the default schedule
+    return {te::Schedule(search_policy->search_task->compute_dag->ops),
+            search_policy->search_task->compute_dag->tensors};
+  }
 }
 
 TVM_REGISTER_GLOBAL("auto_scheduler.TuningOptions")
index f21c8ae..723f8ee 100644 (file)
  * \brief The base class of search policies.
  */
 
+#include <tvm/auto_scheduler/measure_record.h>
 #include <tvm/auto_scheduler/search_policy.h>
 #include <tvm/runtime/registry.h>
 
+#include "utils.h"
+
 namespace tvm {
 namespace auto_scheduler {
 
 TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode);
 TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode);
+TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode);
+
+void SearchPolicyNode::PreloadMeasuredStates(const String& log_file) {
+  RecordReader reader = RecordReader(log_file);
+  const auto& res = reader->ReadLines(-1);
+  size_t log_size = res.first.size();
+  CHECK_EQ(log_size, res.second.size());
+  if (log_size) {
+    Array<State> measured_states;
+    std::vector<float> measured_throughputs;
+    for (size_t i = 0; i < log_size; i++) {
+      const auto& inp = res.first[i];
+      if (inp->task->workload_key == search_task->workload_key &&
+          inp->task->target->kind->name.compare(search_task->target->kind->name) == 0) {
+        State state = search_task->compute_dag->init_state;
+        auto pstate = state.CopyOnWrite();
+        pstate->transform_steps = inp->state->transform_steps;
+        for (const auto& step : pstate->transform_steps) {
+          StepApplyToState(step, &state, search_task->compute_dag);
+        }
+        measured_states.push_back(std::move(state));
+        measured_throughputs.push_back(
+            res.second[i]->error_no == 0 ? (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0);
+      }
+    }
+    measured_states = search_task->compute_dag.InferBound(measured_states);
+    for (size_t i = 0; i < measured_states.size(); i++) {
+      auto& state = measured_states[i];
+      const auto& state_str = state.ToStr();
+      if (!measured_states_set_.count(state_str)) {
+        measured_states_set_.insert(state_str);
+        if (measured_throughputs[i] != 0.0) {
+          measured_states_vector_.emplace_back(std::move(state));
+          measured_states_throughputs_.emplace_back(measured_throughputs[i]);
+        }
+      }
+    }
+
+    StdCout(verbose) << "SearchPolicy: Loaded " << measured_states_set_.size()
+                     << " measurement records from " << log_file << " for "
+                     << search_task->workload_key << std::endl;
+  } else {
+    StdCout(verbose) << "SearchPolicy: No measurement records found in " << log_file << " for "
+                     << search_task->workload_key << std::endl;
+  }
+}
 
 void SearchPolicyNode::RunCallbacks(const Array<SearchCallback>& callbacks) {
   for (const auto& callback : callbacks) {
@@ -37,6 +86,16 @@ void SearchPolicyNode::RunCallbacks(const Array<SearchCallback>& callbacks) {
   }
 }
 
+PreloadMeasuredStates::PreloadMeasuredStates(String filename) {
+  auto node = make_object<PreloadMeasuredStatesNode>();
+  node->filename = std::move(filename);
+  data_ = std::move(node);
+}
+
+void PreloadMeasuredStatesNode::Callback(SearchPolicyNode* policy) {
+  policy->PreloadMeasuredStates(filename);
+}
+
 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks")
     .set_body_typed([](SearchPolicy policy, Optional<Array<SearchCallback>> callbacks) {
       if (callbacks) {
@@ -50,5 +109,9 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetTask")
 TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetVerbose")
     .set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.PreloadMeasuredStates").set_body_typed([](String filename) {
+  return PreloadMeasuredStates(filename);
+});
+
 }  // namespace auto_scheduler
 }  // namespace tvm
index c428cf7..51c138b 100644 (file)
@@ -49,9 +49,12 @@ static RuleSkipStage rule_skip_stage;
 static RuleAlwaysInline rule_always_inline;
 static RuleMultiLevelTiling rule_multi_level_tiling;
 static RuleMultiLevelTilingWithFusion rule_multi_level_tiling_with_fusion;
+static RuleAddCacheRead rule_add_cache_read_stage;
 static RuleAddCacheWrite rule_add_cache_write_stage;
 static RuleAddRfactor rule_add_rfactor;
+static RuleCrossThreadReduction rule_cross_thread_reduction;
 static RuleSimplifyComputeWithConstTensor rule_simplify_compute_with_const_tensor;
+static RuleSpecialComputeLocationGPU rule_special_compute_location_gpu;
 
 /********** Init population rules **********/
 
@@ -60,6 +63,7 @@ static InitChangeComputeLocation init_change_compute_location;
 static InitParallel init_parallel;
 static InitUnroll init_unroll;
 static InitVectorization init_vectorization;
+static InitThreadBind init_thread_bind;
 
 /********** Sketch policy **********/
 
@@ -85,23 +89,45 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model,
     node->RunCallbacks(init_search_callbacks.value());
   }
 
-  // The default sketch rules for CPU policy
   // Notice: Some rules require us to skip all the rest rules after they are applied.
   // So the rules below should be ordered carefully.
-  node->sketch_rules.push_back(&rule_always_inline);
-  node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
-  node->sketch_rules.push_back(&rule_add_rfactor);
-  node->sketch_rules.push_back(&rule_add_cache_write_stage);
-  node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
-  node->sketch_rules.push_back(&rule_multi_level_tiling);
+  if (IsCPUTask(node->search_task)) {
+    // The default sketch rules for CPU policy
+    node->sketch_rules.push_back(&rule_always_inline);
+    node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
+    node->sketch_rules.push_back(&rule_add_rfactor);
+    node->sketch_rules.push_back(&rule_add_cache_write_stage);
+    node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
+    node->sketch_rules.push_back(&rule_multi_level_tiling);
+  } else if (IsCUDATask(node->search_task)) {
+    // The default sketch rules for CUDA policy
+    node->sketch_rules.push_back(&rule_add_cache_read_stage);
+    node->sketch_rules.push_back(&rule_always_inline);
+    node->sketch_rules.push_back(&rule_special_compute_location_gpu);
+    node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
+    node->sketch_rules.push_back(&rule_cross_thread_reduction);
+    node->sketch_rules.push_back(&rule_add_cache_write_stage);
+    node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
+    node->sketch_rules.push_back(&rule_multi_level_tiling);
+  } else {
+    LOG(FATAL) << "No default sketch rules for target: " << task->target;
+  }
   node->sketch_rules.push_back(&rule_skip_stage);  // This should always be the last rule
 
-  // The default init population rules for CPU policy
-  node->init_rules.push_back(&init_fill_tile_size);
-  node->init_rules.push_back(&init_change_compute_location);
-  node->init_rules.push_back(&init_parallel);
-  node->init_rules.push_back(&init_unroll);
-  node->init_rules.push_back(&init_vectorization);
+  node->init_rules.push_back(&init_fill_tile_size);  // This should always be the first rule
+  if (IsCPUTask(node->search_task)) {
+    // The default init population rules for CPU policy
+    node->init_rules.push_back(&init_change_compute_location);
+    node->init_rules.push_back(&init_parallel);
+    node->init_rules.push_back(&init_unroll);
+    node->init_rules.push_back(&init_vectorization);
+  } else if (IsCUDATask(node->search_task)) {
+    // The default init population rules for CUDA policy
+    node->init_rules.push_back(&init_thread_bind);
+    node->init_rules.push_back(&init_unroll);
+  } else {
+    LOG(FATAL) << "No default init rules for target: " << task->target;
+  }
 
   data_ = std::move(node);
 }
@@ -122,6 +148,7 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
     measurer->Reset();
 
     int ct = 0;
+    int empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count);
     Array<MeasureInput> inputs;
     Array<MeasureResult> results;
     while (ct < n_trials) {
@@ -144,10 +171,19 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
       // Also pick some random states to do eps-greedy
       inputs = PickStatesWithEpsGreedy(best_states, random_states, n_trials - ct);
 
-      // Have traversed all of the search space
+      // Currently it's hard to detect if all of the search space has been traversed
+      // Stop if no extra valid states found in several retries
       if (inputs.empty()) {
-        StdCout(verbose) << "All candidates in the search space have been measured." << std::endl;
-        break;
+        if (empty_retry_count-- > 0) {
+          continue;
+        } else {
+          StdCout(verbose) << "It seems all candidates in the search space have been measured."
+                           << std::endl;
+          break;
+        }
+      } else {
+        // Reset the retry count
+        empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count);
       }
 
       // Measure candidate states
@@ -216,7 +252,7 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
 }
 
 Array<State> SketchPolicyNode::GenerateSketches() {
-  State init_state = search_task->compute_dag->init_state;
+  const State& init_state = search_task->compute_dag->init_state;
 
   // Two ping pong buffers to avoid copy
   Array<State> states_buf1{init_state}, states_buf2;
@@ -248,7 +284,7 @@ Array<State> SketchPolicyNode::GenerateSketches() {
             cur_stage_id_map[pair.first] = pair.second;
             pnext->push_back(pair.first);
           }
-          // Skip the reset rules
+          // Skip the rest rules
           if (cond == SketchGenerationRule::ConditionKind::kApplyAndSkipRest) {
             break;
           }
index 104e51e..0c1e6df 100644 (file)
@@ -50,6 +50,8 @@ namespace auto_scheduler {
 struct SketchParamKey {
   /*! \brief Always allocate this percentage of measurements to random sampled states. */
   static constexpr const char* eps_greedy = "eps_greedy";
+  /*! \brief Retry several times if SearchOneRound gets no valid state. */
+  static constexpr const char* empty_retry_count = "retry_search_one_round_on_empty";
 
   struct EvolutionarySearch {
     /*! \brief The population size for evolutionary search. */
index 587e2c7..92073b6 100644 (file)
@@ -38,7 +38,8 @@ namespace auto_scheduler {
 /********** RuleSkipStage **********/
 
 SketchGenerationRule::ConditionKind RuleSkipStage::MeetCondition(const SketchPolicyNode& policy,
-                                                                 const State& state, int stage_id) {
+                                                                 const State& state,
+                                                                 int stage_id) const {
   // This rule should be the last rule, always return true to decrease the stage index count
   return ConditionKind::kApply;
 }
@@ -52,7 +53,7 @@ std::vector<std::pair<State, int>> RuleSkipStage::Apply(const SketchPolicyNode&
 
 SketchGenerationRule::ConditionKind RuleAlwaysInline::MeetCondition(const SketchPolicyNode& policy,
                                                                     const State& state,
-                                                                    int stage_id) {
+                                                                    int stage_id) const {
   const Stage& stage = state->stages[stage_id];
   // Check the inline limitation of TE first
   if (stage->op_type == StageKind::kPlaceholder ||
@@ -60,8 +61,8 @@ SketchGenerationRule::ConditionKind RuleAlwaysInline::MeetCondition(const Sketch
     return ConditionKind::kSkip;
   }
 
-  // TODO(jcf94): Greedily inline all inlinable ops on GPU when introducing GPU search policy.
-  return IsStrictlyInlineable(policy.search_task, state, stage_id)
+  // Always do compute inline if it's strictly inlineable or is in GPU policy
+  return IsStrictlyInlineable(policy.search_task, state, stage_id) || IsGPUTask(policy.search_task)
              ? ConditionKind::kApplyAndSkipRest
              : ConditionKind::kSkip;
 }
@@ -76,7 +77,7 @@ std::vector<std::pair<State, int>> RuleAlwaysInline::Apply(const SketchPolicyNod
 /********** RuleMultiLevelTiling **********/
 
 SketchGenerationRule::ConditionKind RuleMultiLevelTiling::MeetCondition(
-    const SketchPolicyNode& policy, const State& state, int stage_id) {
+    const SketchPolicyNode& policy, const State& state, int stage_id) const {
   return NeedsMultilevelTiling(policy.search_task, state, stage_id)
              ? ConditionKind::kApplyAndSkipRest
              : ConditionKind::kSkip;
@@ -85,9 +86,10 @@ SketchGenerationRule::ConditionKind RuleMultiLevelTiling::MeetCondition(
 std::vector<std::pair<State, int>> RuleMultiLevelTiling::Apply(const SketchPolicyNode& policy,
                                                                const State& state,
                                                                int stage_id) const {
-  // TODO(jcf94): Add support for GPU structure when introducing GPU search policy.
   const std::string& multi_level_tiling_structure =
-      GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
+      IsGPUTask(policy.search_task)
+          ? GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::gpu_structure)
+          : GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
   State tmp_s = DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure);
   return {std::make_pair(std::move(tmp_s), stage_id - 1)};
 }
@@ -95,29 +97,32 @@ std::vector<std::pair<State, int>> RuleMultiLevelTiling::Apply(const SketchPolic
 /********** RuleMultiLevelTilingWithFusion **********/
 
 SketchGenerationRule::ConditionKind RuleMultiLevelTilingWithFusion::MeetCondition(
-    const SketchPolicyNode& policy, const State& state, int stage_id) {
+    const SketchPolicyNode& policy, const State& state, int stage_id) const {
   if (NeedsMultilevelTiling(policy.search_task, state, stage_id) &&
-      HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id, &target_stage_id)) {
-    // Always do fusion for stage with cache_write
-    // TODO(jcf94): Always do fusion on GPU when introducing GPU search policy.
-    return HasCacheWriteStage(state, stage_id) ? ConditionKind::kApplyAndSkipRest
-                                               : ConditionKind::kApply;
+      HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) {
+    // Always do fusion for stage with cache_write or is in GPU policy
+    return HasCacheWriteStage(state, stage_id) || IsGPUTask(policy.search_task)
+               ? ConditionKind::kApplyAndSkipRest
+               : ConditionKind::kApply;
   }
   return ConditionKind::kSkip;
 }
 
 std::vector<std::pair<State, int>> RuleMultiLevelTilingWithFusion::Apply(
     const SketchPolicyNode& policy, const State& state, int stage_id) const {
-  // TODO(jcf94): Add support for GPU structure when introducing GPU search policy.
+  int target_stage_id;
+  CHECK(HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id, &target_stage_id));
   const std::string& multi_level_tiling_structure =
-      GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
+      IsGPUTask(policy.search_task)
+          ? GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::gpu_structure)
+          : GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
   std::vector<int> spatial_split_step_ids;
   State base_state =
       DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure, &spatial_split_step_ids);
 
   std::vector<std::pair<State, int>> ret;
-  // TODO(jcf94): Add follow_tiling_levels for GPU when introducing GPU search policy.
-  std::vector<int> follow_tiling_levels{1, 2};
+  std::vector<int> follow_tiling_levels =
+      IsGPUTask(policy.search_task) ? std::vector<int>{3} : std::vector<int>{1, 2};
   for (int level : follow_tiling_levels) {
     if (tolower(multi_level_tiling_structure[level - 1]) != 's') {
       continue;
@@ -133,18 +138,67 @@ std::vector<std::pair<State, int>> RuleMultiLevelTilingWithFusion::Apply(
   return ret;
 }
 
+/********** RuleAddCacheRead **********/
+
+SketchGenerationRule::ConditionKind RuleAddCacheRead::MeetCondition(const SketchPolicyNode& policy,
+                                                                    const State& state,
+                                                                    int stage_id) const {
+  const SearchTask& task = policy.search_task;
+
+  // Don't cache_read a stage if it has multiple consumers
+  const std::set<int>& consumers = GetConsumers(task, state, stage_id);
+  if (consumers.size() != 1) {
+    return ConditionKind::kSkip;
+  }
+
+  // Don't cache_read a stage if its consumer does not need multi-level tiling
+  int target_stage_id = *consumers.begin();
+  if (!NeedsMultilevelTiling(task, state, target_stage_id)) {
+    return ConditionKind::kSkip;
+  }
+
+  // Don't cache_read a stage if its consumer does cross-thread reduction
+  if (HasCrossThreadReduction(state, target_stage_id)) {
+    return ConditionKind::kSkip;
+  }
+
+  // Only direct producers can be cache read
+  const std::set<int>& producers = GetDirectProducers(task, state, target_stage_id);
+  if (producers.find(stage_id) == producers.end()) {
+    return ConditionKind::kSkip;
+  }
+
+  return ConditionKind::kApplyAndSkipRest;
+}
+
+std::vector<std::pair<State, int>> RuleAddCacheRead::Apply(const SketchPolicyNode& policy,
+                                                           const State& state, int stage_id) const {
+  const SearchTask& task = policy.search_task;
+  const std::set<int>& consumers = GetConsumers(task, state, stage_id);
+  CHECK_EQ(consumers.size(), 1);
+  int target_stage_id = *consumers.begin();
+  State tmp_s = state;
+
+  // Cache read add shared memory
+  int added_stage_id = tmp_s.cache_read(stage_id, "shared", {target_stage_id}, task->compute_dag);
+  target_stage_id++;
+  const auto& share_read_pos =
+      GetLastReduceIteratorInOutermostReduceTile(tmp_s->stages[target_stage_id]);
+  tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos);
+  return {std::make_pair(tmp_s, stage_id)};
+}
+
 /********** RuleAddCacheWrite **********/
 
 SketchGenerationRule::ConditionKind RuleAddCacheWrite::MeetCondition(const SketchPolicyNode& policy,
                                                                      const State& state,
-                                                                     int stage_id) {
+                                                                     int stage_id) const {
   // Add cache write if a stage needs multi-level tiling, but does not have a element-wise
   // matched consumer
   if (NeedsMultilevelTiling(policy.search_task, state, stage_id) &&
       !HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) {
     // An apply and skip rule will be handled in RuleMultiLevelTilingWithFusion
-    // TODO(jcf94): Always do cache_write on GPU when introducing GPU search policy.
-    return ConditionKind::kApply;
+    return IsGPUTask(policy.search_task) ? ConditionKind::kApplyAndSkipRest : ConditionKind::kApply;
   }
 
   return ConditionKind::kSkip;
@@ -162,7 +216,7 @@ std::vector<std::pair<State, int>> RuleAddCacheWrite::Apply(const SketchPolicyNo
 
 SketchGenerationRule::ConditionKind RuleAddRfactor::MeetCondition(const SketchPolicyNode& policy,
                                                                   const State& state,
-                                                                  int stage_id) {
+                                                                  int stage_id) const {
   return (NeedsRfactor(policy.search_task, state, stage_id) && !HasCacheWriteStage(state, stage_id))
              ? ConditionKind::kApply
              : ConditionKind::kSkip;
@@ -211,7 +265,7 @@ std::vector<std::pair<State, int>> RuleAddRfactor::Apply(const SketchPolicyNode&
 /********** RuleSimplifyComputeWithConstTensor **********/
 
 SketchGenerationRule::ConditionKind RuleSimplifyComputeWithConstTensor::MeetCondition(
-    const SketchPolicyNode& policy, const State& state, int stage_id) {
+    const SketchPolicyNode& policy, const State& state, int stage_id) const {
   return state->stages[stage_id]->op->attrs.count(SearchPolicyKey::simplify_const_tensor_indices)
              ? ConditionKind::kApplyAndSkipRest
              : ConditionKind::kSkip;
@@ -254,6 +308,132 @@ std::vector<std::pair<State, int>> RuleSimplifyComputeWithConstTensor::Apply(
   return {std::make_pair(tmp_s, stage_id - 1)};
 }
 
+/********** RuleCrossThreadReduction **********/
+
+SketchGenerationRule::ConditionKind RuleCrossThreadReduction::MeetCondition(
+    const SketchPolicyNode& policy, const State& state, int stage_id) const {
+  CHECK(IsGPUTask(policy.search_task));
+
+  // If it is an intermidiate state created by RuleAddCacheWrite,
+  // we just skip it.
+  if (HasCacheWriteStage(state, stage_id)) {
+    return ConditionKind::kSkip;
+  }
+
+  const auto& op = state->stages[stage_id]->op;
+  if (op->IsInstance<te::ComputeOpNode>()) {
+    // Compute the product of lengths of all space iters and all reduce iters
+    int cum_space_len, cum_reduce_len;
+    std::tie(cum_space_len, cum_reduce_len) =
+        GetCumulativeSpaceAndReductionLengh(state->stages[stage_id]);
+
+    if (NeedsMultilevelTiling(policy.search_task, state, stage_id)) {
+      // Do rfactor if we do not have enough parallelism on space iters
+      return cum_space_len < cum_reduce_len ? ConditionKind::kApply : ConditionKind::kSkip;
+    } else if (cum_reduce_len > 1) {
+      // Try rfactor for other reduction operators
+      return cum_reduce_len > policy.search_task->hardware_params->warp_size ? ConditionKind::kApply
+                                                                             : ConditionKind::kSkip;
+    }
+  }
+
+  return ConditionKind::kSkip;
+}
+
+std::vector<std::pair<State, int>> RuleCrossThreadReduction::Apply(const SketchPolicyNode& policy,
+                                                                   const State& state,
+                                                                   int stage_id) const {
+  const SearchTask& task = policy.search_task;
+  State tmp_s = state;
+
+  // fuse all reduction iters
+  Array<Iterator> space_iters, reduce_iters;
+  Iterator fused_reduce_iter;
+  tmp_s =
+      FuseAllReductionIterators(tmp_s, stage_id, &fused_reduce_iter, &space_iters, &reduce_iters);
+
+  // Check the opportunity for kernel fusion
+  bool fusible = false;
+  int target_stage_id = GetSingleConsumerId(policy.search_task, tmp_s, stage_id);
+  int num_common_outer = -1;
+  if (target_stage_id >= 0) {
+    num_common_outer =
+        GetNumCommonOuterIterator(policy.search_task, tmp_s, stage_id, target_stage_id);
+    if (num_common_outer > 0 &&
+        !NeedsMultilevelTiling(policy.search_task, state, target_stage_id)) {
+      fusible = true;
+    }
+  }
+
+  if (fusible) {
+    const Stage& target_stage = state->stages[target_stage_id];
+    std::vector<int> split_step_ids;
+
+    GetSplitStepIds(tmp_s, target_stage_id, &split_step_ids);
+
+    if (split_step_ids.size() == 0) {
+      // If the target stage does not have split step,
+      // it must be a simple stage without reduce iters.
+      // We then should do a split for it.
+      CHECK(!HasReduceIter(target_stage));
+      const auto& split_res = tmp_s.split(target_stage_id, target_stage->iters.back(),
+                                          {Integer(task->hardware_params->warp_size)});
+      tmp_s.bind(target_stage_id, split_res[1], IteratorAnnotation::kThreadX);
+      split_step_ids.push_back(tmp_s->transform_steps.size() - 2);
+    }
+
+    CHECK_EQ(split_step_ids.size(), 1);
+
+    const Iterator& target_iter = tmp_s->stages[target_stage_id]->iters[num_common_outer - 1];
+    const auto& split_res = tmp_s.follow_split(stage_id, fused_reduce_iter, split_step_ids[0], 1);
+    tmp_s.bind(stage_id, split_res[1], IteratorAnnotation::kThreadX);
+    tmp_s.compute_at(stage_id, target_stage_id, target_iter);
+  } else {
+    const auto& split_res =
+        tmp_s.split(stage_id, fused_reduce_iter, {Integer(task->hardware_params->warp_size)});
+    tmp_s.bind(stage_id, split_res[1], IteratorAnnotation::kThreadX);
+  }
+
+  return {std::make_pair(std::move(tmp_s), stage_id - 1)};
+}
+
+/********** RuleSpecialComputeLocationGPU **********/
+
+SketchGenerationRule::ConditionKind RuleSpecialComputeLocationGPU::MeetCondition(
+    const SketchPolicyNode& policy, const State& state, int stage_id) const {
+  if (GetProducers(policy.search_task, state, stage_id).empty()) {
+    return ConditionKind::kSkip;
+  }
+
+  const std::set<int>& consumers = GetConsumers(policy.search_task, state, stage_id);
+  if (consumers.size() == 1 && state->stages[*consumers.begin()]->op->attrs.count(
+                                   SearchPolicyKey::simplify_const_tensor_indices)) {
+    return ConditionKind::kApplyAndSkipRest;
+  }
+
+  return ConditionKind::kSkip;
+}
+
+std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
+    const SketchPolicyNode& policy, const State& state, int stage_id) const {
+  State tmp_s = state;
+  const std::set<int>& consumers = GetConsumers(policy.search_task, state, stage_id);
+  CHECK_EQ(consumers.size(), 1);
+
+  // Get the last outer space iterator that is not unrolled.
+  const Stage& target_stage = state->stages[*consumers.begin()];
+  for (size_t i = 0; i < target_stage->iters.size(); ++i) {
+    if (target_stage->iters[i]->annotation == IteratorAnnotation::kUnroll) {
+      CHECK_GT(i, 0);
+
+      tmp_s.compute_at(stage_id, *consumers.begin(), target_stage->iters[i - 1]);
+      break;
+    }
+  }
+
+  return {std::make_pair(std::move(tmp_s), stage_id - 1)};
+}
+
 /********** Init Population **********/
 
 InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy,
@@ -473,7 +653,9 @@ InitPopulationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, Sta
 }
 
 InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state) const {
-  std::vector<int> auto_unroll_configs = {0, 16, 64, 512};
+  std::vector<int> auto_unroll_configs = IsGPUTask(policy->search_task)
+                                             ? std::vector<int>({0, 16, 64, 512, 1024})
+                                             : std::vector<int>({0, 16, 64, 512});
   for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
     const Stage& stage = (*state)->stages[stage_id];
     // Skip the inlined stage and placeholder stage
@@ -580,5 +762,155 @@ InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy
   return ResultKind::kValid;
 }
 
+InitPopulationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state) const {
+  std::set<int> multi_level_tiling_root_set;
+  for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
+    if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
+      const Stage& stage = (*state)->stages[stage_id];
+      if (stage->compute_at != ComputeAtKind::kIter) {
+        // This stage is not multi-level tiled,
+        // so it must be produced by RuleCrossThreadReduction.
+        CHECK(HasCrossThreadReduction(*state, stage_id));
+      } else {
+        const auto res = (*state)->attach_map->stage_to_attach_iter.find(stage_id);
+        CHECK(res != (*state)->attach_map->stage_to_attach_iter.end());
+        multi_level_tiling_root_set.insert(res->second.first);
+      }
+    }
+  }
+
+  *state = policy->search_task->compute_dag.InferBound(*state);
+
+  for (int stage_id = (*state)->stages.size() - 1; stage_id >= 0; --stage_id) {
+    const Stage& stage = (*state)->stages[stage_id];
+
+    if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) {
+      continue;
+    }
+
+    // Deal with the cross-thread reduction generated by RuleCrossThreadReduction
+    if (HasCrossThreadReduction(*state, stage_id)) {
+      Iterator fused_it;
+      *state = std::move(FuseAllOuterSpaceIterators(*state, stage_id, &fused_it));
+      state->bind(stage_id, fused_it, IteratorAnnotation::kBlockX);
+      continue;
+    }
+
+    // Skip if this stage has already been annotaed with threadIdx.x
+    if (HasAnnotatedIter(stage, IteratorAnnotation::kThreadX)) {
+      continue;
+    }
+
+    if (stage->compute_at == ComputeAtKind::kRoot) {
+      // This stage has not been tiled, but in GPU schedule, we must tile the root stage
+      // to do thread binding
+      if (!multi_level_tiling_root_set.count(stage_id)) {
+        Iterator fused_it;
+        *state = FuseAllOuterSpaceIterators(*state, stage_id, &fused_it);
+
+        if (GetExtent(fused_it) <= policy->search_task->hardware_params->warp_size) {
+          state->bind(stage_id, fused_it, IteratorAnnotation::kThreadX);
+        } else {
+          // Set threadIdx.x = default_warp_size by default.
+          // The later EvolutionarySearch will try more possiblity
+          const auto& split_its = state->split(
+              stage_id, fused_it, {Integer(policy->search_task->hardware_params->warp_size)});
+          state->bind(stage_id, split_its[0], IteratorAnnotation::kBlockX);
+          state->bind(stage_id, split_its[1], IteratorAnnotation::kThreadX);
+        }
+        continue;
+      }
+
+      // Otherwise, this is a tiled root stage, we assume it should be tiled with 3 space level
+      // in the outer iterators.
+      // The remaining part deals with the thread binding for multi-level tiled stages
+      auto pop = stage->op.as<te::ComputeOpNode>();
+      std::vector<Iterator> to_fuse;
+      int total_space_extent = 1;
+      for (const auto& i : pop->root_iter_vars()) {
+        CHECK(i->dom.defined());
+        const auto& pint = i->dom->extent.as<IntImmNode>();
+        CHECK(pint);
+        total_space_extent *= pint->value;
+      }
+
+      // Check if the total space extent is too small for multi-level thread binding
+      if (total_space_extent <= policy->search_task->hardware_params->warp_size) {
+        Iterator fused_it;
+        *state = FuseAllOuterSpaceIterators(*state, stage_id, &fused_it);
+        state->bind(stage_id, fused_it, IteratorAnnotation::kThreadX);
+        continue;
+      }
+
+      // Fuse the outermost space tile as blockIdx
+      for (size_t i = 0; i < pop->axis.size(); i++) {
+        const auto& it = (*state)->stages[stage_id]->iters[i];
+        // There may be some iterators that are marked with no split, stop if reaches next
+        // tiling level
+        if (!StrEndsWith(it->name, ".0")) {
+          break;
+        }
+        to_fuse.push_back(it);
+      }
+      const auto& blockidx_it = state->fuse(stage_id, to_fuse);
+      state->bind(stage_id, blockidx_it, IteratorAnnotation::kBlockX);
+
+      // Fuse the second outermost space tile as vthread
+      to_fuse.clear();
+      for (size_t i = 1; i < pop->axis.size() + 1; i++) {
+        const auto& it = (*state)->stages[stage_id]->iters[i];
+        // There may be some iterators that are marked with no split, stop if reaches next
+        // tiling level
+        if (!StrEndsWith(it->name, ".1")) {
+          break;
+        }
+        to_fuse.push_back((*state)->stages[stage_id]->iters[i]);
+      }
+      const auto& vthread_it = state->fuse(stage_id, to_fuse);
+      if (GetExtent(vthread_it) > policy->search_task->hardware_params->max_vthread_extent) {
+        return ResultKind::kInvalid;
+      }
+      state->bind(stage_id, vthread_it, IteratorAnnotation::kVThread);
+
+      // Fuse the third outermost space tile as threadIdx
+      to_fuse.clear();
+      for (size_t i = 2; i < pop->axis.size() + 2; i++) {
+        const auto& it = (*state)->stages[stage_id]->iters[i];
+        // There may be some iterators that are marked with no split, stop if reaches next
+        // tiling level
+        if (!StrEndsWith(it->name, ".2")) {
+          break;
+        }
+        to_fuse.push_back((*state)->stages[stage_id]->iters[i]);
+      }
+      const auto& threadidx_it = state->fuse(stage_id, to_fuse);
+      if (GetExtent(threadidx_it) < policy->search_task->hardware_params->warp_size) {
+        return ResultKind::kInvalid;
+      }
+      state->bind(stage_id, threadidx_it, IteratorAnnotation::kThreadX);
+    } else if (stage->compute_at == ComputeAtKind::kIter &&
+               StrEndsWith(stage->op->name, ".shared")) {
+      // Do cooperative fetching for the cache read stage.
+      // Get spatial_split_step_ids from the root stage
+      const auto& it = (*state)->attach_map->stage_to_attach_iter.find(stage_id);
+      CHECK(it != (*state)->attach_map->stage_to_attach_iter.end());
+      Array<Integer> spatial_split_step_ids = GetSpatialSplitStepIds(*state, it->second.first);
+
+      // Fuse all iterators to do cooperative fetching
+      Iterator fused = state->fuse(stage_id, (*state)->stages[stage_id]->iters);
+      // Split out an extra iterator for vectorization
+      // The later EvolutionarySearch will try more possiblity
+      const auto& iters0 = state->split(stage_id, fused, {Integer(1)});
+      state->vectorize(stage_id, iters0[1]);
+      // Follow split to keep a same thread extent with the root stage
+      const auto& iters1 =
+          state->follow_fused_split(stage_id, iters0[0], spatial_split_step_ids, 1, true);
+      state->bind(stage_id, iters1[1], IteratorAnnotation::kThreadX);
+    }
+  }
+
+  return ResultKind::kValid;
+}
+
 }  // namespace auto_scheduler
 }  // namespace tvm
index dac186d..5ddfd18 100644 (file)
@@ -59,7 +59,7 @@ class SketchGenerationRule {
    * \return The condition check result of this rule.
    */
   virtual ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
-                                      int stage_id) = 0;
+                                      int stage_id) const = 0;
 
   /*!
    * \brief Apply function of this rule.
@@ -73,84 +73,51 @@ class SketchGenerationRule {
                                                    const State& state, int stage_id) const = 0;
 };
 
+#define DEFINE_SKETCH_GENERATION_RULE(rule_name)                                                 \
+  class rule_name : public SketchGenerationRule {                                                \
+   public:                                                                                       \
+    ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,              \
+                                int stage_id) const final;                                       \
+    std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state, \
+                                             int stage_id) const final;                          \
+  };
+
 /*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to
  * the next stage. */
-class RuleSkipStage : public SketchGenerationRule {
- public:
-  ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
-                              int stage_id) final;
-
-  std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
-                                           int stage_id) const final;
-};
+DEFINE_SKETCH_GENERATION_RULE(RuleSkipStage);
 
 /*! \brief The rule that inlines simple elementwise ops.
  * \note This rule only inlines the strictly inlineable stages. Stages marked as not strictly
  * inlineable will have a chance to try different compute at location in InitPopulation later.
  */
-class RuleAlwaysInline : public SketchGenerationRule {
- public:
-  ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
-                              int stage_id) final;
-
-  std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
-                                           int stage_id) const final;
-};
+DEFINE_SKETCH_GENERATION_RULE(RuleAlwaysInline);
 
 /*! \brief The rule that performs multi-level tiling. */
-class RuleMultiLevelTiling : public SketchGenerationRule {
- public:
-  ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
-                              int stage_id) final;
+DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTiling);
 
-  std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
-                                           int stage_id) const final;
-};
+/*! \brief The rule that performs multi-level tiling and fuses later consumers. */
+DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTilingWithFusion);
 
-/*! The rule that performs multi-level tiling and fuses later consumers. */
-class RuleMultiLevelTilingWithFusion : public SketchGenerationRule {
- public:
-  ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
-                              int stage_id) final;
-
-  std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
-                                           int stage_id) const final;
-
- private:
-  int target_stage_id;
-};
+/*! \brief The rule that adds a cache read stage. Mainly used for GPU cooperative fetching,
+ * Currently only support 1 to 1 match cache read. */
+DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheRead);
 
 /*! \brief The rule that adds a cache write stage. */
-class RuleAddCacheWrite : public SketchGenerationRule {
- public:
-  ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
-                              int stage_id) final;
-
-  std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
-                                           int stage_id) const final;
-};
+DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheWrite);
 
 /*! \brief The rule that adds rfactor stage. */
-class RuleAddRfactor : public SketchGenerationRule {
- public:
-  ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
-                              int stage_id) final;
-
-  std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
-                                           int stage_id) const final;
-};
+DEFINE_SKETCH_GENERATION_RULE(RuleAddRfactor);
 
 /*! \brief The rule that deals with compute ops that perform "fake reduction" with const tensors.
- * This kind of op comes from winograd transformation.
- */
-class RuleSimplifyComputeWithConstTensor : public SketchGenerationRule {
- public:
-  ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
-                              int stage_id) final;
+ * This kind of op comes from winograd transformation. */
+DEFINE_SKETCH_GENERATION_RULE(RuleSimplifyComputeWithConstTensor);
 
-  std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
-                                           int stage_id) const final;
-};
+/*! \brief The rule that use cross thread reduction for GPU. */
+DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction);
+
+/*! \brief Handle special cases in Winograd transformation for GPU. We need to change the compute
+ * location of the producers of compute ops that perform "fake reduction" with const tensors. */
+DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU);
 
 /********** Init Population **********/
 
@@ -170,36 +137,30 @@ class InitPopulationRule {
   virtual ResultKind Apply(SketchPolicyNode* policy, State* state) const = 0;
 };
 
+#define DEFINE_INIT_POPULATION_RULE(rule_name)                            \
+  class rule_name : public InitPopulationRule {                           \
+   public:                                                                \
+    ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \
+  };
+
 /*! \brief The rule that fills the incomplete SplitSteps. */
-class InitFillTileSize : public InitPopulationRule {
- public:
-  ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-};
+DEFINE_INIT_POPULATION_RULE(InitFillTileSize);
 
 /*! \brief The rule that randomly changes the computation location for some stages, which do not
  * need tiling and are not strictly inlineable(e.g. data padding). */
-class InitChangeComputeLocation : public InitPopulationRule {
- public:
-  ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-};
+DEFINE_INIT_POPULATION_RULE(InitChangeComputeLocation);
 
 /*! \brief The rule that annotates parallel for CPU. */
-class InitParallel : public InitPopulationRule {
- public:
-  ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-};
+DEFINE_INIT_POPULATION_RULE(InitParallel);
 
 /*! \brief The rule that annotates unroll. */
-class InitUnroll : public InitPopulationRule {
- public:
-  ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-};
+DEFINE_INIT_POPULATION_RULE(InitUnroll);
 
 /*! \brief The rule that annotates vectorization. */
-class InitVectorization : public InitPopulationRule {
- public:
-  ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-};
+DEFINE_INIT_POPULATION_RULE(InitVectorization);
+
+/*! \brief The rule that annotates thread binding for GPU. */
+DEFINE_INIT_POPULATION_RULE(InitThreadBind);
 
 }  // namespace auto_scheduler
 }  // namespace tvm
index 6c2e68d..b3f07b1 100644 (file)
 namespace tvm {
 namespace auto_scheduler {
 
+Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id) {
+  const auto& stage = s->stages[stage_id];
+  const auto& pop = s->stages[stage_id]->op.as<te::ComputeOpNode>();
+  CHECK(pop != nullptr);
+  const std::set<std::string>& no_split_at_inner_name_set =
+      stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
+          ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
+          : std::set<std::string>();
+  size_t reduce_count = 0;
+  for (const auto axis : pop->reduce_axis) {
+    if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
+      reduce_count++;
+    }
+  }
+
+  Array<Integer> spatial_split_step_ids;
+  for (int i = s->transform_steps.size() - 1; i >= 0; --i) {
+    if (s->transform_steps[i]->IsInstance<CacheWriteStepNode>() ||
+        s->transform_steps[i]->IsInstance<CacheReadStepNode>() ||
+        s->transform_steps[i]->IsInstance<RfactorStepNode>()) {
+      if (stage_id > s->transform_steps[i]->stage_id) {
+        stage_id--;
+      }
+    } else if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
+      if (stage_id == ps->stage_id) {
+        // Assume SplitStep on reduction axes are always after SplitStep on spatial axes.
+        if (reduce_count) {
+          reduce_count--;
+        } else {
+          spatial_split_step_ids.push_back(i);
+        }
+      }
+    }
+  }
+
+  return spatial_split_step_ids;
+}
+
 State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
                          std::vector<int>* spatial_split_step_ids) {
   // Temporal object to be used if the input pointer is nullptr
@@ -282,5 +320,22 @@ const std::vector<int>& SplitFactorizationMemo::GetFactors(int n) {
   return res;
 }
 
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled")
+    .set_body_typed([](const Stage& stage) { return IsTiled(stage); });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheReadStage")
+    .set_body_typed([](const State& s, int stage_id) { return HasCacheReadStage(s, stage_id); });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheWriteStage")
+    .set_body_typed([](const State& s, int stage_id) { return HasCacheWriteStage(s, stage_id); });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasRfactorStage")
+    .set_body_typed([](const State& s, int stage_id) { return HasRfactorStage(s, stage_id); });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCrossThreadReduction")
+    .set_body_typed([](const State& s, int stage_id) {
+      return HasCrossThreadReduction(s, stage_id);
+    });
+
 }  // namespace auto_scheduler
 }  // namespace tvm
index 6814d25..2d49ab0 100644 (file)
 namespace tvm {
 namespace auto_scheduler {
 
+/*! \brief Return whether the search task is targeting a CPU. */
+inline bool IsCPUTask(const SearchTask& task) {
+  return (task)->target->kind->device_type == kDLCPU;
+}
+
+/*! \brief Return whether the search task is targeting a GPU. */
+inline bool IsGPUTask(const SearchTask& task) {
+  return (task)->target->kind->device_type == kDLGPU ||
+         (task)->target->kind->device_type == kDLOpenCL ||
+         (task)->target->kind->device_type == kDLVulkan ||
+         (task)->target->kind->device_type == kDLMetal ||
+         (task)->target->kind->device_type == kDLROCM ||
+         (task)->target->kind->device_type == kOpenGL;
+}
+
+/*! \brief Return whether the search task is targeting a CUDA GPU. */
+inline bool IsCUDATask(const SearchTask& task) {
+  return (task)->target->kind->device_type == kDLGPU;
+}
+
+/*! \brief Return whether the search task is targeting a OpenCL GPU. */
+inline bool IsOpenCLTask(const SearchTask& task) {
+  return (task)->target->kind->device_type == kDLOpenCL;
+}
+
 /*! \brief Argsort. Order: largest to smallest */
 template <typename T>
 inline std::vector<int> Argsort(const std::vector<T>& scores) {
@@ -354,6 +379,26 @@ inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, const St
   return false;
 }
 
+/*! \brief Return whether the state does cache_read for stage_id. */
+inline bool HasCacheReadStage(const State& s, int stage_id) {
+  for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
+    if (auto ps = s->transform_steps[i].as<CacheReadStepNode>()) {
+      if (stage_id == ps->stage_id) {
+        return true;
+      }
+    }
+
+    if (s->transform_steps[i]->IsInstance<CacheWriteStepNode>() ||
+        s->transform_steps[i]->IsInstance<CacheReadStepNode>() ||
+        s->transform_steps[i]->IsInstance<RfactorStepNode>()) {
+      if (stage_id > s->transform_steps[i]->stage_id) {
+        stage_id--;
+      }
+    }
+  }
+  return false;
+}
+
 /*! \brief Return whether the state does cache_write for stage_id. */
 inline bool HasCacheWriteStage(const State& s, int stage_id) {
   for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
@@ -374,6 +419,59 @@ inline bool HasCacheWriteStage(const State& s, int stage_id) {
   return false;
 }
 
+/*! \brief Return whether the state does rfactor for stage_id. */
+inline bool HasRfactorStage(const State& s, int stage_id) {
+  for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
+    if (auto ps = s->transform_steps[i].as<RfactorStepNode>()) {
+      if (stage_id == ps->stage_id) {
+        return true;
+      }
+    }
+
+    if (s->transform_steps[i]->IsInstance<CacheWriteStepNode>() ||
+        s->transform_steps[i]->IsInstance<CacheReadStepNode>() ||
+        s->transform_steps[i]->IsInstance<RfactorStepNode>()) {
+      if (stage_id > s->transform_steps[i]->stage_id) {
+        stage_id--;
+      }
+    }
+  }
+  return false;
+}
+
+/*! \brief Return whether the stage does cross thread reduction. */
+inline bool HasCrossThreadReduction(const State& state, int stage_id) {
+  std::function<bool(const Stage&)> check_stage = [](const Stage& in_stage) {
+    for (const auto& iter : in_stage->iters) {
+      if (iter->annotation == IteratorAnnotation::kThreadX &&
+          iter->iter_kind == IteratorKind::kReduction) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  // Check the stage itself
+  if (check_stage(state->stages[stage_id])) {
+    return true;
+  }
+
+  // Check the attached stages
+  for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); iter_id++) {
+    const auto& res =
+        state->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id));
+    if (res != state->attach_map->iter_to_attached_stages.end()) {
+      for (int attached_stage_id : res->second) {
+        if (check_stage(state->stages[attached_stage_id])) {
+          return true;
+        }
+      }
+    }
+  }
+
+  return false;
+}
+
 /*! \brief Return whether the stage has been tiled already. */
 inline bool IsTiled(const Stage& stage) {
   auto op = stage->op.as<te::ComputeOpNode>();
@@ -399,6 +497,63 @@ inline void ExtractOriginalIterators(const std::string& name, std::set<std::stri
   }
 }
 
+/*! \brief Get the last reduce iterator in the outermost reduce tile. */
+inline Iterator GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) {
+  auto pop = stage->op.as<te::ComputeOpNode>();
+  CHECK(pop != nullptr);
+  std::set<std::string> original_names;
+
+  const std::set<std::string>& no_split_at_inner_name_set =
+      stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
+          ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
+          : std::set<std::string>();
+  size_t reduce_axis_size = 0;
+  for (const auto axis : pop->reduce_axis) {
+    if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
+      reduce_axis_size++;
+    }
+  }
+  if (reduce_axis_size) {
+    for (const auto& iter : stage->iters) {
+      if (iter->iter_kind == IteratorKind::kReduction) {
+        ExtractOriginalIterators(iter->name, &original_names);
+        if (original_names.size() == reduce_axis_size) {
+          return iter;
+        }
+      }
+    }
+  } else {
+    // Return the first reduce iterator
+    for (const auto& iter : stage->iters) {
+      if (iter->iter_kind == IteratorKind::kReduction) {
+        return iter;
+      }
+    }
+  }
+
+  LOG(FATAL) << "Cannot find the iterator.";
+  return stage->iters[0];
+}
+
+/*! \brief Get all split steps for one stage. */
+inline void GetSplitStepIds(const State& s, int stage_id, std::vector<int>* split_step_ids) {
+  for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
+    if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
+      if (stage_id == ps->stage_id) {
+        split_step_ids->push_back(i);
+      }
+    }
+
+    if (s->transform_steps[i]->IsInstance<CacheWriteStepNode>() ||
+        s->transform_steps[i]->IsInstance<CacheReadStepNode>() ||
+        s->transform_steps[i]->IsInstance<RfactorStepNode>()) {
+      if (stage_id > s->transform_steps[i]->stage_id) {
+        stage_id--;
+      }
+    }
+  }
+}
+
 /*! \brief Fuse all reduction iterators. */
 inline State FuseAllReductionIterators(const State& state, int stage_id, Iterator* fused_iter,
                                        Array<Iterator>* space_iters,
@@ -424,6 +579,32 @@ inline State FuseAllReductionIterators(const State& state, int stage_id, Iterato
   return tmp_s;
 }
 
+/*! \brief Fuse all outer level space iterators. */
+inline State FuseAllOuterSpaceIterators(const State& state, int stage_id, Iterator* fused_iter) {
+  std::vector<Iterator> to_fuse;
+  for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); ++iter_id) {
+    const auto& it = state->stages[stage_id]->iters[iter_id];
+    // Stop at reduce iterator or annotated iterator
+    if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) {
+      break;
+    }
+    // Stop at compute_at attach point
+    if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id - 1))) {
+      break;
+    }
+    to_fuse.push_back(it);
+  }
+
+  CHECK(!to_fuse.empty());
+  State tmp_s = state;
+  if (to_fuse.size() > 1) {
+    *fused_iter = tmp_s.fuse(stage_id, to_fuse);
+  } else {
+    *fused_iter = to_fuse[0];
+  }
+  return tmp_s;
+}
+
 /*! \brief Random sample states. */
 inline Array<State> RandomSampleStates(const Array<State>& in_states, std::mt19937* random_gen,
                                        size_t out_size) {
@@ -464,6 +645,9 @@ class SplitFactorizationMemo {
   std::unordered_map<int, std::vector<int>> factor_memory_;
 };
 
+/*! \brief Get the indexes of SplitStep that processes on spatial iteratior. */
+Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id);
+
 // Apply multi-level tiling structure according to a string format,
 // where "S" stands a space level, "R" stands for a reudciton level.
 // For example, if the format is "SSRSRS", the we will
index e632d4e..e3f35e9 100644 (file)
@@ -23,6 +23,7 @@
  */
 
 #include <tvm/auto_scheduler/search_task.h>
+#include <tvm/runtime/device_api.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/threading_backend.h>
 
@@ -44,8 +45,33 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_l
 
 HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target,
                                                             const Target& target_host) {
-  if (target->kind->name == "llvm") {
+  if (target->kind->device_type == kDLCPU) {
     return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64);
+  } else if (target->kind->device_type == kDLGPU) {
+    auto hardware_params = HardwareParams(-1, 16, 64);
+    auto* p_hardware_params = hardware_params.CopyOnWrite();
+
+    auto ctx = TVMContext{kDLGPU, 0};
+    auto func = tvm::runtime::Registry::Get("device_api.gpu");
+    CHECK(func != nullptr) << "Cannot find GPU device_api in registry";
+    auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
+
+    tvm::runtime::TVMRetValue ret;
+    device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
+    p_hardware_params->max_shared_memory_per_block = ret;
+
+    device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret);
+    p_hardware_params->max_registers_per_block = ret;
+
+    device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
+    p_hardware_params->max_threads_per_block = ret;
+
+    device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
+    p_hardware_params->warp_size = ret;
+
+    p_hardware_params->max_vthread_extent = p_hardware_params->warp_size / 4;
+
+    return hardware_params;
   } else {
     LOG(FATAL) << "No default hardware parameters for target: " << target;
   }
index e533a7c..02a28b1 100644 (file)
 
 #include "utils.h"
 
+namespace dmlc {
+namespace json {
+
+template <>
+struct Handler<::tvm::Array<::tvm::Integer>> {
+  inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::Integer>& array) {
+    writer->BeginArray(false);
+    for (const auto& i : array) {
+      CHECK(i.defined());
+      writer->WriteArrayItem(i->value);
+    }
+    writer->EndArray();
+  }
+  inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::Integer>* array) {
+    array->clear();
+    reader->BeginArray();
+    while (reader->NextArrayItem()) {
+      int value;
+      Handler<int>::Read(reader, &value);
+      array->push_back(value);
+    }
+  }
+};
+
+template <>
+struct Handler<::tvm::Array<::tvm::Optional<::tvm::Integer>>> {
+  inline static void Write(dmlc::JSONWriter* writer,
+                           const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& array) {
+    writer->BeginArray(false);
+    for (const auto& i : array) {
+      CHECK(i);
+      writer->WriteArrayItem(i.value()->value);
+    }
+    writer->EndArray();
+  }
+  inline static void Read(dmlc::JSONReader* reader,
+                          ::tvm::Array<::tvm::Optional<::tvm::Integer>>* array) {
+    array->clear();
+    reader->BeginArray();
+    while (reader->NextArrayItem()) {
+      int value;
+      Handler<int>::Read(reader, &value);
+      array->push_back(::tvm::Integer(value));
+    }
+  }
+};
+
+}  // namespace json
+}  // namespace dmlc
+
 namespace tvm {
 namespace auto_scheduler {
 
@@ -371,15 +421,9 @@ FuseStep::FuseStep(dmlc::JSONReader* reader) {
   s = reader->NextArrayItem();
   CHECK(s);
   reader->Read(&node->stage_id);
-  std::vector<int> int_list;
   s = reader->NextArrayItem();
   CHECK(s);
-  reader->Read(&int_list);
-  ::tvm::Array<::tvm::Integer> fused_ids;
-  for (const auto& i : int_list) {
-    fused_ids.push_back(i);
-  }
-  node->fused_ids = fused_ids;
+  reader->Read(&node->fused_ids);
   data_ = std::move(node);
 }
 
@@ -387,7 +431,7 @@ void FuseStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
   writer->WriteArraySeperator();
   writer->WriteString(record_prefix_str);
   writer->WriteArrayItem(stage_id);
-  writer->WriteArrayItem(IntArrayToVector(fused_ids));
+  writer->WriteArrayItem(fused_ids);
 }
 
 Iterator FuseStepNode::ApplyToState(State* state) const {
@@ -638,13 +682,7 @@ ReorderStep::ReorderStep(dmlc::JSONReader* reader) {
   reader->Read(&node->stage_id);
   s = reader->NextArrayItem();
   CHECK(s);
-  std::vector<int> int_list;
-  reader->Read(&int_list);
-  ::tvm::Array<::tvm::Integer> after_ids;
-  for (const auto& i : int_list) {
-    after_ids.push_back(i);
-  }
-  node->after_ids = after_ids;
+  reader->Read(&node->after_ids);
   data_ = std::move(node);
 }
 
@@ -652,7 +690,7 @@ void ReorderStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
   writer->WriteArraySeperator();
   writer->WriteString(record_prefix_str);
   writer->WriteArrayItem(stage_id);
-  writer->WriteArrayItem(IntArrayToVector(after_ids));
+  writer->WriteArrayItem(after_ids);
 }
 
 void ReorderStepNode::ApplyToState(State* state) const {
@@ -887,13 +925,7 @@ SplitStep::SplitStep(dmlc::JSONReader* reader) {
   }
   s = reader->NextArrayItem();
   CHECK(s);
-  std::vector<int> int_list;
-  reader->Read(&int_list);
-  ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths;
-  for (const auto& i : int_list) {
-    lengths.push_back(::tvm::Integer(i));
-  }
-  node->lengths = lengths;
+  reader->Read(&node->lengths);
   s = reader->NextArrayItem();
   CHECK(s);
   reader->Read(&node->inner_to_outer);
@@ -906,7 +938,7 @@ void SplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
   writer->WriteArrayItem(stage_id);
   writer->WriteArrayItem(iter_id);
   writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0);
-  writer->WriteArrayItem(IntArrayToVector(lengths));
+  writer->WriteArrayItem(lengths);
   writer->WriteArrayItem(static_cast<int>(inner_to_outer));
 }
 
@@ -1044,19 +1076,13 @@ FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
   reader->Read(&node->iter_id);
   s = reader->NextArrayItem();
   CHECK(s);
-  std::vector<int> int_list;
-  reader->Read(&int_list);
+  reader->Read(&node->src_step_ids);
   s = reader->NextArrayItem();
   CHECK(s);
   reader->Read(&node->level);
   s = reader->NextArrayItem();
   CHECK(s);
   reader->Read(&node->factor_or_nparts);
-  ::tvm::Array<::tvm::Integer> src_step_ids;
-  for (const auto& i : int_list) {
-    src_step_ids.push_back(i);
-  }
-  node->src_step_ids = src_step_ids;
   data_ = std::move(node);
 }
 
@@ -1065,7 +1091,7 @@ void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
   writer->WriteString(record_prefix_str);
   writer->WriteArrayItem(stage_id);
   writer->WriteArrayItem(iter_id);
-  writer->WriteArrayItem(IntArrayToVector(src_step_ids));
+  writer->WriteArrayItem(src_step_ids);
   writer->WriteArrayItem(level);
   writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
 }
@@ -1416,13 +1442,7 @@ CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) {
   node->scope_name = std::move(string_value);
   s = reader->NextArrayItem();
   CHECK(s);
-  std::vector<int> int_list;
-  reader->Read(&int_list);
-  Array<Integer> reader_stage_ids;
-  for (int i : int_list) {
-    reader_stage_ids.push_back(i);
-  }
-  node->reader_stage_ids = std::move(reader_stage_ids);
+  reader->Read(&node->reader_stage_ids);
   data_ = std::move(node);
 }
 
@@ -1432,7 +1452,7 @@ void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
   writer->WriteArrayItem(stage_id);
   writer->WriteArraySeperator();
   writer->WriteString(scope_name);
-  writer->WriteArrayItem(IntArrayToVector(reader_stage_ids));
+  writer->WriteArrayItem(reader_stage_ids);
 }
 
 int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
index b3fa2dc..85bd7b0 100644 (file)
@@ -141,27 +141,6 @@ inline void StrReplace(std::string* base, const std::string& from, const std::st
   }
 }
 
-/*! \brief Convert a Array<Integer> to std::vector<int>. */
-inline std::vector<int> IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) {
-  std::vector<int> out;
-  for (const auto& x : data) {
-    CHECK(x.defined());
-    out.push_back(x);
-  }
-  return out;
-}
-
-/*! \brief Convert a Array<Optional<Integer>> to std::vector<int>. */
-inline std::vector<int> IntArrayToVector(
-    const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) {
-  std::vector<int> out;
-  for (const auto& x : data) {
-    CHECK(x);
-    out.push_back(x.value());
-  }
-  return out;
-}
-
 /*! \brief Return whether two int arrays are elementwise-equal */
 inline bool IntArrayEqual(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
   if (arr1.size() != arr2.size()) {
index 5dfc649..a646c38 100644 (file)
@@ -39,15 +39,18 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm",
     target = tvm.target.create(target)
     task = auto_scheduler.SearchTask(dag, workload_key, target)
 
-    if search_policy == 'empty':
-        search_policy = auto_scheduler.EmptyPolicy(task)
-    elif search_policy == 'sketch':
-        search_policy = auto_scheduler.SketchPolicy(task,
-                init_search_callbacks=init_search_callbacks)
-
     with tempfile.NamedTemporaryFile() as fp:
         log_file = fp.name
 
+        init_search_callbacks = init_search_callbacks or []
+        init_search_callbacks.append(auto_scheduler.PreloadMeasuredStates(log_file))
+
+        if search_policy == 'empty':
+            search_policy = auto_scheduler.EmptyPolicy(task)
+        elif search_policy == 'sketch':
+            search_policy = auto_scheduler.SketchPolicy(task,
+                    init_search_callbacks=init_search_callbacks)
+
         tuning_options = auto_scheduler.TuningOptions(num_measure_trials=num_measure_trials,
                 runner=runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)])
         sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options)
@@ -104,6 +107,20 @@ def test_sketch_search_policy_basic():
     t.join()
 
 
+def test_sketch_search_policy_cuda_rpc_runner():
+    if not tvm.runtime.enabled("cuda"):
+        return
+    measure_ctx = auto_scheduler.LocalRPCMeasureContext()
+    # wrap the search in a new thread to avoid the conflict
+    # between python's multiprocessing and tvm's thread pool
+    t = PropagatingThread(target=search_common,
+                          kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda',
+                                  'runner': measure_ctx.runner})
+    t.start()
+    t.join()
+
+
 if __name__ == "__main__":
     test_workload_registry_search_basic()
     test_sketch_search_policy_basic()
+    test_sketch_search_policy_cuda_rpc_runner()
index 4ef0cbc..f518866 100644 (file)
 
 import tvm
 from tvm import te, auto_scheduler
+from tvm.auto_scheduler import _ffi_api
+from tvm.auto_scheduler.loop_state import Stage
 
 from test_auto_scheduler_common import (matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test,
                                         max_pool2d_auto_scheduler_test, min_nm_auto_scheduler_test,
                                         softmax_nm_auto_scheduler_test, softmax_abcd_auto_scheduler_test,
                                         conv2d_winograd_nhwc_auto_scheduler_test)
 
+
 def generate_sketches(workload_func, args, target, print_for_debug=False):
     workload_key = auto_scheduler.make_workload_key(workload_func, args)
     dag = auto_scheduler.ComputeDAG(workload_key)
@@ -32,6 +35,28 @@ def generate_sketches(workload_func, args, target, print_for_debug=False):
     policy = auto_scheduler.SketchPolicy(task, verbose=0)
     return policy.generate_sketches(print_for_debug)
 
+def assert_compute_at_condition(stage, condition):
+    assert stage.compute_at == Stage.COMPUTE_AT_TRANS_TABLE[condition]
+
+def assert_is_tiled(stage):
+    assert _ffi_api.SearchPolicyUtilsIsTiled(stage)
+
+def assert_is_not_tiled(stage):
+    assert not _ffi_api.SearchPolicyUtilsIsTiled(stage)
+
+def assert_has_cache_write(state, stage_id):
+    assert _ffi_api.SearchPolicyUtilsHasCacheWriteStage(state, stage_id)
+
+def assert_has_cache_read(state, stage_id):
+    assert _ffi_api.SearchPolicyUtilsHasCacheReadStage(state, stage_id)
+
+def assert_has_rfactor(state, stage_id):
+    assert _ffi_api.SearchPolicyUtilsHasRfactorStage(state, stage_id)
+
+def assert_has_cross_thread_reduction(state, stage_id):
+    assert _ffi_api.SearchPolicyUtilsHasCrossThreadReduction(state, stage_id)
+
+
 def test_cpu_matmul_sketch():
     sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 'llvm')
     ''' 3 multi-level tiling sketches
@@ -40,6 +65,17 @@ def test_cpu_matmul_sketch():
         2 - Multi-level tiling with cache write on position 1
     '''
     assert len(sketches) == 3
+    # Sketch 0
+    assert_is_tiled(sketches[0].stages[2])
+    # Sketch 1
+    assert_is_tiled(sketches[1].stages[2])
+    assert_has_cache_write(sketches[1], 2)
+    assert_compute_at_condition(sketches[1].stages[2], "iter")
+    # Sketch 2
+    assert_is_tiled(sketches[2].stages[2])
+    assert_has_cache_write(sketches[2], 2)
+    assert_compute_at_condition(sketches[2].stages[2], "iter")
+    assert sketches[1] != sketches[2]
 
     sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 512), 'llvm')
     ''' 2 rfactor sketches + 3 multi-level tiling sketches
@@ -50,6 +86,23 @@ def test_cpu_matmul_sketch():
         4 - Multi-level tiling with cache write on position 1
     '''
     assert len(sketches) == 5
+    # Sketch 0
+    assert_has_rfactor(sketches[0], 2)
+    # Sketch 1
+    assert_has_rfactor(sketches[1], 2)
+    assert sketches[0] != sketches[1]
+    # Sketch 2
+    assert_is_tiled(sketches[2].stages[2])
+    # Sketch 3
+    assert_is_tiled(sketches[3].stages[2])
+    assert_has_cache_write(sketches[3], 2)
+    assert_compute_at_condition(sketches[3].stages[2], "iter")
+    # Sketch 4
+    assert_is_tiled(sketches[4].stages[2])
+    assert_has_cache_write(sketches[4], 2)
+    assert_compute_at_condition(sketches[4].stages[2], "iter")
+    assert sketches[3] != sketches[4]
+
 
 def test_cpu_conv2d_bn_relu_sketch():
     sketches = generate_sketches(conv2d_nchw_bn_relu_auto_scheduler_test,
@@ -60,28 +113,82 @@ def test_cpu_conv2d_bn_relu_sketch():
         2 - Conv2d multi-level tiling without fusion
     '''
     assert len(sketches) == 3
+    # Sketch 0
+    assert_is_not_tiled(sketches[0].stages[1])
+    assert_is_tiled(sketches[0].stages[3])
+    assert_compute_at_condition(sketches[0].stages[3], "iter")
+    assert_compute_at_condition(sketches[0].stages[5], "inlined")
+    assert_compute_at_condition(sketches[0].stages[7], "inlined")
+    assert_compute_at_condition(sketches[0].stages[9], "inlined")
+    assert_is_tiled(sketches[0].stages[10])
+    # Sketch 1
+    assert_is_not_tiled(sketches[1].stages[1])
+    assert_is_tiled(sketches[1].stages[3])
+    assert_compute_at_condition(sketches[1].stages[3], "iter")
+    assert_compute_at_condition(sketches[1].stages[5], "inlined")
+    assert_compute_at_condition(sketches[1].stages[7], "inlined")
+    assert_compute_at_condition(sketches[1].stages[9], "inlined")
+    assert_is_tiled(sketches[1].stages[10])
+    # Sketch 2
+    assert_is_not_tiled(sketches[2].stages[1])
+    assert_is_tiled(sketches[2].stages[3])
+    assert_compute_at_condition(sketches[2].stages[3], "root")
+    assert_compute_at_condition(sketches[2].stages[5], "inlined")
+    assert_compute_at_condition(sketches[2].stages[7], "inlined")
+    assert_compute_at_condition(sketches[2].stages[9], "inlined")
+    assert_is_not_tiled(sketches[2].stages[10])
+
 
 def test_cpu_max_pool2d_sketch():
     sketches = generate_sketches(max_pool2d_auto_scheduler_test, (1, 56, 56, 512, 1), 'llvm')
-    assert len(sketches) == 1  # 1 default sketch
+    ''' 1 default sketch '''
+    assert len(sketches) == 1
+    # Sketch 0
+    assert len(sketches[0].transform_steps) == 0
+
 
 def test_cpu_min_sketch():
     sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), 'llvm')
-    assert len(sketches) == 3
     ''' 2 rfactor sketches + 1 default sketch
         0 - Rfactor with factor position 0
         1 - Rfactor with factor position 1
         2 - Default sketch
     '''
+    assert len(sketches) == 3
+    # Sketch 0
+    assert_has_rfactor(sketches[0], 1)
+    # Sketch 1
+    assert_has_rfactor(sketches[1], 1)
+    assert sketches[0] != sketches[1]
+    # Sketch 2
+    assert len(sketches[2].transform_steps) == 0
+
 
 def test_cpu_softmax_sketch():
     sketches = generate_sketches(softmax_nm_auto_scheduler_test, (1, 1024), 'llvm')
     ''' (2 rfactor sketches + 1 default sketch) * (2 rfactor sketches + 1 default sketch) '''
     assert len(sketches) == (3 * 3)
+    for i in range(0, 3):
+        for j in range(0, 3):
+            sketch = sketches[i * 3 + j]
+            if j in [0, 1]:
+                assert_has_rfactor(sketch, 1)
+            if i in [0, 1]:
+                assert_has_rfactor(sketch, 4 if j in [0, 1] else 3)
+    assert len(sketches[8].transform_steps) == 0
 
     sketches = generate_sketches(softmax_abcd_auto_scheduler_test, (1, 12, 128, 128), 'llvm')
     ''' (2 rfactor sketches + 1 default sketch) * (2 rfactor sketches + 1 default sketch) '''
     assert len(sketches) == (3 * 3)
+    for i in range(0, 3):
+        for j in range(0, 3):
+            sketch = sketches[i * 3 + j]
+            if j in [0, 1]:
+                assert_has_rfactor(sketch, 1)
+            if i in [0, 1]:
+                assert_has_rfactor(sketch, 4 if j in [0, 1] else 3)
+    assert len(sketches[8].transform_steps) == 0
+
 
 def test_cpu_conv2d_winograd_sketch():
     sketches = generate_sketches(conv2d_winograd_nhwc_auto_scheduler_test,
@@ -92,6 +199,175 @@ def test_cpu_conv2d_winograd_sketch():
         2 - Bgemm multi-level tiling with cache write on position 1
     '''
     assert len(sketches) == 3
+    # Sketch 0
+    assert_is_not_tiled(sketches[0].stages[1])
+    assert_is_not_tiled(sketches[0].stages[2])
+    assert_compute_at_condition(sketches[0].stages[3], "inlined")
+    assert_is_tiled(sketches[0].stages[4])
+    assert_is_tiled(sketches[0].stages[6])
+    assert_compute_at_condition(sketches[0].stages[7], "inlined")
+    assert_is_tiled(sketches[0].stages[8])
+    assert_is_not_tiled(sketches[0].stages[9])
+    # Sketch 1
+    assert_is_not_tiled(sketches[1].stages[1])
+    assert_is_not_tiled(sketches[1].stages[2])
+    assert_compute_at_condition(sketches[1].stages[3], "inlined")
+    assert_is_tiled(sketches[1].stages[4])
+    assert_is_tiled(sketches[1].stages[6])
+    assert_has_cache_write(sketches[1], 6)
+    assert_compute_at_condition(sketches[1].stages[6], "iter")
+    assert_compute_at_condition(sketches[1].stages[8], "inlined")
+    assert_is_tiled(sketches[1].stages[9])
+    assert_is_not_tiled(sketches[1].stages[10])
+    # Sketch 2
+    assert_is_not_tiled(sketches[2].stages[1])
+    assert_is_not_tiled(sketches[2].stages[2])
+    assert_compute_at_condition(sketches[2].stages[3], "inlined")
+    assert_is_tiled(sketches[2].stages[4])
+    assert_is_tiled(sketches[2].stages[6])
+    assert_has_cache_write(sketches[2], 6)
+    assert_compute_at_condition(sketches[2].stages[6], "iter")
+    assert_compute_at_condition(sketches[2].stages[8], "inlined")
+    assert_is_tiled(sketches[2].stages[9])
+    assert_is_not_tiled(sketches[2].stages[10])
+    assert sketches[1] != sketches[2]
+
+
+def test_cuda_matmul_sketch():
+    if not tvm.context("cuda", 0).exist:
+        return
+
+    sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 'cuda')
+    ''' 1 multi-level tiling sketch '''
+    assert len(sketches) == 1
+    assert_has_cache_read(sketches[0], 0)
+    assert_compute_at_condition(sketches[0].stages[1], "iter")
+    assert_has_cache_read(sketches[0], 2)
+    assert_compute_at_condition(sketches[0].stages[3], "iter")
+    assert_has_cache_write(sketches[0], 4)
+    assert_is_tiled(sketches[0].stages[4])
+    assert_compute_at_condition(sketches[0].stages[4], "iter")
+    assert_is_tiled(sketches[0].stages[5])
+
+    sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 1024), 'cuda')
+    ''' 1 cross thread reuction sketch + 1 multi-level tiling sketch '''
+    assert len(sketches) == 2
+    # Sketch 0
+    assert_has_cross_thread_reduction(sketches[0], 2)
+    # Sketch 1
+    assert_has_cache_read(sketches[1], 0)
+    assert_compute_at_condition(sketches[1].stages[1], "iter")
+    assert_has_cache_read(sketches[1], 2)
+    assert_compute_at_condition(sketches[1].stages[3], "iter")
+    assert_has_cache_write(sketches[1], 4)
+    assert_is_tiled(sketches[1].stages[4])
+    assert_compute_at_condition(sketches[1].stages[4], "iter")
+    assert_is_tiled(sketches[1].stages[5])
+
+
+def test_cuda_conv2d_bn_relu_sketch():
+    if not tvm.context("cuda", 0).exist:
+        return
+
+    sketches = generate_sketches(conv2d_nchw_bn_relu_auto_scheduler_test,
+                                 (1, 56, 56, 512, 512, 3, 1, 1), 'cuda')
+    ''' 1 multi-level tiling sketch '''
+    assert len(sketches) == 1
+    assert_has_cache_read(sketches[0], 1)
+    assert_compute_at_condition(sketches[0].stages[1], "inlined")
+    assert_compute_at_condition(sketches[0].stages[2], "iter")
+    assert_has_cache_read(sketches[0], 3)
+    assert_compute_at_condition(sketches[0].stages[4], "iter")
+    assert_is_tiled(sketches[0].stages[5])
+    assert_compute_at_condition(sketches[0].stages[5], "iter")
+    assert_compute_at_condition(sketches[0].stages[7], "inlined")
+    assert_compute_at_condition(sketches[0].stages[9], "inlined")
+    assert_compute_at_condition(sketches[0].stages[11], "inlined")
+    assert_is_tiled(sketches[0].stages[12])
+
+
+def test_cuda_max_pool2d_sketch():
+    if not tvm.context("cuda", 0).exist:
+        return
+
+    sketches = generate_sketches(max_pool2d_auto_scheduler_test, (1, 56, 56, 512, 0), 'cuda')
+    ''' 1 default sketch '''
+    assert len(sketches) == 1
+    assert len(sketches[0].transform_steps) == 0
+
+
+def test_cuda_min_sketch():
+    if not tvm.context("cuda", 0).exist:
+        return
+
+    sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), 'cuda')
+    ''' 1 cross thread reuction sketch + 1 default sketch '''
+    assert len(sketches) == 2
+    # Sketch 0
+    assert_has_cross_thread_reduction(sketches[0], 1)
+    # Sketch 1
+    assert len(sketches[1].transform_steps) == 0
+
+
+def test_cuda_softmax_sketch():
+    if not tvm.context("cuda", 0).exist:
+        return
+
+    sketches = generate_sketches(softmax_nm_auto_scheduler_test, (2, 1024), 'cuda')
+    ''' (1 cross thread reuction sketch + 1 default sketch) * (1 cross thread reuction sketch + 1 default sketch) '''
+    assert len(sketches) == (2 * 2)
+    # Sketch 0
+    assert_has_cross_thread_reduction(sketches[0], 1)
+    assert_compute_at_condition(sketches[3].stages[2], "inlined")
+    assert_has_cross_thread_reduction(sketches[0], 3)
+    # Sketch 1
+    assert_compute_at_condition(sketches[3].stages[2], "inlined")
+    assert_has_cross_thread_reduction(sketches[1], 3)
+    # Sketch 2
+    assert_has_cross_thread_reduction(sketches[2], 1)
+    assert_compute_at_condition(sketches[3].stages[2], "inlined")
+    # Sketch 3
+    assert_compute_at_condition(sketches[3].stages[2], "inlined")
+
+    sketches = generate_sketches(softmax_abcd_auto_scheduler_test, (1, 12, 128, 128), 'cuda')
+    ''' (1 cross thread reuction sketch + 1 default sketch) * (1 cross thread reuction sketch + 1 default sketch) '''
+    assert len(sketches) == (2 * 2)
+    # Sketch 0
+    assert_has_cross_thread_reduction(sketches[0], 1)
+    assert_compute_at_condition(sketches[3].stages[2], "inlined")
+    assert_has_cross_thread_reduction(sketches[0], 3)
+    # Sketch 1
+    assert_compute_at_condition(sketches[3].stages[2], "inlined")
+    assert_has_cross_thread_reduction(sketches[1], 3)
+    # Sketch 2
+    assert_has_cross_thread_reduction(sketches[2], 1)
+    assert_compute_at_condition(sketches[3].stages[2], "inlined")
+    # Sketch 3
+    assert_compute_at_condition(sketches[3].stages[2], "inlined")
+
+
+def test_cuda_conv2d_winograd_sketch():
+    if not tvm.context("cuda", 0).exist:
+        return
+
+    sketches = generate_sketches(conv2d_winograd_nhwc_auto_scheduler_test,
+                                 (1, 28, 28, 128, 128, 3, 1, 1), 'cuda')
+    ''' 1 multi-level tiling sketch '''
+    assert len(sketches) == 1
+    assert_compute_at_condition(sketches[0].stages[1], "inlined")
+    assert_compute_at_condition(sketches[0].stages[2], "inlined")
+    assert_compute_at_condition(sketches[0].stages[3], "inlined")
+    assert_is_tiled(sketches[0].stages[4])
+    assert_has_cache_read(sketches[0], 4)
+    assert_compute_at_condition(sketches[0].stages[5], "iter")
+    assert_has_cache_read(sketches[0], 6)
+    assert_compute_at_condition(sketches[0].stages[7], "iter")
+    assert_is_not_tiled(sketches[0].stages[8])
+    assert_compute_at_condition(sketches[0].stages[8], "iter")
+    assert_compute_at_condition(sketches[0].stages[9], "inlined")
+    assert_is_tiled(sketches[0].stages[10])
+    assert_is_not_tiled(sketches[0].stages[11])
+
 
 if __name__ == "__main__":
     test_cpu_matmul_sketch()
@@ -100,3 +376,9 @@ if __name__ == "__main__":
     test_cpu_min_sketch()
     test_cpu_softmax_sketch()
     test_cpu_conv2d_winograd_sketch()
+    test_cuda_matmul_sketch()
+    test_cuda_conv2d_bn_relu_sketch()
+    test_cuda_max_pool2d_sketch()
+    test_cuda_min_sketch()
+    test_cuda_softmax_sketch()
+    test_cuda_conv2d_winograd_sketch()