/// Returns a builder for the body of a for Stmt.
static MLFuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) {
- return MLFuncBuilder(forStmt, forStmt->end());
+ return MLFuncBuilder(forStmt->getBody(), forStmt->getBody()->end());
}
/// Returns the current insertion point of the builder.
assert(index < getNumSuccessors());
return getBasicBlockOperands()[index].get();
}
- BasicBlock *getSuccessor(unsigned index) const {
+ const BasicBlock *getSuccessor(unsigned index) const {
return const_cast<Instruction *>(this)->getSuccessor(index);
}
void setSuccessor(BasicBlock *block, unsigned index);
}
};
+/// A ForStmtBody represents statements contained within a ForStmt.
+class ForStmtBody : public StmtBlock {
+public:
+ explicit ForStmtBody(ForStmt *stmt)
+ : StmtBlock(StmtBlockKind::ForBody), forStmt(stmt) {
+ assert(stmt != nullptr && "ForStmtBody must have non-null parent");
+ }
+
+ ~ForStmtBody() {}
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast
+ static bool classof(const StmtBlock *block) {
+ return block->getStmtBlockKind() == StmtBlockKind::ForBody;
+ }
+
+ /// Returns the 'for' statement that contains this body.
+ ForStmt *getFor() { return forStmt; }
+ const ForStmt *getFor() const { return forStmt; }
+
+private:
+ ForStmt *forStmt;
+};
+
/// For statement represents an affine loop nest.
-class ForStmt : public Statement, public MLValue, public StmtBlock {
+class ForStmt : public Statement, public MLValue {
public:
static ForStmt *create(Location location, ArrayRef<MLValue *> lbOperands,
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
// since child statements need to be destroyed before the MLValue that this
// for stmt represents is destroyed. Affine maps are immortal objects and
// don't need to be deleted.
- clear();
+ getBody()->clear();
}
/// Resolve base class ambiguity.
using operand_range = llvm::iterator_range<operand_iterator>;
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
+ /// Get the body of the ForStmt.
+ ForStmtBody *getBody() { return &body; }
+
+ /// Get the body of the ForStmt.
+ const ForStmtBody *getBody() const { return &body; }
+
//===--------------------------------------------------------------------===//
// Bounds and step
//===--------------------------------------------------------------------===//
return ptr->getKind() == IROperandOwner::Kind::ForStmt;
}
- static bool classof(const StmtBlock *block) {
- return block->getStmtBlockKind() == StmtBlockKind::For;
- }
-
// For statement represents implicitly represents induction variable by
// inheriting from MLValue class. Whenever you need to refer to the loop
// induction variable, just use the for statement itself.
}
private:
+ // The StmtBlock for the body.
+ ForStmtBody body;
+
// Affine map for the lower bound.
AffineMap lbMap;
// Affine map for the upper bound. The upper bound is exclusive.
~IfClause() {}
/// Returns the if statement that contains this clause.
- IfStmt *getIf() const { return ifStmt; }
+ const IfStmt *getIf() const { return ifStmt; }
+
+ IfStmt *getIf() { return ifStmt; }
private:
IfStmt *ifStmt;
public:
enum class StmtBlockKind {
MLFunc, // MLFunction
- For, // ForStmt
+ ForBody, // ForStmtBody
IfClause // IfClause
};
/// Returns the closest surrounding statement that contains this block or
/// nullptr if this is a top-level statement block.
- Statement *getContainingStmt() const;
+ Statement *getContainingStmt();
+
+ const Statement *getContainingStmt() const {
+ return const_cast<StmtBlock *>(this)->getContainingStmt();
+ }
/// Returns the function that this statement block is part of.
/// The function is determined by traversing the chain of parent statements.
void walkForStmt(ForStmt *forStmt) {
static_cast<SubClass *>(this)->visitForStmt(forStmt);
- static_cast<SubClass *>(this)->walk(forStmt->begin(), forStmt->end());
+ auto *body = forStmt->getBody();
+ static_cast<SubClass *>(this)->walk(body->begin(), body->end());
}
void walkForStmtPostOrder(ForStmt *forStmt) {
- static_cast<SubClass *>(this)->walkPostOrder(forStmt->begin(),
- forStmt->end());
+ auto *body = forStmt->getBody();
+ static_cast<SubClass *>(this)->walkPostOrder(body->begin(), body->end());
static_cast<SubClass *>(this)->visitForStmt(forStmt);
}
}
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
assert(isa<ForStmt>(commonForValue));
- return dyn_cast<ForStmt>(commonForValue);
+ return cast<ForStmt>(commonForValue)->getBody();
}
// Returns true if the ancestor operation statement of 'srcAccess' properly
// violation when we have the support.
bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
ArrayRef<uint64_t> shifts) {
- assert(shifts.size() == forStmt.getStatements().size());
+ auto *forBody = forStmt.getBody();
+ assert(shifts.size() == forBody->getStatements().size());
unsigned s = 0;
- for (const auto &stmt : forStmt) {
+ for (const auto &stmt : *forBody) {
// A for or if stmt does not produce any def/results (that are used
// outside).
if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
// This is a naive way. If performance becomes an issue, a map can
// be used to store 'shifts' - to look up the shift for a statement in
// constant time.
- if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner()))
- if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)])
+ if (auto *ancStmt = forBody->findAncestorStmtInBlock(*use.getOwner()))
+ if (shifts[s] != shifts[forBody->findStmtPosInBlock(*ancStmt)])
return false;
}
}
if (level == positions.size() - 1)
return &stmt;
if (auto *childForStmt = dyn_cast<ForStmt>(&stmt))
- return getStmtAtPosition(positions, level + 1, childForStmt);
+ return getStmtAtPosition(positions, level + 1, childForStmt->getBody());
if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen());
// Clone src loop nest and insert it a the beginning of the statement block
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
- MLFuncBuilder b(dstForStmt, dstForStmt->begin());
+ MLFuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
DenseMap<const MLValue *, MLValue *> operandMap;
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
// Lookup stmt in cloned 'sliceLoopNest' at 'positions'.
Statement *sliceStmt =
- getStmtAtPosition(positions, /*level=*/0, sliceLoopNest);
+ getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
// Get loop nest surrounding 'sliceStmt'.
SmallVector<ForStmt *, 4> sliceSurroundingLoops;
getLoopIVs(*sliceStmt, &sliceSurroundingLoops);
HashTable::ScopeTy blockScope(liveValues);
// The induction variable of a for statement is live within its body.
- if (auto *forStmt = dyn_cast<ForStmt>(&block))
- liveValues.insert(forStmt, true);
+ if (auto *forStmtBody = dyn_cast<ForStmtBody>(&block))
+ liveValues.insert(forStmtBody->getFor(), true);
for (auto &stmt : block) {
// Verify that each of the operands are live.
return true;
}
if (auto *forStmt = dyn_cast<ForStmt>(&stmt))
- if (walkBlock(*forStmt))
+ if (walkBlock(*forStmt->getBody()))
return true;
}
if (!hasShorthandForm(ubMap))
recordAffineMapReference(ubMap);
- for (auto &childStmt : *forStmt)
+ for (auto &childStmt : *forStmt->getBody())
visitStatement(&childStmt);
}
os << " step " << stmt->getStep();
os << " {\n";
- print(static_cast<const StmtBlock *>(stmt));
+ print(stmt->getBody());
os.indent(numSpaces) << "}";
}
int cloneOperandIt = operands.size() - 1, operandIt = getNumOperands() - 1;
for (int succIt = getNumSuccessors() - 1, succE = 0; succIt >= succE;
--succIt) {
- successors[succIt] = getSuccessor(succIt);
+ successors[succIt] = const_cast<BasicBlock *>(getSuccessor(succIt));
// Add the successor operands in-place in reverse order.
for (unsigned i = 0, e = getNumSuccessorOperands(succIt); i != e;
: Statement(Kind::For, location),
MLValue(MLValueKind::ForStmt,
Type::getIndex(lbMap.getResult(0).getContext())),
- StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
+ body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
operands.reserve(numOperands);
}
operandMap[forStmt] = newFor;
// Recursively clone the body of the for loop.
- for (auto &subStmt : *forStmt)
- newFor->push_back(subStmt.clone(operandMap, context));
+ for (auto &subStmt : *forStmt->getBody())
+ newFor->getBody()->push_back(subStmt.clone(operandMap, context));
return newFor;
}
// Statement block
//===----------------------------------------------------------------------===//
-Statement *StmtBlock::getContainingStmt() const {
+Statement *StmtBlock::getContainingStmt() {
switch (kind) {
case StmtBlockKind::MLFunc:
return nullptr;
- case StmtBlockKind::For:
- return cast<ForStmt>(const_cast<StmtBlock *>(this));
+ case StmtBlockKind::ForBody:
+ return cast<ForStmtBody>(this)->getFor();
case StmtBlockKind::IfClause:
return cast<IfClause>(this)->getIf();
}
}
MLFunction *StmtBlock::findFunction() const {
+ // FIXME: const incorrect.
StmtBlock *block = const_cast<StmtBlock *>(this);
while (block->getContainingStmt()) {
// If parsing of the for statement body fails,
// MLIR contains for statement with those nested statements that have been
// successfully parsed.
- if (parseStmtBlock(forStmt))
+ if (parseStmtBlock(forStmt->getBody()))
return ParseFailure;
// Reset insertion point to the current block.
// Walking manually because we need custom logic before and after traversing
// the list of children.
builder.setInsertionPoint(loopBodyFirstBlock);
- visitStmtBlock(forStmt);
+ visitStmtBlock(forStmt->getBody());
// Builder point is currently at the last block of the loop body. Append the
// induction variable stepping to this block and branch back to the exit
replaceAllMemRefUsesWith(memref, cast<MLValue>(fastMemRef),
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
- /*domStmtFilter=*/&*forStmt->begin());
+ /*domStmtFilter=*/&*forStmt->getBody()->begin());
return true;
}
// the pass has to be instantiated with additional information that we aren't
// provided with at the moment.
if (forStmt->getStep() != 1) {
- if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->begin())) {
+ if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) {
runOnForStmt(innerFor);
}
return;
// destination's body.
static inline void moveLoopBody(ForStmt *src, ForStmt *dest,
StmtBlock::iterator loc) {
- dest->getStatements().splice(loc, src->getStatements());
+ dest->getBody()->getStatements().splice(loc, src->getBody()->getStatements());
}
// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body.
static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
- moveLoopBody(src, dest, dest->begin());
+ moveLoopBody(src, dest, dest->getBody()->begin());
}
/// Constructs and sets new loop bounds after tiling for the case of
MLFuncBuilder b(topLoop);
// Loop bounds will be set later.
auto *pointLoop = b.createFor(loc, 0, 0);
- pointLoop->getStatements().splice(
- pointLoop->begin(), topLoop->getBlock()->getStatements(), topLoop);
+ pointLoop->getBody()->getStatements().splice(
+ pointLoop->getBody()->begin(), topLoop->getBlock()->getStatements(),
+ topLoop);
newLoops[2 * width - 1 - i] = pointLoop;
topLoop = pointLoop;
if (i == 0)
MLFuncBuilder b(topLoop);
// Loop bounds will be set later.
auto *tileSpaceLoop = b.createFor(loc, 0, 0);
- tileSpaceLoop->getStatements().splice(
- tileSpaceLoop->begin(), topLoop->getBlock()->getStatements(), topLoop);
+ tileSpaceLoop->getBody()->getStatements().splice(
+ tileSpaceLoop->getBody()->begin(), topLoop->getBlock()->getStatements(),
+ topLoop);
newLoops[2 * width - i - 1] = tileSpaceLoop;
topLoop = tileSpaceLoop;
}
ForStmt *currStmt = root;
do {
band.push_back(currStmt);
- } while (currStmt->getStatements().size() == 1 &&
- (currStmt = dyn_cast<ForStmt>(&*currStmt->begin())));
+ } while (currStmt->getBody()->getStatements().size() == 1 &&
+ (currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin())));
bands->push_back(band);
};
}
bool walkForStmtPostOrder(ForStmt *forStmt) {
- bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
+ bool hasInnerLoops =
+ walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end());
if (!hasInnerLoops)
loops.push_back(forStmt);
return true;
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
- if (unrollJamFactor == 1 || forStmt->getStatements().empty())
+ if (unrollJamFactor == 1 || forStmt->getBody()->empty())
return false;
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value());
loops.insert(forStmt);
// Setting the insertion point to the innermost loop achieves nesting.
- b.setInsertionPointToStart(loops.back());
+ b.setInsertionPointToStart(loops.back()->getBody());
if (composed == getAffineConstantExpr(0, b.getContext())) {
transfer->emitWarning(
"Redundant copy can be implemented as a vector broadcast");
/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
/// a replacement cannot be performed.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
- MLFuncBuilder bInner(forStmt, forStmt->begin());
- bInner.setInsertionPoint(forStmt, forStmt->begin());
+ auto *forBody = forStmt->getBody();
+ MLFuncBuilder bInner(forBody, forBody->begin());
+ bInner.setInsertionPoint(forBody, forBody->begin());
// Doubles the shape with a leading dimension extent of 2.
auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
// non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef),
ivModTwoOp->getResult(0), AffineMap::Null(), {},
- &*forStmt->begin())) {
+ &*forStmt->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getOperation()->erase();
// Collect outgoing DMA statements - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
}
SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
- if (!dominates(*forStmt->begin(), *use.getOwner())) {
+ if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
- for (const auto &stmt : *forStmt) {
+ for (const auto &stmt : *forStmt->getBody()) {
if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
stmtShiftMap[&stmt] = 1;
}
}
// Get shifts stored in map.
- std::vector<uint64_t> shifts(forStmt->getStatements().size());
+ std::vector<uint64_t> shifts(forStmt->getBody()->getStatements().size());
unsigned s = 0;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
shifts[s++] = stmtShiftMap[&stmt];
LLVM_DEBUG(
// Move the loop body statements to the loop's containing block.
auto *block = forStmt->getBlock();
block->getStatements().splice(StmtBlock::iterator(forStmt),
- forStmt->getStatements());
+ forStmt->getBody()->getStatements());
forStmt->erase();
return true;
}
operandMap[srcForStmt] = loopChunk;
}
for (auto *stmt : stmts) {
- loopChunk->push_back(stmt->clone(operandMap, b->getContext()));
+ loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext()));
}
}
if (promoteIfSingleIteration(loopChunk))
// method.
UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
bool unrollPrologueEpilogue) {
- if (forStmt->getStatements().empty())
+ if (forStmt->getBody()->empty())
return UtilResult::Success;
// If the trip counts aren't constant, we would need versioning and
int64_t step = forStmt->getStep();
- unsigned numChildStmts = forStmt->getStatements().size();
+ unsigned numChildStmts = forStmt->getBody()->getStatements().size();
// Do a linear time (counting) sort for the shifts.
uint64_t maxShift = 0;
// body of the 'for' stmt.
std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1);
unsigned pos = 0;
- for (auto &stmt : *forStmt) {
+ for (auto &stmt : *forStmt->getBody()) {
auto shift = shifts[pos++];
sortedStmtGroups[shift].push_back(&stmt);
}
bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
- if (unrollFactor == 1 || forStmt->getStatements().empty())
+ if (unrollFactor == 1 || forStmt->getBody()->empty())
return false;
auto lbMap = forStmt->getLowerBoundMap();
// Builder to insert unrolled bodies right after the last statement in the
// body of 'forStmt'.
- MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
+ MLFuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
// Keep a pointer to the last statement in the original block so that we know
// what to clone (since we are doing this in-place).
- StmtBlock::iterator srcBlockEnd = std::prev(forStmt->end());
+ StmtBlock::iterator srcBlockEnd = std::prev(forStmt->getBody()->end());
// Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
for (unsigned i = 1; i < unrollFactor; i++) {
}
// Clone the original body of 'forStmt'.
- for (auto it = forStmt->begin(); it != std::next(srcBlockEnd); it++) {
+ for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd);
+ it++) {
builder.clone(*it, operandMap);
}
}