[Ansor][AutoTVM v2.0] Phase 1: Add cache_read/cache_write steps (#6107)
authorChenfan <chengfan.jcf@alibaba-inc.com>
Mon, 27 Jul 2020 04:35:52 +0000 (12:35 +0800)
committerGitHub <noreply@github.com>
Mon, 27 Jul 2020 04:35:52 +0000 (21:35 -0700)
* Add cache_read/cache_write step

* Update

* Update

* Update

* Update state->current_compute_dag to Optional

* Update

* Update doc

* Update

* Update

* Doc update

* Update

include/tvm/auto_scheduler/compute_dag.h
include/tvm/auto_scheduler/loop_state.h
include/tvm/auto_scheduler/transform_step.h
python/tvm/auto_scheduler/compute_dag.py
python/tvm/auto_scheduler/loop_state.py
src/auto_scheduler/compute_dag.cc
src/auto_scheduler/loop_state.cc
src/auto_scheduler/transform_step.cc
tests/python/unittest/test_auto_scheduler_loop_state.py
tests/python/unittest/test_auto_scheduler_measure.py

index 71652fd..69b74bf 100644 (file)
@@ -238,6 +238,17 @@ class ComputeDAG : public ObjectRef {
    */
   State InferBound(const State& state) const;
 
+  /*!
+   * \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial
+   * ComputeDAG may not be up-to-date. This function replays the given transform steps from the
+   * initial state and returns an up-to-date ComputeDAG.
+   * \param steps The steps to be replaied. Usually we'll filter out the unused steps to speed up
+   * the replay process, since we only intend to get a ComputeDAG with the up-to-date op stage
+   * structure.
+   * \return The up-to-date ComputeDAG.
+   */
+  ComputeDAG ReplayAndGetDAG(const Array<Step>& steps) const;
+
   TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
 };
index 4e9cb9b..1c8ea77 100644 (file)
@@ -182,16 +182,18 @@ class AttachMap : public ObjectRef {
  public:
   /*!
    * \brief Process the stage/iterator mapping after compute at.
-   * \param stage_id The index of the stage to be compute at.
+   * \param stage_id The index of the stage to be computed at.
    * \param target_stage_id The index of stage that this step will compute at to.
    * \param target_iter_id The index of iterator in target stage that this step will compute at to.
    */
   void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
+
   /*!
    * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage.
-   * \param stage_id The index of the stage to be compute at.
+   * \param stage_id The index of the stage to be computed at.
    */
   void DeleteStage(int stage_id);
+
   /*!
    * \brief Find the relations of original iterators in AttachMap, and update them with the new
    * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
@@ -201,6 +203,17 @@ class AttachMap : public ObjectRef {
   void UpdateIters(const std::vector<IterKey>& original_iters,
                    const std::vector<IterKey>& new_iters);
 
+  /*!
+   * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset
+   * to stage indexes that are larger than the start_id. Used for steps that insert new stages to
+   * ComputeDAG(e.g. CacheRead/CacheWrite step).
+   * \param start_id The index threshold, stage indexes in AttachMap which are larger than this
+   * will be applied the extra offset.
+   * \param offset The index offset to be added to the stage index.
+   * \return The updated AttachMap after applying stage index offset.
+   */
+  AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const;
+
   TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode);
 
@@ -231,6 +244,12 @@ class StateNode : public Object {
    * operation.
    */
   AttachMap attach_map;
+  /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, means
+   * no modification to the original ComputeDAG.
+   * Otherwise, it means some steps (e.g., CacheReadStep/CacheWriteStep) have modified the
+   * ComputeDAG, the stored value is the up-to-date ComputeDAG for this state.
+   */
+  Optional<ObjectRef> current_compute_dag;
   /*!
    * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all
    * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule.
@@ -245,15 +264,6 @@ class StateNode : public Object {
 
   static constexpr const char* _type_key = "auto_scheduler.State";
   TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object);
-
- private:
-  /*!
-   * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the
-   * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added
-   * later).
-   * The default value is an empty ObjectRef. (means no modification to the original DAG)
-   */
-  ObjectRef current_compute_dag;
 };
 
 /*!
@@ -290,7 +300,7 @@ class State : public ObjectRef {
   /********** Step APIs working on single stage **********/
 
   /*!
-   * \brief Schedule primitive corresponds to te.bind.
+   * \brief Schedule primitive corresponds to `te::Stage::bind`.
    * \param stage_id The index of the stage to be binded.
    * \param it The iterator to be binded.
    * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as
@@ -299,14 +309,14 @@ class State : public ObjectRef {
    */
   TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
   /*!
-   * \brief Schedule primitive corresponds to te.parallel.
+   * \brief Schedule primitive corresponds to `te::Stage::parallel`.
    * \param stage_id The index of the stage to be paralleled.
    * \param it The iterator to be paralleled.
    * \return The iterator result after parallel.
    */
   TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
   /*!
-   * \brief Schedule primitive corresponds to te.unroll.
+   * \brief Schedule primitive corresponds to `te::Stage::unroll`.
    * \param stage_id The index of the stage to be unrolled.
    * \param it The iterator to be unrolled.
    * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be
@@ -315,14 +325,14 @@ class State : public ObjectRef {
    */
   TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
   /*!
-   * \brief Schedule primitive corresponds to te.vectorize.
+   * \brief Schedule primitive corresponds to `te::Stage::vectorize`.
    * \param stage_id The index of the stage to be vectorized.
    * \param it The iterator to be vectorized.
    * \return The iterator result after vectorize.
    */
   TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
   /*!
-   * \brief Schedule primitive corresponds to te.fuse.
+   * \brief Schedule primitive corresponds to `te::Stage::fuse`.
    * \param stage_id The index of the stage to be fused.
    * \param iters The iterators to be fused.
    * \return The iterator result after fuse.
@@ -331,13 +341,13 @@ class State : public ObjectRef {
    */
   TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
   /*!
-   * \brief Schedule primitive corresponds to te.reorder.
+   * \brief Schedule primitive corresponds to `te::Stage::reorder`.
    * \param stage_id The index of the stage to be reordered.
    * \param order The expected iterator order.
    */
   TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
   /*!
-   * \brief Schedule primitive corresponds to te.split.
+   * \brief Schedule primitive corresponds to `te::Stage::split`.
    * \param stage_id The index of the stage to be split.
    * \param it The iterator to be split.
    * \param lengths The multiple split factors. Can be None to be filled by search policy.
@@ -353,8 +363,8 @@ class State : public ObjectRef {
   /********** Step APIs working on multiple stages **********/
 
   /*!
-   * \brief Schedule primitive corresponds to te.compute_at.
-   * \param stage_id The index of the stage to be reordered.
+   * \brief Schedule primitive corresponds to `te::Stage::compute_at`.
+   * \param stage_id The index of the stage to be computed at.
    * \param target_stage_id The index of stage that this step will compute at to.
    * \param target_iter The iterator in target stage that this step will compute at to.
    * \note After compute_at, we need careful dependency analysis to compute the accurate bound
@@ -364,13 +374,13 @@ class State : public ObjectRef {
    */
   TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
   /*!
-   * \brief Schedule primitive corresponds to te.compute_inline.
-   * \param stage_id The index of the stage to be reordered.
+   * \brief Schedule primitive corresponds to `te::Stage::compute_inline`.
+   * \param stage_id The index of the stage to be marked compute inlined.
    */
   TVM_DLL void compute_inline(int stage_id);
   /*!
-   * \brief Schedule primitive corresponds to te.compute_root.
-   * \param stage_id The index of the stage to be reordered.
+   * \brief Schedule primitive corresponds to `te::Stage::compute_root`.
+   * \param stage_id The index of the stage to be marked compute at root.
    * \note After compute_root, we need careful dependency analysis to compute the accurate bound
    * information. However, it is relatively expensive and complicated, so we just fill "None" as
    * bound for the newly created iterators.
@@ -378,6 +388,30 @@ class State : public ObjectRef {
    */
   TVM_DLL void compute_root(int stage_id);
 
+  /********** Step APIs adding new stages **********/
+
+  /*!
+   * \brief Schedule primitive corresponds to `te::Schedule::cache_read`.
+   * \param stage_id The index of the stage to be cache read.
+   * \param scope_name The scope name of the newly added read stage.
+   * \param reader_stage_ids The indices of read stages.
+   * \param dag The original ComputeDAG of this state.
+   * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
+   * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+   */
+  int cache_read(int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids,
+                 const ComputeDAG& dag);
+  /*!
+   * \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
+   * \param stage_id The index of the stage to be cache write.
+   * \param scope_name The scope name of the newly added compute stage.
+   * \param dag The original ComputeDAG of this state.
+   * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the
+   * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+   * This step will cache write all output tensors of the target stage.
+   */
+  int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
+
   TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
 };
