From: Christopher Bate Date: Sat, 5 Nov 2022 15:43:24 +0000 (-0600) Subject: [mlir][NVGPU] Add support for structured sparsity MMA variants X-Git-Tag: upstream/17.0.6~28325 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=708185f03ff480b3481132802b7b63461564f0ab;p=platform%2Fupstream%2Fllvm.git [mlir][NVGPU] Add support for structured sparsity MMA variants This change adds a new NVGPU operation that targets the PTX `mma.sp.sync` instruction variants. A lowering to NVVM is provided using inline assembly. Reviewed By: ThomasRaoux, manishucsd Differential Revision: https://reviews.llvm.org/D137202 --- diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index 138ffc8..db4ee53 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -98,10 +98,24 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", [ let hasVerifier = 1; } -def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [ - Pure, - PredOpTrait<"matrixA and matrixB have same element type", - TCopVTEtIsSameAs<0, 1>>]> { +class NVGPU_MmaSyncOp : + NVGPU_Op>]> { + code extraBaseClassDeclaration = [{ + std::array getMmaShapeAsArray() { + ArrayAttr mmaShape = this->getMmaShape(); + assert(mmaShape.size() == 3 && "mmaShape should be three integers"); + return {mmaShape[0].cast().getInt(), + mmaShape[1].cast().getInt(), + mmaShape[2].cast().getInt()}; + } + }]; + + let hasVerifier = 1; +} + +def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> { let description = [{ The `nvgpu.mma.sync` op represents the warp-level matrix-multiply-and- accumulate (mma) operation that is compatible with `nvvm.mma.sync`. @@ -143,9 +157,63 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [ `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res) }]; - let hasVerifier = 1; + let extraClassDeclaration = extraBaseClassDeclaration; } +def NVGPU_MmaSparseSyncMetadataType : FixedVectorOfLengthAndType<[2], [I16]>, + BuildableType<"::mlir::VectorType::get(" + "{2},$_builder.getI16Type())">; + +def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> { + let description = [{ + The `nvgu.mma.sp.sync` operation performs a warp-distributed MMA operation + where operand A is "structured sparse". In this case, the `matrixA` operand + represents the (warp-distributed) non-zero values of operand A, and the + `sparse_metadata` operand provides the indices. + + The full description of the sparsity storage format and distribution scheme is + described in the PTX docs. This operation is meant to follow the semantic + described in the PTX documentation here: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma + + The way the indices are distributed among the threads in a warp is controlled + by the optional `sparsity_selector` operand, which is `0` by default. For + more information, please consult the PTX documentation linked above. + + Example (targetingthe f16 16x8x32 `mma.sp` PTX instruction): + + ```mlir + nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} : + (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + ``` + }]; + + let arguments = (ins AnyVector:$matrixA, + AnyVector:$matrixB, + AnyVector:$matrixC, + NVGPU_MmaSparseSyncMetadataType:$sparseMetadata, + I64ArrayAttr:$mmaShape, + DefaultValuedAttr:$sparsitySelector, + OptionalAttr:$tf32Enabled + ); + + let results = (outs AnyVector:$res); + + let builders = [ + OpBuilder<(ins "Value":$matrixA, + "Value":$matrixB, + "Value":$matrixC, + "Value":$sparseMetadata, + "ArrayRef":$mmaShape)> + ]; + + let assemblyFormat = [{ + `(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict + `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res) + }]; + + let extraClassDeclaration = extraBaseClassDeclaration; +} def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [ AttrSizedOperandSegments]> { diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index c4c49f2..d9f54b8 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -11,8 +11,10 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -253,6 +255,23 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { } }; +/// Convert the given type into the corresponding PTX type (NVVM::MMATypes +/// enum). +static FailureOr getNvvmMmaType(Type t) { + Type elType = getElementTypeOrSelf(t); + if (elType.isInteger(8)) + return NVVM::MMATypes::s8; + if (elType.isInteger(4)) + return NVVM::MMATypes::s4; + if (elType.isF16()) + return NVVM::MMATypes::f16; + if (elType.isF64()) + return NVVM::MMATypes::f64; + if (elType.isF32()) + return NVVM::MMATypes::tf32; + return failure(); +} + struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -262,53 +281,38 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { Location loc = op->getLoc(); // Get the shapes of the MMAMatrix type being used. The shapes will // choose which intrinsic this op will be lowered to. - auto aType = op.getMatrixA().getType().cast(); - auto cType = op.getMatrixC().getType().cast(); + VectorType aType = op.getMatrixA().getType(); + VectorType bType = op.getMatrixA().getType(); + VectorType cType = op.getMatrixC().getType(); - int64_t m = op.getMmaShape()[0].cast().getInt(); - int64_t n = op.getMmaShape()[1].cast().getInt(); - int64_t k = op.getMmaShape()[2].cast().getInt(); - std::array gemmShape{m, n, k}; + std::array gemmShape = op.getMmaShapeAsArray(); + + // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32). + bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); + if (aType.getElementType().isF32() && !tf32Enabled) + return failure(); - NVVM::MMATypes ptxTypeA; - NVVM::MMATypes ptxTypeB; + FailureOr ptxTypeA = getNvvmMmaType(aType); + if (failed(ptxTypeA)) + return op->emitOpError("failed to deduce operand PTX types"); + FailureOr ptxTypeB = getNvvmMmaType(bType); + if (failed(ptxTypeB)) + return op->emitOpError("failed to deduce operand PTX types"); Optional ptxTypeC = NVVM::MmaOp::inferOperandMMAType( cType.getElementType(), /*isAccumulator=*/true); if (!ptxTypeC) return op->emitError( "could not infer the PTX type for the accumulator/result"); - // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32). - bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); - if (aType.getElementType().isF32() && !tf32Enabled) - return failure(); - + // TODO: add an attribute to the op to customize this behavior. Optional overflow(llvm::None); - if (aType.getElementType().isInteger(8)) { - ptxTypeA = NVVM::MMATypes::s8; - ptxTypeB = NVVM::MMATypes::s8; + if (aType.getElementType().isa()) overflow = NVVM::MMAIntOverflow::satfinite; - } else if (aType.getElementType().isInteger(4)) { - ptxTypeA = NVVM::MMATypes::s4; - ptxTypeB = NVVM::MMATypes::s4; - overflow = NVVM::MMAIntOverflow::satfinite; - } else if (aType.getElementType().isF16()) { - ptxTypeA = NVVM::MMATypes::f16; - ptxTypeB = NVVM::MMATypes::f16; - } else if (aType.getElementType().isF64()) { - ptxTypeA = NVVM::MMATypes::f64; - ptxTypeB = NVVM::MMATypes::f64; - } else if (aType.getElementType().isF32()) { - ptxTypeA = NVVM::MMATypes::tf32; - ptxTypeB = NVVM::MMATypes::tf32; - } else { - return op->emitError("could not deduce operand PTX types"); - } SmallVector matA = - unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), ptxTypeA); + unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA); SmallVector matB = - unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), ptxTypeB); + unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB); SmallVector matC = unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC); @@ -321,7 +325,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { /*b1Op=*/llvm::None, /*intOverflow=*/overflow, /*multiplicandPtxTypes=*/ - std::array{ptxTypeA, ptxTypeB}, + std::array{*ptxTypeA, *ptxTypeB}, /*multiplicandLayouts=*/ std::array{NVVM::MMALayout::row, NVVM::MMALayout::col}); @@ -376,13 +380,182 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr, SmallVector asmVals{dstPtr, srcPtr, dstBytes, srcBytes}; rewriter.create( - loc, LLVM::LLVMVoidType::get(rewriter.getContext()), /*operands=*/asmVals, + loc, LLVM::LLVMVoidType::get(rewriter.getContext()), + /*operands=*/asmVals, /*asm_string=*/asmStr, /*constraints=*/asmConstraints, /*has_side_effects=*/true, /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, /*operand_attrs=*/ArrayAttr()); } +/// Returns the constraints for the sparse MMA inline assembly instruction. +static std::string buildMmaSparseAsmConstraintString(unsigned matASize, + unsigned matBSize, + unsigned matCSize) { + std::string str; + llvm::raw_string_ostream ss(str); + for (unsigned i = 0; i < matCSize; i++) + ss << "=r,"; + for (unsigned i = 0; i < matASize + matBSize + matCSize; i++) + ss << "r,"; + // The final two operands are for the sparsity metadata and sparsity selector. + ss << "r,r"; + ss.flush(); + return str; +} + +/// Returns the string for the `mma.sp.sync` instruction that corresponds to +/// the give parameters. Note that this function doesn't do any validation, +/// it's expected that the provided parameters correspond to a valid +/// instruction. +static std::string +buildMmaSparseAsmString(const std::array &shape, unsigned matASize, + unsigned matBSize, unsigned matCSize, + NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, + NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, + Optional overflow) { + auto ptxTypeStr = [](NVVM::MMATypes ptxType) { + return NVVM::stringifyMMATypes(ptxType); + }; + + std::string asmStr; + llvm::raw_string_ostream ss(asmStr); + ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k" + << shape[2] << ".row.col."; + + if (overflow) + ss << NVVM::stringifyMMAIntOverflow(*overflow) << "."; + + ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "." + << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " "; + unsigned asmArgIdx = 0; + + // The operand string is structured into sections `{matC elements...}, + // {matA elements...}, {matB elements...}, {matC elements}`. + for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) { + ss << "{"; + for (unsigned i = 0; i < arrSize; i++) + ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : ""); + ss << "},"; + } + ss << "$" << asmArgIdx++ << ",$" << asmArgIdx++ << ";"; + ss.flush(); + return asmStr; +} + +/// Builds an inline assembly operation corresponding to the specified MMA +/// sparse sync operation. +static FailureOr emitMmaSparseSyncOpAsm( + Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, + NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, + Optional overflow, ArrayRef unpackedAData, + ArrayRef unpackedB, ArrayRef unpackedC, Value indexData, + int64_t metadataSelector, const std::array &shape, + Type intrinsicResultType, ConversionPatternRewriter &rewriter) { + auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT); + + std::string asmStr = buildMmaSparseAsmString( + shape, unpackedAData.size(), unpackedB.size(), unpackedC.size(), ptxTypeA, + ptxTypeB, ptxTypeC, ptxTypeD, overflow); + std::string constraintStr = buildMmaSparseAsmConstraintString( + unpackedAData.size(), unpackedB.size(), unpackedC.size()); + + Value selectorVal = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(metadataSelector)); + + SmallVector asmVals; + asmVals.reserve(unpackedAData.size() + unpackedB.size() + unpackedC.size() + + 2); + for (ArrayRef args : {unpackedAData, unpackedB, unpackedC}) + llvm::append_range(asmVals, args); + asmVals.push_back(indexData); + asmVals.push_back(selectorVal); + + return rewriter.create(loc, + /*resultTypes=*/intrinsicResultType, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/constraintStr, + /*has_side_effects=*/true, + /*is_align_stack=*/false, + /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); +} + +/// Lowers `nvgpu.mma.sp.sync` to inline assembly. +struct NVGPUMmaSparseSyncLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + // Get the shapes of the MMAMatrix type being used. The shapes will + // choose which intrinsic this op will be lowered to. + VectorType aType = op.getMatrixA().getType(); + VectorType bType = op.getMatrixB().getType(); + VectorType cType = op.getMatrixC().getType(); + + FailureOr ptxTypeA = getNvvmMmaType(aType); + if (failed(ptxTypeA)) + return op->emitOpError("failed to deduce operand PTX types"); + FailureOr ptxTypeB = getNvvmMmaType(bType); + if (failed(ptxTypeB)) + return op->emitOpError("failed to deduce operand PTX types"); + Optional ptxTypeC = NVVM::MmaOp::inferOperandMMAType( + cType.getElementType(), /*isAccumulator=*/true); + if (!ptxTypeC) + return op->emitError( + "could not infer the PTX type for the accumulator/result"); + + // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32). + bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); + if (aType.getElementType().isF32() && !tf32Enabled) + return failure(); + + // TODO: add an attribute to the op to customize this behavior. + Optional overflow(llvm::None); + if (aType.getElementType().isa()) + overflow = NVVM::MMAIntOverflow::satfinite; + + SmallVector matA = + unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA); + SmallVector matB = + unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB); + SmallVector matC = + unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC); + + Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); + Type intrinsicResTy = inferIntrinsicResultType( + typeConverter->convertType(op->getResultTypes()[0])); + + // Bitcast the sparse metadata from vector<2xf16> to an i32. + Value sparseMetadata = adaptor.getSparseMetadata(); + if (sparseMetadata.getType() != + LLVM::getFixedVectorType(rewriter.getI16Type(), 2)) + return op->emitOpError() << "Expected metadata type to be LLVM " + "VectorType of 2 i16 elements"; + sparseMetadata = rewriter.create( + loc, rewriter.getI32Type(), sparseMetadata); + + FailureOr intrinsicResult = emitMmaSparseSyncOpAsm( + loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, + matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(), + intrinsicResTy, rewriter); + if (failed(intrinsicResult)) + return failure(); + + assert((*intrinsicResult).getNumResults() == 1 && + "expected inline asm op returns a single LLVM struct type"); + rewriter.replaceOp( + op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, + (*intrinsicResult)->getResult(0), rewriter)); + return success(); + } +}; + struct NVGPUAsyncCopyLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -488,8 +661,8 @@ struct NVGPUAsyncWaitLowering void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add( - converter); + NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, + NVGPUMmaSparseSyncLowering>(converter); } std::unique_ptr mlir::createConvertNVGPUToNVVMPass() { diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 9ed04b4..24f70cb 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -13,9 +13,11 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Verifier.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -80,13 +82,21 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder, mmaShape, UnitAttr()); } -LogicalResult MmaSyncOp::verify() { - - // Fundamental tensor core mma.sync op - // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core - // operation is of shape: 8-by-8-by-128b. F64 is an exception. The - // verification for mma.sync covering various shapes and data types is based - // on the fundamental tensor core operionation. +/// Performs verification for MmaSyncOp and MmaSparseSyncOp. +static LogicalResult verifyMmaSyncOp(Operation *op, + TypedValue matrixA, + TypedValue matrixB, + TypedValue matrixC, + const std::array &mmaShape, + bool tf32Enabled, bool sparse = false) { + + // The verification for mma.sync covering various shapes and data types is + // based on the fundamental tensor core shape. + + // "Fundamental" tensor core shapes: + // - For F32 (TF32), F16, S8, and S4 data + // types the fundamental tensor core operation is of shape 8-by-8-by-128b. + // - F64 is an exception and is of shape 8-by-8-by-256b. constexpr int kThreads = 32; // 32 threads per warp int64_t shapeM = 8; int64_t shapeN = 8; @@ -98,9 +108,9 @@ LogicalResult MmaSyncOp::verify() { int64_t numElementC{2}; // two accumulator elements per fundamental tile // nvgpu.mma.sync vector operands (per thread) - auto aVector = getMatrixA().getType().cast(); - auto bVector = getMatrixB().getType().cast(); - auto cVector = getMatrixC().getType().cast(); + auto aVector = matrixA.getType(); + auto bVector = matrixB.getType(); + auto cVector = matrixC.getType(); // vector shapes ArrayRef aShape = aVector.getShape(); @@ -110,13 +120,9 @@ LogicalResult MmaSyncOp::verify() { // vector element type Type aType = aVector.getElementType(); - // tensor float32 (TF32) enabled - bool tf32Enabled = getOperation()->hasAttr(getTf32EnabledAttrName()); - - // nvgpu.mma.sync shape (per 32 threads or per warp) - int64_t m = getMmaShape()[0].cast().getInt(); - int64_t n = getMmaShape()[1].cast().getInt(); - int64_t k = getMmaShape()[2].cast().getInt(); + // Certain data types are not allowed in sparse mode. + if (sparse && aType.isF64()) + return op->emitError() << "f64 is not supported for sparse mode"; if (aType.isF64()) { // exception to 8-by-8-128b fundamental tensor core tile size @@ -127,36 +133,43 @@ LogicalResult MmaSyncOp::verify() { aType.isInteger(8) || aType.isInteger(4)) { // 8-by-8-128b fundamental tensor core tile size int operandBitwidth = aType.getIntOrFloatBitWidth(); - shapeK = 128 / operandBitwidth; // 128b wide shapeK + shapeK = 128 / operandBitwidth; // 128b wide shapeK + numElementA = 32 / operandBitwidth; // 32b wide operand A numElementB = 32 / operandBitwidth; // 32b wide operand B } else { - return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) " - "supported by nvgpu.mma.sync"; + return op->emitError() + << "expected input data type (i4,i8,f16,bf16,tf32,f64) " + "supported by " + << op->getName(); } // // Basic verification // + auto [m, n, k] = mmaShape; + // verify warp-wide size for vector a - if (aShape[0] * aShape[1] * kThreads != m * k) - return emitOpError() << "expected " << m * k - << " warp-wide matrix A elements"; + int64_t sparseFactor = sparse ? 2 : 1; + if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor) + return op->emitOpError() + << "expected " << m * k << " warp-wide matrix A elements"; // verify warp-wide size for vector b if (bShape[0] * bShape[1] * kThreads != k * n) - return emitOpError() << "expected " << k * n - << " warp-wide matrix B elements"; + return op->emitOpError() + << "expected " << k * n << " warp-wide matrix B elements"; // verify warp-wide size for vector c if (cShape[0] * cShape[1] * kThreads != m * n) - return emitOpError() << "expected " << m * n - << " warp-wide matrix C elements"; + return op->emitOpError() + << "expected " << m * n << " warp-wide matrix C elements"; // verify tf32 tensor cores are enabled for only F32 datatype if (tf32Enabled && !(aType.isF32())) - return emitOpError() << "expected tf32 tensor cores only for F32 operands"; + return op->emitOpError() + << "expected tf32 tensor cores only for F32 operands"; // // Extended verification @@ -168,23 +181,48 @@ LogicalResult MmaSyncOp::verify() { int64_t kTile = k / shapeK; // verify shape of aVector - if ((aShape[0] != mTile * kTile) || (aShape[1] != numElementA)) - return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile - << " x " << numElementA << ")"; + if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) || + (aShape[1] != numElementA)) + return op->emitOpError() << "expected matrix A to be shaped (" + << mTile * kTile << " x " << numElementA << ")"; // verify shape of bVector if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB)) - return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile - << " x " << numElementB << ")"; + return op->emitOpError() << "expected matrix B to be shaped (" + << kTile * nTile << " x " << numElementB << ")"; // verify shape of cVector if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC)) - return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile - << " x " << numElementC << ")"; + return op->emitOpError() << "expected matrix C to be shaped (" + << mTile * nTile << " x " << numElementC << ")"; return success(); } +LogicalResult MmaSyncOp::verify() { + return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(), + getMatrixC(), getMmaShapeAsArray(), + getOperation()->hasAttr(getTf32EnabledAttrName())); +} + +//===----------------------------------------------------------------------===// +// NVGPU_MmaSparseSyncOp +//===----------------------------------------------------------------------===// +void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, Value matrixA, + Value matrixB, Value matrixC, Value sparseMetadata, + ArrayRef mmaShape) { + build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, + sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr()); +} + +LogicalResult MmaSparseSyncOp::verify() { + return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(), + getMatrixC(), getMmaShapeAsArray(), + getOperation()->hasAttr(getTf32EnabledAttrName()), + true); +} + //===----------------------------------------------------------------------===// // NVGPU_LdMatrixOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 0a9f8d5..c95b2fc 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -313,3 +313,119 @@ func.func @async_cp_zfill( return } + +// ----- + +// CHECK-LABEL: func @mma_sp_sync_f16_16832( +func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>, + %arg1: vector<4x2xf16>, + %arg2: vector<2x2xf16>, + %arg3: vector<2xi16>) -> vector<2x2xf16> { + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<2xf16>> + + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<2xf16>> + + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>> + + // CHECK-NOT llvm.extractvalue + + // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32 + // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32 + + // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3,$4,$5},{$6,$7,$8,$9},{$10,$11},$12,$13;" + // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r,r,r,r,r" + // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] : + // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + + %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} : + (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + + // CHECK-DAG: llvm.extractvalue %[[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK-DAG: llvm.extractvalue %[[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.insertvalue %{{.+}}, %{{.+}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.insertvalue %{{.+}}, %{{.+}}[1] : !llvm.array<2 x vector<2xf16>> + return %d : vector<2x2xf16> +} + +// ----- + +// CHECK-LABEL: func @mma_sp_sync_f16_16816( +func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>, + %arg1: vector<2x2xf16>, + %arg2: vector<2x2xf16>, + %arg3: vector<2xi16>) -> vector<2x2xf16> { + + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>> + + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>> + + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>> + + // CHECK-NOT llvm.extractvalue + + // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32 + // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32 + + // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3},{$4,$5},{$6,$7},$8,$9;" + // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r" + // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] : + // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + + %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} : + (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} + +// ----- + +// CHECK-LABEL: func @mma_sp_sync_i8_16864( +func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>, + %arg1: vector<4x4xi8>, + %arg2: vector<2x2xi32>, + %arg3: vector<2xi16>) -> vector<2x2xi32> { + + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32 + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32 + // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32 + // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<4xi8>> + + + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32 + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32 + + // CHECK: llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>> + // CHECK: llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>> + + // CHECK-NOT llvm.extractvalue + + // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32 + // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32 + + // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: "mma.sp.sync.aligned.m16n8k64.row.col.satfinite.s32.s8.s8.s32 + // CHECK-SAME: "=r,=r,=r,=r,r,r,r,r,r,r,r,r,r,r,r,r,r,r" + // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] : + // CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32) + + %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} : + (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + return %d : vector<2x2xi32> +} diff --git a/mlir/test/Dialect/NVGPU/roundtrip.mlir b/mlir/test/Dialect/NVGPU/roundtrip.mlir index 524f1fd..ad516b4 100644 --- a/mlir/test/Dialect/NVGPU/roundtrip.mlir +++ b/mlir/test/Dialect/NVGPU/roundtrip.mlir @@ -19,6 +19,44 @@ func.func @mma_sync(%arg0: vector<4x2xf16>, return %d : vector<2x2xf16> } +// CHECK-LABEL: func @mma_sp_sync_f16_16832( +func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>, + %arg1: vector<4x2xf16>, + %arg2: vector<2x2xf16>, + %arg3: vector<2xi16>) -> vector<2x2xf16> { + // CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) { + // CHECK-SAME: mmaShape = [16, 8, 32] + // CHECK-SAME: (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} : + (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} + +// CHECK-LABEL: func @mma_sp_sync_f16_16816( +func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>, + %arg1: vector<2x2xf16>, + %arg2: vector<2x2xf16>, + %arg3: vector<2xi16>) -> vector<2x2xf16> { + // CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) { + // CHECK-SAME: mmaShape = [16, 8, 16] + // CHECK-SAME: (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} : + (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} + +// CHECK-LABEL: func @mma_sp_sync_i8_16864( +func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>, + %arg1: vector<4x4xi8>, + %arg2: vector<2x2xi32>, + %arg3: vector<2xi16>) -> vector<2x2xi32> { + // CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) { + // CHECK-SAME: mmaShape = [16, 8, 64] + // CHECK-SAME: (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} : + (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + return %d : vector<2x2xi32> +} func.func @async_cp(%dst : memref<2x7x5xf32, 3>, %src : memref<4x5xf32>){ // CHECK-LABEL: func @async_cp