Give StmtBlocks a use-def list, and give OperationStmt's the ability to have
authorChris Lattner <clattner@google.com>
Sun, 23 Dec 2018 16:27:55 +0000 (08:27 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:35:34 +0000 (14:35 -0700)
optional successor operands when they are terminator operations.

This isn't used yet, but is part 2/n towards merging BasicBlock into StmtBlock
and Instruction into OperationStmt.

PiperOrigin-RevId: 226684636

mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/Statements.h
mlir/include/mlir/IR/StmtBlock.h
mlir/lib/IR/Builders.cpp
mlir/lib/IR/Statement.cpp

index 4197e4c95e83ccd34eac75d48600376031650e64..a466559257f86b76c608fe72ac8d46bd9a131716 100644 (file)
@@ -40,6 +40,7 @@ class OpAsmParserResult;
 class OpAsmPrinter;
 class Pattern;
 class RewritePattern;
+class StmtBlock;
 class SSAValue;
 class Type;
 
@@ -209,6 +210,9 @@ struct OperationState {
   /// Successors of this operation and their respective operands.
   SmallVector<BasicBlock *, 1> successors;
 
+  // TODO: rename to successors when CFG and ML Functions are merged.
+  SmallVector<StmtBlock *, 1> successorsS;
+
 public:
   OperationState(MLIRContext *context, Location location, StringRef name)
       : context(context), location(location), name(name, context) {}
@@ -226,6 +230,16 @@ public:
         attributes(attributes.begin(), attributes.end()),
         successors(successors.begin(), successors.end()) {}
 
+  OperationState(MLIRContext *context, Location location, StringRef name,
+                 ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
+                 ArrayRef<NamedAttribute> attributes = {},
+                 ArrayRef<StmtBlock *> successors = {})
+      : context(context), location(location), name(name, context),
+        operands(operands.begin(), operands.end()),
+        types(types.begin(), types.end()),
+        attributes(attributes.begin(), attributes.end()),
+        successorsS(successors.begin(), successors.end()) {}
+
   void addOperands(ArrayRef<SSAValue *> newOperands) {
     assert(successors.empty() &&
            "Non successor operands should be added first.");
index 28f5a14540dc14852fc6eb1f56c105795b42b03d..6b8aa0355c5ae0d620adbb693fdd15841cd0829a 100644 (file)
@@ -34,19 +34,23 @@ namespace mlir {
 class AffineBound;
 class IntegerSet;
 class AffineCondition;
+class OperationStmt;
+
+/// The operand of a Terminator contains a StmtBlock.
+using StmtBlockOperand = IROperandImpl<StmtBlock, OperationStmt>;
 
 /// Operation statements represent operations inside ML functions.
 class OperationStmt final
     : public Operation,
       public Statement,
-      private llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult> {
+      private llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult,
+                                    StmtBlockOperand, unsigned> {
 public:
   /// Create a new OperationStmt with the specific fields.
-  static OperationStmt *create(Location location, OperationName name,
-                               ArrayRef<MLValue *> operands,
-                               ArrayRef<Type> resultTypes,
-                               ArrayRef<NamedAttribute> attributes,
-                               MLIRContext *context);
+  static OperationStmt *
+  create(Location location, OperationName name, ArrayRef<MLValue *> operands,
+         ArrayRef<Type> resultTypes, ArrayRef<NamedAttribute> attributes,
+         ArrayRef<StmtBlock *> successors, MLIRContext *context);
 
   /// Return the context this operation is associated with.
   MLIRContext *getContext() const;
@@ -183,6 +187,61 @@ public:
     return {result_type_begin(), result_type_end()};
   }
 
+  //===--------------------------------------------------------------------===//
+  // Terminators
+  //===--------------------------------------------------------------------===//
+
+  MutableArrayRef<StmtBlockOperand> getBlockOperands() {
+    assert(isTerminator() && "Only terminators have a block operands list");
+    return {getTrailingObjects<StmtBlockOperand>(), numSuccs};
+  }
+  ArrayRef<StmtBlockOperand> getBlockOperands() const {
+    return const_cast<OperationStmt *>(this)->getBlockOperands();
+  }
+
+  MutableArrayRef<StmtOperand> getSuccessorOperands(unsigned index) {
+    assert(isTerminator() && "Only terminators have successors");
+    assert(index < getNumSuccessors());
+    unsigned succOpIndex = getSuccessorOperandIndex(index);
+    auto *operandBegin = getStmtOperands().data() + succOpIndex;
+    return {operandBegin, getNumSuccessorOperands(index)};
+  }
+  ArrayRef<StmtOperand> getSuccessorOperands(unsigned index) const {
+    return const_cast<OperationStmt *>(this)->getSuccessorOperands(index);
+  }
+
+  unsigned getNumSuccessors() const { return numSuccs; }
+  unsigned getNumSuccessorOperands(unsigned index) const {
+    assert(isTerminator() && "Only terminators have successors");
+    assert(index < getNumSuccessors());
+    return getTrailingObjects<unsigned>()[index];
+  }
+
+  StmtBlock *getSuccessor(unsigned index) {
+    assert(index < getNumSuccessors());
+    return getBlockOperands()[index].get();
+  }
+  const StmtBlock *getSuccessor(unsigned index) const {
+    return const_cast<OperationStmt *>(this)->getSuccessor(index);
+  }
+  void setSuccessor(BasicBlock *block, unsigned index);
+
+  /// Get the index of the first operand of the successor at the provided
+  /// index.
+  unsigned getSuccessorOperandIndex(unsigned index) const {
+    assert(isTerminator() && "Only terminators have successors.");
+    assert(index < getNumSuccessors());
+
+    // Count the number of operands for each of the successors after, and
+    // including, the one at 'index'. This is based upon the assumption that all
+    // non successor operands are placed at the beginning of the operand list.
+    auto *successorOpCountBegin = getTrailingObjects<unsigned>();
+    unsigned postSuccessorOpCount =
+        std::accumulate(successorOpCountBegin + index,
+                        successorOpCountBegin + getNumSuccessors(), 0);
+    return getNumOperands() - postSuccessorOpCount;
+  }
+
   //===--------------------------------------------------------------------===//
   // Other
   //===--------------------------------------------------------------------===//
@@ -199,21 +258,26 @@ public:
   }
 
 private:
-  const unsigned numOperands, numResults;
+  const unsigned numOperands, numResults, numSuccs;
 
   OperationStmt(Location location, OperationName name, unsigned numOperands,
-                unsigned numResults, ArrayRef<NamedAttribute> attributes,
-                MLIRContext *context);
+                unsigned numResults, unsigned numSuccessors,
+                ArrayRef<NamedAttribute> attributes, MLIRContext *context);
   ~OperationStmt();
 
   // This stuff is used by the TrailingObjects template.
-  friend llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult>;
+  friend llvm::TrailingObjects<OperationStmt, StmtOperand, StmtResult,
+                               StmtBlockOperand, unsigned>;
   size_t numTrailingObjects(OverloadToken<StmtOperand>) const {
     return numOperands;
   }
   size_t numTrailingObjects(OverloadToken<StmtResult>) const {
     return numResults;
   }
+  size_t numTrailingObjects(OverloadToken<StmtBlockOperand>) const {
+    return numSuccs;
+  }
+  size_t numTrailingObjects(OverloadToken<unsigned>) const { return numSuccs; }
 };
 
 /// A ForStmtBody represents statements contained within a ForStmt.
index a9a35e564c77acc7324a428a041eb9115e173a47..c9f2638d56d2f38a87dd818e9a316dfd50970678 100644 (file)
@@ -32,7 +32,7 @@ class MLValue;
 /// Statement block represents an ordered list of statements, with the order
 /// being the contiguous lexical order in which the statements appear as
 /// children of a parent statement in the ML Function.
-class StmtBlock {
+class StmtBlock : public IRObjectWithUseList {
 public:
   enum class StmtBlockKind {
     MLFunc,  // MLFunction
index a6391777ea7ae337f38f9a9c88d044ee727e5ad7..d3c59493718008018b674471cd172a86a03e986b 100644 (file)
@@ -312,8 +312,9 @@ OperationStmt *MLFuncBuilder::createOperation(const OperationState &state) {
   for (auto elt : state.operands)
     operands.push_back(cast<MLValue>(elt));
 
-  auto *op = OperationStmt::create(state.location, state.name, operands,
-                                   state.types, state.attributes, context);
+  auto *op =
+      OperationStmt::create(state.location, state.name, operands, state.types,
+                            state.attributes, state.successorsS, context);
   block->getStatements().insert(insertPoint, op);
   return op;
 }
@@ -324,7 +325,7 @@ OperationStmt *MLFuncBuilder::createOperation(Location location,
                                               ArrayRef<MLValue *> operands,
                                               ArrayRef<Type> types,
                                               ArrayRef<NamedAttribute> attrs) {
-  auto *op = OperationStmt::create(location, name, operands, types, attrs,
+  auto *op = OperationStmt::create(location, name, operands, types, attrs, {},
                                    getContext());
   block->getStatements().insert(insertPoint, op);
   return op;
index f63c76605de865216668b9ef1dfff35054ea3d2d..0f7e90bec3201fcf05670acdee316793ec4ba622 100644 (file)
@@ -29,7 +29,7 @@ using namespace mlir;
 
 //===----------------------------------------------------------------------===//
 // StmtResult
-//===------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
 
 /// Return the result number of this result.
 unsigned StmtResult::getResultNumber() const {
@@ -40,7 +40,7 @@ unsigned StmtResult::getResultNumber() const {
 
 //===----------------------------------------------------------------------===//
 // StmtOperand
-//===------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
 
 /// Return which operand this is in the operand list.
 template <> unsigned StmtOperand::getOperandNumber() const {
@@ -49,7 +49,7 @@ template <> unsigned StmtOperand::getOperandNumber() const {
 
 //===----------------------------------------------------------------------===//
 // Statement
-//===------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
 
 // Statements are deleted through the destroy() member because we don't have
 // a virtual destructor.
@@ -171,6 +171,7 @@ void Statement::emitWarning(const Twine &message) const {
 bool Statement::emitError(const Twine &message) const {
   return getContext()->emitError(getLoc(), message);
 }
+
 //===----------------------------------------------------------------------===//
 // ilist_traits for Statement
 //===----------------------------------------------------------------------===//
@@ -249,33 +250,93 @@ OperationStmt *OperationStmt::create(Location location, OperationName name,
                                      ArrayRef<MLValue *> operands,
                                      ArrayRef<Type> resultTypes,
                                      ArrayRef<NamedAttribute> attributes,
+                                     ArrayRef<StmtBlock *> successors,
                                      MLIRContext *context) {
-  auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
-                                                            resultTypes.size());
+  unsigned numSuccessors = successors.size();
+  auto byteSize =
+      totalSizeToAlloc<StmtOperand, StmtResult, StmtBlockOperand, unsigned>(
+          operands.size(), resultTypes.size(), numSuccessors, numSuccessors);
   void *rawMem = malloc(byteSize);
 
   // Initialize the OperationStmt part of the statement.
-  auto stmt = ::new (rawMem) OperationStmt(
-      location, name, operands.size(), resultTypes.size(), attributes, context);
-
-  // Initialize the operands and results.
-  auto stmtOperands = stmt->getStmtOperands();
-  for (unsigned i = 0, e = operands.size(); i != e; ++i)
-    new (&stmtOperands[i]) StmtOperand(stmt, operands[i]);
+  auto stmt = ::new (rawMem)
+      OperationStmt(location, name, operands.size(), resultTypes.size(),
+                    numSuccessors, attributes, context);
 
+  // Initialize the results and operands.
   auto stmtResults = stmt->getStmtResults();
   for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
     new (&stmtResults[i]) StmtResult(resultTypes[i], stmt);
+
+  auto stmtOperands = stmt->getStmtOperands();
+
+  // Initialize normal operands.
+  unsigned operandIt = 0, operandE = operands.size();
+  unsigned nextOperand = 0;
+  for (; operandIt != operandE; ++operandIt) {
+    // Null operands are used as sentinals between successor operand lists. If
+    // we encounter one here, break and handle the successor operands lists
+    // separately below.
+    if (!operands[operandIt])
+      break;
+    new (&stmtOperands[nextOperand++]) StmtOperand(stmt, operands[operandIt]);
+  }
+
+  unsigned currentSuccNum = 0;
+  if (operandIt == operandE) {
+    // Verify that the amount of sentinal operands is equivalent to the number
+    // of successors.
+    assert(currentSuccNum == numSuccessors);
+
+    return stmt;
+  }
+
+  assert(stmt->isTerminator() &&
+         "Sentinal operand found in non terminator operand list.");
+  auto instBlockOperands = stmt->getBlockOperands();
+  unsigned *succOperandCountIt = stmt->getTrailingObjects<unsigned>();
+  unsigned *succOperandCountE = succOperandCountIt + numSuccessors;
+  (void)succOperandCountE;
+
+  for (; operandIt != operandE; ++operandIt) {
+    // If we encounter a sentinal branch to the next operand update the count
+    // variable.
+    if (!operands[operandIt]) {
+      assert(currentSuccNum < numSuccessors);
+
+      // After the first iteration update the successor operand count
+      // variable.
+      if (currentSuccNum != 0) {
+        ++succOperandCountIt;
+        assert(succOperandCountIt != succOperandCountE &&
+               "More sentinal operands than successors.");
+      }
+
+      new (&instBlockOperands[currentSuccNum])
+          StmtBlockOperand(stmt, successors[currentSuccNum]);
+      *succOperandCountIt = 0;
+      ++currentSuccNum;
+      continue;
+    }
+    new (&stmtOperands[nextOperand++]) StmtOperand(stmt, operands[operandIt]);
+    ++(*succOperandCountIt);
+  }
+
+  // Verify that the amount of sentinal operands is equivalent to the number of
+  // successors.
+  assert(currentSuccNum == numSuccessors);
+
   return stmt;
 }
 
 OperationStmt::OperationStmt(Location location, OperationName name,
                              unsigned numOperands, unsigned numResults,
+                             unsigned numSuccessors,
                              ArrayRef<NamedAttribute> attributes,
                              MLIRContext *context)
     : Operation(/*isInstruction=*/false, name, attributes, context),
       Statement(Kind::Operation, location), numOperands(numOperands),
-      numResults(numResults) {}
+      numResults(numResults), numSuccs(numSuccessors) {}
 
 OperationStmt::~OperationStmt() {
   // Explicitly run the destructors for the operands and results.
@@ -512,24 +573,57 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
   };
 
   SmallVector<MLValue *, 8> operands;
-  operands.reserve(getNumOperands());
-  for (auto *opValue : getOperands())
-    operands.push_back(remapOperand(opValue));
-
+  SmallVector<StmtBlock *, 2> successors;
   if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
+    operands.reserve(getNumOperands() + opStmt->getNumSuccessors());
+
+    if (!opStmt->isTerminator()) {
+      // Non-terminators just add all the operands.
+      for (auto *opValue : getOperands())
+        operands.push_back(remapOperand(opValue));
+    } else {
+      // We add the operands separated by nullptr's for each successor.
+      unsigned firstSuccOperand = opStmt->getNumSuccessors()
+                                      ? opStmt->getSuccessorOperandIndex(0)
+                                      : opStmt->getNumOperands();
+      auto stmtOperands = opStmt->getStmtOperands();
+
+      unsigned i = 0;
+      for (; i != firstSuccOperand; ++i)
+        operands.push_back(remapOperand(stmtOperands[i].get()));
+
+      successors.reserve(opStmt->getNumSuccessors());
+      for (unsigned succ = 0, e = opStmt->getNumSuccessors(); succ != e;
+           ++succ) {
+        successors.push_back(
+            const_cast<StmtBlock *>(opStmt->getSuccessor(succ)));
+
+        // Add sentinel to delineate successor operands.
+        operands.push_back(nullptr);
+
+        // Remap the successors operands.
+        for (auto &operand : opStmt->getSuccessorOperands(succ))
+          operands.push_back(remapOperand(operand.get()));
+      }
+    }
+
     SmallVector<Type, 8> resultTypes;
     resultTypes.reserve(opStmt->getNumResults());
     for (auto *result : opStmt->getResults())
       resultTypes.push_back(result->getType());
-    auto *newOp =
-        OperationStmt::create(getLoc(), opStmt->getName(), operands,
-                              resultTypes, opStmt->getAttrs(), context);
+    auto *newOp = OperationStmt::create(getLoc(), opStmt->getName(), operands,
+                                        resultTypes, opStmt->getAttrs(),
+                                        successors, context);
     // Remember the mapping of any results.
     for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
       operandMap[opStmt->getResult(i)] = newOp->getResult(i);
     return newOp;
   }
 
+  operands.reserve(getNumOperands());
+  for (auto *opValue : getOperands())
+    operands.push_back(remapOperand(opValue));
+
   if (auto *forStmt = dyn_cast<ForStmt>(this)) {
     auto lbMap = forStmt->getLowerBoundMap();
     auto ubMap = forStmt->getUpperBoundMap();