index b23137a..83d6e29 100644 (file)
@@ -56,6 +56,13 @@ namespace auto_scheduler {
 
 typedef Map<tvm::te::Stage, Array<tir::IterVar>, ObjectHash, ObjectEqual> StageToAxesMap;
 
+/*!
+ * \brief Update the current stage IterVar information to StageToAxesMap.
+ * \param stage A te::Stage Object.
+ * \param stage_to_axes A mutable pointer to StageToAxesMap, this map will be updated.
+ */
+void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes);
+
 /*! \brief The type of an iterator. */
 enum class IteratorKind : int {
   /*! \brief Spatial iterator. */
@@ -183,7 +190,7 @@ Step StepReadFromRecord(dmlc::JSONReader* reader);
 /*!
  * \brief Apply the step to State.
  * \param step The step to be applied to State.
- * \param state A mutable pointer to State.
+ * \param state A mutable pointer to state, which will be updated.
  * \param dag The original ComputeDAG of this state.
  */
 void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
@@ -191,20 +198,25 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
 /*!
  * \brief Apply the step to tvm.schedule.
  * \param step The step to be applied to tvm.schedule.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need
+ * `te::Schedule` API. (e.g. CacheRead/CacheWrite step)
  */
-void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes);
+void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                         te::Schedule* schedule);
 
 /*!
  * \brief Print the step as equivalent python schedule API.
  * \param step The step to be applied to python API.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g.
+ * CacheRead/CacheWrite step)
  * \return Python schedule code.
  */
 String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
-                            StageToAxesMap* stage_to_axes);
+                            StageToAxesMap* stage_to_axes, te::Schedule* schedule);
 
 /********** Steps working on single stage **********/
 
@@ -223,22 +235,22 @@ class AnnotationStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to State.
-   * \param state A mutable pointer to State.
+   * \param state A mutable pointer to state, which will be updated.
    * \return The iterator result after annotate.
    */
   Iterator ApplyToState(State* state) const;
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -283,7 +295,7 @@ class FuseStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to State.
-   * \param state A mutable pointer to State.
+   * \param state A mutable pointer to state, which will be updated.
    * \return The iterator result after fuse.
    * \note If the iterators to be fused have stages attached at them(by compute_at), the fused
    * result will become the new attach point.
@@ -292,16 +304,16 @@ class FuseStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return The iterator result after fuse.
    */
   tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -348,21 +360,21 @@ class ReorderStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to State.
-   * \param state A mutable pointer to State.
+   * \param state A mutable pointer to state, which will be updated.
    */
   void ApplyToState(State* state) const;
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -418,7 +430,7 @@ class SplitStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to State.
-   * \param state A mutable pointer to State.
+   * \param state A mutable pointer to state, which will be updated.
    * \return The iterator results after split.
    * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
    * most iterator of split results will become the new attach point.
@@ -427,8 +439,8 @@ class SplitStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return The iterator results after split.
    */
   Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
