/*!
* \brief Run schedule search for a given compute declaration.
* \param task The search task of the compute declaration.
- * \param search_policy The search policy to be used.
+ * \param search_policy The search policy.
* \param tuning_options Tuning and measurement options.
- * \return A `te::schedule` and the an Array of `te::Tensor` to be used in `tvm.lower` or
+ * \return A `te::schedule` and an Array of `te::Tensor` to be used in `tvm.lower` or
* `tvm.build`.
*/
TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task,
* \brief The auto-scheduler's computational graph and related program analyses.
*
* We convert a compute declaration described by `tvm.compute` (could be a single operator or a
- * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
- * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
- * total float operation count, consumer/producer relations of each operation stage, whether an
- * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
- * to make decisions during search process.
- * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
- * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors, all operations in the DAG, and
+ * some static analysis results for the DAG (e.g. the total float operation count, consumer/producer
+ * relations of operations, whether an operation stage should be tiled/compute inlined ...).
+ * These analyses can help the search policy to make decisions during the search.
+ * ComputeDAG is also responsible for the interaction between auto-scheduler's `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to a TVM schedule, providing
* `LoopState` with extra information got from TVM schedule ...).
*/
namespace tvm {
namespace auto_scheduler {
-/*! \brief Static analysis result for a ComputeDAG */
+/*! \brief Static analyzer for a ComputeDAG */
class AccessAnalyzerNode : public Object {
public:
template <class T>
using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;
/*! \brief Map an operation to all operations it reads from.
- * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
+ * For each operation pair, use a two-dimentional array for multiple multi-dimentional accesses
* The inner vector represents the indices of multi-dimensional access.*/
OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
/*! \brief Map an operation to all operations it is read by.
- * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
+ * For each operation pair, use a two-dimentional array for multiple multi-dimentional accesses
* The inner vector represents the indices of multi-dimensional access.*/
OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
/*! \brief Store the number of common outer iterators for operation pairs that have
explicit AccessAnalyzer(const Array<te::Tensor>& tensors);
/*!
- * \brief Return whether this operation is an injective operation
+ * \brief Return whether this operation is an op with simple access
* (e.g., injective, broadcast and elementwise ops without reduction)
* \param op The operation
*/
TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
/*!
- * \brief Return whether this operation is an output op
+ * \brief Return whether this operation is an output operation
* \param op The operation
*/
TVM_DLL bool IsOutput(const te::Operation& op) const;
/*!
- * \brief Get all consumers of on operation
+ * \brief Get all consumers of an operation
* \param state The current loop state
* \param op The operation
* \return The set of consumers
const State& state, const te::Operation& op) const;
/*!
- * \brief Get all producers of on operation
+ * \brief Get all producers of an operation
* \param state The current loop state
* \param op The operation
* \return The set of producers
const State& state, const te::Operation& op) const;
/*!
- * \brief Get all direct producers of on operation
+ * \brief Get all direct producers of an operation
* \param op The operation
* \return The set of direct producers
* \note This function DOES NOT propagate the relation for inlined ops
/*!
* \brief Return whether two operations are elementwise-matched
- * (e.g. conv2d and relu are elementwise matched)
+ * (e.g. conv2d and relu are elementwise-matched)
* \note This function propagates the relation for chains with multiple ops.
*/
TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const;
TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode);
};
-/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */
+/*! \brief The auto-scheduler's computational graph and related program analyses. */
class ComputeDAGNode : public Object {
public:
/*!
* This is used as the input of `tvm.lower` or `tvm.build`.
*/
Array<te::Tensor> tensors;
- /*! \brief All related operations in topo order. */
+ /*! \brief All used operations in topo order. */
Array<te::Operation> ops;
- /*! \brief The number of total float operations for this ComputeDAG. */
+ /*! \brief The number of float operations in this ComputeDAG. */
double flop_ct;
/*! \brief The initial state without any transform steps. */
State init_state;
/*!
* \file auto_scheduler/loop_state.h
- * \brief The definition of the "state" in search.
+ * \brief The definition of the "state" in the search.
*
* Each LoopState corresponds to a schedule for its ComputeDAG.
* A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to
* During the schedule search process, the loop structure can provide search policy with necessary
* information on how to manipulate the current state.
* The transform history is a sequence of `TransformStep` which will finally be mapped to TVM
- * schedule primitives. The steps can also be used for the serialization of a state.
+ * schedule primitives. The steps are also used for the serialization of a state.
*
* The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
* We don't use the existing TVM IR but to extend a new structure on it is because:
* 3. We may create some macro schedule primitives that represent the combination of several
* TVM schedule primitives.
*
- * When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives.
+ * When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives.
* Since we share a lot of common objects during search, the transformation is implemented in
* copy on write style. All objects are immutable, which is similar to TVM IR.
*/
explicit Stage(te::Operation op);
/*!
* \brief The constructor.
- * \param op A `te::Operation`.
+ * \param op The source operation
* \param op_type The stage type of this op.
* \param iters The iterators of this op.
* \param compute_at The compute at type of this op.
/*! \brief A Map to store the mapping of stage to its attached iterator. */
std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
- /*! \brief A Map to store the mapping of iterator to the stage attached to it. */
+ /*! \brief A Map to store the mapping of iterator to the stages attached to it. */
std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages;
static constexpr const char* _type_key = "auto_scheduler.AttachMap";
public:
/*!
* \brief Process the stage/iterator mapping after compute at.
- * \param stage_id The index of the stage to be computed at.
+ * \param stage_id The index of the source stage of computed at.
* \param target_stage_id The index of stage that this step will compute at to.
- * \param target_iter_id The index of iterator in target stage that this step will compute at to.
+ * \param target_iter_id The index of target iterator in the target stage.
*/
void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
/*!
- * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage.
- * \param stage_id The index of the stage to be computed at.
+ * \brief Delete the entry of a specific stage. This is a public wrapper of `DeleteStageEntry`.
+ * \param stage_id The index of the stage to be deleted.
*/
void DeleteStage(int stage_id);
* \brief Find the relations of original iterators in AttachMap, and update them with the new
* iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
* \param original_iters The original IterKey.
- * \param new_iters The new IterKey to update.
+ * \param new_iters The new IterKey for replacing the old ones.
*/
void UpdateIters(const std::vector<IterKey>& original_iters,
const std::vector<IterKey>& new_iters);
/*!
* \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset
* to stage indexes that are larger than the start_id. Used for steps that insert new stages to
- * ComputeDAG(e.g. CacheRead/CacheWrite step).
- * \param start_id The index threshold, stage indexes in AttachMap which are larger than this
- * will be applied the extra offset.
+ * ComputeDAG (e.g., CacheRead/CacheWrite step).
+ * \param start_id The index threshold. This function only adds offset for stages
+ * with indices larger then this threshold.
* \param offset The index offset to be added to the stage index.
* \return The updated AttachMap after applying stage index offset.
*/
private:
/*!
- * \brief To delete the entry of a specific stage. This will remove the items related to this
+ * \brief Delete the entry of a specific stage. This will remove the items related to this
* stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map.
* \param pnode A mutable pointer to AttachMapNode.
* \param stage_id The index of stage that will be removed from the map.
* operation.
*/
AttachMap attach_map;
- /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, means
- * no modification to the original ComputeDAG.
- * Otherwise, it means some steps (e.g., CacheReadStep/CacheWriteStep) have modified the
- * ComputeDAG, the stored value is the up-to-date ComputeDAG for this state.
+ /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt,
+ * meaning the dag of this state is the same as the original ComputeDAG in the SearchTask.
+ * Otherwise, the stored value is the up-to-date ComputeDAG for this state, meaning some steps
+ * (e.g., CacheReadStep/CacheWriteStep) have modified the ComputeDAG.
*/
Optional<ObjectRef> current_compute_dag;
/*!
explicit State(const Array<te::Operation>& ops);
/*!
- * \brief Print the state to a human readable string.
+ * \brief Pretty-print the state to a human readable string.
* \param delete_trivial_loop True for skipping the trivial loops.
* (undefined or extent == 1, default set to True)
- * \return The human readable state structure.
+ * \return The human readable string.
*/
String ToStr(bool delete_trivial_loop = true) const;
+ /********** Step APIs working on a single stage **********/
/*!
- * \brief General call step functions with a runtime dynamic dispatcher. This will re-apply all
- * the transform steps from the initial state.
- * \param dag The original ComputeDAG of this state.
- * \note The input `dag` is different from the class member `current_compute_dag`.
- * This function takes the initial ComputeDAG as input to replay all the history. While the
- * `current_compute_dag` is used to track the current stage status, for some transform step may
- * change the op stage structure.
- */
- void ApplySteps(const ComputeDAG& dag);
-
- /********** Step APIs working on single stage **********/
-
- /*!
- * \brief Schedule primitive corresponds to `te::Stage::bind`.
+ * \brief The schedule primitive corresponding to `te::Stage::bind`.
* \param stage_id The index of the stage to be binded.
* \param it The iterator to be binded.
- * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as
- * this input.
- * \return The iterator result after binded.
+ * \param thread_type The thread type.
+ * \return The new iterator after binding.
*/
TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
/*!
- * \brief Schedule primitive corresponds to `te::Stage::parallel`.
+ * \brief The schedule primitive corresponding to `te::Stage::parallel`.
* \param stage_id The index of the stage to be paralleled.
* \param it The iterator to be paralleled.
- * \return The iterator result after parallel.
+ * \return The new iterator after parallel.
*/
TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
/*!
- * \brief Schedule primitive corresponds to `te::Stage::unroll`.
+ * \brief The schedule primitive corresponding to `te::Stage::unroll`.
* \param stage_id The index of the stage to be unrolled.
* \param it The iterator to be unrolled.
* \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be
* skipped.
- * \return The iterator result after unrolled.
+ * \return The new iterator after unroll.
*/
TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
/*!
- * \brief Schedule primitive corresponds to `te::Stage::vectorize`.
+ * \brief The schedule primitive corresponding to `te::Stage::vectorize`.
* \param stage_id The index of the stage to be vectorized.
* \param it The iterator to be vectorized.
- * \return The iterator result after vectorize.
+ * \return The new iterator after vectorization.
*/
TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
/*!
- * \brief Schedule primitive corresponds to `te::Stage::fuse`.
+ * \brief The schedule primitive corresponding to `te::Stage::fuse`.
* \param stage_id The index of the stage to be fused.
* \param iters The iterators to be fused.
* \return The iterator result after fuse.
*/
TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
/*!
- * \brief Schedule primitive corresponds to `te.Stage.pragma`.
+ * \brief The schedule primitive corresponding to `te.Stage.pragma`.
* \param stage_id The index of the stage to add pragma.
* \param it The iterator to add pragma.
* \param pragma_type The pragma string.
*/
TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
/*!
- * \brief Schedule primitive corresponds to `te::Stage::reorder`.
+ * \brief The schedule primitive corresponding to `te::Stage::reorder`.
* \param stage_id The index of the stage to be reordered.
* \param order The expected iterator order.
*/
TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
/*!
- * \brief Schedule primitive corresponds to `te::Stage::split`.
+ * \brief The schedule primitive corresponding to `te::Stage::split`.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param lengths The multiple split factors. Can be None to be filled by search policy.
- * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner.
- * \return The iterator results after split.
+ * \param inner_to_outer Whether the factors go from inner to outer, or from outer to inner.
+ * \return The new iterator after splitting.
* \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
* most iterator of split results will become the new attach point.
*/
const Array<Optional<Integer>>& lengths,
bool inner_to_outer = true);
/*!
- * \brief Schedule primitive extends to split step.
+ * \brief The schedule primitive similar to split, but uses split factors from previous steps.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param src_step_id The index of the split step to be followed in the history.
* \param n_split The number of split level.
- * \return The splitted new Iterators.
+ * \return The split new Iterators.
*/
TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
int n_split);
/*!
- * \brief Schedule primitive extends to split step.
+ * \brief The schedule primitive similar to split, but uses split factors from
+ * fused previous steps.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param src_step_ids The indices of the split steps to be followed in the history.
* \param level Use the length in this split level.
* \param factor_or_nparts True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.
- * \return The splitted new Iterators.
+ * \return The split new Iterators.
*/
TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts);
/*!
- * \brief Schedule primitive corresponds to `te.Stage.storage_align`.
+ * \brief The schedule primitive corresponding to `te.Stage.storage_align`.
* \param stage_id The index of the stage to be aligned.
* \param it The iterator to be aligned.
* \param factor The factor in alignment specification.
TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);
/********** Step APIs working on multiple stages **********/
-
/*!
- * \brief Schedule primitive corresponds to `te::Stage::compute_at`.
- * \param stage_id The index of the stage to be computed at.
+ * \brief The schedule primitive corresponding to `te::Stage::compute_at`.
+ * \param stage_id The index of the source stage of computed at.
* \param target_stage_id The index of stage that this step will compute at to.
- * \param target_iter The iterator in target stage that this step will compute at to.
+ * \param target_iter The indiex of the target iterator in the target stage.
* \note After compute_at, we need careful dependency analysis to compute the accurate bound
* information. However, it is relatively expensive and complicated, so we just fill "None" as
* bound for the newly created iterators.
- * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+ * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
*/
TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
/*!
- * \brief Schedule primitive corresponds to `te::Stage::compute_inline`.
+ * \brief The schedule primitive corresponding to `te::Stage::compute_inline`.
* \param stage_id The index of the stage to be marked compute inlined.
*/
TVM_DLL void compute_inline(int stage_id);
/*!
- * \brief Schedule primitive corresponds to `te::Stage::compute_root`.
+ * \brief The schedule primitive corresponding to `te::Stage::compute_root`.
* \param stage_id The index of the stage to be marked compute at root.
* \note After compute_root, we need careful dependency analysis to compute the accurate bound
* information. However, it is relatively expensive and complicated, so we just fill "None" as
* bound for the newly created iterators.
- * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+ * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
*/
TVM_DLL void compute_root(int stage_id);
/********** Step APIs adding new stages **********/
-
/*!
- * \brief Schedule primitive corresponds to `te::Schedule::cache_read`.
- * \param stage_id The index of the stage to be cache read.
- * \param scope_name The scope name of the newly added read stage.
- * \param reader_stage_ids The indices of read stages.
+ * \brief The schedule primitive corresponding to `te::Schedule::cache_read`.
+ * \param stage_id The index of the stage to be cache_read.
+ * \param scope_name The scope name of the newly added stage.
+ * \param reader_stage_ids The indices of reader stages.
* \param dag The original ComputeDAG of this state.
* \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
- * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+ * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
*/
TVM_DLL int cache_read(int stage_id, const String& scope_name,
const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
/*!
- * \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
- * \param stage_id The index of the stage to be cache write.
- * \param scope_name The scope name of the newly added compute stage.
+ * \brief The schedule primitive corresponding to `te::Schedule::cache_write`.
+ * \param stage_id The index of the stage to be cache_write.
+ * \param scope_name The scope name of the newly added stage.
* \param dag The original ComputeDAG of this state.
* \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the
- * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+ * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
* This step will cache write all output tensors of the target stage.
*/
TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
/*!
- * \brief Schedule primitive corresponds to `te::Schedule::rfactor`.
+ * \brief The schedule primitive corresponding to `te::Schedule::rfactor`.
* \param stage_id The index of the iterator to be factored.
* \param it The iterator to be factored.
* \param factor_iter_id The position where the new iterator is placed.
* \param dag The original ComputeDAG of this state.
* \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the
- * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+ * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
*/
TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);
/*!
* \file auto_scheduler/transform_step.h
- * \brief Transformation steps. These steps are used to manipulate the LoopState.
+ * \brief Transformation steps. These steps are used to manipulate `LoopState`.
* They are similar to the schedule primitives in te::Stage.
*
* \note How to add a new transform step:
* 3. Implement `FuseStepNode::ApplyToState` and the state API `State::fuse`.
* - In these two functions you need to incrementally update all data structures in State with
* CopyOnWrite style.
- * 4. Add your step implementation to `StepApplyToState`, `StepApplyToSchedule` and
- * `StepPrintAsPythonAPI`, make sure it works.
+ * 4. Add your step to `StepApplyToState`, `StepApplyToSchedule`, and `StepPrintAsPythonAPI`.
* 5. Log record serialization support:
* - Add `FuseStepNode::WriteToRecord` which takes a mutable JSONWriter pointer as input and
* output the record to it.
* - Add another construction function that takes a mutable JSONReader as input, this will get a
* step record from the reader and create the step.
* - Add the step implementation to `StepReadFromRecord`.
- * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test, the test should
- * at lease consists of two parts: the functional test and the record serialization test.
+ * 6. Add its corresponding Python API to `loop_state.py` with necessary unit tests. The test should
+ * at lease cover two parts: the functional test and the record serialization test.
*/
#ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
/*!
* \brief Update the current stage IterVar information to StageToAxesMap.
- * \param stage A te::Stage Object.
- * \param stage_to_axes A mutable pointer to StageToAxesMap, this map will be updated.
+ * \param stage The stage to be updated.
+ * \param stage_to_axes The map to be updated.
*/
void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes);
extern const char* IteratorAnnotationString[];
/*!
- * \brief A for loop iterator
+ * \brief An iterator of a for-loop
* Similar to tvm::IterVar in `include/tvm/tir/expr.h`
*/
class IteratorNode : public Object {
Step StepReadFromRecord(dmlc::JSONReader* reader);
/*!
- * \brief Apply the step to State.
+ * \brief Apply a general step to a State with runtime dynamic dispatching.
* \param step The step to be applied to State.
* \param state A mutable pointer to state, which will be updated.
* \param dag The original ComputeDAG of this state.
void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
/*!
- * \brief Apply the step to tvm.schedule.
+ * \brief Apply a general step to tvm.schedule with runtime dynamic dispatching.
* \param step The step to be applied to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
- * \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need
- * `te::Schedule` API. (e.g. CacheRead/CacheWrite step)
- * \param transform_steps An array record all transform steps.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
+ * \param schedule A mutable point to the current schedule
+ * \param transform_steps An array of all history transform steps.
*/
void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
te::Schedule* schedule, const Array<Step>& transform_steps);
/*!
- * \brief Print the step as equivalent python schedule API.
- * \param step The step to be applied to python API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
- * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g.
- * CacheRead/CacheWrite step)
- * \param transform_steps An array record all transform steps.
+ * \brief Print a general step as equivalent python schedule API with runtime dynamic dispatching.
+ * \param step The step to be printed as python API.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
+ * \param schedule A mutable point to the current schedule
+ * \param transform_steps An array of all history transform steps.
* \return Python schedule code.
*/
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return The iterator result after fuse.
*/
tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return The iterator results after split.
*/
Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
public:
/*! \brief The id of the iter to be split. */
int iter_id;
- /*! \brief The index of the split step to follow in the history. */
+ /*! \brief The index of the split step to be followed in the history. */
int src_step_id;
/*! \brief The number of split level. */
int n_split;
/*!
* \brief Extract split lengths.
- * \param transform_steps An array record all transform steps.
+ * \param transform_steps An array of history transform steps.
* \return The multiple split factors.
*/
Array<Optional<Integer>> ExtractSplitLengths(const Array<Step>& transform_steps) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
- * \param transform_steps An array record all transform steps.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
+ * \param transform_steps An array of history transform steps.
* \return The iterator results after split.
*/
Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
- * \param transform_steps An array record all transform steps.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
+ * \param transform_steps An array of history transform steps.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
* \brief The constructor.
* \param stage_id The index of the stage to be split.
* \param iter_id The index of the iterator to be split.
- * \param src_step_id The index of the split step to follow in the history.
+ * \param src_step_id The index of the split step to be followed in the history.
* \param n_split The number of split level.
*/
FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
public:
/*! \brief The id of the iter to split. */
int iter_id;
- /*! \brief The indices of the split steps to follow in the history. */
+ /*! \brief The indices of the split steps to be followed in the history. */
Array<Integer> src_step_ids;
/*! \brief Use the length in this split level. */
int level;
/*!
* \brief Extract split length.
- * \param transform_steps An array record all transform steps.
+ * \param transform_steps An array of history transform steps.
* \return Split factor.
*/
Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
- * \param transform_steps An array record all transform steps.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
+ * \param transform_steps An array of history transform steps.
* \return The iterator results after split.
*/
Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
- * \param transform_steps An array record all transform steps.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
+ * \param transform_steps An array of history transform steps.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
* \brief The constructor.
* \param stage_id The index of the stage to be split.
* \param iter_id The index of the iterator to be split.
- * \param src_step_ids An array of index for split step to follow in the history.
+ * \param src_step_ids An array of index for split step to be followed in the history.
* \param level Use the length in this split level.
* \param factor_or_nparts If this is true, use factor. Otherwise, use nparts.
*/
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
* \note After compute_at, we need careful dependency analysis to compute the accurate bound
* information. However, it is relatively expensive and complicated, so we just fill "None" as
* bound for the newly created iterators.
- * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+ * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
*/
void ApplyToState(State* state) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
public:
/*!
* \brief The constructor.
- * \param stage_id The index of the stage to be computed at.
+ * \param stage_id The index of the source stage.
* \param target_stage_id The index of stage that this step will compute at to.
* \param target_iter_id The index of iterator in target stage that this step will compute at to.
*/
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return The iterator result after fuse.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
* \note After compute_root, we need careful dependency analysis to compute the accurate bound
* information. However, it is relatively expensive and complicated, so we just fill "None" as
* bound for the newly created iterators.
- * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+ * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
*/
void ApplyToState(State* state) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return The iterator result after fuse.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Cache read step that corresponds to te::Schedule::cache_read.
- * \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date ComputeDAG
- * is stored in State's `current_compute_dag`.
+ * \note Cache read step adds an extra stage to the original ComputeDAG,
+ * an up-to-date ComputeDAG will be stored in State's `current_compute_dag`.
*/
class CacheReadStepNode : public StepNode {
public:
- /*! \brief The scope name of the newly added read stage. (e.g. local, shared, global) */
+ /*! \brief The scope name of the newly added read stage. (e.g., local, shared, global) */
String scope_name;
/*! \brief The indices of read stages. */
Array<Integer> reader_stage_ids;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \param schedule A mutable pointer to a te::Schedule.
* \return The output Tensor of the new added stage.
*/
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \param schedule A mutable pointer to a te::Schedule.
* \return Python schedule code.
*/
public:
/*!
* \brief The constructor.
- * \param stage_id The index of the stage to be cache read.
- * \param scope_name The scope name of the newly added read stage.
- * \param reader_stage_ids The indices of read stages.
+ * \param stage_id The index of the stage to be cache_read.
+ * \param scope_name The scope name of the newly added stage.
+ * \param reader_stage_ids The indices of reader stages.
*/
CacheReadStep(int stage_id, String scope_name, const Array<Integer>& reader_stage_ids);
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \param schedule A mutable pointer to a te::Schedule.
* \return The output Tensors of the new added stage.
*/
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \param schedule A mutable pointer to a te::Schedule.
* \return Python schedule code.
*/
public:
/*!
* \brief The constructor.
- * \param stage_id The index of the stage to be cache write.
- * \param scope_name The scope name of the newly added compute stage.
+ * \param stage_id The index of the stage to be cache_write.
+ * \param scope_name The scope name of the newly added stage.
*/
CacheWriteStep(int stage_id, String scope_name);
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \param schedule A mutable pointer to a te::Schedule.
* \return The output Tensors of the new added stage.
*/
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
* \param schedule A mutable pointer to a te::Schedule.
* \return Python schedule code.
*/
# specific language governing permissions and limitations
# under the License.
-""" The TVM Auto-scheduler computational graph and related program analyses. """
+""" The auto-scheduler's computational graph and related program analyses. """
import hashlib
@tvm._ffi.register_object("auto_scheduler.ComputeDAG")
class ComputeDAG(Object):
"""
- The TVM Auto-scheduler computational graph and related program analyses.
+ The auto-scheduler's computational graph and related program analyses.
We convert a compute declaration described by `tvm.compute` (could be a single operator or a
- subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
- a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
- total float operation count, consumer/producer relations of each operation stage, whether an
- operation stage should be tiled/compute inlined ...). These analyses can help the search policy
- to make decisions during search process.
- ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
- TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ subgraph) to a ComputeDAG. It keeps the input/output tensors, all operations in the DAG, and
+ some static analysis results for the DAG (e.g. the total float operation count,
+ consumer/producer relations of operations, whether an operation stage should
+ be tiled/compute inlined ...).
+ These analyses can help the search policy to make decisions during the search.
+ ComputeDAG is also responsible for the interaction between auto-scheduler's `LoopState` and
+ TVM schedule (e.g. applying the `LoopState` transform steps to a TVM schedule, providing
`LoopState` with extra information got from TVM schedule ...).
Parameters
def print_python_code_from_state(self, state):
"""
- Print transform steps in the history of a State as TVM's python schedule primitive.
+ Print transform steps in the history of a State as TVM's python schedule code.
This is used to print transformation steps for debugging.
Use `apply_steps_from_state` if you want to get a schedule for code generation.
# pylint: disable=unused-import
"""
-The definition of the "state" in search.
+The definition of the "state" in the search.
Each LoopState corresponds to a schedule for its ComputeDAG.
A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to
construct the loop structure.
The loop structure keeps a preview of how the schedule will finally look like after lowering the
-current state (e.g. number of iterators, the extent of each iterator, the compute_at locations ...).
+current state (e.g. number of iterators, the extent of each iterator, the compute_at locations
+...).
During the schedule search process, the loop structure can provide search policy with necessary
information on how to manipulate the current state.
-The transform history is a sequence of `TransformStep` which will finally be mapped to TVM schedule
-primitives. The steps can also be used for the serialization of a state.
+The transform history is a sequence of `TransformStep` which will finally be mapped to TVM
+schedule primitives. The steps are also used for the serialization of a state.
The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
We don't use the existing TVM IR but to extend a new structure on it is because:
3. We may create some macro schedule primitives that represent the combination of several
TVM schedule primitives.
-When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives.
+When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives.
Since we share a lot of common objects during search, the transformation is implemented in
copy on write style. All objects are immutable, which is similar to TVM IR.
"""
return [stage.op for stage in self.stages]
def bind(self, stage, iterator, thread_name):
- """ Schedule primitive corresponds to `te.Stage.bind`, see also the `te.Stage` for more
- details.
+ """Schedule primitive corresponding to `te.Stage.bind`.
+ See also the `te.Stage` for more details.
Parameters
----------
return res
def parallel(self, stage, iterator):
- """ Schedule primitive corresponds to `te.Stage.parallel`, see also the `te.Stage` for more
- details.
+ """Schedule primitive corresponding to `te.Stage.parallel`.
+ See also the `te.Stage` for more details.
Parameters
----------
return res
def unroll(self, stage, iterator, max_unroll=None):
- """ Schedule primitive corresponds to `te.Stage.unroll`, see also the `te.Stage` for more
- details.
+ """Schedule primitive corresponding to `te.Stage.unroll`.
+ See also the `te.Stage` for more details.
Parameters
----------
return res
def vectorize(self, stage, iterator):
- """ Schedule primitive corresponds to `te.Stage.vectorize`, see also the `te.Stage` for
- more details.
+ """Schedule primitive corresponding to `te.Stage.vectorize`.
+ See also the `te.Stage` for more details.
Parameters
----------
return res
def fuse(self, stage, iters):
- """ Schedule primitive corresponds to `te.Stage.fuse`, see also the `te.Stage` for more
- details.
+ """Schedule primitive corresponding to `te.Stage.fuse`.
+ See also the `te.Stage` for more details.
Parameters
----------
return res
def pragma(self, stage, iterator, pragma_type):
- """ Schedule primitive corresponds to `te.Stage.pragma`, see also the `te.Stage` for more
- details.
+ """Schedule primitive corresponding to `te.Stage.pragma`.
+ See also the `te.Stage` for more details.
Parameters
----------
iterator, pragma_type)
def reorder(self, stage, order):
- """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more
- details.
+ """Schedule primitive corresponding to `te.Stage.reorder`.
+ See also the `te.Stage` for more details.
Parameters
----------
order)
def split(self, stage, iterator, lengths, inner_to_outer=True):
- """ Schedule primitive corresponds to `te.Stage.split`, see also the `te.Stage` for more
- details.
+ """Schedule primitive corresponding to `te.Stage.split`.
+ See also the `te.Stage` for more details.
This API supports multiple split factors. (e.g. with 2 split factors, the original iterator
will be split to 3 parts, use `inner_to_outer` to control the split order)
return res
def follow_split(self, stage, iterator, src_step_id, n_split):
- """ Schedule primitive extends to split step.
+ """The schedule primitive similar to split, but uses split factors from previous steps.
This step splits the iterator by the same factors as the given SplitStep.
iterator : Iterator
The iterator to split.
src_step_id : int
- The index of the split step to follow in the history.
+ The index of the split step to be followed in the history.
n_split : int
The number of split level.
iterator : Iterator
The iterator to split.
src_step_ids : List[int]
- The indices of the split steps to follow in the history.
+ The indices of the split steps to be followed in the history.
level : int
Use the length in this split level.
factor_or_nparts : bool
return res
def storage_align(self, stage, iterator, factor, offset):
- """ Schedule primitive corresponds to `te.Stage.storage_align`, see also the `te.Stage` for
- more details.
+ """Schedule primitive corresponding to `te.Stage.storage_align`.
+ See also the `te.Stage` for more details.
Parameters
----------
factor, offset)
def compute_at(self, stage, target_stage, target_iter):
- """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
- more details.
+ """Schedule primitive corresponding to `te.Stage.compute_at`.
+ See also the `te.Stage` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
- The Stage to be computed at, which can be specified by the integer index, Operation,
- or output tensor of the stage.
+ The source Stage of computed at, which can be specified by the integer index,
+ Operation, or output tensor of the stage.
target_stage : Union[int, Operation, Tensor]
The target stage of compute_at, which can be specified by the integer index, Operation,
or output tensor of the stage.
target_iter)
def compute_inline(self, stage):
- """ Schedule primitive corresponds to `te.Stage.compute_inline`, see also the `te.Stage`
+ """Schedule primitive corresponding to `te.Stage.compute_inline`, see also the `te.Stage`
for more details.
Parameters
self._resolve_stage_id(stage))
def compute_root(self, stage):
- """ Schedule primitive corresponds to `te.Stage.compute_root`, see also the `te.Stage` for
- more details.
+ """Schedule primitive corresponding to `te.Stage.compute_root`.
+ Ssee also the `te.Stage` for more details.
Parameters
----------
self._resolve_stage_id(stage))
def cache_read(self, stage, scope_name, reader_stages):
- """ Schedule primitive corresponds to `te.Schedule.cache_read`, see also the `te.Schedule`
- for more details.
+ """Schedule primitive corresponding to `te.Schedule.cache_read`.
+ See also the `te.Schedule` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
- The Stage to be cache read, which can be specified by the integer index, Operation,
+ The Stage to be cache_read, which can be specified by the integer index, Operation,
or output tensor of the stage.
scope_name : str
The scope name of the newly added read stage.
return self.stages[int(new_stage_id)].op
def cache_write(self, stage, scope_name):
- """ Schedule primitive corresponds to `te.Schedule.cache_write`, see also the `te.Schedule`
- for more details.
+ """Schedule primitive corresponding to `te.Schedule.cache_write`.
+ See also the `te.Schedule` for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
- The Stage to be cache write, which can be specified by the integer index, Operation,
+ The Stage to be cache_write, which can be specified by the integer index, Operation,
or output tensor of the stage.
scope_name : str
The scope name of the newly added compute stage.
return self.stages[int(new_stage_id)].op
def rfactor(self, stage, iterator, factor_iter_id):
- """ Schedule primitive corresponds to `te.Schedule.rfactor`, see also the `te.Schedule` for
- more details.
+ """Schedule primitive corresponding to `te.Schedule.rfactor`.
+ See also the `te.Schedule` for more details.
Parameters
----------
#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
ret_state = operator->()->init_state;
pstate = ret_state.CopyOnWrite();
pstate->transform_steps = state->transform_steps;
- ret_state.ApplySteps(*this);
+ for (const auto& step : pstate->transform_steps) {
+ StepApplyToState(step, &ret_state, *this);
+ }
} else {
ret_state = state;
pstate = ret_state.CopyOnWrite();
return step->ApplyToState(this, dag);
}
-void State::ApplySteps(const ComputeDAG& dag) {
- CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
-
- // Call each step's ApplyToState method
- for (const auto& step : operator->()->transform_steps) {
- StepApplyToState(step, this, dag);
- }
-}
-
// Print stage to ostream
void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent,
bool delete_trivial_loop) {