[mlir][Vector]Fix bug where vector::WarpExecuteOnLane0Op are created with 2 blocks...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 24 Jun 2022 14:31:47 +0000 (07:31 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 24 Jun 2022 14:33:58 +0000 (07:33 -0700)
Differential Revision: https://reviews.llvm.org/D128534

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

index 9f308d4..08ea442 100644 (file)
@@ -46,8 +46,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
     Value bbArg = warpOpBody->getArgument(it.index());
 
     rewriter.setInsertionPoint(ifOp);
-    Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
-                                            bbArg.getType());
+    Value buffer =
+        options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
 
     // Store arg vector into buffer.
     rewriter.setInsertionPoint(ifOp);
@@ -68,7 +68,7 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
   // Insert sync after all the stores and before all the loads.
   if (!warpOp.getArgs().empty()) {
     rewriter.setInsertionPoint(ifOp);
-    options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
+    options.warpSyncronizationFn(loc, rewriter, warpOp);
   }
 
   // Move body of warpOp to ifOp.
@@ -82,8 +82,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
     Value val = it.value();
     Type resultType = warpOp->getResultTypes()[it.index()];
     rewriter.setInsertionPoint(ifOp);
-    Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
-                                            val.getType());
+    Value buffer =
+        options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
 
     // Store yielded value into buffer.
     rewriter.setInsertionPoint(yieldOp);
@@ -121,7 +121,7 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
   // Insert sync after all the stores and before all the loads.
   if (!yieldOp.operands().empty()) {
     rewriter.setInsertionPointAfter(ifOp);
-    options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
+    options.warpSyncronizationFn(loc, rewriter, warpOp);
   }
 
   // Delete terminator and add empty scf.yield.
@@ -148,7 +148,12 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
 
   Region &opBody = warpOp.getBodyRegion();
   Region &newOpBody = newWarpOp.getBodyRegion();
+  Block &newOpFirstBlock = newOpBody.front();
   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  rewriter.eraseBlock(&newOpFirstBlock);
+  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+         "expected WarpOp with single block");
+
   auto yield =
       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());