[MLIR][GPU] Add NvGpu mma.sync path to the VectorToGPU pass
authorChristopher Bate <cbate@nvidia.com>
Tue, 17 May 2022 23:54:29 +0000 (17:54 -0600)
committerChristopher Bate <cbate@nvidia.com>
Fri, 20 May 2022 15:42:55 +0000 (09:42 -0600)
This changes adds the option to lower to NvGpu dialect ops during the
VectorToGPU convsersion pass. Because this transformation reuses
existing VectorToGPU logic, a seperate VectorToNvGpu conversion pass is
not created. The option `use-nvgpu` is added to the VectorToGPU pass.
When this is true, the pass will attempt to convert slices rooted at
`vector.contract` operations into `nvgpu.mma.sync` ops, and
`vector.transfer_read` ops are converted to either `nvgpu.ldmatrix` or
one or more `vector.load` operations.  The specific data loaded will
depend on the thread id within a subgroup (warp). These index
calculations depend on data type and shape of the MMA op
according to the downstream PTX specification. The code for supporting
these details is separated into `NvGpuSupport.cpp|h`.

Differential Revision: https://reviews.llvm.org/D122940

mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
mlir/lib/Conversion/PassDetail.h
mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp [new file with mode: 0644]
mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h [new file with mode: 0644]
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir [new file with mode: 0644]

index 41e7b29..6d9863e 100644 (file)
@@ -851,8 +851,13 @@ def ConvertVectorToGPU : Pass<"convert-vector-to-gpu"> {
                 "dialect";
   let constructor = "mlir::createConvertVectorToGPUPass()";
   let dependentDialects = [
-    "memref::MemRefDialect",
-    "gpu::GPUDialect"
+    "memref::MemRefDialect", "gpu::GPUDialect", "AffineDialect", 
+    "vector::VectorDialect", "nvgpu::NVGPUDialect"
+  ];
+
+  let options = [
+    Option<"useNvGpu", "use-nvgpu", "bool", /*default=*/"false", 
+      "convert to NvGPU ops instead of GPU dialect ops">
   ];
 }
 
index 266fa0e..1ba5b3f 100644 (file)
@@ -17,16 +17,25 @@ class Pass;
 class RewritePatternSet;
 
 /// Patterns to transform vector ops into a canonical form to convert to MMA
-/// matrix operations.
-void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns);
+/// matrix operations. If `useNvGpu` is true, then the patterns will populated
+/// will prepare for conversion to `nvgpu` mma operations rather than the `gpu`
+/// dialect WMMA operations.
+void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
+                                        bool useNvGpu = false);
 
 /// Convert vector ops to MMA matrix operations nested under `rootOp`. This will
 /// convert slice of operations that can be legally converted to MMA operations.
 /// The rest of the vector operations are left untouched.
 void convertVectorToMMAOps(Operation *rootOp);
 
+/// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons
+/// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice
+/// of operations that can be legally lowered on this path while the rest of
+/// the vector operations are left untouched.
+LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp);
+
 /// Convert from vector to GPU ops.
-std::unique_ptr<Pass> createConvertVectorToGPUPass();
+std::unique_ptr<Pass> createConvertVectorToGPUPass(bool useNvGpu = false);
 
 } // namespace mlir
 
index e050040..530e156 100644 (file)
@@ -55,6 +55,10 @@ namespace LLVM {
 class LLVMDialect;
 } // namespace LLVM
 
+namespace nvgpu {
+class NVGPUDialect;
+}
+
 namespace NVVM {
 class NVVMDialect;
 } // namespace NVVM
