[Ansor][AutoTVM v2.0] Phase 1: Add pragma/storage_align/rfactor steps (#6141)
authorChenfan <chengfan.jcf@alibaba-inc.com>
Wed, 29 Jul 2020 06:39:21 +0000 (14:39 +0800)
committerGitHub <noreply@github.com>
Wed, 29 Jul 2020 06:39:21 +0000 (23:39 -0700)
* Add pragma/storage_align/rfactor step

* Update

* Update

* Update UT

* Update

include/tvm/auto_scheduler/loop_state.h
include/tvm/auto_scheduler/transform_step.h
python/tvm/auto_scheduler/loop_state.py
src/auto_scheduler/loop_state.cc
src/auto_scheduler/transform_step.cc
src/auto_scheduler/utils.h
tests/python/unittest/test_auto_scheduler_loop_state.py
tests/python/unittest/test_auto_scheduler_measure.py

index 9850620..34e7e56 100644 (file)
@@ -341,6 +341,13 @@ class State : public ObjectRef {
    */
   TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
   /*!
+   * \brief Schedule primitive corresponds to `te.Stage.pragma`.
+   * \param stage_id The index of the stage to add pragma.
+   * \param it The iterator to add pragma.
+   * \param pragma_type The pragma string.
+   */
+  TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
+  /*!
    * \brief Schedule primitive corresponds to `te::Stage::reorder`.
    * \param stage_id The index of the stage to be reordered.
    * \param order The expected iterator order.
@@ -382,6 +389,14 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
                                              const Array<Integer>& src_step_ids, int level,
                                              bool factor_or_nparts);
+  /*!
+   * \brief Schedule primitive corresponds to `te.Stage.storage_align`.
+   * \param stage_id The index of the stage to be aligned.
+   * \param it The iterator to be aligned.
+   * \param factor The factor in alignment specification.
+   * \param offset The offset in the alignment specification.
+   */
+  TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);
 
   /********** Step APIs working on multiple stages **********/
 
@@ -422,8 +437,8 @@ class State : public ObjectRef {
    * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
    * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
    */
-  int cache_read(int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids,
-                 const ComputeDAG& dag);
+  TVM_DLL int cache_read(int stage_id, const String& scope_name,
+                         const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
   /*!
    * \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
    * \param stage_id The index of the stage to be cache write.
@@ -433,7 +448,17 @@ class State : public ObjectRef {
    * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
    * This step will cache write all output tensors of the target stage.
    */
-  int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
+  TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
+  /*!
+   * \brief Schedule primitive corresponds to `te::Schedule::rfactor`.
+   * \param stage_id The index of the iterator to be factored.
+   * \param it The iterator to be factored.
+   * \param factor_iter_id The position where the new iterator is placed.
+   * \param dag The original ComputeDAG of this state.
+   * \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the
+   * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+   */
+  TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);
 
   TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
index f91505c..a31765a 100644 (file)
@@ -350,6 +350,67 @@ class FuseStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
 };
 
+/*! \brief Pragma step that corresponds to te::Stage::pragma */
+class PragmaStepNode : public StepNode {
+ public:
+  /*! \brief The index of the iterator to add pragma. */
+  int iter_id;
+  /*! \brief The pragma string. */
+  String pragma_type;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to state, which will be updated.
+   */
+  void ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   */
+  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  static constexpr const char* record_prefix_str = "PR";
+
+  static constexpr const char* _type_key = "auto_scheduler.PragmaStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to PragmaStepNode.
+ * \sa PragmaStepNode
+ */
+class PragmaStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be fused.
+   * \param iter_id The index of the iterator to add pragma.
+   * \param pragma_type The pragma string.
+   */
+  PragmaStep(int stage_id, int iter_id, String pragma_type);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit PragmaStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode);
+};
+
 /*! \brief Reorder step that corresponds to te::Stage::reorder */
 class ReorderStepNode : public StepNode {
  public:
@@ -506,14 +567,14 @@ class FollowSplitStepNode : public StepNode {
   /*!
    * \brief Extract split lengths.
    * \param transform_steps An array record all transform steps.
-   * \param lengths The multiple split factors. Can be None to be filled by search policy.
+   * \return The multiple split factors.
    */
-  void ExtractSplitLengths(const Array<Step>& transform_steps,
-                           Array<Optional<Integer>>* lengths) const;
+  Array<Optional<Integer>> ExtractSplitLengths(const Array<Step>& transform_steps) const;
 
   /*!
    * \brief Apply the current step to State.
    * \param state A mutable pointer to state, which will be updated.
+   * \return The iterator results after split.
    */
   Array<Iterator> ApplyToState(State* state) const;
 
@@ -651,6 +712,70 @@ class FollowFusedSplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
 };
 
+/*! \brief Storage align step that corresponds to te::Stage::storage_align */
+class StorageAlignStepNode : public StepNode {
+ public:
+  /*! \brief The iterator to be aligned. */
+  int iter_id;
+  /*! \brief The factor in alignment specification. */
+  int factor;
+  /*! \brief The offset in the alignment specification. */
+  int offset;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State, which will be updated.
+   */
+  void ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   */
+  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  static constexpr const char* record_prefix_str = "SA";
+
+  static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to StorageAlignStepNode.
+ * \sa StorageAlignStepNode
+ */
+class StorageAlignStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be aligned.
+   * \param iter_id The index of the iterator to be aligned.
+   * \param factor The factor in alignment specification.
+   * \param offset The offset in the alignment specification.
+   */
+  StorageAlignStep(int stage_id, int iter_id, int factor, int offset);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit StorageAlignStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode);
+};
+
 /********** Steps working on multiple stages **********/
 
 /*! \brief Compute at step that corresponds to te::Stage::compute_at */
