#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
PatternRewriter &rewriter) const override;
};
-struct TerminatorLowering : public OpRewritePattern<TerminatorOp> {
- using OpRewritePattern<TerminatorOp>::OpRewritePattern;
+struct ParallelLowering : public OpRewritePattern<mlir::loop::ParallelOp> {
+ using OpRewritePattern<mlir::loop::ParallelOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(TerminatorOp op,
- PatternRewriter &rewriter) const override {
- rewriter.eraseOp(op);
- return matchSuccess();
- }
+ PatternMatchResult matchAndRewrite(mlir::loop::ParallelOp parallelOp,
+ PatternRewriter &rewriter) const override;
};
} // namespace
// Append the induction variable stepping logic to the last body block and
// branch back to the condition block. Construct an expression f :
// (x -> x+step) and apply this expression to the induction variable.
+ rewriter.eraseOp(lastBodyBlock->getTerminator());
rewriter.setInsertionPointToEnd(lastBodyBlock);
auto step = forOp.step();
auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
// place it before the continuation block, and branch to it.
auto &thenRegion = ifOp.thenRegion();
auto *thenBlock = &thenRegion.front();
+ rewriter.eraseOp(thenRegion.back().getTerminator());
rewriter.setInsertionPointToEnd(&thenRegion.back());
rewriter.create<BranchOp>(loc, continueBlock);
rewriter.inlineRegionBefore(thenRegion, continueBlock);
auto &elseRegion = ifOp.elseRegion();
if (!elseRegion.empty()) {
elseBlock = &elseRegion.front();
+ rewriter.eraseOp(elseRegion.back().getTerminator());
rewriter.setInsertionPointToEnd(&elseRegion.back());
rewriter.create<BranchOp>(loc, continueBlock);
rewriter.inlineRegionBefore(elseRegion, continueBlock);
return matchSuccess();
}
+PatternMatchResult
+ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
+ PatternRewriter &rewriter) const {
+ Location loc = parallelOp.getLoc();
+ BlockAndValueMapping mapping;
+
+ if (parallelOp.getNumResults() != 0) {
+ // TODO: Implement lowering of parallelOp with reductions.
+ return matchFailure();
+ }
+
+ // For a parallel loop, we essentially need to create an n-dimensional loop
+ // nest. We do this by translating to loop.for ops and have those lowered in
+ // a further rewrite.
+ for (auto loop_operands :
+ llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(),
+ parallelOp.upperBound(), parallelOp.step())) {
+ Value iv, lower, upper, step;
+ std::tie(iv, lower, upper, step) = loop_operands;
+ ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step);
+ mapping.map(iv, forOp.getInductionVar());
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ }
+
+ // Now copy over the contents of the body.
+ for (auto &op : parallelOp.body().front().without_terminator())
+ rewriter.clone(op, mapping);
+
+ rewriter.eraseOp(parallelOp);
+
+ return matchSuccess();
+}
+
void mlir::populateLoopToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
+ patterns.insert<ForLowering, IfLowering, ParallelLowering>(ctx);
}
void LoopToStandardPass::runOnOperation() {
}
return
}
+
+// CHECK-LABEL: func @parallel_loop(
+// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) {
+// CHECK: [[VAL_5:%.*]] = constant 1 : index
+// CHECK: br ^bb1([[VAL_0]] : index)
+// CHECK: ^bb1([[VAL_6:%.*]]: index):
+// CHECK: [[VAL_7:%.*]] = cmpi "slt", [[VAL_6]], [[VAL_2]] : index
+// CHECK: cond_br [[VAL_7]], ^bb2, ^bb6
+// CHECK: ^bb2:
+// CHECK: br ^bb3([[VAL_1]] : index)
+// CHECK: ^bb3([[VAL_8:%.*]]: index):
+// CHECK: [[VAL_9:%.*]] = cmpi "slt", [[VAL_8]], [[VAL_3]] : index
+// CHECK: cond_br [[VAL_9]], ^bb4, ^bb5
+// CHECK: ^bb4:
+// CHECK: [[VAL_10:%.*]] = constant 1 : index
+// CHECK: [[VAL_11:%.*]] = addi [[VAL_8]], [[VAL_5]] : index
+// CHECK: br ^bb3([[VAL_11]] : index)
+// CHECK: ^bb5:
+// CHECK: [[VAL_12:%.*]] = addi [[VAL_6]], [[VAL_4]] : index
+// CHECK: br ^bb1([[VAL_12]] : index)
+// CHECK: ^bb6:
+// CHECK: return
+// CHECK: }
+
+func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index) {
+ %step = constant 1 : index
+ loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+ step (%arg4, %step) {
+ %c1 = constant 1 : index
+ }
+ return
+}