index 06758c5..778f2c4 100644 (file)
@@ -1,5 +1,6 @@
 add_mlir_conversion_library(MLIRVectorToGPU
   VectorToGPU.cpp
+  NvGpuSupport.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU
diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
new file mode 100644 (file)
index 0000000..a2820c3
--- /dev/null
@@ -0,0 +1,327 @@
+//===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides utilities to assist in the lowering of Vector operations
+// to NvGPU dialect MMA operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "NvGpuSupport.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir {
+namespace nvgpu {
+namespace {
+
+/// There are always 4 threads per [128|256|512] bit row.
+constexpr int64_t kThreadsPerRow = 4;
+
+constexpr int64_t kNumRowsPerTile = 8;
+
+bool isAccumulatorOrResult(MatMulOperandRole operandType) {
+  return operandType == MatMulOperandRole::C;
+}
+
+/// Returns the number of registers which compose a matrix fragment held by a
+/// single thread.
+int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
+  int64_t lineSize = inferTileWidthInBits(type);
+  auto shape = type.vectorType.getShape();
+  return (shape[0] / kNumRowsPerTile) *
+         (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
+         lineSize;
+}
+
+/// Returns the number of 8 x [128|256|512] bit tiles that compose the given
+/// operand shape.
+std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
+                                    Type elementType, int64_t lineSizeBits) {
+  // For each 8x128bit square, a thread is responsible for one 32bit register.
+  return {operandShape[0] / kNumRowsPerTile,
+          (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
+              lineSizeBits};
+}
+
+} // namespace
+
+FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) {
+  WarpMatrixInfo info;
+
+  // Determine the vector type.
+  if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
+    info.vectorType = writeOp.getVectorType();
+  } else if (isa<vector::TransferReadOp, vector::ContractionOp,
+                 arith::ConstantOp>(op)) {
+    info.vectorType = op->getResult(0).getType().cast<VectorType>();
+  } else {
+    return op->emitError()
+           << "unhandled operation type in nvgpu.mma.sync conversion path";
+  }
+
+  // Determine the operand role. We assume it is an accumulator/result unless it
+  // is directly consumed by a `vector.contract` op.
+  info.operandRole = MatMulOperandRole::C;
+  for (Operation *user : op->getUsers()) {
+    auto contract = dyn_cast<vector::ContractionOp>(user);
+    if (!contract)
+      continue;
+    if (contract.getLhs() == op->getResult(0)) {
+      info.operandRole = MatMulOperandRole::A;
+      break;
+    }
+    if (contract.getRhs() == op->getResult(0)) {
+      info.operandRole = MatMulOperandRole::B;
+      break;
+    }
+  }
+  return info;
+}
+
+int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
+  bool isAcc = isAccumulatorOrResult(type.operandRole);
+  Type elType = type.vectorType.getElementType();
+  if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
+    return 256;
+  }
+  if (elType.getIntOrFloatBitWidth() == 64) {
+    return isAcc ? 512 : 256;
+  }
+  return 128;
+}
+
+FailureOr<FragmentElementInfo>
+getMmaSyncRegisterType(const WarpMatrixInfo &type) {
+  MLIRContext *ctx = type.vectorType.getContext();
+  const bool isAccum = isAccumulatorOrResult(type.operandRole);
+
+  Type elType = type.vectorType.getElementType();
+  if (elType.isF16()) {
+    return FragmentElementInfo{
+        LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
+        inferNumRegistersPerMatrixFragment(type)};
+  }
+
+  // f64 operand
+  Type f64Ty = Float64Type::get(ctx);
+  if (elType.isF64()) {
+    return isAccum
+               ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
+                                     inferNumRegistersPerMatrixFragment(type)}
+               : FragmentElementInfo{f64Ty, 1, 64,
+                                     inferNumRegistersPerMatrixFragment(type)};
+  }
+
+  // int8 operand
+  if (elType.isInteger(8)) {
+    return FragmentElementInfo{
+        LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
+        inferNumRegistersPerMatrixFragment(type)};
+  }
+  // Integer 32bit acc operands
+  if (elType.isInteger(32)) {
+    return FragmentElementInfo{
+        LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
+        inferNumRegistersPerMatrixFragment(type)};
+  }
+
+  // Floating point 32bit operands
+  if (elType.isF32()) {
+    Type f32Ty = Float32Type::get(ctx);
+    return isAccum
+               ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
+                                     inferNumRegistersPerMatrixFragment(type)}
+               : FragmentElementInfo{f32Ty, 1, 32,
+                                     inferNumRegistersPerMatrixFragment(type)};
+  }
+  return failure();
+}
+
+static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
+                                                 Type elementType,
+                                                 ArrayRef<int64_t> operandShape,
+                                                 bool isAccumulator,
+                                                 int64_t elementsPerRegister,
+                                                 AffineExpr logicalValueId) {
+  const int64_t elementsPerLine =
+      lineSize / elementType.getIntOrFloatBitWidth();
+  const std::array<int64_t, 2> num8x128bTiles =
+      getTileShape(operandShape, elementType, lineSize);
+  AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
+  return AffineMap::get(
+      2, 0,
+      {(registerIdx % num8x128bTiles[0]) * 8,
+       (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
+      elementType.getContext());
+}
+
+FailureOr<AffineMap>
+getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+                                  const WarpMatrixInfo &fragmentType) {
+  Type elementType = fragmentType.vectorType.getElementType();
+  ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
+  FailureOr<nvgpu::FragmentElementInfo> regInfo =
+      getMmaSyncRegisterType(fragmentType);
+  if (failed(regInfo))
+    return failure();
+
+  const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
+  const int64_t elementsPerRegister =
+      regInfo->registerWidthBits / elementBitWidth;
+  const int64_t lineSize = inferTileWidthInBits(fragmentType);
+
+  AffineExpr laneId, logicalValueIdDim;
+  bindDims(builder.getContext(), laneId, logicalValueIdDim);
+
+  // Determine what register logicalValueId corresponds to. Use that as a
+  // linear index into the coordinate mapping `index -> (tile row, tile col)`.
+  AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
+      lineSize, elementType, operandShape,
+      isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
+      logicalValueIdDim);
+
+  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
+    return AffineMap::get(2, 0, dimExprs, builder.getContext());
+  };
+
+  auto tileRow = registerIndexToTileCoord.getResult(0);
+  auto tileCol = registerIndexToTileCoord.getResult(1);
+  return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
+                  tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
+                      (logicalValueIdDim % elementsPerRegister)});
+}
+
+FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
+                                                   bool transpose) {
+  LdMatrixParams params;
+  Type elType = type.vectorType.getElementType();
+  params.fragmentType = type.vectorType;
+  if (type.operandRole == MatMulOperandRole::A ||
+      type.operandRole == MatMulOperandRole::C) {
+    params.targetLayout = NVVM::MMALayout::row;
+  } else {
+    params.targetLayout = NVVM::MMALayout::col;
+  }
+  ArrayRef<int64_t> shape = type.vectorType.getShape();
+  params.contiguousDimType =
+      transpose ? IteratorType::Parallel : IteratorType::Reduction;
+
+  if (params.targetLayout == NVVM::MMALayout::row) {
+    params.numTiles = (shape[0] / kNumRowsPerTile) *
+                      ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
+  } else {
+    params.numTiles = (shape[1] / kNumRowsPerTile) *
+                      ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
+  }
+
+  if (params.numTiles == 0)
+    return failure();
+
+  return params;
+}
+
+FailureOr<AffineMap>
+getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+                               const LdMatrixParams &params) {
+  // One thread per 128b row.
+  const int64_t kNumThreadsPerTile = kNumRowsPerTile;
+  const int bitsPerElement = static_cast<int>(
+      params.fragmentType.getElementType().getIntOrFloatBitWidth());
+  const int kElementsPer128b = (128 / bitsPerElement);
+  ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
+  AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
+
+  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
+    return AffineMap::get(1, 0, dimExprs, builder.getContext());
+  };
+
+  // This case corresponds to row-major A|C or col-major B operands.
+  if (params.contiguousDimType == IteratorType::Reduction) {
+    AffineExpr row = d0 % (operandShape[0]);
+    AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
+    return makeMap({row, col});
+  }
+
+  // This case Corresponds to col-major A|C or row-major B operands. The
+  // operandShape given is already pre-transposed (e.g. 8x16 = KxN).
+  if (params.contiguousDimType == IteratorType::Parallel) {
+    const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
+    // Threads are assigned in groups of 8 first across columns, then to
+    // rows. This is transpose of what `ldmatrix` expects, but when
+    // `ldmatrix` gets the `.trans` qualifier, final the effect will be to
+    // transpose just the blocks.
+    auto groupIdx = d0.floorDiv(kNumThreadsPerTile);
+    auto tileCol = (groupIdx % num8x128bCols);
+    auto tileRow = groupIdx.floorDiv(num8x128bCols);
+    return makeMap({tileCol * kElementsPer128b,
+                    tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)});
+  }
+  return failure();
+}
+
+LogicalResult
+PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
+                                             PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  Value lhs = op.getLhs();
+  Value rhs = op.getRhs();
+  Value res = op.getAcc();
+
+  // Set up the parallel/reduction structure in right form.
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+  AffineExpr m;
+  AffineExpr n;
+  AffineExpr k;
+  bindDims(rewriter.getContext(), m, n, k);
+  static constexpr std::array<int64_t, 2> perm = {1, 0};
+  auto iteratorTypes = op.getIteratorTypes().getValue();
+  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+  if (iteratorTypes.size() != 3)
+    return failure();
+  if (!(isParallelIterator(iteratorTypes[0]) &&
+        isParallelIterator(iteratorTypes[1]) &&
+        isReductionIterator(iteratorTypes[2])))
+    return failure();
+
+  // The canonical form is "TNT" = A row-major, B col-major, C row-major.
+  const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
+  if (maps == canonicalForm) {
+    return failure();
+  }
+  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+    std::swap(rhs, lhs);
+    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+    std::swap(rhs, lhs);
+    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+    std::swap(lhs, rhs);
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+    std::swap(lhs, rhs);
+  } else {
+    return failure();
+  }
+  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+      op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
+      op.getIteratorTypes());
+  return success();
+}
+
+} // namespace nvgpu
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
new file mode 100644 (file)
index 0000000..9902faa
--- /dev/null
@@ -0,0 +1,100 @@
+//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides utilities to assist in the lowering of Vector operations
+// to GPU dialect MMA operations.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
+#define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace nvgpu {
+
+enum class MatMulOperandRole : int32_t { A = 0, B, C };
+
+/// Collects information about a warp-level matrix operand represented by a
+/// VectorType.
+struct WarpMatrixInfo {
+  VectorType vectorType;
+  MatMulOperandRole operandRole;
+};
+
+/// Given an op that operates on a VectorType representing a warp-level matrix
+/// operand, the function returns a struct containing relevant type information.
+FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op);
+
+/// Returns the number of bits in a single tile row. It is either 128, 256, or
+/// 512 bits depending on the data type and` whether the operand is an
+/// accumulator/result operand
+int64_t inferTileWidthInBits(const WarpMatrixInfo &type);
+
+/// Specifies information about the registers which compose a matrix fragment
+/// according to the PTX documentation.
+struct FragmentElementInfo {
+  Type registerLLVMType;
+  int64_t elementsPerRegister;
+  int64_t registerWidthBits;
+  int64_t numRegistersPerFragment;
+};
+
+/// Returns a FragmentElementInfo struct describing the register types for the
+/// given matrix fragment type.
+FailureOr<FragmentElementInfo>
+getMmaSyncRegisterType(const WarpMatrixInfo &type);
+
+/// Returns an AffineMap which maps a two dimensions representing (laneId,
+/// logicalValueId) and returns two results representing offsets within a
+/// matrix operand. The offsets point to the values the thread is responsible
+/// for (AKA the matrix fragment values) during a warp-collective matrix
+/// operation. For a visual reference of this LaneId -> (row, col) mapping,
+/// please see NVIDIA's PTX documentation:
+/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
+FailureOr<AffineMap>
+getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+                                  const WarpMatrixInfo &fragmentType);
+
+struct LdMatrixParams {
+  VectorType fragmentType;
+  bool isAccum;
+  int64_t numTiles;
+  IteratorType contiguousDimType;
+  NVVM::MMALayout targetLayout;
+};
+
+FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
+                                            bool transpose);
+/// Returns an AffineMap which maps a single dimension representing the laneId
+/// to two results representing offsets within the matrix operand that should
+/// be the pointer locations a thread should pass to the ldmatrix instruction.
+FailureOr<AffineMap>
+getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+                               const LdMatrixParams &params);
+
+// Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted
+// to MMA matmul.
+struct PrepareContractToGPUMMASync
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+} // namespace nvgpu
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
index 9ed1c34..a6e122c 100644 (file)
@@ -12,6 +12,7 @@
 
 #include <type_traits>
 