@@ -832,7 +957,7 @@ class ComputeRootStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
 };
 
-/********** Primitives adding new stages **********/
+/********** Steps adding new stages **********/
 
 /*!
  * \brief Cache read step that corresponds to te::Schedule::cache_read.
@@ -976,6 +1101,74 @@ class CacheWriteStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode);
 };
 
+/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */
+class RfactorStepNode : public StepNode {
+ public:
+  /*! \brief The index of the iterator to be factored. */
+  int iter_id;
+  /*! \brief The position where the new iterator is placed. */
+  int factor_iter_id;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State, which will be updated.
+   * \param dag The original ComputeDAG of this state.
+   * \return The index of the new added stage.
+   */
+  int ApplyToState(State* state, const ComputeDAG& dag) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return The output Tensors of the new added stage.
+   */
+  Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                    te::Schedule* schedule) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          te::Schedule* schedule) const;
+
+  static constexpr const char* record_prefix_str = "RF";
+
+  static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to RfactorStepNode.
+ * \sa RfactorStepNode
+ */
+class RfactorStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be factored.
+   * \param iter_id The index of the iterator to be factored.
+   * \param factor_iter_id The position where the new iterator is placed.
+   */
+  RfactorStep(int stage_id, int iter_id, int factor_iter_id);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit RfactorStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode);
+};
+
 }  // namespace auto_scheduler
 }  // namespace tvm
 
index 9ec26c3..35ecacc 100644 (file)
@@ -261,6 +261,23 @@ class State:
                                                     self._resolve_stage_id(stage), iters)
         return res
 
+    def pragma(self, stage, iterator, pragma_type):
+        """ Schedule primitive corresponds to `te.Stage.pragma`, see also the `te.Stage` for more
+        details.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to add pragma, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The iterator to add pragma.
+        pragma_type : str
+            The pragma string.
+        """
+        self.state_object = _ffi_api.StatePragma(self.state_object, self._resolve_stage_id(stage),
+                                                 iterator, pragma_type)
+
     def reorder(self, stage, order):
         """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more
         details.
@@ -397,6 +414,26 @@ class State:
                                                                 factor_or_nparts)
         return res
 
+    def storage_align(self, stage, iterator, factor, offset):
+        """ Schedule primitive corresponds to `te.Stage.storage_align`, see also the `te.Stage` for
+        more details.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be storage aligned, which can be specified by the integer index,
+            Operation, or output tensor of the stage.
+        iterator : Iterator
+            The iterator to be aligned.
+        factor : int
+            The factor in alignment specification.
+        offset : int
+            The offset in the alignment specification.
+        """
+        self.state_object = _ffi_api.StateStorageAlign(self.state_object,
+                                                       self._resolve_stage_id(stage), iterator,
+                                                       factor, offset)
+
     def compute_at(self, stage, target_stage, target_iter):
         """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
         more details.
@@ -525,6 +562,40 @@ class State:
         self._update_stage_id_map()
         return self.stages[int(new_stage_id)].op
 
+    def rfactor(self, stage, iterator, factor_iter_id):
+        """ Schedule primitive corresponds to `te.Schedule.rfactor`, see also the `te.Schedule` for
+        more details.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be factored, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The reduction iterator to be factored.
+        factor_iter_id : int
+            The position where the new iterator is placed.
+
+        Returns
+        -------
+        new_stage_op : Operator
+            The Operator of the new added stage.
+
+        Notes
+        -----
+        Rfactor step will insert an extra stage to the original ComputeDAG (in the front of the
+        target stage).
+        """
+        self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object,
+                                                                self._resolve_stage_id(stage),
+                                                                iterator, factor_iter_id,
+                                                                self.compute_dag)
+        # Add a new stage will change all ops behind the added stage. But we still want to keep the
+        # original ops map, apply stage id offset to stage_id_map to make them work.
+        self._apply_stage_id_offset(int(new_stage_id))
+        self._update_stage_id_map()
+        return self.stages[int(new_stage_id)].op
+
     def copy(self):
         """ Do deep copy of this State. """
         state = State(self.state_object, self.compute_dag)
