TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode);
};
+/*! \brief Preload measured states from a log file.
+ * This can resume the state of the search policy */
+class PreloadMeasuredStatesNode : public SearchCallbackNode {
+ public:
+ /*! \brief The name of the record log file. */
+ String filename;
+
+ void Callback(SearchPolicyNode* policy) final;
+
+ static constexpr const char* _type_key = "auto_scheduler.PreloadMeasuredStates";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode);
+};
+
+/*!
+ * \brief Managed reference to PreloadMeasuredStatesNode.
+ * \sa PreloadMeasuredStatesNode
+ */
+class PreloadMeasuredStates : public SearchCallback {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param filename The name of the record log file.
+ */
+ explicit PreloadMeasuredStates(String filename);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadMeasuredStates, SearchCallback,
+ PreloadMeasuredStatesNode);
+};
+
/*! \brief Attribute keys of ops used for SearchPolicy. */
struct SearchPolicyKey {
/*! \brief Always apply unroll to the inner most iterator of the specificed iterators. */
ProgramMeasurer measurer) = 0;
/*!
+ * \brief Preload measured states from a log file to resume the state of the search policy.
+ * \param log_file The name of the record log file.
+ */
+ void PreloadMeasuredStates(const String& log_file);
+
+ /*!
* \brief Call SearchCallback with the current SearchPolicyNode
* \param callbacks SearchCallback to be called.
*/
# Shortcut
from .auto_schedule import SearchTask, TuningOptions, HardwareParams, \
- auto_schedule, EmptyPolicy, SketchPolicy
+ auto_schedule
from .compute_dag import ComputeDAG
from .cost_model import RandomModel, XGBModel
from .measure import MeasureInput, MeasureResult, LocalBuilder, LocalRunner, RPCRunner, \
LocalRPCMeasureContext
from .measure_record import RecordToFile, RecordReader, load_best, \
load_records, save_records
+from .search_policy import EmptyPolicy, SketchPolicy, PreloadMeasuredStates
from .workload_registry import register_workload, make_workload_key
Candidate schedules are measured against the specific hardware target.
"""
-import random
-
import tvm._ffi
from tvm.runtime import Object
from .measure import LocalBuilder, LocalRunner
-from .cost_model import RandomModel
+from .search_policy import EmptyPolicy
from . import _ffi_api
hardware_params)
-@tvm._ffi.register_object("auto_scheduler.SearchPolicy")
-class SearchPolicy(Object):
- """ The base class of search policies. """
-
-
-@tvm._ffi.register_object("auto_scheduler.EmptyPolicy")
-class EmptyPolicy(SearchPolicy):
- """ This is an example empty search policy which will always generate
- the init state of ComputeDAG.
-
- Parameters
- ----------
- task : SearchTask
- The SearchTask for the computation declaration.
- init_search_callbacks : Optional[List[SearchCallback]]
- Callback functions called before the search process.
- """
- def __init__(self, task, init_search_callbacks=None):
- self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy, task, init_search_callbacks)
-
-
-@tvm._ffi.register_object("auto_scheduler.SketchPolicy")
-class SketchPolicy(SearchPolicy):
- """ The search policy that searches in a hierarchical search space defined by sketches.
- The policy randomly samples programs from the space defined by sketches
- and use evolutionary search to fine-tune them.
-
- Parameters
- ----------
- task : SearchTask
- The SearchTask for the computation declaration.
- schedule_cost_model : CostModel = RandomModel()
- The cost model to estimate the complete schedules.
- params : Optional[Dict[str, Any]]
- Parameters of the search policy.
- See `src/auto_scheduler/search_policy/sketch_search_policy.h` for the definitions.
- See `DEFAULT_PARAMS` below to find the default values.
- seed : Optional[int]
- Random seed.
- verbose : int = 1
- Verbosity level. 0 for silent, 1 to output information during schedule search.
- init_search_callbacks : Optional[List[SearchCallback]]
- Callback functions called before the search process, usually used to do extra
- initializations.
- Possible callbacks:
- - auto_scheduler.PreloadMeasuredStates
- - auto_scheduler.PreloadCustomSketchRule
- TODO(jcf94): Add these search callback implementations.
- """
-
- DEFAULT_PARAMS = {
- "eps_greedy": 0.05,
-
- 'evolutionary_search_population': 2048,
- "evolutionary_search_use_measured_ratio": 0.2,
-
- 'cpu_multi_level_tiling_structure': 'SSRSRS',
- 'gpu_multi_level_tiling_structure': 'SSSRRSRS',
-
- 'max_innermost_split_factor': 16,
- 'max_vectorize_size': 16,
-
- 'disable_change_compute_location': 0,
- }
-
- def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=None, verbose=1,
- init_search_callbacks=None):
- if params is None:
- params = SketchPolicy.DEFAULT_PARAMS
- else:
- for key, value in SketchPolicy.DEFAULT_PARAMS.items():
- if key not in params:
- params[key] = value
-
- self.__init_handle_by_constructor__(
- _ffi_api.SketchPolicy, task, schedule_cost_model, params,
- seed or random.randint(1, 1 << 30), verbose, init_search_callbacks)
-
- def generate_sketches(self, print_for_debug=False):
- """ Generate the sketches.
- This python interface is mainly used for debugging and testing.
- The actual search is all doen in c++.
-
- Parameters
- ----------
- print_for_debug : bool = False
- Whether print out the sketches for debug.
-
- Returns
- -------
- sketches : List[State]
- The generated sketches of this search task.
- """
- sketches = _ffi_api.SketchPolicyGenerateSketches(self)
- if print_for_debug:
- for i, s in enumerate(sketches):
- print("=" * 20 + " %d " % i + "=" * 20)
- print(s)
- return sketches
-
- def sample_initial_population(self, pop_size):
- """Sample initial population.
- This python interface is mainly used for debugging and testing.
- The actual search is all doen in c++.
-
- Parameters
- ----------
- pop_size : int
- The size of sampled population
-
- Returns
- -------
- states: List[State]
- The sampled states
- """
- states = _ffi_api.SketchPolicySampleInitialPopulation(self, pop_size)
- return states
-
@tvm._ffi.register_object("auto_scheduler.TuningOptions")
class TuningOptions(Object):
""" This controls the options of performance tuning.
class Stage(Object):
""" A stage in the compute declaration. Similar to tvm.te.schedule.Stage. """
+ # Static trans table for compute_at location
+ # This is used to transform the compute_at location to C++ enum
+ COMPUTE_AT_TRANS_TABLE = {
+ "root": 0,
+ "inlined": 1,
+ "iter": 2
+ }
@tvm._ffi.register_object("auto_scheduler.State")
class StateObject(Object):
This is a wrapper class of StateObject to deal with copy-on-write property
"""
- # Static trans table for thread bind
+ # Static trans table for thread bind and annotation
# This is used to transform the annotation name to C++ enum
ANNOTATION_TRANS_TABLE = {
"none": 0,
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+The search policies for TVM Auto-scheduler.
+
+This contains the strategies to generate a schedule automatically. We provide an EmptyPolicy
+which always returns an unchanged initial state, and a more advanced SketchPolicy which can
+deal with various ops/subgraphs on different target devices.
+
+Reference:
+L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor
+Programs for Deep Learning." arXiv preprint arXiv:2006.06762 (2020).
+"""
+
+import random
+
+import tvm._ffi
+from tvm.runtime import Object
+from .cost_model import RandomModel
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("auto_scheduler.SearchCallback")
+class SearchCallback(Object):
+ """Callback function before or after search process"""
+
+
+@tvm._ffi.register_object("auto_scheduler.PreloadMeasuredStates")
+class PreloadMeasuredStates(SearchCallback):
+ """ A SearchCallback to load measured states from the log file for a search policy.
+
+ This can resume the state of the search policy:
+ - Making sure an already measured state in former searches will never be measured again.
+ - The history states can be used to speed up the search process(e.g. SketchPolicy uses
+ history states as starting point to perform Evolutionary Search).
+
+ Parameters
+ ----------
+ filename : str
+ The name of the record file.
+ """
+ def __init__(self, filename="auto_scheduler_tuning.json"):
+ self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
+
+
+@tvm._ffi.register_object("auto_scheduler.SearchPolicy")
+class SearchPolicy(Object):
+ """ The base class of search policies. """
+
+
+@tvm._ffi.register_object("auto_scheduler.EmptyPolicy")
+class EmptyPolicy(SearchPolicy):
+ """ This is an example empty search policy which will always generate
+ the init state of ComputeDAG.
+
+ Parameters
+ ----------
+ task : SearchTask
+ The SearchTask for the computation declaration.
+ init_search_callbacks : Optional[List[SearchCallback]]
+ Callback functions called before the search process.
+ """
+ def __init__(self, task, init_search_callbacks=None):
+ self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy, task, init_search_callbacks)
+
+
+@tvm._ffi.register_object("auto_scheduler.SketchPolicy")
+class SketchPolicy(SearchPolicy):
+ """ The search policy that searches in a hierarchical search space defined by sketches.
+ The policy randomly samples programs from the space defined by sketches and use evolutionary
+ search to fine-tune them.
+
+ Parameters
+ ----------
+ task : SearchTask
+ The SearchTask for the computation declaration.
+ schedule_cost_model : CostModel = RandomModel()
+ The cost model to estimate the complete schedules.
+ params : Optional[Dict[str, Any]]
+ Parameters of the search policy.
+ See `src/auto_scheduler/search_policy/sketch_search_policy.h` for the definitions.
+ See `DEFAULT_PARAMS` below to find the default values.
+ seed : Optional[int]
+ Random seed.
+ verbose : int = 1
+ Verbosity level. 0 for silent, 1 to output information during schedule search.
+ init_search_callbacks : Optional[List[SearchCallback]]
+ Callback functions called before the search process, usually used to do extra
+ initializations.
+ Possible callbacks:
+ - auto_scheduler.PreloadMeasuredStates
+ - auto_scheduler.PreloadCustomSketchRule
+ TODO(jcf94): Add these search callback implementations.
+ """
+
+ DEFAULT_PARAMS = {
+ "eps_greedy": 0.05,
+ "retry_search_one_round_on_empty": 10,
+
+ 'evolutionary_search_population': 2048,
+ "evolutionary_search_use_measured_ratio": 0.2,
+
+ 'cpu_multi_level_tiling_structure': 'SSRSRS',
+ 'gpu_multi_level_tiling_structure': 'SSSRRSRS',
+ # Notice: the default thread bind policy of GPU assumes the tiling structure to have at
+ # least 3 spatial tiling levels in outermost
+
+ 'max_innermost_split_factor': 16,
+ 'max_vectorize_size': 16,
+
+ 'disable_change_compute_location': 0,
+ }
+
+ def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=None, verbose=1,
+ init_search_callbacks=None):
+ if params is None:
+ params = SketchPolicy.DEFAULT_PARAMS
+ else:
+ for key, value in SketchPolicy.DEFAULT_PARAMS.items():
+ if key not in params:
+ params[key] = value
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.SketchPolicy, task, schedule_cost_model, params,
+ seed or random.randint(1, 1 << 30), verbose, init_search_callbacks)
+
+ def generate_sketches(self, print_for_debug=False):
+ """ Generate the sketches.
+ This python interface is mainly used for debugging and testing.
+ The actual search is all done in c++.
+
+ Parameters
+ ----------
+ print_for_debug : bool = False
+ Whether print out the sketches for debug.
+
+ Returns
+ -------
+ sketches : List[State]
+ The generated sketches of this search task.
+ """
+ sketches = _ffi_api.SketchPolicyGenerateSketches(self)
+ if print_for_debug:
+ for i, s in enumerate(sketches):
+ print("=" * 20 + " %d " % i + "=" * 20)
+ print(s)
+ return sketches
+
+ def sample_initial_population(self, pop_size):
+ """Sample initial population.
+ This python interface is mainly used for debugging and testing.
+ The actual search is all done in c++.
+
+ Parameters
+ ----------
+ pop_size : int
+ The size of sampled population
+
+ Returns
+ -------
+ states: List[State]
+ The sampled states
+ """
+ states = _ffi_api.SketchPolicySampleInitialPopulation(self, pop_size)
+ return states
#include <tvm/auto_scheduler/auto_schedule.h>
#include <tvm/runtime/registry.h>
+#include "utils.h"
+
namespace tvm {
namespace auto_scheduler {
State state =
search_policy->Search(tuning_options->num_measure_trials, tuning_options->early_stopping,
tuning_options->num_measures_per_round, measurer);
- return search_policy->search_task->compute_dag.ApplySteps(state->transform_steps);
+ if (state.defined()) {
+ return search_policy->search_task->compute_dag.ApplySteps(state->transform_steps);
+ } else {
+ StdCout(tuning_options->verbose)
+ << "No valid state found in this search round. Check if it has traversed all of the "
+ << "search space." << std::endl;
+ // Return the default schedule
+ return {te::Schedule(search_policy->search_task->compute_dag->ops),
+ search_policy->search_task->compute_dag->tensors};
+ }
}
TVM_REGISTER_GLOBAL("auto_scheduler.TuningOptions")
* \brief The base class of search policies.
*/
+#include <tvm/auto_scheduler/measure_record.h>
#include <tvm/auto_scheduler/search_policy.h>
#include <tvm/runtime/registry.h>
+#include "utils.h"
+
namespace tvm {
namespace auto_scheduler {
TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode);
TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode);
+TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode);
+
+void SearchPolicyNode::PreloadMeasuredStates(const String& log_file) {
+ RecordReader reader = RecordReader(log_file);
+ const auto& res = reader->ReadLines(-1);
+ size_t log_size = res.first.size();
+ CHECK_EQ(log_size, res.second.size());
+ if (log_size) {
+ Array<State> measured_states;
+ std::vector<float> measured_throughputs;
+ for (size_t i = 0; i < log_size; i++) {
+ const auto& inp = res.first[i];
+ if (inp->task->workload_key == search_task->workload_key &&
+ inp->task->target->kind->name.compare(search_task->target->kind->name) == 0) {
+ State state = search_task->compute_dag->init_state;
+ auto pstate = state.CopyOnWrite();
+ pstate->transform_steps = inp->state->transform_steps;
+ for (const auto& step : pstate->transform_steps) {
+ StepApplyToState(step, &state, search_task->compute_dag);
+ }
+ measured_states.push_back(std::move(state));
+ measured_throughputs.push_back(
+ res.second[i]->error_no == 0 ? (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0);
+ }
+ }
+ measured_states = search_task->compute_dag.InferBound(measured_states);
+ for (size_t i = 0; i < measured_states.size(); i++) {
+ auto& state = measured_states[i];
+ const auto& state_str = state.ToStr();
+ if (!measured_states_set_.count(state_str)) {
+ measured_states_set_.insert(state_str);
+ if (measured_throughputs[i] != 0.0) {
+ measured_states_vector_.emplace_back(std::move(state));
+ measured_states_throughputs_.emplace_back(measured_throughputs[i]);
+ }
+ }
+ }
+
+ StdCout(verbose) << "SearchPolicy: Loaded " << measured_states_set_.size()
+ << " measurement records from " << log_file << " for "
+ << search_task->workload_key << std::endl;
+ } else {
+ StdCout(verbose) << "SearchPolicy: No measurement records found in " << log_file << " for "
+ << search_task->workload_key << std::endl;
+ }
+}
void SearchPolicyNode::RunCallbacks(const Array<SearchCallback>& callbacks) {
for (const auto& callback : callbacks) {
}
}
+PreloadMeasuredStates::PreloadMeasuredStates(String filename) {
+ auto node = make_object<PreloadMeasuredStatesNode>();
+ node->filename = std::move(filename);
+ data_ = std::move(node);
+}
+
+void PreloadMeasuredStatesNode::Callback(SearchPolicyNode* policy) {
+ policy->PreloadMeasuredStates(filename);
+}
+
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks")
.set_body_typed([](SearchPolicy policy, Optional<Array<SearchCallback>> callbacks) {
if (callbacks) {
TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicySetVerbose")
.set_body_typed([](SearchPolicy policy, int verbose) { policy->verbose = verbose; });
+TVM_REGISTER_GLOBAL("auto_scheduler.PreloadMeasuredStates").set_body_typed([](String filename) {
+ return PreloadMeasuredStates(filename);
+});
+
} // namespace auto_scheduler
} // namespace tvm
static RuleAlwaysInline rule_always_inline;
static RuleMultiLevelTiling rule_multi_level_tiling;
static RuleMultiLevelTilingWithFusion rule_multi_level_tiling_with_fusion;
+static RuleAddCacheRead rule_add_cache_read_stage;
static RuleAddCacheWrite rule_add_cache_write_stage;
static RuleAddRfactor rule_add_rfactor;
+static RuleCrossThreadReduction rule_cross_thread_reduction;
static RuleSimplifyComputeWithConstTensor rule_simplify_compute_with_const_tensor;
+static RuleSpecialComputeLocationGPU rule_special_compute_location_gpu;
/********** Init population rules **********/
static InitParallel init_parallel;
static InitUnroll init_unroll;
static InitVectorization init_vectorization;
+static InitThreadBind init_thread_bind;
/********** Sketch policy **********/
node->RunCallbacks(init_search_callbacks.value());
}
- // The default sketch rules for CPU policy
// Notice: Some rules require us to skip all the rest rules after they are applied.
// So the rules below should be ordered carefully.
- node->sketch_rules.push_back(&rule_always_inline);
- node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
- node->sketch_rules.push_back(&rule_add_rfactor);
- node->sketch_rules.push_back(&rule_add_cache_write_stage);
- node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
- node->sketch_rules.push_back(&rule_multi_level_tiling);
+ if (IsCPUTask(node->search_task)) {
+ // The default sketch rules for CPU policy
+ node->sketch_rules.push_back(&rule_always_inline);
+ node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
+ node->sketch_rules.push_back(&rule_add_rfactor);
+ node->sketch_rules.push_back(&rule_add_cache_write_stage);
+ node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
+ node->sketch_rules.push_back(&rule_multi_level_tiling);
+ } else if (IsCUDATask(node->search_task)) {
+ // The default sketch rules for CUDA policy
+ node->sketch_rules.push_back(&rule_add_cache_read_stage);
+ node->sketch_rules.push_back(&rule_always_inline);
+ node->sketch_rules.push_back(&rule_special_compute_location_gpu);
+ node->sketch_rules.push_back(&rule_simplify_compute_with_const_tensor);
+ node->sketch_rules.push_back(&rule_cross_thread_reduction);
+ node->sketch_rules.push_back(&rule_add_cache_write_stage);
+ node->sketch_rules.push_back(&rule_multi_level_tiling_with_fusion);
+ node->sketch_rules.push_back(&rule_multi_level_tiling);
+ } else {
+ LOG(FATAL) << "No default sketch rules for target: " << task->target;
+ }
node->sketch_rules.push_back(&rule_skip_stage); // This should always be the last rule
- // The default init population rules for CPU policy
- node->init_rules.push_back(&init_fill_tile_size);
- node->init_rules.push_back(&init_change_compute_location);
- node->init_rules.push_back(&init_parallel);
- node->init_rules.push_back(&init_unroll);
- node->init_rules.push_back(&init_vectorization);
+ node->init_rules.push_back(&init_fill_tile_size); // This should always be the first rule
+ if (IsCPUTask(node->search_task)) {
+ // The default init population rules for CPU policy
+ node->init_rules.push_back(&init_change_compute_location);
+ node->init_rules.push_back(&init_parallel);
+ node->init_rules.push_back(&init_unroll);
+ node->init_rules.push_back(&init_vectorization);
+ } else if (IsCUDATask(node->search_task)) {
+ // The default init population rules for CUDA policy
+ node->init_rules.push_back(&init_thread_bind);
+ node->init_rules.push_back(&init_unroll);
+ } else {
+ LOG(FATAL) << "No default init rules for target: " << task->target;
+ }
data_ = std::move(node);
}
measurer->Reset();
int ct = 0;
+ int empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count);
Array<MeasureInput> inputs;
Array<MeasureResult> results;
while (ct < n_trials) {
// Also pick some random states to do eps-greedy
inputs = PickStatesWithEpsGreedy(best_states, random_states, n_trials - ct);
- // Have traversed all of the search space
+ // Currently it's hard to detect if all of the search space has been traversed
+ // Stop if no extra valid states found in several retries
if (inputs.empty()) {
- StdCout(verbose) << "All candidates in the search space have been measured." << std::endl;
- break;
+ if (empty_retry_count-- > 0) {
+ continue;
+ } else {
+ StdCout(verbose) << "It seems all candidates in the search space have been measured."
+ << std::endl;
+ break;
+ }
+ } else {
+ // Reset the retry count
+ empty_retry_count = GetIntParam(params, SketchParamKey::empty_retry_count);
}
// Measure candidate states
}
Array<State> SketchPolicyNode::GenerateSketches() {
- State init_state = search_task->compute_dag->init_state;
+ const State& init_state = search_task->compute_dag->init_state;
// Two ping pong buffers to avoid copy
Array<State> states_buf1{init_state}, states_buf2;
cur_stage_id_map[pair.first] = pair.second;
pnext->push_back(pair.first);
}
- // Skip the reset rules
+ // Skip the rest rules
if (cond == SketchGenerationRule::ConditionKind::kApplyAndSkipRest) {
break;
}
struct SketchParamKey {
/*! \brief Always allocate this percentage of measurements to random sampled states. */
static constexpr const char* eps_greedy = "eps_greedy";
+ /*! \brief Retry several times if SearchOneRound gets no valid state. */
+ static constexpr const char* empty_retry_count = "retry_search_one_round_on_empty";
struct EvolutionarySearch {
/*! \brief The population size for evolutionary search. */
/********** RuleSkipStage **********/
SketchGenerationRule::ConditionKind RuleSkipStage::MeetCondition(const SketchPolicyNode& policy,
- const State& state, int stage_id) {
+ const State& state,
+ int stage_id) const {
// This rule should be the last rule, always return true to decrease the stage index count
return ConditionKind::kApply;
}
SketchGenerationRule::ConditionKind RuleAlwaysInline::MeetCondition(const SketchPolicyNode& policy,
const State& state,
- int stage_id) {
+ int stage_id) const {
const Stage& stage = state->stages[stage_id];
// Check the inline limitation of TE first
if (stage->op_type == StageKind::kPlaceholder ||
return ConditionKind::kSkip;
}
- // TODO(jcf94): Greedily inline all inlinable ops on GPU when introducing GPU search policy.
- return IsStrictlyInlineable(policy.search_task, state, stage_id)
+ // Always do compute inline if it's strictly inlineable or is in GPU policy
+ return IsStrictlyInlineable(policy.search_task, state, stage_id) || IsGPUTask(policy.search_task)
? ConditionKind::kApplyAndSkipRest
: ConditionKind::kSkip;
}
/********** RuleMultiLevelTiling **********/
SketchGenerationRule::ConditionKind RuleMultiLevelTiling::MeetCondition(
- const SketchPolicyNode& policy, const State& state, int stage_id) {
+ const SketchPolicyNode& policy, const State& state, int stage_id) const {
return NeedsMultilevelTiling(policy.search_task, state, stage_id)
? ConditionKind::kApplyAndSkipRest
: ConditionKind::kSkip;
std::vector<std::pair<State, int>> RuleMultiLevelTiling::Apply(const SketchPolicyNode& policy,
const State& state,
int stage_id) const {
- // TODO(jcf94): Add support for GPU structure when introducing GPU search policy.
const std::string& multi_level_tiling_structure =
- GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
+ IsGPUTask(policy.search_task)
+ ? GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::gpu_structure)
+ : GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
State tmp_s = DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure);
return {std::make_pair(std::move(tmp_s), stage_id - 1)};
}
/********** RuleMultiLevelTilingWithFusion **********/
SketchGenerationRule::ConditionKind RuleMultiLevelTilingWithFusion::MeetCondition(
- const SketchPolicyNode& policy, const State& state, int stage_id) {
+ const SketchPolicyNode& policy, const State& state, int stage_id) const {
if (NeedsMultilevelTiling(policy.search_task, state, stage_id) &&
- HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id, &target_stage_id)) {
- // Always do fusion for stage with cache_write
- // TODO(jcf94): Always do fusion on GPU when introducing GPU search policy.
- return HasCacheWriteStage(state, stage_id) ? ConditionKind::kApplyAndSkipRest
- : ConditionKind::kApply;
+ HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) {
+ // Always do fusion for stage with cache_write or is in GPU policy
+ return HasCacheWriteStage(state, stage_id) || IsGPUTask(policy.search_task)
+ ? ConditionKind::kApplyAndSkipRest
+ : ConditionKind::kApply;
}
return ConditionKind::kSkip;
}
std::vector<std::pair<State, int>> RuleMultiLevelTilingWithFusion::Apply(
const SketchPolicyNode& policy, const State& state, int stage_id) const {
- // TODO(jcf94): Add support for GPU structure when introducing GPU search policy.
+ int target_stage_id;
+ CHECK(HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id, &target_stage_id));
const std::string& multi_level_tiling_structure =
- GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
+ IsGPUTask(policy.search_task)
+ ? GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::gpu_structure)
+ : GetStringParam(policy.params, SketchParamKey::MultiLevelTiling::cpu_structure);
std::vector<int> spatial_split_step_ids;
State base_state =
DoMultiLevelTiling(state, stage_id, multi_level_tiling_structure, &spatial_split_step_ids);
std::vector<std::pair<State, int>> ret;
- // TODO(jcf94): Add follow_tiling_levels for GPU when introducing GPU search policy.
- std::vector<int> follow_tiling_levels{1, 2};
+ std::vector<int> follow_tiling_levels =
+ IsGPUTask(policy.search_task) ? std::vector<int>{3} : std::vector<int>{1, 2};
for (int level : follow_tiling_levels) {
if (tolower(multi_level_tiling_structure[level - 1]) != 's') {
continue;
return ret;
}
+/********** RuleAddCacheRead **********/
+
+SketchGenerationRule::ConditionKind RuleAddCacheRead::MeetCondition(const SketchPolicyNode& policy,
+ const State& state,
+ int stage_id) const {
+ const SearchTask& task = policy.search_task;
+
+ // Don't cache_read a stage if it has multiple consumers
+ const std::set<int>& consumers = GetConsumers(task, state, stage_id);
+ if (consumers.size() != 1) {
+ return ConditionKind::kSkip;
+ }
+
+ // Don't cache_read a stage if its consumer does not need multi-level tiling
+ int target_stage_id = *consumers.begin();
+ if (!NeedsMultilevelTiling(task, state, target_stage_id)) {
+ return ConditionKind::kSkip;
+ }
+
+ // Don't cache_read a stage if its consumer does cross-thread reduction
+ if (HasCrossThreadReduction(state, target_stage_id)) {
+ return ConditionKind::kSkip;
+ }
+
+ // Only direct producers can be cache read
+ const std::set<int>& producers = GetDirectProducers(task, state, target_stage_id);
+ if (producers.find(stage_id) == producers.end()) {
+ return ConditionKind::kSkip;
+ }
+
+ return ConditionKind::kApplyAndSkipRest;
+}
+
+std::vector<std::pair<State, int>> RuleAddCacheRead::Apply(const SketchPolicyNode& policy,
+ const State& state, int stage_id) const {
+ const SearchTask& task = policy.search_task;
+ const std::set<int>& consumers = GetConsumers(task, state, stage_id);
+ CHECK_EQ(consumers.size(), 1);
+ int target_stage_id = *consumers.begin();
+ State tmp_s = state;
+
+ // Cache read add shared memory
+ int added_stage_id = tmp_s.cache_read(stage_id, "shared", {target_stage_id}, task->compute_dag);
+ target_stage_id++;
+ const auto& share_read_pos =
+ GetLastReduceIteratorInOutermostReduceTile(tmp_s->stages[target_stage_id]);
+ tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos);
+ return {std::make_pair(tmp_s, stage_id)};
+}
+
/********** RuleAddCacheWrite **********/
SketchGenerationRule::ConditionKind RuleAddCacheWrite::MeetCondition(const SketchPolicyNode& policy,
const State& state,
- int stage_id) {
+ int stage_id) const {
// Add cache write if a stage needs multi-level tiling, but does not have a element-wise
// matched consumer
if (NeedsMultilevelTiling(policy.search_task, state, stage_id) &&
!HasSingleElementwiseMatchedConsumer(policy.search_task, state, stage_id)) {
// An apply and skip rule will be handled in RuleMultiLevelTilingWithFusion
- // TODO(jcf94): Always do cache_write on GPU when introducing GPU search policy.
- return ConditionKind::kApply;
+ return IsGPUTask(policy.search_task) ? ConditionKind::kApplyAndSkipRest : ConditionKind::kApply;
}
return ConditionKind::kSkip;
SketchGenerationRule::ConditionKind RuleAddRfactor::MeetCondition(const SketchPolicyNode& policy,
const State& state,
- int stage_id) {
+ int stage_id) const {
return (NeedsRfactor(policy.search_task, state, stage_id) && !HasCacheWriteStage(state, stage_id))
? ConditionKind::kApply
: ConditionKind::kSkip;
/********** RuleSimplifyComputeWithConstTensor **********/
SketchGenerationRule::ConditionKind RuleSimplifyComputeWithConstTensor::MeetCondition(
- const SketchPolicyNode& policy, const State& state, int stage_id) {
+ const SketchPolicyNode& policy, const State& state, int stage_id) const {
return state->stages[stage_id]->op->attrs.count(SearchPolicyKey::simplify_const_tensor_indices)
? ConditionKind::kApplyAndSkipRest
: ConditionKind::kSkip;
return {std::make_pair(tmp_s, stage_id - 1)};
}
+/********** RuleCrossThreadReduction **********/
+
+SketchGenerationRule::ConditionKind RuleCrossThreadReduction::MeetCondition(
+ const SketchPolicyNode& policy, const State& state, int stage_id) const {
+ CHECK(IsGPUTask(policy.search_task));
+
+ // If it is an intermidiate state created by RuleAddCacheWrite,
+ // we just skip it.
+ if (HasCacheWriteStage(state, stage_id)) {
+ return ConditionKind::kSkip;
+ }
+
+ const auto& op = state->stages[stage_id]->op;
+ if (op->IsInstance<te::ComputeOpNode>()) {
+ // Compute the product of lengths of all space iters and all reduce iters
+ int cum_space_len, cum_reduce_len;
+ std::tie(cum_space_len, cum_reduce_len) =
+ GetCumulativeSpaceAndReductionLengh(state->stages[stage_id]);
+
+ if (NeedsMultilevelTiling(policy.search_task, state, stage_id)) {
+ // Do rfactor if we do not have enough parallelism on space iters
+ return cum_space_len < cum_reduce_len ? ConditionKind::kApply : ConditionKind::kSkip;
+ } else if (cum_reduce_len > 1) {
+ // Try rfactor for other reduction operators
+ return cum_reduce_len > policy.search_task->hardware_params->warp_size ? ConditionKind::kApply
+ : ConditionKind::kSkip;
+ }
+ }
+
+ return ConditionKind::kSkip;
+}
+
+std::vector<std::pair<State, int>> RuleCrossThreadReduction::Apply(const SketchPolicyNode& policy,
+ const State& state,
+ int stage_id) const {
+ const SearchTask& task = policy.search_task;
+ State tmp_s = state;
+
+ // fuse all reduction iters
+ Array<Iterator> space_iters, reduce_iters;
+ Iterator fused_reduce_iter;
+ tmp_s =
+ FuseAllReductionIterators(tmp_s, stage_id, &fused_reduce_iter, &space_iters, &reduce_iters);
+
+ // Check the opportunity for kernel fusion
+ bool fusible = false;
+ int target_stage_id = GetSingleConsumerId(policy.search_task, tmp_s, stage_id);
+ int num_common_outer = -1;
+ if (target_stage_id >= 0) {
+ num_common_outer =
+ GetNumCommonOuterIterator(policy.search_task, tmp_s, stage_id, target_stage_id);
+ if (num_common_outer > 0 &&
+ !NeedsMultilevelTiling(policy.search_task, state, target_stage_id)) {
+ fusible = true;
+ }
+ }
+
+ if (fusible) {
+ const Stage& target_stage = state->stages[target_stage_id];
+ std::vector<int> split_step_ids;
+
+ GetSplitStepIds(tmp_s, target_stage_id, &split_step_ids);
+
+ if (split_step_ids.size() == 0) {
+ // If the target stage does not have split step,
+ // it must be a simple stage without reduce iters.
+ // We then should do a split for it.
+ CHECK(!HasReduceIter(target_stage));
+ const auto& split_res = tmp_s.split(target_stage_id, target_stage->iters.back(),
+ {Integer(task->hardware_params->warp_size)});
+ tmp_s.bind(target_stage_id, split_res[1], IteratorAnnotation::kThreadX);
+ split_step_ids.push_back(tmp_s->transform_steps.size() - 2);
+ }
+
+ CHECK_EQ(split_step_ids.size(), 1);
+
+ const Iterator& target_iter = tmp_s->stages[target_stage_id]->iters[num_common_outer - 1];
+ const auto& split_res = tmp_s.follow_split(stage_id, fused_reduce_iter, split_step_ids[0], 1);
+ tmp_s.bind(stage_id, split_res[1], IteratorAnnotation::kThreadX);
+ tmp_s.compute_at(stage_id, target_stage_id, target_iter);
+ } else {
+ const auto& split_res =
+ tmp_s.split(stage_id, fused_reduce_iter, {Integer(task->hardware_params->warp_size)});
+ tmp_s.bind(stage_id, split_res[1], IteratorAnnotation::kThreadX);
+ }
+
+ return {std::make_pair(std::move(tmp_s), stage_id - 1)};
+}
+
+/********** RuleSpecialComputeLocationGPU **********/
+
+SketchGenerationRule::ConditionKind RuleSpecialComputeLocationGPU::MeetCondition(
+ const SketchPolicyNode& policy, const State& state, int stage_id) const {
+ if (GetProducers(policy.search_task, state, stage_id).empty()) {
+ return ConditionKind::kSkip;
+ }
+
+ const std::set<int>& consumers = GetConsumers(policy.search_task, state, stage_id);
+ if (consumers.size() == 1 && state->stages[*consumers.begin()]->op->attrs.count(
+ SearchPolicyKey::simplify_const_tensor_indices)) {
+ return ConditionKind::kApplyAndSkipRest;
+ }
+
+ return ConditionKind::kSkip;
+}
+
+std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
+ const SketchPolicyNode& policy, const State& state, int stage_id) const {
+ State tmp_s = state;
+ const std::set<int>& consumers = GetConsumers(policy.search_task, state, stage_id);
+ CHECK_EQ(consumers.size(), 1);
+
+ // Get the last outer space iterator that is not unrolled.
+ const Stage& target_stage = state->stages[*consumers.begin()];
+ for (size_t i = 0; i < target_stage->iters.size(); ++i) {
+ if (target_stage->iters[i]->annotation == IteratorAnnotation::kUnroll) {
+ CHECK_GT(i, 0);
+
+ tmp_s.compute_at(stage_id, *consumers.begin(), target_stage->iters[i - 1]);
+ break;
+ }
+ }
+
+ return {std::make_pair(std::move(tmp_s), stage_id - 1)};
+}
+
/********** Init Population **********/
InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy,
}
InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state) const {
- std::vector<int> auto_unroll_configs = {0, 16, 64, 512};
+ std::vector<int> auto_unroll_configs = IsGPUTask(policy->search_task)
+ ? std::vector<int>({0, 16, 64, 512, 1024})
+ : std::vector<int>({0, 16, 64, 512});
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
const Stage& stage = (*state)->stages[stage_id];
// Skip the inlined stage and placeholder stage
return ResultKind::kValid;
}
+InitPopulationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state) const {
+ std::set<int> multi_level_tiling_root_set;
+ for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
+ if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
+ const Stage& stage = (*state)->stages[stage_id];
+ if (stage->compute_at != ComputeAtKind::kIter) {
+ // This stage is not multi-level tiled,
+ // so it must be produced by RuleCrossThreadReduction.
+ CHECK(HasCrossThreadReduction(*state, stage_id));
+ } else {
+ const auto res = (*state)->attach_map->stage_to_attach_iter.find(stage_id);
+ CHECK(res != (*state)->attach_map->stage_to_attach_iter.end());
+ multi_level_tiling_root_set.insert(res->second.first);
+ }
+ }
+ }
+
+ *state = policy->search_task->compute_dag.InferBound(*state);
+
+ for (int stage_id = (*state)->stages.size() - 1; stage_id >= 0; --stage_id) {
+ const Stage& stage = (*state)->stages[stage_id];
+
+ if (stage->compute_at == ComputeAtKind::kInlined || stage->op_type == StageKind::kPlaceholder) {
+ continue;
+ }
+
+ // Deal with the cross-thread reduction generated by RuleCrossThreadReduction
+ if (HasCrossThreadReduction(*state, stage_id)) {
+ Iterator fused_it;
+ *state = std::move(FuseAllOuterSpaceIterators(*state, stage_id, &fused_it));
+ state->bind(stage_id, fused_it, IteratorAnnotation::kBlockX);
+ continue;
+ }
+
+ // Skip if this stage has already been annotaed with threadIdx.x
+ if (HasAnnotatedIter(stage, IteratorAnnotation::kThreadX)) {
+ continue;
+ }
+
+ if (stage->compute_at == ComputeAtKind::kRoot) {
+ // This stage has not been tiled, but in GPU schedule, we must tile the root stage
+ // to do thread binding
+ if (!multi_level_tiling_root_set.count(stage_id)) {
+ Iterator fused_it;
+ *state = FuseAllOuterSpaceIterators(*state, stage_id, &fused_it);
+
+ if (GetExtent(fused_it) <= policy->search_task->hardware_params->warp_size) {
+ state->bind(stage_id, fused_it, IteratorAnnotation::kThreadX);
+ } else {
+ // Set threadIdx.x = default_warp_size by default.
+ // The later EvolutionarySearch will try more possiblity
+ const auto& split_its = state->split(
+ stage_id, fused_it, {Integer(policy->search_task->hardware_params->warp_size)});
+ state->bind(stage_id, split_its[0], IteratorAnnotation::kBlockX);
+ state->bind(stage_id, split_its[1], IteratorAnnotation::kThreadX);
+ }
+ continue;
+ }
+
+ // Otherwise, this is a tiled root stage, we assume it should be tiled with 3 space level
+ // in the outer iterators.
+ // The remaining part deals with the thread binding for multi-level tiled stages
+ auto pop = stage->op.as<te::ComputeOpNode>();
+ std::vector<Iterator> to_fuse;
+ int total_space_extent = 1;
+ for (const auto& i : pop->root_iter_vars()) {
+ CHECK(i->dom.defined());
+ const auto& pint = i->dom->extent.as<IntImmNode>();
+ CHECK(pint);
+ total_space_extent *= pint->value;
+ }
+
+ // Check if the total space extent is too small for multi-level thread binding
+ if (total_space_extent <= policy->search_task->hardware_params->warp_size) {
+ Iterator fused_it;
+ *state = FuseAllOuterSpaceIterators(*state, stage_id, &fused_it);
+ state->bind(stage_id, fused_it, IteratorAnnotation::kThreadX);
+ continue;
+ }
+
+ // Fuse the outermost space tile as blockIdx
+ for (size_t i = 0; i < pop->axis.size(); i++) {
+ const auto& it = (*state)->stages[stage_id]->iters[i];
+ // There may be some iterators that are marked with no split, stop if reaches next
+ // tiling level
+ if (!StrEndsWith(it->name, ".0")) {
+ break;
+ }
+ to_fuse.push_back(it);
+ }
+ const auto& blockidx_it = state->fuse(stage_id, to_fuse);
+ state->bind(stage_id, blockidx_it, IteratorAnnotation::kBlockX);
+
+ // Fuse the second outermost space tile as vthread
+ to_fuse.clear();
+ for (size_t i = 1; i < pop->axis.size() + 1; i++) {
+ const auto& it = (*state)->stages[stage_id]->iters[i];
+ // There may be some iterators that are marked with no split, stop if reaches next
+ // tiling level
+ if (!StrEndsWith(it->name, ".1")) {
+ break;
+ }
+ to_fuse.push_back((*state)->stages[stage_id]->iters[i]);
+ }
+ const auto& vthread_it = state->fuse(stage_id, to_fuse);
+ if (GetExtent(vthread_it) > policy->search_task->hardware_params->max_vthread_extent) {
+ return ResultKind::kInvalid;
+ }
+ state->bind(stage_id, vthread_it, IteratorAnnotation::kVThread);
+
+ // Fuse the third outermost space tile as threadIdx
+ to_fuse.clear();
+ for (size_t i = 2; i < pop->axis.size() + 2; i++) {
+ const auto& it = (*state)->stages[stage_id]->iters[i];
+ // There may be some iterators that are marked with no split, stop if reaches next
+ // tiling level
+ if (!StrEndsWith(it->name, ".2")) {
+ break;
+ }
+ to_fuse.push_back((*state)->stages[stage_id]->iters[i]);
+ }
+ const auto& threadidx_it = state->fuse(stage_id, to_fuse);
+ if (GetExtent(threadidx_it) < policy->search_task->hardware_params->warp_size) {
+ return ResultKind::kInvalid;
+ }
+ state->bind(stage_id, threadidx_it, IteratorAnnotation::kThreadX);
+ } else if (stage->compute_at == ComputeAtKind::kIter &&
+ StrEndsWith(stage->op->name, ".shared")) {
+ // Do cooperative fetching for the cache read stage.
+ // Get spatial_split_step_ids from the root stage
+ const auto& it = (*state)->attach_map->stage_to_attach_iter.find(stage_id);
+ CHECK(it != (*state)->attach_map->stage_to_attach_iter.end());
+ Array<Integer> spatial_split_step_ids = GetSpatialSplitStepIds(*state, it->second.first);
+
+ // Fuse all iterators to do cooperative fetching
+ Iterator fused = state->fuse(stage_id, (*state)->stages[stage_id]->iters);
+ // Split out an extra iterator for vectorization
+ // The later EvolutionarySearch will try more possiblity
+ const auto& iters0 = state->split(stage_id, fused, {Integer(1)});
+ state->vectorize(stage_id, iters0[1]);
+ // Follow split to keep a same thread extent with the root stage
+ const auto& iters1 =
+ state->follow_fused_split(stage_id, iters0[0], spatial_split_step_ids, 1, true);
+ state->bind(stage_id, iters1[1], IteratorAnnotation::kThreadX);
+ }
+ }
+
+ return ResultKind::kValid;
+}
+
} // namespace auto_scheduler
} // namespace tvm
* \return The condition check result of this rule.
*/
virtual ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
- int stage_id) = 0;
+ int stage_id) const = 0;
/*!
* \brief Apply function of this rule.
const State& state, int stage_id) const = 0;
};
+#define DEFINE_SKETCH_GENERATION_RULE(rule_name) \
+ class rule_name : public SketchGenerationRule { \
+ public: \
+ ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state, \
+ int stage_id) const final; \
+ std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state, \
+ int stage_id) const final; \
+ };
+
/*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to
* the next stage. */
-class RuleSkipStage : public SketchGenerationRule {
- public:
- ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
- int stage_id) final;
-
- std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
- int stage_id) const final;
-};
+DEFINE_SKETCH_GENERATION_RULE(RuleSkipStage);
/*! \brief The rule that inlines simple elementwise ops.
* \note This rule only inlines the strictly inlineable stages. Stages marked as not strictly
* inlineable will have a chance to try different compute at location in InitPopulation later.
*/
-class RuleAlwaysInline : public SketchGenerationRule {
- public:
- ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
- int stage_id) final;
-
- std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
- int stage_id) const final;
-};
+DEFINE_SKETCH_GENERATION_RULE(RuleAlwaysInline);
/*! \brief The rule that performs multi-level tiling. */
-class RuleMultiLevelTiling : public SketchGenerationRule {
- public:
- ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
- int stage_id) final;
+DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTiling);
- std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
- int stage_id) const final;
-};
+/*! \brief The rule that performs multi-level tiling and fuses later consumers. */
+DEFINE_SKETCH_GENERATION_RULE(RuleMultiLevelTilingWithFusion);
-/*! The rule that performs multi-level tiling and fuses later consumers. */
-class RuleMultiLevelTilingWithFusion : public SketchGenerationRule {
- public:
- ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
- int stage_id) final;
-
- std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
- int stage_id) const final;
-
- private:
- int target_stage_id;
-};
+/*! \brief The rule that adds a cache read stage. Mainly used for GPU cooperative fetching,
+ * Currently only support 1 to 1 match cache read. */
+DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheRead);
/*! \brief The rule that adds a cache write stage. */
-class RuleAddCacheWrite : public SketchGenerationRule {
- public:
- ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
- int stage_id) final;
-
- std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
- int stage_id) const final;
-};
+DEFINE_SKETCH_GENERATION_RULE(RuleAddCacheWrite);
/*! \brief The rule that adds rfactor stage. */
-class RuleAddRfactor : public SketchGenerationRule {
- public:
- ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
- int stage_id) final;
-
- std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
- int stage_id) const final;
-};
+DEFINE_SKETCH_GENERATION_RULE(RuleAddRfactor);
/*! \brief The rule that deals with compute ops that perform "fake reduction" with const tensors.
- * This kind of op comes from winograd transformation.
- */
-class RuleSimplifyComputeWithConstTensor : public SketchGenerationRule {
- public:
- ConditionKind MeetCondition(const SketchPolicyNode& policy, const State& state,
- int stage_id) final;
+ * This kind of op comes from winograd transformation. */
+DEFINE_SKETCH_GENERATION_RULE(RuleSimplifyComputeWithConstTensor);
- std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state,
- int stage_id) const final;
-};
+/*! \brief The rule that use cross thread reduction for GPU. */
+DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction);
+
+/*! \brief Handle special cases in Winograd transformation for GPU. We need to change the compute
+ * location of the producers of compute ops that perform "fake reduction" with const tensors. */
+DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU);
/********** Init Population **********/
virtual ResultKind Apply(SketchPolicyNode* policy, State* state) const = 0;
};
+#define DEFINE_INIT_POPULATION_RULE(rule_name) \
+ class rule_name : public InitPopulationRule { \
+ public: \
+ ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \
+ };
+
/*! \brief The rule that fills the incomplete SplitSteps. */
-class InitFillTileSize : public InitPopulationRule {
- public:
- ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-};
+DEFINE_INIT_POPULATION_RULE(InitFillTileSize);
/*! \brief The rule that randomly changes the computation location for some stages, which do not
* need tiling and are not strictly inlineable(e.g. data padding). */
-class InitChangeComputeLocation : public InitPopulationRule {
- public:
- ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-};
+DEFINE_INIT_POPULATION_RULE(InitChangeComputeLocation);
/*! \brief The rule that annotates parallel for CPU. */
-class InitParallel : public InitPopulationRule {
- public:
- ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-};
+DEFINE_INIT_POPULATION_RULE(InitParallel);
/*! \brief The rule that annotates unroll. */
-class InitUnroll : public InitPopulationRule {
- public:
- ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-};
+DEFINE_INIT_POPULATION_RULE(InitUnroll);
/*! \brief The rule that annotates vectorization. */
-class InitVectorization : public InitPopulationRule {
- public:
- ResultKind Apply(SketchPolicyNode* policy, State* state) const final;
-};
+DEFINE_INIT_POPULATION_RULE(InitVectorization);
+
+/*! \brief The rule that annotates thread binding for GPU. */
+DEFINE_INIT_POPULATION_RULE(InitThreadBind);
} // namespace auto_scheduler
} // namespace tvm
namespace tvm {
namespace auto_scheduler {
+Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id) {
+ const auto& stage = s->stages[stage_id];
+ const auto& pop = s->stages[stage_id]->op.as<te::ComputeOpNode>();
+ CHECK(pop != nullptr);
+ const std::set<std::string>& no_split_at_inner_name_set =
+ stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
+ ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
+ : std::set<std::string>();
+ size_t reduce_count = 0;
+ for (const auto axis : pop->reduce_axis) {
+ if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
+ reduce_count++;
+ }
+ }
+
+ Array<Integer> spatial_split_step_ids;
+ for (int i = s->transform_steps.size() - 1; i >= 0; --i) {
+ if (s->transform_steps[i]->IsInstance<CacheWriteStepNode>() ||
+ s->transform_steps[i]->IsInstance<CacheReadStepNode>() ||
+ s->transform_steps[i]->IsInstance<RfactorStepNode>()) {
+ if (stage_id > s->transform_steps[i]->stage_id) {
+ stage_id--;
+ }
+ } else if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
+ if (stage_id == ps->stage_id) {
+ // Assume SplitStep on reduction axes are always after SplitStep on spatial axes.
+ if (reduce_count) {
+ reduce_count--;
+ } else {
+ spatial_split_step_ids.push_back(i);
+ }
+ }
+ }
+ }
+
+ return spatial_split_step_ids;
+}
+
State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
std::vector<int>* spatial_split_step_ids) {
// Temporal object to be used if the input pointer is nullptr
return res;
}
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled")
+ .set_body_typed([](const Stage& stage) { return IsTiled(stage); });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheReadStage")
+ .set_body_typed([](const State& s, int stage_id) { return HasCacheReadStage(s, stage_id); });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCacheWriteStage")
+ .set_body_typed([](const State& s, int stage_id) { return HasCacheWriteStage(s, stage_id); });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasRfactorStage")
+ .set_body_typed([](const State& s, int stage_id) { return HasRfactorStage(s, stage_id); });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsHasCrossThreadReduction")
+ .set_body_typed([](const State& s, int stage_id) {
+ return HasCrossThreadReduction(s, stage_id);
+ });
+
} // namespace auto_scheduler
} // namespace tvm
namespace tvm {
namespace auto_scheduler {
+/*! \brief Return whether the search task is targeting a CPU. */
+inline bool IsCPUTask(const SearchTask& task) {
+ return (task)->target->kind->device_type == kDLCPU;
+}
+
+/*! \brief Return whether the search task is targeting a GPU. */
+inline bool IsGPUTask(const SearchTask& task) {
+ return (task)->target->kind->device_type == kDLGPU ||
+ (task)->target->kind->device_type == kDLOpenCL ||
+ (task)->target->kind->device_type == kDLVulkan ||
+ (task)->target->kind->device_type == kDLMetal ||
+ (task)->target->kind->device_type == kDLROCM ||
+ (task)->target->kind->device_type == kOpenGL;
+}
+
+/*! \brief Return whether the search task is targeting a CUDA GPU. */
+inline bool IsCUDATask(const SearchTask& task) {
+ return (task)->target->kind->device_type == kDLGPU;
+}
+
+/*! \brief Return whether the search task is targeting a OpenCL GPU. */
+inline bool IsOpenCLTask(const SearchTask& task) {
+ return (task)->target->kind->device_type == kDLOpenCL;
+}
+
/*! \brief Argsort. Order: largest to smallest */
template <typename T>
inline std::vector<int> Argsort(const std::vector<T>& scores) {
return false;
}
+/*! \brief Return whether the state does cache_read for stage_id. */
+inline bool HasCacheReadStage(const State& s, int stage_id) {
+ for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
+ if (auto ps = s->transform_steps[i].as<CacheReadStepNode>()) {
+ if (stage_id == ps->stage_id) {
+ return true;
+ }
+ }
+
+ if (s->transform_steps[i]->IsInstance<CacheWriteStepNode>() ||
+ s->transform_steps[i]->IsInstance<CacheReadStepNode>() ||
+ s->transform_steps[i]->IsInstance<RfactorStepNode>()) {
+ if (stage_id > s->transform_steps[i]->stage_id) {
+ stage_id--;
+ }
+ }
+ }
+ return false;
+}
+
/*! \brief Return whether the state does cache_write for stage_id. */
inline bool HasCacheWriteStage(const State& s, int stage_id) {
for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
return false;
}
+/*! \brief Return whether the state does rfactor for stage_id. */
+inline bool HasRfactorStage(const State& s, int stage_id) {
+ for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
+ if (auto ps = s->transform_steps[i].as<RfactorStepNode>()) {
+ if (stage_id == ps->stage_id) {
+ return true;
+ }
+ }
+
+ if (s->transform_steps[i]->IsInstance<CacheWriteStepNode>() ||
+ s->transform_steps[i]->IsInstance<CacheReadStepNode>() ||
+ s->transform_steps[i]->IsInstance<RfactorStepNode>()) {
+ if (stage_id > s->transform_steps[i]->stage_id) {
+ stage_id--;
+ }
+ }
+ }
+ return false;
+}
+
+/*! \brief Return whether the stage does cross thread reduction. */
+inline bool HasCrossThreadReduction(const State& state, int stage_id) {
+ std::function<bool(const Stage&)> check_stage = [](const Stage& in_stage) {
+ for (const auto& iter : in_stage->iters) {
+ if (iter->annotation == IteratorAnnotation::kThreadX &&
+ iter->iter_kind == IteratorKind::kReduction) {
+ return true;
+ }
+ }
+ return false;
+ };
+
+ // Check the stage itself
+ if (check_stage(state->stages[stage_id])) {
+ return true;
+ }
+
+ // Check the attached stages
+ for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); iter_id++) {
+ const auto& res =
+ state->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id));
+ if (res != state->attach_map->iter_to_attached_stages.end()) {
+ for (int attached_stage_id : res->second) {
+ if (check_stage(state->stages[attached_stage_id])) {
+ return true;
+ }
+ }
+ }
+ }
+
+ return false;
+}
+
/*! \brief Return whether the stage has been tiled already. */
inline bool IsTiled(const Stage& stage) {
auto op = stage->op.as<te::ComputeOpNode>();
}
}
+/*! \brief Get the last reduce iterator in the outermost reduce tile. */
+inline Iterator GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) {
+ auto pop = stage->op.as<te::ComputeOpNode>();
+ CHECK(pop != nullptr);
+ std::set<std::string> original_names;
+
+ const std::set<std::string>& no_split_at_inner_name_set =
+ stage->op->attrs.count(SearchPolicyKey::no_split_at_inner)
+ ? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
+ : std::set<std::string>();
+ size_t reduce_axis_size = 0;
+ for (const auto axis : pop->reduce_axis) {
+ if (!no_split_at_inner_name_set.count(axis->var->name_hint)) {
+ reduce_axis_size++;
+ }
+ }
+ if (reduce_axis_size) {
+ for (const auto& iter : stage->iters) {
+ if (iter->iter_kind == IteratorKind::kReduction) {
+ ExtractOriginalIterators(iter->name, &original_names);
+ if (original_names.size() == reduce_axis_size) {
+ return iter;
+ }
+ }
+ }
+ } else {
+ // Return the first reduce iterator
+ for (const auto& iter : stage->iters) {
+ if (iter->iter_kind == IteratorKind::kReduction) {
+ return iter;
+ }
+ }
+ }
+
+ LOG(FATAL) << "Cannot find the iterator.";
+ return stage->iters[0];
+}
+
+/*! \brief Get all split steps for one stage. */
+inline void GetSplitStepIds(const State& s, int stage_id, std::vector<int>* split_step_ids) {
+ for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
+ if (auto ps = s->transform_steps[i].as<SplitStepNode>()) {
+ if (stage_id == ps->stage_id) {
+ split_step_ids->push_back(i);
+ }
+ }
+
+ if (s->transform_steps[i]->IsInstance<CacheWriteStepNode>() ||
+ s->transform_steps[i]->IsInstance<CacheReadStepNode>() ||
+ s->transform_steps[i]->IsInstance<RfactorStepNode>()) {
+ if (stage_id > s->transform_steps[i]->stage_id) {
+ stage_id--;
+ }
+ }
+ }
+}
+
/*! \brief Fuse all reduction iterators. */
inline State FuseAllReductionIterators(const State& state, int stage_id, Iterator* fused_iter,
Array<Iterator>* space_iters,
return tmp_s;
}
+/*! \brief Fuse all outer level space iterators. */
+inline State FuseAllOuterSpaceIterators(const State& state, int stage_id, Iterator* fused_iter) {
+ std::vector<Iterator> to_fuse;
+ for (size_t iter_id = 0; iter_id < state->stages[stage_id]->iters.size(); ++iter_id) {
+ const auto& it = state->stages[stage_id]->iters[iter_id];
+ // Stop at reduce iterator or annotated iterator
+ if (it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) {
+ break;
+ }
+ // Stop at compute_at attach point
+ if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id - 1))) {
+ break;
+ }
+ to_fuse.push_back(it);
+ }
+
+ CHECK(!to_fuse.empty());
+ State tmp_s = state;
+ if (to_fuse.size() > 1) {
+ *fused_iter = tmp_s.fuse(stage_id, to_fuse);
+ } else {
+ *fused_iter = to_fuse[0];
+ }
+ return tmp_s;
+}
+
/*! \brief Random sample states. */
inline Array<State> RandomSampleStates(const Array<State>& in_states, std::mt19937* random_gen,
size_t out_size) {
std::unordered_map<int, std::vector<int>> factor_memory_;
};
+/*! \brief Get the indexes of SplitStep that processes on spatial iteratior. */
+Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id);
+
// Apply multi-level tiling structure according to a string format,
// where "S" stands a space level, "R" stands for a reudciton level.
// For example, if the format is "SSRSRS", the we will
*/
#include <tvm/auto_scheduler/search_task.h>
+#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>
HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target,
const Target& target_host) {
- if (target->kind->name == "llvm") {
+ if (target->kind->device_type == kDLCPU) {
return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64);
+ } else if (target->kind->device_type == kDLGPU) {
+ auto hardware_params = HardwareParams(-1, 16, 64);
+ auto* p_hardware_params = hardware_params.CopyOnWrite();
+
+ auto ctx = TVMContext{kDLGPU, 0};
+ auto func = tvm::runtime::Registry::Get("device_api.gpu");
+ CHECK(func != nullptr) << "Cannot find GPU device_api in registry";
+ auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
+
+ tvm::runtime::TVMRetValue ret;
+ device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
+ p_hardware_params->max_shared_memory_per_block = ret;
+
+ device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret);
+ p_hardware_params->max_registers_per_block = ret;
+
+ device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
+ p_hardware_params->max_threads_per_block = ret;
+
+ device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
+ p_hardware_params->warp_size = ret;
+
+ p_hardware_params->max_vthread_extent = p_hardware_params->warp_size / 4;
+
+ return hardware_params;
} else {
LOG(FATAL) << "No default hardware parameters for target: " << target;
}
#include "utils.h"
+namespace dmlc {
+namespace json {
+
+template <>
+struct Handler<::tvm::Array<::tvm::Integer>> {
+ inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::Integer>& array) {
+ writer->BeginArray(false);
+ for (const auto& i : array) {
+ CHECK(i.defined());
+ writer->WriteArrayItem(i->value);
+ }
+ writer->EndArray();
+ }
+ inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::Integer>* array) {
+ array->clear();
+ reader->BeginArray();
+ while (reader->NextArrayItem()) {
+ int value;
+ Handler<int>::Read(reader, &value);
+ array->push_back(value);
+ }
+ }
+};
+
+template <>
+struct Handler<::tvm::Array<::tvm::Optional<::tvm::Integer>>> {
+ inline static void Write(dmlc::JSONWriter* writer,
+ const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& array) {
+ writer->BeginArray(false);
+ for (const auto& i : array) {
+ CHECK(i);
+ writer->WriteArrayItem(i.value()->value);
+ }
+ writer->EndArray();
+ }
+ inline static void Read(dmlc::JSONReader* reader,
+ ::tvm::Array<::tvm::Optional<::tvm::Integer>>* array) {
+ array->clear();
+ reader->BeginArray();
+ while (reader->NextArrayItem()) {
+ int value;
+ Handler<int>::Read(reader, &value);
+ array->push_back(::tvm::Integer(value));
+ }
+ }
+};
+
+} // namespace json
+} // namespace dmlc
+
namespace tvm {
namespace auto_scheduler {
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&node->stage_id);
- std::vector<int> int_list;
s = reader->NextArrayItem();
CHECK(s);
- reader->Read(&int_list);
- ::tvm::Array<::tvm::Integer> fused_ids;
- for (const auto& i : int_list) {
- fused_ids.push_back(i);
- }
- node->fused_ids = fused_ids;
+ reader->Read(&node->fused_ids);
data_ = std::move(node);
}
writer->WriteArraySeperator();
writer->WriteString(record_prefix_str);
writer->WriteArrayItem(stage_id);
- writer->WriteArrayItem(IntArrayToVector(fused_ids));
+ writer->WriteArrayItem(fused_ids);
}
Iterator FuseStepNode::ApplyToState(State* state) const {
reader->Read(&node->stage_id);
s = reader->NextArrayItem();
CHECK(s);
- std::vector<int> int_list;
- reader->Read(&int_list);
- ::tvm::Array<::tvm::Integer> after_ids;
- for (const auto& i : int_list) {
- after_ids.push_back(i);
- }
- node->after_ids = after_ids;
+ reader->Read(&node->after_ids);
data_ = std::move(node);
}
writer->WriteArraySeperator();
writer->WriteString(record_prefix_str);
writer->WriteArrayItem(stage_id);
- writer->WriteArrayItem(IntArrayToVector(after_ids));
+ writer->WriteArrayItem(after_ids);
}
void ReorderStepNode::ApplyToState(State* state) const {
}
s = reader->NextArrayItem();
CHECK(s);
- std::vector<int> int_list;
- reader->Read(&int_list);
- ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths;
- for (const auto& i : int_list) {
- lengths.push_back(::tvm::Integer(i));
- }
- node->lengths = lengths;
+ reader->Read(&node->lengths);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&node->inner_to_outer);
writer->WriteArrayItem(stage_id);
writer->WriteArrayItem(iter_id);
writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0);
- writer->WriteArrayItem(IntArrayToVector(lengths));
+ writer->WriteArrayItem(lengths);
writer->WriteArrayItem(static_cast<int>(inner_to_outer));
}
reader->Read(&node->iter_id);
s = reader->NextArrayItem();
CHECK(s);
- std::vector<int> int_list;
- reader->Read(&int_list);
+ reader->Read(&node->src_step_ids);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&node->level);
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&node->factor_or_nparts);
- ::tvm::Array<::tvm::Integer> src_step_ids;
- for (const auto& i : int_list) {
- src_step_ids.push_back(i);
- }
- node->src_step_ids = src_step_ids;
data_ = std::move(node);
}
writer->WriteString(record_prefix_str);
writer->WriteArrayItem(stage_id);
writer->WriteArrayItem(iter_id);
- writer->WriteArrayItem(IntArrayToVector(src_step_ids));
+ writer->WriteArrayItem(src_step_ids);
writer->WriteArrayItem(level);
writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
}
node->scope_name = std::move(string_value);
s = reader->NextArrayItem();
CHECK(s);
- std::vector<int> int_list;
- reader->Read(&int_list);
- Array<Integer> reader_stage_ids;
- for (int i : int_list) {
- reader_stage_ids.push_back(i);
- }
- node->reader_stage_ids = std::move(reader_stage_ids);
+ reader->Read(&node->reader_stage_ids);
data_ = std::move(node);
}
writer->WriteArrayItem(stage_id);
writer->WriteArraySeperator();
writer->WriteString(scope_name);
- writer->WriteArrayItem(IntArrayToVector(reader_stage_ids));
+ writer->WriteArrayItem(reader_stage_ids);
}
int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
}
}
-/*! \brief Convert a Array<Integer> to std::vector<int>. */
-inline std::vector<int> IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) {
- std::vector<int> out;
- for (const auto& x : data) {
- CHECK(x.defined());
- out.push_back(x);
- }
- return out;
-}
-
-/*! \brief Convert a Array<Optional<Integer>> to std::vector<int>. */
-inline std::vector<int> IntArrayToVector(
- const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) {
- std::vector<int> out;
- for (const auto& x : data) {
- CHECK(x);
- out.push_back(x.value());
- }
- return out;
-}
-
/*! \brief Return whether two int arrays are elementwise-equal */
inline bool IntArrayEqual(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
if (arr1.size() != arr2.size()) {
target = tvm.target.create(target)
task = auto_scheduler.SearchTask(dag, workload_key, target)
- if search_policy == 'empty':
- search_policy = auto_scheduler.EmptyPolicy(task)
- elif search_policy == 'sketch':
- search_policy = auto_scheduler.SketchPolicy(task,
- init_search_callbacks=init_search_callbacks)
-
with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name
+ init_search_callbacks = init_search_callbacks or []
+ init_search_callbacks.append(auto_scheduler.PreloadMeasuredStates(log_file))
+
+ if search_policy == 'empty':
+ search_policy = auto_scheduler.EmptyPolicy(task)
+ elif search_policy == 'sketch':
+ search_policy = auto_scheduler.SketchPolicy(task,
+ init_search_callbacks=init_search_callbacks)
+
tuning_options = auto_scheduler.TuningOptions(num_measure_trials=num_measure_trials,
runner=runner, verbose=1, measure_callbacks=[auto_scheduler.RecordToFile(log_file)])
sch, args = auto_scheduler.auto_schedule(task, search_policy, tuning_options)
t.join()
+def test_sketch_search_policy_cuda_rpc_runner():
+ if not tvm.runtime.enabled("cuda"):
+ return
+ measure_ctx = auto_scheduler.LocalRPCMeasureContext()
+ # wrap the search in a new thread to avoid the conflict
+ # between python's multiprocessing and tvm's thread pool
+ t = PropagatingThread(target=search_common,
+ kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda',
+ 'runner': measure_ctx.runner})
+ t.start()
+ t.join()
+
+
if __name__ == "__main__":
test_workload_registry_search_basic()
test_sketch_search_policy_basic()
+ test_sketch_search_policy_cuda_rpc_runner()
import tvm
from tvm import te, auto_scheduler
+from tvm.auto_scheduler import _ffi_api
+from tvm.auto_scheduler.loop_state import Stage
from test_auto_scheduler_common import (matmul_auto_scheduler_test, conv2d_nchw_bn_relu_auto_scheduler_test,
max_pool2d_auto_scheduler_test, min_nm_auto_scheduler_test,
softmax_nm_auto_scheduler_test, softmax_abcd_auto_scheduler_test,
conv2d_winograd_nhwc_auto_scheduler_test)
+
def generate_sketches(workload_func, args, target, print_for_debug=False):
workload_key = auto_scheduler.make_workload_key(workload_func, args)
dag = auto_scheduler.ComputeDAG(workload_key)
policy = auto_scheduler.SketchPolicy(task, verbose=0)
return policy.generate_sketches(print_for_debug)
+def assert_compute_at_condition(stage, condition):
+ assert stage.compute_at == Stage.COMPUTE_AT_TRANS_TABLE[condition]
+
+def assert_is_tiled(stage):
+ assert _ffi_api.SearchPolicyUtilsIsTiled(stage)
+
+def assert_is_not_tiled(stage):
+ assert not _ffi_api.SearchPolicyUtilsIsTiled(stage)
+
+def assert_has_cache_write(state, stage_id):
+ assert _ffi_api.SearchPolicyUtilsHasCacheWriteStage(state, stage_id)
+
+def assert_has_cache_read(state, stage_id):
+ assert _ffi_api.SearchPolicyUtilsHasCacheReadStage(state, stage_id)
+
+def assert_has_rfactor(state, stage_id):
+ assert _ffi_api.SearchPolicyUtilsHasRfactorStage(state, stage_id)
+
+def assert_has_cross_thread_reduction(state, stage_id):
+ assert _ffi_api.SearchPolicyUtilsHasCrossThreadReduction(state, stage_id)
+
+
def test_cpu_matmul_sketch():
sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 'llvm')
''' 3 multi-level tiling sketches
2 - Multi-level tiling with cache write on position 1
'''
assert len(sketches) == 3
+ # Sketch 0
+ assert_is_tiled(sketches[0].stages[2])
+ # Sketch 1
+ assert_is_tiled(sketches[1].stages[2])
+ assert_has_cache_write(sketches[1], 2)
+ assert_compute_at_condition(sketches[1].stages[2], "iter")
+ # Sketch 2
+ assert_is_tiled(sketches[2].stages[2])
+ assert_has_cache_write(sketches[2], 2)
+ assert_compute_at_condition(sketches[2].stages[2], "iter")
+ assert sketches[1] != sketches[2]
sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 512), 'llvm')
''' 2 rfactor sketches + 3 multi-level tiling sketches
4 - Multi-level tiling with cache write on position 1
'''
assert len(sketches) == 5
+ # Sketch 0
+ assert_has_rfactor(sketches[0], 2)
+ # Sketch 1
+ assert_has_rfactor(sketches[1], 2)
+ assert sketches[0] != sketches[1]
+ # Sketch 2
+ assert_is_tiled(sketches[2].stages[2])
+ # Sketch 3
+ assert_is_tiled(sketches[3].stages[2])
+ assert_has_cache_write(sketches[3], 2)
+ assert_compute_at_condition(sketches[3].stages[2], "iter")
+ # Sketch 4
+ assert_is_tiled(sketches[4].stages[2])
+ assert_has_cache_write(sketches[4], 2)
+ assert_compute_at_condition(sketches[4].stages[2], "iter")
+ assert sketches[3] != sketches[4]
+
def test_cpu_conv2d_bn_relu_sketch():
sketches = generate_sketches(conv2d_nchw_bn_relu_auto_scheduler_test,
2 - Conv2d multi-level tiling without fusion
'''
assert len(sketches) == 3
+ # Sketch 0
+ assert_is_not_tiled(sketches[0].stages[1])
+ assert_is_tiled(sketches[0].stages[3])
+ assert_compute_at_condition(sketches[0].stages[3], "iter")
+ assert_compute_at_condition(sketches[0].stages[5], "inlined")
+ assert_compute_at_condition(sketches[0].stages[7], "inlined")
+ assert_compute_at_condition(sketches[0].stages[9], "inlined")
+ assert_is_tiled(sketches[0].stages[10])
+ # Sketch 1
+ assert_is_not_tiled(sketches[1].stages[1])
+ assert_is_tiled(sketches[1].stages[3])
+ assert_compute_at_condition(sketches[1].stages[3], "iter")
+ assert_compute_at_condition(sketches[1].stages[5], "inlined")
+ assert_compute_at_condition(sketches[1].stages[7], "inlined")
+ assert_compute_at_condition(sketches[1].stages[9], "inlined")
+ assert_is_tiled(sketches[1].stages[10])
+ # Sketch 2
+ assert_is_not_tiled(sketches[2].stages[1])
+ assert_is_tiled(sketches[2].stages[3])
+ assert_compute_at_condition(sketches[2].stages[3], "root")
+ assert_compute_at_condition(sketches[2].stages[5], "inlined")
+ assert_compute_at_condition(sketches[2].stages[7], "inlined")
+ assert_compute_at_condition(sketches[2].stages[9], "inlined")
+ assert_is_not_tiled(sketches[2].stages[10])
+
def test_cpu_max_pool2d_sketch():
sketches = generate_sketches(max_pool2d_auto_scheduler_test, (1, 56, 56, 512, 1), 'llvm')
- assert len(sketches) == 1 # 1 default sketch
+ ''' 1 default sketch '''
+ assert len(sketches) == 1
+ # Sketch 0
+ assert len(sketches[0].transform_steps) == 0
+
def test_cpu_min_sketch():
sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), 'llvm')
- assert len(sketches) == 3
''' 2 rfactor sketches + 1 default sketch
0 - Rfactor with factor position 0
1 - Rfactor with factor position 1
2 - Default sketch
'''
+ assert len(sketches) == 3
+ # Sketch 0
+ assert_has_rfactor(sketches[0], 1)
+ # Sketch 1
+ assert_has_rfactor(sketches[1], 1)
+ assert sketches[0] != sketches[1]
+ # Sketch 2
+ assert len(sketches[2].transform_steps) == 0
+
def test_cpu_softmax_sketch():
sketches = generate_sketches(softmax_nm_auto_scheduler_test, (1, 1024), 'llvm')
''' (2 rfactor sketches + 1 default sketch) * (2 rfactor sketches + 1 default sketch) '''
assert len(sketches) == (3 * 3)
+ for i in range(0, 3):
+ for j in range(0, 3):
+ sketch = sketches[i * 3 + j]
+ if j in [0, 1]:
+ assert_has_rfactor(sketch, 1)
+ if i in [0, 1]:
+ assert_has_rfactor(sketch, 4 if j in [0, 1] else 3)
+ assert len(sketches[8].transform_steps) == 0
sketches = generate_sketches(softmax_abcd_auto_scheduler_test, (1, 12, 128, 128), 'llvm')
''' (2 rfactor sketches + 1 default sketch) * (2 rfactor sketches + 1 default sketch) '''
assert len(sketches) == (3 * 3)
+ for i in range(0, 3):
+ for j in range(0, 3):
+ sketch = sketches[i * 3 + j]
+ if j in [0, 1]:
+ assert_has_rfactor(sketch, 1)
+ if i in [0, 1]:
+ assert_has_rfactor(sketch, 4 if j in [0, 1] else 3)
+ assert len(sketches[8].transform_steps) == 0
+
def test_cpu_conv2d_winograd_sketch():
sketches = generate_sketches(conv2d_winograd_nhwc_auto_scheduler_test,
2 - Bgemm multi-level tiling with cache write on position 1
'''
assert len(sketches) == 3
+ # Sketch 0
+ assert_is_not_tiled(sketches[0].stages[1])
+ assert_is_not_tiled(sketches[0].stages[2])
+ assert_compute_at_condition(sketches[0].stages[3], "inlined")
+ assert_is_tiled(sketches[0].stages[4])
+ assert_is_tiled(sketches[0].stages[6])
+ assert_compute_at_condition(sketches[0].stages[7], "inlined")
+ assert_is_tiled(sketches[0].stages[8])
+ assert_is_not_tiled(sketches[0].stages[9])
+ # Sketch 1
+ assert_is_not_tiled(sketches[1].stages[1])
+ assert_is_not_tiled(sketches[1].stages[2])
+ assert_compute_at_condition(sketches[1].stages[3], "inlined")
+ assert_is_tiled(sketches[1].stages[4])
+ assert_is_tiled(sketches[1].stages[6])
+ assert_has_cache_write(sketches[1], 6)
+ assert_compute_at_condition(sketches[1].stages[6], "iter")
+ assert_compute_at_condition(sketches[1].stages[8], "inlined")
+ assert_is_tiled(sketches[1].stages[9])
+ assert_is_not_tiled(sketches[1].stages[10])
+ # Sketch 2
+ assert_is_not_tiled(sketches[2].stages[1])
+ assert_is_not_tiled(sketches[2].stages[2])
+ assert_compute_at_condition(sketches[2].stages[3], "inlined")
+ assert_is_tiled(sketches[2].stages[4])
+ assert_is_tiled(sketches[2].stages[6])
+ assert_has_cache_write(sketches[2], 6)
+ assert_compute_at_condition(sketches[2].stages[6], "iter")
+ assert_compute_at_condition(sketches[2].stages[8], "inlined")
+ assert_is_tiled(sketches[2].stages[9])
+ assert_is_not_tiled(sketches[2].stages[10])
+ assert sketches[1] != sketches[2]
+
+
+def test_cuda_matmul_sketch():
+ if not tvm.context("cuda", 0).exist:
+ return
+
+ sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 'cuda')
+ ''' 1 multi-level tiling sketch '''
+ assert len(sketches) == 1
+ assert_has_cache_read(sketches[0], 0)
+ assert_compute_at_condition(sketches[0].stages[1], "iter")
+ assert_has_cache_read(sketches[0], 2)
+ assert_compute_at_condition(sketches[0].stages[3], "iter")
+ assert_has_cache_write(sketches[0], 4)
+ assert_is_tiled(sketches[0].stages[4])
+ assert_compute_at_condition(sketches[0].stages[4], "iter")
+ assert_is_tiled(sketches[0].stages[5])
+
+ sketches = generate_sketches(matmul_auto_scheduler_test, (8, 8, 1024), 'cuda')
+ ''' 1 cross thread reuction sketch + 1 multi-level tiling sketch '''
+ assert len(sketches) == 2
+ # Sketch 0
+ assert_has_cross_thread_reduction(sketches[0], 2)
+ # Sketch 1
+ assert_has_cache_read(sketches[1], 0)
+ assert_compute_at_condition(sketches[1].stages[1], "iter")
+ assert_has_cache_read(sketches[1], 2)
+ assert_compute_at_condition(sketches[1].stages[3], "iter")
+ assert_has_cache_write(sketches[1], 4)
+ assert_is_tiled(sketches[1].stages[4])
+ assert_compute_at_condition(sketches[1].stages[4], "iter")
+ assert_is_tiled(sketches[1].stages[5])
+
+
+def test_cuda_conv2d_bn_relu_sketch():
+ if not tvm.context("cuda", 0).exist:
+ return
+
+ sketches = generate_sketches(conv2d_nchw_bn_relu_auto_scheduler_test,
+ (1, 56, 56, 512, 512, 3, 1, 1), 'cuda')
+ ''' 1 multi-level tiling sketch '''
+ assert len(sketches) == 1
+ assert_has_cache_read(sketches[0], 1)
+ assert_compute_at_condition(sketches[0].stages[1], "inlined")
+ assert_compute_at_condition(sketches[0].stages[2], "iter")
+ assert_has_cache_read(sketches[0], 3)
+ assert_compute_at_condition(sketches[0].stages[4], "iter")
+ assert_is_tiled(sketches[0].stages[5])
+ assert_compute_at_condition(sketches[0].stages[5], "iter")
+ assert_compute_at_condition(sketches[0].stages[7], "inlined")
+ assert_compute_at_condition(sketches[0].stages[9], "inlined")
+ assert_compute_at_condition(sketches[0].stages[11], "inlined")
+ assert_is_tiled(sketches[0].stages[12])
+
+
+def test_cuda_max_pool2d_sketch():
+ if not tvm.context("cuda", 0).exist:
+ return
+
+ sketches = generate_sketches(max_pool2d_auto_scheduler_test, (1, 56, 56, 512, 0), 'cuda')
+ ''' 1 default sketch '''
+ assert len(sketches) == 1
+ assert len(sketches[0].transform_steps) == 0
+
+
+def test_cuda_min_sketch():
+ if not tvm.context("cuda", 0).exist:
+ return
+
+ sketches = generate_sketches(min_nm_auto_scheduler_test, (10, 1024), 'cuda')
+ ''' 1 cross thread reuction sketch + 1 default sketch '''
+ assert len(sketches) == 2
+ # Sketch 0
+ assert_has_cross_thread_reduction(sketches[0], 1)
+ # Sketch 1
+ assert len(sketches[1].transform_steps) == 0
+
+
+def test_cuda_softmax_sketch():
+ if not tvm.context("cuda", 0).exist:
+ return
+
+ sketches = generate_sketches(softmax_nm_auto_scheduler_test, (2, 1024), 'cuda')
+ ''' (1 cross thread reuction sketch + 1 default sketch) * (1 cross thread reuction sketch + 1 default sketch) '''
+ assert len(sketches) == (2 * 2)
+ # Sketch 0
+ assert_has_cross_thread_reduction(sketches[0], 1)
+ assert_compute_at_condition(sketches[3].stages[2], "inlined")
+ assert_has_cross_thread_reduction(sketches[0], 3)
+ # Sketch 1
+ assert_compute_at_condition(sketches[3].stages[2], "inlined")
+ assert_has_cross_thread_reduction(sketches[1], 3)
+ # Sketch 2
+ assert_has_cross_thread_reduction(sketches[2], 1)
+ assert_compute_at_condition(sketches[3].stages[2], "inlined")
+ # Sketch 3
+ assert_compute_at_condition(sketches[3].stages[2], "inlined")
+
+ sketches = generate_sketches(softmax_abcd_auto_scheduler_test, (1, 12, 128, 128), 'cuda')
+ ''' (1 cross thread reuction sketch + 1 default sketch) * (1 cross thread reuction sketch + 1 default sketch) '''
+ assert len(sketches) == (2 * 2)
+ # Sketch 0
+ assert_has_cross_thread_reduction(sketches[0], 1)
+ assert_compute_at_condition(sketches[3].stages[2], "inlined")
+ assert_has_cross_thread_reduction(sketches[0], 3)
+ # Sketch 1
+ assert_compute_at_condition(sketches[3].stages[2], "inlined")
+ assert_has_cross_thread_reduction(sketches[1], 3)
+ # Sketch 2
+ assert_has_cross_thread_reduction(sketches[2], 1)
+ assert_compute_at_condition(sketches[3].stages[2], "inlined")
+ # Sketch 3
+ assert_compute_at_condition(sketches[3].stages[2], "inlined")
+
+
+def test_cuda_conv2d_winograd_sketch():
+ if not tvm.context("cuda", 0).exist:
+ return
+
+ sketches = generate_sketches(conv2d_winograd_nhwc_auto_scheduler_test,
+ (1, 28, 28, 128, 128, 3, 1, 1), 'cuda')
+ ''' 1 multi-level tiling sketch '''
+ assert len(sketches) == 1
+ assert_compute_at_condition(sketches[0].stages[1], "inlined")
+ assert_compute_at_condition(sketches[0].stages[2], "inlined")
+ assert_compute_at_condition(sketches[0].stages[3], "inlined")
+ assert_is_tiled(sketches[0].stages[4])
+ assert_has_cache_read(sketches[0], 4)
+ assert_compute_at_condition(sketches[0].stages[5], "iter")
+ assert_has_cache_read(sketches[0], 6)
+ assert_compute_at_condition(sketches[0].stages[7], "iter")
+ assert_is_not_tiled(sketches[0].stages[8])
+ assert_compute_at_condition(sketches[0].stages[8], "iter")
+ assert_compute_at_condition(sketches[0].stages[9], "inlined")
+ assert_is_tiled(sketches[0].stages[10])
+ assert_is_not_tiled(sketches[0].stages[11])
+
if __name__ == "__main__":
test_cpu_matmul_sketch()
test_cpu_min_sketch()
test_cpu_softmax_sketch()
test_cpu_conv2d_winograd_sketch()
+ test_cuda_matmul_sketch()
+ test_cuda_conv2d_bn_relu_sketch()
+ test_cuda_max_pool2d_sketch()
+ test_cuda_min_sketch()
+ test_cuda_softmax_sketch()
+ test_cuda_conv2d_winograd_sketch()