NFC: Update the LoopToStd conversion patterns to use RewritePattern instead of Conver...
authorRiver Riddle <riverriddle@google.com>
Mon, 22 Jul 2019 20:22:24 +0000 (13:22 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 22 Jul 2019 20:22:49 +0000 (13:22 -0700)
These patterns don't require type changes so they don't need to be using ConversionPattern.

PiperOrigin-RevId: 259393151

mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp

index 1515d95..c280ed9 100644 (file)
@@ -100,13 +100,11 @@ struct ControlFlowToCFGPass : public FunctionPass<ControlFlowToCFGPass> {
 //      |   <code after the ForOp> |
 //      +--------------------------------+
 //
-struct ForLowering : public ConversionPattern {
-  ForLowering(MLIRContext *ctx)
-      : ConversionPattern(ForOp::getOperationName(), 1, ctx) {}
+struct ForLowering : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
 
-  PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const override;
+  PatternMatchResult matchAndRewrite(ForOp forOp,
+                                     PatternRewriter &rewriter) const override;
 };
 
 // Create a CFG subgraph for the loop.if operation (including its "then" and
@@ -151,22 +149,18 @@ struct ForLowering : public ConversionPattern {
 //      |   <code after the IfOp>  |
 //      +--------------------------------+
 //
-struct IfLowering : public ConversionPattern {
-  IfLowering(MLIRContext *ctx)
-      : ConversionPattern(IfOp::getOperationName(), 1, ctx) {}
+struct IfLowering : public OpRewritePattern<IfOp> {
+  using OpRewritePattern<IfOp>::OpRewritePattern;
 
-  PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const override;
+  PatternMatchResult matchAndRewrite(IfOp ifOp,
+                                     PatternRewriter &rewriter) const override;
 };
 
-struct TerminatorLowering : public ConversionPattern {
-  TerminatorLowering(MLIRContext *ctx)
-      : ConversionPattern(TerminatorOp::getOperationName(), 1, ctx) {}
+struct TerminatorLowering : public OpRewritePattern<TerminatorOp> {
+  using OpRewritePattern<TerminatorOp>::OpRewritePattern;
 
-  PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const override {
+  PatternMatchResult matchAndRewrite(TerminatorOp op,
+                                     PatternRewriter &rewriter) const override {
     rewriter.replaceOp(op, {});
     return matchSuccess();
   }
@@ -174,12 +168,10 @@ struct TerminatorLowering : public ConversionPattern {
 } // namespace
 
 PatternMatchResult
-ForLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                             ConversionPatternRewriter &rewriter) const {
-  auto forOp = cast<ForOp>(op);
-  Location loc = op->getLoc();
+ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
+  Location loc = forOp.getLoc();
 
-  // Start by splitting the block containing the 'affine.for' into two parts.
+  // Start by splitting the block containing the 'loop.for' into two parts.
   // The part before will get the init code, the part after will be the end
   // point.
   auto *initBlock = rewriter.getInsertionBlock();
@@ -201,8 +193,7 @@ ForLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   // branch back to the condition block.  Construct an expression f :
   // (x -> x+step) and apply this expression to the induction variable.
   rewriter.setInsertionPointToEnd(lastBodyBlock);
-  ForOpOperandAdaptor newOperands(operands);
-  auto *step = newOperands.step();
+  auto *step = forOp.step();
   auto *stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
   if (!stepped)
     return matchFailure();
@@ -210,8 +201,8 @@ ForLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
 
   // Compute loop bounds before branching to the condition.
   rewriter.setInsertionPointToEnd(initBlock);
-  Value *lowerBound = operands[0];
-  Value *upperBound = operands[1];
+  Value *lowerBound = forOp.lowerBound();
+  Value *upperBound = forOp.upperBound();
   if (!lowerBound || !upperBound)
     return matchFailure();
   rewriter.create<BranchOp>(loc, conditionBlock, lowerBound);
@@ -225,15 +216,13 @@ ForLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                                 ArrayRef<Value *>(), endBlock,
                                 ArrayRef<Value *>());
   // Ok, we're done!
-  rewriter.replaceOp(op, {});
+  rewriter.replaceOp(forOp, {});
   return matchSuccess();
 }
 
 PatternMatchResult
-IfLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
-                            ConversionPatternRewriter &rewriter) const {
-  auto ifOp = cast<IfOp>(op);
-  auto loc = op->getLoc();
+IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
+  auto loc = ifOp.getLoc();
 
   // Start by splitting the block containing the 'loop.if' into two parts.
   // The part before will contain the condition, the part after will be the
@@ -263,13 +252,12 @@ IfLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   }
 
   rewriter.setInsertionPointToEnd(condBlock);
-  IfOpOperandAdaptor newOperands(operands);
-  rewriter.create<CondBranchOp>(loc, newOperands.condition(), thenBlock,
+  rewriter.create<CondBranchOp>(loc, ifOp.condition(), thenBlock,
                                 /*trueArgs=*/ArrayRef<Value *>(), elseBlock,
                                 /*falseArgs=*/ArrayRef<Value *>());
 
   // Ok, we're done!
-  rewriter.replaceOp(op, {});
+  rewriter.replaceOp(ifOp, {});
   return matchSuccess();
 }