@@ -436,8 +448,8 @@ class SplitStepNode : public StepNode {
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -489,7 +501,7 @@ class ComputeAtStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to State.
-   * \param state A mutable pointer to State.
+   * \param state A mutable pointer to state, which will be updated.
    * \note After compute_at, we need careful dependency analysis to compute the accurate bound
    * information. However, it is relatively expensive and complicated, so we just fill "None" as
    * bound for the newly created iterators.
@@ -499,15 +511,15 @@ class ComputeAtStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -526,7 +538,7 @@ class ComputeAtStep : public Step {
  public:
   /*!
    * \brief The constructor.
-   * \param stage_id The index of the stage to be compute at.
+   * \param stage_id The index of the stage to be computed at.
    * \param target_stage_id The index of stage that this step will compute at to.
    * \param target_iter_id The index of iterator in target stage that this step will compute at to.
    */
@@ -549,22 +561,22 @@ class ComputeInlineStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to State.
-   * \param state A mutable pointer to State.
+   * \param state A mutable pointer to state, which will be updated.
    */
   void ApplyToState(State* state) const;
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return The iterator result after fuse.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -583,7 +595,7 @@ class ComputeInlineStep : public Step {
  public:
   /*!
    * \brief The constructor.
-   * \param stage_id The index of the stage to be compute inline.
+   * \param stage_id The index of the stage to be marked compute inlined.
    */
   explicit ComputeInlineStep(int stage_id);
 
@@ -604,8 +616,8 @@ class ComputeRootStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to State.
-   * \param state A mutable pointer to State.
-   * \note After compute_at, we need careful dependency analysis to compute the accurate bound
+   * \param state A mutable pointer to state, which will be updated.
+   * \note After compute_root, we need careful dependency analysis to compute the accurate bound
    * information. However, it is relatively expensive and complicated, so we just fill "None" as
    * bound for the newly created iterators.
    * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
@@ -614,16 +626,16 @@ class ComputeRootStepNode : public StepNode {
 
   /*!
    * \brief Apply the current step to tvm.schedule.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return The iterator result after fuse.
    */
   void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
 
   /*!
    * \brief Print the current step as equivalent python schedule API.
-   * \param stages A pointer to a `te::Stage` Array.
-   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
    * \return Python schedule code.
    */
   String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
@@ -642,7 +654,7 @@ class ComputeRootStep : public Step {
  public:
   /*!
    * \brief The constructor.
-   * \param stage_id The index of the stage to be compute root
+   * \param stage_id The index of the stage to be marked compute at root.
    */
   explicit ComputeRootStep(int stage_id);
 
@@ -656,6 +668,150 @@ class ComputeRootStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
 };
 
+/********** Primitives adding new stages **********/
+
+/*!
+ * \brief Cache read step that corresponds to te::Schedule::cache_read.
+ * \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date ComputeDAG
+ * is stored in State's `current_compute_dag`.
+ */
+class CacheReadStepNode : public StepNode {
+ public:
+  /*! \brief The scope name of the newly added read stage. (e.g. local, shared, global) */
+  String scope_name;
+  /*! \brief The indices of read stages. */
+  Array<Integer> reader_stage_ids;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to state, which will be updated.
+   * \param dag The original ComputeDAG of this state.
+   * \return The index of the new added stage.
+   */
+  int ApplyToState(State* state, const ComputeDAG& dag) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return The output Tensor of the new added stage.
+   */
+  te::Tensor ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                             te::Schedule* schedule) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          te::Schedule* schedule) const;
+
+  static constexpr const char* record_prefix_str = "CHR";
+
+  static constexpr const char* _type_key = "auto_scheduler.CacheReadStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to CacheReadStepNode.
+ * \sa CacheReadStepNode
+ */
+class CacheReadStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be cache read.
+   * \param scope_name The scope name of the newly added read stage.
+   * \param reader_stage_ids The indices of read stages.
+   */
+  CacheReadStep(int stage_id, String scope_name, const Array<Integer>& reader_stage_ids);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit CacheReadStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode);
+};
+
+/*!
+ * \brief Cache write step that corresponds to te::Schedule::cache_write.
+ * \note Cache write step will add an extra stage to the original ComputeDAG, a up-to-date
+ * ComputeDAG is stored in State's `current_compute_dag`.
+ * This step will cache write all output tensors of the target stage.
+ */
+class CacheWriteStepNode : public StepNode {
+ public:
+  /*! \brief The scope name of the newly added compute stage. (e.g. local, shared, global) */
+  String scope_name;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to state, which will be updated.
+   * \param dag The original ComputeDAG of this state.
+   * \return The index of the new added stage.
+   */
+  int ApplyToState(State* state, const ComputeDAG& dag) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return The output Tensors of the new added stage.
+   */
+  Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                    te::Schedule* schedule) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages The `te::Stage`s used in TVM scheduler applying.
+   * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          te::Schedule* schedule) const;
+
+  static constexpr const char* record_prefix_str = "CHW";
+
+  static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to CacheWriteStepNode.
+ * \sa CacheWriteStepNode
+ */
+class CacheWriteStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be cache write.
+   * \param scope_name The scope name of the newly added compute stage.
+   */
+  CacheWriteStep(int stage_id, String scope_name);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit CacheWriteStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode);
+};
+
 }  // namespace auto_scheduler
 }  // namespace tvm
 
index 115d28b..e08454f 100644 (file)
@@ -126,11 +126,17 @@ class ComputeDAG(Object):
 
         Returns
         -------
