*/
TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
/*!
+ * \brief Schedule primitive corresponds to `te.Stage.pragma`.
+ * \param stage_id The index of the stage to add pragma.
+ * \param it The iterator to add pragma.
+ * \param pragma_type The pragma string.
+ */
+ TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
+ /*!
* \brief Schedule primitive corresponds to `te::Stage::reorder`.
* \param stage_id The index of the stage to be reordered.
* \param order The expected iterator order.
TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts);
+ /*!
+ * \brief Schedule primitive corresponds to `te.Stage.storage_align`.
+ * \param stage_id The index of the stage to be aligned.
+ * \param it The iterator to be aligned.
+ * \param factor The factor in alignment specification.
+ * \param offset The offset in the alignment specification.
+ */
+ TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);
/********** Step APIs working on multiple stages **********/
* \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
*/
- int cache_read(int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids,
- const ComputeDAG& dag);
+ TVM_DLL int cache_read(int stage_id, const String& scope_name,
+ const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
/*!
* \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
* \param stage_id The index of the stage to be cache write.
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
* This step will cache write all output tensors of the target stage.
*/
- int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
+ TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
+ /*!
+ * \brief Schedule primitive corresponds to `te::Schedule::rfactor`.
+ * \param stage_id The index of the iterator to be factored.
+ * \param it The iterator to be factored.
+ * \param factor_iter_id The position where the new iterator is placed.
+ * \param dag The original ComputeDAG of this state.
+ * \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the
+ * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
+ */
+ TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);
TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
};
+/*! \brief Pragma step that corresponds to te::Stage::pragma */
+class PragmaStepNode : public StepNode {
+ public:
+ /*! \brief The index of the iterator to add pragma. */
+ int iter_id;
+ /*! \brief The pragma string. */
+ String pragma_type;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to state, which will be updated.
+ */
+ void ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ */
+ void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ static constexpr const char* record_prefix_str = "PR";
+
+ static constexpr const char* _type_key = "auto_scheduler.PragmaStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to PragmaStepNode.
+ * \sa PragmaStepNode
+ */
+class PragmaStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be fused.
+ * \param iter_id The index of the iterator to add pragma.
+ * \param pragma_type The pragma string.
+ */
+ PragmaStep(int stage_id, int iter_id, String pragma_type);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit PragmaStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode);
+};
+
/*! \brief Reorder step that corresponds to te::Stage::reorder */
class ReorderStepNode : public StepNode {
public:
/*!
* \brief Extract split lengths.
* \param transform_steps An array record all transform steps.
- * \param lengths The multiple split factors. Can be None to be filled by search policy.
+ * \return The multiple split factors.
*/
- void ExtractSplitLengths(const Array<Step>& transform_steps,
- Array<Optional<Integer>>* lengths) const;
+ Array<Optional<Integer>> ExtractSplitLengths(const Array<Step>& transform_steps) const;
/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to state, which will be updated.
+ * \return The iterator results after split.
*/
Array<Iterator> ApplyToState(State* state) const;
TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
};
+/*! \brief Storage align step that corresponds to te::Stage::storage_align */
+class StorageAlignStepNode : public StepNode {
+ public:
+ /*! \brief The iterator to be aligned. */
+ int iter_id;
+ /*! \brief The factor in alignment specification. */
+ int factor;
+ /*! \brief The offset in the alignment specification. */
+ int offset;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State, which will be updated.
+ */
+ void ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ */
+ void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ static constexpr const char* record_prefix_str = "SA";
+
+ static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to StorageAlignStepNode.
+ * \sa StorageAlignStepNode
+ */
+class StorageAlignStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be aligned.
+ * \param iter_id The index of the iterator to be aligned.
+ * \param factor The factor in alignment specification.
+ * \param offset The offset in the alignment specification.
+ */
+ StorageAlignStep(int stage_id, int iter_id, int factor, int offset);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit StorageAlignStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode);
+};
+
/********** Steps working on multiple stages **********/
/*! \brief Compute at step that corresponds to te::Stage::compute_at */
TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
};
-/********** Primitives adding new stages **********/
+/********** Steps adding new stages **********/
/*!
* \brief Cache read step that corresponds to te::Schedule::cache_read.
TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode);
};
+/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */
+class RfactorStepNode : public StepNode {
+ public:
+ /*! \brief The index of the iterator to be factored. */
+ int iter_id;
+ /*! \brief The position where the new iterator is placed. */
+ int factor_iter_id;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State, which will be updated.
+ * \param dag The original ComputeDAG of this state.
+ * \return The index of the new added stage.
+ */
+ int ApplyToState(State* state, const ComputeDAG& dag) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param schedule A mutable pointer to a te::Schedule.
+ * \return The output Tensors of the new added stage.
+ */
+ Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param schedule A mutable pointer to a te::Schedule.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const;
+
+ static constexpr const char* record_prefix_str = "RF";
+
+ static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to RfactorStepNode.
+ * \sa RfactorStepNode
+ */
+class RfactorStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be factored.
+ * \param iter_id The index of the iterator to be factored.
+ * \param factor_iter_id The position where the new iterator is placed.
+ */
+ RfactorStep(int stage_id, int iter_id, int factor_iter_id);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit RfactorStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode);
+};
+
} // namespace auto_scheduler
} // namespace tvm
self._resolve_stage_id(stage), iters)
return res
+ def pragma(self, stage, iterator, pragma_type):
+ """ Schedule primitive corresponds to `te.Stage.pragma`, see also the `te.Stage` for more
+ details.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to add pragma, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iterator : Iterator
+ The iterator to add pragma.
+ pragma_type : str
+ The pragma string.
+ """
+ self.state_object = _ffi_api.StatePragma(self.state_object, self._resolve_stage_id(stage),
+ iterator, pragma_type)
+
def reorder(self, stage, order):
""" Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more
details.
factor_or_nparts)
return res
+ def storage_align(self, stage, iterator, factor, offset):
+ """ Schedule primitive corresponds to `te.Stage.storage_align`, see also the `te.Stage` for
+ more details.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be storage aligned, which can be specified by the integer index,
+ Operation, or output tensor of the stage.
+ iterator : Iterator
+ The iterator to be aligned.
+ factor : int
+ The factor in alignment specification.
+ offset : int
+ The offset in the alignment specification.
+ """
+ self.state_object = _ffi_api.StateStorageAlign(self.state_object,
+ self._resolve_stage_id(stage), iterator,
+ factor, offset)
+
def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
more details.
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op
+ def rfactor(self, stage, iterator, factor_iter_id):
+ """ Schedule primitive corresponds to `te.Schedule.rfactor`, see also the `te.Schedule` for
+ more details.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be factored, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iterator : Iterator
+ The reduction iterator to be factored.
+ factor_iter_id : int
+ The position where the new iterator is placed.
+
+ Returns
+ -------
+ new_stage_op : Operator
+ The Operator of the new added stage.
+
+ Notes
+ -----
+ Rfactor step will insert an extra stage to the original ComputeDAG (in the front of the
+ target stage).
+ """
+ self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object,
+ self._resolve_stage_id(stage),
+ iterator, factor_iter_id,
+ self.compute_dag)
+ # Add a new stage will change all ops behind the added stage. But we still want to keep the
+ # original ops map, apply stage id offset to stage_id_map to make them work.
+ self._apply_stage_id_offset(int(new_stage_id))
+ self._update_stage_id_map()
+ return self.stages[int(new_stage_id)].op
+
def copy(self):
""" Do deep copy of this State. """
state = State(self.state_object, self.compute_dag)
return step->ApplyToState(this);
}
+void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) {
+ const Stage& stage = operator->()->stages[stage_id];
+ PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
+}
+
void State::reorder(int stage_id, const Array<Iterator>& order) {
const Stage& stage = operator->()->stages[stage_id];
CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
return step->ApplyToState(this);
}
+void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) {
+ const Stage& stage = operator->()->stages[stage_id];
+ StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
+}
+
void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
const Stage& target_stage = operator->()->stages[target_stage_id];
ComputeAtStep step =
return step->ApplyToState(this, dag);
}
+int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) {
+ const Stage& stage = operator->()->stages[stage_id];
+ RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this, dag);
+}
+
void State::ApplySteps(const ComputeDAG& dag) {
CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
return Array<ObjectRef>{state, res};
});
+TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma")
+ .set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) {
+ state.pragma(stage_id, it, pragma_type);
+ return state;
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder")
.set_body_typed([](State state, int stage_id, const Array<Iterator>& order) {
state.reorder(stage_id, order);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});
+TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign")
+ .set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) {
+ state.storage_align(stage_id, it, factor, offset);
+ return state;
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
.set_body_typed([](State state, int stage_id, int target_stage_id,
const Iterator& target_iter) {
return Array<ObjectRef>{state, Integer(res)};
});
+TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor")
+ .set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id,
+ const ComputeDAG& dag) {
+ int res = state.rfactor(stage_id, it, factor_iter_id, dag);
+ return Array<ObjectRef>{state, Integer(res)};
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
return std::equal_to<State>()(state1, state2);
});
return AnnotationStep(reader);
} else if (name == FuseStepNode::record_prefix_str) {
return FuseStep(reader);
+ } else if (name == PragmaStepNode::record_prefix_str) {
+ return PragmaStep(reader);
} else if (name == ReorderStepNode::record_prefix_str) {
return ReorderStep(reader);
} else if (name == SplitStepNode::record_prefix_str) {
return FollowSplitStep(reader);
} else if (name == FollowFusedSplitStepNode::record_prefix_str) {
return FollowFusedSplitStep(reader);
+ } else if (name == StorageAlignStepNode::record_prefix_str) {
+ return StorageAlignStep(reader);
} else if (name == ComputeAtStepNode::record_prefix_str) {
return ComputeAtStep(reader);
} else if (name == ComputeInlineStepNode::record_prefix_str) {
return CacheReadStep(reader);
} else if (name == CacheWriteStepNode::record_prefix_str) {
return CacheWriteStep(reader);
+ } else if (name == RfactorStepNode::record_prefix_str) {
+ return RfactorStep(reader);
} else {
LOG(FATAL) << "Invalid step format: " << name;
}
ps->ApplyToState(state);
} else if (auto ps = step.as<FuseStepNode>()) {
ps->ApplyToState(state);
+ } else if (auto ps = step.as<PragmaStepNode>()) {
+ ps->ApplyToState(state);
} else if (auto ps = step.as<ReorderStepNode>()) {
ps->ApplyToState(state);
} else if (auto ps = step.as<SplitStepNode>()) {
ps->ApplyToState(state);
} else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
ps->ApplyToState(state);
+ } else if (auto ps = step.as<StorageAlignStepNode>()) {
+ ps->ApplyToState(state);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
ps->ApplyToState(state);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
ps->ApplyToState(state, dag);
} else if (auto ps = step.as<CacheWriteStepNode>()) {
ps->ApplyToState(state, dag);
+ } else if (auto ps = step.as<RfactorStepNode>()) {
+ ps->ApplyToState(state, dag);
} else {
LOG(FATAL) << "Invalid step: " << step;
}
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<PragmaStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<ReorderStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
} else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
+ } else if (auto ps = step.as<StorageAlignStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes, schedule);
} else if (auto ps = step.as<CacheWriteStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes, schedule);
+ } else if (auto ps = step.as<RfactorStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes, schedule);
} else {
LOG(FATAL) << "Invalid Step: " << step;
}
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else if (auto ps = step.as<PragmaStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<ReorderStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
} else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
+ } else if (auto ps = step.as<StorageAlignStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
} else if (auto ps = step.as<CacheWriteStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
+ } else if (auto ps = step.as<RfactorStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
} else {
LOG(FATAL) << "Invalid Step: " << step;
}
return ss.str();
}
+/********** Pragma **********/
+PragmaStep::PragmaStep(int stage_id, int iter_id, String pragma_type) {
+ auto node = make_object<PragmaStepNode>();
+ node->stage_id = stage_id;
+ node->iter_id = iter_id;
+ node->pragma_type = std::move(pragma_type);
+ data_ = std::move(node);
+}
+
+PragmaStep::PragmaStep(dmlc::JSONReader* reader) {
+ auto node = make_object<PragmaStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->iter_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ std::string string_value;
+ reader->Read(&string_value);
+ node->pragma_type = std::move(string_value);
+ data_ = std::move(node);
+}
+
+void PragmaStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(iter_id);
+ writer->WriteArraySeperator();
+ writer->WriteString(pragma_type);
+}
+
+void PragmaStepNode::ApplyToState(State* state) const {
+ if (pragma_type == "debug_skip_region") {
+ StateNode* pstate = state->CopyOnWrite();
+ pstate->attach_map.DeleteStage(stage_id);
+ } else if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
+ StateNode* pstate = state->CopyOnWrite();
+ Stage stage = pstate->stages[stage_id];
+ size_t pos = 0;
+ for (; pos < pragma_type.size(); ++pos) {
+ if ((*(pragma_type.c_str() + pos)) == '$') {
+ break;
+ }
+ }
+ CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
+ stage.CopyOnWrite()->attrs.auto_unroll_max_step = atoi(pragma_type.c_str() + pos + 1);
+ pstate->stages.Set(stage_id, std::move(stage));
+ } else {
+ LOG(FATAL) << "Unsupported pragma: " << pragma_type;
+ }
+}
+
+void PragmaStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ te::Stage stage = (*stages)[stage_id];
+ const Array<IterVar>& axes = (*stage_to_axes)[stage];
+ if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
+ size_t pos = 0;
+ for (; pos < pragma_type.size(); ++pos) {
+ if ((*(pragma_type.c_str() + pos)) == '$') {
+ break;
+ }
+ }
+ CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
+ int value = atoi(pragma_type.c_str() + pos + 1);
+ stage.pragma(axes[iter_id], "auto_unroll_max_step", value);
+ stage.pragma(axes[iter_id], "unroll_explicit", true);
+ } else {
+ stage.pragma(axes[iter_id], pragma_type);
+ }
+ stages->Set(stage_id, std::move(stage));
+}
+
+String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ std::stringstream ss;
+ const auto& stage = (*stages)[stage_id];
+
+ if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
+ size_t pos = 0;
+ for (; pos < pragma_type.size(); ++pos) {
+ if ((*(pragma_type.c_str() + pos)) == '$') {
+ break;
+ }
+ }
+ CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
+ int value = atoi(pragma_type.c_str() + pos + 1);
+ ss << "s[" << CleanName(stage->op->name) << "].pragma("
+ << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
+ << ", \"auto_unroll_max_step\", " << value << ")\n";
+ ss << "s[" << CleanName(stage->op->name) << "].pragma("
+ << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
+ << ", \"unroll_explicit\", True)\n";
+ } else {
+ ss << "s[" << CleanName(stage->op->name) << "].pragma("
+ << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type
+ << "\")\n";
+ }
+
+ ApplyToSchedule(stages, stage_to_axes);
+ return ss.str();
+}
+
/********** Reorder **********/
ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) {
auto node = make_object<ReorderStepNode>();
writer->WriteArrayItem(n_split);
}
-void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
- Array<Optional<Integer>>* lengths) const {
+Array<Optional<Integer>> FollowSplitStepNode::ExtractSplitLengths(
+ const Array<Step>& transform_steps) const {
// Make sure src_step_id is within the range of transform_steps.
CHECK_LT(src_step_id, transform_steps.size());
auto ps = transform_steps[src_step_id].as<SplitStepNode>();
CHECK_LE(n_split, ps->lengths.size() + 1);
CHECK(ps != nullptr);
- lengths->reserve(n_split);
+ Array<Optional<Integer>> lengths;
+ lengths.reserve(n_split);
int j = 0;
// Get the first (n_split-1) split factors of followed src_step.
for (; j < n_split - 1; ++j) {
- lengths->push_back(ps->lengths[j]);
+ lengths.push_back(ps->lengths[j]);
}
// Get the last split factor of src_step for splitting level if n_split is smaller than
}
}
if (last_factor.defined()) {
- lengths->push_back(Downcast<Integer>(last_factor));
+ lengths.push_back(Downcast<Integer>(last_factor));
} else {
- lengths->push_back(NullOpt);
+ lengths.push_back(NullOpt);
}
+
+ return lengths;
}
FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&node->n_split);
-
data_ = std::move(node);
}
Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
- Array<Optional<Integer>> lengths;
- ExtractSplitLengths((*state)->transform_steps, &lengths);
- return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+ return ApplySplitToState(state, stage_id, iter_id, ExtractSplitLengths((*state)->transform_steps),
+ true);
}
Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const {
- Array<Optional<Integer>> lengths;
- ExtractSplitLengths(transform_steps, &lengths);
- return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+ return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id,
+ ExtractSplitLengths(transform_steps), true);
}
String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const {
- Array<Optional<Integer>> lengths;
- ExtractSplitLengths(transform_steps, &lengths);
- return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+ return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id,
+ ExtractSplitLengths(transform_steps), true);
}
/********** Follow Fused Split **********/
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&node->factor_or_nparts);
-
::tvm::Array<::tvm::Integer> src_step_ids;
for (const auto& i : int_list) {
src_step_ids.push_back(i);
}
Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const {
- const Optional<Integer>& length = ExtractSplitLength((*state)->transform_steps);
- return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts);
+ return ApplySplitToState(state, stage_id, iter_id,
+ {ExtractSplitLength((*state)->transform_steps)}, factor_or_nparts);
}
Array<IterVar> FollowFusedSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const {
- const Optional<Integer>& length = ExtractSplitLength(transform_steps);
- return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts);
+ return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id,
+ {ExtractSplitLength(transform_steps)}, factor_or_nparts);
}
String FollowFusedSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const {
- const Optional<Integer>& length = ExtractSplitLength(transform_steps);
- return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length},
- factor_or_nparts);
+ return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id,
+ {ExtractSplitLength(transform_steps)}, factor_or_nparts);
+}
+
+/********** Storage Align **********/
+StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, int factor, int offset) {
+ auto node = make_object<StorageAlignStepNode>();
+ node->stage_id = stage_id;
+ node->iter_id = iter_id;
+ node->factor = factor;
+ node->offset = offset;
+ data_ = std::move(node);
+}
+
+StorageAlignStep::StorageAlignStep(dmlc::JSONReader* reader) {
+ auto node = make_object<StorageAlignStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->iter_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->factor);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->offset);
+ data_ = std::move(node);
+}
+
+void StorageAlignStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(iter_id);
+ writer->WriteArrayItem(factor);
+ writer->WriteArrayItem(offset);
+}
+
+void StorageAlignStepNode::ApplyToState(State* state) const {
+ StateNode* pstate = state->CopyOnWrite();
+ Stage stage = pstate->stages[stage_id];
+ stage.CopyOnWrite()->attrs.storage_offset = offset;
+ pstate->stages.Set(stage_id, std::move(stage));
+}
+
+void StorageAlignStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ te::Stage stage = (*stages)[stage_id];
+ const Array<IterVar>& axes = (*stage_to_axes)[stage];
+ stage.storage_align(axes[iter_id], factor, offset);
+ stages->Set(stage_id, std::move(stage));
+}
+
+String StorageAlignStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ std::stringstream ss;
+ const auto& stage = (*stages)[stage_id];
+ ss << "s[" << CleanName(stage->op->name) << "].storage_align("
+ << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", "
+ << offset << ")\n";
+
+ ApplyToSchedule(stages, stage_to_axes);
+ return ss.str();
}
/********** Steps working on multiple stages **********/
return ss.str();
}
-/********** Primitives adding new stages **********/
+/********** Steps adding new stages **********/
/*!
* \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep,
*/
Array<Step> GetFormerStageModifiableSteps(Step current_step, const Array<Step>& transform_steps) {
Array<Step> ret_steps;
- for (const Step& step : transform_steps) {
+ for (size_t i = 0; i < transform_steps.size(); ++i) {
+ const Step& step = transform_steps[i];
if (step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>()) {
ret_steps.push_back(step);
+ } else if (step->IsInstance<RfactorStepNode>()) {
+ // add FuseStepNode required by rfactor
+ if (i >= 2 && transform_steps[i - 2]->IsInstance<FuseStepNode>()) {
+ const Step& fuse_step = transform_steps[i - 2];
+ if (fuse_step->stage_id == step->stage_id) {
+ ret_steps.push_back(fuse_step);
+ }
+ }
+ // add SplitStepNode required by rfactor
+ CHECK_GE(i, 1);
+ CHECK(transform_steps[i - 1]->IsInstance<SplitStepNode>());
+ const Step& split_step = transform_steps[i - 1];
+ CHECK_EQ(split_step->stage_id, step->stage_id);
+ ret_steps.push_back(split_step);
+ // add RfactorStepNode
+ ret_steps.push_back(step);
}
- // TODO(jcf94): add rfactor support
// A state may have multiple stage modifiable steps, stop by the current step to avoid
// replaying excess steps
if (step.same_as(current_step)) {
return ss.str();
}
+/********** Rfactor **********/
+RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) {
+ auto node = make_object<RfactorStepNode>();
+ node->stage_id = stage_id;
+ node->iter_id = iter_id;
+ node->factor_iter_id = factor_iter_id;
+ data_ = std::move(node);
+}
+
+RfactorStep::RfactorStep(dmlc::JSONReader* reader) {
+ auto node = make_object<RfactorStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->iter_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->factor_iter_id);
+ data_ = std::move(node);
+}
+
+void RfactorStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(iter_id);
+ writer->WriteArrayItem(factor_iter_id);
+}
+
+int RfactorStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
+ StateNode* pstate = state->CopyOnWrite();
+ const auto& compute_at_type = pstate->stages[stage_id]->compute_at;
+ const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(
+ GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
+
+ // target_stage -> rfactor_compute + target_stage
+ // Insert a new compute stage, update the target stage and later stage, then update the stage_id
+ // mapping in AttachMap
+ pstate->stages.insert(pstate->stages.begin() + stage_id,
+ Stage(current_compute_dag->ops[stage_id]));
+ // Maintain the compute_at type of the target stage
+ Stage target_stage = Stage(current_compute_dag->ops[stage_id + 1]);
+ target_stage.CopyOnWrite()->compute_at = compute_at_type;
+ pstate->stages.Set(stage_id + 1, std::move(target_stage));
+ for (size_t i = stage_id + 2; i < pstate->stages.size(); ++i) {
+ Stage stage = pstate->stages[i];
+ stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+ pstate->stages.Set(i, std::move(stage));
+ }
+ pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id);
+ pstate->current_compute_dag = std::move(current_compute_dag);
+
+ return stage_id;
+}
+
+Array<te::Tensor> RfactorStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const {
+ const auto& stage = (*stages)[stage_id];
+ const Array<IterVar>& axes = (*stage_to_axes)[stage];
+
+ const te::Tensor& tensor = stage->origin_op.output(0);
+ const IterVar& axis = axes[iter_id];
+ auto outs = schedule->rfactor(tensor, axis, factor_iter_id);
+
+ UpdateStageToAxesMap(stage, stage_to_axes);
+ const auto& new_stage = (*schedule)[outs[0]->op];
+ UpdateStageToAxesMap(new_stage, stage_to_axes);
+ stages->insert(stages->begin() + stage_id, new_stage);
+
+ return outs;
+}
+
+String RfactorStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ te::Schedule* schedule) const {
+ std::stringstream ss;
+ const auto& stage = (*stages)[stage_id];
+
+ const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name);
+ const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint);
+
+ const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+ for (size_t i = 0; i < outs.size(); ++i) {
+ ss << CleanName(outs[i]->op->name);
+ if (i != outs.size() - 1) {
+ ss << ", ";
+ }
+ }
+ ss << " = "
+ << "s.rfactor(" << tensor_name << ", " << axis_name << ", " << factor_iter_id << ")\n";
+
+ for (const auto& out : outs) {
+ const auto& iters = out->op->root_iter_vars();
+ for (size_t i = 0; i < iters.size(); ++i) {
+ ss << CleanName(iters[i]->var->name_hint);
+ if (i != iters.size() - 1) {
+ ss << ", ";
+ }
+ }
+ ss << " = "
+ << "tuple(" << CleanName(out->op->name) << ".op.axis)"
+ << " + "
+ << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
+ }
+
+ const auto& output = (*stages)[stage_id + 1]->op.output(0);
+ const auto& iters = output->op->root_iter_vars();
+ for (size_t i = 0; i < iters.size(); ++i) {
+ ss << CleanName(iters[i]->var->name_hint);
+ if (i != iters.size() - 1) {
+ ss << ", ";
+ }
+ }
+ ss << " = "
+ << "tuple(s[" << CleanName(output->op->name) << "].op.axis)"
+ << " + "
+ << "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n";
+
+ return ss.str();
+}
+
} // namespace auto_scheduler
} // namespace tvm
return sum / float_array.size();
}
+/*! \brief Return whether a string starts with another substring */
+inline bool StrStartsWith(const String& a, const String& b) {
+ if (b.size() > a.size()) return false;
+ return std::equal(a.c_str(), a.c_str() + b.size(), b.c_str());
+}
+
/********** Other Utilities **********/
/*! \brief Get an int value from an Expr */
inline int64_t GetIntImm(const PrimExpr& expr) {
assert tmp[C].iters[level + 1].range.extent == \
tmp[C_global].iters[1].range.extent
+
+def test_rfactor():
+ A, B, C = matmul_auto_scheduler_test(8, 8, 512)
+ dag = auto_scheduler.ComputeDAG([A, B, C])
+ s0 = dag.get_init_state()
+
+ ko, ki = s0.split(C, s0[C].iters[2], [16])
+
+ s1 = s0.copy()
+ C_r = s1.rfactor(C, ko, 2)
+ """
+ Placeholder: A, B
+ for i (0,8)
+ for j (0,8)
+ for k_o (0,32)
+ for k_i (0,16)
+ C.rf = ...
+ for ax0 (0,8)
+ for ax1 (0,8)
+ for k_o_v (0,32)
+ C.repl = ...
+ """
+ assert s1[C_r].iters[0].range.extent == 8
+ assert s1[C_r].iters[1].range.extent == 8
+ assert s1[C_r].iters[2].range.extent == 32
+ assert s1[C_r].iters[3].range.extent == 16
+ assert s1[C].iters[0].range.extent == 8
+ assert s1[C].iters[1].range.extent == 8
+ assert s1[C].iters[2].range.extent == 32
+
+ s2 = s0.copy()
+ C_r = s2.rfactor(C, ki, 2)
+ """
+ Placeholder: A, B
+ for i (0,8)
+ for j (0,8)
+ for k_i (0,16)
+ for k_o (0,32)
+ C.rf = ...
+ for ax0 (0,8)
+ for ax1 (0,8)
+ for k_i_v (0,16)
+ C.repl = ...
+ """
+ assert s2[C_r].iters[0].range.extent == 8
+ assert s2[C_r].iters[1].range.extent == 8
+ assert s2[C_r].iters[2].range.extent == 16
+ assert s2[C_r].iters[3].range.extent == 32
+ assert s2[C].iters[0].range.extent == 8
+ assert s2[C].iters[1].range.extent == 8
+ assert s2[C].iters[2].range.extent == 16
+
+
if __name__ == "__main__":
test_split_fuse_reorder_annotation()
test_compute_at_root_inline()
test_cache_read_write()
test_follow_split_follow_fused_split()
+ test_rfactor()
from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul
-def test_record():
+def record_common(dag, s):
+ target = tvm.target.create("llvm")
+ task = auto_scheduler.SearchTask(dag, "test", target)
+
+ inp = auto_scheduler.measure.MeasureInput(task, s)
+ res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
+
+ with tempfile.NamedTemporaryFile() as fp:
+ auto_scheduler.save_records(fp.name, [inp], [res])
+
+ log_reader = auto_scheduler.RecordReader(fp.name)
+ inputs, results = log_reader.read_lines()
+ assert len(inputs) == 1
+
+ s1 = dag.infer_bound_from_state(s)
+ s2 = dag.infer_bound_from_state(inputs[0].state)
+
+ assert s1 == s2
+ assert not (s1 == dag.get_init_state())
+
+
+def test_record_split_reorder_fuse_annotation():
if not tvm.runtime.enabled("llvm"):
return
B = te.placeholder((512, 512), name='B')
k = te.reduce_axis((0, 512), name='k')
C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
- D = topi.nn.relu(C)
- k = te.reduce_axis((0, 512), name='k')
- E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E')
- F = topi.nn.relu(E)
- k = te.reduce_axis((0, 512), name='k')
- G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G')
- H = topi.nn.relu(G)
- I = topi.nn.relu(H)
- dag = auto_scheduler.ComputeDAG([A, B, I])
+ dag = auto_scheduler.ComputeDAG([A, B, C])
s = dag.get_init_state()
# Split
its1[3]])
# Fuse
s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]])
- # Compute at
- s.split(F, s[F].iters[0], [2])
- s.compute_at(E, F, s[F].iters[0])
- # Compute inline
- s.compute_inline(D)
- # Compute root
- s.compute_root(D)
# Parallel
s.parallel(C, s[C].iters[0])
# Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here)
s.unroll(C, s[C].iters[4])
# Vectorize
s.vectorize(C, s[C].iters[6])
- # Cache Read
- D_global = s.cache_read(D, "global", [E])
- s.compute_at(D_global, E, s[E].iters[2])
+
+ record_common(dag, s)
+
+
+def test_record_compute_at_root_inline_cache_read_write():
+ if not tvm.runtime.enabled("llvm"):
+ return
+
+ A = te.placeholder((512, 512), name='A')
+ AA = topi.nn.relu(A)
+ B = te.placeholder((512, 512), name='B')
+ k = te.reduce_axis((0, 512), name='k')
+ C = te.compute((512, 512), lambda i, j: te.sum(AA[i][k] * B[k][j], axis=[k]), name='C')
+
+ dag = auto_scheduler.ComputeDAG([A, B, C])
+ s = dag.get_init_state()
+
# Cache Write
- s.cache_write(D, "shared")
- #follow_split
- its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True)
+ C_shared = s.cache_write(C, "shared")
+ # Compute At
+ s.compute_at(C_shared, C, s[C].iters[0])
+ # Cache Read
+ B_global = s.cache_read(B, "global", [C_shared])
+ s.compute_at(B_global, C_shared, s[C_shared].iters[2])
+ # Compute Inline
+ s.compute_inline(AA)
+ # Compute Root
+ s.compute_root(C_shared)
+
+ record_common(dag, s)
+
+
+def test_record_follow_split_follow_fused_split():
+ if not tvm.runtime.enabled("llvm"):
+ return
+
+ A = te.placeholder((512, 512), name='A')
+ B = te.placeholder((512, 512), name='B')
+ k = te.reduce_axis((0, 512), name='k')
+ C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
+ D = topi.nn.relu(C)
+ E = topi.nn.relu(D)
+
+ dag = auto_scheduler.ComputeDAG([A, B, E])
+ s = dag.get_init_state()
+
+ # Follow Split
+ s.split(C, s[C].iters[0], [4, 2, 8, 4], True)
split_step0 = len(s.transform_steps) - 1
- s.follow_split(G, s[G].iters[5], split_step0, 4)
- #follow_fused_split
- its2 = s.split(H, s[H].iters[0], [4, 2, 8, 4], True)
+ s.follow_split(C, s[C].iters[5], split_step0, 4)
+ # Follow Fused Split
+ its0 = s.split(E, s[E].iters[0], [4, 2, 8, 4], True)
split_step1 = len(s.transform_steps) - 1
- its3 = s.split(H, s[H].iters[5], [2, 4, 2, 4], True)
+ its1 = s.split(E, s[E].iters[5], [2, 4, 2, 4], True)
split_step2 = len(s.transform_steps) - 1
its = []
- for i0, i1 in zip(its2, its3):
+ for i0, i1 in zip(its0, its1):
its.append(i0)
its.append(i1)
for i in range(0, 5):
- s.fuse(H, [s[H].iters[i], s[H].iters[i + 1]])
- s.follow_fused_split(I, s[I].iters[0], [split_step1, split_step2], 0, False)
+ s.fuse(E, [s[E].iters[i], s[E].iters[i + 1]])
+ s.follow_fused_split(D, s[D].iters[0], [split_step1, split_step2], 2, True)
- target = tvm.target.create("llvm")
- task = auto_scheduler.SearchTask(dag, "test", target)
+ record_common(dag, s)
- inp = auto_scheduler.measure.MeasureInput(task, s)
- res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
- with tempfile.NamedTemporaryFile() as fp:
- auto_scheduler.save_records(fp.name, [inp], [res])
+def test_record_pragma_storage_align_rfactor():
+ if not tvm.runtime.enabled("llvm"):
+ return
- log_reader = auto_scheduler.RecordReader(fp.name)
- inputs, results = log_reader.read_lines()
- assert len(inputs) == 1
+ A = te.placeholder((512, 512), name='A')
+ B = te.placeholder((512, 512), name='B')
+ k = te.reduce_axis((0, 512), name='k')
+ C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
- s1 = dag.infer_bound_from_state(s)
- s2 = dag.infer_bound_from_state(inputs[0].state)
+ dag = auto_scheduler.ComputeDAG([A, B, C])
+ s = dag.get_init_state()
- assert s1 == s2
- assert not (s1 == dag.get_init_state())
+ # Rfactor
+ ko, _ = s.split(C, s[C].iters[2], [16])
+ s.rfactor(C, ko, 2)
+ # Pragma
+ s.pragma(C, s[C].iters[0], "auto_unroll_max_step$64")
+ # StorageAlign
+ s.storage_align(C, s[C].iters[-1], 8, 4)
+
+ record_common(dag, s)
def test_measure_local_builder_runner():
if __name__ == "__main__":
- test_record()
+ test_record_split_reorder_fuse_annotation()
+ test_record_compute_at_root_inline_cache_read_write()
+ test_record_follow_split_follow_fused_split()
+ test_record_pragma_storage_align_rfactor()
test_measure_local_builder_runner()
test_measure_local_builder_rpc_runner()