[mlir][Vector] Significantly improve VectorToGPU.cpp
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 10 Feb 2023 15:22:30 +0000 (07:22 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 15 Feb 2023 00:49:36 +0000 (16:49 -0800)
This revision performs a bunch of cleanups and tracks free-flowing IR mutations.
APIs are systematized around RewriterBase and relevant debug messages are added.
Deliberate use of OpBuilder::InsertionGuard is added where needed.

Differential Revision: https://reviews.llvm.org/D143738

mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

index 6899134..d8231fc 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
+class LogicalResult;
 class MLIRContext;
 class Pass;
 class RewritePatternSet;
@@ -29,13 +30,14 @@ void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
 /// Convert vector ops to MMA matrix operations nested under `rootOp`. This will
 /// convert slice of operations that can be legally converted to MMA operations.
 /// The rest of the vector operations are left untouched.
-void convertVectorToMMAOps(Operation *rootOp);
+LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp);
 
 /// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons
 /// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice
 /// of operations that can be legally lowered on this path while the rest of
 /// the vector operations are left untouched.
-LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp);
+LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
+                                                   Operation *rootOp);
 
 /// Convert from vector to GPU ops.
 std::unique_ptr<Pass> createConvertVectorToGPUPass(bool useNvGpu = false);
index fac99dc..5880b09 100644 (file)
@@ -69,7 +69,7 @@ getMmaSyncRegisterType(const WarpMatrixInfo &type);
 /// please see NVIDIA's PTX documentation:
 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
 FailureOr<AffineMap>
-getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
                                   const WarpMatrixInfo &fragmentType);
 
 /// Encapsulates the parameters needed to lower a `nvgpu.ldmatrix` operation to
@@ -90,7 +90,7 @@ FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
 /// to two results representing offsets within the matrix operand that should
 /// be the pointer locations a thread should pass to the ldmatrix instruction.
 FailureOr<AffineMap>
-getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
                                const LdMatrixParams &params);
 
 /// Transform `vector.contract` into (m,k)x(n,k)x(m,n) form so that it can be
index b0fa50d..0266ba1 100644 (file)
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Region.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#define DEBUG_TYPE "vector-to-gpu"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
 #include "mlir/Conversion/Passes.h.inc"
@@ -45,7 +53,7 @@ using namespace mlir;
 /// the `offsetMap` has dimension placeholders, those should be provided in
 /// `dimValues`.
 template <typename TransferOpType>
-static void getXferIndices(OpBuilder &b, TransferOpType xferOp,
+static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
                            AffineMap offsetMap, ArrayRef<Value> dimValues,
                            SmallVector<Value, 4> &indices) {
   indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
@@ -56,9 +64,9 @@ static void getXferIndices(OpBuilder &b, TransferOpType xferOp,
       Value prevIdx = indices[dim.getPosition()];
       SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end());
       dims.push_back(prevIdx);
-      AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims());
+      AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
       indices[dim.getPosition()] = makeComposedAffineApply(
-          b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
+          rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
       continue;
     }
   }
@@ -94,8 +102,10 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
 
 // Return true if the given map represents a transposed matrix load,
 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
-static bool isTransposeMatrixLoadMap(OpBuilder &b, AffineMap permutationMap) {
-  MLIRContext *ctx = b.getContext();
+static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
+  MLIRContext *ctx = permutationMap.getContext();
+  // Local OpBuilder is fine here, we just build attributes.
+  OpBuilder b(ctx);
   auto nDim = permutationMap.getNumDims();
   AffineExpr zero = b.getAffineConstantExpr(0);
   if (nDim < 2) {
@@ -148,15 +158,16 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
       return false;
 
   AffineMap map = readOp.getPermutationMap();
-  OpBuilder b(readOp.getContext());
-  AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
-  AffineExpr zero = b.getAffineConstantExpr(0);
-  auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
-                                          readOp.getContext());
+
+  MLIRContext *ctx = readOp.getContext();
+  AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
+  AffineExpr zero = getAffineConstantExpr(0, ctx);
+  auto broadcastInnerDim =
+      AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
 
   if (!useNvGpu) {
     bool result = map.isMinorIdentity() || map == broadcastInnerDim ||
-                  isTransposeMatrixLoadMap(b, map);
+                  isTransposeMatrixLoadMap(map);
     return result;
   }
 
@@ -383,14 +394,13 @@ struct PrepareContractToGPUMMA
     if (!(vector::isParallelIterator(iteratorTypes[0]) &&
           vector::isParallelIterator(iteratorTypes[1]) &&
           vector::isReductionIterator(iteratorTypes[2])))
-      return failure();
+      return rewriter.notifyMatchFailure(op, "not a gemm contraction");
     //
     // Two outer parallel, one inner reduction (matmat flavor).
     //
-    if (maps == infer({{m, k}, {k, n}, {m, n}})) {
-      // This is the classical row-major matmul, nothing to do.
-      return failure();
-    }
+    // This is the classical row-major matmul, nothing to do.
+    if (maps == infer({{m, k}, {k, n}, {m, n}}))
+      return rewriter.notifyMatchFailure(op, "contraction already prepared");
     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
@@ -411,7 +421,8 @@ struct PrepareContractToGPUMMA
     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
       std::swap(lhs, rhs);
     } else {
-      return failure();
+      // TODO: llvm_unreachable ?
+      return rewriter.notifyMatchFailure(op, "unexpected contraction case");
     }
     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
         op, lhs, rhs, res,
