TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
const Array<Optional<Integer>>& lengths,
bool inner_to_outer = true);
+ /*!
+ * \brief Schedule primitive extends to split step.
+ * \param stage_id The index of the stage to be split.
+ * \param it The iterator to be split.
+ * \param src_step_id The index of the split step to be followed in the history.
+ * \param n_split The number of split level.
+ * \return The splitted new Iterators.
+ */
+ TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
+ int n_split);
+ /*!
+ * \brief Schedule primitive extends to split step.
+ * \param stage_id The index of the stage to be split.
+ * \param it The iterator to be split.
+ * \param src_step_ids The indices of the split steps to be followed in the history.
+ * \param level Use the length in this split level.
+ * \param factor_or_nparts True to use `factor` for split from inner to outer,
+ False to use `nparts` for split from outer to inner.
+ * \return The splitted new Iterators.
+ */
+ TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
+ const Array<Integer>& src_step_ids, int level,
+ bool factor_or_nparts);
/********** Step APIs working on multiple stages **********/
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need
* `te::Schedule` API. (e.g. CacheRead/CacheWrite step)
+ * \param transform_steps An array record all transform steps.
*/
void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
- te::Schedule* schedule);
+ te::Schedule* schedule, const Array<Step>& transform_steps);
/*!
* \brief Print the step as equivalent python schedule API.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g.
* CacheRead/CacheWrite step)
+ * \param transform_steps An array record all transform steps.
* \return Python schedule code.
*/
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes, te::Schedule* schedule);
+ StageToAxesMap* stage_to_axes, te::Schedule* schedule,
+ const Array<Step>& transform_steps);
/********** Steps working on single stage **********/
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to be split. */
+ int iter_id;
+ /*! \brief The index of the split step to follow in the history. */
+ int src_step_id;
+ /*! \brief The number of split level. */
+ int n_split;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split lengths.
+ * \param transform_steps An array record all transform steps.
+ * \param lengths The multiple split factors. Can be None to be filled by search policy.
+ */
+ void ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to state, which will be updated.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param transform_steps An array record all transform steps.
+ * \return The iterator results after split.
+ */
+ Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param transform_steps An array record all transform steps.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ static constexpr const char* record_prefix_str = "FSP";
+
+ static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be split.
+ * \param iter_id The index of the iterator to be split.
+ * \param src_step_id The index of the split step to follow in the history.
+ * \param n_split The number of split level.
+ */
+ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps.
+ * \note This can be used for the split in cooperative fetching.
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to split. */
+ int iter_id;
+ /*! \brief The indices of the split steps to follow in the history. */
+ Array<Integer> src_step_ids;
+ /*! \brief Use the length in this split level. */
+ int level;
+ /*! \brief If this is true, use factor. Otherwise, use nparts. */
+ bool factor_or_nparts;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split length.
+ * \param transform_steps An array record all transform steps.
+ * \return Split factor.
+ */
+ Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to state, which will be updated.
+ * \return The iterator results after split.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param transform_steps An array record all transform steps.
+ * \return The iterator results after split.
+ */
+ Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages The `te::Stage`s used in TVM scheduler applying.
+ * \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
+ * \param transform_steps An array record all transform steps.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ static constexpr const char* record_prefix_str = "FFSP";
+
+ static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowFusedSplitStepNode.
+ * \sa FollowFusedSplitStepNode
+ */
+class FollowFusedSplitStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be split.
+ * \param iter_id The index of the iterator to be split.
+ * \param src_step_ids An array of index for split step to follow in the history.
+ * \param level Use the length in this split level.
+ * \param factor_or_nparts If this is true, use factor. Otherwise, use nparts.
+ */
+ FollowFusedSplitStep(int stage_id, int iter_id, const Array<Integer>& src_step_ids, int level,
+ bool factor_or_nparts);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit FollowFusedSplitStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
+};
+
/********** Steps working on multiple stages **********/
/*! \brief Compute at step that corresponds to te::Stage::compute_at */
return self.state_object.stages
@property
+ def transform_steps(self):
+ """
+ Returns
+ -------
+ transform_steps : List[transform_steps]
+ """
+ return self.state_object.transform_steps
+
+ @property
def stage_ops(self):
"""
Returns
iterator, lengths, inner_to_outer)
return res
+ def follow_split(self, stage, iterator, src_step_id, n_split):
+ """ Schedule primitive extends to split step.
+
+ This step splits the iterator by the same factors as the given SplitStep.
+
+ Notes
+ ------
+ This step is useful in a scenario that we have subgraph Dense -> Relu,
+ and we want to compute the Dense stage at ReLU. In this case, we need them to have
+ the same tiling structure of common outer loops.
+ The follow_split step could be used here to split the Dense stage and makes sure its
+ splitting factors are the same as the given split step for the ReLU stage.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be split, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iterator : Iterator
+ The iterator to split.
+ src_step_id : int
+ The index of the split step to follow in the history.
+ n_split : int
+ The number of split level.
+
+ Returns
+ -------
+ res_its : List[Iterator]
+ The splitted new Iterators.
+ """
+
+ self.state_object, res = _ffi_api.StateFollowSplit(self.state_object,
+ self._resolve_stage_id(stage),
+ iterator,
+ src_step_id, n_split)
+ return res
+
+ def follow_fused_split(self, stage, iterator, src_step_ids, level,
+ factor_or_nparts):
+ """ Schedule primitive extends to split step.
+
+ This step is used to split an iterator by the same factors
+ as the given list of SplitSteps and FuseSteps.
+
+ Notes
+ ------
+ This step is useful in a scenario that we have a subgraph
+ in GPU schedule: Input -> Dense
+ for i.0@j.0 = ... : Bind to blockIdx.x
+ for i.1@j.1 = ... : Bind to threadIdx.x
+ for i.2@j.2 = ...
+ Input_shared = Input ...
+ for k = ...
+ Dense = ...
+ We intend to apply cooperative fetching with the input stage, while the threadIdx.x
+ axis is bound to an iterator generated by split & fuse step.
+ The follow_fused_step is used split the iterator to 2 parts, while the split factor
+ matches the final extent of the threadIdx.x bound iterator.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be split, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iterator : Iterator
+ The iterator to split.
+ src_step_ids : List[int]
+ The indices of the split steps to follow in the history.
+ level : int
+ Use the length in this split level.
+ factor_or_nparts : bool
+ True to use `factor` for split from inner to outer,
+ False to use `nparts` for split from outer to inner.
+
+ Returns
+ -------
+ res_its : List[Iterator]
+ The splitted new Iterators.
+ """
+
+ self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object,
+ self._resolve_stage_id(stage),
+ iterator,
+ src_step_ids, level,
+ factor_or_nparts)
+ return res
+
def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
more details.
// Apply the history steps to TVM schedule
// Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
- StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
+ StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps);
}
return std::make_pair(schedule, operator->()->tensors);
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
- ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
+ ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps);
}
return ss.str();
return step->ApplyToState(this);
}
+Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
+ int n_split) {
+ const Stage& stage = operator->()->stages[stage_id];
+ FollowSplitStep step =
+ FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
+}
+
+Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
+ const Array<Integer>& src_step_ids, int level,
+ bool factor_or_nparts) {
+ const Stage& stage = operator->()->stages[stage_id];
+ FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
+ src_step_ids, level, factor_or_nparts);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
+}
+
void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
const Stage& target_stage = operator->()->stages[target_stage_id];
ComputeAtStep step =
return Array<ObjectRef>{state, res};
});
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
+ .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
+ int n_split) {
+ const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
+ return Array<ObjectRef>{state, Array<Iterator>(res)};
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
+ .set_body_typed([](State state, int stage_id, const Iterator& it,
+ const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) {
+ const auto& res =
+ state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts);
+ return Array<ObjectRef>{state, Array<Iterator>(res)};
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
.set_body_typed([](State state, int stage_id, int target_stage_id,
const Iterator& target_iter) {
return ReorderStep(reader);
} else if (name == SplitStepNode::record_prefix_str) {
return SplitStep(reader);
+ } else if (name == FollowSplitStepNode::record_prefix_str) {
+ return FollowSplitStep(reader);
+ } else if (name == FollowFusedSplitStepNode::record_prefix_str) {
+ return FollowFusedSplitStep(reader);
} else if (name == ComputeAtStepNode::record_prefix_str) {
return ComputeAtStep(reader);
} else if (name == ComputeInlineStepNode::record_prefix_str) {
ps->ApplyToState(state);
} else if (auto ps = step.as<SplitStepNode>()) {
ps->ApplyToState(state);
+ } else if (auto ps = step.as<FollowSplitStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
+ ps->ApplyToState(state);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
ps->ApplyToState(state);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
}
void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
- te::Schedule* schedule) {
+ te::Schedule* schedule, const Array<Step>& transform_steps) {
if (auto ps = step.as<AnnotationStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<FollowSplitStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
+ } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
}
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes, te::Schedule* schedule) {
+ StageToAxesMap* stage_to_axes, te::Schedule* schedule,
+ const Array<Step>& transform_steps) {
if (auto ps = step.as<AnnotationStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else if (auto ps = step.as<FollowSplitStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
+ } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
} else if (auto ps = step.as<ComputeAtStepNode>()) {
return ps->PrintAsPythonAPI(stages, stage_to_axes);
} else if (auto ps = step.as<ComputeInlineStepNode>()) {
return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
}
+/********** Follow Split **********/
+FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
+ auto node = make_object<FollowSplitStepNode>();
+ node->stage_id = stage_id;
+ node->iter_id = iter_id;
+ node->src_step_id = src_step_id;
+ node->n_split = n_split;
+ data_ = std::move(node);
+}
+
+void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(iter_id);
+ writer->WriteArrayItem(src_step_id);
+ writer->WriteArrayItem(n_split);
+}
+
+void FollowSplitStepNode::ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const {
+ // Make sure src_step_id is within the range of transform_steps.
+ CHECK_LT(src_step_id, transform_steps.size());
+ auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+ CHECK(ps != nullptr);
+
+ // Make sure the size of ps->lengths is not smaller than n_split-1.
+ // Note that the number of actual splitting factors of src_step is ps->lengths.size()+1.
+ CHECK_LE(n_split, ps->lengths.size() + 1);
+ CHECK(ps != nullptr);
+
+ lengths->reserve(n_split);
+ int j = 0;
+ // Get the first (n_split-1) split factors of followed src_step.
+ for (; j < n_split - 1; ++j) {
+ lengths->push_back(ps->lengths[j]);
+ }
+
+ // Get the last split factor of src_step for splitting level if n_split is smaller than
+ // ps->lengths.size()+1.
+ PrimExpr last_factor = 1;
+ for (; j < static_cast<int>(ps->lengths.size()); ++j) {
+ if (ps->lengths[j]) {
+ last_factor *= ps->lengths[j].value();
+ } else {
+ last_factor = PrimExpr();
+ break;
+ }
+ }
+ if (last_factor.defined()) {
+ lengths->push_back(Downcast<Integer>(last_factor));
+ } else {
+ lengths->push_back(NullOpt);
+ }
+}
+
+FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
+ auto node = make_object<FollowSplitStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->iter_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->src_step_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->n_split);
+
+ data_ = std::move(node);
+}
+
+Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
+ Array<Optional<Integer>> lengths;
+ ExtractSplitLengths((*state)->transform_steps, &lengths);
+ return ApplySplitToState(state, stage_id, iter_id, lengths, true);
+}
+
+Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const {
+ Array<Optional<Integer>> lengths;
+ ExtractSplitLengths(transform_steps, &lengths);
+ return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const {
+ Array<Optional<Integer>> lengths;
+ ExtractSplitLengths(transform_steps, &lengths);
+ return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, true);
+}
+
+/********** Follow Fused Split **********/
+FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
+ const Array<Integer>& src_step_ids, int level,
+ bool factor_or_nparts) {
+ auto node = make_object<FollowFusedSplitStepNode>();
+ node->stage_id = stage_id;
+ node->iter_id = iter_id;
+ node->src_step_ids = src_step_ids;
+ node->level = level;
+ node->factor_or_nparts = factor_or_nparts;
+ data_ = std::move(node);
+}
+
+FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
+ auto node = make_object<FollowFusedSplitStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->iter_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ std::vector<int> int_list;
+ reader->Read(&int_list);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->level);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->factor_or_nparts);
+
+ ::tvm::Array<::tvm::Integer> src_step_ids;
+ for (const auto& i : int_list) {
+ src_step_ids.push_back(i);
+ }
+ node->src_step_ids = src_step_ids;
+ data_ = std::move(node);
+}
+
+void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(iter_id);
+ writer->WriteArrayItem(IntArrayToVector(src_step_ids));
+ writer->WriteArrayItem(level);
+ writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
+}
+
+Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
+ const Array<Step>& transform_steps) const {
+ PrimExpr ret(1);
+
+ for (int src_step_id : src_step_ids) {
+ // Make sure the src_step_id is within the range of transform_steps.
+ CHECK_LT(src_step_id, transform_steps.size());
+ auto ps = transform_steps[src_step_id].as<SplitStepNode>();
+ CHECK(ps != nullptr);
+ // Multiple the splitting factor on corresponding splitting level of src_steps.
+ if (ps->lengths[level] && ret.defined()) {
+ ret *= ps->lengths[level].value();
+ } else {
+ return NullOpt;
+ }
+ }
+ return Downcast<Integer>(ret);
+}
+
+Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const {
+ const Optional<Integer>& length = ExtractSplitLength((*state)->transform_steps);
+ return ApplySplitToState(state, stage_id, iter_id, {length}, factor_or_nparts);
+}
+
+Array<IterVar> FollowFusedSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const {
+ const Optional<Integer>& length = ExtractSplitLength(transform_steps);
+ return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, {length}, factor_or_nparts);
+}
+
+String FollowFusedSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps) const {
+ const Optional<Integer>& length = ExtractSplitLength(transform_steps);
+ return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, {length},
+ factor_or_nparts);
+}
+
/********** Steps working on multiple stages **********/
/********** Compute At **********/
assert res == s1[C].iters[5]
assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"]
-
def test_compute_at_root_inline():
dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, CI=3, CO=64,
kernel_size=7, strides=2, padding=3))
assert s0[conv].iters[5].range.extent == 7
assert s0[conv].iters[6].range.extent == 7
-
def test_cache_read_write():
N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (
1, 1), (1, 1)
for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters):
assert it0.range == it1.range
+def test_follow_split_follow_fused_split():
+ A, B, C = matmul_auto_scheduler_test(512, 512, 512)
+ dag = auto_scheduler.ComputeDAG([A, B, C])
+ s0 = dag.get_init_state()
+
+ C_global = s0.cache_write(C, "global")
+ its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True)
+ split_step0 = len(s0.transform_steps) - 1
+ for level in range(1, 6):
+ tmp = s0.copy()
+ tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level)
+ for i in range(0, level):
+ assert tmp[C].iters[i].range.extent == \
+ tmp[C_global].iters[i].range.extent
+
+ its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8])
+ split_step1 = len(s0.transform_steps) - 1
+ its = []
+ for i0, i1 in zip(its0, its1):
+ its.append(i0)
+ its.append(i1)
+ s0.reorder(C, its)
+ for i in range(0, 5):
+ s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]])
+
+ for level in range(0, 4):
+ tmp = s0.copy()
+ tmp.follow_fused_split(C_global, tmp[C_global].iters[0],
+ [split_step0, split_step1], level, False)
+ assert tmp[C].iters[level + 1].range.extent == \
+ tmp[C_global].iters[0].range.extent
+
+ for level in range(0, 4):
+ tmp = s0.copy()
+ tmp.follow_fused_split(C_global, tmp[C_global].iters[0],
+ [split_step0, split_step1], level, True)
+ assert tmp[C].iters[level + 1].range.extent == \
+ tmp[C_global].iters[1].range.extent
+
if __name__ == "__main__":
test_split_fuse_reorder_annotation()
test_compute_at_root_inline()
test_cache_read_write()
+ test_follow_split_follow_fused_split()
from tvm import te, auto_scheduler
import tempfile
-from test_auto_scheduler_common import get_tiled_matmul
-
+from test_auto_scheduler_common import matmul_auto_scheduler_test, get_tiled_matmul
def test_record():
if not tvm.runtime.enabled("llvm"):
k = te.reduce_axis((0, 512), name='k')
E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E')
F = topi.nn.relu(E)
+ k = te.reduce_axis((0, 512), name='k')
+ G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G')
+ H = topi.nn.relu(G)
+ I = topi.nn.relu(H)
- dag = auto_scheduler.ComputeDAG([A, B, F])
+ dag = auto_scheduler.ComputeDAG([A, B, I])
s = dag.get_init_state()
# Split
s.compute_at(D_global, E, s[E].iters[2])
# Cache Write
s.cache_write(D, "shared")
+ #follow_split
+ its2 = s.split(G, s[G].iters[0], [4, 2, 8, 4], True)
+ split_step0 = len(s.transform_steps) - 1
+ s.follow_split(G, s[G].iters[5], split_step0, 4)
+ #follow_fused_split
+ its2 = s.split(H, s[H].iters[0], [4, 2, 8, 4], True)
+ split_step1 = len(s.transform_steps) - 1
+ its3 = s.split(H, s[H].iters[5], [2, 4, 2, 4], True)
+ split_step2 = len(s.transform_steps) - 1
+ its = []
+ for i0, i1 in zip(its2, its3):
+ its.append(i0)
+ its.append(i1)
+ for i in range(0, 5):
+ s.fuse(H, [s[H].iters[i], s[H].iters[i + 1]])
+ s.follow_fused_split(I, s[I].iters[0], [split_step1, split_step2], 0, False)
target = tvm.target.create("llvm")
task = auto_scheduler.SearchTask(dag, "test", target)