[mlir][NVGPU] Handle native mma.sync and ldmatrix(x4) sizes
authorManish Gupta <manigupta@google.com>
Wed, 12 Oct 2022 05:17:32 +0000 (05:17 +0000)
committerManish Gupta <manigupta@google.com>
Thu, 20 Oct 2022 00:10:21 +0000 (17:10 -0700)
This patch handles native `mma.sync` sizes and enables issuing `ldmatrix` on
largest possible tiles for matrixB. It requires handling
`vector.extract_strided_slice` from vector to ngpu lowering.

Differential Revision: https://reviews.llvm.org/D135749

mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir

index 699e9fd..fac99dc 100644 (file)
 #ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
 #define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
 
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Types.h"
 
 namespace mlir {
-namespace vector {
-enum class IteratorType : uint32_t;
-class ContractionOp;
-} // namespace vector
-
-namespace NVVM {
-enum class MMALayout : uint32_t;
-} // namespace NVVM
-
 namespace nvgpu {
 
 /// Represents the role of an operand in an MMA instruction:
 /// `result := matmul(A, B) + C`
 enum class MatMulOperandRole : int32_t { A = 0, B, C };
 
+/// Returns the first user of the `op` that is vector.contract. If no
+/// vector.contract user exists, return failure.
+FailureOr<vector::ContractionOp> getUserContract(Operation *op);
+
 /// Collects information about a warp-level matrix operand represented by a
 /// VectorType.
 struct WarpMatrixInfo {
index f4528b1..01654fd 100644 (file)
@@ -192,6 +192,33 @@ static bool elementwiseSupportsMMAMatrixType(Operation *op) {
   return convertElementwiseOpToMMA(op).has_value();
 }
 
+/// Returns true if the extract strided slice op is supported with `mma.sync`
+/// path.
+static bool
+extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
+
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return false;
+
+  FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
+  if (failed(contractOp))
+    return false;
+
+  // Handle vector.extract_strided_slice on registers containing
+  // matrixB and matrixC operands. vector.extract_strided_slice op
+  // is not supported on registers containing matrixA operands.
+  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
+    return (op->getResult(0).getType().cast<VectorType>() ==
+            (*contractOp).getRhs().getType().cast<VectorType>());
+  else if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
+    return (op->getResult(0).getType().cast<VectorType>() ==
+            (*contractOp).getAcc().getType().cast<VectorType>());
+
+  return false;
+}
+
 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
   if (isa<scf::ForOp, scf::YieldOp>(op))
     return true;
@@ -199,6 +226,9 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
     return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
     return transferWriteSupportsMMAMatrixType(transferWrite);
+  if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
+    return useNvGpu &&
+           extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
   if (auto contract = dyn_cast<vector::ContractionOp>(op))
     return contractSupportsMMAMatrixType(contract, useNvGpu);
   if (auto constant = dyn_cast<arith::ConstantOp>(op))
@@ -338,8 +368,10 @@ struct PrepareContractToGPUMMA
   }
 };
 
-// Merge transpose op into the transfer read op. Transpose are not supported on
-// MMA types but MMA load can transpose the matrix when loading.
+// Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
+// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
+// respectively. We can fold the transpose operation when loading the data from
+// Shared Memory to registers.
 struct CombineTransferReadOpTranspose final
     : public OpRewritePattern<vector::TransposeOp> {
   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
@@ -620,7 +652,7 @@ convertTransferReadToLoads(vector::TransferReadOp op,
   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
 
   // When we are transposing the B operand, ldmatrix will only work if we have
-  // at least 8 rows to read and  the width to read for the transpose is 128
+  // at least 8 rows to read and the width to read for the transpose is 128
   // bits.
   if (!op.getPermutationMap().isMinorIdentity() &&
       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
@@ -671,6 +703,83 @@ convertTransferWriteToStores(vector::TransferWriteOp op,
   return success();
 }
 
+static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
+                                       SmallVectorImpl<int64_t> &results) {
+  for (auto attr : arrayAttr)
+    results.push_back(attr.cast<IntegerAttr>().getInt());
+}
+
+static LogicalResult
+convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
+                           llvm::DenseMap<Value, Value> &valueMapping) {
+
+  OpBuilder b(op);
+  Location loc = op->getLoc();
+
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+
+  FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
+      nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+  if (failed(mmaSyncFragmentInfo))
+    return failure();
+
+  // Find the vector.transer_read whose result vector is being sliced.
+  auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
+  if (!transferReadOp)
+    return failure();
+
+  warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
+  if (failed(warpMatrixInfo))
+    return failure();
+
+  FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
+      nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+  if (failed(ldFragmentInfo))
+    return failure();
+
+  assert(
+      (mmaSyncFragmentInfo->elementsPerRegister ==
+       ldFragmentInfo->elementsPerRegister) &&
+      "Number of elements per register should be same for load and mma.sync");
+
+  // Create vector.extract_strided_slice op for thread-owned fragments.
+  std::array<int64_t, 2> strides = {1,
+                                    1}; // stride for extract slice is always 1.
+  std::array<int64_t, 2> sliceShape = {
+      mmaSyncFragmentInfo->numRegistersPerFragment,
+      mmaSyncFragmentInfo->elementsPerRegister};
+  auto sourceVector = valueMapping.find(transferReadOp)->second;
+
+  // offset and sizes at warp-level of onwership.
+  SmallVector<int64_t> offsets;
+  populateFromInt64AttrArray(op.getOffsets(), offsets);
+
+  SmallVector<int64_t> sizes;
+  populateFromInt64AttrArray(op.getSizes(), sizes);
+  ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
+
+  // Compute offset in vector registers. Note that the mma.sync vector registers
+  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
+  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
+  std::array<int64_t, 2> sliceOffset = {0, 0};
+
+  if (offsets[0] && offsets[1])
+    return op->emitError() << "Slicing fragments in 2D is not supported. ";
+  else if (offsets[0])
+    sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
+  else if (offsets[1])
+    sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
+
+  Value newOp = b.create<vector::ExtractStridedSliceOp>(
+      loc, sourceVector, sliceOffset, sliceShape, strides);
+
+  valueMapping[op] = newOp;
+  return success();
+}
+
 static void convertContractOp(vector::ContractionOp op,
                               llvm::DenseMap<Value, Value> &valueMapping) {
   OpBuilder b(op);
@@ -858,6 +967,10 @@ LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
               return convertTransferWriteToStores(transferWriteOp,
                                                   valueMapping);
             })
