[mlir][gpu][NFC] Simplify conversion of MMA type to NVVM
authorthomasraoux <thomasraoux@google.com>
Tue, 8 Jun 2021 05:25:30 +0000 (22:25 -0700)
committerthomasraoux <thomasraoux@google.com>
Wed, 9 Jun 2021 16:33:38 +0000 (09:33 -0700)
Consolidate the type conversion in a single function to make it simpler
to use. This allow to re-use the type conversion for up coming ops.

Differential Revision: https://reviews.llvm.org/D103868

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

index 5f5213a..d72c8c2 100644 (file)
@@ -20,51 +20,6 @@ using namespace mlir;
 
 namespace {
 
-/// Contains all the common LLVM types which are used across the lowerings of
-/// GPU subgroup ops to NVVM dialect.
-struct CommonLLVMAndBuiltInMLIRTypes {
-public:
-  CommonLLVMAndBuiltInMLIRTypes(MLIRContext *context) {
-    numHalfsInOpFrags.resize(4);
-    numHalfsInOpFrags[A] = 8;
-    numHalfsInOpFrags[B] = 8;
-    numHalfsInOpFrags[C] = 4;
-    i32Ty = IntegerType::get(context, 32);
-    f16Ty = FloatType::getF16(context);
-    f32Ty = FloatType::getF32(context);
-    f16x2Ty = VectorType::get(2, f16Ty);
-    fragArrayABTy = LLVM::LLVMStructType::getLiteral(
-        context, SmallVector<Type>(8, f16x2Ty));
-    fragArrayCDTy = LLVM::LLVMStructType::getLiteral(
-        context, SmallVector<Type>(4, f16x2Ty));
-    fragArrayCDF32Ty =
-        LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
-  };
-
-  Type i32Ty;
-  Type f16Ty;
-  Type f32Ty;
-  Type f16x2Ty;
-  /// Type for the fragment of A and B operands that a single thread holds for
-  /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) +
-  /// (beta*C).
-  Type fragArrayABTy;
-  /// Type for the fragment of C and D operands that a single thread holds for
-  /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) +
-  /// (beta*C).
-  Type fragArrayCDTy;
-  /// Type for the fragment of C and D operands that a single thread holds for
-  /// fp32 data type in a WMMA operation of the form D = (alpha*(A*B)) +
-  /// (beta*C).
-  Type fragArrayCDF32Ty;
-  /// Represents the number of f16 elements a single thread holds in a WMMA
-  /// operation of the form D = (alpha*(A*B)) + (beta*C) .
-  SmallVector<unsigned, 4> numHalfsInOpFrags;
-  /// Represents the operands of a MMA operation of the form D = (alpha*(A*B)) +
-  /// (beta*C).
-  enum OperandMap { A, B, C };
-};
-
 /// Checks if all the operands of the op being lowered are of LLVM Types. The
 /// types are expected to be converted by the `LLVMTypeConverter` before the op
 /// is actually lowered. If the type of an operands is not already converted it
@@ -85,18 +40,32 @@ static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
 static constexpr StringRef kInvalidCaseStr =
     "Unimplemented WMMA variant, Only M16N16K16 version implemented.";
 
+/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
+static LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) {
+  StringRef operandStr = type.getOperand();
+  assert(type.getElementType().isa<FloatType>());
+  Type baseType = type.getElementType().isF16()
+                      ? VectorType::get(2, type.getElementType())
+                      : type.getElementType();
+  auto getLLVMType = [&](int64_t numElements) {
+    return LLVM::LLVMStructType::getLiteral(
+        type.getContext(), SmallVector<Type, 8>(numElements, baseType));
+  };
+  if (operandStr.equals("AOp") || operandStr.equals("BOp"))
+    return getLLVMType(8);
+  if (type.getElementType().isF16())
+    return getLLVMType(4);
+  return getLLVMType(8);
+}
+
 /// This class implements the conversion of GPU MMA loadOp to wmma.load op
 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
 /// emits code that is necessary to store the data in the destination memref
 /// after it has been loaded.
 struct WmmaLoadOpToNVVMLowering