@@ -445,14 +456,15 @@ struct CombineTransferReadOpTranspose final
 
     auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
     if (!transferReadOp)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "no transfer read");
 
     // TODO: support 0-d corner case.
     if (transferReadOp.getTransferRank() == 0)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "0-D transfer read");
 
     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
-      return failure();
+      return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
+
     SmallVector<int64_t, 2> perm;
     op.getTransp(perm);
     SmallVector<unsigned, 2> permU;
@@ -508,17 +520,24 @@ static const char *inferFragType(Operation *op) {
   return "COp";
 }
 
-static void convertTransferReadOp(vector::TransferReadOp op,
-                                  llvm::DenseMap<Value, Value> &valueMapping) {
+static LogicalResult
+convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
+                      llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
   assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
 
   std::optional<int64_t> stride =
       getMemrefConstantHorizontalStride(op.getShapedType());
+  if (!stride.has_value()) {
+    LLVM_DEBUG(DBGS() << "no stride\n");
+    return rewriter.notifyMatchFailure(op, "no stride");
+  }
 
   AffineMap map = op.getPermutationMap();
-  OpBuilder b(op);
-  bool isTranspose = isTransposeMatrixLoadMap(b, map);
+  bool isTranspose = isTransposeMatrixLoadMap(map);
 
   // Handle broadcast by setting the stride to 0.
   if (auto cstExpr =
@@ -526,7 +545,7 @@ static void convertTransferReadOp(vector::TransferReadOp op,
     assert(cstExpr.getValue() == 0);
     stride = 0;
   }
-  assert(stride);
+
   Value mappingResult = op.getResult();
   auto elType = op.getVectorType().getElementType();
   const char *fragType = inferFragType(op);
@@ -544,24 +563,47 @@ static void convertTransferReadOp(vector::TransferReadOp op,
   }
   gpu::MMAMatrixType type =
       gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
-  Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
+  Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
       op.getLoc(), type, op.getSource(), op.getIndices(),
-      b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr());
+      rewriter.getIndexAttr(*stride),
+      isTranspose ? rewriter.getUnitAttr() : UnitAttr());
   valueMapping[mappingResult] = load;
+
+  LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
+  return success();
 }
 
-static void convertTransferWriteOp(vector::TransferWriteOp op,
-                                   llvm::DenseMap<Value, Value> &valueMapping) {
+static LogicalResult
+convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
+                       llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
   assert(transferWriteSupportsMMAMatrixType(op));
   std::optional<int64_t> stride =
       getMemrefConstantHorizontalStride(op.getShapedType());
-  assert(stride);
-  OpBuilder b(op);
-  Value matrix = valueMapping.find(op.getVector())->second;
-  b.create<gpu::SubgroupMmaStoreMatrixOp>(
+  if (!stride.has_value()) {
+    LLVM_DEBUG(DBGS() << "no stride\n");
+    return rewriter.notifyMatchFailure(op, "no stride");
+  }
+
+  auto it = valueMapping.find(op.getVector());
+  if (it == valueMapping.end()) {
+    LLVM_DEBUG(DBGS() << "no mapping\n");
+    return rewriter.notifyMatchFailure(op, "no mapping");
+  }
+
+  Value matrix = it->second;
+  auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
       op.getLoc(), matrix, op.getSource(), op.getIndices(),
-      b.getIndexAttr(*stride), /*transpose=*/UnitAttr());
-  op.erase();
+      rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
+  (void)store;
+
+  LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
+
+  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+  rewriter.eraseOp(op);
+  return success();
 }
 
 /// Returns the vector type which represents a matrix fragment.