+#include "NvGpuSupport.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
 
 #include "../PassDetail.h"
@@ -19,6 +20,7 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 
+/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
+/// AffineMap representing offsets to apply to indices, the function fills
+/// `indices` with the original indices plus the offsets. The offsets are
+/// applied by taking into account the permutation map of the transfer op. If
+/// the `offsetMap` has dimension placeholders, those should be provided in
+/// `dimValues`.
+template <typename TransferOpType>
+static void getXferIndices(OpBuilder &b, TransferOpType xferOp,
+                           AffineMap offsetMap, ArrayRef<Value> dimValues,
+                           SmallVector<Value, 4> &indices) {
+  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
+  Location loc = xferOp.getLoc();
+  unsigned offsetsIdx = 0;
+  for (auto expr : xferOp.getPermutationMap().getResults()) {
+    if (auto dim = expr.template dyn_cast<AffineDimExpr>()) {
+      Value prevIdx = indices[dim.getPosition()];
+      SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end());
+      dims.push_back(prevIdx);
+      AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims());
+      indices[dim.getPosition()] = makeComposedAffineApply(
+          b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
+      continue;
+    }
+  }
+}
+
 // Return true if the contract op can be convert to MMA matmul.
-static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
+static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
+                                          bool useNvGpu) {
   if (llvm::size(contract.getMasks()) != 0)
     return false;
 
@@ -47,7 +77,10 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
 
   // The contract needs to represent a matmul to be able to convert to
   // MMAMatrix matmul.
-  if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
+  if (!useNvGpu &&
+      contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
+    return false;
+  if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}}))
     return false;
 
   return true;