-    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp>,
-      private CommonLLVMAndBuiltInMLIRTypes {
-public:
-  explicit WmmaLoadOpToNVVMLowering(LLVMTypeConverter &typeConverter)
-      : ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp>(typeConverter),
-        CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
-  }
+    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
+  using ConvertOpToLLVMPattern<
+      gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
@@ -115,25 +84,22 @@ public:
     if (indexTypeBitwidth != 32)
       return rewriter.notifyMatchFailure(
           op, "Expected indices to the memref to be 32-bit wide.");
-
-    // Source memref of the original op.
-    MemRefType srcMemrefType =
-        subgroupMmaLoadMatrixOp.srcMemref().getType().cast<MemRefType>();
     Location loc = op->getLoc();
 
     auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
 
+    gpu::SubgroupMmaLoadMatrixOpAdaptor adaptor(operands);
     // MemRefDescriptor to extract alignedPtr and offset.
-    MemRefDescriptor promotedSrcOp(
-        gpu::SubgroupMmaLoadMatrixOpAdaptor(operands).srcMemref());
+    MemRefDescriptor promotedSrcOp(adaptor.srcMemref());
 
     // Emit ops which compute the load offset using `srcOffsetI`,
     // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
     // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are
     // assumed to be normalized and hence the simple conversion works.
-    SmallVector<Value> indices(subgroupMmaLoadMatrixOp.indices());
+    SmallVector<Value> indices(adaptor.indices());
     Value srcOffsetIVal = indices[0];
     Value srcOffsetJVal = indices[1];
+    Type i32Ty = rewriter.getI32Type();
     Value leadingDim32 =
         rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
     Value numElemsLeadDim =