@@ -577,24 +619,33 @@ getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
 
 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
 static LogicalResult
-convertConstantOpMmaSync(arith::ConstantOp op,
+convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
                          llvm::DenseMap<Value, Value> &valueMapping) {
-  OpBuilder b(op);
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
       nvgpu::getWarpMatrixInfo(op);
-  if (failed(warpMatrixInfo))
-    return failure();
+  if (failed(warpMatrixInfo)) {
+    LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+    return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
+  }
 
   FailureOr<nvgpu::FragmentElementInfo> regInfo =
       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
-  if (failed(regInfo))
-    return failure();
+  if (failed(regInfo)) {
+    LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+    return rewriter.notifyMatchFailure(op, "not mma sync reg info");
+  }
 
   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
   auto dense = op.getValue().dyn_cast<SplatElementsAttr>();
-  if (!dense)
-    return failure();
-  Value result = b.create<arith::ConstantOp>(
+  if (!dense) {
+    LLVM_DEBUG(DBGS() << "not a splat\n");
+    return rewriter.notifyMatchFailure(op, "not a splat");
+  }
+
+  Value result = rewriter.create<arith::ConstantOp>(
       op.getLoc(), vectorType,
       DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
   valueMapping[op.getResult()] = result;
@@ -602,43 +653,54 @@ convertConstantOpMmaSync(arith::ConstantOp op,
 }
 
 static LogicalResult
-creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
+creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
                              llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
   Location loc = op->getLoc();
 
   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
       nvgpu::getWarpMatrixInfo(op);
-  if (failed(warpMatrixInfo))
-    return failure();
+  if (failed(warpMatrixInfo)) {
+    LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+    return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
+  }
 
   FailureOr<nvgpu::FragmentElementInfo> regInfo =
       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
-  if (failed(regInfo))
-    return failure();
+  if (failed(regInfo)) {
+    LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+    return rewriter.notifyMatchFailure(op, "not mma sync reg info");
+  }
 
   FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams(
       *warpMatrixInfo,
       /*transpose=*/!op.getPermutationMap().isMinorIdentity());
   if (failed(params)) {
-    return op->emitError()
-           << "failed to convert vector.transfer_read to ldmatrix; this op "
-              "likely "
-              "should not be converted to a nvgpu.ldmatrix call.";
+    LLVM_DEBUG(
+        DBGS()
+        << "failed to convert vector.transfer_read to ldmatrix. "
+        << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
+    return rewriter.notifyMatchFailure(
+        op, "failed to convert vector.transfer_read to ldmatrix; this op "
+            "likely should not be converted to a nvgpu.ldmatrix call.");
   }
 
   // Adjust the load offset.
-  auto laneId = builder.create<gpu::LaneIdOp>(loc);
+  auto laneId = rewriter.create<gpu::LaneIdOp>(loc);
   FailureOr<AffineMap> offsets =
-      nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params);
-  if (failed(offsets))
-    return failure();
+      nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
+  if (failed(offsets)) {
+    LLVM_DEBUG(DBGS() << "no offsets\n");
+    return rewriter.notifyMatchFailure(op, "no offsets");
+  }
 
   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
 
   SmallVector<Value, 4> indices;
-  getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId},
+  getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
                                          indices);
-  nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>(
+  nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
       loc, vectorType, op.getSource(), indices,
       !op.getPermutationMap().isMinorIdentity(), params->numTiles);
   valueMapping[op] = newOp->getResult(0);
@@ -646,32 +708,36 @@ creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
 }
 
 static LogicalResult
-createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
+createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
                        llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
   Location loc = op.getLoc();
   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
       nvgpu::getWarpMatrixInfo(op);
   if (failed(warpMatrixInfo))
-    return failure();
+    return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
   FailureOr<nvgpu::FragmentElementInfo> regInfo =
       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
   if (failed(regInfo)) {
-    op->emitError() << "Failed to deduce register fragment type during "
-                       "conversion to distributed non-ldmatrix compatible load";
-    return failure();
+    rewriter.notifyMatchFailure(
+        op, "Failed to deduce register fragment type during "
+            "conversion to distributed non-ldmatrix compatible load");
   }
 
-  Value laneId = builder.create<gpu::LaneIdOp>(loc);
+  Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
   SmallVector<Value, 4> elements;
 
   // This is the individual element type.
   Type loadedElType = regInfo->registerLLVMType;
   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
 
