Add lowering for loop.parallel to cfg.
authorStephan Herhut <herhut@google.com>
Fri, 24 Jan 2020 14:17:04 +0000 (15:17 +0100)
committerStephan Herhut <herhut@google.com>
Tue, 28 Jan 2020 10:55:51 +0000 (11:55 +0100)
Summary:
This also removes the explicit pattern for loop.terminator to ensure
that the terminator is only erased if the parent op is rewritten.

Reductions are not yet supported.

Reviewers: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73348

mlir/include/mlir/Dialect/LoopOps/LoopOps.td
mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
mlir/test/Conversion/convert-to-cfg.mlir

index d445c68..c9ca460 100644 (file)
@@ -177,6 +177,13 @@ def ParallelOp : Loop_Op<"parallel",
                        Variadic<Index>:$step);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$body);
+
+  let extraClassDeclaration = [{
+    iterator_range<Block::args_iterator> getInductionVars() {
+      Block &block = body().front();
+      return {block.args_begin(), block.args_end()};
+    }
+  }];
 }
 
 def ReduceOp : Loop_Op<"reduce", [HasParent<"ParallelOp">]> {
index 341c9c5..d0fd8f0 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
@@ -142,14 +143,11 @@ struct IfLowering : public OpRewritePattern<IfOp> {
                                      PatternRewriter &rewriter) const override;
 };
 
-struct TerminatorLowering : public OpRewritePattern<TerminatorOp> {
-  using OpRewritePattern<TerminatorOp>::OpRewritePattern;
+struct ParallelLowering : public OpRewritePattern<mlir::loop::ParallelOp> {
+  using OpRewritePattern<mlir::loop::ParallelOp>::OpRewritePattern;
 
-  PatternMatchResult matchAndRewrite(TerminatorOp op,
-                                     PatternRewriter &rewriter) const override {
-    rewriter.eraseOp(op);
-    return matchSuccess();
-  }
+  PatternMatchResult matchAndRewrite(mlir::loop::ParallelOp parallelOp,
+                                     PatternRewriter &rewriter) const override;
 };
 } // namespace
 
@@ -178,6 +176,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
   // Append the induction variable stepping logic to the last body block and
   // branch back to the condition block.  Construct an expression f :
   // (x -> x+step) and apply this expression to the induction variable.
+  rewriter.eraseOp(lastBodyBlock->getTerminator());
   rewriter.setInsertionPointToEnd(lastBodyBlock);
   auto step = forOp.step();
   auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
@@ -220,6 +219,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
   // place it before the continuation block, and branch to it.
   auto &thenRegion = ifOp.thenRegion();
   auto *thenBlock = &thenRegion.front();
+  rewriter.eraseOp(thenRegion.back().getTerminator());
   rewriter.setInsertionPointToEnd(&thenRegion.back());
   rewriter.create<BranchOp>(loc, continueBlock);
   rewriter.inlineRegionBefore(thenRegion, continueBlock);
@@ -231,6 +231,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
   auto &elseRegion = ifOp.elseRegion();
   if (!elseRegion.empty()) {
     elseBlock = &elseRegion.front();
+    rewriter.eraseOp(elseRegion.back().getTerminator());
     rewriter.setInsertionPointToEnd(&elseRegion.back());
     rewriter.create<BranchOp>(loc, continueBlock);
     rewriter.inlineRegionBefore(elseRegion, continueBlock);
@@ -246,9 +247,42 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
   return matchSuccess();
 }
 
+PatternMatchResult
+ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
+                                  PatternRewriter &rewriter) const {
+  Location loc = parallelOp.getLoc();
+  BlockAndValueMapping mapping;
+
+  if (parallelOp.getNumResults() != 0) {
+    // TODO: Implement lowering of parallelOp with reductions.
+    return matchFailure();
+  }
+
+  // For a parallel loop, we essentially need to create an n-dimensional loop
+  // nest. We do this by translating to loop.for ops and have those lowered in
+  // a further rewrite.
+  for (auto loop_operands :
+       llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(),
+                 parallelOp.upperBound(), parallelOp.step())) {
+    Value iv, lower, upper, step;
+    std::tie(iv, lower, upper, step) = loop_operands;
+    ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step);
+    mapping.map(iv, forOp.getInductionVar());
+    rewriter.setInsertionPointToStart(forOp.getBody());
+  }
+
+  // Now copy over the contents of the body.
+  for (auto &op : parallelOp.body().front().without_terminator())
+    rewriter.clone(op, mapping);
+
+  rewriter.eraseOp(parallelOp);
+
+  return matchSuccess();
+}
+
 void mlir::populateLoopToStdConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
+  patterns.insert<ForLowering, IfLowering, ParallelLowering>(ctx);
 }
 
 void LoopToStandardPass::runOnOperation() {
index 8cf0bb2..b53dc23 100644 (file)
@@ -147,3 +147,36 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index
   }
   return
 }
+
+// CHECK-LABEL:   func @parallel_loop(
+// CHECK-SAME:                        [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
+// CHECK:           [[VAL_5:%.*]] = constant 1 : index
+// CHECK:           br ^bb1([[VAL_0]] : index)
+// CHECK:         ^bb1([[VAL_6:%.*]]: index):
+// CHECK:           [[VAL_7:%.*]] = cmpi "slt", [[VAL_6]], [[VAL_2]] : index
+// CHECK:           cond_br [[VAL_7]], ^bb2, ^bb6
+// CHECK:         ^bb2:
+// CHECK:           br ^bb3([[VAL_1]] : index)
+// CHECK:         ^bb3([[VAL_8:%.*]]: index):
+// CHECK:           [[VAL_9:%.*]] = cmpi "slt", [[VAL_8]], [[VAL_3]] : index
+// CHECK:           cond_br [[VAL_9]], ^bb4, ^bb5
+// CHECK:         ^bb4:
+// CHECK:           [[VAL_10:%.*]] = constant 1 : index
+// CHECK:           [[VAL_11:%.*]] = addi [[VAL_8]], [[VAL_5]] : index
+// CHECK:           br ^bb3([[VAL_11]] : index)
+// CHECK:         ^bb5:
+// CHECK:           [[VAL_12:%.*]] = addi [[VAL_6]], [[VAL_4]] : index
+// CHECK:           br ^bb1([[VAL_12]] : index)
+// CHECK:         ^bb6:
+// CHECK:           return
+// CHECK:         }
+
+func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
+                        %arg3 : index, %arg4 : index) {
+  %step = constant 1 : index
+  loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+                                          step (%arg4, %step) {
+    %c1 = constant 1 : index
+  }
+  return
+}