index 636066a..f9d1f82 100644 (file)
@@ -247,6 +247,13 @@ Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
   return step->ApplyToState(this);
 }
 
+void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) {
+  const Stage& stage = operator->()->stages[stage_id];
+  PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
+}
+
 void State::reorder(int stage_id, const Array<Iterator>& order) {
   const Stage& stage = operator->()->stages[stage_id];
   CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
@@ -287,6 +294,13 @@ Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
   return step->ApplyToState(this);
 }
 
+void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) {
+  const Stage& stage = operator->()->stages[stage_id];
+  StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
+}
+
 void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
   const Stage& target_stage = operator->()->stages[target_stage_id];
   ComputeAtStep step =
@@ -320,6 +334,13 @@ int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG&
   return step->ApplyToState(this, dag);
 }
 
+int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) {
+  const Stage& stage = operator->()->stages[stage_id];
+  RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this, dag);
+}
+
 void State::ApplySteps(const ComputeDAG& dag) {
   CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
 
@@ -460,6 +481,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
       return Array<ObjectRef>{state, res};
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) {
+      state.pragma(stage_id, it, pragma_type);
+      return state;
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder")
     .set_body_typed([](State state, int stage_id, const Array<Iterator>& order) {
       state.reorder(stage_id, order);
@@ -488,6 +515,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
       return Array<ObjectRef>{state, Array<Iterator>(res)};
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) {
+      state.storage_align(stage_id, it, factor, offset);
+      return state;
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
     .set_body_typed([](State state, int stage_id, int target_stage_id,
                        const Iterator& target_iter) {
@@ -521,6 +554,13 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite")
       return Array<ObjectRef>{state, Integer(res)};
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id,
+                       const ComputeDAG& dag) {
+      int res = state.rfactor(stage_id, it, factor_iter_id, dag);
+      return Array<ObjectRef>{state, Integer(res)};
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
   return std::equal_to<State>()(state1, state2);
 });
index d43d0af..e533a7c 100644 (file)
@@ -81,6 +81,8 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) {
     return AnnotationStep(reader);
   } else if (name == FuseStepNode::record_prefix_str) {
     return FuseStep(reader);
+  } else if (name == PragmaStepNode::record_prefix_str) {
+    return PragmaStep(reader);
   } else if (name == ReorderStepNode::record_prefix_str) {
     return ReorderStep(reader);
   } else if (name == SplitStepNode::record_prefix_str) {
@@ -89,6 +91,8 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) {
     return FollowSplitStep(reader);
   } else if (name == FollowFusedSplitStepNode::record_prefix_str) {
     return FollowFusedSplitStep(reader);
+  } else if (name == StorageAlignStepNode::record_prefix_str) {
+    return StorageAlignStep(reader);
   } else if (name == ComputeAtStepNode::record_prefix_str) {
     return ComputeAtStep(reader);
   } else if (name == ComputeInlineStepNode::record_prefix_str) {
@@ -99,6 +103,8 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) {
     return CacheReadStep(reader);
   } else if (name == CacheWriteStepNode::record_prefix_str) {
     return CacheWriteStep(reader);
+  } else if (name == RfactorStepNode::record_prefix_str) {
+    return RfactorStep(reader);
   } else {
     LOG(FATAL) << "Invalid step format: " << name;
   }
@@ -111,6 +117,8 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<FuseStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<PragmaStepNode>()) {
+    ps->ApplyToState(state);
   } else if (auto ps = step.as<ReorderStepNode>()) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<SplitStepNode>()) {
@@ -119,6 +127,8 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<StorageAlignStepNode>()) {
+    ps->ApplyToState(state);
   } else if (auto ps = step.as<ComputeAtStepNode>()) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -129,6 +139,8 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state, dag);
   } else if (auto ps = step.as<CacheWriteStepNode>()) {
     ps->ApplyToState(state, dag);
+  } else if (auto ps = step.as<RfactorStepNode>()) {
+    ps->ApplyToState(state, dag);
   } else {
     LOG(FATAL) << "Invalid step: " << step;
   }
@@ -140,6 +152,8 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<FuseStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<PragmaStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<ReorderStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<SplitStepNode>()) {
@@ -148,6 +162,8 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
     ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
   } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
+  } else if (auto ps = step.as<StorageAlignStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeAtStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -158,6 +174,8 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
     ps->ApplyToSchedule(stages, stage_to_axes, schedule);
   } else if (auto ps = step.as<CacheWriteStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes, schedule);
+  } else if (auto ps = step.as<RfactorStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes, schedule);
   } else {
     LOG(FATAL) << "Invalid Step: " << step;
   }
@@ -170,6 +188,8 @@ String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<FuseStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<PragmaStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<ReorderStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<SplitStepNode>()) {
@@ -178,6 +198,8 @@ String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
     return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
   } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
+  } else if (auto ps = step.as<StorageAlignStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeAtStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -188,6 +210,8 @@ String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
     return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
   } else if (auto ps = step.as<CacheWriteStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
+  } else if (auto ps = step.as<RfactorStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
   } else {
     LOG(FATAL) << "Invalid Step: " << step;
   }