-  Value fill = builder.create<arith::ConstantOp>(
+  Value fill = rewriter.create<arith::ConstantOp>(
       op.getLoc(), vectorType.getElementType(),
-      builder.getZeroAttr(vectorType.getElementType()));
-  Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+      rewriter.getZeroAttr(vectorType.getElementType()));
+  Value result =
+      rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
 
   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
 
@@ -684,20 +750,21 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
 
     for (int i = 0; i < vectorType.getShape()[0]; i++) {
       FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
-          op.getLoc(), builder, *warpMatrixInfo);
+          rewriter, op.getLoc(), *warpMatrixInfo);
       if (failed(coords))
-        return failure();
-      Value logicalValueId = builder.create<arith::ConstantOp>(
-          loc, builder.getIndexType(),
-          builder.getIndexAttr(i * regInfo->elementsPerRegister));
+        return rewriter.notifyMatchFailure(op, "no coords");
+
+      Value logicalValueId = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getIndexType(),
+          rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
       SmallVector<Value, 4> newIndices;
       getXferIndices<vector::TransferReadOp>(
-          builder, op, *coords, {laneId, logicalValueId}, newIndices);
+          rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
 
-      Value el = builder.create<vector::LoadOp>(loc, loadedElType,
-                                                op.getSource(), newIndices);
-      result = builder.create<vector::InsertOp>(loc, el, result,
-                                                builder.getI64ArrayAttr(i));
+      Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
+                                                 op.getSource(), newIndices);
+      result = rewriter.create<vector::InsertOp>(loc, el, result,
+                                                 rewriter.getI64ArrayAttr(i));
     }
   } else {
     if (auto vecType = loadedElType.dyn_cast<VectorType>()) {
@@ -707,21 +774,21 @@ createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
       for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
            innerIdx++) {
 
-        Value logicalValueId = builder.create<arith::ConstantOp>(
-            loc, builder.getIndexType(),
-            builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
+        Value logicalValueId = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIndexType(),
+            rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
         FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
-            op.getLoc(), builder, *warpMatrixInfo);
+            rewriter, op.getLoc(), *warpMatrixInfo);
         if (failed(coords))
-          return failure();
+          return rewriter.notifyMatchFailure(op, "no coords");
 
         SmallVector<Value, 4> newIndices;
         getXferIndices<vector::TransferReadOp>(
-            builder, op, *coords, {laneId, logicalValueId}, newIndices);
-        Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType,
-                                                  op.getSource(), newIndices);
-        result = builder.create<vector::InsertOp>(
-            op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx}));
+            rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
+        Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
+                                                   op.getSource(), newIndices);
+        result = rewriter.create<vector::InsertOp>(
+            op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx}));
       }
     }
   }
@@ -744,14 +811,15 @@ static bool isSharedMemory(MemRefType type) {
 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
 /// used when converting to `nvgpu.mma.sync` operations.
 static LogicalResult
-convertTransferReadToLoads(vector::TransferReadOp op,
+convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
                            llvm::DenseMap<Value, Value> &valueMapping) {
-  OpBuilder b(op);
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
 
   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
       nvgpu::getWarpMatrixInfo(op);
   if (failed(warpMatrixInfo))
-    return failure();
+    return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
 
   bool isLdMatrixCompatible =
       isSharedMemory(op.getSource().getType().cast<MemRefType>()) &&
@@ -769,46 +837,54 @@ convertTransferReadToLoads(vector::TransferReadOp op,
     isLdMatrixCompatible = false;
 
   if (!isLdMatrixCompatible)
-    return createNonLdMatrixLoads(op, b, valueMapping);
+    return createNonLdMatrixLoads(rewriter, op, valueMapping);
 
-  return creatLdMatrixCompatibleLoads(op, b, valueMapping);
+  return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
 }
 
 static LogicalResult
-convertTransferWriteToStores(vector::TransferWriteOp op,
+convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
                              llvm::DenseMap<Value, Value> &valueMapping) {
-  OpBuilder b(op);
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
   Location loc = op->getLoc();
-  Value matrix = valueMapping.find(op.getVector())->second;
+  auto it = valueMapping.find(op.getVector());
+  if (it == valueMapping.end())
+    return rewriter.notifyMatchFailure(op, "no mapping");
+  Value matrix = it->second;
 
   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
       nvgpu::getWarpMatrixInfo(op);
   if (failed(warpMatrixInfo))
-    return failure();
+    return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
   FailureOr<nvgpu::FragmentElementInfo> regInfo =
       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
   if (failed(regInfo))
-    return failure();
+    return rewriter.notifyMatchFailure(op, "not mma sync reg info");
 
   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
-  Value laneId = b.create<gpu::LaneIdOp>(loc);
+  Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
 
   for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
-    Value logicalValueId = b.create<arith::ConstantOp>(
-        loc, b.getIndexType(),
-        b.getIndexAttr(i * regInfo->elementsPerRegister));
+    Value logicalValueId = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIndexType(),
+        rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
     FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
-        op.getLoc(), b, *warpMatrixInfo);
+        rewriter, op.getLoc(), *warpMatrixInfo);
     if (failed(coords))
-      return failure();
+      return rewriter.notifyMatchFailure(op, "no coords");
 
-    Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
+    Value el =
+        rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
     SmallVector<Value, 4> newIndices;
     getXferIndices<vector::TransferWriteOp>(
-        b, op, *coords, {laneId, logicalValueId}, newIndices);
-    b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
+        rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
+    rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
   }