-        state : State
+        updated_state : State
             The State with complete bound information.
         """
         state_obj = state if isinstance(state, StateObject) else state.state_object
-        return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)
+        updated_state = State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)
+        # Copy the stage_id_map from the original state to make sure the old indices are still
+        # valid
+        if isinstance(state, State):
+            for k, v in state.stage_id_map.items():
+                updated_state.stage_id_map[k] = v
+        return updated_state
 
     def __hash__(self):
         # TODO(merrymercy): Implement this more carefully and move this to c++ as a member function
index ab041cf..8c3a936 100644 (file)
@@ -127,7 +127,8 @@ class State:
         return [stage.op for stage in self.stages]
 
     def bind(self, stage, iterator, thread_name):
-        """ Schedule primitive corresponds to te.bind.
+        """ Schedule primitive corresponds to `te.Stage.bind`, see also the `te.Stage` for more
+        details.
 
         Parameters
         ----------
@@ -160,7 +161,8 @@ class State:
         return res
 
     def parallel(self, stage, iterator):
-        """ Schedule primitive corresponds to te.parallel.
+        """ Schedule primitive corresponds to `te.Stage.parallel`, see also the `te.Stage` for more
+        details.
 
         Parameters
         ----------
@@ -180,7 +182,8 @@ class State:
         return res
 
     def unroll(self, stage, iterator, max_unroll=None):
-        """ Schedule primitive corresponds to te.unroll.
+        """ Schedule primitive corresponds to `te.Stage.unroll`, see also the `te.Stage` for more
+        details.
 
         Parameters
         ----------
@@ -203,7 +206,8 @@ class State:
         return res
 
     def vectorize(self, stage, iterator):
-        """ Schedule primitive corresponds to te.vectorize.
+        """ Schedule primitive corresponds to `te.Stage.vectorize`, see also the `te.Stage` for
+        more details.
 
         Parameters
         ----------
@@ -223,7 +227,8 @@ class State:
         return res
 
     def fuse(self, stage, iters):
-        """ Schedule primitive corresponds to te.fuse.
+        """ Schedule primitive corresponds to `te.Stage.fuse`, see also the `te.Stage` for more
+        details.
 
         Parameters
         ----------
@@ -248,7 +253,8 @@ class State:
         return res
 
     def reorder(self, stage, order):
-        """ Schedule primitive corresponds to te.reorder.
+        """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more
+        details.
 
         Parameters
         ----------
@@ -262,7 +268,8 @@ class State:
                                                   order)
 
     def split(self, stage, iterator, lengths, inner_to_outer=True):
-        """ Schedule primitive corresponds to te.split.
+        """ Schedule primitive corresponds to `te.Stage.split`, see also the `te.Stage` for more
+        details.
 
         This API supports multiple split factors. (e.g. with 2 split factors, the original iterator
         will be split to 3 parts, use `inner_to_outer` to control the split order)
@@ -295,12 +302,13 @@ class State:
         return res
 
     def compute_at(self, stage, target_stage, target_iter):
-        """ Schedule primitive corresponds to te.compute_at.
+        """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
+        more details.
 
         Parameters
         ----------
         stage : Union[int, Operation, Tensor]
-            The Stage to be compute at, which can be specified by the integer index, Operation,
+            The Stage to be computed at, which can be specified by the integer index, Operation,
             or output tensor of the stage.
         target_stage : Union[int, Operation, Tensor]
             The target stage of compute_at, which can be specified by the integer index, Operation,
@@ -321,25 +329,27 @@ class State:
                                                     target_iter)
 
     def compute_inline(self, stage):
-        """ Schedule primitive corresponds to te.compute_inline.
+        """ Schedule primitive corresponds to `te.Stage.compute_inline`, see also the `te.Stage`
+        for more details.
 
         Parameters
         ----------
         stage : Union[int, Operation, Tensor]
-            The Stage to be compute inlined, which can be specified by the integer index, Operation,
-            or output tensor of the stage.
+            The Stage to be marked compute inlined, which can be specified by the integer index,
+            Operation, or output tensor of the stage.
         """
         self.state_object = _ffi_api.StateComputeInline(self.state_object,
                                                         self._resolve_stage_id(stage))
 
     def compute_root(self, stage):
-        """ Schedule primitive corresponds to te.compute_root.
+        """ Schedule primitive corresponds to `te.Stage.compute_root`, see also the `te.Stage` for
+        more details.
 
         Parameters
         ----------
         stage : Union[int, Operation, Tensor]
-            The Stage to be compute root, which can be specified by the integer index, Operation,
-            or output tensor of the stage.
+            The Stage to be marked compute at root, which can be specified by the integer index,
+            Operation, or output tensor of the stage.
 
         Notes
         -----
@@ -351,6 +361,74 @@ class State:
         self.state_object = _ffi_api.StateComputeRoot(self.state_object,
                                                       self._resolve_stage_id(stage))
 
+    def cache_read(self, stage, scope_name, reader_stages):
+        """ Schedule primitive corresponds to `te.Schedule.cache_read`, see also the `te.Schedule`
+        for more details.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be cache read, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        scope_name : str
+            The scope name of the newly added read stage.
+        reader_stages : List[Union[int, Operation, Tensor]]
+            The reader stages. Each of the list can be specified by the integer index, Operation,
+            or output tensor of the stage.
+
+        Returns
+        -------
+        new_stage_op : Operator
+            The Operator of the new added stage.
+
+        Notes
+        -----
+        Cache read step will insert an extra stage to the original ComputeDAG (at the back of the
+        target stage).
+        """
+        reader_stage_ids = [self._resolve_stage_id(i) for i in reader_stages]
+        self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object,
+                                                                  self._resolve_stage_id(stage),
+                                                                  scope_name, reader_stage_ids,
+                                                                  self.compute_dag)
+        # Add a new stage will change all ops behind the added stage. But we still want to keep the
+        # original ops map, apply stage id offset to stage_id_map to make them work.
+        self._apply_stage_id_offset(int(new_stage_id))
+        self._update_stage_id_map()
+        return self.stages[int(new_stage_id)].op
+
+    def cache_write(self, stage, scope_name):
+        """ Schedule primitive corresponds to `te.Schedule.cache_write`, see also the `te.Schedule`
+        for more details.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be cache write, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        scope_name : str
+            The scope name of the newly added compute stage.
+
+        Returns
+        -------
+        new_stage_op : Operator
+            The Operator of the new added stage.
+
+        Notes
+        -----
+        Cache write step will insert an extra stage to the original ComputeDAG (in the front of the
+        target stage).
+        This step will cache write all output tensors of the target stage.
+        """
+        self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object,
+                                                                   self._resolve_stage_id(stage),
+                                                                   scope_name, self.compute_dag)
+        # Add a new stage will change all ops behind the added stage. But we still want to keep the
+        # original ops map, apply stage id offset to stage_id_map to make them work.
+        self._apply_stage_id_offset(int(new_stage_id))
+        self._update_stage_id_map()
+        return self.stages[int(new_stage_id)].op
+
     def copy(self):
         """ Do deep copy of this State. """
         state = State(self.state_object, self.compute_dag)
@@ -371,6 +449,11 @@ class State:
         for index, stage in enumerate(self.stages):
             self.stage_id_map[stage.op] = index
 
+    def _apply_stage_id_offset(self, start_id, offset=1):
+        for key, value in self.stage_id_map.items():
+            if value >= start_id:
+                self.stage_id_map[key] = value + offset
+
     def __getitem__(self, key):
         if isinstance(key, Tensor):
             key = key.op
index 68d1bb4..2f6e948 100644 (file)
@@ -645,24 +645,6 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
   data_ = std::move(node);
 }
 