@@ -488,6 +512,113 @@ String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Pragma **********/
+PragmaStep::PragmaStep(int stage_id, int iter_id, String pragma_type) {
+  auto node = make_object<PragmaStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->pragma_type = std::move(pragma_type);
+  data_ = std::move(node);
+}
+
+PragmaStep::PragmaStep(dmlc::JSONReader* reader) {
+  auto node = make_object<PragmaStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->pragma_type = std::move(string_value);
+  data_ = std::move(node);
+}
+
+void PragmaStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(pragma_type);
+}
+
+void PragmaStepNode::ApplyToState(State* state) const {
+  if (pragma_type == "debug_skip_region") {
+    StateNode* pstate = state->CopyOnWrite();
+    pstate->attach_map.DeleteStage(stage_id);
+  } else if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
+    StateNode* pstate = state->CopyOnWrite();
+    Stage stage = pstate->stages[stage_id];
+    size_t pos = 0;
+    for (; pos < pragma_type.size(); ++pos) {
+      if ((*(pragma_type.c_str() + pos)) == '$') {
+        break;
+      }
+    }
+    CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
+    stage.CopyOnWrite()->attrs.auto_unroll_max_step = atoi(pragma_type.c_str() + pos + 1);
+    pstate->stages.Set(stage_id, std::move(stage));
+  } else {
+    LOG(FATAL) << "Unsupported pragma: " << pragma_type;
+  }
+}
+
+void PragmaStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                     StageToAxesMap* stage_to_axes) const {
+  te::Stage stage = (*stages)[stage_id];
+  const Array<IterVar>& axes = (*stage_to_axes)[stage];
+  if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
+    size_t pos = 0;
+    for (; pos < pragma_type.size(); ++pos) {
+      if ((*(pragma_type.c_str() + pos)) == '$') {
+        break;
+      }
+    }
+    CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
+    int value = atoi(pragma_type.c_str() + pos + 1);
+    stage.pragma(axes[iter_id], "auto_unroll_max_step", value);
+    stage.pragma(axes[iter_id], "unroll_explicit", true);
+  } else {
+    stage.pragma(axes[iter_id], pragma_type);
+  }
+  stages->Set(stage_id, std::move(stage));
+}
+
+String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                        StageToAxesMap* stage_to_axes) const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+
+  if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
+    size_t pos = 0;
+    for (; pos < pragma_type.size(); ++pos) {
+      if ((*(pragma_type.c_str() + pos)) == '$') {
+        break;
+      }
+    }
+    CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
+    int value = atoi(pragma_type.c_str() + pos + 1);
+    ss << "s[" << CleanName(stage->op->name) << "].pragma("
+       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
+       << ", \"auto_unroll_max_step\", " << value << ")\n";
+    ss << "s[" << CleanName(stage->op->name) << "].pragma("
+       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
+       << ", \"unroll_explicit\", True)\n";
+  } else {
+    ss << "s[" << CleanName(stage->op->name) << "].pragma("
+       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type
+       << "\")\n";
+  }
+
+  ApplyToSchedule(stages, stage_to_axes);
+  return ss.str();
+}
+
 /********** Reorder **********/
 ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) {
   auto node = make_object<ReorderStepNode>();
@@ -812,8 +943,8 @@ void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
   writer->WriteArrayItem(n_split);
 }
 
-void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
-                                              Array<Optional<Integer>>* lengths) const {
+Array<Optional<Integer>> FollowSplitStepNode::ExtractSplitLengths(
+    const Array<Step>& transform_steps) const {
   // Make sure src_step_id is within the range of transform_steps.
   CHECK_LT(src_step_id, transform_steps.size());
   auto ps = transform_steps[src_step_id].as<SplitStepNode>();
@@ -824,11 +955,12 @@ void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps
   CHECK_LE(n_split, ps->lengths.size() + 1);
   CHECK(ps != nullptr);
 
-  lengths->reserve(n_split);
+  Array<Optional<Integer>> lengths;
+  lengths.reserve(n_split);
   int j = 0;
   // Get the first (n_split-1) split factors of followed src_step.
   for (; j < n_split - 1; ++j) {
-    lengths->push_back(ps->lengths[j]);
+    lengths.push_back(ps->lengths[j]);
   }
 
   // Get the last split factor of src_step for splitting level if n_split is smaller than
@@ -843,10 +975,12 @@ void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps
     }
   }
   if (last_factor.defined()) {
-    lengths->push_back(Downcast<Integer>(last_factor));
+    lengths.push_back(Downcast<Integer>(last_factor));
   } else {
-    lengths->push_back(NullOpt);
+    lengths.push_back(NullOpt);
   }
+
+  return lengths;
 }
 
 FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
@@ -864,30 +998,26 @@ FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
   s = reader->NextArrayItem();
   CHECK(s);
   reader->Read(&node->n_split);
