[mlir][NVGPU] Add support for structured sparsity MMA variants
authorChristopher Bate <cbate@nvidia.com>
Sat, 5 Nov 2022 15:43:24 +0000 (09:43 -0600)
committerChristopher Bate <cbate@nvidia.com>
Mon, 7 Nov 2022 16:43:03 +0000 (09:43 -0700)
This change adds a new NVGPU operation that targets the PTX `mma.sp.sync`
instruction variants. A lowering to NVVM is provided using inline
assembly.

Reviewed By: ThomasRaoux, manishucsd

Differential Revision: https://reviews.llvm.org/D137202

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
mlir/test/Dialect/NVGPU/roundtrip.mlir

index 138ffc8..db4ee53 100644 (file)
@@ -98,10 +98,24 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", [
   let hasVerifier = 1;
 }
 
-def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
-                               Pure,
-                               PredOpTrait<"matrixA and matrixB have same element type", 
-                                           TCopVTEtIsSameAs<0, 1>>]> {
+class NVGPU_MmaSyncOp<string mnemonic> :
+        NVGPU_Op<mnemonic,  [Pure,
+                             PredOpTrait<"matrixA and matrixB have same element type", 
+                                         TCopVTEtIsSameAs<0, 1>>]> {
+  code extraBaseClassDeclaration = [{
+    std::array<int64_t, 3> getMmaShapeAsArray() {      
+      ArrayAttr mmaShape = this->getMmaShape();
+      assert(mmaShape.size() == 3 && "mmaShape should be three integers");
+      return {mmaShape[0].cast<IntegerAttr>().getInt(),
+              mmaShape[1].cast<IntegerAttr>().getInt(),
+              mmaShape[2].cast<IntegerAttr>().getInt()};
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
   let description = [{
     The `nvgpu.mma.sync` op represents the warp-level matrix-multiply-and-
     accumulate (mma) operation that is compatible with `nvvm.mma.sync`. 
@@ -143,9 +157,63 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
     `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
   }];
 
-  let hasVerifier = 1;
+  let extraClassDeclaration = extraBaseClassDeclaration;
 }
 
+def NVGPU_MmaSparseSyncMetadataType : FixedVectorOfLengthAndType<[2], [I16]>,
+                        BuildableType<"::mlir::VectorType::get("
+                          "{2},$_builder.getI16Type())">;
+
+def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
+  let description = [{
+  The `nvgu.mma.sp.sync` operation performs a warp-distributed MMA operation
+  where operand A is "structured sparse". In this case, the `matrixA` operand
+  represents the (warp-distributed) non-zero values of operand A, and the
+  `sparse_metadata` operand provides the indices.
+
+  The full description of the sparsity storage format and distribution scheme is
+  described in the PTX docs. This operation is meant to follow the semantic
+  described in the PTX documentation here:
+  https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma
+
+  The way the indices are distributed among the threads in a warp is controlled
+  by the optional `sparsity_selector` operand, which is `0` by default. For
+  more information, please consult the PTX documentation linked above.
+
+  Example (targetingthe f16 16x8x32 `mma.sp` PTX instruction):
+
+  ```mlir  
+  nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} :
+    (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  ```  
+  }];
+
+  let arguments = (ins AnyVector:$matrixA,                       
+                       AnyVector:$matrixB,
+                       AnyVector:$matrixC,
+                       NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
+                       I64ArrayAttr:$mmaShape,
+                       DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
+                       OptionalAttr<UnitAttr>:$tf32Enabled
+                       );
+
+  let results = (outs AnyVector:$res);
+
+  let builders = [
+    OpBuilder<(ins "Value":$matrixA,
+                   "Value":$matrixB,
+                   "Value":$matrixC,
+                   "Value":$sparseMetadata,
+                   "ArrayRef<int64_t>":$mmaShape)>
+  ];
+
+  let assemblyFormat = [{
+    `(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict
+    `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration;
+}
 
 def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
                                        AttrSizedOperandSegments]> {
index c4c49f2..d9f54b8 100644 (file)
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -253,6 +255,23 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
   }
 };
 
+/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
+/// enum).
+static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
+  Type elType = getElementTypeOrSelf(t);
+  if (elType.isInteger(8))
+    return NVVM::MMATypes::s8;
+  if (elType.isInteger(4))
+    return NVVM::MMATypes::s4;
+  if (elType.isF16())
+    return NVVM::MMATypes::f16;
+  if (elType.isF64())
+    return NVVM::MMATypes::f64;
+  if (elType.isF32())
+    return NVVM::MMATypes::tf32;
+  return failure();
+}
+
 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
   using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
 
@@ -262,53 +281,38 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
     Location loc = op->getLoc();
     // Get the shapes of the MMAMatrix type being used. The shapes will
     // choose which intrinsic this op will be lowered to.
-    auto aType = op.getMatrixA().getType().cast<VectorType>();
-    auto cType = op.getMatrixC().getType().cast<VectorType>();
+    VectorType aType = op.getMatrixA().getType();
+    VectorType bType = op.getMatrixA().getType();
+    VectorType cType = op.getMatrixC().getType();
 
-    int64_t m = op.getMmaShape()[0].cast<IntegerAttr>().getInt();
-    int64_t n = op.getMmaShape()[1].cast<IntegerAttr>().getInt();
-    int64_t k = op.getMmaShape()[2].cast<IntegerAttr>().getInt();
-    std::array<int64_t, 3> gemmShape{m, n, k};
+    std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
+
+    // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
+    bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
+    if (aType.getElementType().isF32() && !tf32Enabled)
+      return failure();
 
-    NVVM::MMATypes ptxTypeA;
-    NVVM::MMATypes ptxTypeB;
+    FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
+    if (failed(ptxTypeA))
+      return op->emitOpError("failed to deduce operand PTX types");
+    FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
+    if (failed(ptxTypeB))
+      return op->emitOpError("failed to deduce operand PTX types");
     Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
         cType.getElementType(), /*isAccumulator=*/true);
     if (!ptxTypeC)
       return op->emitError(
           "could not infer the PTX type for the accumulator/result");
 
-    // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
-    bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
-    if (aType.getElementType().isF32() && !tf32Enabled)
-      return failure();
-
+    // TODO: add an attribute to the op to customize this behavior.
     Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
-    if (aType.getElementType().isInteger(8)) {
-      ptxTypeA = NVVM::MMATypes::s8;
-      ptxTypeB = NVVM::MMATypes::s8;
+    if (aType.getElementType().isa<IntegerType>())
       overflow = NVVM::MMAIntOverflow::satfinite;
-    } else if (aType.getElementType().isInteger(4)) {
-      ptxTypeA = NVVM::MMATypes::s4;
-      ptxTypeB = NVVM::MMATypes::s4;
-      overflow = NVVM::MMAIntOverflow::satfinite;
-    } else if (aType.getElementType().isF16()) {
-      ptxTypeA = NVVM::MMATypes::f16;
-      ptxTypeB = NVVM::MMATypes::f16;
-    } else if (aType.getElementType().isF64()) {
-      ptxTypeA = NVVM::MMATypes::f64;
-      ptxTypeB = NVVM::MMATypes::f64;
-    } else if (aType.getElementType().isF32()) {
-      ptxTypeA = NVVM::MMATypes::tf32;
-      ptxTypeB = NVVM::MMATypes::tf32;
-    } else {
-      return op->emitError("could not deduce operand PTX types");
-    }
 
     SmallVector<Value> matA =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), ptxTypeA);
+        unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
     SmallVector<Value> matB =
-        unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), ptxTypeB);
+        unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
     SmallVector<Value> matC =
         unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
 
@@ -321,7 +325,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
         /*b1Op=*/llvm::None,
         /*intOverflow=*/overflow,
         /*multiplicandPtxTypes=*/
-        std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
+        std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
         /*multiplicandLayouts=*/
         std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
                                        NVVM::MMALayout::col});
@@ -376,13 +380,182 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
   SmallVector<Value> asmVals{dstPtr, srcPtr, dstBytes, srcBytes};
 
   rewriter.create<LLVM::InlineAsmOp>(
-      loc, LLVM::LLVMVoidType::get(rewriter.getContext()), /*operands=*/asmVals,
+      loc, LLVM::LLVMVoidType::get(rewriter.getContext()),
+      /*operands=*/asmVals,
       /*asm_string=*/asmStr,
       /*constraints=*/asmConstraints, /*has_side_effects=*/true,
       /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
       /*operand_attrs=*/ArrayAttr());
 }
 
+/// Returns the constraints for the sparse MMA inline assembly instruction.
+static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
+                                                     unsigned matBSize,
+                                                     unsigned matCSize) {
+  std::string str;
+  llvm::raw_string_ostream ss(str);
+  for (unsigned i = 0; i < matCSize; i++)
+    ss << "=r,";
+  for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
+    ss << "r,";
+  // The final two operands are for the sparsity metadata and sparsity selector.
+  ss << "r,r";
+  ss.flush();
+  return str;
+}
+
+/// Returns the string for the `mma.sp.sync` instruction that corresponds to
+/// the give parameters. Note that this function doesn't do any validation,
+/// it's expected that the provided parameters correspond to a valid
+/// instruction.
+static std::string
+buildMmaSparseAsmString(const std::array<int64_t, 3> &shape, unsigned matASize,
+                        unsigned matBSize, unsigned matCSize,
+                        NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
+                        NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
+                        Optional<NVVM::MMAIntOverflow> overflow) {
+  auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
+    return NVVM::stringifyMMATypes(ptxType);
+  };
+
+  std::string asmStr;
+  llvm::raw_string_ostream ss(asmStr);
+  ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
+     << shape[2] << ".row.col.";
+
+  if (overflow)
+    ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
+
+  ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
+     << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
+  unsigned asmArgIdx = 0;
+
+  // The operand string is structured into sections `{matC elements...},
+  // {matA elements...}, {matB elements...}, {matC elements}`.
+  for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
+    ss << "{";
+    for (unsigned i = 0; i < arrSize; i++)
+      ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
+    ss << "},";
+  }
+  ss << "$" << asmArgIdx++ << ",$" << asmArgIdx++ << ";";
+  ss.flush();
+  return asmStr;
+}
+
+/// Builds an inline assembly operation corresponding to the specified MMA
+/// sparse sync operation.
+static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
+    Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
+    NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
+    Optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
+    ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
+    int64_t metadataSelector, const std::array<int64_t, 3> &shape,
+    Type intrinsicResultType, ConversionPatternRewriter &rewriter) {
+  auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
+                                                  LLVM::AsmDialect::AD_ATT);
+
+  std::string asmStr = buildMmaSparseAsmString(
+      shape, unpackedAData.size(), unpackedB.size(), unpackedC.size(), ptxTypeA,
+      ptxTypeB, ptxTypeC, ptxTypeD, overflow);
+  std::string constraintStr = buildMmaSparseAsmConstraintString(
+      unpackedAData.size(), unpackedB.size(), unpackedC.size());
+
+  Value selectorVal = rewriter.create<LLVM::ConstantOp>(
+      loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(metadataSelector));
+
+  SmallVector<Value> asmVals;
+  asmVals.reserve(unpackedAData.size() + unpackedB.size() + unpackedC.size() +
+                  2);
+  for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
+    llvm::append_range(asmVals, args);
+  asmVals.push_back(indexData);
+  asmVals.push_back(selectorVal);
+
+  return rewriter.create<LLVM::InlineAsmOp>(loc,
+                                            /*resultTypes=*/intrinsicResultType,
+                                            /*operands=*/asmVals,
+                                            /*asm_string=*/asmStr,
+                                            /*constraints=*/constraintStr,
+                                            /*has_side_effects=*/true,
+                                            /*is_align_stack=*/false,
+                                            /*asm_dialect=*/asmDialectAttr,
+                                            /*operand_attrs=*/ArrayAttr());
+}
+
+/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
+struct NVGPUMmaSparseSyncLowering
+    : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
+  using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    // Get the shapes of the MMAMatrix type being used. The shapes will
+    // choose which intrinsic this op will be lowered to.
+    VectorType aType = op.getMatrixA().getType();
+    VectorType bType = op.getMatrixB().getType();
+    VectorType cType = op.getMatrixC().getType();
+
+    FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
+    if (failed(ptxTypeA))
+      return op->emitOpError("failed to deduce operand PTX types");
+    FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
+    if (failed(ptxTypeB))
+      return op->emitOpError("failed to deduce operand PTX types");
+    Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
+        cType.getElementType(), /*isAccumulator=*/true);
+    if (!ptxTypeC)
+      return op->emitError(
+          "could not infer the PTX type for the accumulator/result");
+
+    // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
+    bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
+    if (aType.getElementType().isF32() && !tf32Enabled)
+      return failure();
+
+    // TODO: add an attribute to the op to customize this behavior.
+    Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
+    if (aType.getElementType().isa<IntegerType>())
+      overflow = NVVM::MMAIntOverflow::satfinite;
+
+    SmallVector<Value> matA =
+        unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
+    SmallVector<Value> matB =
+        unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
+    SmallVector<Value> matC =
+        unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
+
+    Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
+    Type intrinsicResTy = inferIntrinsicResultType(
+        typeConverter->convertType(op->getResultTypes()[0]));
+
+    // Bitcast the sparse metadata from vector<2xf16> to an i32.
+    Value sparseMetadata = adaptor.getSparseMetadata();
+    if (sparseMetadata.getType() !=
+        LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
+      return op->emitOpError() << "Expected metadata type to be LLVM "
+                                  "VectorType of 2 i16 elements";
+    sparseMetadata = rewriter.create<LLVM::BitcastOp>(
+        loc, rewriter.getI32Type(), sparseMetadata);
+
+    FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
+        loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
+        matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
+        intrinsicResTy, rewriter);
+    if (failed(intrinsicResult))
+      return failure();
+
+    assert((*intrinsicResult).getNumResults() == 1 &&
+           "expected inline asm op returns a single LLVM struct type");
+    rewriter.replaceOp(
+        op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
+                                   (*intrinsicResult)->getResult(0), rewriter));
+    return success();
+  }
+};
+
 struct NVGPUAsyncCopyLowering
     : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
   using ConvertOpToLLVMPattern<
