verify(root_stmt_);
}
-const std::unordered_set<BufPtr> LoopNest::getIntermediateBufs() const {
+std::unordered_set<BufPtr> LoopNest::getIntermediateBufs() const {
std::unordered_set<BufPtr> result;
auto input_bufs = getInputBufs();
auto bufs = NodeFinder<Buf>::find(root_stmt_);
return b;
}
-StmtPtr LoopNest::insertAllocFree(StmtPtr stmt) {
- auto intermediate_bufs = getIntermediateBufs();
+StmtPtr LoopNest::insertAllocFree(
+ StmtPtr stmt,
+ const c10::optional<std::unordered_set<BufPtr>>&
+ interm_bufs /* = c10::nullopt*/) {
+ std::unordered_set<BufPtr> intermediate_bufs;
+ if (interm_bufs) {
+ intermediate_bufs = *interm_bufs;
+ } else {
+ intermediate_bufs = getIntermediateBufs();
+ }
+
if (intermediate_bufs.size() == 0ULL) {
return stmt;
}
root_stmt_ = root_stmt_->accept_mutator(&deleter);
}
-void LoopNest::prepareForCodegen() {
+void LoopNest::prepareForCodegen(
+ const c10::optional<std::unordered_set<BufPtr>>&
+ interm_bufs /*= c10::nullopt*/) {
// Expand reduction ops.
ReductionExpander reduceExpander;
root_stmt_ = reduceExpander.expand(root_stmt_);
root_stmt_ = FlattenIndexes(root_stmt_);
// Add allocs and frees for intermediate buffers at the global level.
- root_stmt_ = insertAllocFree(root_stmt_);
+ root_stmt_ = insertAllocFree(root_stmt_, interm_bufs);
}
namespace {
void vectorizeInnerLoops();
void eliminateDeadStores();
- void prepareForCodegen();
+
+ // Make the stmt ready for codegen. The optional argument 'interm_bufs' allows
+ // users to specify intermediate buffers that need runtime allocation. In
+ // default, we will insert 'Alloc/Free' stmts to allocate all intermediate
+ // buffers at runtime but users may have pre-allocated some of them at compile
+ // time, and in that case the user can specify what buffers to insert
+ // 'Alloc/Free' stmts for using 'interm_bufs'.
+ // TODO: refactor function 'prepareForCodegen' to remove argument
+ // 'interm_bufs'.
+ void prepareForCodegen(
+ const c10::optional<std::unordered_set<BufPtr>>& interm_bufs =
+ c10::nullopt);
const std::unordered_set<BufPtr> getInputBufs() const;
const std::unordered_set<BufPtr> getOutputBufs() const {
return output_bufs_;
}
+ std::unordered_set<BufPtr> getIntermediateBufs() const;
private:
void initialize(
const std::vector<Tensor>& output_tensors,
const std::vector<Tensor>& tensors_to_compute);
- StmtPtr insertAllocFree(StmtPtr stmt);
- const std::unordered_set<BufPtr> getIntermediateBufs() const;
+ StmtPtr insertAllocFree(
+ StmtPtr stmt,
+ const c10::optional<std::unordered_set<BufPtr>>& interm_bufs =
+ c10::nullopt);
StmtPtr root_stmt_;
return std::make_unique<LoopNest>(s, buf_nodes);
}))
.def("vectorize_inner_loops", &LoopNest::vectorizeInnerLoops)
- .def("prepare_for_codegen", &LoopNest::prepareForCodegen)
+ .def(
+ "prepare_for_codegen",
+ [](LoopNest& self) { return self.prepareForCodegen(); },
+ py::return_value_policy::reference)
.def(
"get_loop_body_for",
[](const LoopNest& self, Tensor t) { return self.getLoopBodyFor(t); },