-
   data_ = std::move(node);
 }
 
 Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
-  Array<Optional<Integer>> lengths;
-  ExtractSplitLengths((*state)->transform_steps, &lengths);
-  return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+  return ApplySplitToState(state, stage_id, iter_id, ExtractSplitLengths((*state)->transform_steps),
+                           true);
 }
 
 Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
                                                     StageToAxesMap* stage_to_axes,
                                                     const Array<Step>& transform_steps) const {
-  Array<Optional<Integer>> lengths;
-  ExtractSplitLengths(transform_steps, &lengths);
-  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id,
+                              ExtractSplitLengths(transform_steps), true);
 }
 
 String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
                                              StageToAxesMap* stage_to_axes,
                                              const Array<Step>& transform_steps) const {
-  Array<Optional<Integer>> lengths;
-  ExtractSplitLengths(transform_steps, &lengths);
-  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id,
+                               ExtractSplitLengths(transform_steps), true);
 }
 
 /********** Follow Fused Split **********/
@@ -922,7 +1052,6 @@ FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
   s = reader->NextArrayItem();
   CHECK(s);
   reader->Read(&node->factor_or_nparts);
-
   ::tvm::Array<::tvm::Integer> src_step_ids;
   for (const auto& i : int_list) {
     src_step_ids.push_back(i);
@@ -961,23 +1090,86 @@ Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
 }
 
 Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const {
-  const Optional<Integer>& length = ExtractSplitLength((*state)->transform_steps);
-  return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts);
+  return ApplySplitToState(state, stage_id, iter_id,
+                           {ExtractSplitLength((*state)->transform_steps)}, factor_or_nparts);
 }
 
 Array<IterVar> FollowFusedSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
                                                          StageToAxesMap* stage_to_axes,
                                                          const Array<Step>& transform_steps) const {
-  const Optional<Integer>& length = ExtractSplitLength(transform_steps);
-  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts);
+  return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id,
+                              {ExtractSplitLength(transform_steps)}, factor_or_nparts);
 }
 
 String FollowFusedSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
                                                   StageToAxesMap* stage_to_axes,
                                                   const Array<Step>& transform_steps) const {
-  const Optional<Integer>& length = ExtractSplitLength(transform_steps);
-  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length},
-                               factor_or_nparts);
+  return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id,
+                               {ExtractSplitLength(transform_steps)}, factor_or_nparts);
+}
+
+/********** Storage Align **********/
+StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, int factor, int offset) {
+  auto node = make_object<StorageAlignStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->factor = factor;
+  node->offset = offset;
+  data_ = std::move(node);
+}
+
+StorageAlignStep::StorageAlignStep(dmlc::JSONReader* reader) {
+  auto node = make_object<StorageAlignStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->factor);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->offset);
+  data_ = std::move(node);
+}
+
+void StorageAlignStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(factor);
+  writer->WriteArrayItem(offset);
+}
+
+void StorageAlignStepNode::ApplyToState(State* state) const {
+  StateNode* pstate = state->CopyOnWrite();
+  Stage stage = pstate->stages[stage_id];
+  stage.CopyOnWrite()->attrs.storage_offset = offset;
+  pstate->stages.Set(stage_id, std::move(stage));
+}
+
+void StorageAlignStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                           StageToAxesMap* stage_to_axes) const {
+  te::Stage stage = (*stages)[stage_id];
+  const Array<IterVar>& axes = (*stage_to_axes)[stage];
+  stage.storage_align(axes[iter_id], factor, offset);
+  stages->Set(stage_id, std::move(stage));
+}
+
+String StorageAlignStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                              StageToAxesMap* stage_to_axes) const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+  ss << "s[" << CleanName(stage->op->name) << "].storage_align("
+     << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", "
+     << offset << ")\n";
+
+  ApplyToSchedule(stages, stage_to_axes);
+  return ss.str();
 }
 
 /********** Steps working on multiple stages **********/
