ExprPtr value = v->value();
ExprPtr new_value = value->accept_mutator(this);
if (v->indices().size() == 1 && value == new_value) {
- return (StmtPtr)v;
+ return v;
}
- return alloc<Store>(
- v->buf(),
- std::vector<ExprPtr>({flatten_index(v->buf()->dims(), v->indices())}),
- new_value);
+ std::vector<ExprPtr> indices = {
+ flatten_index(v->buf()->dims(), v->indices())};
+ v->set_indices(indices);
+ v->set_value(new_value);
+ return v;
}
};
ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
newIndices.push_back(sub);
}
-
- return alloc<Load>(cache_, newIndices);
+ v->set_buf(cache_);
+ v->set_indices(newIndices);
+ return v;
}
StmtPtr mutate(StorePtr v) override {
ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
newIndices.push_back(sub);
}
-
- return alloc<Store>(cache_, newIndices, newValue);
+ v->set_buf(cache_);
+ v->set_indices(newIndices);
+ v->set_value(newValue);
+ return v;
}
BufPtr buf_;
// Replace acceses to the producer in the consumer with the cache.
CacheReplacer replacer(producer, tmp_buf, info.start);
- // TODO: Can we reuse 'consumer' below without cloning?
- StmtPtr new_consumer =
- IRSimplifier::simplify(Stmt::clone(consumer)->accept_mutator(&replacer));
+ consumer->accept_mutator(&replacer);
// replace the old consumer with the replaced consumer.
- BlockPtr consumer_block = nullptr;
+ BlockPtr consumer_block = to<Block>(consumer);
+ BlockPtr parent_block = to<Block>(consumer->get_parent());
// if the consumer is a block, we should mutate it in place.
- if ((consumer_block = to<Block>(consumer))) {
- consumer_block->clear();
- consumer_block->append_stmt(new_consumer);
- } else {
- consumer_block = to<Block>(consumer->get_parent());
- assert(consumer_block);
- consumer_block->replace_stmt(consumer, new_consumer);
- }
+ bool is_block = consumer_block != nullptr;
// If there's a reduction and we are operating on the reduce axis, we need to
// initialize the cache with 0s. Also, we can't just write the result straight
alloc<For>(new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_init);
}
- consumer_block->insert_stmt_before(tmp_init, new_consumer);
+ if (is_block) {
+ consumer_block->prepend_stmt(tmp_init);
+ } else {
+ parent_block->insert_stmt_before(tmp_init, consumer);
+ }
// Reduce back to the original buffer:
StmtPtr tmp_store = alloc<Store>(
new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_store);
}
- consumer_block->insert_stmt_after(tmp_store, new_consumer);
+ if (is_block) {
+ consumer_block->append_stmt(tmp_store);
+ } else {
+ parent_block->insert_stmt_after(tmp_store, consumer);
+ }
- return std::make_pair(tmp_buf, new_consumer);
+ return std::make_pair(tmp_buf, consumer);
}
if (hasReads) {
new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_store);
}
- consumer_block->insert_stmt_before(tmp_store, new_consumer);
+ if (is_block) {
+ consumer_block->prepend_stmt(tmp_store);
+ } else {
+ parent_block->insert_stmt_before(tmp_store, consumer);
+ }
}
if (hasWrites) {
new_loop_vars[i], alloc<IntImm>(0), tmp_dims[i], tmp_store);
}
- consumer_block->insert_stmt_after(tmp_store, new_consumer);
+ if (is_block) {
+ consumer_block->append_stmt(tmp_store);
+ } else {
+ parent_block->insert_stmt_after(tmp_store, consumer);
+ }
}
- return std::make_pair(tmp_buf, new_consumer);
+ return std::make_pair(tmp_buf, consumer);
}
/*