@@ -488,8 +661,8 @@ struct NVGPUAsyncWaitLowering
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
-               NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>(
-      converter);
+               NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
+               NVGPUMmaSparseSyncLowering>(converter);
 }
 
 std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {
index 9ed04b4..24f70cb 100644 (file)
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Verifier.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -80,13 +82,21 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
         mmaShape, UnitAttr());
 }
 
-LogicalResult MmaSyncOp::verify() {
-
-  // Fundamental tensor core mma.sync op
-  // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
-  // operation is of shape: 8-by-8-by-128b. F64 is an exception. The
-  // verification for mma.sync covering various shapes and data types is based
-  // on the fundamental tensor core operionation.
+/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
+static LogicalResult verifyMmaSyncOp(Operation *op,
+                                     TypedValue<VectorType> matrixA,
+                                     TypedValue<VectorType> matrixB,
+                                     TypedValue<VectorType> matrixC,
+                                     const std::array<int64_t, 3> &mmaShape,
+                                     bool tf32Enabled, bool sparse = false) {
+
+  // The verification for mma.sync covering various shapes and data types is
+  // based on the fundamental tensor core shape.
+
+  // "Fundamental" tensor core shapes:
+  //  - For F32 (TF32), F16, S8, and S4 data
+  //    types the fundamental tensor core operation is of shape 8-by-8-by-128b.
+  //  - F64 is an exception and is of shape 8-by-8-by-256b.
   constexpr int kThreads = 32; // 32 threads per warp
   int64_t shapeM = 8;
   int64_t shapeN = 8;
@@ -98,9 +108,9 @@ LogicalResult MmaSyncOp::verify() {
   int64_t numElementC{2}; // two accumulator elements per fundamental tile
 
   // nvgpu.mma.sync vector operands (per thread)
-  auto aVector = getMatrixA().getType().cast<VectorType>();
-  auto bVector = getMatrixB().getType().cast<VectorType>();
-  auto cVector = getMatrixC().getType().cast<VectorType>();
+  auto aVector = matrixA.getType();
+  auto bVector = matrixB.getType();
+  auto cVector = matrixC.getType();
 
   // vector shapes
   ArrayRef<int64_t> aShape = aVector.getShape();
@@ -110,13 +120,9 @@ LogicalResult MmaSyncOp::verify() {
   // vector element type
   Type aType = aVector.getElementType();
 
-  // tensor float32 (TF32) enabled
-  bool tf32Enabled = getOperation()->hasAttr(getTf32EnabledAttrName());
-
-  // nvgpu.mma.sync shape (per 32 threads or per warp)
-  int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
-  int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
-  int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
+  // Certain data types are not allowed in sparse mode.
+  if (sparse && aType.isF64())
+    return op->emitError() << "f64 is not supported for sparse mode";
 
   if (aType.isF64()) {
     // exception to 8-by-8-128b fundamental tensor core tile size
@@ -127,36 +133,43 @@ LogicalResult MmaSyncOp::verify() {
              aType.isInteger(8) || aType.isInteger(4)) {
     // 8-by-8-128b fundamental tensor core tile size
     int operandBitwidth = aType.getIntOrFloatBitWidth();
-    shapeK = 128 / operandBitwidth;     // 128b wide shapeK
+    shapeK = 128 / operandBitwidth; // 128b wide shapeK
+
     numElementA = 32 / operandBitwidth; // 32b wide operand A
     numElementB = 32 / operandBitwidth; // 32b wide operand B
   } else {
-    return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
-                          "supported by nvgpu.mma.sync";
+    return op->emitError()
+           << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
+              "supported by "
+           << op->getName();
   }
 
   //
   // Basic verification
   //
 
+  auto [m, n, k] = mmaShape;
+
   // verify warp-wide size for vector a
-  if (aShape[0] * aShape[1] * kThreads != m * k)
-    return emitOpError() << "expected " << m * k
-                         << " warp-wide matrix A elements";
+  int64_t sparseFactor = sparse ? 2 : 1;
+  if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor)
+    return op->emitOpError()
+           << "expected " << m * k << " warp-wide matrix A elements";
 
   // verify warp-wide size for vector b
   if (bShape[0] * bShape[1] * kThreads != k * n)
-    return emitOpError() << "expected " << k * n
-                         << " warp-wide matrix B elements";
+    return op->emitOpError()
+           << "expected " << k * n << " warp-wide matrix B elements";
 
   // verify warp-wide size for vector c
   if (cShape[0] * cShape[1] * kThreads != m * n)
-    return emitOpError() << "expected " << m * n
-                         << " warp-wide matrix C elements";
+    return op->emitOpError()
+           << "expected " << m * n << " warp-wide matrix C elements";
 
   // verify tf32 tensor cores are enabled for only F32 datatype
   if (tf32Enabled && !(aType.isF32()))
-    return emitOpError() << "expected tf32 tensor cores only for F32 operands";
+    return op->emitOpError()
+           << "expected tf32 tensor cores only for F32 operands";
 
   //
   // Extended verification
@@ -168,23 +181,48 @@ LogicalResult MmaSyncOp::verify() {
   int64_t kTile = k / shapeK;
 
   // verify shape of aVector
-  if ((aShape[0] != mTile * kTile) || (aShape[1] != numElementA))
-    return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
-                         << " x " << numElementA << ")";
+  if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
+      (aShape[1] != numElementA))
+    return op->emitOpError() << "expected matrix A to be shaped ("
+                             << mTile * kTile << " x " << numElementA << ")";
 
   // verify shape of bVector
   if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
-    return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
-                         << " x " << numElementB << ")";
+    return op->emitOpError() << "expected matrix B to be shaped ("
+                             << kTile * nTile << " x " << numElementB << ")";
 
   // verify shape of cVector
   if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
-    return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
-                         << " x " << numElementC << ")";
+    return op->emitOpError() << "expected matrix C to be shaped ("
+                             << mTile * nTile << " x " << numElementC << ")";
 
   return success();
 }
 
+LogicalResult MmaSyncOp::verify() {
+  return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
+                         getMatrixC(), getMmaShapeAsArray(),
+                         getOperation()->hasAttr(getTf32EnabledAttrName()));
+}
+
+//===----------------------------------------------------------------------===//
+// NVGPU_MmaSparseSyncOp
+//===----------------------------------------------------------------------===//
+void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
+                            ::mlir::OperationState &odsState, Value matrixA,
+                            Value matrixB, Value matrixC, Value sparseMetadata,
+                            ArrayRef<int64_t> mmaShape) {
+  build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
+        sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
+}
+
+LogicalResult MmaSparseSyncOp::verify() {
+  return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
+                         getMatrixC(), getMmaShapeAsArray(),
+                         getOperation()->hasAttr(getTf32EnabledAttrName()),
+                         true);
+}
+
 //===----------------------------------------------------------------------===//
 // NVGPU_LdMatrixOp
 //===----------------------------------------------------------------------===//
