From aff6bf4ff81a35a85034b478cccc7015499ce427 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 4 Mar 2020 11:44:22 +0100 Subject: [PATCH] [mlir] support conversion of parallel reduction loops to std Recently introduced support for converting sequential reduction loops to CFG of basic blocks in the Standard dialect makes it possible to perform a staged conversion of parallel reduction loops into a similar CFG by using sequential loops as an intermediate step. This is already the case for parallel loops without reduction, so extend the pattern to support an additional use case. Differential Revision: https://reviews.llvm.org/D75599 --- mlir/include/mlir/Dialect/LoopOps/LoopOps.td | 3 +- .../LoopToStandard/ConvertLoopToStandard.cpp | 66 ++++++++++++++--- mlir/lib/Dialect/LoopOps/LoopOps.cpp | 7 +- mlir/test/Conversion/convert-to-cfg.mlir | 85 ++++++++++++++++++++++ 4 files changed, 149 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td index f92a399..8850349 100644 --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -131,7 +131,8 @@ def ForOp : Loop_Op<"for", let skipDefaultBuilders = 1; let builders = [ OpBuilder<"Builder *builder, OperationState &result, " - "Value lowerBound, Value upperBound, Value step"> + "Value lowerBound, Value upperBound, Value step, " + "ValueRange iterArgs = llvm::None"> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp index b4d25e0..a16c4a0 100644 --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -274,29 +274,75 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, Location loc = parallelOp.getLoc(); BlockAndValueMapping mapping; - if (parallelOp.getNumResults() != 0) { - // TODO: Implement lowering of parallelOp with reductions. - return matchFailure(); - } - // For a parallel loop, we essentially need to create an n-dimensional loop // nest. We do this by translating to loop.for ops and have those lowered in - // a further rewrite. + // a further rewrite. If a parallel loop contains reductions (and thus returns + // values), forward the initial values for the reductions down the loop + // hierarchy and bubble up the results by modifying the "yield" terminator. + SmallVector iterArgs = llvm::to_vector<4>(parallelOp.initVals()); + bool first = true; + SmallVector loopResults(iterArgs); for (auto loop_operands : llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(), parallelOp.upperBound(), parallelOp.step())) { Value iv, lower, upper, step; std::tie(iv, lower, upper, step) = loop_operands; - ForOp forOp = rewriter.create(loc, lower, upper, step); + ForOp forOp = rewriter.create(loc, lower, upper, step, iterArgs); mapping.map(iv, forOp.getInductionVar()); + auto iterRange = forOp.getRegionIterArgs(); + iterArgs.assign(iterRange.begin(), iterRange.end()); + + if (first) { + // Store the results of the outermost loop that will be used to replace + // the results of the parallel loop when it is fully rewritten. + loopResults.assign(forOp.result_begin(), forOp.result_end()); + first = false; + } else { + // A loop is constructed with an empty "yield" terminator by default. + // Replace it with another "yield" that forwards the results of the nested + // loop to the parent loop. We need to explicitly make sure the new + // terminator is the last operation in the block because further transfoms + // rely on this. + rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); + rewriter.replaceOpWithNewOp( + rewriter.getInsertionBlock()->getTerminator(), forOp.getResults()); + } + rewriter.setInsertionPointToStart(forOp.getBody()); } // Now copy over the contents of the body. - for (auto &op : parallelOp.getBody()->without_terminator()) - rewriter.clone(op, mapping); + SmallVector yieldOperands; + yieldOperands.reserve(parallelOp.getNumResults()); + for (auto &op : parallelOp.getBody()->without_terminator()) { + // Reduction blocks are handled differently. + auto reduce = dyn_cast(op); + if (!reduce) { + rewriter.clone(op, mapping); + continue; + } + + // Clone the body of the reduction operation into the body of the loop, + // using operands of "loop.reduce" and iteration arguments corresponding + // to the reduction value to replace arguments of the reduction block. + // Collect operands of "loop.reduce.return" to be returned by a final + // "loop.yield" instead. + Value arg = iterArgs[yieldOperands.size()]; + Block &reduceBlock = reduce.reductionOperator().front(); + mapping.map(reduceBlock.getArgument(0), mapping.lookupOrDefault(arg)); + mapping.map(reduceBlock.getArgument(1), + mapping.lookupOrDefault(reduce.operand())); + for (auto &nested : reduceBlock.without_terminator()) + rewriter.clone(nested, mapping); + yieldOperands.push_back( + mapping.lookup(reduceBlock.getTerminator()->getOperand(0))); + } + + rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); + rewriter.replaceOpWithNewOp( + rewriter.getInsertionBlock()->getTerminator(), yieldOperands); - rewriter.eraseOp(parallelOp); + rewriter.replaceOp(parallelOp, loopResults); return matchSuccess(); } diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp index e9e9397..c0cb149b 100644 --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -61,11 +61,16 @@ LoopOpsDialect::LoopOpsDialect(MLIRContext *context) //===----------------------------------------------------------------------===// void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub, - Value step) { + Value step, ValueRange iterArgs) { result.addOperands({lb, ub, step}); + result.addOperands(iterArgs); + for (Value v : iterArgs) + result.addTypes(v.getType()); Region *bodyRegion = result.addRegion(); ForOp::ensureTerminator(*bodyRegion, *builder, result.location); bodyRegion->front().addArgument(builder->getIndexType()); + for (Value v : iterArgs) + bodyRegion->front().addArgument(v.getType()); } static LogicalResult verify(ForOp op) { diff --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir index c6be3fdb..54c5d4c 100644 --- a/mlir/test/Conversion/convert-to-cfg.mlir +++ b/mlir/test/Conversion/convert-to-cfg.mlir @@ -236,3 +236,88 @@ func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 { } return %r : f32 } + +func @generate() -> i64 + +// CHECK-LABEL: @simple_parallel_reduce_loop +// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[INIT:.*]]: f32 +func @simple_parallel_reduce_loop(%arg0: index, %arg1: index, + %arg2: index, %arg3: f32) -> f32 { + // A parallel loop with reduction is converted through sequential loops with + // reductions into a CFG of blocks where the partially reduced value is + // passed across as a block argument. + + // Branch to the condition block passing in the initial reduction value. + // CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT]] + + // Condition branch takes as arguments the current value of the iteration + // variable and the current partially reduced value. + // CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32 + // CHECK: %[[COMP:.*]] = cmpi "slt", %[[ITER]], %[[UB]] + // CHECK: cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] + + // Bodies of loop.reduce operations are folded into the main loop body. The + // result of this partial reduction is passed as argument to the condition + // block. + // CHECK: ^[[BODY]]: + // CHECK: %[[CST:.*]] = constant 4.2 + // CHECK: %[[PROD:.*]] = mulf %[[ITER_ARG]], %[[CST]] + // CHECK: %[[INCR:.*]] = addi %[[ITER]], %[[STEP]] + // CHECK: br ^[[COND]](%[[INCR]], %[[PROD]] + + // The continuation block has access to the (last value of) reduction. + // CHECK: ^[[CONTINUE]]: + // CHECK: return %[[ITER_ARG]] + %0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) { + %cst = constant 42.0 : f32 + loop.reduce(%cst) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = mulf %lhs, %rhs : f32 + loop.reduce.return %1 : f32 + } : f32 + } : f32 + return %0 : f32 +} + +// CHECK-LABEL: parallel_reduce_loop +// CHECK-SAME: %[[INIT1:[0-9A-Za-z_]*]]: f32) +func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index, %arg5 : f32) -> (f32, i64) { + // Multiple reduction blocks should be folded in the same body, and the + // reduction value must be forwarded through block structures. + // CHECK: %[[INIT2:.*]] = constant 42 + // CHECK: br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]] + // CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64 + // CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] + // CHECK: ^[[BODY_OUT]]: + // CHECK: br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]] + // CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64 + // CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] + // CHECK: ^[[BODY_IN]]: + // CHECK: %[[REDUCE1:.*]] = addf %[[ITER_ARG1_IN]], %{{.*}} + // CHECK: %[[REDUCE2:.*]] = or %[[ITER_ARG2_IN]], %{{.*}} + // CHECK: br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]] + // CHECK: ^[[CONT_IN]]: + // CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]] + // CHECK: ^[[CONT_OUT]]: + // CHECK: return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]] + %step = constant 1 : index + %init = constant 42 : i64 + %0:2 = loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %step) init(%arg5, %init) { + %cf = constant 42.0 : f32 + loop.reduce(%cf) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = addf %lhs, %rhs : f32 + loop.reduce.return %1 : f32 + } : f32 + + %2 = call @generate() : () -> i64 + loop.reduce(%2) { + ^bb0(%lhs: i64, %rhs: i64): + %3 = or %lhs, %rhs : i64 + loop.reduce.return %3 : i64 + } : i64 + } : f32, i64 + return %0#0, %0#1 : f32, i64 +} -- 2.7.4