-  op->erase();
+
+  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+  rewriter.eraseOp(op);
   return success();
 }
 
@@ -819,35 +895,37 @@ static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
 }
 
 static LogicalResult
-convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
+convertExtractStridedSlice(RewriterBase &rewriter,
+                           vector::ExtractStridedSliceOp op,
                            llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
 
-  OpBuilder b(op);
   Location loc = op->getLoc();
 
   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
       nvgpu::getWarpMatrixInfo(op);
   if (failed(warpMatrixInfo))
-    return failure();
+    return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
 
   FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
   if (failed(mmaSyncFragmentInfo))
-    return failure();
+    return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
 
   // Find the vector.transer_read whose result vector is being sliced.
   auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
   if (!transferReadOp)
-    return failure();
+    return rewriter.notifyMatchFailure(op, "no transfer read");
 
   warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
   if (failed(warpMatrixInfo))
-    return failure();
+    return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
 
   FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
   if (failed(ldFragmentInfo))
-    return failure();
+    return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
 
   assert(
       (mmaSyncFragmentInfo->elementsPerRegister ==
@@ -860,7 +938,10 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
   std::array<int64_t, 2> sliceShape = {
       mmaSyncFragmentInfo->numRegistersPerFragment,
       mmaSyncFragmentInfo->elementsPerRegister};
-  auto sourceVector = valueMapping.find(transferReadOp)->second;
+  auto it = valueMapping.find(transferReadOp);
+  if (it == valueMapping.end())
+    return rewriter.notifyMatchFailure(op, "no mapping");
+  auto sourceVector = it->second;
 
   // offset and sizes at warp-level of onwership.
   SmallVector<int64_t> offsets;
@@ -882,86 +963,114 @@ convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
   else if (offsets[1])
     sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
 
-  Value newOp = b.create<vector::ExtractStridedSliceOp>(
+  Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
       loc, sourceVector, sliceOffset, sliceShape, strides);
 
   valueMapping[op] = newOp;
   return success();
 }
 
-static void convertContractOp(vector::ContractionOp op,
-                              llvm::DenseMap<Value, Value> &valueMapping) {
-  OpBuilder b(op);
-  Value opA = valueMapping.find(op.getLhs())->second;
-  Value opB = valueMapping.find(op.getRhs())->second;
-  Value opC = valueMapping.find(op.getAcc())->second;
-  Value matmul = b.create<gpu::SubgroupMmaComputeOp>(
+static LogicalResult
+convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
+                  llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
+  auto itA = valueMapping.find(op.getLhs());
+  auto itB = valueMapping.find(op.getRhs());
+  auto itC = valueMapping.find(op.getAcc());
+  if (itA == valueMapping.end() || itB == valueMapping.end() ||
+      itC == valueMapping.end())
+    return rewriter.notifyMatchFailure(op, "no mapping");
+  Value opA = itA->second, opB = itB->second, opC = itC->second;
+  Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
       op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
       /*b_transpose=*/UnitAttr());
   valueMapping[op.getResult()] = matmul;
+  return success();
 }
 
 static LogicalResult
-convertContractOpToMmaSync(vector::ContractionOp op,
+convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
                            llvm::DenseMap<Value, Value> &valueMapping) {
-  OpBuilder b(op);
-  Value opA = valueMapping.find(op.getLhs())->second;
-  Value opB = valueMapping.find(op.getRhs())->second;
-  Value opC = valueMapping.find(op.getAcc())->second;
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
+  auto itA = valueMapping.find(op.getLhs());
+  auto itB = valueMapping.find(op.getRhs());
+  auto itC = valueMapping.find(op.getAcc());
+  if (itA == valueMapping.end() || itB == valueMapping.end() ||
+      itC == valueMapping.end())
+    return rewriter.notifyMatchFailure(op, "no mapping");
+  Value opA = itA->second, opB = itB->second, opC = itC->second;
   int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
   int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
   int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
-  Value matmul = b.create<nvgpu::MmaSyncOp>(op.getLoc(), opA, opB, opC,
-                                            b.getI64ArrayAttr({m, n, k}));
+  Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
+      op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
   valueMapping[op.getResult()] = matmul;
   return success();
 }
 
 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
-static void convertConstantOp(arith::ConstantOp op,
-                              llvm::DenseMap<Value, Value> &valueMapping) {
+static LogicalResult
+convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
+                  llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
   assert(constantSupportsMMAMatrixType(op));
-  OpBuilder b(op);
+
   auto splat =
       op.getValue().cast<SplatElementsAttr>().getSplatValue<TypedAttr>();
   auto scalarConstant =
-      b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
+      rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
   const char *fragType = inferFragType(op);
   auto vecType = op.getType().cast<VectorType>();
   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
-  auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
-                                                           scalarConstant);
+  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
+      op.getLoc(), type, scalarConstant);
   valueMapping[op.getResult()] = matrix;
+  return success();
 }
 
 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
