From b8f8b8d9a804f5867daddd9c5d7ddc0ed4d0d199 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 27 Jul 2020 12:35:52 +0800 Subject: [PATCH] [Ansor][AutoTVM v2.0] Phase 1: Add cache_read/cache_write steps (#6107) * Add cache_read/cache_write step * Update * Update * Update * Update state->current_compute_dag to Optional * Update * Update doc * Update * Update * Doc update * Update --- include/tvm/auto_scheduler/compute_dag.h | 11 + include/tvm/auto_scheduler/loop_state.h | 82 ++++-- include/tvm/auto_scheduler/transform_step.h | 248 +++++++++++++--- python/tvm/auto_scheduler/compute_dag.py | 10 +- python/tvm/auto_scheduler/loop_state.py | 113 +++++++- src/auto_scheduler/compute_dag.cc | 39 ++- src/auto_scheduler/loop_state.cc | 58 ++++ src/auto_scheduler/transform_step.cc | 313 ++++++++++++++++++++- .../unittest/test_auto_scheduler_loop_state.py | 275 ++++++++++++++++++ .../python/unittest/test_auto_scheduler_measure.py | 7 +- 10 files changed, 1043 insertions(+), 113 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 71652fd..69b74bf 100644 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -238,6 +238,17 @@ class ComputeDAG : public ObjectRef { */ State InferBound(const State& state) const; + /*! + * \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial + * ComputeDAG may not be up-to-date. This function replays the given transform steps from the + * initial state and returns an up-to-date ComputeDAG. + * \param steps The steps to be replaied. Usually we'll filter out the unused steps to speed up + * the replay process, since we only intend to get a ComputeDAG with the up-to-date op stage + * structure. + * \return The up-to-date ComputeDAG. + */ + ComputeDAG ReplayAndGetDAG(const Array& steps) const; + TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); }; diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index 4e9cb9b..1c8ea77 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -182,16 +182,18 @@ class AttachMap : public ObjectRef { public: /*! * \brief Process the stage/iterator mapping after compute at. - * \param stage_id The index of the stage to be compute at. + * \param stage_id The index of the stage to be computed at. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter_id The index of iterator in target stage that this step will compute at to. */ void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); + /*! * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage. - * \param stage_id The index of the stage to be compute at. + * \param stage_id The index of the stage to be computed at. */ void DeleteStage(int stage_id); + /*! * \brief Find the relations of original iterators in AttachMap, and update them with the new * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated. @@ -201,6 +203,17 @@ class AttachMap : public ObjectRef { void UpdateIters(const std::vector& original_iters, const std::vector& new_iters); + /*! + * \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset + * to stage indexes that are larger than the start_id. Used for steps that insert new stages to + * ComputeDAG(e.g. CacheRead/CacheWrite step). + * \param start_id The index threshold, stage indexes in AttachMap which are larger than this + * will be applied the extra offset. + * \param offset The index offset to be added to the stage index. + * \return The updated AttachMap after applying stage index offset. + */ + AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const; + TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); @@ -231,6 +244,12 @@ class StateNode : public Object { * operation. */ AttachMap attach_map; + /*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, means + * no modification to the original ComputeDAG. + * Otherwise, it means some steps (e.g., CacheReadStep/CacheWriteStep) have modified the + * ComputeDAG, the stored value is the up-to-date ComputeDAG for this state. + */ + Optional current_compute_dag; /*! * \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all * tile sizes of the state is filled. Only concrete state can be apply to TVM schedule. @@ -245,15 +264,6 @@ class StateNode : public Object { static constexpr const char* _type_key = "auto_scheduler.State"; TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); - - private: - /*! - * \brief The up-to-date ComputeDAG of this state, used for some steps that may change the - * stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added - * later). - * The default value is an empty ObjectRef. (means no modification to the original DAG) - */ - ObjectRef current_compute_dag; }; /*! @@ -290,7 +300,7 @@ class State : public ObjectRef { /********** Step APIs working on single stage **********/ /*! - * \brief Schedule primitive corresponds to te.bind. + * \brief Schedule primitive corresponds to `te::Stage::bind`. * \param stage_id The index of the stage to be binded. * \param it The iterator to be binded. * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as @@ -299,14 +309,14 @@ class State : public ObjectRef { */ TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); /*! - * \brief Schedule primitive corresponds to te.parallel. + * \brief Schedule primitive corresponds to `te::Stage::parallel`. * \param stage_id The index of the stage to be paralleled. * \param it The iterator to be paralleled. * \return The iterator result after parallel. */ TVM_DLL Iterator parallel(int stage_id, const Iterator& it); /*! - * \brief Schedule primitive corresponds to te.unroll. + * \brief Schedule primitive corresponds to `te::Stage::unroll`. * \param stage_id The index of the stage to be unrolled. * \param it The iterator to be unrolled. * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be @@ -315,14 +325,14 @@ class State : public ObjectRef { */ TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); /*! - * \brief Schedule primitive corresponds to te.vectorize. + * \brief Schedule primitive corresponds to `te::Stage::vectorize`. * \param stage_id The index of the stage to be vectorized. * \param it The iterator to be vectorized. * \return The iterator result after vectorize. */ TVM_DLL Iterator vectorize(int stage_id, const Iterator& it); /*! - * \brief Schedule primitive corresponds to te.fuse. + * \brief Schedule primitive corresponds to `te::Stage::fuse`. * \param stage_id The index of the stage to be fused. * \param iters The iterators to be fused. * \return The iterator result after fuse. @@ -331,13 +341,13 @@ class State : public ObjectRef { */ TVM_DLL Iterator fuse(int stage_id, const Array& iters); /*! - * \brief Schedule primitive corresponds to te.reorder. + * \brief Schedule primitive corresponds to `te::Stage::reorder`. * \param stage_id The index of the stage to be reordered. * \param order The expected iterator order. */ TVM_DLL void reorder(int stage_id, const Array& order); /*! - * \brief Schedule primitive corresponds to te.split. + * \brief Schedule primitive corresponds to `te::Stage::split`. * \param stage_id The index of the stage to be split. * \param it The iterator to be split. * \param lengths The multiple split factors. Can be None to be filled by search policy. @@ -353,8 +363,8 @@ class State : public ObjectRef { /********** Step APIs working on multiple stages **********/ /*! - * \brief Schedule primitive corresponds to te.compute_at. - * \param stage_id The index of the stage to be reordered. + * \brief Schedule primitive corresponds to `te::Stage::compute_at`. + * \param stage_id The index of the stage to be computed at. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter The iterator in target stage that this step will compute at to. * \note After compute_at, we need careful dependency analysis to compute the accurate bound @@ -364,13 +374,13 @@ class State : public ObjectRef { */ TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! - * \brief Schedule primitive corresponds to te.compute_inline. - * \param stage_id The index of the stage to be reordered. + * \brief Schedule primitive corresponds to `te::Stage::compute_inline`. + * \param stage_id The index of the stage to be marked compute inlined. */ TVM_DLL void compute_inline(int stage_id); /*! - * \brief Schedule primitive corresponds to te.compute_root. - * \param stage_id The index of the stage to be reordered. + * \brief Schedule primitive corresponds to `te::Stage::compute_root`. + * \param stage_id The index of the stage to be marked compute at root. * \note After compute_root, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. @@ -378,6 +388,30 @@ class State : public ObjectRef { */ TVM_DLL void compute_root(int stage_id); + /********** Step APIs adding new stages **********/ + + /*! + * \brief Schedule primitive corresponds to `te::Schedule::cache_read`. + * \param stage_id The index of the stage to be cache read. + * \param scope_name The scope name of the newly added read stage. + * \param reader_stage_ids The indices of read stages. + * \param dag The original ComputeDAG of this state. + * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the + * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. + */ + int cache_read(int stage_id, const String& scope_name, const Array& reader_stage_ids, + const ComputeDAG& dag); + /*! + * \brief Schedule primitive corresponds to `te::Schedule::cache_write`. + * \param stage_id The index of the stage to be cache write. + * \param scope_name The scope name of the newly added compute stage. + * \param dag The original ComputeDAG of this state. + * \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the + * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. + * This step will cache write all output tensors of the target stage. + */ + int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag); + TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); }; diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index b23137a..83d6e29 100644 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -56,6 +56,13 @@ namespace auto_scheduler { typedef Map, ObjectHash, ObjectEqual> StageToAxesMap; +/*! + * \brief Update the current stage IterVar information to StageToAxesMap. + * \param stage A te::Stage Object. + * \param stage_to_axes A mutable pointer to StageToAxesMap, this map will be updated. + */ +void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes); + /*! \brief The type of an iterator. */ enum class IteratorKind : int { /*! \brief Spatial iterator. */ @@ -183,7 +190,7 @@ Step StepReadFromRecord(dmlc::JSONReader* reader); /*! * \brief Apply the step to State. * \param step The step to be applied to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \param dag The original ComputeDAG of this state. */ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); @@ -191,20 +198,25 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); /*! * \brief Apply the step to tvm.schedule. * \param step The step to be applied to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need + * `te::Schedule` API. (e.g. CacheRead/CacheWrite step) */ -void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes); +void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule); /*! * \brief Print the step as equivalent python schedule API. * \param step The step to be applied to python API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g. + * CacheRead/CacheWrite step) * \return Python schedule code. */ String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes); + StageToAxesMap* stage_to_axes, te::Schedule* schedule); /********** Steps working on single stage **********/ @@ -223,22 +235,22 @@ class AnnotationStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \return The iterator result after annotate. */ Iterator ApplyToState(State* state) const; /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -283,7 +295,7 @@ class FuseStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \return The iterator result after fuse. * \note If the iterators to be fused have stages attached at them(by compute_at), the fused * result will become the new attach point. @@ -292,16 +304,16 @@ class FuseStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return The iterator result after fuse. */ tir::IterVar ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -348,21 +360,21 @@ class ReorderStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. */ void ApplyToState(State* state) const; /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -418,7 +430,7 @@ class SplitStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \return The iterator results after split. * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner * most iterator of split results will become the new attach point. @@ -427,8 +439,8 @@ class SplitStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return The iterator results after split. */ Array ApplyToSchedule(Array* stages, @@ -436,8 +448,8 @@ class SplitStepNode : public StepNode { /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -489,7 +501,7 @@ class ComputeAtStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. * \note After compute_at, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. @@ -499,15 +511,15 @@ class ComputeAtStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -526,7 +538,7 @@ class ComputeAtStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be compute at. + * \param stage_id The index of the stage to be computed at. * \param target_stage_id The index of stage that this step will compute at to. * \param target_iter_id The index of iterator in target stage that this step will compute at to. */ @@ -549,22 +561,22 @@ class ComputeInlineStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. + * \param state A mutable pointer to state, which will be updated. */ void ApplyToState(State* state) const; /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return The iterator result after fuse. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -583,7 +595,7 @@ class ComputeInlineStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be compute inline. + * \param stage_id The index of the stage to be marked compute inlined. */ explicit ComputeInlineStep(int stage_id); @@ -604,8 +616,8 @@ class ComputeRootStepNode : public StepNode { /*! * \brief Apply the current step to State. - * \param state A mutable pointer to State. - * \note After compute_at, we need careful dependency analysis to compute the accurate bound + * \param state A mutable pointer to state, which will be updated. + * \note After compute_root, we need careful dependency analysis to compute the accurate bound * information. However, it is relatively expensive and complicated, so we just fill "None" as * bound for the newly created iterators. * Call ComputeDAG::InferBound on the updated state to get the complete bound information. @@ -614,16 +626,16 @@ class ComputeRootStepNode : public StepNode { /*! * \brief Apply the current step to tvm.schedule. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return The iterator result after fuse. */ void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; /*! * \brief Print the current step as equivalent python schedule API. - * \param stages A pointer to a `te::Stage` Array. - * \param stage_to_axes A pointer to a StageToAxesMap. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. * \return Python schedule code. */ String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; @@ -642,7 +654,7 @@ class ComputeRootStep : public Step { public: /*! * \brief The constructor. - * \param stage_id The index of the stage to be compute root + * \param stage_id The index of the stage to be marked compute at root. */ explicit ComputeRootStep(int stage_id); @@ -656,6 +668,150 @@ class ComputeRootStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); }; +/********** Primitives adding new stages **********/ + +/*! + * \brief Cache read step that corresponds to te::Schedule::cache_read. + * \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date ComputeDAG + * is stored in State's `current_compute_dag`. + */ +class CacheReadStepNode : public StepNode { + public: + /*! \brief The scope name of the newly added read stage. (e.g. local, shared, global) */ + String scope_name; + /*! \brief The indices of read stages. */ + Array reader_stage_ids; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to state, which will be updated. + * \param dag The original ComputeDAG of this state. + * \return The index of the new added stage. + */ + int ApplyToState(State* state, const ComputeDAG& dag) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param schedule A mutable pointer to a te::Schedule. + * \return The output Tensor of the new added stage. + */ + te::Tensor ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param schedule A mutable pointer to a te::Schedule. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + static constexpr const char* record_prefix_str = "CHR"; + + static constexpr const char* _type_key = "auto_scheduler.CacheReadStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); +}; + +/*! + * \brief Managed reference to CacheReadStepNode. + * \sa CacheReadStepNode + */ +class CacheReadStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be cache read. + * \param scope_name The scope name of the newly added read stage. + * \param reader_stage_ids The indices of read stages. + */ + CacheReadStep(int stage_id, String scope_name, const Array& reader_stage_ids); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit CacheReadStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode); +}; + +/*! + * \brief Cache write step that corresponds to te::Schedule::cache_write. + * \note Cache write step will add an extra stage to the original ComputeDAG, a up-to-date + * ComputeDAG is stored in State's `current_compute_dag`. + * This step will cache write all output tensors of the target stage. + */ +class CacheWriteStepNode : public StepNode { + public: + /*! \brief The scope name of the newly added compute stage. (e.g. local, shared, global) */ + String scope_name; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to state, which will be updated. + * \param dag The original ComputeDAG of this state. + * \return The index of the new added stage. + */ + int ApplyToState(State* state, const ComputeDAG& dag) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param schedule A mutable pointer to a te::Schedule. + * \return The output Tensors of the new added stage. + */ + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param schedule A mutable pointer to a te::Schedule. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + static constexpr const char* record_prefix_str = "CHW"; + + static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); +}; + +/*! + * \brief Managed reference to CacheWriteStepNode. + * \sa CacheWriteStepNode + */ +class CacheWriteStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be cache write. + * \param scope_name The scope name of the newly added compute stage. + */ + CacheWriteStep(int stage_id, String scope_name); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit CacheWriteStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); +}; + } // namespace auto_scheduler } // namespace tvm diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index 115d28b..e08454f 100644 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -126,11 +126,17 @@ class ComputeDAG(Object): Returns ------- - state : State + updated_state : State The State with complete bound information. """ state_obj = state if isinstance(state, StateObject) else state.state_object - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) + updated_state = State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) + # Copy the stage_id_map from the original state to make sure the old indices are still + # valid + if isinstance(state, State): + for k, v in state.stage_id_map.items(): + updated_state.stage_id_map[k] = v + return updated_state def __hash__(self): # TODO(merrymercy): Implement this more carefully and move this to c++ as a member function diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index ab041cf..8c3a936 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -127,7 +127,8 @@ class State: return [stage.op for stage in self.stages] def bind(self, stage, iterator, thread_name): - """ Schedule primitive corresponds to te.bind. + """ Schedule primitive corresponds to `te.Stage.bind`, see also the `te.Stage` for more + details. Parameters ---------- @@ -160,7 +161,8 @@ class State: return res def parallel(self, stage, iterator): - """ Schedule primitive corresponds to te.parallel. + """ Schedule primitive corresponds to `te.Stage.parallel`, see also the `te.Stage` for more + details. Parameters ---------- @@ -180,7 +182,8 @@ class State: return res def unroll(self, stage, iterator, max_unroll=None): - """ Schedule primitive corresponds to te.unroll. + """ Schedule primitive corresponds to `te.Stage.unroll`, see also the `te.Stage` for more + details. Parameters ---------- @@ -203,7 +206,8 @@ class State: return res def vectorize(self, stage, iterator): - """ Schedule primitive corresponds to te.vectorize. + """ Schedule primitive corresponds to `te.Stage.vectorize`, see also the `te.Stage` for + more details. Parameters ---------- @@ -223,7 +227,8 @@ class State: return res def fuse(self, stage, iters): - """ Schedule primitive corresponds to te.fuse. + """ Schedule primitive corresponds to `te.Stage.fuse`, see also the `te.Stage` for more + details. Parameters ---------- @@ -248,7 +253,8 @@ class State: return res def reorder(self, stage, order): - """ Schedule primitive corresponds to te.reorder. + """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more + details. Parameters ---------- @@ -262,7 +268,8 @@ class State: order) def split(self, stage, iterator, lengths, inner_to_outer=True): - """ Schedule primitive corresponds to te.split. + """ Schedule primitive corresponds to `te.Stage.split`, see also the `te.Stage` for more + details. This API supports multiple split factors. (e.g. with 2 split factors, the original iterator will be split to 3 parts, use `inner_to_outer` to control the split order) @@ -295,12 +302,13 @@ class State: return res def compute_at(self, stage, target_stage, target_iter): - """ Schedule primitive corresponds to te.compute_at. + """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for + more details. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute at, which can be specified by the integer index, Operation, + The Stage to be computed at, which can be specified by the integer index, Operation, or output tensor of the stage. target_stage : Union[int, Operation, Tensor] The target stage of compute_at, which can be specified by the integer index, Operation, @@ -321,25 +329,27 @@ class State: target_iter) def compute_inline(self, stage): - """ Schedule primitive corresponds to te.compute_inline. + """ Schedule primitive corresponds to `te.Stage.compute_inline`, see also the `te.Stage` + for more details. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute inlined, which can be specified by the integer index, Operation, - or output tensor of the stage. + The Stage to be marked compute inlined, which can be specified by the integer index, + Operation, or output tensor of the stage. """ self.state_object = _ffi_api.StateComputeInline(self.state_object, self._resolve_stage_id(stage)) def compute_root(self, stage): - """ Schedule primitive corresponds to te.compute_root. + """ Schedule primitive corresponds to `te.Stage.compute_root`, see also the `te.Stage` for + more details. Parameters ---------- stage : Union[int, Operation, Tensor] - The Stage to be compute root, which can be specified by the integer index, Operation, - or output tensor of the stage. + The Stage to be marked compute at root, which can be specified by the integer index, + Operation, or output tensor of the stage. Notes ----- @@ -351,6 +361,74 @@ class State: self.state_object = _ffi_api.StateComputeRoot(self.state_object, self._resolve_stage_id(stage)) + def cache_read(self, stage, scope_name, reader_stages): + """ Schedule primitive corresponds to `te.Schedule.cache_read`, see also the `te.Schedule` + for more details. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be cache read, which can be specified by the integer index, Operation, + or output tensor of the stage. + scope_name : str + The scope name of the newly added read stage. + reader_stages : List[Union[int, Operation, Tensor]] + The reader stages. Each of the list can be specified by the integer index, Operation, + or output tensor of the stage. + + Returns + ------- + new_stage_op : Operator + The Operator of the new added stage. + + Notes + ----- + Cache read step will insert an extra stage to the original ComputeDAG (at the back of the + target stage). + """ + reader_stage_ids = [self._resolve_stage_id(i) for i in reader_stages] + self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, + self._resolve_stage_id(stage), + scope_name, reader_stage_ids, + self.compute_dag) + # Add a new stage will change all ops behind the added stage. But we still want to keep the + # original ops map, apply stage id offset to stage_id_map to make them work. + self._apply_stage_id_offset(int(new_stage_id)) + self._update_stage_id_map() + return self.stages[int(new_stage_id)].op + + def cache_write(self, stage, scope_name): + """ Schedule primitive corresponds to `te.Schedule.cache_write`, see also the `te.Schedule` + for more details. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be cache write, which can be specified by the integer index, Operation, + or output tensor of the stage. + scope_name : str + The scope name of the newly added compute stage. + + Returns + ------- + new_stage_op : Operator + The Operator of the new added stage. + + Notes + ----- + Cache write step will insert an extra stage to the original ComputeDAG (in the front of the + target stage). + This step will cache write all output tensors of the target stage. + """ + self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, + self._resolve_stage_id(stage), + scope_name, self.compute_dag) + # Add a new stage will change all ops behind the added stage. But we still want to keep the + # original ops map, apply stage id offset to stage_id_map to make them work. + self._apply_stage_id_offset(int(new_stage_id)) + self._update_stage_id_map() + return self.stages[int(new_stage_id)].op + def copy(self): """ Do deep copy of this State. """ state = State(self.state_object, self.compute_dag) @@ -371,6 +449,11 @@ class State: for index, stage in enumerate(self.stages): self.stage_id_map[stage.op] = index + def _apply_stage_id_offset(self, start_id, offset=1): + for key, value in self.stage_id_map.items(): + if value >= start_id: + self.stage_id_map[key] = value + offset + def __getitem__(self, key): if isinstance(key, Tensor): key = key.op diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 68d1bb4..2f6e948 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -645,24 +645,6 @@ ComputeDAG::ComputeDAG(Array tensors) { data_ = std::move(node); } -// Update the te::stage to tir::IterVar axis mapping -void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) { - if (auto pop = stage->op.as()) { - Array axes; - for (const auto& axis : pop->axis) { - axes.push_back(axis); - } - for (const auto& axis : pop->reduce_axis) { - axes.push_back(axis); - } - stage_to_axes->Set(stage, std::move(axes)); - } else if (stage->op->IsInstance()) { - {} // do nothing on Placeholder - } else { - LOG(FATAL) << "Invalid op " << stage->op; - } -} - std::pair> ComputeDAG::ApplySteps( const Array& transform_steps, Array* stages, StageToAxesMap* stage_to_axes) const { @@ -696,7 +678,7 @@ std::pair> ComputeDAG::ApplySteps( // Apply the history steps to TVM schedule // Call each step's ApplyToSchedule method for (const auto& step : transform_steps) { - StepApplyToSchedule(step, stages, stage_to_axes); + StepApplyToSchedule(step, stages, stage_to_axes, &schedule); } return std::make_pair(schedule, operator->()->tensors); @@ -740,7 +722,7 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const } // Call each step's PrintAsPythonAPI method for (const auto& step : transform_steps) { - ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes); + ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule); } return ss.str(); @@ -806,6 +788,23 @@ State ComputeDAG::InferBound(const State& state) const { return ret_state; } +ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array& transform_steps) const { + te::Schedule sch; + Array old_tensors; + std::tie(sch, old_tensors) = ApplySteps(transform_steps); + + Array new_tensors; + for (auto stage : sch->stages) { + if (stage->op->IsInstance() || stage->is_output) { + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + new_tensors.push_back(stage->op.output(i)); + } + } + } + + return ComputeDAG(new_tensors); +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 35d899a..67c6b38 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -23,6 +23,7 @@ * see auto_scheduler/loop_state.h for more explanation. */ +#include #include #include #include @@ -150,6 +151,36 @@ void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { } } +AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const { + AttachMap map = AttachMap(make_object()); + auto pmap = map.CopyOnWrite(); + for (const auto& x : operator->()->stage_to_attach_iter) { + auto key = x.first; + if (key >= start_id) { + key += offset; + } + auto value = x.second; + if (value.first >= start_id) { + value.first += offset; + } + pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); + } + for (const auto& x : operator->()->iter_to_attached_stages) { + auto key = x.first; + if (key.first >= start_id) { + key.first += offset; + } + auto value = x.second; + for (auto& i : value) { + if (i >= start_id) { + i += offset; + } + } + pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); + } + return map; +} + /********** State **********/ State::State(const Array& ops) { auto node = make_object(); @@ -257,6 +288,19 @@ void State::compute_root(int stage_id) { step->ApplyToState(this); } +int State::cache_read(int stage_id, const String& scope_name, + const Array& reader_stage_ids, const ComputeDAG& dag) { + CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this, dag); +} + +int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) { + CacheWriteStep step = CacheWriteStep(stage_id, scope_name); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this, dag); +} + void State::ApplySteps(const ComputeDAG& dag) { CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; @@ -429,6 +473,20 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot") return state; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead") + .set_body_typed([](State state, int stage_id, const String& scope_name, + const Array& reader_stage_ids, const ComputeDAG& dag) { + int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag); + return Array{state, Integer(res)}; + }); + +TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite") + .set_body_typed([](State state, int stage_id, const String& scope_name, + const ComputeDAG& task_dag) { + int res = state.cache_write(stage_id, scope_name, task_dag); + return Array{state, Integer(res)}; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); }); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index b1b3b94..5c5cc4b 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -23,6 +23,7 @@ * They are similar to the schedule primitives in te::Stage. */ +#include #include #include #include @@ -37,6 +38,24 @@ namespace tvm { namespace auto_scheduler { +// Update the te::stage to tir::IterVar axis mapping +void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) { + if (auto pop = stage->op.as()) { + Array axes; + for (const auto& axis : pop->axis) { + axes.push_back(axis); + } + for (const auto& axis : pop->reduce_axis) { + axes.push_back(axis); + } + stage_to_axes->Set(stage, std::move(axes)); + } else if (stage->op->IsInstance()) { + {} // do nothing on Placeholder + } else { + LOG(FATAL) << "Invalid op " << stage->op; + } +} + const char* IteratorAnnotationString[] = { "for", // kNone = 0 "unroll", // kUnroll = 1 @@ -72,6 +91,10 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return ComputeInlineStep(reader); } else if (name == ComputeRootStepNode::record_prefix_str) { return ComputeRootStep(reader); + } else if (name == CacheReadStepNode::record_prefix_str) { + return CacheReadStep(reader); + } else if (name == CacheWriteStepNode::record_prefix_str) { + return CacheWriteStep(reader); } else { LOG(FATAL) << "Invalid step format: " << name; } @@ -94,14 +117,17 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state, dag); + } else if (auto ps = step.as()) { + ps->ApplyToState(state, dag); } else { LOG(FATAL) << "Invalid step: " << step; } } -void StepApplyToSchedule(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes) { - // We need this runtime dispatcher because different steps have different function signatures +void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) { if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -116,14 +142,17 @@ void StepApplyToSchedule(const Step& step, Array* stages, ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, schedule); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, schedule); } else { LOG(FATAL) << "Invalid Step: " << step; } } String StepPrintAsPythonAPI(const Step& step, Array* stages, - StageToAxesMap* stage_to_axes) { - // We need this runtime dispatcher because different steps have different function signatures + StageToAxesMap* stage_to_axes, te::Schedule* schedule) { if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -138,6 +167,10 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -925,5 +958,275 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, return ss.str(); } +/********** Primitives adding new stages **********/ + +/*! + * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep, + * RfactorStep). This will return all steps that can change the number of stages in a ComputeDAG, + * and stop by the current step. + */ +Array GetFormerStageModifiableSteps(Step current_step, const Array& transform_steps) { + Array ret_steps; + for (const Step& step : transform_steps) { + if (step->IsInstance() || step->IsInstance()) { + ret_steps.push_back(step); + } + // TODO(jcf94): add rfactor support + // A state may have multiple stage modifiable steps, stop by the current step to avoid + // replaying excess steps + if (step.same_as(current_step)) { + break; + } + } + return ret_steps; +} + +/********** Cache Read **********/ +CacheReadStep::CacheReadStep(int stage_id, String scope_name, + const Array& reader_stage_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + node->reader_stage_ids = reader_stage_ids; + data_ = std::move(node); +} + +CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + std::string string_value; + reader->Read(&string_value); + node->scope_name = std::move(string_value); + s = reader->NextArrayItem(); + CHECK(s); + std::vector int_list; + reader->Read(&int_list); + Array reader_stage_ids; + for (int i : int_list) { + reader_stage_ids.push_back(i); + } + node->reader_stage_ids = std::move(reader_stage_ids); + data_ = std::move(node); +} + +void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArraySeperator(); + writer->WriteString(scope_name); + writer->WriteArrayItem(IntArrayToVector(reader_stage_ids)); +} + +int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { + StateNode* pstate = state->CopyOnWrite(); + const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG( + GetFormerStageModifiableSteps(GetRef(this), (*state)->transform_steps)); + + // target_stage -> target_stage + target_store + // Update the op of the target stage, insert a new cache read stage behind, update the op of + // later stages, then update the stage_id mapping in AttachMap + int added_stage_id = stage_id + 1; + Stage tmp_stage = pstate->stages[stage_id]; + tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id]; + pstate->stages.Set(stage_id, std::move(tmp_stage)); + pstate->stages.insert(pstate->stages.begin() + added_stage_id, + Stage(current_compute_dag->ops[added_stage_id])); + for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) { + tmp_stage = pstate->stages[i]; + tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i]; + pstate->stages.Set(i, std::move(tmp_stage)); + } + pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id); + pstate->current_compute_dag = std::move(current_compute_dag); + + return added_stage_id; +} + +te::Tensor CacheReadStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + const te::Stage& stage = (*stages)[stage_id]; + Array readers; + for (const auto& i : reader_stage_ids) { + readers.push_back((*stages)[i]->origin_op); + } + auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); + + const auto& new_stage = (*schedule)[out->op]; + UpdateStageToAxesMap(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id + 1, new_stage); + + return out; +} + +String CacheReadStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + std::stringstream ss; + // Since the original stage will be changed after schedule apply, keep a copy here + // These information will be used to print Python API string later + auto stage = (*stages)[stage_id]; + Array reader_stages; + for (size_t i = 0; i < reader_stage_ids.size(); ++i) { + reader_stages.push_back((*stages)[reader_stage_ids[i]]); + } + auto out = ApplyToSchedule(stages, stage_to_axes, schedule); + + ss << CleanName(out->op->name) << " = " + << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", [" + << CleanName(reader_stages[0]->op->name); + for (size_t i = 1; i < reader_stage_ids.size(); ++i) { + ss << ", " << CleanName(reader_stages[i]->op->name); + } + ss << "])\n"; + + // Print the iterators of the new added stage + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "tuple(" << CleanName(out->op->name) << ".op.axis)\n"; + + return ss.str(); +} + +/********** Cache Write **********/ +CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + data_ = std::move(node); +} + +CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + std::string string_value; + reader->Read(&string_value); + node->scope_name = std::move(string_value); + data_ = std::move(node); +} + +void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArraySeperator(); + writer->WriteString(scope_name); +} + +int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { + StateNode* pstate = state->CopyOnWrite(); + int last_dag_op_size = pstate->current_compute_dag + ? pstate->current_compute_dag.value().as()->ops.size() + : dag->ops.size(); + const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG( + GetFormerStageModifiableSteps(GetRef(this), (*state)->transform_steps)); + int added_ops = current_compute_dag->ops.size() - last_dag_op_size; + // TODO(jcf94): Update this check to equal after fixing the cache write bug in TVM + CHECK_GE(added_ops, 1); + + // target_stage -> cache_write_stage + target_stage + // Assume no step has been applied to the target stage before cache write. + // Insert a new cache write stage ahead, update the op of the target stage and later stages, then + // update the stage_id mapping in AttachMap + pstate->stages.insert(pstate->stages.begin() + stage_id, + Stage(current_compute_dag->ops[stage_id])); + pstate->stages.Set(stage_id + 1, Stage(current_compute_dag->ops[stage_id + 1])); + int next_stage_id = stage_id + 2; + // TODO(jc94): Fix the cache write bug in TVM and remove added_op == 2 support. + // TVM's cache_write has a bug with multi outputs. See + // `tests/python/unittest/test_auto_scheduler_loop_state.py::test_cache_read_write` test + // for more details + if (added_ops == 2) { + pstate->stages.insert(pstate->stages.begin() + next_stage_id, + Stage(current_compute_dag->ops[next_stage_id])); + next_stage_id++; + } else if (added_ops > 2) { + LOG(ERROR) << "Unexpected behavior of CacheWrite."; + } + for (size_t i = next_stage_id; i < current_compute_dag->ops.size(); ++i) { + Stage tmp_stage = pstate->stages[i]; + tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i]; + pstate->stages.Set(i, std::move(tmp_stage)); + } + pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id, added_ops); + pstate->current_compute_dag = std::move(current_compute_dag); + + return stage_id; +} + +Array CacheWriteStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + const te::Stage& stage = (*stages)[stage_id]; + Array tensor_array; + // If the target stage has multi outputs, TVM requires to cache_write + // all of them or schedule.cache_write will raise an error + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + tensor_array.push_back(stage->origin_op.output(i)); + } + auto outs = schedule->cache_write(tensor_array, scope_name); + + UpdateStageToAxesMap(stage, stage_to_axes); + // Even if there is multi outputs, TVM schedule only generate one + // new stage + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageToAxesMap(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +String CacheWriteStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + std::stringstream ss; + // Since the original stage will be changed after schedule apply, keep a copy here + // These information will be used to print Python API string later + te::Stage stage = (*stages)[stage_id]; + auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->name) << ", "; + } + ss << "= " + << "s.cache_write([" << CleanName(stage->op.output(0)->op->name); + for (auto i = 1; i < stage->op->num_outputs(); ++i) { + ss << ", " << CleanName(stage->op.output(i)->op->name); + } + ss << "], \"" << scope_name << "\")\n"; + + // Print the iterators of the new added stage + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "tuple(" << CleanName(out->op->name) << ".op.axis)" + << " + " + << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n"; + } + + return ss.str(); +} + } // namespace auto_scheduler } // namespace tvm diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index 32ea8fa..8282d4a 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -143,6 +143,281 @@ def test_compute_at_root_inline(): assert s0[conv].iters[6].range.extent == 7 +def test_cache_read_write(): + N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( + 1, 1), (1, 1) + + data = te.placeholder((N, CI, H, W), name='Data') + kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') + k0, k1 = te.compute(kernel_data.shape, + lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), + name='Kernel_split') + kernel = te.compute(kernel_data.shape, + lambda *i: k0(*i) + k1(*i), + name='Kernel') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) + relu = topi.nn.relu(conv) + add = topi.add(data, relu) + + dag = auto_scheduler.ComputeDAG([data, kernel_data, add]) + s0 = dag.get_init_state() + + pad_temp = s0.stage_ops[1] + kernel_split = s0.stage_ops[3] + + # 0: init state + ori_its = s0[add].iters + its = s0.split(add, s0[add].iters[0], [2]) + s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) + s0.compute_inline(relu) + + # 1: simple cache_write with compute_at + conv_global = s0.cache_write(conv, "global") + s0.compute_at(conv_global, conv, s0[conv].iters[3]) + + # 2: simple cache_read with compute_at + kernel_global = s0.cache_read(kernel, "global", [conv_global]) + s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4]) + """ + Placeholder: Data, Kernel_data + for i0 (0,4) + for i1 (0,512) + for i2 (0,9) + for i3 (0,9) + pad_temp = ... + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel_split = ... + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel = ... + for nn (0,4) + for ff (0,512) + for yy (0,7) + for xx (0,7) + for nn_c (None) + for ff_c (None) + for yy_c (None) + for xx_c (None) + for rc (None) + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + Kernel.global = ... + for ry (None) + for rx (None) + compute.global = ... + compute = ... + for ax0.0 (0,2) + for ax1 (0,512) + for ax0.1 (0,2) + for ax2 (0,7) + for ax3 (0,7) + T_add = ... + """ + s1 = dag.infer_bound_from_state(s0) + assert s1[conv].iters[0].range.extent == 4 + assert s1[conv].iters[1].range.extent == 512 + assert s1[conv].iters[2].range.extent == 7 + assert s1[conv].iters[3].range.extent == 7 + assert s1[kernel_global].iters[0].range.extent == 1 + assert s1[kernel_global].iters[1].range.extent == 1 + assert s1[kernel_global].iters[2].range.extent == 3 + assert s1[kernel_global].iters[3].range.extent == 3 + assert s1[conv_global].iters[0].range.extent == 1 + assert s1[conv_global].iters[1].range.extent == 1 + assert s1[conv_global].iters[2].range.extent == 1 + assert s1[conv_global].iters[3].range.extent == 1 + assert s1[conv_global].iters[4].range.extent == 512 + assert s1[conv_global].iters[5].range.extent == 3 + assert s1[conv_global].iters[6].range.extent == 3 + + # 3: two level cache_read with compute_at + # preparing for GPU's shared memory & local memory + pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global]) + pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global]) + s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2]) + s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4]) + + # 4: cache_read with multi readers + # This stage cannot be compute at to its consumer + s0.cache_read(data, "global", [pad_temp, add]) + """ + Placeholder: Data, Kernel_data + for ax0 (0,4) + for ax1 (0,512) + for ax2 (0,7) + for ax3 (0,7) + Data.global = ... + for i0 (0,4) + for i1 (0,512) + for i2 (0,9) + for i3 (0,9) + pad_temp = ... + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel_split = ... + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel = ... + for nn (0,4) + for ff (0,512) + for yy (0,7) + for xx (0,7) + for nn_c (None) + for ff_c (None) + for yy_c (None) + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + pad_temp.global = ... + for xx_c (None) + for rc (None) + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + Kernel.global = ... + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + pad_temp.global.shared = ... + for ry (None) + for rx (None) + compute.global = ... + compute = ... + for ax0.0 (0,2) + for ax1 (0,512) + for ax0.1 (0,2) + for ax2 (0,7) + for ax3 (0,7) + T_add = ... + """ + s1 = dag.infer_bound_from_state(s0) + assert s1[conv].iters[0].range.extent == 4 + assert s1[conv].iters[1].range.extent == 512 + assert s1[conv].iters[2].range.extent == 7 + assert s1[conv].iters[3].range.extent == 7 + assert s1[kernel_global].iters[0].range.extent == 1 + assert s1[kernel_global].iters[1].range.extent == 1 + assert s1[kernel_global].iters[2].range.extent == 3 + assert s1[kernel_global].iters[3].range.extent == 3 + assert s1[conv_global].iters[0].range.extent == 1 + assert s1[conv_global].iters[1].range.extent == 1 + assert s1[conv_global].iters[2].range.extent == 1 + assert s1[conv_global].iters[3].range.extent == 1 + assert s1[conv_global].iters[4].range.extent == 512 + assert s1[conv_global].iters[5].range.extent == 3 + assert s1[conv_global].iters[6].range.extent == 3 + assert s1[pad_temp_global].iters[0].range.extent == 1 + assert s1[pad_temp_global].iters[1].range.extent == 512 + assert s1[pad_temp_global].iters[2].range.extent == 3 + assert s1[pad_temp_global].iters[3].range.extent == 3 + assert s1[pad_temp_shared].iters[0].range.extent == 1 + assert s1[pad_temp_shared].iters[1].range.extent == 1 + assert s1[pad_temp_shared].iters[2].range.extent == 3 + assert s1[pad_temp_shared].iters[3].range.extent == 3 + + # 5: cache_write with multi outputs + # TVM's cache_write actually has a bug with this case: + # + # After schedule.cache_write, TVM generate one new stage: + # From: kernel_data -> kernel_split -> kernel + # To: kernel_data -> kernel_split_global -> kernel_split -> kernel + # + # But with topo sort analyse, we get: + # // kernel_data -> kernel_split_global -> kernel_split -> kernel + # \ / + # ----------------> kernel_split ----------------> + # + # TODO(jcf94): Seems there's bug with the input/output tensor. Such multi outputs case + # should be unusual, so we make some hack on DoCacheWrite. This should be fixed later. + kernel_split_global = s0.cache_write(kernel_split, "global") + """ + Placeholder: Data, Kernel_data + for ax0 (0,4) + for ax1 (0,512) + for ax2 (0,7) + for ax3 (0,7) + Data.global = ... + for i0 (0,4) + for i1 (0,512) + for i2 (0,9) + for i3 (0,9) + pad_temp = ... + for i0_c (0,512) + for i1_c (0,512) + for i2_c (0,3) + for i3_c (0,3) + Kernel_split.global = ... + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel_split = ... + (******* Bug here, there should not be two kernel_split stage *******) + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel_split = ... + (******* Bug here, there should not be two kernel_split stage *******) + for i0 (0,512) + for i1 (0,512) + for i2 (0,3) + for i3 (0,3) + Kernel = ... + for nn (0,4) + for ff (0,512) + for yy (0,7) + for xx (0,7) + for nn_c (None) + for ff_c (None) + for yy_c (None) + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + pad_temp.global = ... + for xx_c (None) + for rc (None) + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + Kernel.global = ... + for ax0 (None) + for ax1 (None) + for ax2 (None) + for ax3 (None) + pad_temp.global.shared = ... + for ry (None) + for rx (None) + compute.global = ... + compute = ... + for ax0.0 (0,2) + for ax1 (0,512) + for ax0.1 (0,2) + for ax2 (0,7) + for ax3 (0,7) + T_add = ... + """ + assert len(s0[kernel_split].iters) == len(s0[kernel_split_global].iters) + for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters): + assert it0.range == it1.range + if __name__ == "__main__": test_split_fuse_reorder_annotation() test_compute_at_root_inline() + test_cache_read_write() diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 333d20e..5f2f87a 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -35,7 +35,7 @@ def test_record(): C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') D = topi.nn.relu(C) k = te.reduce_axis((0, 512), name='k') - E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='C') + E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') F = topi.nn.relu(E) dag = auto_scheduler.ComputeDAG([A, B, F]) @@ -66,6 +66,11 @@ def test_record(): s.unroll(C, s[C].iters[4]) # Vectorize s.vectorize(C, s[C].iters[6]) + # Cache Read + D_global = s.cache_read(D, "global", [E]) + s.compute_at(D_global, E, s[E].iters[2]) + # Cache Write + s.cache_write(D, "shared") target = tvm.target.create("llvm") task = auto_scheduler.SearchTask(dag, "test", target) -- 2.7.4