@@ -61,7 +94,7 @@ getMemrefConstantHorizontalStride(ShapedType type) {
   if (!memrefType)
     return false;
   // If the memref is 0 or 1D the horizontal stride is 0.
-  if(memrefType.getRank() < 2)
+  if (memrefType.getRank() < 2)
     return 0;
   int64_t offset = 0;
   SmallVector<int64_t, 2> strides;
@@ -75,7 +108,8 @@ getMemrefConstantHorizontalStride(ShapedType type) {
 }
 
 // Return true if the transfer op can be converted to a MMA matrix load.
-static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
+static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
+                                              bool useNvGpu) {
   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
       readOp.getVectorType().getRank() != 2)
     return false;
@@ -87,9 +121,14 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
   AffineExpr zero = b.getAffineConstantExpr(0);
   auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
                                           readOp.getContext());
-  // TODO: Support transpose once it is added to GPU dialect ops.
-  // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
-  return !(!map.isMinorIdentity() && map != broadcastInnerDim);
+
+  if (!useNvGpu) {
+    // TODO: Support transpose once it is added to GPU dialect ops.
+    // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
+    return map.isMinorIdentity() || map == broadcastInnerDim;
+  }
+
+  return true;
 }
 
 // Return true if the transfer op can be converted to a MMA matrix store.