-static void convertBroadcastOp(vector::BroadcastOp op,
-                               llvm::DenseMap<Value, Value> &valueMapping) {
+static LogicalResult
+convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
+                   llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
   assert(broadcastSupportsMMAMatrixType(op));
-  OpBuilder b(op);
+
   const char *fragType = inferFragType(op);
   auto vecType = op.getVectorType();
   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
-  auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
-                                                           op.getSource());
+  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
+      op.getLoc(), type, op.getSource());
   valueMapping[op.getResult()] = matrix;
+  return success();
 }
 
 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
 // updated and needs to be updated separatly for the loop to be correct.
-static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
+static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
+                                               scf::ForOp loop,
                                                ValueRange newIterOperands) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(loop);
+
   // Create a new loop before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(loop);
+  rewriter.setInsertionPoint(loop);
   auto operands = llvm::to_vector<4>(loop.getIterOperands());
   operands.append(newIterOperands.begin(), newIterOperands.end());
-  scf::ForOp newLoop =
-      b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(),
-                           loop.getUpperBound(), loop.getStep(), operands);
+  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+      loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+      operands);
   newLoop.getBody()->erase();
+
   newLoop.getLoopBody().getBlocks().splice(
       newLoop.getLoopBody().getBlocks().begin(),
       loop.getLoopBody().getBlocks());
@@ -970,25 +1079,35 @@ static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
 
   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
                                                   loop.getNumResults())))
-    std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
-  loop.erase();
+    rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+
+  LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
+  LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
+  LLVM_DEBUG(DBGS() << "erase: " << loop);
+
+  rewriter.eraseOp(loop);
   return newLoop;
 }
 
-static void convertForOp(scf::ForOp op,
-                         llvm::DenseMap<Value, Value> &valueMapping) {
+static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
+                                  llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
   SmallVector<Value> newOperands;
   SmallVector<std::pair<size_t, size_t>> argMapping;
   for (const auto &operand : llvm::enumerate(op.getIterOperands())) {
     auto it = valueMapping.find(operand.value());
-    if (it == valueMapping.end())
+    if (it == valueMapping.end()) {
+      LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
       continue;
+    }
     argMapping.push_back(std::make_pair(
         operand.index(), op.getNumIterOperands() + newOperands.size()));
     newOperands.push_back(it->second);
   }
-  OpBuilder b(op);
-  scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands);
+
+  scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
   Block &loopBody = *newForOp.getBody();
   for (auto mapping : argMapping) {
     valueMapping[newForOp.getResult(mapping.first)] =
@@ -997,11 +1116,17 @@ static void convertForOp(scf::ForOp op,
                                       newForOp.getNumInductionVars())] =
         loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
   }
+
+  LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
+  return success();
 }
 