@@ -146,8 +112,7 @@ public:
     Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
                                                       promotedSrcOpToUse);
     Value loadAddress = rewriter.create<LLVM::GEPOp>(
-        loc,
-        LLVM::LLVMPointerType::get(f16Ty, srcMemrefType.getMemorySpaceAsInt()),
+        loc, promotedSrcOp.getElementPtrType(),
         promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
 
     // Bitcast the base address pointer of the destination memref, So that
@@ -155,7 +120,8 @@ public:
     // intrinsic exposed by NVPTX backend.
     Value loadAddressCasted = rewriter.create<LLVM::BitcastOp>(
         loc,
-        LLVM::LLVMPointerType::get(i32Ty, srcMemrefType.getMemorySpaceAsInt()),
+        LLVM::LLVMPointerType::get(
+            i32Ty, promotedSrcOp.getElementPtrType().getAddressSpace()),
         loadAddress);
 
     // Get the shape of the MMAMatrix type being returned. The shape will
@@ -164,51 +130,33 @@ public:
         subgroupMmaLoadMatrixOp.res().getType().cast<gpu::MMAMatrixType>();
     ArrayRef<int64_t> retTypeShape = retType.getShape();
 
-    Type resType;
+    Type resType = convertMMAToLLVMType(retType);
     StringRef operandStr = retType.getOperand();
-    if (operandStr.equals("AOp") || operandStr.equals("BOp")) {
-      resType = fragArrayABTy;
-    } else {
-      if (srcMemrefType.getElementType().isF16())
-        resType = fragArrayCDTy;
-      else if (srcMemrefType.getElementType().isF32())
-        resType = fragArrayCDF32Ty;
-      else
-        return failure();
-    }
 
     // Create nvvm.mma_load op according to the operand types.
     SmallVector<Value, 2> loadOpOperands({loadAddressCasted, leadingDim32});
     if (operandStr.equals("AOp")) {
       if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
-        NVVM::WMMALoadAM16N16K16Op wmmaLoadAOp =
-            rewriter.create<NVVM::WMMALoadAM16N16K16Op>(loc, resType,
-                                                        loadOpOperands);
-        rewriter.replaceOp(op, wmmaLoadAOp.getResult());
+        rewriter.replaceOpWithNewOp<NVVM::WMMALoadAM16N16K16Op>(op, resType,
+                                                                loadOpOperands);
       } else {
         return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
       }
     } else if (operandStr.equals("BOp")) {
       if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
-        NVVM::WMMALoadBM16N16K16Op wmmaLoadBOp =
-            rewriter.create<NVVM::WMMALoadBM16N16K16Op>(loc, resType,
-                                                        loadOpOperands);
-        rewriter.replaceOp(op, wmmaLoadBOp.getResult());
+        rewriter.replaceOpWithNewOp<NVVM::WMMALoadBM16N16K16Op>(op, resType,
+                                                                loadOpOperands);
       } else {
         return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
       }
     } else {
       if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
-        if (srcMemrefType.getElementType().isF16()) {
-          NVVM::WMMALoadCF16M16N16K16Op wmmaLoadCOp =
-              rewriter.create<NVVM::WMMALoadCF16M16N16K16Op>(loc, resType,
-                                                             loadOpOperands);
-          rewriter.replaceOp(op, wmmaLoadCOp.getResult());
-        } else if (srcMemrefType.getElementType().isF32()) {
-          NVVM::WMMALoadCF32M16N16K16Op wmmaLoadCOp =
-              rewriter.create<NVVM::WMMALoadCF32M16N16K16Op>(loc, resType,
-                                                             loadOpOperands);
-          rewriter.replaceOp(op, wmmaLoadCOp.getResult());
+        if (retType.getElementType().isF16()) {
+          rewriter.replaceOpWithNewOp<NVVM::WMMALoadCF16M16N16K16Op>(
+              op, resType, loadOpOperands);
+        } else if (retType.getElementType().isF32()) {
+          rewriter.replaceOpWithNewOp<NVVM::WMMALoadCF32M16N16K16Op>(
+              op, resType, loadOpOperands);
         }
       } else {
         return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
@@ -223,13 +171,9 @@ public:
 /// emits code that is necessary to unpack the data in the source and
 /// convert the data in the format that is needed by the NVVM op.
 struct WmmaStoreOpToNVVMLowering
-    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp>,
-      private CommonLLVMAndBuiltInMLIRTypes {
-public:
-  explicit WmmaStoreOpToNVVMLowering(LLVMTypeConverter &typeConverter)
-      : ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp>(typeConverter),
-        CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
-  }
+    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
+  using ConvertOpToLLVMPattern<
+      gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
@@ -250,22 +194,19 @@ public:
 
     Location loc = op->getLoc();
 
-    // Destination memref of the original op.
-    MemRefType dstMemrefType =
-        subgroupMmaStoreMatrixOp.dstMemref().getType().cast<MemRefType>();
-
+    gpu::SubgroupMmaStoreMatrixOpAdaptor adaptor(operands);
     // MemRefDescriptor to extract alignedPtr and offset.
-    MemRefDescriptor promotedDstOp(
-        gpu::SubgroupMmaStoreMatrixOpAdaptor(operands).dstMemref());
+    MemRefDescriptor promotedDstOp(adaptor.dstMemref());
 
     auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
 
     // Emit ops which compute the store offset using `dstOffsetI`,
     // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
     // ((leadDimension * dstOffsetI) + dstOffsetJ)).
-    SmallVector<Value> indices(subgroupMmaStoreMatrixOp.indices());
+    SmallVector<Value> indices(adaptor.indices());
     Value dstOffsetIVal = indices[0];
     Value dstOffsetJVal = indices[1];
+    Type i32Ty = rewriter.getI32Type();
     Value leadingDim32 =
         rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
     Value numElemsLeadDim =
@@ -278,8 +219,7 @@ public:
     Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
                                                       promotedDstOpToUse);
     Value storeAddress = rewriter.create<LLVM::GEPOp>(
-        loc,
-        LLVM::LLVMPointerType::get(f16Ty, dstMemrefType.getMemorySpaceAsInt()),
+        loc, promotedDstOp.getElementPtrType(),
         promotedDstOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
 
     // Bitcast the base address pointer of the destination memref, So that
@@ -287,7 +227,8 @@ public:
     // intrinsic exposed by NVPTX backend.
     Value storeAddressCasted = rewriter.create<LLVM::BitcastOp>(
         loc,
-        LLVM::LLVMPointerType::get(i32Ty, dstMemrefType.getMemorySpaceAsInt()),
+        LLVM::LLVMPointerType::get(
+            i32Ty, promotedDstOp.getElementPtrType().getAddressSpace()),
         storeAddress);
 
     SmallVector<Value, 4> storeOpOperands;
@@ -299,18 +240,16 @@ public:
         subgroupMmaStoreMatrixOp.src().getType().cast<gpu::MMAMatrixType>();
     ArrayRef<int64_t> srcTypeShape = srcType.getShape();
 
+    auto matrixType = adaptor.src().getType().cast<LLVM::LLVMStructType>();
+    for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
+      Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+          loc, matrixType.getBody()[i], adaptor.src(),
+          rewriter.getI32ArrayAttr(i));
+      storeOpOperands.push_back(toUse);
+    }
+    storeOpOperands.push_back(leadingDim32);
     // Unpack the results from the source.
