namespace {
-/// Contains all the common LLVM types which are used across the lowerings of
-/// GPU subgroup ops to NVVM dialect.
-struct CommonLLVMAndBuiltInMLIRTypes {
-public:
- CommonLLVMAndBuiltInMLIRTypes(MLIRContext *context) {
- numHalfsInOpFrags.resize(4);
- numHalfsInOpFrags[A] = 8;
- numHalfsInOpFrags[B] = 8;
- numHalfsInOpFrags[C] = 4;
- i32Ty = IntegerType::get(context, 32);
- f16Ty = FloatType::getF16(context);
- f32Ty = FloatType::getF32(context);
- f16x2Ty = VectorType::get(2, f16Ty);
- fragArrayABTy = LLVM::LLVMStructType::getLiteral(
- context, SmallVector<Type>(8, f16x2Ty));
- fragArrayCDTy = LLVM::LLVMStructType::getLiteral(
- context, SmallVector<Type>(4, f16x2Ty));
- fragArrayCDF32Ty =
- LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
- };
-
- Type i32Ty;
- Type f16Ty;
- Type f32Ty;
- Type f16x2Ty;
- /// Type for the fragment of A and B operands that a single thread holds for
- /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) +
- /// (beta*C).
- Type fragArrayABTy;
- /// Type for the fragment of C and D operands that a single thread holds for
- /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) +
- /// (beta*C).
- Type fragArrayCDTy;
- /// Type for the fragment of C and D operands that a single thread holds for
- /// fp32 data type in a WMMA operation of the form D = (alpha*(A*B)) +
- /// (beta*C).
- Type fragArrayCDF32Ty;
- /// Represents the number of f16 elements a single thread holds in a WMMA
- /// operation of the form D = (alpha*(A*B)) + (beta*C) .
- SmallVector<unsigned, 4> numHalfsInOpFrags;
- /// Represents the operands of a MMA operation of the form D = (alpha*(A*B)) +
- /// (beta*C).
- enum OperandMap { A, B, C };
-};
-
/// Checks if all the operands of the op being lowered are of LLVM Types. The
/// types are expected to be converted by the `LLVMTypeConverter` before the op
/// is actually lowered. If the type of an operands is not already converted it
static constexpr StringRef kInvalidCaseStr =
"Unimplemented WMMA variant, Only M16N16K16 version implemented.";
+/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
+static LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) {
+ StringRef operandStr = type.getOperand();
+ assert(type.getElementType().isa<FloatType>());
+ Type baseType = type.getElementType().isF16()
+ ? VectorType::get(2, type.getElementType())
+ : type.getElementType();
+ auto getLLVMType = [&](int64_t numElements) {
+ return LLVM::LLVMStructType::getLiteral(
+ type.getContext(), SmallVector<Type, 8>(numElements, baseType));
+ };
+ if (operandStr.equals("AOp") || operandStr.equals("BOp"))
+ return getLLVMType(8);
+ if (type.getElementType().isF16())
+ return getLLVMType(4);
+ return getLLVMType(8);
+}
+
/// This class implements the conversion of GPU MMA loadOp to wmma.load op
/// in the NVVM dialect. The conversion not only emits the NVVM op but also
/// emits code that is necessary to store the data in the destination memref
/// after it has been loaded.
struct WmmaLoadOpToNVVMLowering
- : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp>,
- private CommonLLVMAndBuiltInMLIRTypes {
-public:
- explicit WmmaLoadOpToNVVMLowering(LLVMTypeConverter &typeConverter)
- : ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp>(typeConverter),
- CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
- }
+ : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
+ using ConvertOpToLLVMPattern<
+ gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
if (indexTypeBitwidth != 32)
return rewriter.notifyMatchFailure(
op, "Expected indices to the memref to be 32-bit wide.");
-
- // Source memref of the original op.
- MemRefType srcMemrefType =
- subgroupMmaLoadMatrixOp.srcMemref().getType().cast<MemRefType>();
Location loc = op->getLoc();
auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
+ gpu::SubgroupMmaLoadMatrixOpAdaptor adaptor(operands);
// MemRefDescriptor to extract alignedPtr and offset.
- MemRefDescriptor promotedSrcOp(
- gpu::SubgroupMmaLoadMatrixOpAdaptor(operands).srcMemref());
+ MemRefDescriptor promotedSrcOp(adaptor.srcMemref());
// Emit ops which compute the load offset using `srcOffsetI`,
// `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
// ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are
// assumed to be normalized and hence the simple conversion works.
- SmallVector<Value> indices(subgroupMmaLoadMatrixOp.indices());
+ SmallVector<Value> indices(adaptor.indices());
Value srcOffsetIVal = indices[0];
Value srcOffsetJVal = indices[1];
+ Type i32Ty = rewriter.getI32Type();
Value leadingDim32 =
rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
Value numElemsLeadDim =
Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
promotedSrcOpToUse);
Value loadAddress = rewriter.create<LLVM::GEPOp>(
- loc,
- LLVM::LLVMPointerType::get(f16Ty, srcMemrefType.getMemorySpaceAsInt()),
+ loc, promotedSrcOp.getElementPtrType(),
promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
// Bitcast the base address pointer of the destination memref, So that
// intrinsic exposed by NVPTX backend.
Value loadAddressCasted = rewriter.create<LLVM::BitcastOp>(
loc,
- LLVM::LLVMPointerType::get(i32Ty, srcMemrefType.getMemorySpaceAsInt()),
+ LLVM::LLVMPointerType::get(
+ i32Ty, promotedSrcOp.getElementPtrType().getAddressSpace()),
loadAddress);
// Get the shape of the MMAMatrix type being returned. The shape will
subgroupMmaLoadMatrixOp.res().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> retTypeShape = retType.getShape();
- Type resType;
+ Type resType = convertMMAToLLVMType(retType);
StringRef operandStr = retType.getOperand();
- if (operandStr.equals("AOp") || operandStr.equals("BOp")) {
- resType = fragArrayABTy;
- } else {
- if (srcMemrefType.getElementType().isF16())
- resType = fragArrayCDTy;
- else if (srcMemrefType.getElementType().isF32())
- resType = fragArrayCDF32Ty;
- else
- return failure();
- }
// Create nvvm.mma_load op according to the operand types.
SmallVector<Value, 2> loadOpOperands({loadAddressCasted, leadingDim32});
if (operandStr.equals("AOp")) {
if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
- NVVM::WMMALoadAM16N16K16Op wmmaLoadAOp =
- rewriter.create<NVVM::WMMALoadAM16N16K16Op>(loc, resType,
- loadOpOperands);
- rewriter.replaceOp(op, wmmaLoadAOp.getResult());
+ rewriter.replaceOpWithNewOp<NVVM::WMMALoadAM16N16K16Op>(op, resType,
+ loadOpOperands);
} else {
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
} else if (operandStr.equals("BOp")) {
if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
- NVVM::WMMALoadBM16N16K16Op wmmaLoadBOp =
- rewriter.create<NVVM::WMMALoadBM16N16K16Op>(loc, resType,
- loadOpOperands);
- rewriter.replaceOp(op, wmmaLoadBOp.getResult());
+ rewriter.replaceOpWithNewOp<NVVM::WMMALoadBM16N16K16Op>(op, resType,
+ loadOpOperands);
} else {
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
} else {
if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
- if (srcMemrefType.getElementType().isF16()) {
- NVVM::WMMALoadCF16M16N16K16Op wmmaLoadCOp =
- rewriter.create<NVVM::WMMALoadCF16M16N16K16Op>(loc, resType,
- loadOpOperands);
- rewriter.replaceOp(op, wmmaLoadCOp.getResult());
- } else if (srcMemrefType.getElementType().isF32()) {
- NVVM::WMMALoadCF32M16N16K16Op wmmaLoadCOp =
- rewriter.create<NVVM::WMMALoadCF32M16N16K16Op>(loc, resType,
- loadOpOperands);
- rewriter.replaceOp(op, wmmaLoadCOp.getResult());
+ if (retType.getElementType().isF16()) {
+ rewriter.replaceOpWithNewOp<NVVM::WMMALoadCF16M16N16K16Op>(
+ op, resType, loadOpOperands);
+ } else if (retType.getElementType().isF32()) {
+ rewriter.replaceOpWithNewOp<NVVM::WMMALoadCF32M16N16K16Op>(
+ op, resType, loadOpOperands);
}
} else {
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
/// emits code that is necessary to unpack the data in the source and
/// convert the data in the format that is needed by the NVVM op.
struct WmmaStoreOpToNVVMLowering
- : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp>,
- private CommonLLVMAndBuiltInMLIRTypes {
-public:
- explicit WmmaStoreOpToNVVMLowering(LLVMTypeConverter &typeConverter)
- : ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp>(typeConverter),
- CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
- }
+ : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
+ using ConvertOpToLLVMPattern<
+ gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
Location loc = op->getLoc();
- // Destination memref of the original op.
- MemRefType dstMemrefType =
- subgroupMmaStoreMatrixOp.dstMemref().getType().cast<MemRefType>();
-
+ gpu::SubgroupMmaStoreMatrixOpAdaptor adaptor(operands);
// MemRefDescriptor to extract alignedPtr and offset.
- MemRefDescriptor promotedDstOp(
- gpu::SubgroupMmaStoreMatrixOpAdaptor(operands).dstMemref());
+ MemRefDescriptor promotedDstOp(adaptor.dstMemref());
auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
// Emit ops which compute the store offset using `dstOffsetI`,
// `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
// ((leadDimension * dstOffsetI) + dstOffsetJ)).
- SmallVector<Value> indices(subgroupMmaStoreMatrixOp.indices());
+ SmallVector<Value> indices(adaptor.indices());
Value dstOffsetIVal = indices[0];
Value dstOffsetJVal = indices[1];
+ Type i32Ty = rewriter.getI32Type();
Value leadingDim32 =
rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
Value numElemsLeadDim =
Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
promotedDstOpToUse);
Value storeAddress = rewriter.create<LLVM::GEPOp>(
- loc,
- LLVM::LLVMPointerType::get(f16Ty, dstMemrefType.getMemorySpaceAsInt()),
+ loc, promotedDstOp.getElementPtrType(),
promotedDstOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
// Bitcast the base address pointer of the destination memref, So that
// intrinsic exposed by NVPTX backend.
Value storeAddressCasted = rewriter.create<LLVM::BitcastOp>(
loc,
- LLVM::LLVMPointerType::get(i32Ty, dstMemrefType.getMemorySpaceAsInt()),
+ LLVM::LLVMPointerType::get(
+ i32Ty, promotedDstOp.getElementPtrType().getAddressSpace()),
storeAddress);
SmallVector<Value, 4> storeOpOperands;
subgroupMmaStoreMatrixOp.src().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> srcTypeShape = srcType.getShape();
+ auto matrixType = adaptor.src().getType().cast<LLVM::LLVMStructType>();
+ for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
+ Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+ loc, matrixType.getBody()[i], adaptor.src(),
+ rewriter.getI32ArrayAttr(i));
+ storeOpOperands.push_back(toUse);
+ }
+ storeOpOperands.push_back(leadingDim32);
// Unpack the results from the source.
- if (subgroupMmaStoreMatrixOp.src()
- .getType()
- .cast<gpu::MMAMatrixType>()
- .getElementType() == f16Ty) {
- for (unsigned i = 0, e = numHalfsInOpFrags[C]; i < e; ++i) {
- Value toUse = rewriter.create<LLVM::ExtractValueOp>(
- loc, f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i));
- storeOpOperands.push_back(toUse);
- }
- storeOpOperands.push_back(leadingDim32);
-
+ if (srcType.getElementType().isF16()) {
// Create nvvm.mma_store op.
if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) {
rewriter.create<NVVM::WMMAStoreF16M16N16K16Op>(loc, storeOpOperands);
}
rewriter.eraseOp(op);
return success();
- } else if (subgroupMmaStoreMatrixOp.src()
- .getType()
- .cast<gpu::MMAMatrixType>()
- .getElementType() == f32Ty) {
- for (unsigned i = 0, e = 8; i < e; ++i) {
- Value toUse = rewriter.create<LLVM::ExtractValueOp>(
- loc, f32Ty, operands[0], rewriter.getI32ArrayAttr(i));
- storeOpOperands.push_back(toUse);
- }
- storeOpOperands.push_back(leadingDim32);
-
+ }
+ if (srcType.getElementType().isF32()) {
// Create nvvm.mma_store op.
if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16)
rewriter.create<NVVM::WMMAStoreF32M16N16K16Op>(loc, storeOpOperands);
rewriter.eraseOp(op);
return success();
}
-
return failure();
}
};
/// This class implements the conversion of GPU MMA computeOp to wmma.mma op
/// in the NVVM dialect.
struct WmmaMmaOpToNVVMLowering
- : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp>,
- private CommonLLVMAndBuiltInMLIRTypes {
- explicit WmmaMmaOpToNVVMLowering(LLVMTypeConverter &typeConverter)
- : ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp>(typeConverter),
- CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
- }
+ : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> {
+ using ConvertOpToLLVMPattern<
+ gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
// values form lowered memrefs.
SmallVector<Value> unpackedOps;
- auto unpackOp = [&](CommonLLVMAndBuiltInMLIRTypes::OperandMap op,
- Value operand, unsigned numElems, Type elemType) {
- for (unsigned i = 0; i < numElems; ++i) {
+ auto unpackOp = [&](Value operand) {
+ auto structType = operand.getType().cast<LLVM::LLVMStructType>();
+ for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
Value toUse = rewriter.create<LLVM::ExtractValueOp>(
- loc, elemType, operand, rewriter.getI32ArrayAttr(i));
+ loc, structType.getBody()[i], operand, rewriter.getI32ArrayAttr(i));
unpackedOps.push_back(toUse);
}
};
subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> aTypeShape = aType.getShape();
gpu::MMAMatrixType bType =
- subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
+ subgroupMmaComputeOp.opB().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> bTypeShape = bType.getShape();
gpu::MMAMatrixType cType =
- subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
+ subgroupMmaComputeOp.opC().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> cTypeShape = cType.getShape();
gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands);
- if (subgroupMmaComputeOp.opC()
- .getType()
- .cast<gpu::MMAMatrixType>()
- .getElementType() == f16Ty) {
- unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty);
- unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty);
- unpackOp(C, transformedOperands.opC(), numHalfsInOpFrags[C], f16x2Ty);
+ unpackOp(transformedOperands.opA());
+ unpackOp(transformedOperands.opB());
+ unpackOp(transformedOperands.opC());
+ if (cType.getElementType().isF16()) {
if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
// Create nvvm.wmma.mma op.
- NVVM::WMMAMmaF16F16M16N16K16Op wmmaMmaOp =
- rewriter.create<NVVM::WMMAMmaF16F16M16N16K16Op>(loc, fragArrayCDTy,
- unpackedOps);
+ rewriter.replaceOpWithNewOp<NVVM::WMMAMmaF16F16M16N16K16Op>(
+ op, transformedOperands.opC().getType(), unpackedOps);
- rewriter.replaceOp(op, wmmaMmaOp.getResult());
return success();
- } else {
- return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
- } else if (subgroupMmaComputeOp.opC()
- .getType()
- .cast<gpu::MMAMatrixType>()
- .getElementType() == f32Ty) {
- unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty);
- unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty);
- unpackOp(C, transformedOperands.opC(), 8, f32Ty);
-
+ return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+ }
+ if (cType.getElementType().isF32()) {
if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
// Create nvvm.wmma.mma op.
- NVVM::WMMAMmaF32F32M16N16K16Op wmmaMmaOp =
- rewriter.create<NVVM::WMMAMmaF32F32M16N16K16Op>(
- loc, fragArrayCDF32Ty, unpackedOps);
+ rewriter.replaceOpWithNewOp<NVVM::WMMAMmaF32F32M16N16K16Op>(
+ op, transformedOperands.opC().getType(), unpackedOps);
- rewriter.replaceOp(op, wmmaMmaOp.getResult());
return success();
- } else {
- return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
+ return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
-
return failure();
}
};
namespace mlir {
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
- patterns.insert<WmmaLoadOpToNVVMLowering>(converter);
- patterns.insert<WmmaMmaOpToNVVMLowering>(converter);
- patterns.insert<WmmaStoreOpToNVVMLowering>(converter);
+ patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
+ WmmaStoreOpToNVVMLowering>(converter);
}
} // namespace mlir