From 114ba722c1e58d23bafdf3654e4f8e537150c318 Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Wed, 12 Oct 2022 05:17:32 +0000 Subject: [PATCH] [mlir][NVGPU] Handle native mma.sync and ldmatrix(x4) sizes This patch handles native `mma.sync` sizes and enables issuing `ldmatrix` on largest possible tiles for matrixB. It requires handling `vector.extract_strided_slice` from vector to ngpu lowering. Differential Revision: https://reviews.llvm.org/D135749 --- mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h | 15 ++- mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp | 119 ++++++++++++++++++++- mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp | 36 ++++--- .../VectorToGPU/vector-to-mma-ops-mma-sync.mlir | 62 ++++++++++- 4 files changed, 202 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h index 699e9fd..fac99dc 100644 --- a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h +++ b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h @@ -13,25 +13,22 @@ #ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H #define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" namespace mlir { -namespace vector { -enum class IteratorType : uint32_t; -class ContractionOp; -} // namespace vector - -namespace NVVM { -enum class MMALayout : uint32_t; -} // namespace NVVM - namespace nvgpu { /// Represents the role of an operand in an MMA instruction: /// `result := matmul(A, B) + C` enum class MatMulOperandRole : int32_t { A = 0, B, C }; +/// Returns the first user of the `op` that is vector.contract. If no +/// vector.contract user exists, return failure. +FailureOr getUserContract(Operation *op); + /// Collects information about a warp-level matrix operand represented by a /// VectorType. struct WarpMatrixInfo { diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index f4528b1..01654fd 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -192,6 +192,33 @@ static bool elementwiseSupportsMMAMatrixType(Operation *op) { return convertElementwiseOpToMMA(op).has_value(); } +/// Returns true if the extract strided slice op is supported with `mma.sync` +/// path. +static bool +extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) { + + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return false; + + FailureOr contractOp = nvgpu::getUserContract(op); + if (failed(contractOp)) + return false; + + // Handle vector.extract_strided_slice on registers containing + // matrixB and matrixC operands. vector.extract_strided_slice op + // is not supported on registers containing matrixA operands. + if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) + return (op->getResult(0).getType().cast() == + (*contractOp).getRhs().getType().cast()); + else if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) + return (op->getResult(0).getType().cast() == + (*contractOp).getAcc().getType().cast()); + + return false; +} + static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { if (isa(op)) return true; @@ -199,6 +226,9 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); if (auto transferWrite = dyn_cast(op)) return transferWriteSupportsMMAMatrixType(transferWrite); + if (auto extractStridedSlice = dyn_cast(op)) + return useNvGpu && + extractStridedSliceSupportsMMAMatrixType(extractStridedSlice); if (auto contract = dyn_cast(op)) return contractSupportsMMAMatrixType(contract, useNvGpu); if (auto constant = dyn_cast(op)) @@ -338,8 +368,10 @@ struct PrepareContractToGPUMMA } }; -// Merge transpose op into the transfer read op. Transpose are not supported on -// MMA types but MMA load can transpose the matrix when loading. +// Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports +// row-, column-, and row-major layout for matrixA, matrixB, and matrixC, +// respectively. We can fold the transpose operation when loading the data from +// Shared Memory to registers. struct CombineTransferReadOpTranspose final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -620,7 +652,7 @@ convertTransferReadToLoads(vector::TransferReadOp op, int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); // When we are transposing the B operand, ldmatrix will only work if we have - // at least 8 rows to read and the width to read for the transpose is 128 + // at least 8 rows to read and the width to read for the transpose is 128 // bits. if (!op.getPermutationMap().isMinorIdentity() && (bitWidth != 16 || vecTy.getDimSize(1) < 8 || @@ -671,6 +703,83 @@ convertTransferWriteToStores(vector::TransferWriteOp op, return success(); } +static void populateFromInt64AttrArray(ArrayAttr arrayAttr, + SmallVectorImpl &results) { + for (auto attr : arrayAttr) + results.push_back(attr.cast().getInt()); +} + +static LogicalResult +convertExtractStridedSlice(vector::ExtractStridedSliceOp op, + llvm::DenseMap &valueMapping) { + + OpBuilder b(op); + Location loc = op->getLoc(); + + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + + FailureOr mmaSyncFragmentInfo = + nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); + if (failed(mmaSyncFragmentInfo)) + return failure(); + + // Find the vector.transer_read whose result vector is being sliced. + auto transferReadOp = op.getVector().getDefiningOp(); + if (!transferReadOp) + return failure(); + + warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp); + if (failed(warpMatrixInfo)) + return failure(); + + FailureOr ldFragmentInfo = + nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); + if (failed(ldFragmentInfo)) + return failure(); + + assert( + (mmaSyncFragmentInfo->elementsPerRegister == + ldFragmentInfo->elementsPerRegister) && + "Number of elements per register should be same for load and mma.sync"); + + // Create vector.extract_strided_slice op for thread-owned fragments. + std::array strides = {1, + 1}; // stride for extract slice is always 1. + std::array sliceShape = { + mmaSyncFragmentInfo->numRegistersPerFragment, + mmaSyncFragmentInfo->elementsPerRegister}; + auto sourceVector = valueMapping.find(transferReadOp)->second; + + // offset and sizes at warp-level of onwership. + SmallVector offsets; + populateFromInt64AttrArray(op.getOffsets(), offsets); + + SmallVector sizes; + populateFromInt64AttrArray(op.getSizes(), sizes); + ArrayRef warpVectorShape = op.getVectorType().getShape(); + + // Compute offset in vector registers. Note that the mma.sync vector registers + // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector + // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. + std::array sliceOffset = {0, 0}; + + if (offsets[0] && offsets[1]) + return op->emitError() << "Slicing fragments in 2D is not supported. "; + else if (offsets[0]) + sliceOffset[0] = (warpVectorShape[0] / offsets[0]); + else if (offsets[1]) + sliceOffset[0] = (warpVectorShape[1] / offsets[1]); + + Value newOp = b.create( + loc, sourceVector, sliceOffset, sliceShape, strides); + + valueMapping[op] = newOp; + return success(); +} + static void convertContractOp(vector::ContractionOp op, llvm::DenseMap &valueMapping) { OpBuilder b(op); @@ -858,6 +967,10 @@ LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) { return convertTransferWriteToStores(transferWriteOp, valueMapping); }) + .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) { + return convertExtractStridedSlice(extractStridedSliceOp, + valueMapping); + }) .Case([&](vector::ContractionOp contractionOp) { return convertContractOpToMmaSync(contractionOp, valueMapping); }) diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp index 18fc4e6..6de16f8 100644 --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -45,14 +45,24 @@ static std::array getTileShape(ArrayRef operandShape, lineSizeBits}; } +/// Returns the first user of the `op` that is vector.contract. If no +/// vector.contract user exists, return failure. +FailureOr nvgpu::getUserContract(Operation *op) { + for (Operation *user : op->getUsers()) { + if (auto contractOp = dyn_cast(user)) + return contractOp; + } + return failure(); +} + FailureOr nvgpu::getWarpMatrixInfo(Operation *op) { WarpMatrixInfo info; - // Determine the vector type. + // Determine the vector type at warp-level. if (vector::TransferWriteOp writeOp = dyn_cast(op)) { info.vectorType = writeOp.getVectorType(); } else if (isa(op)) { + vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) { info.vectorType = op->getResult(0).getType().cast(); } else { return op->emitError() @@ -62,19 +72,15 @@ FailureOr nvgpu::getWarpMatrixInfo(Operation *op) { // Determine the operand role. We assume it is an accumulator/result unless it // is directly consumed by a `vector.contract` op. info.operandRole = MatMulOperandRole::C; - for (Operation *user : op->getUsers()) { - auto contract = dyn_cast(user); - if (!contract) - continue; - if (contract.getLhs() == op->getResult(0)) { - info.operandRole = MatMulOperandRole::A; - break; - } - if (contract.getRhs() == op->getResult(0)) { - info.operandRole = MatMulOperandRole::B; - break; - } - } + FailureOr contractOp = getUserContract(op); + if (failed(contractOp)) + return info; + + if ((*contractOp).getLhs() == op->getResult(0)) + info.operandRole = MatMulOperandRole::A; + else if ((*contractOp).getRhs() == op->getResult(0)) + info.operandRole = MatMulOperandRole::B; + return info; } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir index 42dc06c..ae6329c 100644 --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir @@ -164,9 +164,9 @@ func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x // ----- -//######################################################### -// FP16 row-row-row -//######################################################### +//######################################################################### +// FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x2 for matrixB) +//######################################################################### #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> @@ -203,6 +203,62 @@ func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref< // ----- +//######################################################################### +// FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB) +//######################################################################### + +// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> +// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 - ((s0 floordiv 8) floordiv 2) * 16)> +// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + ((s0 floordiv 8) floordiv 2) * 8)> + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @m16n16k16_mmasync16816_fp16_f16_row_row_row +func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16, 3>, %arg1: memref<32x64xf16, 3>, %arg2: memref<42x64xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x32xf16, 3>, vector<16x16xf16> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] + // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[col]], [[row]]] {numTiles = 4 : i32, transpose = true} + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, 3>, vector<16x16xf16> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x64xf16, 3>, vector<16x16xf16> + + // CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16> + // CHECK-DAG: [[fragmentC0:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16> + // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> + %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16> + %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, 3> + + // CHECK-DAG: [[fragmentB1:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16> + // CHECK-DAG: [[fragmentC1:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16> + // CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB1]], [[fragmentC1]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + %B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16> + %C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16> + %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B1, %C1 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, 3> + + return +} +// ----- + // CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> // CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> // CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)> -- 2.7.4