From: Chenfan Date: Wed, 29 Jul 2020 06:39:21 +0000 (+0800) Subject: [Ansor][AutoTVM v2.0] Phase 1: Add pragma/storage_align/rfactor steps (#6141) X-Git-Tag: upstream/0.7.0~337 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=96f601183f16686092e558bceceb5e6d507f1357;p=platform%2Fupstream%2Ftvm.git [Ansor][AutoTVM v2.0] Phase 1: Add pragma/storage_align/rfactor steps (#6141) * Add pragma/storage_align/rfactor step * Update * Update * Update UT * Update --- diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index 9850620..34e7e56 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -341,6 +341,13 @@ class State : public ObjectRef { */ TVM_DLL Iterator fuse(int stage_id, const Array& iters); /*! + * \brief Schedule primitive corresponds to `te.Stage.pragma`. + * \param stage_id The index of the stage to add pragma. + * \param it The iterator to add pragma. + * \param pragma_type The pragma string. + */ + TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type); + /*! * \brief Schedule primitive corresponds to `te::Stage::reorder`. * \param stage_id The index of the stage to be reordered. * \param order The expected iterator order. @@ -382,6 +389,14 @@ class State : public ObjectRef { TVM_DLL Array follow_fused_split(int stage_id, const Iterator& it, const Array& src_step_ids, int level, bool factor_or_nparts); + /*! + * \brief Schedule primitive corresponds to `te.Stage.storage_align`. + * \param stage_id The index of the stage to be aligned. + * \param it The iterator to be aligned. + * \param factor The factor in alignment specification. + * \param offset The offset in the alignment specification. + */ + TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset); /********** Step APIs working on multiple stages **********/ @@ -422,8 +437,8 @@ class State : public ObjectRef { * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. */ - int cache_read(int stage_id, const String& scope_name, const Array& reader_stage_ids, - const ComputeDAG& dag); + TVM_DLL int cache_read(int stage_id, const String& scope_name, + const Array& reader_stage_ids, const ComputeDAG& dag); /*! * \brief Schedule primitive corresponds to `te::Schedule::cache_write`. * \param stage_id The index of the stage to be cache write. @@ -433,7 +448,17 @@ class State : public ObjectRef { * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. * This step will cache write all output tensors of the target stage. */ - int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag); + TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag); + /*! + * \brief Schedule primitive corresponds to `te::Schedule::rfactor`. + * \param stage_id The index of the iterator to be factored. + * \param it The iterator to be factored. + * \param factor_iter_id The position where the new iterator is placed. + * \param dag The original ComputeDAG of this state. + * \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the + * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`. + */ + TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag); TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h index f91505c..a31765a 100644 --- a/include/tvm/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -350,6 +350,67 @@ class FuseStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; +/*! \brief Pragma step that corresponds to te::Stage::pragma */ +class PragmaStepNode : public StepNode { + public: + /*! \brief The index of the iterator to add pragma. */ + int iter_id; + /*! \brief The pragma string. */ + String pragma_type; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to state, which will be updated. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* record_prefix_str = "PR"; + + static constexpr const char* _type_key = "auto_scheduler.PragmaStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); +}; + +/*! + * \brief Managed reference to PragmaStepNode. + * \sa PragmaStepNode + */ +class PragmaStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be fused. + * \param iter_id The index of the iterator to add pragma. + * \param pragma_type The pragma string. + */ + PragmaStep(int stage_id, int iter_id, String pragma_type); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit PragmaStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode); +}; + /*! \brief Reorder step that corresponds to te::Stage::reorder */ class ReorderStepNode : public StepNode { public: @@ -506,14 +567,14 @@ class FollowSplitStepNode : public StepNode { /*! * \brief Extract split lengths. * \param transform_steps An array record all transform steps. - * \param lengths The multiple split factors. Can be None to be filled by search policy. + * \return The multiple split factors. */ - void ExtractSplitLengths(const Array& transform_steps, - Array>* lengths) const; + Array> ExtractSplitLengths(const Array& transform_steps) const; /*! * \brief Apply the current step to State. * \param state A mutable pointer to state, which will be updated. + * \return The iterator results after split. */ Array ApplyToState(State* state) const; @@ -651,6 +712,70 @@ class FollowFusedSplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); }; +/*! \brief Storage align step that corresponds to te::Stage::storage_align */ +class StorageAlignStepNode : public StepNode { + public: + /*! \brief The iterator to be aligned. */ + int iter_id; + /*! \brief The factor in alignment specification. */ + int factor; + /*! \brief The offset in the alignment specification. */ + int offset; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State, which will be updated. + */ + void ApplyToState(State* state) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + */ + void ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const; + + static constexpr const char* record_prefix_str = "SA"; + + static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); +}; + +/*! + * \brief Managed reference to StorageAlignStepNode. + * \sa StorageAlignStepNode + */ +class StorageAlignStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be aligned. + * \param iter_id The index of the iterator to be aligned. + * \param factor The factor in alignment specification. + * \param offset The offset in the alignment specification. + */ + StorageAlignStep(int stage_id, int iter_id, int factor, int offset); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit StorageAlignStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode); +}; + /********** Steps working on multiple stages **********/ /*! \brief Compute at step that corresponds to te::Stage::compute_at */ @@ -832,7 +957,7 @@ class ComputeRootStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); }; -/********** Primitives adding new stages **********/ +/********** Steps adding new stages **********/ /*! * \brief Cache read step that corresponds to te::Schedule::cache_read. @@ -976,6 +1101,74 @@ class CacheWriteStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); }; +/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ +class RfactorStepNode : public StepNode { + public: + /*! \brief The index of the iterator to be factored. */ + int iter_id; + /*! \brief The position where the new iterator is placed. */ + int factor_iter_id; + + void WriteToRecord(dmlc::JSONWriter* writer) const final; + + /*! + * \brief Apply the current step to State. + * \param state A mutable pointer to State, which will be updated. + * \param dag The original ComputeDAG of this state. + * \return The index of the new added stage. + */ + int ApplyToState(State* state, const ComputeDAG& dag) const; + + /*! + * \brief Apply the current step to tvm.schedule. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param schedule A mutable pointer to a te::Schedule. + * \return The output Tensors of the new added stage. + */ + Array ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + /*! + * \brief Print the current step as equivalent python schedule API. + * \param stages The `te::Stage`s used in TVM scheduler applying. + * \param stage_to_axes The `te::Stage` and `tir::IterVar` map. + * \param schedule A mutable pointer to a te::Schedule. + * \return Python schedule code. + */ + String PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const; + + static constexpr const char* record_prefix_str = "RF"; + + static constexpr const char* _type_key = "auto_scheduler.RfactorStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); +}; + +/*! + * \brief Managed reference to RfactorStepNode. + * \sa RfactorStepNode + */ +class RfactorStep : public Step { + public: + /*! + * \brief The constructor. + * \param stage_id The index of the stage to be factored. + * \param iter_id The index of the iterator to be factored. + * \param factor_iter_id The position where the new iterator is placed. + */ + RfactorStep(int stage_id, int iter_id, int factor_iter_id); + + /*! + * \brief The constructor used to read a step record from JSONReader and create the + * corresponding step. + * \param reader The input JSONReader. + */ + explicit RfactorStep(dmlc::JSONReader* reader); + + TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode); +}; + } // namespace auto_scheduler } // namespace tvm diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py index 9ec26c3..35ecacc 100644 --- a/python/tvm/auto_scheduler/loop_state.py +++ b/python/tvm/auto_scheduler/loop_state.py @@ -261,6 +261,23 @@ class State: self._resolve_stage_id(stage), iters) return res + def pragma(self, stage, iterator, pragma_type): + """ Schedule primitive corresponds to `te.Stage.pragma`, see also the `te.Stage` for more + details. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to add pragma, which can be specified by the integer index, Operation, + or output tensor of the stage. + iterator : Iterator + The iterator to add pragma. + pragma_type : str + The pragma string. + """ + self.state_object = _ffi_api.StatePragma(self.state_object, self._resolve_stage_id(stage), + iterator, pragma_type) + def reorder(self, stage, order): """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more details. @@ -397,6 +414,26 @@ class State: factor_or_nparts) return res + def storage_align(self, stage, iterator, factor, offset): + """ Schedule primitive corresponds to `te.Stage.storage_align`, see also the `te.Stage` for + more details. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be storage aligned, which can be specified by the integer index, + Operation, or output tensor of the stage. + iterator : Iterator + The iterator to be aligned. + factor : int + The factor in alignment specification. + offset : int + The offset in the alignment specification. + """ + self.state_object = _ffi_api.StateStorageAlign(self.state_object, + self._resolve_stage_id(stage), iterator, + factor, offset) + def compute_at(self, stage, target_stage, target_iter): """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for more details. @@ -525,6 +562,40 @@ class State: self._update_stage_id_map() return self.stages[int(new_stage_id)].op + def rfactor(self, stage, iterator, factor_iter_id): + """ Schedule primitive corresponds to `te.Schedule.rfactor`, see also the `te.Schedule` for + more details. + + Parameters + ---------- + stage : Union[int, Operation, Tensor] + The Stage to be factored, which can be specified by the integer index, Operation, + or output tensor of the stage. + iterator : Iterator + The reduction iterator to be factored. + factor_iter_id : int + The position where the new iterator is placed. + + Returns + ------- + new_stage_op : Operator + The Operator of the new added stage. + + Notes + ----- + Rfactor step will insert an extra stage to the original ComputeDAG (in the front of the + target stage). + """ + self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, + self._resolve_stage_id(stage), + iterator, factor_iter_id, + self.compute_dag) + # Add a new stage will change all ops behind the added stage. But we still want to keep the + # original ops map, apply stage id offset to stage_id_map to make them work. + self._apply_stage_id_offset(int(new_stage_id)) + self._update_stage_id_map() + return self.stages[int(new_stage_id)].op + def copy(self): """ Do deep copy of this State. """ state = State(self.state_object, self.compute_dag) diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index 636066a..f9d1f82 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -247,6 +247,13 @@ Iterator State::fuse(int stage_id, const Array& iters) { return step->ApplyToState(this); } +void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) { + const Stage& stage = operator->()->stages[stage_id]; + PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + void State::reorder(int stage_id, const Array& order) { const Stage& stage = operator->()->stages[stage_id]; CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " @@ -287,6 +294,13 @@ Array State::follow_fused_split(int stage_id, const Iterator& it, return step->ApplyToState(this); } +void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) { + const Stage& stage = operator->()->stages[stage_id]; + StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this); +} + void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; ComputeAtStep step = @@ -320,6 +334,13 @@ int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& return step->ApplyToState(this, dag); } +int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) { + const Stage& stage = operator->()->stages[stage_id]; + RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id); + CopyOnWrite()->transform_steps.push_back(step); + return step->ApplyToState(this, dag); +} + void State::ApplySteps(const ComputeDAG& dag) { CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages."; @@ -460,6 +481,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse") return Array{state, res}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma") + .set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) { + state.pragma(stage_id, it, pragma_type); + return state; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder") .set_body_typed([](State state, int stage_id, const Array& order) { state.reorder(stage_id, order); @@ -488,6 +515,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit") return Array{state, Array(res)}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign") + .set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) { + state.storage_align(stage_id, it, factor, offset); + return state; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt") .set_body_typed([](State state, int stage_id, int target_stage_id, const Iterator& target_iter) { @@ -521,6 +554,13 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite") return Array{state, Integer(res)}; }); +TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor") + .set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id, + const ComputeDAG& dag) { + int res = state.rfactor(stage_id, it, factor_iter_id, dag); + return Array{state, Integer(res)}; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); }); diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index d43d0af..e533a7c 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -81,6 +81,8 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return AnnotationStep(reader); } else if (name == FuseStepNode::record_prefix_str) { return FuseStep(reader); + } else if (name == PragmaStepNode::record_prefix_str) { + return PragmaStep(reader); } else if (name == ReorderStepNode::record_prefix_str) { return ReorderStep(reader); } else if (name == SplitStepNode::record_prefix_str) { @@ -89,6 +91,8 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return FollowSplitStep(reader); } else if (name == FollowFusedSplitStepNode::record_prefix_str) { return FollowFusedSplitStep(reader); + } else if (name == StorageAlignStepNode::record_prefix_str) { + return StorageAlignStep(reader); } else if (name == ComputeAtStepNode::record_prefix_str) { return ComputeAtStep(reader); } else if (name == ComputeInlineStepNode::record_prefix_str) { @@ -99,6 +103,8 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { return CacheReadStep(reader); } else if (name == CacheWriteStepNode::record_prefix_str) { return CacheWriteStep(reader); + } else if (name == RfactorStepNode::record_prefix_str) { + return RfactorStep(reader); } else { LOG(FATAL) << "Invalid step format: " << name; } @@ -111,6 +117,8 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { @@ -119,6 +127,8 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); + } else if (auto ps = step.as()) { + ps->ApplyToState(state); } else if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { @@ -129,6 +139,8 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { ps->ApplyToState(state, dag); } else if (auto ps = step.as()) { ps->ApplyToState(state, dag); + } else if (auto ps = step.as()) { + ps->ApplyToState(state, dag); } else { LOG(FATAL) << "Invalid step: " << step; } @@ -140,6 +152,8 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -148,6 +162,8 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -158,6 +174,8 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes ps->ApplyToSchedule(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes, schedule); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, schedule); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -170,6 +188,8 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -178,6 +198,8 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -188,6 +210,8 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); + } else if (auto ps = step.as()) { + return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); } else { LOG(FATAL) << "Invalid Step: " << step; } @@ -488,6 +512,113 @@ String FuseStepNode::PrintAsPythonAPI(Array* stages, return ss.str(); } +/********** Pragma **********/ +PragmaStep::PragmaStep(int stage_id, int iter_id, String pragma_type) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->pragma_type = std::move(pragma_type); + data_ = std::move(node); +} + +PragmaStep::PragmaStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + std::string string_value; + reader->Read(&string_value); + node->pragma_type = std::move(string_value); + data_ = std::move(node); +} + +void PragmaStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArraySeperator(); + writer->WriteString(pragma_type); +} + +void PragmaStepNode::ApplyToState(State* state) const { + if (pragma_type == "debug_skip_region") { + StateNode* pstate = state->CopyOnWrite(); + pstate->attach_map.DeleteStage(stage_id); + } else if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + StateNode* pstate = state->CopyOnWrite(); + Stage stage = pstate->stages[stage_id]; + size_t pos = 0; + for (; pos < pragma_type.size(); ++pos) { + if ((*(pragma_type.c_str() + pos)) == '$') { + break; + } + } + CHECK_LT(pos, pragma_type.size()) << "max step value not found."; + stage.CopyOnWrite()->attrs.auto_unroll_max_step = atoi(pragma_type.c_str() + pos + 1); + pstate->stages.Set(stage_id, std::move(stage)); + } else { + LOG(FATAL) << "Unsupported pragma: " << pragma_type; + } +} + +void PragmaStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + te::Stage stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = 0; + for (; pos < pragma_type.size(); ++pos) { + if ((*(pragma_type.c_str() + pos)) == '$') { + break; + } + } + CHECK_LT(pos, pragma_type.size()) << "max step value not found."; + int value = atoi(pragma_type.c_str() + pos + 1); + stage.pragma(axes[iter_id], "auto_unroll_max_step", value); + stage.pragma(axes[iter_id], "unroll_explicit", true); + } else { + stage.pragma(axes[iter_id], pragma_type); + } + stages->Set(stage_id, std::move(stage)); +} + +String PragmaStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = 0; + for (; pos < pragma_type.size(); ++pos) { + if ((*(pragma_type.c_str() + pos)) == '$') { + break; + } + } + CHECK_LT(pos, pragma_type.size()) << "max step value not found."; + int value = atoi(pragma_type.c_str() + pos + 1); + ss << "s[" << CleanName(stage->op->name) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"auto_unroll_max_step\", " << value << ")\n"; + ss << "s[" << CleanName(stage->op->name) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"unroll_explicit\", True)\n"; + } else { + ss << "s[" << CleanName(stage->op->name) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type + << "\")\n"; + } + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + /********** Reorder **********/ ReorderStep::ReorderStep(int stage_id, const Array& after_ids) { auto node = make_object(); @@ -812,8 +943,8 @@ void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { writer->WriteArrayItem(n_split); } -void FollowSplitStepNode::ExtractSplitLengths(const Array& transform_steps, - Array>* lengths) const { +Array> FollowSplitStepNode::ExtractSplitLengths( + const Array& transform_steps) const { // Make sure src_step_id is within the range of transform_steps. CHECK_LT(src_step_id, transform_steps.size()); auto ps = transform_steps[src_step_id].as(); @@ -824,11 +955,12 @@ void FollowSplitStepNode::ExtractSplitLengths(const Array& transform_steps CHECK_LE(n_split, ps->lengths.size() + 1); CHECK(ps != nullptr); - lengths->reserve(n_split); + Array> lengths; + lengths.reserve(n_split); int j = 0; // Get the first (n_split-1) split factors of followed src_step. for (; j < n_split - 1; ++j) { - lengths->push_back(ps->lengths[j]); + lengths.push_back(ps->lengths[j]); } // Get the last split factor of src_step for splitting level if n_split is smaller than @@ -843,10 +975,12 @@ void FollowSplitStepNode::ExtractSplitLengths(const Array& transform_steps } } if (last_factor.defined()) { - lengths->push_back(Downcast(last_factor)); + lengths.push_back(Downcast(last_factor)); } else { - lengths->push_back(NullOpt); + lengths.push_back(NullOpt); } + + return lengths; } FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) { @@ -864,30 +998,26 @@ FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) { s = reader->NextArrayItem(); CHECK(s); reader->Read(&node->n_split); - data_ = std::move(node); } Array FollowSplitStepNode::ApplyToState(State* state) const { - Array> lengths; - ExtractSplitLengths((*state)->transform_steps, &lengths); - return ApplySplitToState(state, stage_id, iter_id, lengths, true); + return ApplySplitToState(state, stage_id, iter_id, ExtractSplitLengths((*state)->transform_steps), + true); } Array FollowSplitStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const { - Array> lengths; - ExtractSplitLengths(transform_steps, &lengths); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + ExtractSplitLengths(transform_steps), true); } String FollowSplitStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const { - Array> lengths; - ExtractSplitLengths(transform_steps, &lengths); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + ExtractSplitLengths(transform_steps), true); } /********** Follow Fused Split **********/ @@ -922,7 +1052,6 @@ FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) { s = reader->NextArrayItem(); CHECK(s); reader->Read(&node->factor_or_nparts); - ::tvm::Array<::tvm::Integer> src_step_ids; for (const auto& i : int_list) { src_step_ids.push_back(i); @@ -961,23 +1090,86 @@ Optional FollowFusedSplitStepNode::ExtractSplitLength( } Array FollowFusedSplitStepNode::ApplyToState(State* state) const { - const Optional& length = ExtractSplitLength((*state)->transform_steps); - return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts); + return ApplySplitToState(state, stage_id, iter_id, + {ExtractSplitLength((*state)->transform_steps)}, factor_or_nparts); } Array FollowFusedSplitStepNode::ApplyToSchedule(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const { - const Optional& length = ExtractSplitLength(transform_steps); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + {ExtractSplitLength(transform_steps)}, factor_or_nparts); } String FollowFusedSplitStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, const Array& transform_steps) const { - const Optional& length = ExtractSplitLength(transform_steps); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length}, - factor_or_nparts); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + {ExtractSplitLength(transform_steps)}, factor_or_nparts); +} + +/********** Storage Align **********/ +StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, int factor, int offset) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor = factor; + node->offset = offset; + data_ = std::move(node); +} + +StorageAlignStep::StorageAlignStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->factor); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->offset); + data_ = std::move(node); +} + +void StorageAlignStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(factor); + writer->WriteArrayItem(offset); +} + +void StorageAlignStepNode::ApplyToState(State* state) const { + StateNode* pstate = state->CopyOnWrite(); + Stage stage = pstate->stages[stage_id]; + stage.CopyOnWrite()->attrs.storage_offset = offset; + pstate->stages.Set(stage_id, std::move(stage)); +} + +void StorageAlignStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes) const { + te::Stage stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; + stage.storage_align(axes[iter_id], factor, offset); + stages->Set(stage_id, std::move(stage)); +} + +String StorageAlignStepNode::PrintAsPythonAPI(Array* stages, + StageToAxesMap* stage_to_axes) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->name) << "].storage_align(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", " + << offset << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); } /********** Steps working on multiple stages **********/ @@ -1162,7 +1354,7 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, return ss.str(); } -/********** Primitives adding new stages **********/ +/********** Steps adding new stages **********/ /*! * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep, @@ -1171,11 +1363,27 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array* stages, */ Array GetFormerStageModifiableSteps(Step current_step, const Array& transform_steps) { Array ret_steps; - for (const Step& step : transform_steps) { + for (size_t i = 0; i < transform_steps.size(); ++i) { + const Step& step = transform_steps[i]; if (step->IsInstance() || step->IsInstance()) { ret_steps.push_back(step); + } else if (step->IsInstance()) { + // add FuseStepNode required by rfactor + if (i >= 2 && transform_steps[i - 2]->IsInstance()) { + const Step& fuse_step = transform_steps[i - 2]; + if (fuse_step->stage_id == step->stage_id) { + ret_steps.push_back(fuse_step); + } + } + // add SplitStepNode required by rfactor + CHECK_GE(i, 1); + CHECK(transform_steps[i - 1]->IsInstance()); + const Step& split_step = transform_steps[i - 1]; + CHECK_EQ(split_step->stage_id, step->stage_id); + ret_steps.push_back(split_step); + // add RfactorStepNode + ret_steps.push_back(step); } - // TODO(jcf94): add rfactor support // A state may have multiple stage modifiable steps, stop by the current step to avoid // replaying excess steps if (step.same_as(current_step)) { @@ -1432,5 +1640,130 @@ String CacheWriteStepNode::PrintAsPythonAPI(Array* stages, StageToAxe return ss.str(); } +/********** Rfactor **********/ +RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor_iter_id = factor_iter_id; + data_ = std::move(node); +} + +RfactorStep::RfactorStep(dmlc::JSONReader* reader) { + auto node = make_object(); + bool s; + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->stage_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->iter_id); + s = reader->NextArrayItem(); + CHECK(s); + reader->Read(&node->factor_iter_id); + data_ = std::move(node); +} + +void RfactorStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { + writer->WriteArraySeperator(); + writer->WriteString(record_prefix_str); + writer->WriteArrayItem(stage_id); + writer->WriteArrayItem(iter_id); + writer->WriteArrayItem(factor_iter_id); +} + +int RfactorStepNode::ApplyToState(State* state, const ComputeDAG& dag) const { + StateNode* pstate = state->CopyOnWrite(); + const auto& compute_at_type = pstate->stages[stage_id]->compute_at; + const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG( + GetFormerStageModifiableSteps(GetRef(this), (*state)->transform_steps)); + + // target_stage -> rfactor_compute + target_stage + // Insert a new compute stage, update the target stage and later stage, then update the stage_id + // mapping in AttachMap + pstate->stages.insert(pstate->stages.begin() + stage_id, + Stage(current_compute_dag->ops[stage_id])); + // Maintain the compute_at type of the target stage + Stage target_stage = Stage(current_compute_dag->ops[stage_id + 1]); + target_stage.CopyOnWrite()->compute_at = compute_at_type; + pstate->stages.Set(stage_id + 1, std::move(target_stage)); + for (size_t i = stage_id + 2; i < pstate->stages.size(); ++i) { + Stage stage = pstate->stages[i]; + stage.CopyOnWrite()->op = current_compute_dag->ops[i]; + pstate->stages.Set(i, std::move(stage)); + } + pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id); + pstate->current_compute_dag = std::move(current_compute_dag); + + return stage_id; +} + +Array RfactorStepNode::ApplyToSchedule(Array* stages, + StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + const auto& stage = (*stages)[stage_id]; + const Array& axes = (*stage_to_axes)[stage]; + + const te::Tensor& tensor = stage->origin_op.output(0); + const IterVar& axis = axes[iter_id]; + auto outs = schedule->rfactor(tensor, axis, factor_iter_id); + + UpdateStageToAxesMap(stage, stage_to_axes); + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageToAxesMap(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +String RfactorStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes, + te::Schedule* schedule) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); + const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); + + const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->name); + if (i != outs.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "s.rfactor(" << tensor_name << ", " << axis_name << ", " << factor_iter_id << ")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "tuple(" << CleanName(out->op->name) << ".op.axis)" + << " + " + << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n"; + } + + const auto& output = (*stages)[stage_id + 1]->op.output(0); + const auto& iters = output->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " + << "tuple(s[" << CleanName(output->op->name) << "].op.axis)" + << " + " + << "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n"; + + return ss.str(); +} + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index da5032e..aacdcf4 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -162,6 +162,12 @@ inline double FloatArrayMean(const Array& float_array) { return sum / float_array.size(); } +/*! \brief Return whether a string starts with another substring */ +inline bool StrStartsWith(const String& a, const String& b) { + if (b.size() > a.size()) return false; + return std::equal(a.c_str(), a.c_str() + b.size(), b.c_str()); +} + /********** Other Utilities **********/ /*! \brief Get an int value from an Expr */ inline int64_t GetIntImm(const PrimExpr& expr) { diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py index e35dfe3..5c501ac 100644 --- a/tests/python/unittest/test_auto_scheduler_loop_state.py +++ b/tests/python/unittest/test_auto_scheduler_loop_state.py @@ -454,8 +454,62 @@ def test_follow_split_follow_fused_split(): assert tmp[C].iters[level + 1].range.extent == \ tmp[C_global].iters[1].range.extent + +def test_rfactor(): + A, B, C = matmul_auto_scheduler_test(8, 8, 512) + dag = auto_scheduler.ComputeDAG([A, B, C]) + s0 = dag.get_init_state() + + ko, ki = s0.split(C, s0[C].iters[2], [16]) + + s1 = s0.copy() + C_r = s1.rfactor(C, ko, 2) + """ + Placeholder: A, B + for i (0,8) + for j (0,8) + for k_o (0,32) + for k_i (0,16) + C.rf = ... + for ax0 (0,8) + for ax1 (0,8) + for k_o_v (0,32) + C.repl = ... + """ + assert s1[C_r].iters[0].range.extent == 8 + assert s1[C_r].iters[1].range.extent == 8 + assert s1[C_r].iters[2].range.extent == 32 + assert s1[C_r].iters[3].range.extent == 16 + assert s1[C].iters[0].range.extent == 8 + assert s1[C].iters[1].range.extent == 8 + assert s1[C].iters[2].range.extent == 32 + + s2 = s0.copy() + C_r = s2.rfactor(C, ki, 2) + """ + Placeholder: A, B + for i (0,8) + for j (0,8) + for k_i (0,16) + for k_o (0,32) + C.rf = ... + for ax0 (0,8) + for ax1 (0,8) + for k_i_v (0,16) + C.repl = ... + """ + assert s2[C_r].iters[0].range.extent == 8 + assert s2[C_r].iters[1].range.extent == 8 + assert s2[C_r].iters[2].range.extent == 16 + assert s2[C_r].iters[3].range.extent == 32 + assert s2[C].iters[0].range.extent == 8 + assert s2[C].iters[1].range.extent == 8 + assert s2[C].iters[2].range.extent == 16 + + if __name__ == "__main__": test_split_fuse_reorder_annotation() test_compute_at_root_inline() test_cache_read_write() test_follow_split_follow_fused_split() + test_rfactor() diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index 39d01e0..e65f191 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -24,7 +24,28 @@ import tempfile from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul -def test_record(): +def record_common(dag, s): + target = tvm.target.create("llvm") + task = auto_scheduler.SearchTask(dag, "test", target) + + inp = auto_scheduler.measure.MeasureInput(task, s) + res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) + + with tempfile.NamedTemporaryFile() as fp: + auto_scheduler.save_records(fp.name, [inp], [res]) + + log_reader = auto_scheduler.RecordReader(fp.name) + inputs, results = log_reader.read_lines() + assert len(inputs) == 1 + + s1 = dag.infer_bound_from_state(s) + s2 = dag.infer_bound_from_state(inputs[0].state) + + assert s1 == s2 + assert not (s1 == dag.get_init_state()) + + +def test_record_split_reorder_fuse_annotation(): if not tvm.runtime.enabled("llvm"): return @@ -32,16 +53,8 @@ def test_record(): B = te.placeholder((512, 512), name='B') k = te.reduce_axis((0, 512), name='k') C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') - D = topi.nn.relu(C) - k = te.reduce_axis((0, 512), name='k') - E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E') - F = topi.nn.relu(E) - k = te.reduce_axis((0, 512), name='k') - G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G') - H = topi.nn.relu(G) - I = topi.nn.relu(H) - dag = auto_scheduler.ComputeDAG([A, B, I]) + dag = auto_scheduler.ComputeDAG([A, B, C]) s = dag.get_init_state() # Split @@ -52,13 +65,6 @@ def test_record(): its1[3]]) # Fuse s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]]) - # Compute at - s.split(F, s[F].iters[0], [2]) - s.compute_at(E, F, s[F].iters[0]) - # Compute inline - s.compute_inline(D) - # Compute root - s.compute_root(D) # Parallel s.parallel(C, s[C].iters[0]) # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here) @@ -69,46 +75,93 @@ def test_record(): s.unroll(C, s[C].iters[4]) # Vectorize s.vectorize(C, s[C].iters[6]) - # Cache Read - D_global = s.cache_read(D, "global", [E]) - s.compute_at(D_global, E, s[E].iters[2]) + + record_common(dag, s) + + +def test_record_compute_at_root_inline_cache_read_write(): + if not tvm.runtime.enabled("llvm"): + return + + A = te.placeholder((512, 512), name='A') + AA = topi.nn.relu(A) + B = te.placeholder((512, 512), name='B') + k = te.reduce_axis((0, 512), name='k') + C = te.compute((512, 512), lambda i, j: te.sum(AA[i][k] * B[k][j], axis=[k]), name='C') + + dag = auto_scheduler.ComputeDAG([A, B, C]) + s = dag.get_init_state() + # Cache Write - s.cache_write(D, "shared") - #follow_split - its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True) + C_shared = s.cache_write(C, "shared") + # Compute At + s.compute_at(C_shared, C, s[C].iters[0]) + # Cache Read + B_global = s.cache_read(B, "global", [C_shared]) + s.compute_at(B_global, C_shared, s[C_shared].iters[2]) + # Compute Inline + s.compute_inline(AA) + # Compute Root + s.compute_root(C_shared) + + record_common(dag, s) + + +def test_record_follow_split_follow_fused_split(): + if not tvm.runtime.enabled("llvm"): + return + + A = te.placeholder((512, 512), name='A') + B = te.placeholder((512, 512), name='B') + k = te.reduce_axis((0, 512), name='k') + C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') + D = topi.nn.relu(C) + E = topi.nn.relu(D) + + dag = auto_scheduler.ComputeDAG([A, B, E]) + s = dag.get_init_state() + + # Follow Split + s.split(C, s[C].iters[0], [4, 2, 8, 4], True) split_step0 = len(s.transform_steps) - 1 - s.follow_split(G, s[G].iters[5], split_step0, 4) - #follow_fused_split - its2 = s.split(H, s[H].iters[0], [4, 2, 8, 4], True) + s.follow_split(C, s[C].iters[5], split_step0, 4) + # Follow Fused Split + its0 = s.split(E, s[E].iters[0], [4, 2, 8, 4], True) split_step1 = len(s.transform_steps) - 1 - its3 = s.split(H, s[H].iters[5], [2, 4, 2, 4], True) + its1 = s.split(E, s[E].iters[5], [2, 4, 2, 4], True) split_step2 = len(s.transform_steps) - 1 its = [] - for i0, i1 in zip(its2, its3): + for i0, i1 in zip(its0, its1): its.append(i0) its.append(i1) for i in range(0, 5): - s.fuse(H, [s[H].iters[i], s[H].iters[i + 1]]) - s.follow_fused_split(I, s[I].iters[0], [split_step1, split_step2], 0, False) + s.fuse(E, [s[E].iters[i], s[E].iters[i + 1]]) + s.follow_fused_split(D, s[D].iters[0], [split_step1, split_step2], 2, True) - target = tvm.target.create("llvm") - task = auto_scheduler.SearchTask(dag, "test", target) + record_common(dag, s) - inp = auto_scheduler.measure.MeasureInput(task, s) - res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1) - with tempfile.NamedTemporaryFile() as fp: - auto_scheduler.save_records(fp.name, [inp], [res]) +def test_record_pragma_storage_align_rfactor(): + if not tvm.runtime.enabled("llvm"): + return - log_reader = auto_scheduler.RecordReader(fp.name) - inputs, results = log_reader.read_lines() - assert len(inputs) == 1 + A = te.placeholder((512, 512), name='A') + B = te.placeholder((512, 512), name='B') + k = te.reduce_axis((0, 512), name='k') + C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') - s1 = dag.infer_bound_from_state(s) - s2 = dag.infer_bound_from_state(inputs[0].state) + dag = auto_scheduler.ComputeDAG([A, B, C]) + s = dag.get_init_state() - assert s1 == s2 - assert not (s1 == dag.get_init_state()) + # Rfactor + ko, _ = s.split(C, s[C].iters[2], [16]) + s.rfactor(C, ko, 2) + # Pragma + s.pragma(C, s[C].iters[0], "auto_unroll_max_step$64") + # StorageAlign + s.storage_align(C, s[C].iters[-1], 8, 4) + + record_common(dag, s) def test_measure_local_builder_runner(): @@ -149,6 +202,9 @@ def test_measure_local_builder_rpc_runner(): if __name__ == "__main__": - test_record() + test_record_split_reorder_fuse_annotation() + test_record_compute_at_root_inline_cache_read_write() + test_record_follow_split_follow_fused_split() + test_record_pragma_storage_align_rfactor() test_measure_local_builder_runner() test_measure_local_builder_rpc_runner()