@@ -1162,7 +1354,7 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
-/********** Primitives adding new stages **********/
+/********** Steps adding new stages **********/
 
 /*!
  * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep,
@@ -1171,11 +1363,27 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
  */
 Array<Step> GetFormerStageModifiableSteps(Step current_step, const Array<Step>& transform_steps) {
   Array<Step> ret_steps;
-  for (const Step& step : transform_steps) {
+  for (size_t i = 0; i < transform_steps.size(); ++i) {
+    const Step& step = transform_steps[i];
     if (step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>()) {
       ret_steps.push_back(step);
+    } else if (step->IsInstance<RfactorStepNode>()) {
+      // add FuseStepNode required by rfactor
+      if (i >= 2 && transform_steps[i - 2]->IsInstance<FuseStepNode>()) {
+        const Step& fuse_step = transform_steps[i - 2];
+        if (fuse_step->stage_id == step->stage_id) {
+          ret_steps.push_back(fuse_step);
+        }
+      }
+      // add SplitStepNode required by rfactor
+      CHECK_GE(i, 1);
+      CHECK(transform_steps[i - 1]->IsInstance<SplitStepNode>());
+      const Step& split_step = transform_steps[i - 1];
+      CHECK_EQ(split_step->stage_id, step->stage_id);
+      ret_steps.push_back(split_step);
+      // add RfactorStepNode
+      ret_steps.push_back(step);
     }
-    // TODO(jcf94): add rfactor support
     // A state may have multiple stage modifiable steps, stop by the current step to avoid
     // replaying excess steps
     if (step.same_as(current_step)) {
@@ -1432,5 +1640,130 @@ String CacheWriteStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxe
   return ss.str();
 }
 
+/********** Rfactor **********/
+RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) {
+  auto node = make_object<RfactorStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->factor_iter_id = factor_iter_id;
+  data_ = std::move(node);
+}
+
+RfactorStep::RfactorStep(dmlc::JSONReader* reader) {
+  auto node = make_object<RfactorStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->factor_iter_id);
+  data_ = std::move(node);
+}
+
+void RfactorStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(factor_iter_id);
+}
+
+int RfactorStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
+  StateNode* pstate = state->CopyOnWrite();
+  const auto& compute_at_type = pstate->stages[stage_id]->compute_at;
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(
+      GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
+
+  // target_stage -> rfactor_compute + target_stage
+  // Insert a new compute stage, update the target stage and later stage, then update the stage_id
+  // mapping in AttachMap
+  pstate->stages.insert(pstate->stages.begin() + stage_id,
+                        Stage(current_compute_dag->ops[stage_id]));
+  // Maintain the compute_at type of the target stage
+  Stage target_stage = Stage(current_compute_dag->ops[stage_id + 1]);
+  target_stage.CopyOnWrite()->compute_at = compute_at_type;
+  pstate->stages.Set(stage_id + 1, std::move(target_stage));
+  for (size_t i = stage_id + 2; i < pstate->stages.size(); ++i) {
+    Stage stage = pstate->stages[i];
+    stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+    pstate->stages.Set(i, std::move(stage));
+  }
+  pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id);
+  pstate->current_compute_dag = std::move(current_compute_dag);
+
+  return stage_id;
+}
+
+Array<te::Tensor> RfactorStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                   StageToAxesMap* stage_to_axes,
+                                                   te::Schedule* schedule) const {
+  const auto& stage = (*stages)[stage_id];
+  const Array<IterVar>& axes = (*stage_to_axes)[stage];
+
+  const te::Tensor& tensor = stage->origin_op.output(0);
+  const IterVar& axis = axes[iter_id];
+  auto outs = schedule->rfactor(tensor, axis, factor_iter_id);
+
+  UpdateStageToAxesMap(stage, stage_to_axes);
+  const auto& new_stage = (*schedule)[outs[0]->op];
+  UpdateStageToAxesMap(new_stage, stage_to_axes);
+  stages->insert(stages->begin() + stage_id, new_stage);
+
+  return outs;
+}
+
+String RfactorStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                         te::Schedule* schedule) const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+
+  const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name);
+  const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint);
+
+  const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+  for (size_t i = 0; i < outs.size(); ++i) {
+    ss << CleanName(outs[i]->op->name);
+    if (i != outs.size() - 1) {
+      ss << ", ";
+    }
+  }
+  ss << " = "
+     << "s.rfactor(" << tensor_name << ", " << axis_name << ", " << factor_iter_id << ")\n";
+
+  for (const auto& out : outs) {
+    const auto& iters = out->op->root_iter_vars();
+    for (size_t i = 0; i < iters.size(); ++i) {
+      ss << CleanName(iters[i]->var->name_hint);
+      if (i != iters.size() - 1) {
+        ss << ", ";
+      }
+    }
+    ss << " = "
+       << "tuple(" << CleanName(out->op->name) << ".op.axis)"
+       << " + "
+       << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
+  }
+
+  const auto& output = (*stages)[stage_id + 1]->op.output(0);
+  const auto& iters = output->op->root_iter_vars();
+  for (size_t i = 0; i < iters.size(); ++i) {
+    ss << CleanName(iters[i]->var->name_hint);
+    if (i != iters.size() - 1) {
+      ss << ", ";
+    }
+  }
+  ss << " = "
+     << "tuple(s[" << CleanName(output->op->name) << "].op.axis)"
+     << " + "
+     << "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n";
+
+  return ss.str();
+}
+
 }  // namespace auto_scheduler
 }  // namespace tvm
index da5032e..aacdcf4 100644 (file)
@@ -162,6 +162,12 @@ inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
   return sum / float_array.size();
 }
 
