[mlir][gpu] Add support for unsigned integer extend in vector to gpu.subgroup_mma...
authorQuinn Dawkins <quinn@nod-labs.com>
Mon, 13 Feb 2023 01:33:10 +0000 (20:33 -0500)
committerQuinn Dawkins <quinn@nod-labs.com>
Tue, 14 Feb 2023 18:09:46 +0000 (13:09 -0500)
Unsigned integer types are supported in subgroup mma ops by matching
against arith.extui ops. This allows for subgroup_mma_compute ops with
mixed signedness which requires later conversions to handle this. SPIR-V
cooperative matrix ops support this while the lowering to WMMA does not.

Differential Revision: https://reviews.llvm.org/D143922

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir

index 46d40a7..bf5be54 100644 (file)
@@ -60,6 +60,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
 
   if (type.getElementType().isSignedInteger(8))
     return NVVM::MMATypes::s8;
+  if (type.getElementType().isUnsignedInteger(8))
+    return NVVM::MMATypes::u8;
   // Accumulator type is signless and implies signed.
   if (type.getElementType().isInteger(32))
     return NVVM::MMATypes::s32;
@@ -112,11 +114,8 @@ struct WmmaLoadOpToNVVMLowering
     }
     NVVM::MMAFrag frag = convertOperand(retType.getOperand());
     // Check that there is an exisiting instruction for the combination we need.
-    if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) {
-      llvm::errs() << "No matching intrinsic " << m << " " << n << " " << k
-                   << "\n";
+    if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
-    }
 
     Type resType = convertMMAToLLVMType(retType);
     Location loc = op->getLoc();
@@ -245,6 +244,12 @@ struct WmmaMmaOpToNVVMLowering
                                         destType) == 0)
       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
 
+    NVVM::MMATypes bElementType = getElementType(
+        subgroupMmaComputeOp.getOpB().getType().cast<gpu::MMAMatrixType>());
+    if (bElementType != sourceType)
+      return rewriter.notifyMatchFailure(
+          op, "WMMA compute op input matrix element types must match.");
+
     unpackOp(adaptor.getOpA());
     unpackOp(adaptor.getOpB());
     unpackOp(adaptor.getOpC());
index cdd8cd7..b0fa50d 100644 (file)
@@ -143,7 +143,8 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
 
   // Only allow integer types if the signedness can be inferred.
   if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8))
-    if (!readOp->hasOneUse() || !isa<arith::ExtSIOp>(*readOp->user_begin()))
+    if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
+                                 !isa<arith::ExtUIOp>(*readOp->user_begin())))
       return false;
 
   AffineMap map = readOp.getPermutationMap();
@@ -194,8 +195,9 @@ static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
   return broadcastOp.getVectorType().getRank() == 2;
 }
 
