[mlir] Async: update condition for dispatching block-aligned compute function
authorEugene Zhulenev <ezhulenev@google.com>
Thu, 17 Feb 2022 18:22:18 +0000 (10:22 -0800)
committerEugene Zhulenev <ezhulenev@google.com>
Wed, 23 Feb 2022 18:29:55 +0000 (10:29 -0800)
+ compare block size with the unrollable inner dimension
+ reduce nesting in the code and simplify a bit IR building

Reviewed By: cota

Differential Revision: https://reviews.llvm.org/D120075

mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir

index e596fc3..c4ba141 100644 (file)
@@ -779,10 +779,10 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
     // and we can elide dynamic loop boundaries, and give LLVM an opportunity to
     // unroll the loops. The constant `512` is arbitrary, it should depend on
     // how many iterations LLVM will typically decide to unroll.
-    static constexpr int64_t maxIterations = 512;
+    static constexpr int64_t maxUnrollableIterations = 512;
 
     // The number of inner loops with statically known number of iterations less
-    // than the `maxIterations` value.
+    // than the `maxUnrollableIterations` value.
     int numUnrollableLoops = 0;
 
     auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; };
@@ -796,7 +796,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
       numIterations[i] = tripCount * innerIterations;
 
       // Update the number of inner loops that we can potentially unroll.
-      if (innerIterations > 0 && innerIterations <= maxIterations)
+      if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
         numUnrollableLoops++;
     }
 
@@ -856,9 +856,6 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
     Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
     Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
 
-    ParallelComputeFunction notUnrollableParallelComputeFunction =
-        createParallelComputeFunction(op, staticBounds, 0, rewriter);
-
     // Dispatch parallel compute function using async recursive work splitting,
     // or by submitting compute task sequentially from a caller thread.
     auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
@@ -869,42 +866,47 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
     // Compute the number of parallel compute blocks.
     Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
 
-    // Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations.
-    bool staticShouldUnroll = numUnrollableLoops > 0;
-    auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
+    // Dispatch parallel compute function without hints to unroll inner loops.
+    auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
+      ParallelComputeFunction compute =
+          createParallelComputeFunction(op, staticBounds, 0, rewriter);
+
+      ImplicitLocOpBuilder b(loc, nestedBuilder);
+      doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
+      b.create<scf::YieldOp>();
+    };
+
+    // Dispatch parallel compute function with hints for unrolling inner loops.
+    auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) {
+      ParallelComputeFunction compute = createParallelComputeFunction(
+          op, staticBounds, numUnrollableLoops, rewriter);
+
       ImplicitLocOpBuilder b(loc, nestedBuilder);
-      doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op,
-                 blockSize, blockCount, tripCounts);
+      // Align the block size to be a multiple of the statically known
+      // number of iterations in the inner loops.
+      Value numIters = b.create<arith::ConstantIndexOp>(
+          numIterations[op.getNumLoops() - numUnrollableLoops]);
+      Value alignedBlockSize = b.create<arith::MulIOp>(
+          b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
+      doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
+                 tripCounts);
       b.create<scf::YieldOp>();
     };
 
-    if (staticShouldUnroll) {
-      Value dynamicShouldUnroll = b.create<arith::CmpIOp>(
-          arith::CmpIPredicate::sge, blockSize,
-          b.create<arith::ConstantIndexOp>(maxIterations));
-
-      ParallelComputeFunction unrollableParallelComputeFunction =
-          createParallelComputeFunction(op, staticBounds, numUnrollableLoops,
-                                        rewriter);
-
-      auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
-        ImplicitLocOpBuilder b(loc, nestedBuilder);
-        // Align the block size to be a multiple of the statically known
-        // number of iterations in the inner loops.
-        Value numIters = b.create<arith::ConstantIndexOp>(
-            numIterations[op.getNumLoops() - numUnrollableLoops]);
-        Value alignedBlockSize = b.create<arith::MulIOp>(
-            b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
-        doDispatch(b, rewriter, unrollableParallelComputeFunction, op,
-                   alignedBlockSize, blockCount, tripCounts);
-        b.create<scf::YieldOp>();
-      };
-
-      b.create<scf::IfOp>(TypeRange(), dynamicShouldUnroll, dispatchUnrollable,
-                          dispatchNotUnrollable);
+    // Dispatch to block aligned compute function only if the computed block
+    // size is larger than the number of iterations in the unrollable inner
+    // loops, because otherwise it can reduce the available parallelism.
+    if (numUnrollableLoops > 0) {
+      Value numIters = b.create<arith::ConstantIndexOp>(
+          numIterations[op.getNumLoops() - numUnrollableLoops]);
+      Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
+          arith::CmpIPredicate::sge, blockSize, numIters);
+
+      b.create<scf::IfOp>(TypeRange(), useBlockAlignedComputeFn,
+                          dispatchBlockAligned, dispatchDefault);
       b.create<scf::YieldOp>();
     } else {
-      dispatchNotUnrollable(b, loc);
+      dispatchDefault(b, loc);
     }
   };
 
index 217e63b..8fc1c66 100644 (file)
@@ -87,7 +87,7 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
   return
 }
 
-// CHECK-LABEL: func private @parallel_compute_fn(
+// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops(
 // CHECK-SAME:   %[[BLOCK_INDEX:arg[0-9]+]]: index,
 // CHECK-SAME:   %[[BLOCK_SIZE:arg[0-9]+]]: index,
 // CHECK-SAME:   %[[TRIP_COUNT0:arg[0-9]+]]: index,
@@ -100,12 +100,14 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
 // CHECK-SAME:   %[[STEP1:arg[0-9]+]]: index,
 // CHECK-SAME:   %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
 // CHECK-SAME: ) {
+// CHECK:        %[[C0:.*]] = arith.constant 0 : index
+// CHECK:        %[[C1:.*]] = arith.constant 1 : index
+// CHECK:        %[[C10:.*]] = arith.constant 10 : index
 // CHECK:        scf.for %[[I:arg[0-9]+]]
-// CHECK:          arith.select
-// CHECK:          scf.for %[[J:arg[0-9]+]]
-// CHECK:          memref.store
+// CHECK-NOT:      arith.select
+// CHECK:          scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1
 
-// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops(
+// CHECK-LABEL: func private @parallel_compute_fn(
 // CHECK-SAME:   %[[BLOCK_INDEX:arg[0-9]+]]: index,
 // CHECK-SAME:   %[[BLOCK_SIZE:arg[0-9]+]]: index,
 // CHECK-SAME:   %[[TRIP_COUNT0:arg[0-9]+]]: index,
@@ -118,9 +120,7 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
 // CHECK-SAME:   %[[STEP1:arg[0-9]+]]: index,
 // CHECK-SAME:   %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
 // CHECK-SAME: ) {
-// CHECK:        %[[C0:.*]] = arith.constant 0 : index
-// CHECK:        %[[C1:.*]] = arith.constant 1 : index
-// CHECK:        %[[C10:.*]] = arith.constant 10 : index
 // CHECK:        scf.for %[[I:arg[0-9]+]]
-// CHECK-NOT:      arith.select
-// CHECK:          scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1
+// CHECK:          arith.select
+// CHECK:          scf.for %[[J:arg[0-9]+]]
+// CHECK:          memref.store