@@ -147,15 +186,15 @@ static bool elementwiseSupportsMMAMatrixType(Operation *op) {
   return convertElementwiseOpToMMA(op).hasValue();
 }
 
-static bool supportsMMaMatrixType(Operation *op) {
+static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
   if (isa<scf::ForOp, scf::YieldOp>(op))
     return true;
   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
-    return transferReadSupportsMMAMatrixType(transferRead);
+    return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
     return transferWriteSupportsMMAMatrixType(transferWrite);
   if (auto contract = dyn_cast<vector::ContractionOp>(op))
-    return contractSupportsMMAMatrixType(contract);
+    return contractSupportsMMAMatrixType(contract, useNvGpu);
   if (auto constant = dyn_cast<arith::ConstantOp>(op))
     return constantSupportsMMAMatrixType(constant);
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
@@ -203,7 +242,8 @@ static SetVector<Operation *> getSliceContract(Operation *op,
 
 // Analyze slice of operations based on convert op to figure out if the whole
 // slice can be converted to MMA operations.
-static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
+static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
+                                             bool useNvGpu) {
   auto hasVectorDest = [](Operation *op) {
     return llvm::any_of(op->getResultTypes(),
                         [](Type t) { return t.isa<VectorType>(); });
@@ -221,8 +261,9 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
     // If any instruction cannot use MMA matrix type drop the whole
     // chain. MMA matrix are stored in an opaque type so they cannot be used
     // by all operations.
-    if (llvm::any_of(dependentOps,
-                     [](Operation *op) { return !supportsMMaMatrixType(op); }))
+    if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
+          return !supportsMMaMatrixType(op, useNvGpu);
+        }))
       return;
     opToConvert.insert(dependentOps.begin(), dependentOps.end());
   });