+/*! \brief Return whether a string starts with another substring */
+inline bool StrStartsWith(const String& a, const String& b) {
+  if (b.size() > a.size()) return false;
+  return std::equal(a.c_str(), a.c_str() + b.size(), b.c_str());
+}
+
 /********** Other Utilities **********/
 /*! \brief Get an int value from an Expr */
 inline int64_t GetIntImm(const PrimExpr& expr) {
index e35dfe3..5c501ac 100644 (file)
@@ -454,8 +454,62 @@ def test_follow_split_follow_fused_split():
         assert tmp[C].iters[level + 1].range.extent == \
                tmp[C_global].iters[1].range.extent
 
+
+def test_rfactor():
+    A, B, C = matmul_auto_scheduler_test(8, 8, 512)
+    dag = auto_scheduler.ComputeDAG([A, B, C])
+    s0 = dag.get_init_state()
+
+    ko, ki = s0.split(C, s0[C].iters[2], [16])
+
+    s1 = s0.copy()
+    C_r = s1.rfactor(C, ko, 2)
+    """
+        Placeholder: A, B
+        for i (0,8)
+          for j (0,8)
+            for k_o (0,32)
+              for k_i (0,16)
+                C.rf = ...
+        for ax0 (0,8)
+          for ax1 (0,8)
+            for k_o_v (0,32)
+              C.repl = ...
+    """
+    assert s1[C_r].iters[0].range.extent == 8
+    assert s1[C_r].iters[1].range.extent == 8
+    assert s1[C_r].iters[2].range.extent == 32
+    assert s1[C_r].iters[3].range.extent == 16
+    assert s1[C].iters[0].range.extent == 8
+    assert s1[C].iters[1].range.extent == 8
+    assert s1[C].iters[2].range.extent == 32
+
+    s2 = s0.copy()
+    C_r = s2.rfactor(C, ki, 2)
+    """
+        Placeholder: A, B
+        for i (0,8)
+          for j (0,8)
+            for k_i (0,16)
+              for k_o (0,32)
+                C.rf = ...
+        for ax0 (0,8)
+          for ax1 (0,8)
+            for k_i_v (0,16)
+              C.repl = ...
+    """
+    assert s2[C_r].iters[0].range.extent == 8
+    assert s2[C_r].iters[1].range.extent == 8
+    assert s2[C_r].iters[2].range.extent == 16
+    assert s2[C_r].iters[3].range.extent == 32
+    assert s2[C].iters[0].range.extent == 8
+    assert s2[C].iters[1].range.extent == 8
+    assert s2[C].iters[2].range.extent == 16
+
+
 if __name__ == "__main__":
     test_split_fuse_reorder_annotation()
     test_compute_at_root_inline()
     test_cache_read_write()
     test_follow_split_follow_fused_split()
+    test_rfactor()
index 39d01e0..e65f191 100644 (file)
@@ -24,7 +24,28 @@ import tempfile
 
 from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul
 
-def test_record():
+def record_common(dag, s):
+    target = tvm.target.create("llvm")
+    task = auto_scheduler.SearchTask(dag, "test", target)
+
+    inp = auto_scheduler.measure.MeasureInput(task, s)
+    res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
+
+    with tempfile.NamedTemporaryFile() as fp:
+        auto_scheduler.save_records(fp.name, [inp], [res])
+
+        log_reader = auto_scheduler.RecordReader(fp.name)
+        inputs, results = log_reader.read_lines()
+        assert len(inputs) == 1
+
+        s1 = dag.infer_bound_from_state(s)
+        s2 = dag.infer_bound_from_state(inputs[0].state)
+
+        assert s1 == s2
+        assert not (s1 == dag.get_init_state())
+
+
+def test_record_split_reorder_fuse_annotation():
     if not tvm.runtime.enabled("llvm"):
         return
 
@@ -32,16 +53,8 @@ def test_record():
     B = te.placeholder((512, 512), name='B')
     k = te.reduce_axis((0, 512), name='k')
     C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
-    D = topi.nn.relu(C)
-    k = te.reduce_axis((0, 512), name='k')
-    E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E')
-    F = topi.nn.relu(E)
-    k = te.reduce_axis((0, 512), name='k')
-    G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G')
-    H = topi.nn.relu(G)
-    I = topi.nn.relu(H)
 
-    dag = auto_scheduler.ComputeDAG([A, B, I])
+    dag = auto_scheduler.ComputeDAG([A, B, C])
     s = dag.get_init_state()
 
     # Split
@@ -52,13 +65,6 @@ def test_record():
                   its1[3]])
     # Fuse
     s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]])
-    # Compute at
-    s.split(F, s[F].iters[0], [2])
-    s.compute_at(E, F, s[F].iters[0])
-    # Compute inline
-    s.compute_inline(D)
-    # Compute root
-    s.compute_root(D)
     # Parallel
     s.parallel(C, s[C].iters[0])
     # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here)
@@ -69,46 +75,93 @@ def test_record():
     s.unroll(C, s[C].iters[4])
     # Vectorize
     s.vectorize(C, s[C].iters[6])
