From: Nicolas Vasilache Date: Fri, 24 Jun 2022 14:31:47 +0000 (-0700) Subject: [mlir][Vector]Fix bug where vector::WarpExecuteOnLane0Op are created with 2 blocks... X-Git-Tag: upstream/15.0.7~3674 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f6c79c6ae49f3a642bebe32a2346186c38bb83d7;p=platform%2Fupstream%2Fllvm.git [mlir][Vector]Fix bug where vector::WarpExecuteOnLane0Op are created with 2 blocks in the region Differential Revision: https://reviews.llvm.org/D128534 --- diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 9f308d4..08ea442 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -46,8 +46,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, Value bbArg = warpOpBody->getArgument(it.index()); rewriter.setInsertionPoint(ifOp); - Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, - bbArg.getType()); + Value buffer = + options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType()); // Store arg vector into buffer. rewriter.setInsertionPoint(ifOp); @@ -68,7 +68,7 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, // Insert sync after all the stores and before all the loads. if (!warpOp.getArgs().empty()) { rewriter.setInsertionPoint(ifOp); - options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); + options.warpSyncronizationFn(loc, rewriter, warpOp); } // Move body of warpOp to ifOp. @@ -82,8 +82,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, Value val = it.value(); Type resultType = warpOp->getResultTypes()[it.index()]; rewriter.setInsertionPoint(ifOp); - Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, - val.getType()); + Value buffer = + options.warpAllocationFn(loc, rewriter, warpOp, val.getType()); // Store yielded value into buffer. rewriter.setInsertionPoint(yieldOp); @@ -121,7 +121,7 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, // Insert sync after all the stores and before all the loads. if (!yieldOp.operands().empty()) { rewriter.setInsertionPointAfter(ifOp); - options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); + options.warpSyncronizationFn(loc, rewriter, warpOp); } // Delete terminator and add empty scf.yield. @@ -148,7 +148,12 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( Region &opBody = warpOp.getBodyRegion(); Region &newOpBody = newWarpOp.getBodyRegion(); + Block &newOpFirstBlock = newOpBody.front(); rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); + rewriter.eraseBlock(&newOpFirstBlock); + assert(newWarpOp.getWarpRegion().hasOneBlock() && + "expected WarpOp with single block"); + auto yield = cast(newOpBody.getBlocks().begin()->getTerminator());