@@ -351,7 +392,7 @@ static const char *inferFragType(OpTy op) {
 static void convertTransferReadOp(vector::TransferReadOp op,
                                   llvm::DenseMap<Value, Value> &valueMapping) {
   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
-  assert(transferReadSupportsMMAMatrixType(op));
+  assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
   Optional<int64_t> stride =
       getMemrefConstantHorizontalStride(op.getShapedType());
   AffineMap map = op.getPermutationMap();
@@ -386,6 +427,250 @@ static void convertTransferWriteOp(vector::TransferWriteOp op,
   op.erase();
 }
 
+/// Returns the vector type which represents a matrix fragment.
+static VectorType
+getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
+  SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
+                             regInfo.elementsPerRegister};
+  Type elType = regInfo.registerLLVMType;
+  if (auto vecType = elType.dyn_cast<VectorType>())
+    elType = vecType.getElementType();
+  return VectorType::get(shape, elType);
+}
+
+/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
+static LogicalResult
+convertConstantOpMmaSync(arith::ConstantOp op,
+                         llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder b(op);
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+
+  FailureOr<nvgpu::FragmentElementInfo> regInfo =
+      nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+  if (failed(regInfo))
+    return failure();
+
+  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+  auto dense = op.getValue().dyn_cast<SplatElementsAttr>();
+  if (!dense)
+    return failure();
+  Value result = b.create<arith::ConstantOp>(
+      op.getLoc(), vectorType,
+      DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
+  valueMapping[op.getResult()] = result;
+  return success();
+}
+
+static LogicalResult
+creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
+                             llvm::DenseMap<Value, Value> &valueMapping) {
+  Location loc = op->getLoc();
+
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+
+  FailureOr<nvgpu::FragmentElementInfo> regInfo =
+      nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+  if (failed(regInfo))
+    return failure();
+
+  FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams(
+      *warpMatrixInfo,
+      /*transpose=*/!op.getPermutationMap().isMinorIdentity());
+  if (failed(params)) {
+    return op->emitError()
+           << "failed to convert vector.transfer_read to ldmatrix; this op "
+              "likely "
+              "should not be converted to a nvgpu.ldmatrix call.";
+  }
+
+  // Adjust the load offset.
+  auto laneId = builder.create<gpu::LaneIdOp>(loc);
+  FailureOr<AffineMap> offsets =
+      nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params);
+  if (failed(offsets))
+    return failure();
+
+  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+
+  SmallVector<Value, 4> indices;
+  getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId},
+                                         indices);
+  nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>(
+      loc, vectorType, op.getSource(), indices,
+      !op.getPermutationMap().isMinorIdentity(), params->numTiles);
+  valueMapping[op] = newOp->getResult(0);
+  return success();
+}
+
+static LogicalResult
+createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
+                       llvm::DenseMap<Value, Value> &valueMapping) {
+  Location loc = op.getLoc();
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+  FailureOr<nvgpu::FragmentElementInfo> regInfo =
+      nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+  if (failed(regInfo)) {
+    op->emitError() << "Failed to deduce register fragment type during "
+                       "conversion to distributed non-ldmatrix compatible load";
+    return failure();
+  }
+
+  NVVM::MMALayout targetLayout =
+      warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B
+          ? NVVM::MMALayout::col
+          : NVVM::MMALayout::row;
+
+  Value laneId = builder.create<gpu::LaneIdOp>(loc);
+  SmallVector<Value, 4> elements;
+
+  // This is the individual element type.
+  Type loadedElType = regInfo->registerLLVMType;
+  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+
+  Value fill = builder.create<arith::ConstantOp>(
+      op.getLoc(), vectorType.getElementType(),
+      builder.getZeroAttr(vectorType.getElementType()));
+  Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+
+  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
+
+  // Vectorized loads.
+  if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) {
+    if (!loadedElType.isa<VectorType>()) {
+      loadedElType = VectorType::get({1}, loadedElType);
+    }
+
+    for (int i = 0; i < vectorType.getShape()[0]; i++) {
+      FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
+          op.getLoc(), builder, *warpMatrixInfo);
+      if (failed(coords))
+        return failure();
+      Value logicalValueId = builder.create<arith::ConstantOp>(
+          loc, builder.getIndexType(),
+          builder.getIndexAttr(i * regInfo->elementsPerRegister));
+      SmallVector<Value, 4> newIndices;
+      getXferIndices<vector::TransferReadOp>(
+          builder, op, *coords, {laneId, logicalValueId}, newIndices);
+
+      Value el = builder.create<vector::LoadOp>(loc, loadedElType,
+                                                op.getSource(), newIndices);
+      result = builder.create<vector::InsertOp>(loc, el, result,
+                                                builder.getI64ArrayAttr(i));
+    }
+  } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) {
+    if (auto vecType = loadedElType.dyn_cast<VectorType>()) {
+      loadedElType = vecType.getElementType();
+    }
+    // Load each element individually.
+    for (int i = 0; i < vectorType.getShape()[0]; i++) {
+      for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
+           innerIdx++) {
+
+        Value logicalValueId = builder.create<arith::ConstantOp>(
+            loc, builder.getIndexType(),
+            builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
+        FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
+            op.getLoc(), builder, *warpMatrixInfo);
+        if (failed(coords))
+          return failure();
+
+        SmallVector<Value, 4> newIndices;
+        getXferIndices<vector::TransferReadOp>(
+            builder, op, *coords, {laneId, logicalValueId}, newIndices);
+        Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType,
+                                                  op.getSource(), newIndices);
+        result = builder.create<vector::InsertOp>(
+            op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx}));
+      }
+    }
+  } else {
+    return failure();
+  }
+
+  valueMapping[op.getResult()] = result;
+  return success();
+}
+
+/// Converts a `vector.transfer_read` operation directly to either a
+/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
+/// used when converting to `nvgpu.mma.sync` operations.
+static LogicalResult
+convertTransferReadToLoads(vector::TransferReadOp op,
+                           llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder b(op);
+
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+
+  bool isLdMatrixCompatible =
+      op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 &&
+      nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
+
+  VectorType vecTy = op.getVectorType();
+  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
+
+  // When we are transposing the B operand, ldmatrix will only work if we have
+  // at least 8 rows to read and  the width to read for the transpose is 128
+  // bits.
+  if (!op.getPermutationMap().isMinorIdentity() &&
+      (vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128))
+    isLdMatrixCompatible = false;
+
+  if (!isLdMatrixCompatible)
+    return createNonLdMatrixLoads(op, b, valueMapping);
+
+  return creatLdMatrixCompatibleLoads(op, b, valueMapping);
+}
+
+static LogicalResult
+convertTransferWriteToStores(vector::TransferWriteOp op,
+                             llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder b(op);
+  Location loc = op->getLoc();
+  Value matrix = valueMapping.find(op.getVector())->second;
+
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+  FailureOr<nvgpu::FragmentElementInfo> regInfo =
+      nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+  if (failed(regInfo))
+    return failure();
+
+  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+  Value laneId = b.create<gpu::LaneIdOp>(loc);
+
+  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
+    Value logicalValueId = b.create<arith::ConstantOp>(
+        loc, b.getIndexType(),
+        b.getIndexAttr(i * regInfo->elementsPerRegister));
+    FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
+        op.getLoc(), b, *warpMatrixInfo);
+    if (failed(coords))
+      return failure();
+
+    Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
+    SmallVector<Value, 4> newIndices;
+    getXferIndices<vector::TransferWriteOp>(
+        b, op, *coords, {laneId, logicalValueId}, newIndices);
+    b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
+  }
+  op->erase();
+  return success();
+}
+
 static void convertContractOp(vector::ContractionOp op,
                               llvm::DenseMap<Value, Value> &valueMapping) {
   OpBuilder b(op);
@@ -397,6 +682,22 @@ static void convertContractOp(vector::ContractionOp op,
   valueMapping[op.getResult()] = matmul;
 }
 
+static LogicalResult
+convertContractOpToMmaSync(vector::ContractionOp op,
+                           llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder b(op);
+  Value opA = valueMapping.find(op.getLhs())->second;
+  Value opB = valueMapping.find(op.getRhs())->second;
+  Value opC = valueMapping.find(op.getAcc())->second;
+  int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
+  int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
+  int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
+  Value matmul = b.create<nvgpu::MmaSyncOp>(
+      op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k}));
+  valueMapping[op.getResult()] = matmul;
+  return success();
+}
+
 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
 static void convertConstantOp(arith::ConstantOp op,
                               llvm::DenseMap<Value, Value> &valueMapping) {
@@ -509,13 +810,20 @@ static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType,
   valueMapping[op->getResult(0)] = newOp;
 }
 