-// Update the te::stage to tir::IterVar axis mapping
-void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) {
-  if (auto pop = stage->op.as<te::ComputeOpNode>()) {
-    Array<IterVar> axes;
-    for (const auto& axis : pop->axis) {
-      axes.push_back(axis);
-    }
-    for (const auto& axis : pop->reduce_axis) {
-      axes.push_back(axis);
-    }
-    stage_to_axes->Set(stage, std::move(axes));
-  } else if (stage->op->IsInstance<te::PlaceholderOpNode>()) {
-    {}  // do nothing on Placeholder
-  } else {
-    LOG(FATAL) << "Invalid op " << stage->op;
-  }
-}
-
 std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
     const Array<Step>& transform_steps, Array<te::Stage>* stages,
     StageToAxesMap* stage_to_axes) const {
@@ -696,7 +678,7 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
   // Apply the history steps to TVM schedule
   // Call each step's ApplyToSchedule method
   for (const auto& step : transform_steps) {
-    StepApplyToSchedule(step, stages, stage_to_axes);
+    StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
   }
 
   return std::make_pair(schedule, operator->()->tensors);
@@ -740,7 +722,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
   }
   // Call each step's PrintAsPythonAPI method
   for (const auto& step : transform_steps) {
-    ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes);
+    ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
   }
 
   return ss.str();
@@ -806,6 +788,23 @@ State ComputeDAG::InferBound(const State& state) const {
   return ret_state;
 }
 
+ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array<Step>& transform_steps) const {
+  te::Schedule sch;
+  Array<te::Tensor> old_tensors;
+  std::tie(sch, old_tensors) = ApplySteps(transform_steps);
+
+  Array<te::Tensor> new_tensors;
+  for (auto stage : sch->stages) {
+    if (stage->op->IsInstance<te::PlaceholderOpNode>() || stage->is_output) {
+      for (auto i = 0; i < stage->op->num_outputs(); ++i) {
+        new_tensors.push_back(stage->op.output(i));
+      }
+    }
+  }
+
+  return ComputeDAG(new_tensors);
+}
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<ComputeDAGNode>([](const ObjectRef& ref, ReprPrinter* p) {
       auto* node = static_cast<const ComputeDAGNode*>(ref.get());
index 35d899a..67c6b38 100644 (file)
@@ -23,6 +23,7 @@
  * see auto_scheduler/loop_state.h for more explanation.
  */
 
+#include <tvm/auto_scheduler/compute_dag.h>
 #include <tvm/auto_scheduler/loop_state.h>
 #include <tvm/auto_scheduler/transform_step.h>
 #include <tvm/runtime/registry.h>
@@ -150,6 +151,36 @@ void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) {
   }
 }
 
+AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const {
+  AttachMap map = AttachMap(make_object<AttachMapNode>());
+  auto pmap = map.CopyOnWrite();
+  for (const auto& x : operator->()->stage_to_attach_iter) {
+    auto key = x.first;
+    if (key >= start_id) {
+      key += offset;
+    }
+    auto value = x.second;
+    if (value.first >= start_id) {
+      value.first += offset;
+    }
+    pmap->stage_to_attach_iter.insert(std::make_pair(key, value));
+  }
+  for (const auto& x : operator->()->iter_to_attached_stages) {
+    auto key = x.first;
+    if (key.first >= start_id) {
+      key.first += offset;
+    }
+    auto value = x.second;
+    for (auto& i : value) {
+      if (i >= start_id) {
+        i += offset;
+      }
+    }
+    pmap->iter_to_attached_stages.insert(std::make_pair(key, value));
+  }
+  return map;
+}
+
 /********** State **********/
 State::State(const Array<te::Operation>& ops) {
   auto node = make_object<StateNode>();
@@ -257,6 +288,19 @@ void State::compute_root(int stage_id) {
   step->ApplyToState(this);
 }
 
+int State::cache_read(int stage_id, const String& scope_name,
+                      const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
+  CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this, dag);
+}
+
+int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) {
+  CacheWriteStep step = CacheWriteStep(stage_id, scope_name);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this, dag);
+}
+
 void State::ApplySteps(const ComputeDAG& dag) {
   CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
 
@@ -429,6 +473,20 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot")
       return state;
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead")
+    .set_body_typed([](State state, int stage_id, const String& scope_name,
+                       const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
+      int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag);
+      return Array<ObjectRef>{state, Integer(res)};
+    });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite")
+    .set_body_typed([](State state, int stage_id, const String& scope_name,
+                       const ComputeDAG& task_dag) {
+      int res = state.cache_write(stage_id, scope_name, task_dag);
+      return Array<ObjectRef>{state, Integer(res)};
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
   return std::equal_to<State>()(state1, state2);
 });
index b1b3b94..5c5cc4b 100644 (file)
@@ -23,6 +23,7 @@
  *        They are similar to the schedule primitives in te::Stage.
  */
 
+#include <tvm/auto_scheduler/compute_dag.h>
 #include <tvm/auto_scheduler/loop_state.h>
 #include <tvm/auto_scheduler/transform_step.h>
 #include <tvm/runtime/registry.h>
 namespace tvm {
 namespace auto_scheduler {
 
+// Update the te::stage to tir::IterVar axis mapping
+void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) {
+  if (auto pop = stage->op.as<te::ComputeOpNode>()) {
+    Array<IterVar> axes;
+    for (const auto& axis : pop->axis) {
+      axes.push_back(axis);
+    }
+    for (const auto& axis : pop->reduce_axis) {
+      axes.push_back(axis);
+    }
+    stage_to_axes->Set(stage, std::move(axes));
+  } else if (stage->op->IsInstance<te::PlaceholderOpNode>()) {
+    {}  // do nothing on Placeholder
+  } else {
+    LOG(FATAL) << "Invalid op " << stage->op;
+  }
+}
+
 const char* IteratorAnnotationString[] = {
     "for",          // kNone = 0
     "unroll",       // kUnroll = 1
@@ -72,6 +91,10 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) {
     return ComputeInlineStep(reader);
   } else if (name == ComputeRootStepNode::record_prefix_str) {
     return ComputeRootStep(reader);
+  } else if (name == CacheReadStepNode::record_prefix_str) {
+    return CacheReadStep(reader);
+  } else if (name == CacheWriteStepNode::record_prefix_str) {
+    return CacheWriteStep(reader);
   } else {
     LOG(FATAL) << "Invalid step format: " << name;
   }
@@ -94,14 +117,17 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<ComputeRootStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<CacheReadStepNode>()) {
+    ps->ApplyToState(state, dag);
+  } else if (auto ps = step.as<CacheWriteStepNode>()) {
+    ps->ApplyToState(state, dag);
   } else {
     LOG(FATAL) << "Invalid step: " << step;
   }
 }
 