-static void convertYieldOp(scf::YieldOp op,
-                           llvm::DenseMap<Value, Value> &valueMapping) {
-  OpBuilder b(op);
+static LogicalResult
+convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
+               llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
   auto loop = cast<scf::ForOp>(op->getParentOp());
   auto yieldOperands = llvm::to_vector<4>(op.getOperands());
   for (const auto &operand : llvm::enumerate(op.getOperands())) {
@@ -1013,20 +1138,32 @@ static void convertYieldOp(scf::YieldOp op,
     yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
     yieldOperands.push_back(it->second);
   }
-  b.create<scf::YieldOp>(op.getLoc(), yieldOperands);
-  op.erase();
+  rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
+
+  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+  rewriter.eraseOp(op);
+  return success();
 }
 
 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
-static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType,
-                                 llvm::DenseMap<Value, Value> &valueMapping) {
-  OpBuilder b(op);
+static LogicalResult
+convertElementwiseOp(RewriterBase &rewriter, Operation *op,
+                     gpu::MMAElementwiseOp opType,
+                     llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+
   SmallVector<Value> matrixOperands;
-  for (Value operand : op->getOperands())
-    matrixOperands.push_back(valueMapping.find(operand)->second);
-  Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>(
+  for (Value operand : op->getOperands()) {
+    auto it = valueMapping.find(operand);
+    if (it == valueMapping.end())
+      return rewriter.notifyMatchFailure(op, "no mapping");
+    matrixOperands.push_back(it->second);
+  }
+  Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
       op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType);
   valueMapping[op->getResult(0)] = newOp;
+  return success();
 }
 
 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
@@ -1041,67 +1178,75 @@ void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
           patterns.getContext());
 }
 
-void mlir::convertVectorToMMAOps(Operation *rootOp) {
+LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
+                                          Operation *rootOp) {
   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
   llvm::DenseMap<Value, Value> valueMapping;
+
+  auto globalRes = LogicalResult::success();
   for (Operation *op : ops) {
+    LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
+    // Apparently callers do not want to early exit on failure here.
+    auto res = LogicalResult::success();
     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
-      convertTransferReadOp(transferRead, valueMapping);
+      res = convertTransferReadOp(rewriter, transferRead, valueMapping);
     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
-      convertTransferWriteOp(transferWrite, valueMapping);
+      res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
-      convertContractOp(contractOp, valueMapping);
+      res = convertContractOp(rewriter, contractOp, valueMapping);
     } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
-      convertConstantOp(constantOp, valueMapping);
+      res = convertConstantOp(rewriter, constantOp, valueMapping);
     } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
-      convertBroadcastOp(broadcastOp, valueMapping);
+      res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
-      convertForOp(forOp, valueMapping);
-    } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
-      convertYieldOp(yiledOp, valueMapping);
+      res = convertForOp(rewriter, forOp, valueMapping);
+    } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
+      res = convertYieldOp(rewriter, yieldOp, valueMapping);
     } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
-      convertElementwiseOp(op, *elementwiseType, valueMapping);
+      res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
     }
+    if (failed(res))
+      globalRes = failure();
   }
+  return globalRes;
 }
 
-LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
+LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
+                                                         Operation *rootOp) {
   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
   llvm::DenseMap<Value, Value> valueMapping;
   for (Operation *op : ops) {
     if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
             .Case([&](vector::TransferReadOp transferReadOp) {
-              return convertTransferReadToLoads(transferReadOp, valueMapping);
+              return convertTransferReadToLoads(rewriter, transferReadOp,
+                                                valueMapping);
             })
             .Case([&](vector::TransferWriteOp transferWriteOp) {
-              return convertTransferWriteToStores(transferWriteOp,
+              return convertTransferWriteToStores(rewriter, transferWriteOp,
                                                   valueMapping);
             })
             .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
-              return convertExtractStridedSlice(extractStridedSliceOp,
+              return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
                                                 valueMapping);
             })
             .Case([&](vector::ContractionOp contractionOp) {
-              return convertContractOpToMmaSync(contractionOp, valueMapping);
+              return convertContractOpToMmaSync(rewriter, contractionOp,
+                                                valueMapping);
             })
             .Case([&](scf::ForOp forOp) {
-              convertForOp(forOp, valueMapping);
-              return success();
+              return convertForOp(rewriter, forOp, valueMapping);
             })
             .Case([&](scf::YieldOp yieldOp) {
-              convertYieldOp(yieldOp, valueMapping);
-              return success();
+              return convertYieldOp(rewriter, yieldOp, valueMapping);
             })
             .Case([&](arith::ConstantOp constOp) {
-              return convertConstantOpMmaSync(constOp, valueMapping);
+              return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
             })
             .Default([&](Operation *op) {
-              op->emitError() << "unhandled vector to mma type: " << *op;
-              return failure();
+              return op->emitError() << "unhandled vector to mma type: " << *op;
             })
             .failed()) {
-      op->emitError() << "Failed to convert op " << *op;
-      return failure();
+      return op->emitError() << "Failed to convert op " << *op;
     }
   }
   return success();