index 0a9f8d5..c95b2fc 100644 (file)
@@ -313,3 +313,119 @@ func.func @async_cp_zfill(
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @mma_sp_sync_f16_16832(
+func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
+                                 %arg1: vector<4x2xf16>,
+                                 %arg2: vector<2x2xf16>,
+                                 %arg3: vector<2xi16>) -> vector<2x2xf16> {
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<2xf16>>
+
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<2xf16>>
+
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+
+  // CHECK-NOT llvm.extractvalue
+
+  // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
+  // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+  // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+  // CHECK-SAME: "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3,$4,$5},{$6,$7,$8,$9},{$10,$11},$12,$13;"
+  // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r,r,r,r,r"
+  // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+  // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+    (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+
+  // CHECK-DAG: llvm.extractvalue %[[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>    
+  // CHECK-DAG: llvm.extractvalue %[[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  //     CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
+  //     CHECK: llvm.insertvalue %{{.+}}, %{{.+}}[0] : !llvm.array<2 x vector<2xf16>>
+  //     CHECK: llvm.insertvalue %{{.+}}, %{{.+}}[1] : !llvm.array<2 x vector<2xf16>>
+  return %d : vector<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @mma_sp_sync_f16_16816(
+func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
+                                 %arg1: vector<2x2xf16>,
+                                 %arg2: vector<2x2xf16>,
+                                 %arg3: vector<2xi16>) -> vector<2x2xf16> {
+
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>  
+
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>  
+
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+
+  // CHECK-NOT llvm.extractvalue
+
+  // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
+  // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+  // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+  // CHECK-SAME: "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3},{$4,$5},{$6,$7},$8,$9;"
+  // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r"
+  // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+  // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+    (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @mma_sp_sync_i8_16864(
+func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
+                                %arg1: vector<4x4xi8>,
+                                %arg2: vector<2x2xi32>,
+                                %arg3: vector<2xi16>) -> vector<2x2xi32> {
+
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<4xi8>>
+  // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<4xi8>>
+  // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32
+  // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<4xi8>>
+  // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32
+  // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<4xi8>>
+
+
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<4xi8>>
+  // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<4xi8>>  
+  // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32
+
+  // CHECK: llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+  // CHECK: llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+
+  // CHECK-NOT llvm.extractvalue
+
+  // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
+  // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+  // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+  // CHECK-SAME: "mma.sp.sync.aligned.m16n8k64.row.col.satfinite.s32.s8.s8.s32
+  // CHECK-SAME: "=r,=r,=r,=r,r,r,r,r,r,r,r,r,r,r,r,r,r,r"
+  // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+  // CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32)
+
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+    (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  return %d : vector<2x2xi32>
+}
index 524f1fd..ad516b4 100644 (file)
@@ -19,6 +19,44 @@ func.func @mma_sync(%arg0: vector<4x2xf16>,
   return %d : vector<2x2xf16>
 }
 
+// CHECK-LABEL: func @mma_sp_sync_f16_16832(
+func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
+                                 %arg1: vector<4x2xf16>,
+                                 %arg2: vector<2x2xf16>,
+                                 %arg3: vector<2xi16>) -> vector<2x2xf16> {
+  //      CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
+  // CHECK-SAME:   mmaShape = [16, 8, 32]
+  // CHECK-SAME: (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+    (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func @mma_sp_sync_f16_16816(
+func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
+                                 %arg1: vector<2x2xf16>,
+                                 %arg2: vector<2x2xf16>,
+                                 %arg3: vector<2xi16>) -> vector<2x2xf16> {
+  //      CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
+  // CHECK-SAME:   mmaShape = [16, 8, 16]
+  // CHECK-SAME: (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+    (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func @mma_sp_sync_i8_16864(
+func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
+                                %arg1: vector<4x4xi8>,
+                                %arg2: vector<2x2xi32>,
+                                %arg3: vector<2xi16>) -> vector<2x2xi32> {
+  //      CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
+  // CHECK-SAME:   mmaShape = [16, 8, 64]
+  // CHECK-SAME: (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+    (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  return %d : vector<2x2xi32>
+}
 
 func.func @async_cp(%dst : memref<2x7x5xf32, 3>, %src : memref<4x5xf32>){
   // CHECK-LABEL: func @async_cp