-    if (subgroupMmaStoreMatrixOp.src()
-            .getType()
-            .cast<gpu::MMAMatrixType>()
-            .getElementType() == f16Ty) {
-      for (unsigned i = 0, e = numHalfsInOpFrags[C]; i < e; ++i) {
-        Value toUse = rewriter.create<LLVM::ExtractValueOp>(
-            loc, f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i));
-        storeOpOperands.push_back(toUse);
-      }
-      storeOpOperands.push_back(leadingDim32);
-
+    if (srcType.getElementType().isF16()) {
       // Create nvvm.mma_store op.
       if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) {
         rewriter.create<NVVM::WMMAStoreF16M16N16K16Op>(loc, storeOpOperands);
@@ -319,17 +258,8 @@ public:
       }
       rewriter.eraseOp(op);
       return success();
-    } else if (subgroupMmaStoreMatrixOp.src()
-                   .getType()
-                   .cast<gpu::MMAMatrixType>()
-                   .getElementType() == f32Ty) {
-      for (unsigned i = 0, e = 8; i < e; ++i) {
-        Value toUse = rewriter.create<LLVM::ExtractValueOp>(
-            loc, f32Ty, operands[0], rewriter.getI32ArrayAttr(i));
-        storeOpOperands.push_back(toUse);
-      }
-      storeOpOperands.push_back(leadingDim32);
-
+    }
+    if (srcType.getElementType().isF32()) {
       // Create nvvm.mma_store op.
       if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16)
         rewriter.create<NVVM::WMMAStoreF32M16N16K16Op>(loc, storeOpOperands);
@@ -339,7 +269,6 @@ public:
       rewriter.eraseOp(op);
       return success();
     }
-
     return failure();
   }
 };
@@ -347,12 +276,9 @@ public:
 /// This class implements the conversion of GPU MMA computeOp to wmma.mma op
 /// in the NVVM dialect.
 struct WmmaMmaOpToNVVMLowering
-    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp>,
-      private CommonLLVMAndBuiltInMLIRTypes {
-  explicit WmmaMmaOpToNVVMLowering(LLVMTypeConverter &typeConverter)
-      : ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp>(typeConverter),
-        CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
-  }
+    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
+  using ConvertOpToLLVMPattern<
+      gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
@@ -370,11 +296,11 @@ struct WmmaMmaOpToNVVMLowering
     // values form lowered memrefs.
     SmallVector<Value> unpackedOps;
 
