class AffineBound;
class IntegerSet;
class AffineCondition;
+class OperationStmt;
+
+/// The operand of a Terminator contains a StmtBlock.
+using StmtBlockOperand = IROperandImpl<StmtBlock, OperationStmt>;
/// Operation statements represent operations inside ML functions.
class OperationStmt final
: public Operation,
public Statement,
- private llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult> {
+ private llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult,
+ StmtBlockOperand, unsigned> {
public:
/// Create a new OperationStmt with the specific fields.
- static OperationStmt *create(Location location, OperationName name,
- ArrayRef<MLValue *> operands,
- ArrayRef<Type> resultTypes,
- ArrayRef<NamedAttribute> attributes,
- MLIRContext *context);
+ static OperationStmt *
+ create(Location location, OperationName name, ArrayRef<MLValue *> operands,
+ ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes,
+ ArrayRef<StmtBlock *> successors, MLIRContext *context);
/// Return the context this operation is associated with.
MLIRContext *getContext() const;
return {result_type_begin(), result_type_end()};
}
+ //===--------------------------------------------------------------------===//
+ // Terminators
+ //===--------------------------------------------------------------------===//
+
+ MutableArrayRef<StmtBlockOperand> getBlockOperands() {
+ assert(isTerminator() && "Only terminators have a block operands list");
+ return {getTrailingObjects<StmtBlockOperand>(), numSuccs};
+ }
+ ArrayRef<StmtBlockOperand> getBlockOperands() const {
+ return const_cast<OperationStmt *>(this)->getBlockOperands();
+ }
+
+ MutableArrayRef<StmtOperand> getSuccessorOperands(unsigned index) {
+ assert(isTerminator() && "Only terminators have successors");
+ assert(index < getNumSuccessors());
+ unsigned succOpIndex = getSuccessorOperandIndex(index);
+ auto *operandBegin = getStmtOperands().data() + succOpIndex;
+ return {operandBegin, getNumSuccessorOperands(index)};
+ }
+ ArrayRef<StmtOperand> getSuccessorOperands(unsigned index) const {
+ return const_cast<OperationStmt *>(this)->getSuccessorOperands(index);
+ }
+
+ unsigned getNumSuccessors() const { return numSuccs; }
+ unsigned getNumSuccessorOperands(unsigned index) const {
+ assert(isTerminator() && "Only terminators have successors");
+ assert(index < getNumSuccessors());
+ return getTrailingObjects<unsigned>()[index];
+ }
+
+ StmtBlock *getSuccessor(unsigned index) {
+ assert(index < getNumSuccessors());
+ return getBlockOperands()[index].get();
+ }
+ const StmtBlock *getSuccessor(unsigned index) const {
+ return const_cast<OperationStmt *>(this)->getSuccessor(index);
+ }
+ void setSuccessor(BasicBlock *block, unsigned index);
+
+ /// Get the index of the first operand of the successor at the provided
+ /// index.
+ unsigned getSuccessorOperandIndex(unsigned index) const {
+ assert(isTerminator() && "Only terminators have successors.");
+ assert(index < getNumSuccessors());
+
+ // Count the number of operands for each of the successors after, and
+ // including, the one at 'index'. This is based upon the assumption that all
+ // non successor operands are placed at the beginning of the operand list.
+ auto *successorOpCountBegin = getTrailingObjects<unsigned>();
+ unsigned postSuccessorOpCount =
+ std::accumulate(successorOpCountBegin + index,
+ successorOpCountBegin + getNumSuccessors(), 0);
+ return getNumOperands() - postSuccessorOpCount;
+ }
+
//===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//
}
private:
- const unsigned numOperands, numResults;
+ const unsigned numOperands, numResults, numSuccs;
OperationStmt(Location location, OperationName name, unsigned numOperands,
- unsigned numResults, ArrayRef<NamedAttribute> attributes,
- MLIRContext *context);
+ unsigned numResults, unsigned numSuccessors,
+ ArrayRef<NamedAttribute> attributes, MLIRContext *context);
~OperationStmt();
// This stuff is used by the TrailingObjects template.
- friend llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult>;
+ friend llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult,
+ StmtBlockOperand, unsigned>;
size_t numTrailingObjects(OverloadToken<StmtOperand>) const {
return numOperands;
}
size_t numTrailingObjects(OverloadToken<StmtResult>) const {
return numResults;
}
+ size_t numTrailingObjects(OverloadToken<StmtBlockOperand>) const {
+ return numSuccs;
+ }
+ size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; }
};
/// A ForStmtBody represents statements contained within a ForStmt.
//===----------------------------------------------------------------------===//
// StmtResult
-//===------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
/// Return the result number of this result.
unsigned StmtResult::getResultNumber() const {
//===----------------------------------------------------------------------===//
// StmtOperand
-//===------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
/// Return which operand this is in the operand list.
template <> unsigned StmtOperand::getOperandNumber() const {
//===----------------------------------------------------------------------===//
// Statement
-//===------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
// Statements are deleted through the destroy() member because we don't have
// a virtual destructor.
bool Statement::emitError(const Twine &message) const {
return getContext()->emitError(getLoc(), message);
}
+
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
//===----------------------------------------------------------------------===//
ArrayRef<MLValue *> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
+ ArrayRef<StmtBlock *> successors,
MLIRContext *context) {
- auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
- resultTypes.size());
+ unsigned numSuccessors = successors.size();
+ auto byteSize =
+ totalSizeToAlloc<StmtOperand, StmtResult, StmtBlockOperand, unsigned>(
+ operands.size(), resultTypes.size(), numSuccessors, numSuccessors);
void *rawMem = malloc(byteSize);
// Initialize the OperationStmt part of the statement.
- auto stmt = ::new (rawMem) OperationStmt(
- location, name, operands.size(), resultTypes.size(), attributes, context);
-
- // Initialize the operands and results.
- auto stmtOperands = stmt->getStmtOperands();
- for (unsigned i = 0, e = operands.size(); i != e; ++i)
- new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
+ auto stmt = ::new (rawMem)
+ OperationStmt(location, name, operands.size(), resultTypes.size(),
+ numSuccessors, attributes, context);
+ // Initialize the results and operands.
auto stmtResults = stmt->getStmtResults();
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
+
+ auto stmtOperands = stmt->getStmtOperands();
+
+ // Initialize normal operands.
+ unsigned operandIt = 0, operandE = operands.size();
+ unsigned nextOperand = 0;
+ for (; operandIt != operandE; ++operandIt) {
+ // Null operands are used as sentinals between successor operand lists. If
+ // we encounter one here, break and handle the successor operands lists
+ // separately below.
+ if (!operands[operandIt])
+ break;
+ new (&stmtOperands[nextOperand++]) StmtOperand(stmt, operands[operandIt]);
+ }
+
+ unsigned currentSuccNum = 0;
+ if (operandIt == operandE) {
+ // Verify that the amount of sentinal operands is equivalent to the number
+ // of successors.
+ assert(currentSuccNum == numSuccessors);
+
+ return stmt;
+ }
+
+ assert(stmt->isTerminator() &&
+ "Sentinal operand found in non terminator operand list.");
+ auto instBlockOperands = stmt->getBlockOperands();
+ unsigned *succOperandCountIt = stmt->getTrailingObjects<unsigned>();
+ unsigned *succOperandCountE = succOperandCountIt + numSuccessors;
+ (void)succOperandCountE;
+
+ for (; operandIt != operandE; ++operandIt) {
+ // If we encounter a sentinal branch to the next operand update the count
+ // variable.
+ if (!operands[operandIt]) {
+ assert(currentSuccNum < numSuccessors);
+
+ // After the first iteration update the successor operand count
+ // variable.
+ if (currentSuccNum != 0) {
+ ++succOperandCountIt;
+ assert(succOperandCountIt != succOperandCountE &&
+ "More sentinal operands than successors.");
+ }
+
+ new (&instBlockOperands[currentSuccNum])
+ StmtBlockOperand(stmt, successors[currentSuccNum]);
+ *succOperandCountIt = 0;
+ ++currentSuccNum;
+ continue;
+ }
+ new (&stmtOperands[nextOperand++]) StmtOperand(stmt, operands[operandIt]);
+ ++(*succOperandCountIt);
+ }
+
+ // Verify that the amount of sentinal operands is equivalent to the number of
+ // successors.
+ assert(currentSuccNum == numSuccessors);
+
return stmt;
}
OperationStmt::OperationStmt(Location location, OperationName name,
unsigned numOperands, unsigned numResults,
+ unsigned numSuccessors,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context)
: Operation(/*isInstruction=*/false, name, attributes, context),
Statement(Kind::Operation, location), numOperands(numOperands),
- numResults(numResults) {}
+ numResults(numResults), numSuccs(numSuccessors) {}
OperationStmt::~OperationStmt() {
// Explicitly run the destructors for the operands and results.
};
SmallVector<MLValue *, 8> operands;
- operands.reserve(getNumOperands());
- for (auto *opValue : getOperands())
- operands.push_back(remapOperand(opValue));
-
+ SmallVector<StmtBlock *, 2> successors;
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
+ operands.reserve(getNumOperands() + opStmt->getNumSuccessors());
+
+ if (!opStmt->isTerminator()) {
+ // Non-terminators just add all the operands.
+ for (auto *opValue : getOperands())
+ operands.push_back(remapOperand(opValue));
+ } else {
+ // We add the operands separated by nullptr's for each successor.
+ unsigned firstSuccOperand = opStmt->getNumSuccessors()
+ ? opStmt->getSuccessorOperandIndex(0)
+ : opStmt->getNumOperands();
+ auto stmtOperands = opStmt->getStmtOperands();
+
+ unsigned i = 0;
+ for (; i != firstSuccOperand; ++i)
+ operands.push_back(remapOperand(stmtOperands[i].get()));
+
+ successors.reserve(opStmt->getNumSuccessors());
+ for (unsigned succ = 0, e = opStmt->getNumSuccessors(); succ != e;
+ ++succ) {
+ successors.push_back(
+ const_cast<StmtBlock *>(opStmt->getSuccessor(succ)));
+
+ // Add sentinel to delineate successor operands.
+ operands.push_back(nullptr);
+
+ // Remap the successors operands.
+ for (auto &operand : opStmt->getSuccessorOperands(succ))
+ operands.push_back(remapOperand(operand.get()));
+ }
+ }
+
SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());
- auto *newOp =
- OperationStmt::create(getLoc(), opStmt->getName(), operands,
- resultTypes, opStmt->getAttrs(), context);
+ auto *newOp = OperationStmt::create(getLoc(), opStmt->getName(), operands,
+ resultTypes, opStmt->getAttrs(),
+ successors, context);
// Remember the mapping of any results.
for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
operandMap[opStmt->getResult(i)] = newOp->getResult(i);
return newOp;
}
+ operands.reserve(getNumOperands());
+ for (auto *opValue : getOperands())
+ operands.push_back(remapOperand(opValue));
+
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
auto lbMap = forStmt->getLowerBoundMap();
auto ubMap = forStmt->getUpperBoundMap();