-/// Return true if this signed extend op can be folded into a contract op.
-static bool signedExtendSupportsMMAMatrixType(arith::ExtSIOp extOp) {
+/// Return true if this integer extend op can be folded into a contract op.
+template <typename ExtOpTy>
+static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
   if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
     return false;
   return llvm::all_of(extOp->getUsers(), [](Operation *user) {
@@ -282,8 +284,10 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
     return constantSupportsMMAMatrixType(constant);
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
     return broadcastSupportsMMAMatrixType(broadcast);
-  if (auto extend = dyn_cast<arith::ExtSIOp>(op))
-    return signedExtendSupportsMMAMatrixType(extend);
+  if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
+    return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
+  if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
+    return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
   return elementwiseSupportsMMAMatrixType(op);
 }
 
@@ -429,10 +433,11 @@ struct CombineTransferReadOpTranspose final
                                 PatternRewriter &rewriter) const override {
     // Look through integer extend ops.
     Value source = op.getVector();
-    auto extOp = source.getDefiningOp<arith::ExtSIOp>();
     auto resultType = op.getVectorType();
-    if (extOp) {
-      source = extOp.getOperand();
+    Operation *extOp;
+    if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
+        (extOp = source.getDefiningOp<arith::ExtUIOp>())) {
+      source = extOp->getOperand(0);
       resultType =
           VectorType::get(resultType.getShape(),
                           source.getType().cast<VectorType>().getElementType());
@@ -469,9 +474,14 @@ struct CombineTransferReadOpTranspose final
             .getResult();
 
     // Fuse through the integer extend op.
-    if (extOp)
-      result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
-                   .getResult();
+    if (extOp) {
+      if (isa<arith::ExtSIOp>(extOp))
+        result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
+                     .getResult();
+      else
+        result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
+                     .getResult();
+    }
 
     rewriter.replaceOp(op, result);
     return success();
@@ -484,15 +494,15 @@ struct CombineTransferReadOpTranspose final
 // Figure the right layout to use by looking at op uses.
 // TODO: Change the GPU dialect to abstract the layout at the this level and
 // only care about it during lowering to NVVM.
-template <typename OpTy>
-static const char *inferFragType(OpTy op) {
+static const char *inferFragType(Operation *op) {
   for (Operation *users : op->getUsers()) {
     auto contract = dyn_cast<vector::ContractionOp>(users);
     if (!contract)
       continue;
-    if (contract.getLhs() == op.getResult())
+    assert(op->getNumResults() == 1);
+    if (contract.getLhs() == op->getResult(0))
       return "AOp";
-    if (contract.getRhs() == op.getResult())
+    if (contract.getRhs() == op->getResult(0))
       return "BOp";
   }
   return "COp";
@@ -521,14 +531,15 @@ static void convertTransferReadOp(vector::TransferReadOp op,
   auto elType = op.getVectorType().getElementType();
   const char *fragType = inferFragType(op);
   if (op->hasOneUse()) {
-    auto extOp = dyn_cast<arith::ExtSIOp>(*op->user_begin());
-    // Infer the signedness of the mma type from the signed extend.
-    if (extOp) {
-      elType = IntegerType::get(op.getContext(),
-                                elType.cast<IntegerType>().getWidth(),
-                                IntegerType::Signed);
-      mappingResult = extOp.getResult();
-      fragType = inferFragType(extOp);
+    auto user = *op->user_begin();
+    // Infer the signedness of the mma type from the integer extend.
+    bool isSignedExtend = isa<arith::ExtSIOp>(user);
+    if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
+      elType = IntegerType::get(
+          op.getContext(), elType.cast<IntegerType>().getWidth(),
+          isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
+      mappingResult = user->getResult(0);
+      fragType = inferFragType(user);
     }
   }
   gpu::MMAMatrixType type =
index d634677..92ab0cb 100644 (file)
@@ -4028,9 +4028,19 @@ verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
       typeR.getScope() != typeB.getScope() ||
       typeR.getScope() != typeC.getScope())
     return op.emitOpError("matrix scope must match");
-  if (typeA.getElementType() != typeB.getElementType() ||
-      typeR.getElementType() != typeC.getElementType())
-    return op.emitOpError("matrix element type must match");
+  auto elementTypeA = typeA.getElementType();
+  auto elementTypeB = typeB.getElementType();
+  if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
+    if (elementTypeA.cast<IntegerType>().getWidth() !=
+        elementTypeB.cast<IntegerType>().getWidth())
+      return op.emitOpError(
+          "matrix A and B integer element types must be the same bit width");
+  } else if (elementTypeA != elementTypeB) {
+    return op.emitOpError(
+        "matrix A and B non-integer element types must match");
+  }
+  if (typeR.getElementType() != typeC.getElementType())
+    return op.emitOpError("matrix accumulator element type must match");
   return success();
 }
 
index 93cfd76..c742150 100644 (file)
@@ -266,3 +266,24 @@ func.func @matmul_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2:
   vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
   return
 }
+
+// CHECK-LABEL: func @matmul_mixed_signedness_int8
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xui8, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp">
+//   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi32> -> !gpu.mma_matrix<16x16xi32, "COp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xui8, "AOp">, !gpu.mma_matrix<16x16xsi8, "BOp"> -> !gpu.mma_matrix<16x16xi32, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32>
+func.func @matmul_mixed_signedness_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) {
+  %cst_0 = arith.constant dense<0> : vector<16x16xi8>
+  %c0 = arith.constant 0 : index
+  %cst_i8 = arith.constant 0 : i8
+  %cst_i32 = arith.constant 0 : i32
+  %Ar = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
+  %Br = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
+  %Ae = arith.extui %Ar : vector<16x16xi8> to vector<16x16xi32>
+  %Be = arith.extsi %Br : vector<16x16xi8> to vector<16x16xi32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
+  return
+}
index 723d7d4..de31458 100644 (file)
@@ -136,13 +136,21 @@ spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>
 // -----
 
 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{matrix element type must match}}
+  // expected-error @+1 {{matrix A and B non-integer element types must match}}
   %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xf32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
   spirv.Return
 }
 
 // -----
 
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+  // expected-error @+1 {{matrix A and B integer element types must be the same bit width}}
+  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xui8, Subgroup>, !spirv.coopmatrix<16x8xsi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
   // expected-error @+1 {{Pointer must point to a scalar or vector type}}
   %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>