*/
State InferBound(const State& state) const;
+ /*!
+ * \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial
+ * ComputeDAG may not be up-to-date. This function replays the given transform steps from the
+ * initial state and returns an up-to-date ComputeDAG.
+ * \param steps The steps to be replaied. Usually we'll filter out the unused steps to speed up
+ * the replay process, since we only intend to get a ComputeDAG with the up-to-date op stage
+ * structure.
+ * \return The up-to-date ComputeDAG.
+ */
+ ComputeDAG ReplayAndGetDAG(const Array<Step>& steps) const;
+
TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
};
public:
/*!
* \brief Process the stage/iterator mapping after compute at.
- * \param stage_id The index of the stage to be compute at.
+ * \param stage_id The index of the stage to be computed at.
* \param target_stage_id The index of stage that this step will compute at to.
* \param target_iter_id The index of iterator in target stage that this step will compute at to.
*/
void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
+
/*!
* \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage.
- * \param stage_id The index of the stage to be compute at.
+ * \param stage_id The index of the stage to be computed at.
*/
void DeleteStage(int stage_id);
+
/*!
* \brief Find the relations of original iterators in AttachMap, and update them with the new
* iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
void UpdateIters(const std::vector<IterKey>& original_iters,
const std::vector<IterKey>& new_iters);
+ /*!
+ * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset
+ * to stage indexes that are larger than the start_id. Used for steps that insert new stages to
+ * ComputeDAG(e.g. CacheRead/CacheWrite step).
+ * \param start_id The index threshold, stage indexes in AttachMap which are larger than this
+ * will be applied the extra offset.
+ * \param offset The index offset to be added to the stage index.
+ * \return The updated AttachMap after applying stage index offset.
+ */
+ AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const;
+
TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode);
* operation.
*/
AttachMap attach_map;
+ /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, means
+ * no modification to the original ComputeDAG.
+ * Otherwise, it means some steps (e.g., CacheReadStep/CacheWriteStep) have modified the
+ * ComputeDAG, the stored value is the up-to-date ComputeDAG for this state.
+ */
+ Optional<ObjectRef> current_compute_dag;
/*!
* \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all
* tile sizes of the state is filled. Only concrete state can be apply to TVM schedule.
static constexpr const char* _type_key = "auto_scheduler.State";
TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object);
-
- private:
- /*!
- * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the
- * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added
- * later).
- * The default value is an empty ObjectRef. (means no modification to the original DAG)
- */
- ObjectRef current_compute_dag;
};
/*!
/********** Step APIs working on single stage **********/
/*!
- * \brief Schedule primitive corresponds to te.bind.
+ * \brief Schedule primitive corresponds to `te::Stage::bind`.
* \param stage_id The index of the stage to be binded.
* \param it The iterator to be binded.
* \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as
*/
TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
/*!
- * \brief Schedule primitive corresponds to te.parallel.
+ * \brief Schedule primitive corresponds to `te::Stage::parallel`.
* \param stage_id The index of the stage to be paralleled.
* \param it The iterator to be paralleled.
* \return The iterator result after parallel.
*/
TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
/*!
- * \brief Schedule primitive corresponds to te.unroll.
+ * \brief Schedule primitive corresponds to `te::Stage::unroll`.
* \param stage_id The index of the stage to be unrolled.
* \param it The iterator to be unrolled.
* \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be
*/
TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
/*!
- * \brief Schedule primitive corresponds to te.vectorize.
+ * \brief Schedule primitive corresponds to `te::Stage::vectorize`.
* \param stage_id The index of the stage to be vectorized.
* \param it The iterator to be vectorized.
* \return The iterator result after vectorize.
*/
TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
/*!
- * \brief Schedule primitive corresponds to te.fuse.
+ * \brief Schedule primitive corresponds to `te::Stage::fuse`.
* \param stage_id The index of the stage to be fused.
* \param iters The iterators to be fused.
* \return The iterator result after fuse.
*/
TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
/*!
- * \brief Schedule primitive corresponds to te.reorder.
+ * \brief Schedule primitive corresponds to `te::Stage::reorder`.
* \param stage_id The index of the stage to be reordered.
* \param order The expected iterator order.
*/
TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
/*!
- * \brief Schedule primitive corresponds to te.split.
+ * \brief Schedule primitive corresponds to `te::Stage::split`.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param lengths The multiple split factors. Can be None to be filled by search policy.
/********** Step APIs working on multiple stages **********/
/*!
- * \brief Schedule primitive corresponds to te.compute_at.
- * \param stage_id The index of the stage to be reordered.
+ * \brief Schedule primitive corresponds to `te::Stage::compute_at`.
+ * \param stage_id The index of the stage to be computed at.
* \param target_stage_id The index of stage that this step will compute at to.
* \param target_iter The iterator in target stage that this step will compute at to.
* \note After compute_at, we need careful dependency analysis to compute the accurate bound
*/
TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
/*!
- * \brief Schedule primitive corresponds to te.compute_inline.
- * \param stage_id The index of the stage to be reordered.
+ * \brief Schedule primitive corresponds to `te::Stage::compute_inline`.
+ * \param stage_id The index of the stage to be marked compute inlined.
*/
TVM_DLL void compute_inline(int stage_id);
/*!
- * \brief Schedule primitive corresponds to te.compute_root.
- * \param stage_id The index of the stage to be reordered.
+ * \brief Schedule primitive corresponds to `te::Stage::compute_root`.
+ * \param stage_id The index of the stage to be marked compute at root.
* \note After compute_root, we need careful dependency analysis to compute the accurate bound
* information. However, it is relatively expensive and complicated, so we just fill "None" as
* bound for the newly created iterators.
*/
TVM_DLL void compute_root(int stage_id);
+ /********** Step APIs adding new stages **********/
+
+ /*!
+ * \brief Schedule primitive corresponds to `te::Schedule::cache_read`.
+ * \param stage_id The index of the stage to be cache read.
+ * \param scope_name The scope name of the newly added read stage.
+ * \param reader_stage_ids The indices of read stages.
+ * \param dag The original ComputeDAG of this state.
+ * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
+ * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+ */
+ int cache_read(int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids,
+ const ComputeDAG& dag);
+ /*!
+ * \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
+ * \param stage_id The index of the stage to be cache write.
+ * \param scope_name The scope name of the newly added compute stage.
+ * \param dag The original ComputeDAG of this state.
+ * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the
+ * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+ * This step will cache write all output tensors of the target stage.
+ */
+ int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
+
TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
};
typedef Map<tvm::te::Stage, Array<tir::IterVar>, ObjectHash, ObjectEqual> StageToAxesMap;
+/*!
+ * \brief Update the current stage IterVar information to StageToAxesMap.
+ * \param stage A te::Stage Object.
+ * \param stage_to_axes A mutable pointer to StageToAxesMap, this map will be updated.
+ */
+void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes);
+
/*! \brief The type of an iterator. */
enum class IteratorKind : int {
/*! \brief Spatial iterator. */
/*!
* \brief Apply the step to State.
* \param step The step to be applied to State.
- * \param state A mutable pointer to State.
+ * \param state A mutable pointer to state, which will be updated.
* \param dag The original ComputeDAG of this state.
*/
void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
/*!
* \brief Apply the step to tvm.schedule.
* \param step The step to be applied to tvm.schedule.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need
+ * `te::Schedule` API. (e.g. CacheRead/CacheWrite step)
*/
-void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes);
+void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule);
/*!
* \brief Print the step as equivalent python schedule API.
* \param step The step to be applied to python API.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g.
+ * CacheRead/CacheWrite step)
* \return Python schedule code.
*/
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes);
+ StageToAxesMap* stage_to_axes, te::Schedule* schedule);
/********** Steps working on single stage **********/
/*!
* \brief Apply the current step to State.
- * \param state A mutable pointer to State.
+ * \param state A mutable pointer to state, which will be updated.
* \return The iterator result after annotate.
*/
Iterator ApplyToState(State* state) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Apply the current step to State.
- * \param state A mutable pointer to State.
+ * \param state A mutable pointer to state, which will be updated.
* \return The iterator result after fuse.
* \note If the iterators to be fused have stages attached at them(by compute_at), the fused
* result will become the new attach point.
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return The iterator result after fuse.
*/
tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Apply the current step to State.
- * \param state A mutable pointer to State.
+ * \param state A mutable pointer to state, which will be updated.
*/
void ApplyToState(State* state) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Apply the current step to State.
- * \param state A mutable pointer to State.
+ * \param state A mutable pointer to state, which will be updated.
* \return The iterator results after split.
* \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
* most iterator of split results will become the new attach point.
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return The iterator results after split.
*/
Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Apply the current step to State.
- * \param state A mutable pointer to State.
+ * \param state A mutable pointer to state, which will be updated.
* \note After compute_at, we need careful dependency analysis to compute the accurate bound
* information. However, it is relatively expensive and complicated, so we just fill "None" as
* bound for the newly created iterators.
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
public:
/*!
* \brief The constructor.
- * \param stage_id The index of the stage to be compute at.
+ * \param stage_id The index of the stage to be computed at.
* \param target_stage_id The index of stage that this step will compute at to.
* \param target_iter_id The index of iterator in target stage that this step will compute at to.
*/
/*!
* \brief Apply the current step to State.
- * \param state A mutable pointer to State.
+ * \param state A mutable pointer to state, which will be updated.
*/
void ApplyToState(State* state) const;
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return The iterator result after fuse.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
public:
/*!
* \brief The constructor.
- * \param stage_id The index of the stage to be compute inline.
+ * \param stage_id The index of the stage to be marked compute inlined.
*/
explicit ComputeInlineStep(int stage_id);
/*!
* \brief Apply the current step to State.
- * \param state A mutable pointer to State.
- * \note After compute_at, we need careful dependency analysis to compute the accurate bound
+ * \param state A mutable pointer to state, which will be updated.
+ * \note After compute_root, we need careful dependency analysis to compute the accurate bound
* information. However, it is relatively expensive and complicated, so we just fill "None" as
* bound for the newly created iterators.
* Call ComputeDAG::InferBound on the updated state to get the complete bound information.
/*!
* \brief Apply the current step to tvm.schedule.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return The iterator result after fuse.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
* \brief Print the current step as equivalent python schedule API.
- * \param stages A pointer to a `te::Stage` Array.
- * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
public:
/*!
* \brief The constructor.
- * \param stage_id The index of the stage to be compute root
+ * \param stage_id The index of the stage to be marked compute at root.
*/
explicit ComputeRootStep(int stage_id);
TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
};
+/********** Primitives adding new stages **********/
+
+/*!
+ * \brief Cache read step that corresponds to te::Schedule::cache_read.
+ * \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date ComputeDAG
+ * is stored in State's `current_compute_dag`.
+ */
+class CacheReadStepNode : public StepNode {
+ public:
+ /*! \brief The scope name of the newly added read stage. (e.g. local, shared, global) */
+ String scope_name;
+ /*! \brief The indices of read stages. */
+ Array<Integer> reader_stage_ids;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to state, which will be updated.
+ * \param dag The original ComputeDAG of this state.
+ * \return The index of the new added stage.
+ */
+ int ApplyToState(State* state, const ComputeDAG& dag) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param schedule A mutable pointer to a te::Schedule.
+ * \return The output Tensor of the new added stage.
+ */
+ te::Tensor ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param schedule A mutable pointer to a te::Schedule.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const;
+
+ static constexpr const char* record_prefix_str = "CHR";
+
+ static constexpr const char* _type_key = "auto_scheduler.CacheReadStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to CacheReadStepNode.
+ * \sa CacheReadStepNode
+ */
+class CacheReadStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be cache read.
+ * \param scope_name The scope name of the newly added read stage.
+ * \param reader_stage_ids The indices of read stages.
+ */
+ CacheReadStep(int stage_id, String scope_name, const Array<Integer>& reader_stage_ids);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit CacheReadStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode);
+};
+
+/*!
+ * \brief Cache write step that corresponds to te::Schedule::cache_write.
+ * \note Cache write step will add an extra stage to the original ComputeDAG, a up-to-date
+ * ComputeDAG is stored in State's `current_compute_dag`.
+ * This step will cache write all output tensors of the target stage.
+ */
+class CacheWriteStepNode : public StepNode {
+ public:
+ /*! \brief The scope name of the newly added compute stage. (e.g. local, shared, global) */
+ String scope_name;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to state, which will be updated.
+ * \param dag The original ComputeDAG of this state.
+ * \return The index of the new added stage.
+ */
+ int ApplyToState(State* state, const ComputeDAG& dag) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param schedule A mutable pointer to a te::Schedule.
+ * \return The output Tensors of the new added stage.
+ */
+ Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param schedule A mutable pointer to a te::Schedule.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const;
+
+ static constexpr const char* record_prefix_str = "CHW";
+
+ static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to CacheWriteStepNode.
+ * \sa CacheWriteStepNode
+ */
+class CacheWriteStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be cache write.
+ * \param scope_name The scope name of the newly added compute stage.
+ */
+ CacheWriteStep(int stage_id, String scope_name);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit CacheWriteStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode);
+};
+
} // namespace auto_scheduler
} // namespace tvm
Returns
-------
- state : State
+ updated_state : State
The State with complete bound information.
"""
state_obj = state if isinstance(state, StateObject) else state.state_object
- return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)
+ updated_state = State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)
+ # Copy the stage_id_map from the original state to make sure the old indices are still
+ # valid
+ if isinstance(state, State):
+ for k, v in state.stage_id_map.items():
+ updated_state.stage_id_map[k] = v
+ return updated_state
def __hash__(self):
# TODO(merrymercy): Implement this more carefully and move this to c++ as a member function
return [stage.op for stage in self.stages]
def bind(self, stage, iterator, thread_name):
- """ Schedule primitive corresponds to te.bind.
+ """ Schedule primitive corresponds to `te.Stage.bind`, see also the `te.Stage` for more
+ details.
Parameters
----------
return res
def parallel(self, stage, iterator):
- """ Schedule primitive corresponds to te.parallel.
+ """ Schedule primitive corresponds to `te.Stage.parallel`, see also the `te.Stage` for more
+ details.
Parameters
----------
return res
def unroll(self, stage, iterator, max_unroll=None):
- """ Schedule primitive corresponds to te.unroll.
+ """ Schedule primitive corresponds to `te.Stage.unroll`, see also the `te.Stage` for more
+ details.
Parameters
----------
return res
def vectorize(self, stage, iterator):
- """ Schedule primitive corresponds to te.vectorize.
+ """ Schedule primitive corresponds to `te.Stage.vectorize`, see also the `te.Stage` for
+ more details.
Parameters
----------
return res
def fuse(self, stage, iters):
- """ Schedule primitive corresponds to te.fuse.
+ """ Schedule primitive corresponds to `te.Stage.fuse`, see also the `te.Stage` for more
+ details.
Parameters
----------
return res
def reorder(self, stage, order):
- """ Schedule primitive corresponds to te.reorder.
+ """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more
+ details.
Parameters
----------
order)
def split(self, stage, iterator, lengths, inner_to_outer=True):
- """ Schedule primitive corresponds to te.split.
+ """ Schedule primitive corresponds to `te.Stage.split`, see also the `te.Stage` for more
+ details.
This API supports multiple split factors. (e.g. with 2 split factors, the original iterator
will be split to 3 parts, use `inner_to_outer` to control the split order)
return res
def compute_at(self, stage, target_stage, target_iter):
- """ Schedule primitive corresponds to te.compute_at.
+ """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
+ more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
- The Stage to be compute at, which can be specified by the integer index, Operation,
+ The Stage to be computed at, which can be specified by the integer index, Operation,
or output tensor of the stage.
target_stage : Union[int, Operation, Tensor]
The target stage of compute_at, which can be specified by the integer index, Operation,
target_iter)
def compute_inline(self, stage):
- """ Schedule primitive corresponds to te.compute_inline.
+ """ Schedule primitive corresponds to `te.Stage.compute_inline`, see also the `te.Stage`
+ for more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
- The Stage to be compute inlined, which can be specified by the integer index, Operation,
- or output tensor of the stage.
+ The Stage to be marked compute inlined, which can be specified by the integer index,
+ Operation, or output tensor of the stage.
"""
self.state_object = _ffi_api.StateComputeInline(self.state_object,
self._resolve_stage_id(stage))
def compute_root(self, stage):
- """ Schedule primitive corresponds to te.compute_root.
+ """ Schedule primitive corresponds to `te.Stage.compute_root`, see also the `te.Stage` for
+ more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
- The Stage to be compute root, which can be specified by the integer index, Operation,
- or output tensor of the stage.
+ The Stage to be marked compute at root, which can be specified by the integer index,
+ Operation, or output tensor of the stage.
Notes
-----
self.state_object = _ffi_api.StateComputeRoot(self.state_object,
self._resolve_stage_id(stage))
+ def cache_read(self, stage, scope_name, reader_stages):
+ """ Schedule primitive corresponds to `te.Schedule.cache_read`, see also the `te.Schedule`
+ for more details.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be cache read, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ scope_name : str
+ The scope name of the newly added read stage.
+ reader_stages : List[Union[int, Operation, Tensor]]
+ The reader stages. Each of the list can be specified by the integer index, Operation,
+ or output tensor of the stage.
+
+ Returns
+ -------
+ new_stage_op : Operator
+ The Operator of the new added stage.
+
+ Notes
+ -----
+ Cache read step will insert an extra stage to the original ComputeDAG (at the back of the
+ target stage).
+ """
+ reader_stage_ids = [self._resolve_stage_id(i) for i in reader_stages]
+ self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object,
+ self._resolve_stage_id(stage),
+ scope_name, reader_stage_ids,
+ self.compute_dag)
+ # Add a new stage will change all ops behind the added stage. But we still want to keep the
+ # original ops map, apply stage id offset to stage_id_map to make them work.
+ self._apply_stage_id_offset(int(new_stage_id))
+ self._update_stage_id_map()
+ return self.stages[int(new_stage_id)].op
+
+ def cache_write(self, stage, scope_name):
+ """ Schedule primitive corresponds to `te.Schedule.cache_write`, see also the `te.Schedule`
+ for more details.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be cache write, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ scope_name : str
+ The scope name of the newly added compute stage.
+
+ Returns
+ -------
+ new_stage_op : Operator
+ The Operator of the new added stage.
+
+ Notes
+ -----
+ Cache write step will insert an extra stage to the original ComputeDAG (in the front of the
+ target stage).
+ This step will cache write all output tensors of the target stage.
+ """
+ self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object,
+ self._resolve_stage_id(stage),
+ scope_name, self.compute_dag)
+ # Add a new stage will change all ops behind the added stage. But we still want to keep the
+ # original ops map, apply stage id offset to stage_id_map to make them work.
+ self._apply_stage_id_offset(int(new_stage_id))
+ self._update_stage_id_map()
+ return self.stages[int(new_stage_id)].op
+
def copy(self):
""" Do deep copy of this State. """
state = State(self.state_object, self.compute_dag)
for index, stage in enumerate(self.stages):
self.stage_id_map[stage.op] = index
+ def _apply_stage_id_offset(self, start_id, offset=1):
+ for key, value in self.stage_id_map.items():
+ if value >= start_id:
+ self.stage_id_map[key] = value + offset
+
def __getitem__(self, key):
if isinstance(key, Tensor):
key = key.op
data_ = std::move(node);
}
-// Update the te::stage to tir::IterVar axis mapping
-void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) {
- if (auto pop = stage->op.as<te::ComputeOpNode>()) {
- Array<IterVar> axes;
- for (const auto& axis : pop->axis) {
- axes.push_back(axis);
- }
- for (const auto& axis : pop->reduce_axis) {
- axes.push_back(axis);
- }
- stage_to_axes->Set(stage, std::move(axes));
- } else if (stage->op->IsInstance<te::PlaceholderOpNode>()) {
- {} // do nothing on Placeholder
- } else {
- LOG(FATAL) << "Invalid op " << stage->op;
- }
-}
-
std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
const Array<Step>& transform_steps, Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
// Apply the history steps to TVM schedule
// Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
- StepApplyToSchedule(step, stages, stage_to_axes);
+ StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
}
return std::make_pair(schedule, operator->()->tensors);
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
- ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes);
+ ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
}
return ss.str();
return ret_state;
}
+ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array<Step>& transform_steps) const {
+ te::Schedule sch;
+ Array<te::Tensor> old_tensors;
+ std::tie(sch, old_tensors) = ApplySteps(transform_steps);
+
+ Array<te::Tensor> new_tensors;
+ for (auto stage : sch->stages) {
+ if (stage->op->IsInstance<te::PlaceholderOpNode>() || stage->is_output) {
+ for (auto i = 0; i < stage->op->num_outputs(); ++i) {
+ new_tensors.push_back(stage->op.output(i));
+ }
+ }
+ }
+
+ return ComputeDAG(new_tensors);
+}
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ComputeDAGNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ComputeDAGNode*>(ref.get());
* see auto_scheduler/loop_state.h for more explanation.
*/
+#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/registry.h>
}
}
+AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const {
+ AttachMap map = AttachMap(make_object<AttachMapNode>());
+ auto pmap = map.CopyOnWrite();
+ for (const auto& x : operator->()->stage_to_attach_iter) {
+ auto key = x.first;
+ if (key >= start_id) {
+ key += offset;
+ }
+ auto value = x.second;
+ if (value.first >= start_id) {
+ value.first += offset;
+ }
+ pmap->stage_to_attach_iter.insert(std::make_pair(key, value));
+ }
+ for (const auto& x : operator->()->iter_to_attached_stages) {
+ auto key = x.first;
+ if (key.first >= start_id) {
+ key.first += offset;
+ }
+ auto value = x.second;
+ for (auto& i : value) {
+ if (i >= start_id) {
+ i += offset;
+ }
+ }
+ pmap->iter_to_attached_stages.insert(std::make_pair(key, value));
+ }
+ return map;
+}
+
/********** State **********/
State::State(const Array<te::Operation>& ops) {
auto node = make_object<StateNode>();
step->ApplyToState(this);
}
+int State::cache_read(int stage_id, const String& scope_name,
+ const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
+ CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this, dag);
+}
+
+int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) {
+ CacheWriteStep step = CacheWriteStep(stage_id, scope_name);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this, dag);
+}
+
void State::ApplySteps(const ComputeDAG& dag) {
CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
return state;
});
+TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead")
+ .set_body_typed([](State state, int stage_id, const String& scope_name,
+ const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
+ int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag);
+ return Array<ObjectRef>{state, Integer(res)};
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite")
+ .set_body_typed([](State state, int stage_id, const String& scope_name,
+ const ComputeDAG& task_dag) {
+ int res = state.cache_write(stage_id, scope_name, task_dag);
+ return Array<ObjectRef>{state, Integer(res)};
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
return std::equal_to<State>()(state1, state2);
});
* They are similar to the schedule primitives in te::Stage.
*/
+#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace auto_scheduler {
+// Update the te::stage to tir::IterVar axis mapping
+void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) {
+ if (auto pop = stage->op.as<te::ComputeOpNode>()) {
+ Array<IterVar> axes;
+ for (const auto& axis : pop->axis) {
+ axes.push_back(axis);
+ }
+ for (const auto& axis : pop->reduce_axis) {
+ axes.push_back(axis);
+ }
+ stage_to_axes->Set(stage, std::move(axes));
+ } else if (stage->op->IsInstance<te::PlaceholderOpNode>()) {
+ {} // do nothing on Placeholder
+ } else {
+ LOG(FATAL) << "Invalid op " << stage->op;
+ }
+}
+
const char* IteratorAnnotationString[] = {
"for", // kNone = 0
"unroll", // kUnroll = 1
return ComputeInlineStep(reader);
} else if (name == ComputeRootStepNode::record_prefix_str) {
return ComputeRootStep(reader);
+ } else if (name == CacheReadStepNode::record_prefix_str) {
+ return CacheReadStep(reader);
+ } else if (name == CacheWriteStepNode::record_prefix_str) {
+ return CacheWriteStep(reader);
} else {
LOG(FATAL) << "Invalid step format: " << name;
}
ps->ApplyToState(state);
} else if (auto ps = step.as<ComputeRootStepNode>()) {
ps->ApplyToState(state);
+ } else if (auto ps = step.as<CacheReadStepNode>()) {
+ ps->ApplyToState(state, dag);
+ } else if (auto ps = step.as<CacheWriteStepNode>()) {
+ ps->ApplyToState(state, dag);
} else {
LOG(FATAL) << "Invalid step: " << step;
}
}
-void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes) {
- // We need this runtime dispatcher because different steps have different function signatures
+void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) {
if (auto ps = step.as<AnnotationStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeRootStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<CacheReadStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes, schedule);
+ } else if (auto ps = step.as<CacheWriteStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes, schedule);
} else {
LOG(FATAL) << "Invalid Step: " << step;
}
}
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes) {
- // We need this runtime dispatcher because different steps have different function signatures
+ StageToAxesMap* stage_to_axes, te::Schedule* schedule) {
if (auto ps = step.as<AnnotationStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeRootStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else if (auto ps = step.as<CacheReadStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
+ } else if (auto ps = step.as<CacheWriteStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
} else {
LOG(FATAL) << "Invalid Step: " << step;
}
return ss.str();
}
+/********** Primitives adding new stages **********/
+
+/*!
+ * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep,
+ * RfactorStep). This will return all steps that can change the number of stages in a ComputeDAG,
+ * and stop by the current step.
+ */
+Array<Step> GetFormerStageModifiableSteps(Step current_step, const Array<Step>& transform_steps) {
+ Array<Step> ret_steps;
+ for (const Step& step : transform_steps) {
+ if (step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>()) {
+ ret_steps.push_back(step);
+ }
+ // TODO(jcf94): add rfactor support
+ // A state may have multiple stage modifiable steps, stop by the current step to avoid
+ // replaying excess steps
+ if (step.same_as(current_step)) {
+ break;
+ }
+ }
+ return ret_steps;
+}
+
+/********** Cache Read **********/
+CacheReadStep::CacheReadStep(int stage_id, String scope_name,
+ const Array<Integer>& reader_stage_ids) {
+ auto node = make_object<CacheReadStepNode>();
+ node->stage_id = stage_id;
+ node->scope_name = std::move(scope_name);
+ node->reader_stage_ids = reader_stage_ids;
+ data_ = std::move(node);
+}
+
+CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) {
+ auto node = make_object<CacheReadStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ std::string string_value;
+ reader->Read(&string_value);
+ node->scope_name = std::move(string_value);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ std::vector<int> int_list;
+ reader->Read(&int_list);
+ Array<Integer> reader_stage_ids;
+ for (int i : int_list) {
+ reader_stage_ids.push_back(i);
+ }
+ node->reader_stage_ids = std::move(reader_stage_ids);
+ data_ = std::move(node);
+}
+
+void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArraySeperator();
+ writer->WriteString(scope_name);
+ writer->WriteArrayItem(IntArrayToVector(reader_stage_ids));
+}
+
+int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
+ StateNode* pstate = state->CopyOnWrite();
+ const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(
+ GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
+
+ // target_stage -> target_stage + target_store
+ // Update the op of the target stage, insert a new cache read stage behind, update the op of
+ // later stages, then update the stage_id mapping in AttachMap
+ int added_stage_id = stage_id + 1;
+ Stage tmp_stage = pstate->stages[stage_id];
+ tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id];
+ pstate->stages.Set(stage_id, std::move(tmp_stage));
+ pstate->stages.insert(pstate->stages.begin() + added_stage_id,
+ Stage(current_compute_dag->ops[added_stage_id]));
+ for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) {
+ tmp_stage = pstate->stages[i];
+ tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+ pstate->stages.Set(i, std::move(tmp_stage));
+ }
+ pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id);
+ pstate->current_compute_dag = std::move(current_compute_dag);
+
+ return added_stage_id;
+}
+
+te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const {
+ const te::Stage& stage = (*stages)[stage_id];
+ Array<te::Operation> readers;
+ for (const auto& i : reader_stage_ids) {
+ readers.push_back((*stages)[i]->origin_op);
+ }
+ auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers);
+
+ const auto& new_stage = (*schedule)[out->op];
+ UpdateStageToAxesMap(new_stage, stage_to_axes);
+ stages->insert(stages->begin() + stage_id + 1, new_stage);
+
+ return out;
+}
+
+String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const {
+ std::stringstream ss;
+ // Since the original stage will be changed after schedule apply, keep a copy here
+ // These information will be used to print Python API string later
+ auto stage = (*stages)[stage_id];
+ Array<te::Stage> reader_stages;
+ for (size_t i = 0; i < reader_stage_ids.size(); ++i) {
+ reader_stages.push_back((*stages)[reader_stage_ids[i]]);
+ }
+ auto out = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+ ss << CleanName(out->op->name) << " = "
+ << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", ["
+ << CleanName(reader_stages[0]->op->name);
+ for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
+ ss << ", " << CleanName(reader_stages[i]->op->name);
+ }
+ ss << "])\n";
+
+ // Print the iterators of the new added stage
+ const auto& iters = out->op->root_iter_vars();
+ for (size_t i = 0; i < iters.size(); ++i) {
+ ss << CleanName(iters[i]->var->name_hint);
+ if (i != iters.size() - 1) {
+ ss << ", ";
+ }
+ }
+ ss << " = "
+ << "tuple(" << CleanName(out->op->name) << ".op.axis)\n";
+
+ return ss.str();
+}
+
+/********** Cache Write **********/
+CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) {
+ auto node = make_object<CacheWriteStepNode>();
+ node->stage_id = stage_id;
+ node->scope_name = std::move(scope_name);
+ data_ = std::move(node);
+}
+
+CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) {
+ auto node = make_object<CacheWriteStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ std::string string_value;
+ reader->Read(&string_value);
+ node->scope_name = std::move(string_value);
+ data_ = std::move(node);
+}
+
+void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArraySeperator();
+ writer->WriteString(scope_name);
+}
+
+int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
+ StateNode* pstate = state->CopyOnWrite();
+ int last_dag_op_size = pstate->current_compute_dag
+ ? pstate->current_compute_dag.value().as<ComputeDAGNode>()->ops.size()
+ : dag->ops.size();
+ const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(
+ GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
+ int added_ops = current_compute_dag->ops.size() - last_dag_op_size;
+ // TODO(jcf94): Update this check to equal after fixing the cache write bug in TVM
+ CHECK_GE(added_ops, 1);
+
+ // target_stage -> cache_write_stage + target_stage
+ // Assume no step has been applied to the target stage before cache write.
+ // Insert a new cache write stage ahead, update the op of the target stage and later stages, then
+ // update the stage_id mapping in AttachMap
+ pstate->stages.insert(pstate->stages.begin() + stage_id,
+ Stage(current_compute_dag->ops[stage_id]));
+ pstate->stages.Set(stage_id + 1, Stage(current_compute_dag->ops[stage_id + 1]));
+ int next_stage_id = stage_id + 2;
+ // TODO(jc94): Fix the cache write bug in TVM and remove added_op == 2 support.
+ // TVM's cache_write has a bug with multi outputs. See
+ // `tests/python/unittest/test_auto_scheduler_loop_state.py::test_cache_read_write` test
+ // for more details
+ if (added_ops == 2) {
+ pstate->stages.insert(pstate->stages.begin() + next_stage_id,
+ Stage(current_compute_dag->ops[next_stage_id]));
+ next_stage_id++;
+ } else if (added_ops > 2) {
+ LOG(ERROR) << "Unexpected behavior of CacheWrite.";
+ }
+ for (size_t i = next_stage_id; i < current_compute_dag->ops.size(); ++i) {
+ Stage tmp_stage = pstate->stages[i];
+ tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+ pstate->stages.Set(i, std::move(tmp_stage));
+ }
+ pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id, added_ops);
+ pstate->current_compute_dag = std::move(current_compute_dag);
+
+ return stage_id;
+}
+
+Array<te::Tensor> CacheWriteStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const {
+ const te::Stage& stage = (*stages)[stage_id];
+ Array<te::Tensor> tensor_array;
+ // If the target stage has multi outputs, TVM requires to cache_write
+ // all of them or schedule.cache_write will raise an error
+ for (auto i = 0; i < stage->op->num_outputs(); ++i) {
+ tensor_array.push_back(stage->origin_op.output(i));
+ }
+ auto outs = schedule->cache_write(tensor_array, scope_name);
+
+ UpdateStageToAxesMap(stage, stage_to_axes);
+ // Even if there is multi outputs, TVM schedule only generate one
+ // new stage
+ const auto& new_stage = (*schedule)[outs[0]->op];
+ UpdateStageToAxesMap(new_stage, stage_to_axes);
+ stages->insert(stages->begin() + stage_id, new_stage);
+
+ return outs;
+}
+
+String CacheWriteStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const {
+ std::stringstream ss;
+ // Since the original stage will be changed after schedule apply, keep a copy here
+ // These information will be used to print Python API string later
+ te::Stage stage = (*stages)[stage_id];
+ auto outs = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+ for (size_t i = 0; i < outs.size(); ++i) {
+ ss << CleanName(outs[i]->op->name) << ", ";
+ }
+ ss << "= "
+ << "s.cache_write([" << CleanName(stage->op.output(0)->op->name);
+ for (auto i = 1; i < stage->op->num_outputs(); ++i) {
+ ss << ", " << CleanName(stage->op.output(i)->op->name);
+ }
+ ss << "], \"" << scope_name << "\")\n";
+
+ // Print the iterators of the new added stage
+ for (const auto& out : outs) {
+ const auto& iters = out->op->root_iter_vars();
+ for (size_t i = 0; i < iters.size(); ++i) {
+ ss << CleanName(iters[i]->var->name_hint);
+ if (i != iters.size() - 1) {
+ ss << ", ";
+ }
+ }
+ ss << " = "
+ << "tuple(" << CleanName(out->op->name) << ".op.axis)"
+ << " + "
+ << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
+ }
+
+ return ss.str();
+}
+
} // namespace auto_scheduler
} // namespace tvm
assert s0[conv].iters[6].range.extent == 7
+def test_cache_read_write():
+ N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (
+ 1, 1), (1, 1)
+
+ data = te.placeholder((N, CI, H, W), name='Data')
+ kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data')
+ k0, k1 = te.compute(kernel_data.shape,
+ lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2),
+ name='Kernel_split')
+ kernel = te.compute(kernel_data.shape,
+ lambda *i: k0(*i) + k1(*i),
+ name='Kernel')
+ conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1)
+ relu = topi.nn.relu(conv)
+ add = topi.add(data, relu)
+
+ dag = auto_scheduler.ComputeDAG([data, kernel_data, add])
+ s0 = dag.get_init_state()
+
+ pad_temp = s0.stage_ops[1]
+ kernel_split = s0.stage_ops[3]
+
+ # 0: init state
+ ori_its = s0[add].iters
+ its = s0.split(add, s0[add].iters[0], [2])
+ s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]])
+ s0.compute_inline(relu)
+
+ # 1: simple cache_write with compute_at
+ conv_global = s0.cache_write(conv, "global")
+ s0.compute_at(conv_global, conv, s0[conv].iters[3])
+
+ # 2: simple cache_read with compute_at
+ kernel_global = s0.cache_read(kernel, "global", [conv_global])
+ s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4])
+ """
+ Placeholder: Data, Kernel_data
+ for i0 (0,4)
+ for i1 (0,512)
+ for i2 (0,9)
+ for i3 (0,9)
+ pad_temp = ...
+ for i0 (0,512)
+ for i1 (0,512)
+ for i2 (0,3)
+ for i3 (0,3)
+ Kernel_split = ...
+ for i0 (0,512)
+ for i1 (0,512)
+ for i2 (0,3)
+ for i3 (0,3)
+ Kernel = ...
+ for nn (0,4)
+ for ff (0,512)
+ for yy (0,7)
+ for xx (0,7)
+ for nn_c (None)
+ for ff_c (None)
+ for yy_c (None)
+ for xx_c (None)
+ for rc (None)
+ for ax0 (None)
+ for ax1 (None)
+ for ax2 (None)
+ for ax3 (None)
+ Kernel.global = ...
+ for ry (None)
+ for rx (None)
+ compute.global = ...
+ compute = ...
+ for ax0.0 (0,2)
+ for ax1 (0,512)
+ for ax0.1 (0,2)
+ for ax2 (0,7)
+ for ax3 (0,7)
+ T_add = ...
+ """
+ s1 = dag.infer_bound_from_state(s0)
+ assert s1[conv].iters[0].range.extent == 4
+ assert s1[conv].iters[1].range.extent == 512
+ assert s1[conv].iters[2].range.extent == 7
+ assert s1[conv].iters[3].range.extent == 7
+ assert s1[kernel_global].iters[0].range.extent == 1
+ assert s1[kernel_global].iters[1].range.extent == 1
+ assert s1[kernel_global].iters[2].range.extent == 3
+ assert s1[kernel_global].iters[3].range.extent == 3
+ assert s1[conv_global].iters[0].range.extent == 1
+ assert s1[conv_global].iters[1].range.extent == 1
+ assert s1[conv_global].iters[2].range.extent == 1
+ assert s1[conv_global].iters[3].range.extent == 1
+ assert s1[conv_global].iters[4].range.extent == 512
+ assert s1[conv_global].iters[5].range.extent == 3
+ assert s1[conv_global].iters[6].range.extent == 3
+
+ # 3: two level cache_read with compute_at
+ # preparing for GPU's shared memory & local memory
+ pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global])
+ pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global])
+ s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2])
+ s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4])
+
+ # 4: cache_read with multi readers
+ # This stage cannot be compute at to its consumer
+ s0.cache_read(data, "global", [pad_temp, add])
+ """
+ Placeholder: Data, Kernel_data
+ for ax0 (0,4)
+ for ax1 (0,512)
+ for ax2 (0,7)
+ for ax3 (0,7)
+ Data.global = ...
+ for i0 (0,4)
+ for i1 (0,512)
+ for i2 (0,9)
+ for i3 (0,9)
+ pad_temp = ...
+ for i0 (0,512)
+ for i1 (0,512)
+ for i2 (0,3)
+ for i3 (0,3)
+ Kernel_split = ...
+ for i0 (0,512)
+ for i1 (0,512)
+ for i2 (0,3)
+ for i3 (0,3)
+ Kernel = ...
+ for nn (0,4)
+ for ff (0,512)
+ for yy (0,7)
+ for xx (0,7)
+ for nn_c (None)
+ for ff_c (None)
+ for yy_c (None)
+ for ax0 (None)
+ for ax1 (None)
+ for ax2 (None)
+ for ax3 (None)
+ pad_temp.global = ...
+ for xx_c (None)
+ for rc (None)
+ for ax0 (None)
+ for ax1 (None)
+ for ax2 (None)
+ for ax3 (None)
+ Kernel.global = ...
+ for ax0 (None)
+ for ax1 (None)
+ for ax2 (None)
+ for ax3 (None)
+ pad_temp.global.shared = ...
+ for ry (None)
+ for rx (None)
+ compute.global = ...
+ compute = ...
+ for ax0.0 (0,2)
+ for ax1 (0,512)
+ for ax0.1 (0,2)
+ for ax2 (0,7)
+ for ax3 (0,7)
+ T_add = ...
+ """
+ s1 = dag.infer_bound_from_state(s0)
+ assert s1[conv].iters[0].range.extent == 4
+ assert s1[conv].iters[1].range.extent == 512
+ assert s1[conv].iters[2].range.extent == 7
+ assert s1[conv].iters[3].range.extent == 7
+ assert s1[kernel_global].iters[0].range.extent == 1
+ assert s1[kernel_global].iters[1].range.extent == 1
+ assert s1[kernel_global].iters[2].range.extent == 3
+ assert s1[kernel_global].iters[3].range.extent == 3
+ assert s1[conv_global].iters[0].range.extent == 1
+ assert s1[conv_global].iters[1].range.extent == 1
+ assert s1[conv_global].iters[2].range.extent == 1
+ assert s1[conv_global].iters[3].range.extent == 1
+ assert s1[conv_global].iters[4].range.extent == 512
+ assert s1[conv_global].iters[5].range.extent == 3
+ assert s1[conv_global].iters[6].range.extent == 3
+ assert s1[pad_temp_global].iters[0].range.extent == 1
+ assert s1[pad_temp_global].iters[1].range.extent == 512
+ assert s1[pad_temp_global].iters[2].range.extent == 3
+ assert s1[pad_temp_global].iters[3].range.extent == 3
+ assert s1[pad_temp_shared].iters[0].range.extent == 1
+ assert s1[pad_temp_shared].iters[1].range.extent == 1
+ assert s1[pad_temp_shared].iters[2].range.extent == 3
+ assert s1[pad_temp_shared].iters[3].range.extent == 3
+
+ # 5: cache_write with multi outputs
+ # TVM's cache_write actually has a bug with this case:
+ #
+ # After schedule.cache_write, TVM generate one new stage:
+ # From: kernel_data -> kernel_split -> kernel
+ # To: kernel_data -> kernel_split_global -> kernel_split -> kernel
+ #
+ # But with topo sort analyse, we get:
+ # // kernel_data -> kernel_split_global -> kernel_split -> kernel
+ # \ /
+ # ----------------> kernel_split ---------------->
+ #
+ # TODO(jcf94): Seems there's bug with the input/output tensor. Such multi outputs case
+ # should be unusual, so we make some hack on DoCacheWrite. This should be fixed later.
+ kernel_split_global = s0.cache_write(kernel_split, "global")
+ """
+ Placeholder: Data, Kernel_data
+ for ax0 (0,4)
+ for ax1 (0,512)
+ for ax2 (0,7)
+ for ax3 (0,7)
+ Data.global = ...
+ for i0 (0,4)
+ for i1 (0,512)
+ for i2 (0,9)
+ for i3 (0,9)
+ pad_temp = ...
+ for i0_c (0,512)
+ for i1_c (0,512)
+ for i2_c (0,3)
+ for i3_c (0,3)
+ Kernel_split.global = ...
+ for i0 (0,512)
+ for i1 (0,512)
+ for i2 (0,3)
+ for i3 (0,3)
+ Kernel_split = ...
+ (******* Bug here, there should not be two kernel_split stage *******)
+ for i0 (0,512)
+ for i1 (0,512)
+ for i2 (0,3)
+ for i3 (0,3)
+ Kernel_split = ...
+ (******* Bug here, there should not be two kernel_split stage *******)
+ for i0 (0,512)
+ for i1 (0,512)
+ for i2 (0,3)
+ for i3 (0,3)
+ Kernel = ...
+ for nn (0,4)
+ for ff (0,512)
+ for yy (0,7)
+ for xx (0,7)
+ for nn_c (None)
+ for ff_c (None)
+ for yy_c (None)
+ for ax0 (None)
+ for ax1 (None)
+ for ax2 (None)
+ for ax3 (None)
+ pad_temp.global = ...
+ for xx_c (None)
+ for rc (None)
+ for ax0 (None)
+ for ax1 (None)
+ for ax2 (None)
+ for ax3 (None)
+ Kernel.global = ...
+ for ax0 (None)
+ for ax1 (None)
+ for ax2 (None)
+ for ax3 (None)
+ pad_temp.global.shared = ...
+ for ry (None)
+ for rx (None)
+ compute.global = ...
+ compute = ...
+ for ax0.0 (0,2)
+ for ax1 (0,512)
+ for ax0.1 (0,2)
+ for ax2 (0,7)
+ for ax3 (0,7)
+ T_add = ...
+ """
+ assert len(s0[kernel_split].iters) == len(s0[kernel_split_global].iters)
+ for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters):
+ assert it0.range == it1.range
+
if __name__ == "__main__":
test_split_fuse_reorder_annotation()
test_compute_at_root_inline()
+ test_cache_read_write()
C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
D = topi.nn.relu(C)
k = te.reduce_axis((0, 512), name='k')
- E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='C')
+ E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E')
F = topi.nn.relu(E)
dag = auto_scheduler.ComputeDAG([A, B, F])
s.unroll(C, s[C].iters[4])
# Vectorize
s.vectorize(C, s[C].iters[6])
+ # Cache Read
+ D_global = s.cache_read(D, "global", [E])
+ s.compute_at(D_global, E, s[E].iters[2])
+ # Cache Write
+ s.cache_write(D, "shared")
target = tvm.target.create("llvm")
task = auto_scheduler.SearchTask(dag, "test", target)