@@ -1123,12 +1268,13 @@ struct ConvertVectorToGPUPass
             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
       return signalPassFailure();
 
+    IRRewriter rewriter(&getContext());
     if (useNvGpu.getValue()) {
-      if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
+      if (failed(
+              convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
         return signalPassFailure();
     }
-
-    (void)convertVectorToMMAOps(getOperation());
+    (void)convertVectorToMMAOps(rewriter, getOperation());
   }
 };
 
index 6fdaaad..44f9b6d 100644 (file)
@@ -170,7 +170,7 @@ static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
 }
 
 FailureOr<AffineMap>
-nvgpu::getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
                                          const WarpMatrixInfo &fragmentType) {
   Type elementType = fragmentType.vectorType.getElementType();
   ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
@@ -235,7 +235,7 @@ nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
 }
 
 FailureOr<AffineMap>
-nvgpu::getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
                                       const LdMatrixParams &params) {
   // One thread per 128b row.
   const int bitsPerElement = static_cast<int>(
index c742150..0ba9eb4 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-gpu),canonicalize)" | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-gpu),canonicalize)" --split-input-file | FileCheck %s
 
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
@@ -25,6 +25,15 @@ func.func @matmul(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: mem
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_cst
 //   CHECK-DAG:   %[[CST:.+]] = arith.constant 0.000000e+00 : f16
 //   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
@@ -43,6 +52,15 @@ func.func @matmul_cst(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2:
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_broadcast
 //  CHECK-SAME:   (%{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %[[F:.*]]: f16)
 //   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[F]] : !gpu.mma_matrix<16x16xf16, "COp">
@@ -61,6 +79,15 @@ func.func @matmul_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_loop
 //       CHECK:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 //       CHECK:   %[[ACC:.+]] = scf.for {{.*}} iter_args(%[[ACC1:.+]] = %[[C]]) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
@@ -86,6 +113,15 @@ func.func @matmul_loop(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_fused_elementwise
 //   CHECK-DAG:   %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
 //   CHECK-DAG:   %[[CST_1:.+]] = arith.constant 1.000000e+00 : f16
@@ -109,6 +145,15 @@ func.func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x1
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_fused_broadcast
 //   CHECK-DAG:   %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
 //   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
@@ -134,6 +179,15 @@ func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16x
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_3Dmemref
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
@@ -153,6 +207,15 @@ func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %a
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_memref_strided
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 32 : index} : memref<2x16x16xf16, #{{.*}}> -> !gpu.mma_matrix<16x16xf16, "AOp">
@@ -172,6 +235,15 @@ func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1,
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_transposed
 //   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
 //   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
@@ -190,6 +262,15 @@ func.func @matmul_transposed(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_transposed_broadcasted_1d
 //   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index, transpose} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
 //   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
@@ -208,6 +289,15 @@ func.func @matmul_transposed_broadcasted_1d(%arg0: memref<16xf16>, %arg1: memref
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_transposed_broadcasted_2d
 //   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index, transpose} : memref<32x32xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
 //   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index} : memref<32x32xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
@@ -226,12 +316,25 @@ func.func @matmul_transposed_broadcasted_2d(%arg0: memref<32x32xf16>, %arg1: mem
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
+//   CHECK-DAG: #[[$map:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+//   CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//   CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
 // Do not convert to subgroup_mma ops with integer types if signedness cannot be inferred.
 // CHECK-LABEL: func @matmul_no_extend_int8
 //   CHECK-DAG:   %[[A:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
 //   CHECK-DAG:   %[[B:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
 //   CHECK-DAG:   %[[C:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
-//       CHECK:   %[[D:.+]] = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32>
+//       CHECK:   %[[D:.+]] = vector.contract {indexing_maps = [#[[$map]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32>
 //       CHECK:   vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
 func.func @matmul_no_extend_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) {
   %cst_0 = arith.constant dense<0> : vector<16x16xi8>
@@ -246,6 +349,15 @@ func.func @matmul_no_extend_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_int8
 //   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "AOp">
 //   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp">
@@ -267,6 +379,15 @@ func.func @matmul_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2:
   return
 }
 
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
+
 // CHECK-LABEL: func @matmul_mixed_signedness_int8
 //   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xui8, "AOp">
 //   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp">