[AutoScheduler] Improve doc string (#6176)
authorLianmin Zheng <lianminzheng@gmail.com>
Thu, 30 Jul 2020 16:35:29 +0000 (09:35 -0700)
committerGitHub <noreply@github.com>
Thu, 30 Jul 2020 16:35:29 +0000 (09:35 -0700)
include/tvm/auto_scheduler/auto_schedule.h
include/tvm/auto_scheduler/compute_dag.h
include/tvm/auto_scheduler/loop_state.h
include/tvm/auto_scheduler/transform_step.h
python/tvm/auto_scheduler/compute_dag.py
python/tvm/auto_scheduler/loop_state.py
src/auto_scheduler/compute_dag.cc
src/auto_scheduler/loop_state.cc

index 8477966..8d458f1 100644 (file)
@@ -100,9 +100,9 @@ class TuningOptions : public ObjectRef {
 /*!
  * \brief Run schedule search for a given compute declaration.
  * \param task The search task of the compute declaration.
- * \param search_policy The search policy to be used.
+ * \param search_policy The search policy.
  * \param tuning_options Tuning and measurement options.
- * \return A `te::schedule` and the an Array of `te::Tensor` to be used in `tvm.lower` or
+ * \return A `te::schedule` and an Array of `te::Tensor` to be used in `tvm.lower` or
  * `tvm.build`.
  */
 TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchTask task,
index 69b74bf..16bc729 100644 (file)
  * \brief The auto-scheduler's computational graph and related program analyses.
  *
  * We convert a compute declaration described by `tvm.compute` (could be a single operator or a
- * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
- * a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
- * total float operation count, consumer/producer relations of each operation stage, whether an
- * operation stage should be tiled/compute inlined ...). These analyses can help the search policy
- * to make decisions during search process.
- * ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
- * TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+ * subgraph) to a ComputeDAG. It keeps the input/output tensors, all operations in the DAG, and
+ * some static analysis results for the DAG (e.g. the total float operation count, consumer/producer
+ * relations of operations, whether an operation stage should be tiled/compute inlined ...).
+ * These analyses can help the search policy to make decisions during the search.
+ * ComputeDAG is also responsible for the interaction between auto-scheduler's `LoopState` and
+ * TVM schedule (e.g. applying the `LoopState` transform steps to a TVM schedule, providing
  * `LoopState` with extra information got from TVM schedule ...).
  */
 
 namespace tvm {
 namespace auto_scheduler {
 
-/*! \brief Static analysis result for a ComputeDAG */
+/*! \brief Static analyzer for a ComputeDAG */
 class AccessAnalyzerNode : public Object {
  public:
   template <class T>
   using OperationMap = std::unordered_map<te::Operation, T, ObjectPtrHash, ObjectPtrEqual>;
 
   /*! \brief Map an operation to all operations it reads from.
-   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
+   * For each operation pair, use a two-dimentional array for multiple multi-dimentional accesses
    * The inner vector represents the indices of multi-dimensional access.*/
   OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_from;
   /*! \brief Map an operation to all operations it is read by.
-   * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses
+   * For each operation pair, use a two-dimentional array for multiple multi-dimentional accesses
    * The inner vector represents the indices of multi-dimensional access.*/
   OperationMap<OperationMap<std::vector<std::vector<PrimExpr>>>> read_by;
   /*! \brief Store the number of common outer iterators for operation pairs that have
@@ -92,7 +91,7 @@ class AccessAnalyzer : public ObjectRef {
   explicit AccessAnalyzer(const Array<te::Tensor>& tensors);
 
   /*!
-   * \brief Return whether this operation is an injective operation
+   * \brief Return whether this operation is an op with simple access
    * (e.g., injective, broadcast and elementwise ops without reduction)
    * \param op The operation
    */
@@ -113,13 +112,13 @@ class AccessAnalyzer : public ObjectRef {
   TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const;
 
   /*!
-   * \brief Return whether this operation is an output op
+   * \brief Return whether this operation is an output operation
    * \param op The operation
    */
   TVM_DLL bool IsOutput(const te::Operation& op) const;
 
   /*!
-   * \brief Get all consumers of on operation
+   * \brief Get all consumers of an operation
    * \param state The current loop state
    * \param op The operation
    * \return The set of consumers
@@ -129,7 +128,7 @@ class AccessAnalyzer : public ObjectRef {
       const State& state, const te::Operation& op) const;
 
   /*!
-   * \brief Get all producers of on operation
+   * \brief Get all producers of an operation
    * \param state The current loop state
    * \param op The operation
    * \return The set of producers
@@ -139,7 +138,7 @@ class AccessAnalyzer : public ObjectRef {
       const State& state, const te::Operation& op) const;
 
   /*!
-   * \brief Get all direct producers of on operation
+   * \brief Get all direct producers of an operation
    * \param op The operation
    * \return The set of direct producers
    * \note This function DOES NOT propagate the relation for inlined ops
@@ -158,7 +157,7 @@ class AccessAnalyzer : public ObjectRef {
 
   /*!
    * \brief Return whether two operations are elementwise-matched
-   *  (e.g. conv2d and relu are elementwise matched)
+   *  (e.g. conv2d and relu are elementwise-matched)
    * \note This function propagates the relation for chains with multiple ops.
    */
   TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const;
@@ -166,7 +165,7 @@ class AccessAnalyzer : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode);
 };
 
-/*! \brief The TVM Auto-scheduler computational graph and related program analyses. */
+/*! \brief The auto-scheduler's computational graph and related program analyses. */
 class ComputeDAGNode : public Object {
  public:
   /*!
@@ -174,9 +173,9 @@ class ComputeDAGNode : public Object {
    * This is used as the input of `tvm.lower` or `tvm.build`.
    */
   Array<te::Tensor> tensors;
-  /*! \brief All related operations in topo order. */
+  /*! \brief All used operations in topo order. */
   Array<te::Operation> ops;
-  /*! \brief The number of total float operations for this ComputeDAG. */
+  /*! \brief The number of float operations in this ComputeDAG. */
   double flop_ct;
   /*! \brief The initial state without any transform steps. */
   State init_state;
index 34e7e56..ba58f37 100644 (file)
@@ -19,7 +19,7 @@
 
 /*!
  * \file auto_scheduler/loop_state.h
- * \brief The definition of the "state" in search.
+ * \brief The definition of the "state" in the search.
  *
  * Each LoopState corresponds to a schedule for its ComputeDAG.
  * A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to
@@ -30,7 +30,7 @@
  * During the schedule search process, the loop structure can provide search policy with necessary
  * information on how to manipulate the current state.
  * The transform history is a sequence of `TransformStep` which will finally be mapped to TVM
- * schedule primitives. The steps can also be used for the serialization of a state.
+ * schedule primitives. The steps are also used for the serialization of a state.
  *
  * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
  * We don't use the existing TVM IR but to extend a new structure on it is because:
@@ -40,7 +40,7 @@
  * 3. We may create some macro schedule primitives that represent the combination of several
  * TVM schedule primitives.
  *
- * When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives.
+ * When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives.
  * Since we share a lot of common objects during search, the transformation is implemented in
  * copy on write style. All objects are immutable, which is similar to TVM IR.
  */
@@ -131,7 +131,7 @@ class Stage : public ObjectRef {
   explicit Stage(te::Operation op);
   /*!
    * \brief The constructor.
-   * \param op A `te::Operation`.
+   * \param op The source operation
    * \param op_type The stage type of this op.
    * \param iters The iterators of this op.
    * \param compute_at The compute at type of this op.
@@ -167,7 +167,7 @@ class AttachMapNode : public Object {
 
   /*! \brief A Map to store the mapping of stage to its attached iterator. */
   std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
-  /*! \brief A Map to store the mapping of iterator to the stage attached to it. */
+  /*! \brief A Map to store the mapping of iterator to the stages attached to it. */
   std::unordered_map<IterKey, std::vector<StageKey>, IterKeyHash> iter_to_attached_stages;
 
   static constexpr const char* _type_key = "auto_scheduler.AttachMap";
@@ -182,15 +182,15 @@ class AttachMap : public ObjectRef {
  public:
   /*!
    * \brief Process the stage/iterator mapping after compute at.
-   * \param stage_id The index of the stage to be computed at.
+   * \param stage_id The index of the source stage of computed at.
    * \param target_stage_id The index of stage that this step will compute at to.
-   * \param target_iter_id The index of iterator in target stage that this step will compute at to.
+   * \param target_iter_id The index of target iterator in the target stage.
    */
   void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
 
   /*!
-   * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage.
-   * \param stage_id The index of the stage to be computed at.
+   * \brief Delete the entry of a specific stage. This is a public wrapper of `DeleteStageEntry`.
+   * \param stage_id The index of the stage to be deleted.
    */
   void DeleteStage(int stage_id);
 
@@ -198,7 +198,7 @@ class AttachMap : public ObjectRef {
    * \brief Find the relations of original iterators in AttachMap, and update them with the new
    * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
    * \param original_iters The original IterKey.
-   * \param new_iters The new IterKey to update.
+   * \param new_iters The new IterKey for replacing the old ones.
    */
   void UpdateIters(const std::vector<IterKey>& original_iters,
                    const std::vector<IterKey>& new_iters);
@@ -206,9 +206,9 @@ class AttachMap : public ObjectRef {
   /*!
    * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset
    * to stage indexes that are larger than the start_id. Used for steps that insert new stages to
-   * ComputeDAG(e.g. CacheRead/CacheWrite step).
-   * \param start_id The index threshold, stage indexes in AttachMap which are larger than this
-   * will be applied the extra offset.
+   * ComputeDAG (e.g., CacheRead/CacheWrite step).
+   * \param start_id The index threshold. This function only adds offset for stages
+   * with indices larger then this threshold.
    * \param offset The index offset to be added to the stage index.
    * \return The updated AttachMap after applying stage index offset.
    */
@@ -219,7 +219,7 @@ class AttachMap : public ObjectRef {
 
  private:
   /*!
-   * \brief To delete the entry of a specific stage. This will remove the items related to this
+   * \brief Delete the entry of a specific stage. This will remove the items related to this
    * stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map.
    * \param pnode A mutable pointer to AttachMapNode.
    * \param stage_id The index of stage that will be removed from the map.
@@ -244,10 +244,10 @@ class StateNode : public Object {
    * operation.
    */
   AttachMap attach_map;
-  /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, means
-   * no modification to the original ComputeDAG.
-   * Otherwise, it means some steps (e.g., CacheReadStep/CacheWriteStep) have modified the
-   * ComputeDAG, the stored value is the up-to-date ComputeDAG for this state.
+  /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt,
+   * meaning the dag of this state is the same as the original ComputeDAG in the SearchTask.
+   * Otherwise, the stored value is the up-to-date ComputeDAG for this state, meaning some steps
+   * (e.g., CacheReadStep/CacheWriteStep) have modified the ComputeDAG.
    */
   Optional<ObjectRef> current_compute_dag;
   /*!
@@ -279,60 +279,47 @@ class State : public ObjectRef {
   explicit State(const Array<te::Operation>& ops);
 
   /*!
-   * \brief Print the state to a human readable string.
+   * \brief Pretty-print the state to a human readable string.
    * \param delete_trivial_loop True for skipping the trivial loops.
    * (undefined or extent == 1, default set to True)
-   * \return The human readable state structure.
+   * \return The human readable string.
    */
   String ToStr(bool delete_trivial_loop = true) const;
 
+  /********** Step APIs working on a single stage **********/
   /*!
-   * \brief General call step functions with a runtime dynamic dispatcher. This will re-apply all
-   * the transform steps from the initial state.
-   * \param dag The original ComputeDAG of this state.
-   * \note The input `dag` is different from the class member `current_compute_dag`.
-   * This function takes the initial ComputeDAG as input to replay all the history. While the
-   * `current_compute_dag` is used to track the current stage status, for some transform step may
-   * change the op stage structure.
-   */
-  void ApplySteps(const ComputeDAG& dag);
-
-  /********** Step APIs working on single stage **********/
-
-  /*!
-   * \brief Schedule primitive corresponds to `te::Stage::bind`.
+   * \brief The schedule primitive corresponding to `te::Stage::bind`.
    * \param stage_id The index of the stage to be binded.
    * \param it The iterator to be binded.
-   * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as
-   * this input.
-   * \return The iterator result after binded.
+   * \param thread_type The thread type.
+   * \return The new iterator after binding.
    */
   TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
   /*!
-   * \brief Schedule primitive corresponds to `te::Stage::parallel`.
+   * \brief The schedule primitive corresponding to `te::Stage::parallel`.
    * \param stage_id The index of the stage to be paralleled.
    * \param it The iterator to be paralleled.
-   * \return The iterator result after parallel.
+   * \return The new iterator after parallel.
    */
   TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
   /*!
-   * \brief Schedule primitive corresponds to `te::Stage::unroll`.
+   * \brief The schedule primitive corresponding to `te::Stage::unroll`.
    * \param stage_id The index of the stage to be unrolled.
    * \param it The iterator to be unrolled.
    * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be
    * skipped.
-   * \return The iterator result after unrolled.
+   * \return The new iterator after unroll.
    */
   TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
   /*!
-   * \brief Schedule primitive corresponds to `te::Stage::vectorize`.
+   * \brief The schedule primitive corresponding to `te::Stage::vectorize`.
    * \param stage_id The index of the stage to be vectorized.
    * \param it The iterator to be vectorized.
-   * \return The iterator result after vectorize.
+   * \return The new iterator after vectorization.
    */
   TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
   /*!
-   * \brief Schedule primitive corresponds to `te::Stage::fuse`.
+   * \brief The schedule primitive corresponding to `te::Stage::fuse`.
    * \param stage_id The index of the stage to be fused.
    * \param iters The iterators to be fused.
    * \return The iterator result after fuse.
@@ -341,25 +328,25 @@ class State : public ObjectRef {
    */
   TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
   /*!
-   * \brief Schedule primitive corresponds to `te.Stage.pragma`.
+   * \brief The schedule primitive corresponding to `te.Stage.pragma`.
    * \param stage_id The index of the stage to add pragma.
    * \param it The iterator to add pragma.
    * \param pragma_type The pragma string.
    */
   TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
   /*!
-   * \brief Schedule primitive corresponds to `te::Stage::reorder`.
+   * \brief The schedule primitive corresponding to `te::Stage::reorder`.
    * \param stage_id The index of the stage to be reordered.
    * \param order The expected iterator order.
    */
   TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
   /*!
-   * \brief Schedule primitive corresponds to `te::Stage::split`.
+   * \brief The schedule primitive corresponding to `te::Stage::split`.
    * \param stage_id The index of the stage to be split.
    * \param it The iterator to be split.
    * \param lengths The multiple split factors. Can be None to be filled by search policy.
-   * \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner.
-   * \return The iterator results after split.
+   * \param inner_to_outer Whether the factors go from inner to outer, or from outer to inner.
+   * \return The new iterator after splitting.
    * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
    * most iterator of split results will become the new attach point.
    */
@@ -367,30 +354,31 @@ class State : public ObjectRef {
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
   /*!
-   * \brief Schedule primitive extends to split step.
+   * \brief The schedule primitive similar to split, but uses split factors from previous steps.
    * \param stage_id The index of the stage to be split.
    * \param it The iterator to be split.
    * \param src_step_id The index of the split step to be followed in the history.
    * \param n_split The number of split level.
-   * \return The splitted new Iterators.
+   * \return The split new Iterators.
    */
   TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
                                        int n_split);
   /*!
-   * \brief Schedule primitive extends to split step.
+   * \brief The schedule primitive similar to split, but uses split factors from
+   * fused previous steps.
    * \param stage_id The index of the stage to be split.
    * \param it The iterator to be split.
    * \param src_step_ids The indices of the split steps to be followed in the history.
    * \param level Use the length in this split level.
    * \param factor_or_nparts True to use `factor` for split from inner to outer,
       False to use `nparts` for split from outer to inner.
-   * \return The splitted new Iterators.
+   * \return The split new Iterators.
    */
   TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
                                              const Array<Integer>& src_step_ids, int level,
                                              bool factor_or_nparts);
   /*!
-   * \brief Schedule primitive corresponds to `te.Stage.storage_align`.
+   * \brief The schedule primitive corresponding to `te.Stage.storage_align`.
    * \param stage_id The index of the stage to be aligned.
    * \param it The iterator to be aligned.
    * \param factor The factor in alignment specification.
@@ -399,64 +387,62 @@ class State : public ObjectRef {
   TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);
 
   /********** Step APIs working on multiple stages **********/
-
   /*!
-   * \brief Schedule primitive corresponds to `te::Stage::compute_at`.
-   * \param stage_id The index of the stage to be computed at.
+   * \brief The schedule primitive corresponding to `te::Stage::compute_at`.
+   * \param stage_id The index of the source stage of computed at.
    * \param target_stage_id The index of stage that this step will compute at to.
-   * \param target_iter The iterator in target stage that this step will compute at to.
+   * \param target_iter The indiex of the target iterator in the target stage.
    * \note After compute_at, we need careful dependency analysis to compute the accurate bound
    * information. However, it is relatively expensive and complicated, so we just fill "None" as
    * bound for the newly created iterators.
-   * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+   * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
    */
   TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
   /*!
-   * \brief Schedule primitive corresponds to `te::Stage::compute_inline`.
+   * \brief The schedule primitive corresponding to `te::Stage::compute_inline`.
    * \param stage_id The index of the stage to be marked compute inlined.
    */
   TVM_DLL void compute_inline(int stage_id);
   /*!
-   * \brief Schedule primitive corresponds to `te::Stage::compute_root`.
+   * \brief The schedule primitive corresponding to `te::Stage::compute_root`.
    * \param stage_id The index of the stage to be marked compute at root.
    * \note After compute_root, we need careful dependency analysis to compute the accurate bound
    * information. However, it is relatively expensive and complicated, so we just fill "None" as
    * bound for the newly created iterators.
-   * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+   * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
    */
   TVM_DLL void compute_root(int stage_id);
 
   /********** Step APIs adding new stages **********/
-
   /*!
-   * \brief Schedule primitive corresponds to `te::Schedule::cache_read`.
-   * \param stage_id The index of the stage to be cache read.
-   * \param scope_name The scope name of the newly added read stage.
-   * \param reader_stage_ids The indices of read stages.
+   * \brief The schedule primitive corresponding to `te::Schedule::cache_read`.
+   * \param stage_id The index of the stage to be cache_read.
+   * \param scope_name The scope name of the newly added stage.
+   * \param reader_stage_ids The indices of reader stages.
    * \param dag The original ComputeDAG of this state.
    * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
-   * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+   * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
    */
   TVM_DLL int cache_read(int stage_id, const String& scope_name,
                          const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
   /*!
-   * \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
-   * \param stage_id The index of the stage to be cache write.
-   * \param scope_name The scope name of the newly added compute stage.
+   * \brief The schedule primitive corresponding to `te::Schedule::cache_write`.
+   * \param stage_id The index of the stage to be cache_write.
+   * \param scope_name The scope name of the newly added stage.
    * \param dag The original ComputeDAG of this state.
    * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the
-   * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+   * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
    * This step will cache write all output tensors of the target stage.
    */
   TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
   /*!
-   * \brief Schedule primitive corresponds to `te::Schedule::rfactor`.
+   * \brief The schedule primitive corresponding to `te::Schedule::rfactor`.
    * \param stage_id The index of the iterator to be factored.
    * \param it The iterator to be factored.
    * \param factor_iter_id The position where the new iterator is placed.
    * \param dag The original ComputeDAG of this state.
    * \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the
-   * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+   * target stage), an up-to-date ComputeDAG is stored in State's `current_compute_dag`.
    */
   TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);
 
index a31765a..d4ef032 100644 (file)
@@ -19,7 +19,7 @@
 
 /*!
  * \file auto_scheduler/transform_step.h
- * \brief Transformation steps. These steps are used to manipulate the LoopState.
+ * \brief Transformation steps. These steps are used to manipulate `LoopState`.
  *        They are similar to the schedule primitives in te::Stage.
  *
  * \note How to add a new transform step:
  * 3. Implement `FuseStepNode::ApplyToState` and the state API `State::fuse`.
  *    - In these two functions you need to incrementally update all data structures in State with
  *      CopyOnWrite style.
- * 4. Add your step implementation to `StepApplyToState`, `StepApplyToSchedule` and
- *    `StepPrintAsPythonAPI`, make sure it works.
+ * 4. Add your step to `StepApplyToState`, `StepApplyToSchedule`, and `StepPrintAsPythonAPI`.
  * 5. Log record serialization support:
  *    - Add `FuseStepNode::WriteToRecord` which takes a mutable JSONWriter pointer as input and
  *      output the record to it.
  *    - Add another construction function that takes a mutable JSONReader as input, this will get a
  *      step record from the reader and create the step.
  *    - Add the step implementation to `StepReadFromRecord`.
- * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test, the test should
- *    at lease consists of two parts: the functional test and the record serialization test.
+ * 6. Add its corresponding Python API to `loop_state.py` with necessary unit tests. The test should
+ *    at lease cover two parts: the functional test and the record serialization test.
  */
 
 #ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
@@ -58,8 +57,8 @@ typedef Map<tvm::te::Stage, Array<tir::IterVar>, ObjectHash, ObjectEqual> StageT
 
 /*!
  * \brief Update the current stage IterVar information to StageToAxesMap.
- * \param stage A te::Stage Object.
- * \param stage_to_axes A mutable pointer to StageToAxesMap, this map will be updated.
+ * \param stage The stage to be updated.
+ * \param stage_to_axes The map to be updated.
  */
 void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes);
 
@@ -106,7 +105,7 @@ enum class IteratorAnnotation : int {
 extern const char* IteratorAnnotationString[];
 
 /*!
- * \brief A for loop iterator
+ * \brief An iterator of a for-loop
  * Similar to tvm::IterVar in `include/tvm/tir/expr.h`
  */
 class IteratorNode : public Object {
@@ -188,7 +187,7 @@ class ComputeDAG;
 Step StepReadFromRecord(dmlc::JSONReader* reader);
 
 /*!
- * \brief Apply the step to State.
+ * \brief Apply a general step to a State with runtime dynamic dispatching.
  * \param step The step to be applied to State.
  * \param state A mutable pointer to state, which will be updated.
  * \param dag The original ComputeDAG of this state.
@@ -196,25 +195,23 @@ Step StepReadFromRecord(dmlc::JSONReader* reader);
 void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
 
 /*!
- * \brief Apply the step to tvm.schedule.
+ * \brief Apply a general step to tvm.schedule with runtime dynamic dispatching.
  * \param step The step to be applied to tvm.schedule.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
- * \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need
- * `te::Schedule` API. (e.g. CacheRead/CacheWrite step)
- * \param transform_steps An array record all transform steps.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
+ * \param schedule A mutable point to the current schedule
+ * \param transform_steps An array of all history transform steps.
  */
 void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
                          te::Schedule* schedule, const Array<Step>& transform_steps);
 
 /*!
- * \brief Print the step as equivalent python schedule API.
- * \param step The step to be applied to python API.
- * \param stages The `te::Stage`s used in TVM scheduler applying.
- * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
- * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g.
- * CacheRead/CacheWrite step)
- * \param transform_steps An array record all transform steps.
+ * \brief Print a general step as equivalent python schedule API with runtime dynamic dispatching.
+ * \param step The step to be printed as python API.
+ * \param stages The list of current stages
+ * \param stage_to_axes A map that maps stage ot all its iterators.
+ * \param schedule A mutable point to the current schedule
+ * \param transform_steps An array of all history transform steps.
  * \return Python schedule code.
  */
 String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
@@ -245,15 +242,15 @@ class AnnotationStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -307,16 +304,16 @@ class FuseStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return The iterator result after fuse.
    */
   tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -368,15 +365,15 @@ class PragmaStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -430,15 +427,15 @@ class ReorderStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -503,8 +500,8 @@ class SplitStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return The iterator results after split.
    */
   Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
@@ -512,8 +509,8 @@ class SplitStepNode : public StepNode {
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -557,7 +554,7 @@ class FollowSplitStepNode : public StepNode {
  public:
   /*! \brief The id of the iter to be split. */
   int iter_id;
-  /*! \brief The index of the split step to follow in the history. */
+  /*! \brief The index of the split step to be followed in the history. */
   int src_step_id;
   /*! \brief The number of split level. */
   int n_split;
@@ -566,7 +563,7 @@ class FollowSplitStepNode : public StepNode {
 
   /*!
    * \brief Extract split lengths.
-   * \param transform_steps An array record all transform steps.
+   * \param transform_steps An array of history transform steps.
    * \return The multiple split factors.
    */
   Array<Optional<Integer>> ExtractSplitLengths(const Array<Step>& transform_steps) const;
@@ -580,9 +577,9 @@ class FollowSplitStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
-   * \param transform_steps An array record all transform steps.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
+   * \param transform_steps An array of history transform steps.
    * \return The iterator results after split.
    */
   Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
@@ -590,9 +587,9 @@ class FollowSplitStepNode : public StepNode {
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
-   * \param transform_steps An array record all transform steps.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
+   * \param transform_steps An array of history transform steps.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
@@ -614,7 +611,7 @@ class FollowSplitStep : public Step {
    * \brief The constructor.
    * \param stage_id The index of the stage to be split.
    * \param iter_id The index of the iterator to be split.
-   * \param src_step_id The index of the split step to follow in the history.
+   * \param src_step_id The index of the split step to be followed in the history.
    * \param n_split The number of split level.
    */
   FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
@@ -636,7 +633,7 @@ class FollowFusedSplitStepNode : public StepNode {
  public:
   /*! \brief The id of the iter to split. */
   int iter_id;
-  /*! \brief The indices of the split steps to follow in the history. */
+  /*! \brief The indices of the split steps to be followed in the history. */
   Array<Integer> src_step_ids;
   /*! \brief  Use the length in this split level. */
   int level;
@@ -647,7 +644,7 @@ class FollowFusedSplitStepNode : public StepNode {
 
   /*!
    * \brief Extract split length.
-   * \param transform_steps An array record all transform steps.
+   * \param transform_steps An array of history transform steps.
    * \return Split factor.
    */
   Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
@@ -661,9 +658,9 @@ class FollowFusedSplitStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
-   * \param transform_steps An array record all transform steps.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
+   * \param transform_steps An array of history transform steps.
    * \return The iterator results after split.
    */
   Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
@@ -671,9 +668,9 @@ class FollowFusedSplitStepNode : public StepNode {
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
-   * \param transform_steps An array record all transform steps.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
+   * \param transform_steps An array of history transform steps.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
@@ -695,7 +692,7 @@ class FollowFusedSplitStep : public Step {
    * \brief The constructor.
    * \param stage_id The index of the stage to be split.
    * \param iter_id The index of the iterator to be split.
-   * \param src_step_ids An array of index for split step to follow in the history.
+   * \param src_step_ids An array of index for split step to be followed in the history.
    * \param level Use the length in this split level.
    * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts.
    */
@@ -732,15 +729,15 @@ class StorageAlignStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -794,21 +791,21 @@ class ComputeAtStepNode : public StepNode {
    * \note After compute_at, we need careful dependency analysis to compute the accurate bound
    * information. However, it is relatively expensive and complicated, so we just fill "None" as
    * bound for the newly created iterators.
-   * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+   * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
    */
   void ApplyToState(State* state) const;
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -827,7 +824,7 @@ class ComputeAtStep : public Step {
  public:
   /*!
    * \brief The constructor.
-   * \param stage_id The index of the stage to be computed at.
+   * \param stage_id The index of the source stage.
    * \param target_stage_id The index of stage that this step will compute at to.
    * \param target_iter_id The index of iterator in target stage that this step will compute at to.
    */
@@ -856,16 +853,16 @@ class ComputeInlineStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return The iterator result after fuse.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -909,22 +906,22 @@ class ComputeRootStepNode : public StepNode {
    * \note After compute_root, we need careful dependency analysis to compute the accurate bound
    * information. However, it is relatively expensive and complicated, so we just fill "None" as
    * bound for the newly created iterators.
-   * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+   * Call ComputeDAG::InferBound on the updated state if you need the complete bound information.
    */
   void ApplyToState(State* state) const;
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return The iterator result after fuse.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -961,12 +958,12 @@ class ComputeRootStep : public Step {
 
 /*!
  * \brief Cache read step that corresponds to te::Schedule::cache_read.
- * \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date ComputeDAG
- * is stored in State's `current_compute_dag`.
+ * \note Cache read step adds an extra stage to the original ComputeDAG,
+ * an up-to-date ComputeDAG will be stored in State's `current_compute_dag`.
  */
 class CacheReadStepNode : public StepNode {
  public:
-  /*! \brief The scope name of the newly added read stage. (e.g. local, shared, global) */
+  /*! \brief The scope name of the newly added read stage. (e.g., local, shared, global) */
   String scope_name;
   /*! \brief The indices of read stages. */
   Array<Integer> reader_stage_ids;
@@ -983,8 +980,8 @@ class CacheReadStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \param schedule A mutable pointer to a te::Schedule.
    * \return The output Tensor of the new added stage.
    */
@@ -993,8 +990,8 @@ class CacheReadStepNode : public StepNode {
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \param schedule A mutable pointer to a te::Schedule.
    * \return Python schedule code.
    */
@@ -1015,9 +1012,9 @@ class CacheReadStep : public Step {
  public:
   /*!
    * \brief The constructor.
-   * \param stage_id The index of the stage to be cache read.
-   * \param scope_name The scope name of the newly added read stage.
-   * \param reader_stage_ids The indices of read stages.
+   * \param stage_id The index of the stage to be cache_read.
+   * \param scope_name The scope name of the newly added stage.
+   * \param reader_stage_ids The indices of reader stages.
    */
   CacheReadStep(int stage_id, String scope_name, const Array<Integer>& reader_stage_ids);
 
@@ -1054,8 +1051,8 @@ class CacheWriteStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \param schedule A mutable pointer to a te::Schedule.
    * \return The output Tensors of the new added stage.
    */
@@ -1064,8 +1061,8 @@ class CacheWriteStepNode : public StepNode {
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \param schedule A mutable pointer to a te::Schedule.
    * \return Python schedule code.
    */
@@ -1086,8 +1083,8 @@ class CacheWriteStep : public Step {
  public:
   /*!
    * \brief The constructor.
-   * \param stage_id The index of the stage to be cache write.
-   * \param scope_name The scope name of the newly added compute stage.
+   * \param stage_id The index of the stage to be cache_write.
+   * \param scope_name The scope name of the newly added stage.
    */
   CacheWriteStep(int stage_id, String scope_name);
 
@@ -1121,8 +1118,8 @@ class RfactorStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \param schedule A mutable pointer to a te::Schedule.
    * \return The output Tensors of the new added stage.
    */
@@ -1131,8 +1128,8 @@ class RfactorStepNode : public StepNode {
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages The `te::Stage`s used in TVM scheduler applying.
-   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param stages The list of current stages
+   * \param stage_to_axes A map that maps stage ot all its iterators.
    * \param schedule A mutable pointer to a te::Schedule.
    * \return Python schedule code.
    */
index e08454f..f56f430 100644 (file)
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-""" The TVM Auto-scheduler computational graph and related program analyses. """
+""" The auto-scheduler's computational graph and related program analyses. """
 
 import hashlib
 
@@ -33,16 +33,16 @@ from . import _ffi_api
 @tvm._ffi.register_object("auto_scheduler.ComputeDAG")
 class ComputeDAG(Object):
     """
-    The TVM Auto-scheduler computational graph and related program analyses.
+    The auto-scheduler's computational graph and related program analyses.
 
     We convert a compute declaration described by `tvm.compute` (could be a single operator or a
-    subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration,
-    a list of all operations in the DAG as well as static analysis results for the DAG (e.g. the
-    total float operation count, consumer/producer relations of each operation stage, whether an
-    operation stage should be tiled/compute inlined ...). These analyses can help the search policy
-    to make decisions during search process.
-    ComputeDAG is also responsible for the interaction between TVM Auto-scheduler `LoopState` and
-    TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
+    subgraph) to a ComputeDAG. It keeps the input/output tensors, all operations in the DAG, and
+    some static analysis results for the DAG (e.g. the total float operation count,
+    consumer/producer relations of operations, whether an operation stage should
+    be tiled/compute inlined ...).
+    These analyses can help the search policy to make decisions during the search.
+    ComputeDAG is also responsible for the interaction between auto-scheduler's `LoopState` and
+    TVM schedule (e.g. applying the `LoopState` transform steps to TVM schedule, providing
     `LoopState` with extra information got from TVM schedule ...).
 
     Parameters
@@ -90,7 +90,7 @@ class ComputeDAG(Object):
 
     def print_python_code_from_state(self, state):
         """
-        Print transform steps in the history of a State as TVM's python schedule primitive.
+        Print transform steps in the history of a State as TVM's python schedule code.
 
         This is used to print transformation steps for debugging.
         Use `apply_steps_from_state` if you want to get a schedule for code generation.
index 35ecacc..da3a4bf 100644 (file)
 # pylint: disable=unused-import
 
 """
-The definition of the "state" in search.
+The definition of the "state" in the search.
 
 Each LoopState corresponds to a schedule for its ComputeDAG.
 A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to
 construct the loop structure.
 The loop structure keeps a preview of how the schedule will finally look like after lowering the
-current state (e.g. number of iterators, the extent of each iterator, the compute_at locations ...).
+current state (e.g. number of iterators, the extent of each iterator, the compute_at locations
+...).
 During the schedule search process, the loop structure can provide search policy with necessary
 information on how to manipulate the current state.
-The transform history is a sequence of `TransformStep` which will finally be mapped to TVM schedule
-primitives. The steps can also be used for the serialization of a state.
+The transform history is a sequence of `TransformStep` which will finally be mapped to TVM
+schedule primitives. The steps are also used for the serialization of a state.
 
 The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
 We don't use the existing TVM IR but to extend a new structure on it is because:
@@ -37,7 +38,7 @@ immediate loop structures update rather than after TVM lowering;
 3. We may create some macro schedule primitives that represent the combination of several
 TVM schedule primitives.
 
-When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives.
+When the search is finished, we will lower the state to TVM IR with TVM's schedule primitives.
 Since we share a lot of common objects during search, the transformation is implemented in
 copy on write style. All objects are immutable, which is similar to TVM IR.
 """
@@ -136,8 +137,8 @@ class State:
         return [stage.op for stage in self.stages]
 
     def bind(self, stage, iterator, thread_name):
-        """ Schedule primitive corresponds to `te.Stage.bind`, see also the `te.Stage` for more
-        details.
+        """Schedule primitive corresponding to `te.Stage.bind`.
+        See also the `te.Stage` for more details.
 
         Parameters
         ----------
@@ -170,8 +171,8 @@ class State:
         return res
 
     def parallel(self, stage, iterator):
-        """ Schedule primitive corresponds to `te.Stage.parallel`, see also the `te.Stage` for more
-        details.
+        """Schedule primitive corresponding to `te.Stage.parallel`.
+        See also the `te.Stage` for more details.
 
         Parameters
         ----------
@@ -191,8 +192,8 @@ class State:
         return res
 
     def unroll(self, stage, iterator, max_unroll=None):
-        """ Schedule primitive corresponds to `te.Stage.unroll`, see also the `te.Stage` for more
-        details.
+        """Schedule primitive corresponding to `te.Stage.unroll`.
+        See also the `te.Stage` for more details.
 
         Parameters
         ----------
@@ -215,8 +216,8 @@ class State:
         return res
 
     def vectorize(self, stage, iterator):
-        """ Schedule primitive corresponds to `te.Stage.vectorize`, see also the `te.Stage` for
-        more details.
+        """Schedule primitive corresponding to `te.Stage.vectorize`.
+        See also the `te.Stage` for more details.
 
         Parameters
         ----------
@@ -236,8 +237,8 @@ class State:
         return res
 
     def fuse(self, stage, iters):
-        """ Schedule primitive corresponds to `te.Stage.fuse`, see also the `te.Stage` for more
-        details.
+        """Schedule primitive corresponding to `te.Stage.fuse`.
+        See also the `te.Stage` for more details.
 
         Parameters
         ----------
@@ -262,8 +263,8 @@ class State:
         return res
 
     def pragma(self, stage, iterator, pragma_type):
-        """ Schedule primitive corresponds to `te.Stage.pragma`, see also the `te.Stage` for more
-        details.
+        """Schedule primitive corresponding to `te.Stage.pragma`.
+        See also the `te.Stage` for more details.
 
         Parameters
         ----------
@@ -279,8 +280,8 @@ class State:
                                                  iterator, pragma_type)
 
     def reorder(self, stage, order):
-        """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more
-        details.
+        """Schedule primitive corresponding to `te.Stage.reorder`.
+        See also the `te.Stage` for more details.
 
         Parameters
         ----------
@@ -294,8 +295,8 @@ class State:
                                                   order)
 
     def split(self, stage, iterator, lengths, inner_to_outer=True):
-        """ Schedule primitive corresponds to `te.Stage.split`, see also the `te.Stage` for more
-        details.
+        """Schedule primitive corresponding to `te.Stage.split`.
+        See also the `te.Stage` for more details.
 
         This API supports multiple split factors. (e.g. with 2 split factors, the original iterator
         will be split to 3 parts, use `inner_to_outer` to control the split order)
@@ -328,7 +329,7 @@ class State:
         return res
 
     def follow_split(self, stage, iterator, src_step_id, n_split):
-        """ Schedule primitive extends to split step.
+        """The schedule primitive similar to split, but uses split factors from previous steps.
 
         This step splits the iterator by the same factors as the given SplitStep.
 
@@ -348,7 +349,7 @@ class State:
         iterator : Iterator
             The iterator to split.
         src_step_id : int
-            The index of the split step to follow in the history.
+            The index of the split step to be followed in the history.
         n_split : int
             The number of split level.
 
@@ -394,7 +395,7 @@ class State:
         iterator : Iterator
             The iterator to split.
         src_step_ids : List[int]
-            The indices of the split steps to follow in the history.
+            The indices of the split steps to be followed in the history.
         level : int
             Use the length in this split level.
         factor_or_nparts : bool
@@ -415,8 +416,8 @@ class State:
         return res
 
     def storage_align(self, stage, iterator, factor, offset):
-        """ Schedule primitive corresponds to `te.Stage.storage_align`, see also the `te.Stage` for
-        more details.
+        """Schedule primitive corresponding to `te.Stage.storage_align`.
+        See also the `te.Stage` for  more details.
 
         Parameters
         ----------
@@ -435,14 +436,14 @@ class State:
                                                        factor, offset)
 
     def compute_at(self, stage, target_stage, target_iter):
-        """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
-        more details.
+        """Schedule primitive corresponding to `te.Stage.compute_at`.
+        See also the `te.Stage` for more details.
 
         Parameters
         ----------
         stage : Union[int, Operation, Tensor]
-            The Stage to be computed at, which can be specified by the integer index, Operation,
-            or output tensor of the stage.
+            The source Stage of computed at, which can be specified by the integer index,
+            Operation, or output tensor of the stage.
         target_stage : Union[int, Operation, Tensor]
             The target stage of compute_at, which can be specified by the integer index, Operation,
             or output tensor of the stage.
@@ -462,7 +463,7 @@ class State:
                                                     target_iter)
 
     def compute_inline(self, stage):
-        """ Schedule primitive corresponds to `te.Stage.compute_inline`, see also the `te.Stage`
+        """Schedule primitive corresponding to `te.Stage.compute_inline`, see also the `te.Stage`
         for more details.
 
         Parameters
@@ -475,8 +476,8 @@ class State:
                                                         self._resolve_stage_id(stage))
 
     def compute_root(self, stage):
-        """ Schedule primitive corresponds to `te.Stage.compute_root`, see also the `te.Stage` for
-        more details.
+        """Schedule primitive corresponding to `te.Stage.compute_root`.
+        Ssee also the `te.Stage` for more details.
 
         Parameters
         ----------
@@ -495,13 +496,13 @@ class State:
                                                       self._resolve_stage_id(stage))
 
     def cache_read(self, stage, scope_name, reader_stages):
-        """ Schedule primitive corresponds to `te.Schedule.cache_read`, see also the `te.Schedule`
-        for more details.
+        """Schedule primitive corresponding to `te.Schedule.cache_read`.
+        See also the `te.Schedule` for more details.
 
         Parameters
         ----------
         stage : Union[int, Operation, Tensor]
-            The Stage to be cache read, which can be specified by the integer index, Operation,
+            The Stage to be cache_read, which can be specified by the integer index, Operation,
             or output tensor of the stage.
         scope_name : str
             The scope name of the newly added read stage.
@@ -531,13 +532,13 @@ class State:
         return self.stages[int(new_stage_id)].op
 
     def cache_write(self, stage, scope_name):
-        """ Schedule primitive corresponds to `te.Schedule.cache_write`, see also the `te.Schedule`
-        for more details.
+        """Schedule primitive corresponding to `te.Schedule.cache_write`.
+        See also the `te.Schedule` for more details.
 
         Parameters
         ----------
         stage : Union[int, Operation, Tensor]
-            The Stage to be cache write, which can be specified by the integer index, Operation,
+            The Stage to be cache_write, which can be specified by the integer index, Operation,
             or output tensor of the stage.
         scope_name : str
             The scope name of the newly added compute stage.
@@ -563,8 +564,8 @@ class State:
         return self.stages[int(new_stage_id)].op
 
     def rfactor(self, stage, iterator, factor_iter_id):
-        """ Schedule primitive corresponds to `te.Schedule.rfactor`, see also the `te.Schedule` for
-        more details.
+        """Schedule primitive corresponding to `te.Schedule.rfactor`.
+        See also the `te.Schedule` for more details.
 
         Parameters
         ----------
index f2815fb..b11dd73 100644 (file)
@@ -24,6 +24,7 @@
 
 #include <tvm/auto_scheduler/compute_dag.h>
 #include <tvm/auto_scheduler/loop_state.h>
+#include <tvm/auto_scheduler/transform_step.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/te/schedule.h>
@@ -740,7 +741,9 @@ State ComputeDAG::InferBound(const State& state) const {
     ret_state = operator->()->init_state;
     pstate = ret_state.CopyOnWrite();
     pstate->transform_steps = state->transform_steps;
-    ret_state.ApplySteps(*this);
+    for (const auto& step : pstate->transform_steps) {
+      StepApplyToState(step, &ret_state, *this);
+    }
   } else {
     ret_state = state;
     pstate = ret_state.CopyOnWrite();
index f9d1f82..9e1a54f 100644 (file)
@@ -341,15 +341,6 @@ int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const C
   return step->ApplyToState(this, dag);
 }
 
-void State::ApplySteps(const ComputeDAG& dag) {
-  CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
-
-  // Call each step's ApplyToState method
-  for (const auto& step : operator->()->transform_steps) {
-    StepApplyToState(step, this, dag);
-  }
-}
-
 // Print stage to ostream
 void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent,
                 bool delete_trivial_loop) {