-void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages,
-                         StageToAxesMap* stage_to_axes) {
-  // We need this runtime dispatcher because different steps have different function signatures
+void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                         te::Schedule* schedule) {
   if (auto ps = step.as<AnnotationStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<FuseStepNode>()) {
@@ -116,14 +142,17 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages,
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeRootStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<CacheReadStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes, schedule);
+  } else if (auto ps = step.as<CacheWriteStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes, schedule);
   } else {
     LOG(FATAL) << "Invalid Step: " << step;
   }
 }
 
 String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
-                            StageToAxesMap* stage_to_axes) {
-  // We need this runtime dispatcher because different steps have different function signatures
+                            StageToAxesMap* stage_to_axes, te::Schedule* schedule) {
   if (auto ps = step.as<AnnotationStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<FuseStepNode>()) {
@@ -138,6 +167,10 @@ String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeRootStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<CacheReadStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
+  } else if (auto ps = step.as<CacheWriteStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
   } else {
     LOG(FATAL) << "Invalid Step: " << step;
   }
@@ -925,5 +958,275 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Primitives adding new stages **********/
+
+/*!
+ * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep,
+ * RfactorStep). This will return all steps that can change the number of stages in a ComputeDAG,
+ * and stop by the current step.
+ */
+Array<Step> GetFormerStageModifiableSteps(Step current_step, const Array<Step>& transform_steps) {
+  Array<Step> ret_steps;
+  for (const Step& step : transform_steps) {
+    if (step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>()) {
+      ret_steps.push_back(step);
+    }
+    // TODO(jcf94): add rfactor support
+    // A state may have multiple stage modifiable steps, stop by the current step to avoid
+    // replaying excess steps
+    if (step.same_as(current_step)) {
+      break;
+    }
+  }
+  return ret_steps;
+}
+
+/********** Cache Read **********/
+CacheReadStep::CacheReadStep(int stage_id, String scope_name,
+                             const Array<Integer>& reader_stage_ids) {
+  auto node = make_object<CacheReadStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  node->reader_stage_ids = reader_stage_ids;
+  data_ = std::move(node);
+}
+
+CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheReadStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  Array<Integer> reader_stage_ids;
+  for (int i : int_list) {
+    reader_stage_ids.push_back(i);
+  }
+  node->reader_stage_ids = std::move(reader_stage_ids);
+  data_ = std::move(node);
+}
+
+void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+  writer->WriteArrayItem(IntArrayToVector(reader_stage_ids));
+}
+
+int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
+  StateNode* pstate = state->CopyOnWrite();
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(
+      GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
+
+  // target_stage -> target_stage + target_store
+  // Update the op of the target stage, insert a new cache read stage behind, update the op of
+  // later stages, then update the stage_id mapping in AttachMap
+  int added_stage_id = stage_id + 1;
+  Stage tmp_stage = pstate->stages[stage_id];
+  tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id];
+  pstate->stages.Set(stage_id, std::move(tmp_stage));
+  pstate->stages.insert(pstate->stages.begin() + added_stage_id,
+                        Stage(current_compute_dag->ops[added_stage_id]));
+  for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) {
+    tmp_stage = pstate->stages[i];
+    tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+    pstate->stages.Set(i, std::move(tmp_stage));
+  }
+  pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id);
+  pstate->current_compute_dag = std::move(current_compute_dag);
+
+  return added_stage_id;
+}
+
+te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                              StageToAxesMap* stage_to_axes,
+                                              te::Schedule* schedule) const {
+  const te::Stage& stage = (*stages)[stage_id];
+  Array<te::Operation> readers;
+  for (const auto& i : reader_stage_ids) {
+    readers.push_back((*stages)[i]->origin_op);
+  }
+  auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers);
+
+  const auto& new_stage = (*schedule)[out->op];
+  UpdateStageToAxesMap(new_stage, stage_to_axes);
+  stages->insert(stages->begin() + stage_id + 1, new_stage);
+
+  return out;
+}
+
+String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                           te::Schedule* schedule) const {
+  std::stringstream ss;
+  // Since the original stage will be changed after schedule apply, keep a copy here
+  // These information will be used to print Python API string later
+  auto stage = (*stages)[stage_id];
+  Array<te::Stage> reader_stages;
+  for (size_t i = 0; i < reader_stage_ids.size(); ++i) {
+    reader_stages.push_back((*stages)[reader_stage_ids[i]]);
+  }
+  auto out = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+  ss << CleanName(out->op->name) << " = "
+     << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", ["
+     << CleanName(reader_stages[0]->op->name);
+  for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
+    ss << ", " << CleanName(reader_stages[i]->op->name);
+  }
+  ss << "])\n";
+
+  // Print the iterators of the new added stage
+  const auto& iters = out->op->root_iter_vars();
+  for (size_t i = 0; i < iters.size(); ++i) {
+    ss << CleanName(iters[i]->var->name_hint);
+    if (i != iters.size() - 1) {
+      ss << ", ";
+    }
+  }
+  ss << " = "
+     << "tuple(" << CleanName(out->op->name) << ".op.axis)\n";
+
+  return ss.str();
+}
+
+/********** Cache Write **********/
+CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) {
+  auto node = make_object<CacheWriteStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  data_ = std::move(node);
+}
+
+CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheWriteStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  data_ = std::move(node);
+}
+
+void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+}
+
+int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
+  StateNode* pstate = state->CopyOnWrite();
+  int last_dag_op_size = pstate->current_compute_dag
+                             ? pstate->current_compute_dag.value().as<ComputeDAGNode>()->ops.size()
+                             : dag->ops.size();
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(
+      GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
+  int added_ops = current_compute_dag->ops.size() - last_dag_op_size;
+  // TODO(jcf94): Update this check to equal after fixing the cache write bug in TVM
+  CHECK_GE(added_ops, 1);
+
+  // target_stage -> cache_write_stage + target_stage
+  // Assume no step has been applied to the target stage before cache write.
+  // Insert a new cache write stage ahead, update the op of the target stage and later stages, then
+  // update the stage_id mapping in AttachMap
+  pstate->stages.insert(pstate->stages.begin() + stage_id,
+                        Stage(current_compute_dag->ops[stage_id]));
+  pstate->stages.Set(stage_id + 1, Stage(current_compute_dag->ops[stage_id + 1]));
+  int next_stage_id = stage_id + 2;
+  // TODO(jc94): Fix the cache write bug in TVM and remove added_op == 2 support.
+  // TVM's cache_write has a bug with multi outputs. See
+  // `tests/python/unittest/test_auto_scheduler_loop_state.py::test_cache_read_write` test
+  // for more details
+  if (added_ops == 2) {
+    pstate->stages.insert(pstate->stages.begin() + next_stage_id,
+                          Stage(current_compute_dag->ops[next_stage_id]));
+    next_stage_id++;
+  } else if (added_ops > 2) {
+    LOG(ERROR) << "Unexpected behavior of CacheWrite.";
+  }
+  for (size_t i = next_stage_id; i < current_compute_dag->ops.size(); ++i) {
+    Stage tmp_stage = pstate->stages[i];
+    tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+    pstate->stages.Set(i, std::move(tmp_stage));
+  }
+  pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id, added_ops);
+  pstate->current_compute_dag = std::move(current_compute_dag);
+
+  return stage_id;
+}
+
+Array<te::Tensor> CacheWriteStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                      StageToAxesMap* stage_to_axes,
+                                                      te::Schedule* schedule) const {
+  const te::Stage& stage = (*stages)[stage_id];
+  Array<te::Tensor> tensor_array;
+  // If the target stage has multi outputs, TVM requires to cache_write
+  // all of them or schedule.cache_write will raise an error
+  for (auto i = 0; i < stage->op->num_outputs(); ++i) {
+    tensor_array.push_back(stage->origin_op.output(i));
+  }
+  auto outs = schedule->cache_write(tensor_array, scope_name);
+
+  UpdateStageToAxesMap(stage, stage_to_axes);
+  // Even if there is multi outputs, TVM schedule only generate one
+  // new stage
+  const auto& new_stage = (*schedule)[outs[0]->op];
+  UpdateStageToAxesMap(new_stage, stage_to_axes);
+  stages->insert(stages->begin() + stage_id, new_stage);
+
+  return outs;
+}
+
+String CacheWriteStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                            te::Schedule* schedule) const {
+  std::stringstream ss;
+  // Since the original stage will be changed after schedule apply, keep a copy here
+  // These information will be used to print Python API string later
+  te::Stage stage = (*stages)[stage_id];
+  auto outs = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+  for (size_t i = 0; i < outs.size(); ++i) {
+    ss << CleanName(outs[i]->op->name) << ", ";
+  }
+  ss << "= "
+     << "s.cache_write([" << CleanName(stage->op.output(0)->op->name);
+  for (auto i = 1; i < stage->op->num_outputs(); ++i) {
+    ss << ", " << CleanName(stage->op.output(i)->op->name);
+  }
+  ss << "], \"" << scope_name << "\")\n";
+
+  // Print the iterators of the new added stage
+  for (const auto& out : outs) {
+    const auto& iters = out->op->root_iter_vars();
+    for (size_t i = 0; i < iters.size(); ++i) {
+      ss << CleanName(iters[i]->var->name_hint);
+      if (i != iters.size() - 1) {
+        ss << ", ";
+      }
+    }
+    ss << " = "
+       << "tuple(" << CleanName(out->op->name) << ".op.axis)"
+       << " + "
+       << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
+  }
+
+  return ss.str();
+}
+
 }  // namespace auto_scheduler
 }  // namespace tvm