-void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
-  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
-      patterns.getContext());
+void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
+                                              bool useNvGpu) {
+  if (!useNvGpu) {
+    patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
+        patterns.getContext());
+    return;
+  }
+  patterns
+      .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>(
+          patterns.getContext());
 }
 
 void mlir::convertVectorToMMAOps(Operation *rootOp) {
-  SetVector<Operation *> ops = getOpToConvert(rootOp);
+  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
   llvm::DenseMap<Value, Value> valueMapping;
   for (Operation *op : ops) {
     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
@@ -538,21 +846,71 @@ void mlir::convertVectorToMMAOps(Operation *rootOp) {
   }
 }
 
+LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
+  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
+  llvm::DenseMap<Value, Value> valueMapping;
+  for (Operation *op : ops) {
+    if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
+            .Case([&](vector::TransferReadOp transferReadOp) {
+              return convertTransferReadToLoads(transferReadOp, valueMapping);
+            })
+            .Case([&](vector::TransferWriteOp transferWriteOp) {
+              return convertTransferWriteToStores(transferWriteOp,
+                                                  valueMapping);
+            })
+            .Case([&](vector::ContractionOp contractionOp) {
+              return convertContractOpToMmaSync(contractionOp, valueMapping);
+            })
+            .Case([&](scf::ForOp forOp) {
+              convertForOp(forOp, valueMapping);
+              return success();
+            })
+            .Case([&](scf::YieldOp yieldOp) {
+              convertYieldOp(yieldOp, valueMapping);
+              return success();
+            })
+            .Case([&](arith::ConstantOp constOp) {
+              return convertConstantOpMmaSync(constOp, valueMapping);
+            })
+            .Default([&](Operation *op) {
+              op->emitError() << "unhandled vector to mma type: " << *op;
+              return failure();
+            })
+            .failed()) {
+      op->emitError() << "Failed to convert op " << *op;
+      return failure();
+    }
+  }
+  return success();
+}
+
 namespace {
 
 struct ConvertVectorToGPUPass
     : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
+
+  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
+    useNvGpu.setValue(useNvGpu_);
+  }
+
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populatePrepareVectorToMMAPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+
+    if (useNvGpu.getValue()) {
+      if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
+        return signalPassFailure();
+    }
 
-    convertVectorToMMAOps(getOperation());
+    (void)convertVectorToMMAOps(getOperation());
   }
 };
 
 } // namespace
 
