Rename target block size to min task size for clarity.
authorbakhtiyar <bakhtiyar@x.team>
Tue, 28 Sep 2021 21:34:53 +0000 (14:34 -0700)
committerEugene Zhulenev <ezhulenev@google.com>
Tue, 28 Sep 2021 21:51:55 +0000 (14:51 -0700)
Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D110604

mlir/include/mlir/Dialect/Async/Passes.td
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-2d.mlir

index f9f9804f244b4e2789c76ed35ee607a43a7623b9..150929eba15118e4feb24f6ebbe01a4d98f5b472 100644 (file)
@@ -27,9 +27,9 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
       "int32_t", /*default=*/"8",
       "The number of available workers to execute async operations.">,
 
-    Option<"targetBlockSize", "target-block-size",
+    Option<"minTaskSize", "min-task-size",
       "int32_t", /*default=*/"1000",
-      "The target block size for sharding parallel operation.">
+      "The minimum task size for sharding parallel operation.">
   ];
 
   let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"];
index cfc1968d523d1dc73983f0ac1e11f47b9439799b..8ccf97e269500ea78068127c17bf48594c0116a4 100644 (file)
@@ -92,10 +92,10 @@ struct AsyncParallelForPass
   AsyncParallelForPass() = default;
 
   AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
-                       int32_t targetBlockSize) {
+                       int32_t minTaskSize) {
     this->asyncDispatch = asyncDispatch;
     this->numWorkerThreads = numWorkerThreads;
-    this->targetBlockSize = targetBlockSize;
+    this->minTaskSize = minTaskSize;
   }
 
   void runOnOperation() override;
@@ -104,9 +104,9 @@ struct AsyncParallelForPass
 struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
 public:
   AsyncParallelForRewrite(MLIRContext *ctx, bool asyncDispatch,
-                          int32_t numWorkerThreads, int32_t targetBlockSize)
+                          int32_t numWorkerThreads, int32_t minTaskSize)
       : OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
-        numWorkerThreads(numWorkerThreads), targetBlockSize(targetBlockSize) {}
+        numWorkerThreads(numWorkerThreads), minTaskSize(minTaskSize) {}
 
   LogicalResult matchAndRewrite(scf::ParallelOp op,
                                 PatternRewriter &rewriter) const override;
@@ -114,7 +114,7 @@ public:
 private:
   bool asyncDispatch;
   int32_t numWorkerThreads;
-  int32_t targetBlockSize;
+  int32_t minTaskSize;
 };
 
 struct ParallelComputeFunctionType {
@@ -564,7 +564,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
 // Dispatch parallel compute functions by submitting all async compute tasks
 // from a simple for loop in the caller thread.
 static void
-doSequantialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
+doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
                      ParallelComputeFunction &parallelComputeFunction,
                      scf::ParallelOp op, Value blockSize, Value blockCount,
                      const SmallVector<Value> &tripCounts) {
@@ -684,15 +684,15 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
         std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
 
     // Target block size from the pass parameters.
-    Value targetComputeBlock = b.create<ConstantIndexOp>(targetBlockSize);
+    Value minTaskSizeCst = b.create<ConstantIndexOp>(minTaskSize);
 
     // Compute parallel block size from the parallel problem size:
     //   blockSize = min(tripCount,
     //                   max(ceil_div(tripCount, maxComputeBlocks),
-    //                       targetComputeBlock))
+    //                       ceil_div(minTaskSize, bodySize)))
     Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
-    Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlock);
-    Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlock);
+    Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, minTaskSizeCst);
+    Value bs2 = b.create<SelectOp>(bs1, bs0, minTaskSizeCst);
     Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
     Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
     Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);
@@ -712,7 +712,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
       doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
                       blockCount, tripCounts);
     } else {
-      doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
+      doSequentialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
                            blockCount, tripCounts);
     }
 
@@ -733,7 +733,7 @@ void AsyncParallelForPass::runOnOperation() {
 
   RewritePatternSet patterns(ctx);
   patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
-                                        targetBlockSize);
+                                        minTaskSize);
 
   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     signalPassFailure();
@@ -743,9 +743,9 @@ std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
   return std::make_unique<AsyncParallelForPass>();
 }
 
-std::unique_ptr<Pass>
-mlir::createAsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
-                                 int32_t targetBlockSize) {
+std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch,
+                                                       int32_t numWorkerThreads,
+                                                       int32_t minTaskSize) {
   return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
-                                                targetBlockSize);
+                                                minTaskSize);
 }
index 2321905165bd98cd23dcfa11db8d16a29daaa667..ad3166e3c6ac89ab7e91a74b6121bcfefc65779c 100644 (file)
@@ -31,7 +31,7 @@
 
 // RUN:   mlir-opt %s -async-parallel-for="async-dispatch=false                \
 // RUN:                                    num-workers=20                      \
-// RUN:                                    target-block-size=1"                \
+// RUN:                                    min-task-size=1"                \
 // RUN:               -async-to-async-runtime                                  \
 // RUN:               -async-runtime-ref-counting                              \
 // RUN:               -async-runtime-ref-counting-opt                          \
index 595a06a1931ac6be97d8495180e6a7b49ec2f6e1..b5ba465f752cb5a1bf179584bbdec749a5c235b3 100644 (file)
@@ -29,7 +29,7 @@
 
 // RUN:   mlir-opt %s -async-parallel-for="async-dispatch=false                \
 // RUN:                                    num-workers=20                      \
-// RUN:                                    target-block-size=1"                \
+// RUN:                                    min-task-size=1"                \
 // RUN:               -async-to-async-runtime                                  \
 // RUN:               -async-runtime-ref-counting                              \
 // RUN:               -async-runtime-ref-counting-opt                          \