index 32ea8fa..8282d4a 100644 (file)
@@ -143,6 +143,281 @@ def test_compute_at_root_inline():
     assert s0[conv].iters[6].range.extent == 7
 
 
+def test_cache_read_write():
+    N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (
+        1, 1), (1, 1)
+
+    data = te.placeholder((N, CI, H, W), name='Data')
+    kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data')
+    k0, k1 = te.compute(kernel_data.shape,
+                        lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2),
+                        name='Kernel_split')
+    kernel = te.compute(kernel_data.shape,
+                        lambda *i: k0(*i) + k1(*i),
+                        name='Kernel')
+    conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1)
+    relu = topi.nn.relu(conv)
+    add = topi.add(data, relu)
+
+    dag = auto_scheduler.ComputeDAG([data, kernel_data, add])
+    s0 = dag.get_init_state()
+
+    pad_temp = s0.stage_ops[1]
+    kernel_split = s0.stage_ops[3]
+
+    # 0: init state
+    ori_its = s0[add].iters
+    its = s0.split(add, s0[add].iters[0], [2])
+    s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]])
+    s0.compute_inline(relu)
+
+    # 1: simple cache_write with compute_at
+    conv_global = s0.cache_write(conv, "global")
+    s0.compute_at(conv_global, conv, s0[conv].iters[3])
+
+    # 2: simple cache_read with compute_at
+    kernel_global = s0.cache_read(kernel, "global", [conv_global])
+    s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4])
+    """
+        Placeholder: Data, Kernel_data
+        for i0 (0,4)
+          for i1 (0,512)
+            for i2 (0,9)
+              for i3 (0,9)
+                pad_temp = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel_split = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel = ...
+        for nn (0,4)
+          for ff (0,512)
+            for yy (0,7)
+              for xx (0,7)
+                for nn_c (None)
+                  for ff_c (None)
+                    for yy_c (None)
+                      for xx_c (None)
+                        for rc (None)
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  Kernel.global = ...
+                          for ry (None)
+                            for rx (None)
+                              compute.global = ...
+                compute = ...
+        for ax0.0 (0,2)
+          for ax1 (0,512)
+            for ax0.1 (0,2)
+              for ax2 (0,7)
+                for ax3 (0,7)
+                  T_add = ...
+    """
+    s1 = dag.infer_bound_from_state(s0)
+    assert s1[conv].iters[0].range.extent == 4
+    assert s1[conv].iters[1].range.extent == 512
+    assert s1[conv].iters[2].range.extent == 7
+    assert s1[conv].iters[3].range.extent == 7
+    assert s1[kernel_global].iters[0].range.extent == 1
+    assert s1[kernel_global].iters[1].range.extent == 1
+    assert s1[kernel_global].iters[2].range.extent == 3
+    assert s1[kernel_global].iters[3].range.extent == 3
+    assert s1[conv_global].iters[0].range.extent == 1
+    assert s1[conv_global].iters[1].range.extent == 1
+    assert s1[conv_global].iters[2].range.extent == 1
+    assert s1[conv_global].iters[3].range.extent == 1
+    assert s1[conv_global].iters[4].range.extent == 512
+    assert s1[conv_global].iters[5].range.extent == 3
+    assert s1[conv_global].iters[6].range.extent == 3
+
+    # 3: two level cache_read with compute_at
+    #    preparing for GPU's shared memory & local memory
+    pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global])
+    pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global])
+    s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2])
+    s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4])
+
+    # 4: cache_read with multi readers
+    #    This stage cannot be compute at to its consumer
+    s0.cache_read(data, "global", [pad_temp, add])
+    """
+        Placeholder: Data, Kernel_data
+        for ax0 (0,4)
+          for ax1 (0,512)
+            for ax2 (0,7)
+              for ax3 (0,7)
+                Data.global = ...
+        for i0 (0,4)
+          for i1 (0,512)
+            for i2 (0,9)
+              for i3 (0,9)
+                pad_temp = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel_split = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel = ...
+        for nn (0,4)
+          for ff (0,512)
+            for yy (0,7)
+              for xx (0,7)
+                for nn_c (None)
+                  for ff_c (None)
+                    for yy_c (None)
+                      for ax0 (None)
+                        for ax1 (None)
+                          for ax2 (None)
+                            for ax3 (None)
+                              pad_temp.global = ...
+                      for xx_c (None)
+                        for rc (None)
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  Kernel.global = ...
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  pad_temp.global.shared = ...
+                          for ry (None)
+                            for rx (None)
+                              compute.global = ...
+                compute = ...
+        for ax0.0 (0,2)
+          for ax1 (0,512)
+            for ax0.1 (0,2)
+              for ax2 (0,7)
+                for ax3 (0,7)
+                  T_add = ...
+    """
+    s1 = dag.infer_bound_from_state(s0)
+    assert s1[conv].iters[0].range.extent == 4
+    assert s1[conv].iters[1].range.extent == 512
+    assert s1[conv].iters[2].range.extent == 7
+    assert s1[conv].iters[3].range.extent == 7
+    assert s1[kernel_global].iters[0].range.extent == 1
+    assert s1[kernel_global].iters[1].range.extent == 1
+    assert s1[kernel_global].iters[2].range.extent == 3
+    assert s1[kernel_global].iters[3].range.extent == 3
+    assert s1[conv_global].iters[0].range.extent == 1
+    assert s1[conv_global].iters[1].range.extent == 1
+    assert s1[conv_global].iters[2].range.extent == 1
+    assert s1[conv_global].iters[3].range.extent == 1
+    assert s1[conv_global].iters[4].range.extent == 512
+    assert s1[conv_global].iters[5].range.extent == 3
+    assert s1[conv_global].iters[6].range.extent == 3
+    assert s1[pad_temp_global].iters[0].range.extent == 1
+    assert s1[pad_temp_global].iters[1].range.extent == 512
+    assert s1[pad_temp_global].iters[2].range.extent == 3
+    assert s1[pad_temp_global].iters[3].range.extent == 3
+    assert s1[pad_temp_shared].iters[0].range.extent == 1
+    assert s1[pad_temp_shared].iters[1].range.extent == 1
+    assert s1[pad_temp_shared].iters[2].range.extent == 3
+    assert s1[pad_temp_shared].iters[3].range.extent == 3
+
+    # 5: cache_write with multi outputs
+    # TVM's cache_write actually has a bug with this case:
+    #
+    # After schedule.cache_write, TVM generate one new stage:
+    #   From: kernel_data -> kernel_split -> kernel
+    #   To:   kernel_data -> kernel_split_global -> kernel_split -> kernel
+    #
+    # But with topo sort analyse, we get:
+    #  //   kernel_data -> kernel_split_global -> kernel_split -> kernel
+    #         \                                                /
+    #          ----------------> kernel_split ---------------->
+    #
+    # TODO(jcf94): Seems there's bug with the input/output tensor. Such multi outputs case
+    # should be unusual, so we make some hack on DoCacheWrite. This should be fixed later.
+    kernel_split_global = s0.cache_write(kernel_split, "global")
+    """
+        Placeholder: Data, Kernel_data
+        for ax0 (0,4)
+          for ax1 (0,512)
+            for ax2 (0,7)
+              for ax3 (0,7)
+                Data.global = ...
+        for i0 (0,4)
+          for i1 (0,512)
+            for i2 (0,9)
+              for i3 (0,9)
+                pad_temp = ...
+        for i0_c (0,512)
+          for i1_c (0,512)
+            for i2_c (0,3)
+              for i3_c (0,3)
+                Kernel_split.global = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel_split = ...
+        (******* Bug here, there should not be two kernel_split stage *******)
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel_split = ...
+        (******* Bug here, there should not be two kernel_split stage *******)
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel = ...
+        for nn (0,4)
+          for ff (0,512)
+            for yy (0,7)
+              for xx (0,7)
+                for nn_c (None)
+                  for ff_c (None)
+                    for yy_c (None)
+                      for ax0 (None)
+                        for ax1 (None)
+                          for ax2 (None)
+                            for ax3 (None)
+                              pad_temp.global = ...
+                      for xx_c (None)
+                        for rc (None)
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  Kernel.global = ...
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  pad_temp.global.shared = ...
+                          for ry (None)
+                            for rx (None)
+                              compute.global = ...
+                compute = ...
+        for ax0.0 (0,2)
+          for ax1 (0,512)
+            for ax0.1 (0,2)
+              for ax2 (0,7)
+                for ax3 (0,7)
+                  T_add = ...
+    """
+    assert len(s0[kernel_split].iters) == len(s0[kernel_split_global].iters)
+    for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters):
+        assert it0.range == it1.range
+
 if __name__ == "__main__":
     test_split_fuse_reorder_annotation()
     test_compute_at_root_inline()
+    test_cache_read_write()
index 333d20e..5f2f87a 100644 (file)
@@ -35,7 +35,7 @@ def test_record():
     C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
     D = topi.nn.relu(C)
     k = te.reduce_axis((0, 512), name='k')
-    E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='C')
+    E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E')
     F = topi.nn.relu(E)
 
     dag = auto_scheduler.ComputeDAG([A, B, F])
@@ -66,6 +66,11 @@ def test_record():
     s.unroll(C, s[C].iters[4])
     # Vectorize
     s.vectorize(C, s[C].iters[6])
+    # Cache Read
+    D_global = s.cache_read(D, "global", [E])
+    s.compute_at(D_global, E, s[E].iters[2])
+    # Cache Write
+    s.cache_write(D, "shared")
 
     target = tvm.target.create("llvm")
     task = auto_scheduler.SearchTask(dag, "test", target)