[mlir] support conversion of parallel reduction loops to std
authorAlex Zinenko <zinenko@google.com>
Wed, 4 Mar 2020 10:44:22 +0000 (11:44 +0100)
committerAlex Zinenko <zinenko@google.com>
Wed, 4 Mar 2020 15:37:17 +0000 (16:37 +0100)
Recently introduced support for converting sequential reduction loops to
CFG of basic blocks in the Standard dialect makes it possible to perform
a staged conversion of parallel reduction loops into a similar CFG by
using sequential loops as an intermediate step. This is already the case
for parallel loops without reduction, so extend the pattern to support
an additional use case.

Differential Revision: https://reviews.llvm.org/D75599

mlir/include/mlir/Dialect/LoopOps/LoopOps.td
mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp
mlir/lib/Dialect/LoopOps/LoopOps.cpp
mlir/test/Conversion/convert-to-cfg.mlir

index f92a399dce70073ea41073d8be7bc401e34c336b..8850349af574bcf2c7f40c04453e54fc81ad0b3e 100644 (file)
@@ -131,7 +131,8 @@ def ForOp : Loop_Op<"for",
   let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<"Builder *builder, OperationState &result, "
-              "Value lowerBound, Value upperBound, Value step">
+              "Value lowerBound, Value upperBound, Value step, "
+              "ValueRange iterArgs = llvm::None">
   ];
 
   let extraClassDeclaration = [{
index b4d25e04e3896f9765237d5402e9779702ed8fe8..a16c4a0c5cfb89ec3dcc7a8bff1eb537897eb6c4 100644 (file)
@@ -274,29 +274,75 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
   Location loc = parallelOp.getLoc();
   BlockAndValueMapping mapping;
 
-  if (parallelOp.getNumResults() != 0) {
-    // TODO: Implement lowering of parallelOp with reductions.
-    return matchFailure();
-  }
-
   // For a parallel loop, we essentially need to create an n-dimensional loop
   // nest. We do this by translating to loop.for ops and have those lowered in
-  // a further rewrite.
+  // a further rewrite. If a parallel loop contains reductions (and thus returns
+  // values), forward the initial values for the reductions down the loop
+  // hierarchy and bubble up the results by modifying the "yield" terminator.
+  SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.initVals());
+  bool first = true;
+  SmallVector<Value, 4> loopResults(iterArgs);
   for (auto loop_operands :
        llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(),
                  parallelOp.upperBound(), parallelOp.step())) {
     Value iv, lower, upper, step;
     std::tie(iv, lower, upper, step) = loop_operands;
-    ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step);
+    ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
     mapping.map(iv, forOp.getInductionVar());
+    auto iterRange = forOp.getRegionIterArgs();
+    iterArgs.assign(iterRange.begin(), iterRange.end());
+
+    if (first) {
+      // Store the results of the outermost loop that will be used to replace
+      // the results of the parallel loop when it is fully rewritten.
+      loopResults.assign(forOp.result_begin(), forOp.result_end());
+      first = false;
+    } else {
+      // A loop is constructed with an empty "yield" terminator by default.
+      // Replace it with another "yield" that forwards the results of the nested
+      // loop to the parent loop. We need to explicitly make sure the new
+      // terminator is the last operation in the block because further transfoms
+      // rely on this.
+      rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
+      rewriter.replaceOpWithNewOp<YieldOp>(
+          rewriter.getInsertionBlock()->getTerminator(), forOp.getResults());
+    }
+
     rewriter.setInsertionPointToStart(forOp.getBody());
   }
 
   // Now copy over the contents of the body.
-  for (auto &op : parallelOp.getBody()->without_terminator())
-    rewriter.clone(op, mapping);
+  SmallVector<Value, 4> yieldOperands;
+  yieldOperands.reserve(parallelOp.getNumResults());
+  for (auto &op : parallelOp.getBody()->without_terminator()) {
+    // Reduction blocks are handled differently.
+    auto reduce = dyn_cast<ReduceOp>(op);
+    if (!reduce) {
+      rewriter.clone(op, mapping);
+      continue;
+    }
+
+    // Clone the body of the reduction operation into the body of the loop,
+    // using operands of "loop.reduce" and iteration arguments corresponding
+    // to the reduction value to replace arguments of the reduction block.
+    // Collect operands of "loop.reduce.return" to be returned by a final
+    // "loop.yield" instead.
+    Value arg = iterArgs[yieldOperands.size()];
+    Block &reduceBlock = reduce.reductionOperator().front();
+    mapping.map(reduceBlock.getArgument(0), mapping.lookupOrDefault(arg));
+    mapping.map(reduceBlock.getArgument(1),
+                mapping.lookupOrDefault(reduce.operand()));
+    for (auto &nested : reduceBlock.without_terminator())
+      rewriter.clone(nested, mapping);
+    yieldOperands.push_back(
+        mapping.lookup(reduceBlock.getTerminator()->getOperand(0)));
+  }
+
+  rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
+  rewriter.replaceOpWithNewOp<YieldOp>(
+      rewriter.getInsertionBlock()->getTerminator(), yieldOperands);
 