+            .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
+              return convertExtractStridedSlice(extractStridedSliceOp,
+                                                valueMapping);
+            })
             .Case([&](vector::ContractionOp contractionOp) {
               return convertContractOpToMmaSync(contractionOp, valueMapping);
             })
index 18fc4e6..6de16f8 100644 (file)
@@ -45,14 +45,24 @@ static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
               lineSizeBits};
 }
 
+/// Returns the first user of the `op` that is vector.contract. If no
+/// vector.contract user exists, return failure.
+FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
+  for (Operation *user : op->getUsers()) {
+    if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
+      return contractOp;
+  }
+  return failure();
+}
+
 FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
   WarpMatrixInfo info;
 
-  // Determine the vector type.
+  // Determine the vector type at warp-level.
   if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
     info.vectorType = writeOp.getVectorType();
   } else if (isa<vector::TransferReadOp, vector::ContractionOp,
-                 arith::ConstantOp>(op)) {
+                 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
     info.vectorType = op->getResult(0).getType().cast<VectorType>();
   } else {
     return op->emitError()
@@ -62,19 +72,15 @@ FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
   // Determine the operand role. We assume it is an accumulator/result unless it
   // is directly consumed by a `vector.contract` op.
   info.operandRole = MatMulOperandRole::C;
-  for (Operation *user : op->getUsers()) {
-    auto contract = dyn_cast<vector::ContractionOp>(user);
-    if (!contract)
-      continue;
-    if (contract.getLhs() == op->getResult(0)) {
-      info.operandRole = MatMulOperandRole::A;
-      break;
-    }
-    if (contract.getRhs() == op->getResult(0)) {
-      info.operandRole = MatMulOperandRole::B;
-      break;
-    }
-  }
+  FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
+  if (failed(contractOp))
+    return info;
+
+  if ((*contractOp).getLhs() == op->getResult(0))
+    info.operandRole = MatMulOperandRole::A;
+  else if ((*contractOp).getRhs() == op->getResult(0))
+    info.operandRole = MatMulOperandRole::B;
+
   return info;
 }
 
index 42dc06c..ae6329c 100644 (file)
@@ -164,9 +164,9 @@ func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x
 
 // -----
 
-//#########################################################
-// FP16 row-row-row
-//#########################################################
+//#########################################################################
+// FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x2 for matrixB)
+//#########################################################################
 
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
@@ -203,6 +203,62 @@ func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<
 
 // -----
 
+//#########################################################################
+// FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB)
+//#########################################################################
+
+// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
+// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
+// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 - ((s0 floordiv 8) floordiv 2) * 16)>
+// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + ((s0 floordiv 8) floordiv 2) * 8)>
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m16n16k16_mmasync16816_fp16_f16_row_row_row
+func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16, 3>, %arg1: memref<32x64xf16, 3>, %arg2: memref<42x64xf16, 3>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %cst = arith.constant 0.000000e+00 : f16
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+  // CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}  
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x32xf16, 3>, vector<16x16xf16>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+  // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[col]], [[row]]] {numTiles = 4 : i32, transpose = true}
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, 3>, vector<16x16xf16>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+  // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x64xf16, 3>, vector<16x16xf16>
+
+  // CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
+  // CHECK-DAG: [[fragmentC0:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
+  // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
+  %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
+  %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, 3>
+
+  // CHECK-DAG: [[fragmentB1:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
+  // CHECK-DAG: [[fragmentC1:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
+  // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB1]], [[fragmentC1]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
+  %C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
+  %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B1, %C1 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, 3>
+
+  return
+}
+// -----
+
 // CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
 // CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
 // CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)>