-    # Cache Read
-    D_global = s.cache_read(D, "global", [E])
-    s.compute_at(D_global, E, s[E].iters[2])
+
+    record_common(dag, s)
+
+
+def test_record_compute_at_root_inline_cache_read_write():
+    if not tvm.runtime.enabled("llvm"):
+        return
+
+    A = te.placeholder((512, 512), name='A')
+    AA = topi.nn.relu(A)
+    B = te.placeholder((512, 512), name='B')
+    k = te.reduce_axis((0, 512), name='k')
+    C = te.compute((512, 512), lambda i, j: te.sum(AA[i][k] * B[k][j], axis=[k]), name='C')
+
+    dag = auto_scheduler.ComputeDAG([A, B, C])
+    s = dag.get_init_state()
+
     # Cache Write
-    s.cache_write(D, "shared")
-    #follow_split
-    its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True)
+    C_shared = s.cache_write(C, "shared")
+    # Compute At
+    s.compute_at(C_shared, C, s[C].iters[0])
+    # Cache Read
+    B_global = s.cache_read(B, "global", [C_shared])
+    s.compute_at(B_global, C_shared, s[C_shared].iters[2])
+    # Compute Inline
+    s.compute_inline(AA)
+    # Compute Root
+    s.compute_root(C_shared)
+
+    record_common(dag, s)
+
+
+def test_record_follow_split_follow_fused_split():
+    if not tvm.runtime.enabled("llvm"):
+        return
+
+    A = te.placeholder((512, 512), name='A')
+    B = te.placeholder((512, 512), name='B')
+    k = te.reduce_axis((0, 512), name='k')
+    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
+    D = topi.nn.relu(C)
+    E = topi.nn.relu(D)
+
+    dag = auto_scheduler.ComputeDAG([A, B, E])
+    s = dag.get_init_state()
+
+    # Follow Split
+    s.split(C, s[C].iters[0], [4, 2, 8, 4], True)
     split_step0 = len(s.transform_steps) - 1
-    s.follow_split(G, s[G].iters[5], split_step0, 4)
-    #follow_fused_split
-    its2 = s.split(H, s[H].iters[0], [4, 2, 8, 4], True)
+    s.follow_split(C, s[C].iters[5], split_step0, 4)
+    # Follow Fused Split
+    its0 = s.split(E, s[E].iters[0], [4, 2, 8, 4], True)
     split_step1 = len(s.transform_steps) - 1
-    its3 = s.split(H, s[H].iters[5], [2, 4, 2, 4], True)
+    its1 = s.split(E, s[E].iters[5], [2, 4, 2, 4], True)
     split_step2 = len(s.transform_steps) - 1
     its = []
-    for i0, i1 in zip(its2, its3):
+    for i0, i1 in zip(its0, its1):
         its.append(i0)
         its.append(i1)
     for i in range(0, 5):
-        s.fuse(H, [s[H].iters[i], s[H].iters[i + 1]])
-    s.follow_fused_split(I, s[I].iters[0], [split_step1, split_step2], 0, False)
+        s.fuse(E, [s[E].iters[i], s[E].iters[i + 1]])
+    s.follow_fused_split(D, s[D].iters[0], [split_step1, split_step2], 2, True)
 
-    target = tvm.target.create("llvm")
-    task = auto_scheduler.SearchTask(dag, "test", target)
+    record_common(dag, s)
 
-    inp = auto_scheduler.measure.MeasureInput(task, s)
-    res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
 
-    with tempfile.NamedTemporaryFile() as fp:
-        auto_scheduler.save_records(fp.name, [inp], [res])
+def test_record_pragma_storage_align_rfactor():
+    if not tvm.runtime.enabled("llvm"):
+        return
 
-        log_reader = auto_scheduler.RecordReader(fp.name)
-        inputs, results = log_reader.read_lines()
-        assert len(inputs) == 1
+    A = te.placeholder((512, 512), name='A')
+    B = te.placeholder((512, 512), name='B')
+    k = te.reduce_axis((0, 512), name='k')
+    C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
 
-        s1 = dag.infer_bound_from_state(s)
-        s2 = dag.infer_bound_from_state(inputs[0].state)
+    dag = auto_scheduler.ComputeDAG([A, B, C])
+    s = dag.get_init_state()
 
-        assert s1 == s2
-        assert not (s1 == dag.get_init_state())
+    # Rfactor
+    ko, _ = s.split(C, s[C].iters[2], [16])
+    s.rfactor(C, ko, 2)
+    # Pragma
+    s.pragma(C, s[C].iters[0], "auto_unroll_max_step$64")
+    # StorageAlign
+    s.storage_align(C, s[C].iters[-1], 8, 4)
+
+    record_common(dag, s)
 
 
 def test_measure_local_builder_runner():
@@ -149,6 +202,9 @@ def test_measure_local_builder_rpc_runner():
 
 
 if __name__ == "__main__":
-    test_record()
+    test_record_split_reorder_fuse_annotation()
+    test_record_compute_at_root_inline_cache_read_write()
+    test_record_follow_split_follow_fused_split()
+    test_record_pragma_storage_align_rfactor()
     test_measure_local_builder_runner()
     test_measure_local_builder_rpc_runner()