-  rewriter.eraseOp(parallelOp);
+  rewriter.replaceOp(parallelOp, loopResults);
 
   return matchSuccess();
 }
index e9e9397a6d1932e0f98b4b03074a2f72fbc726ae..c0cb149bf815dfeb20a81d1dbfbc8337f5e77852 100644 (file)
@@ -61,11 +61,16 @@ LoopOpsDialect::LoopOpsDialect(MLIRContext *context)
 //===----------------------------------------------------------------------===//
 
 void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub,
-                  Value step) {
+                  Value step, ValueRange iterArgs) {
   result.addOperands({lb, ub, step});
+  result.addOperands(iterArgs);
+  for (Value v : iterArgs)
+    result.addTypes(v.getType());
   Region *bodyRegion = result.addRegion();
   ForOp::ensureTerminator(*bodyRegion, *builder, result.location);
   bodyRegion->front().addArgument(builder->getIndexType());
+  for (Value v : iterArgs)
+    bodyRegion->front().addArgument(v.getType());
 }
 
 static LogicalResult verify(ForOp op) {
index c6be3fdb8953579bab887bbfeb65a9d9af9b1238..54c5d4c4a9cfee01a09da02118fbf7337847c597 100644 (file)
@@ -236,3 +236,88 @@ func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
   }
   return %r : f32
 }
+
+func @generate() -> i64
+
+// CHECK-LABEL: @simple_parallel_reduce_loop
+// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[INIT:.*]]: f32
+func @simple_parallel_reduce_loop(%arg0: index, %arg1: index,
+                                  %arg2: index, %arg3: f32) -> f32 {
+  // A parallel loop with reduction is converted through sequential loops with
+  // reductions into a CFG of blocks where the partially reduced value is
+  // passed across as a block argument.
+
+  // Branch to the condition block passing in the initial reduction value.
+  // CHECK:   br ^[[COND:.*]](%[[LB]], %[[INIT]]
+
+  // Condition branch takes as arguments the current value of the iteration
+  // variable and the current partially reduced value.
+  // CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32
+  // CHECK:   %[[COMP:.*]] = cmpi "slt", %[[ITER]], %[[UB]]
+  // CHECK:   cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]]
+
+  // Bodies of loop.reduce operations are folded into the main loop body. The
+  // result of this partial reduction is passed as argument to the condition
+  // block.
+  // CHECK: ^[[BODY]]:
+  // CHECK:   %[[CST:.*]] = constant 4.2
+  // CHECK:   %[[PROD:.*]] = mulf %[[ITER_ARG]], %[[CST]]
+  // CHECK:   %[[INCR:.*]] = addi %[[ITER]], %[[STEP]]
+  // CHECK:   br ^[[COND]](%[[INCR]], %[[PROD]]
+
+  // The continuation block has access to the (last value of) reduction.
+  // CHECK: ^[[CONTINUE]]:
+  // CHECK:   return %[[ITER_ARG]]
+  %0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) {
+    %cst = constant 42.0 : f32
+    loop.reduce(%cst) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = mulf %lhs, %rhs : f32
+      loop.reduce.return %1 : f32
+    } : f32
+  } : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: parallel_reduce_loop
+// CHECK-SAME: %[[INIT1:[0-9A-Za-z_]*]]: f32)
+func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index,
+                           %arg3 : index, %arg4 : index, %arg5 : f32) -> (f32, i64) {
+  // Multiple reduction blocks should be folded in the same body, and the
+  // reduction value must be forwarded through block structures.
+  // CHECK:   %[[INIT2:.*]] = constant 42
+  // CHECK:   br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]]
+  // CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64
+  // CHECK:   cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]]
+  // CHECK: ^[[BODY_OUT]]:
+  // CHECK:   br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
+  // CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64
+  // CHECK:   cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]]
+  // CHECK: ^[[BODY_IN]]:
+  // CHECK:   %[[REDUCE1:.*]] = addf %[[ITER_ARG1_IN]], %{{.*}}
+  // CHECK:   %[[REDUCE2:.*]] = or %[[ITER_ARG2_IN]], %{{.*}}
+  // CHECK:   br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]]
+  // CHECK: ^[[CONT_IN]]:
+  // CHECK:   br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]]
+  // CHECK: ^[[CONT_OUT]]:
+  // CHECK:   return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]]
+  %step = constant 1 : index
+  %init = constant 42 : i64
+  %0:2 = loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+                       step (%arg4, %step) init(%arg5, %init) {
+    %cf = constant 42.0 : f32
+    loop.reduce(%cf) {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %1 = addf %lhs, %rhs : f32
+      loop.reduce.return %1 : f32
+    } : f32
+
+    %2 = call @generate() : () -> i64
+    loop.reduce(%2) {
+    ^bb0(%lhs: i64, %rhs: i64):
+      %3 = or %lhs, %rhs : i64
+      loop.reduce.return %3 : i64
+    } : i64
+  } : f32, i64
+  return %0#0, %0#1 : f32, i64
+}