-std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() {
-  return std::make_unique<ConvertVectorToGPUPass>();
+std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
+  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
 }
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
new file mode 100644 (file)
index 0000000..be8d08b
--- /dev/null
@@ -0,0 +1,349 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline="func.func(convert-vector-to-gpu{use-nvgpu=true})" | FileCheck %s
+
+//#########################################################
+// INT8 row-row-row
+//#########################################################
+
+// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16 + 1)>
+
+// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 39)>
+// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)>
+// CHECK-DAG: [[$rowB1_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 40)>
+// CHECK-DAG: [[$rowB2_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 41)>
+// CHECK-DAG: [[$rowB3_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 42)>
+// CHECK-DAG: [[$rowB4_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 55)>
+// CHECK-DAG: [[$rowB5_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 56)>
+// CHECK-DAG: [[$rowB6_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 57)>
+// CHECK-DAG: [[$rowB7_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 58)>
+
+// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)>
+// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)>
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 57)>
+
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m16n8k32_int8_row_row_row
+func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8, 3>, %arg2: memref<128x128xi32>) {
+  %cst_0 = arith.constant dense<0> : vector<32x8xi8>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c17 = arith.constant 17 : index  
+  %c39 = arith.constant 39 : index  
+  %c40 = arith.constant 40 : index  
+  %c49 = arith.constant 49 : index  
+  %c50 = arith.constant 50 : index  
+  %cst = arith.constant 0 : i8
+  %cst0 = arith.constant 0 : i32
+
+  // Verify that the operand A is distributed to loads correctly.
+
+  // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[{{%.+}}]
+  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8>
+
+  // Verify that the operand B is distributed to loads correctly. It's elements
+  // must be loaded in a non-vectorized manner to do the transpose.
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB1_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB2_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB3_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB4_map]]()[{{%.+}}]  
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB5_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB6_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB7_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-NOT: memref.load %arg1
+
+  // Verify that the operand C is distributed to loads correctly.
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK-NOT: vector.load %arg2{{.*}}
+
+  %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8>
+  %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, 3>, vector<8x32xi8>
+  %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
+  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32>
+
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+  // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+  // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32>
+  return
+}
+
+// -----
+
+//#########################################################
+// f64 row-row-row
+//#########################################################
+// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 1)>
+// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 1)>
+
+// CHECK-DAG: [[$rowb0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 39)>
+// CHECK-DAG: [[$colb0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)>
+
+// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)>
+// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m8n8k4_f64_row_row_row
+func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x128xf64>, %arg2: memref<128x128xf64>) {
+  %cst_0 = arith.constant dense<0.0> : vector<4x8xf64>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c17 = arith.constant 17 : index  
+  %c39 = arith.constant 39 : index  
+  %c40 = arith.constant 40 : index  
+  %c49 = arith.constant 49 : index  
+  %c50 = arith.constant 50 : index  
+  %cst = arith.constant 0.0 : f64
+  %cst0 = arith.constant 0.0 : f64
+
+  // Verify that the operand A is distributed to loads correctly.
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]]
+  // CHECK: vector.load %arg0[[[row]], [[col]]] : memref<128x128xf64>, vector<1xf64>
+
+  // Verify that the operand B is distributed to loads correctly. It's elements
+  // must be loaded in a non-vectorized manner to do the transpose.
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowb0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colb0_map]]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xf64>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]]
+  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64>  
+
+  %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x4xf64>
+  %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64>
+  %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64>
+  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]]
+  // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64>  
+  vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<8x8xf64>, memref<128x128xf64>
+  return
+}
+
+// -----
+
+//#########################################################
+// FP16 row-row-row
+//#########################################################
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
+
+// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)>
+// CHECK-DAG: [[$colB_map:#.+]] = affine_map<() -> (3)>
+
+// CHECK-LABEL: func @m16n8k16_fp16_row_row_row
+func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}  
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true}
+  %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3>
+  return
+}
+
+// -----
+
+// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
+// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)>
+// CHECK-DAG: [[$Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)>
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @batch_m16n8k16_fp16_row_row_row
+func.func @batch_m16n8k16_fp16_row_row_row(%arg0: memref<2x20x20xf16, 3>, %arg1: memref<2x20x20xf16, 3>, %arg2: memref<2x20x20xf16, 3>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16>
+  // CHECK: [[C0:%.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16>
+  %A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16>
+  
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$Brow_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$Bcol_map]]  
+  // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
+  %B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16>
+  
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]]
+  // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
+  %C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3>
+  return
+}
+
+// -----
+
+//#########################################################
+// FP16 row-col-row
+//#########################################################
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
+
+// CHECK: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + 1)>
+// CHECK: [[$colB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 + 3)>
+
+// CHECK-LABEL: func @m16n8k16_fp16_row_col_row
+func.func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32
+  // CHECK-SAME: transpose = false
+  
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32
+  // CHECK-SAME: transpose = false
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]   
+  // CHECK: nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 2 : i32
+  // CHECK-SAME: transpose = false
+  %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
+  %C = vector.transfer_read %arg2[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3>
+  return
+}
+
+// -----
+
+//#########################################################
+// TF32 (multiplicand) F32 (accumulator) row-row-row
+//#########################################################
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 4 + 3)>
+
+// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 3)>
+// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 3)>
+
+// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)>
+
+// CHECK-LABEL: func @m16n8k4_tf32_f32_row_row_row
+func.func @m16n8k4_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.000000e+00 : f32
+
+  // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+  // CHECK: [[a_frag:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false}  
+
+  // b and c are not loaded by ldmatrix in this test.
+  // CHECK-NOT: nvgpu.ldmatrix
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+  // CHECK: [[b_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>  
+  // CHECK: [[b_frag:%.+]] = vector.insert [[b_el]], {{.*}} : f32 into vector<1x1xf32>
+
+  // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
+  // CHECK-SAME: mmaShape = [16, 8, 4]
+  // CHECK-SAME: -> vector<2x2xf32>
+  %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<16x4xf32>
+  %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x4xf32>  
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x4xf32>, vector<8x4xf32> into vector<16x8xf32>
+
+  // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32>
+  // CHECK: affine.apply [[$rowC_map]]
+  // CHECK: affine.apply [[$colC_map]]
+  // CHECK: vector.store
+  // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32>
+  // CHECK: affine.apply [[$rowC8_map]]
+  // CHECK: affine.apply [[$colC_map]]
+  // CHECK: vector.store
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32>
+  return
+}