-    auto unpackOp = [&](CommonLLVMAndBuiltInMLIRTypes::OperandMap op,
-                        Value operand, unsigned numElems, Type elemType) {
-      for (unsigned i = 0; i < numElems; ++i) {
+    auto unpackOp = [&](Value operand) {
+      auto structType = operand.getType().cast<LLVM::LLVMStructType>();
+      for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
         Value toUse = rewriter.create<LLVM::ExtractValueOp>(
-            loc, elemType, operand, rewriter.getI32ArrayAttr(i));
+            loc, structType.getBody()[i], operand, rewriter.getI32ArrayAttr(i));
         unpackedOps.push_back(toUse);
       }
     };
@@ -385,55 +311,39 @@ struct WmmaMmaOpToNVVMLowering
         subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
     ArrayRef<int64_t> aTypeShape = aType.getShape();
     gpu::MMAMatrixType bType =
-        subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
+        subgroupMmaComputeOp.opB().getType().cast<gpu::MMAMatrixType>();
     ArrayRef<int64_t> bTypeShape = bType.getShape();
     gpu::MMAMatrixType cType =
-        subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
+        subgroupMmaComputeOp.opC().getType().cast<gpu::MMAMatrixType>();
     ArrayRef<int64_t> cTypeShape = cType.getShape();
 
     gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands);
-    if (subgroupMmaComputeOp.opC()
-            .getType()
-            .cast<gpu::MMAMatrixType>()
-            .getElementType() == f16Ty) {
-      unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty);
-      unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty);
-      unpackOp(C, transformedOperands.opC(), numHalfsInOpFrags[C], f16x2Ty);
+    unpackOp(transformedOperands.opA());
+    unpackOp(transformedOperands.opB());
+    unpackOp(transformedOperands.opC());
 
+    if (cType.getElementType().isF16()) {
       if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
           bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
         // Create nvvm.wmma.mma op.
-        NVVM::WMMAMmaF16F16M16N16K16Op wmmaMmaOp =
-            rewriter.create<NVVM::WMMAMmaF16F16M16N16K16Op>(loc, fragArrayCDTy,
-                                                            unpackedOps);
+        rewriter.replaceOpWithNewOp<NVVM::WMMAMmaF16F16M16N16K16Op>(
+            op, transformedOperands.opC().getType(), unpackedOps);
 
-        rewriter.replaceOp(op, wmmaMmaOp.getResult());
         return success();
-      } else {
-        return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
       }
-    } else if (subgroupMmaComputeOp.opC()
-                   .getType()
-                   .cast<gpu::MMAMatrixType>()
-                   .getElementType() == f32Ty) {
-      unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty);
-      unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty);
-      unpackOp(C, transformedOperands.opC(), 8, f32Ty);
-
+      return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+    }
+    if (cType.getElementType().isF32()) {
       if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
           bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
         // Create nvvm.wmma.mma op.
-        NVVM::WMMAMmaF32F32M16N16K16Op wmmaMmaOp =
-            rewriter.create<NVVM::WMMAMmaF32F32M16N16K16Op>(
-                loc, fragArrayCDF32Ty, unpackedOps);
+        rewriter.replaceOpWithNewOp<NVVM::WMMAMmaF32F32M16N16K16Op>(
+            op, transformedOperands.opC().getType(), unpackedOps);
 
-        rewriter.replaceOp(op, wmmaMmaOp.getResult());
         return success();
-      } else {
-        return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
       }
+      return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
     }
-
     return failure();
   }
 };
@@ -443,8 +353,7 @@ struct WmmaMmaOpToNVVMLowering
 namespace mlir {
 void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                              RewritePatternSet &patterns) {
-  patterns.insert<WmmaLoadOpToNVVMLowering>(converter);
-  patterns.insert<WmmaMmaOpToNVVMLowering>(converter);
-  patterns.insert<WmmaStoreOpToNVVMLowering>(converter);
+  patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
+                  WmmaStoreOpToNVVMLowering>(converter);
 }
 } // namespace mlir