From fdcecefe30d8c54b51c8c796adbc9c60bb47088d Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Fri, 24 Jan 2020 15:17:04 +0100 Subject: [PATCH] Add lowering for loop.parallel to cfg. Summary: This also removes the explicit pattern for loop.terminator to ensure that the terminator is only erased if the parent op is rewritten. Reductions are not yet supported. Reviewers: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73348 --- mlir/include/mlir/Dialect/LoopOps/LoopOps.td | 7 +++ .../LoopToStandard/ConvertLoopToStandard.cpp | 50 ++++++++++++++++++---- mlir/test/Conversion/convert-to-cfg.mlir | 33 ++++++++++++++ 3 files changed, 82 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td index d445c68..c9ca460 100644 --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -177,6 +177,13 @@ def ParallelOp : Loop_Op<"parallel", Variadic:$step); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$body); + + let extraClassDeclaration = [{ + iterator_range getInductionVars() { + Block &block = body().front(); + return {block.args_begin(), block.args_end()}; + } + }]; } def ReduceOp : Loop_Op<"reduce", [HasParent<"ParallelOp">]> { diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp index 341c9c5..d0fd8f0 100644 --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" @@ -142,14 +143,11 @@ struct IfLowering : public OpRewritePattern { PatternRewriter &rewriter) const override; }; -struct TerminatorLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ParallelLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TerminatorOp op, - PatternRewriter &rewriter) const override { - rewriter.eraseOp(op); - return matchSuccess(); - } + PatternMatchResult matchAndRewrite(mlir::loop::ParallelOp parallelOp, + PatternRewriter &rewriter) const override; }; } // namespace @@ -178,6 +176,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { // Append the induction variable stepping logic to the last body block and // branch back to the condition block. Construct an expression f : // (x -> x+step) and apply this expression to the induction variable. + rewriter.eraseOp(lastBodyBlock->getTerminator()); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.step(); auto stepped = rewriter.create(loc, iv, step).getResult(); @@ -220,6 +219,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { // place it before the continuation block, and branch to it. auto &thenRegion = ifOp.thenRegion(); auto *thenBlock = &thenRegion.front(); + rewriter.eraseOp(thenRegion.back().getTerminator()); rewriter.setInsertionPointToEnd(&thenRegion.back()); rewriter.create(loc, continueBlock); rewriter.inlineRegionBefore(thenRegion, continueBlock); @@ -231,6 +231,7 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { auto &elseRegion = ifOp.elseRegion(); if (!elseRegion.empty()) { elseBlock = &elseRegion.front(); + rewriter.eraseOp(elseRegion.back().getTerminator()); rewriter.setInsertionPointToEnd(&elseRegion.back()); rewriter.create(loc, continueBlock); rewriter.inlineRegionBefore(elseRegion, continueBlock); @@ -246,9 +247,42 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { return matchSuccess(); } +PatternMatchResult +ParallelLowering::matchAndRewrite(ParallelOp parallelOp, + PatternRewriter &rewriter) const { + Location loc = parallelOp.getLoc(); + BlockAndValueMapping mapping; + + if (parallelOp.getNumResults() != 0) { + // TODO: Implement lowering of parallelOp with reductions. + return matchFailure(); + } + + // For a parallel loop, we essentially need to create an n-dimensional loop + // nest. We do this by translating to loop.for ops and have those lowered in + // a further rewrite. + for (auto loop_operands : + llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(), + parallelOp.upperBound(), parallelOp.step())) { + Value iv, lower, upper, step; + std::tie(iv, lower, upper, step) = loop_operands; + ForOp forOp = rewriter.create(loc, lower, upper, step); + mapping.map(iv, forOp.getInductionVar()); + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + // Now copy over the contents of the body. + for (auto &op : parallelOp.body().front().without_terminator()) + rewriter.clone(op, mapping); + + rewriter.eraseOp(parallelOp); + + return matchSuccess(); +} + void mlir::populateLoopToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } void LoopToStandardPass::runOnOperation() { diff --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir index 8cf0bb2..b53dc23 100644 --- a/mlir/test/Conversion/convert-to-cfg.mlir +++ b/mlir/test/Conversion/convert-to-cfg.mlir @@ -147,3 +147,36 @@ func @simple_std_for_loop_with_2_ifs(%arg0 : index, %arg1 : index, %arg2 : index } return } + +// CHECK-LABEL: func @parallel_loop( +// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) { +// CHECK: [[VAL_5:%.*]] = constant 1 : index +// CHECK: br ^bb1([[VAL_0]] : index) +// CHECK: ^bb1([[VAL_6:%.*]]: index): +// CHECK: [[VAL_7:%.*]] = cmpi "slt", [[VAL_6]], [[VAL_2]] : index +// CHECK: cond_br [[VAL_7]], ^bb2, ^bb6 +// CHECK: ^bb2: +// CHECK: br ^bb3([[VAL_1]] : index) +// CHECK: ^bb3([[VAL_8:%.*]]: index): +// CHECK: [[VAL_9:%.*]] = cmpi "slt", [[VAL_8]], [[VAL_3]] : index +// CHECK: cond_br [[VAL_9]], ^bb4, ^bb5 +// CHECK: ^bb4: +// CHECK: [[VAL_10:%.*]] = constant 1 : index +// CHECK: [[VAL_11:%.*]] = addi [[VAL_8]], [[VAL_5]] : index +// CHECK: br ^bb3([[VAL_11]] : index) +// CHECK: ^bb5: +// CHECK: [[VAL_12:%.*]] = addi [[VAL_6]], [[VAL_4]] : index +// CHECK: br ^bb1([[VAL_12]] : index) +// CHECK: ^bb6: +// CHECK: return +// CHECK: } + +func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index) { + %step = constant 1 : index + loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %step) { + %c1 = constant 1 : index + } + return +} -- 2.7.4