[mlir] Fix for gpu-async-region pass.
authorChristian Sigg <csigg@google.com>
Tue, 15 Dec 2020 20:15:28 +0000 (21:15 +0100)
committerChristian Sigg <csigg@google.com>
Wed, 16 Dec 2020 18:08:10 +0000 (19:08 +0100)
- the !gpu.async.token is the second result of 'gpu.alloc async', not the first.
- async.execute construction takes operand types not yet wrapped in !async.value.
- fix typo

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D93156

mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
mlir/test/Dialect/GPU/async-region.mlir

index eaa777c..c8378ae 100644 (file)
@@ -85,18 +85,19 @@ private:
     asyncOp.addAsyncDependency(currentToken);
 
     // Clone the op to return a token in addition to the other results.
-    SmallVector<Type, 1> resultTypes = {tokenType};
+    SmallVector<Type, 1> resultTypes;
     resultTypes.reserve(1 + op->getNumResults());
     copy(op->getResultTypes(), std::back_inserter(resultTypes));
+    resultTypes.push_back(tokenType);
     auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
                                     op->getOperands(), op->getMutableAttrDict(),
                                     op->getSuccessors());
 
     // Replace the op with the async clone.
     auto results = newOp->getResults();
-    currentToken = results.front();
+    currentToken = results.back();
     builder.insert(newOp);
-    op->replaceAllUsesWith(results.drop_front());
+    op->replaceAllUsesWith(results.drop_back());
     op->erase();
 
     return success();
@@ -165,7 +166,14 @@ private:
     // Construct new result type list with `count` additional types.
     SmallVector<Type, 2> resultTypes;
     resultTypes.reserve(numResults);
-    copy(executeOp.getResultTypes(), std::back_inserter(resultTypes));
+    transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
+              [](Type type) {
+                // Extract value type from !async.value.
+                if (auto valueType = type.dyn_cast<async::ValueType>())
+                  return valueType.getValueType();
+                assert(type.isa<async::TokenType>() && "expected token type");
+                return type;
+              });
     OpBuilder builder(executeOp);
     auto tokenType = builder.getType<gpu::AsyncTokenType>();
     resultTypes.resize(numResults, tokenType);
@@ -266,7 +274,7 @@ void GpuAsyncRegionPass::runOnFunction() {
           .wasInterrupted())
     return signalPassFailure();
 
-  // Collect gpu.wait ops that we can move out of gpu.execute regions.
+  // Collect gpu.wait ops that we can move out of async.execute regions.
   getFunction().getRegion().walk(DeferWaitCallback());
 }
 
index 2fc58cf..216ccce 100644 (file)
@@ -18,7 +18,11 @@ module attributes {gpu.container_module} {
     // CHECK: %[[t2:.*]] = gpu.launch_func async [%[[t1]]]
     gpu.launch_func @kernels::@kernel
         blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz)
-    // CHECK: gpu.wait [%[[t2]]]
+    // CHECK: %[[m:.*]], %[[t3:.*]] = gpu.alloc async [%[[t2]]] ()
+    %0 = gpu.alloc() : memref<7xf32>
+    // CHECK: %[[t4:.*]] = gpu.dealloc async [%[[t3]]] %[[m]]
+    gpu.dealloc %0 : memref<7xf32>
+    // CHECK: gpu.wait [%[[t4]]]
     // CHECK: call @foo
     call @foo() : () -> ()
     return
@@ -98,4 +102,27 @@ module attributes {gpu.container_module} {
     async.await %a1 : !async.token
     return
   }
+
+ // CHECK-LABEL:func @async_execute_with_result(%{{.*}}: index)
+  func @async_execute_with_result(%sz : index) -> index {
+    // CHECK: %[[a0:.*]], %[[f0:.*]]:2 = async.execute
+    // CHECK-SAME: -> (!async.value<index>, !async.value<!gpu.async.token>)
+    %a0, %f0 = async.execute -> !async.value<index> {
+      // CHECK: %[[t:.*]] = gpu.launch_func async
+      gpu.launch_func @kernels::@kernel
+          blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz)
+      // CHECK-NOT: gpu.wait
+      // CHECK: async.yield {{.*}}, %[[t]] : index, !gpu.async.token
+      async.yield %sz : index
+    }
+
+    // CHECK: async.await %[[a0]] : !async.token
+    // CHECK: %[[t:.*]] = async.await %[[f0]]#1 : !async.value<!gpu.async.token>
+    // CHECK: gpu.wait [%[[t]]]
+    async.await %a0 : !async.token
+    // CHECK: %[[x:.*]] = async.await %[[f0]]#0 : !async.value<index>
+    %x = async.await %f0 : !async.value<index>
+    // CHECK: return %[[x]] : index
+    return %x : index
+  }
 }