"dialect";
let constructor = "mlir::createConvertVectorToGPUPass()";
let dependentDialects = [
- "memref::MemRefDialect",
- "gpu::GPUDialect"
+ "memref::MemRefDialect", "gpu::GPUDialect", "AffineDialect",
+ "vector::VectorDialect", "nvgpu::NVGPUDialect"
+ ];
+
+ let options = [
+ Option<"useNvGpu", "use-nvgpu", "bool", /*default=*/"false",
+ "convert to NvGPU ops instead of GPU dialect ops">
];
}
class RewritePatternSet;
/// Patterns to transform vector ops into a canonical form to convert to MMA
-/// matrix operations.
-void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns);
+/// matrix operations. If `useNvGpu` is true, then the patterns will populated
+/// will prepare for conversion to `nvgpu` mma operations rather than the `gpu`
+/// dialect WMMA operations.
+void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
+ bool useNvGpu = false);
/// Convert vector ops to MMA matrix operations nested under `rootOp`. This will
/// convert slice of operations that can be legally converted to MMA operations.
/// The rest of the vector operations are left untouched.
void convertVectorToMMAOps(Operation *rootOp);
+/// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons
+/// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice
+/// of operations that can be legally lowered on this path while the rest of
+/// the vector operations are left untouched.
+LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp);
+
/// Convert from vector to GPU ops.
-std::unique_ptr<Pass> createConvertVectorToGPUPass();
+std::unique_ptr<Pass> createConvertVectorToGPUPass(bool useNvGpu = false);
} // namespace mlir
class LLVMDialect;
} // namespace LLVM
+namespace nvgpu {
+class NVGPUDialect;
+}
+
namespace NVVM {
class NVVMDialect;
} // namespace NVVM
add_mlir_conversion_library(MLIRVectorToGPU
VectorToGPU.cpp
+ NvGpuSupport.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU
--- /dev/null
+//===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides utilities to assist in the lowering of Vector operations
+// to NvGPU dialect MMA operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "NvGpuSupport.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir {
+namespace nvgpu {
+namespace {
+
+/// There are always 4 threads per [128|256|512] bit row.
+constexpr int64_t kThreadsPerRow = 4;
+
+constexpr int64_t kNumRowsPerTile = 8;
+
+bool isAccumulatorOrResult(MatMulOperandRole operandType) {
+ return operandType == MatMulOperandRole::C;
+}
+
+/// Returns the number of registers which compose a matrix fragment held by a
+/// single thread.
+int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
+ int64_t lineSize = inferTileWidthInBits(type);
+ auto shape = type.vectorType.getShape();
+ return (shape[0] / kNumRowsPerTile) *
+ (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
+ lineSize;
+}
+
+/// Returns the number of 8 x [128|256|512] bit tiles that compose the given
+/// operand shape.
+std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
+ Type elementType, int64_t lineSizeBits) {
+ // For each 8x128bit square, a thread is responsible for one 32bit register.
+ return {operandShape[0] / kNumRowsPerTile,
+ (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
+ lineSizeBits};
+}
+
+} // namespace
+
+FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) {
+ WarpMatrixInfo info;
+
+ // Determine the vector type.
+ if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
+ info.vectorType = writeOp.getVectorType();
+ } else if (isa<vector::TransferReadOp, vector::ContractionOp,
+ arith::ConstantOp>(op)) {
+ info.vectorType = op->getResult(0).getType().cast<VectorType>();
+ } else {
+ return op->emitError()
+ << "unhandled operation type in nvgpu.mma.sync conversion path";
+ }
+
+ // Determine the operand role. We assume it is an accumulator/result unless it
+ // is directly consumed by a `vector.contract` op.
+ info.operandRole = MatMulOperandRole::C;
+ for (Operation *user : op->getUsers()) {
+ auto contract = dyn_cast<vector::ContractionOp>(user);
+ if (!contract)
+ continue;
+ if (contract.getLhs() == op->getResult(0)) {
+ info.operandRole = MatMulOperandRole::A;
+ break;
+ }
+ if (contract.getRhs() == op->getResult(0)) {
+ info.operandRole = MatMulOperandRole::B;
+ break;
+ }
+ }
+ return info;
+}
+
+int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
+ bool isAcc = isAccumulatorOrResult(type.operandRole);
+ Type elType = type.vectorType.getElementType();
+ if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
+ return 256;
+ }
+ if (elType.getIntOrFloatBitWidth() == 64) {
+ return isAcc ? 512 : 256;
+ }
+ return 128;
+}
+
+FailureOr<FragmentElementInfo>
+getMmaSyncRegisterType(const WarpMatrixInfo &type) {
+ MLIRContext *ctx = type.vectorType.getContext();
+ const bool isAccum = isAccumulatorOrResult(type.operandRole);
+
+ Type elType = type.vectorType.getElementType();
+ if (elType.isF16()) {
+ return FragmentElementInfo{
+ LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
+ inferNumRegistersPerMatrixFragment(type)};
+ }
+
+ // f64 operand
+ Type f64Ty = Float64Type::get(ctx);
+ if (elType.isF64()) {
+ return isAccum
+ ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
+ inferNumRegistersPerMatrixFragment(type)}
+ : FragmentElementInfo{f64Ty, 1, 64,
+ inferNumRegistersPerMatrixFragment(type)};
+ }
+
+ // int8 operand
+ if (elType.isInteger(8)) {
+ return FragmentElementInfo{
+ LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
+ inferNumRegistersPerMatrixFragment(type)};
+ }
+ // Integer 32bit acc operands
+ if (elType.isInteger(32)) {
+ return FragmentElementInfo{
+ LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
+ inferNumRegistersPerMatrixFragment(type)};
+ }
+
+ // Floating point 32bit operands
+ if (elType.isF32()) {
+ Type f32Ty = Float32Type::get(ctx);
+ return isAccum
+ ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
+ inferNumRegistersPerMatrixFragment(type)}
+ : FragmentElementInfo{f32Ty, 1, 32,
+ inferNumRegistersPerMatrixFragment(type)};
+ }
+ return failure();
+}
+
+static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
+ Type elementType,
+ ArrayRef<int64_t> operandShape,
+ bool isAccumulator,
+ int64_t elementsPerRegister,
+ AffineExpr logicalValueId) {
+ const int64_t elementsPerLine =
+ lineSize / elementType.getIntOrFloatBitWidth();
+ const std::array<int64_t, 2> num8x128bTiles =
+ getTileShape(operandShape, elementType, lineSize);
+ AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
+ return AffineMap::get(
+ 2, 0,
+ {(registerIdx % num8x128bTiles[0]) * 8,
+ (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
+ elementType.getContext());
+}
+
+FailureOr<AffineMap>
+getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+ const WarpMatrixInfo &fragmentType) {
+ Type elementType = fragmentType.vectorType.getElementType();
+ ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
+ FailureOr<nvgpu::FragmentElementInfo> regInfo =
+ getMmaSyncRegisterType(fragmentType);
+ if (failed(regInfo))
+ return failure();
+
+ const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
+ const int64_t elementsPerRegister =
+ regInfo->registerWidthBits / elementBitWidth;
+ const int64_t lineSize = inferTileWidthInBits(fragmentType);
+
+ AffineExpr laneId, logicalValueIdDim;
+ bindDims(builder.getContext(), laneId, logicalValueIdDim);
+
+ // Determine what register logicalValueId corresponds to. Use that as a
+ // linear index into the coordinate mapping `index -> (tile row, tile col)`.
+ AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
+ lineSize, elementType, operandShape,
+ isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
+ logicalValueIdDim);
+
+ auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
+ return AffineMap::get(2, 0, dimExprs, builder.getContext());
+ };
+
+ auto tileRow = registerIndexToTileCoord.getResult(0);
+ auto tileCol = registerIndexToTileCoord.getResult(1);
+ return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
+ tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
+ (logicalValueIdDim % elementsPerRegister)});
+}
+
+FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
+ bool transpose) {
+ LdMatrixParams params;
+ Type elType = type.vectorType.getElementType();
+ params.fragmentType = type.vectorType;
+ if (type.operandRole == MatMulOperandRole::A ||
+ type.operandRole == MatMulOperandRole::C) {
+ params.targetLayout = NVVM::MMALayout::row;
+ } else {
+ params.targetLayout = NVVM::MMALayout::col;
+ }
+ ArrayRef<int64_t> shape = type.vectorType.getShape();
+ params.contiguousDimType =
+ transpose ? IteratorType::Parallel : IteratorType::Reduction;
+
+ if (params.targetLayout == NVVM::MMALayout::row) {
+ params.numTiles = (shape[0] / kNumRowsPerTile) *
+ ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
+ } else {
+ params.numTiles = (shape[1] / kNumRowsPerTile) *
+ ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
+ }
+
+ if (params.numTiles == 0)
+ return failure();
+
+ return params;
+}
+
+FailureOr<AffineMap>
+getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+ const LdMatrixParams ¶ms) {
+ // One thread per 128b row.
+ const int64_t kNumThreadsPerTile = kNumRowsPerTile;
+ const int bitsPerElement = static_cast<int>(
+ params.fragmentType.getElementType().getIntOrFloatBitWidth());
+ const int kElementsPer128b = (128 / bitsPerElement);
+ ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
+ AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
+
+ auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
+ return AffineMap::get(1, 0, dimExprs, builder.getContext());
+ };
+
+ // This case corresponds to row-major A|C or col-major B operands.
+ if (params.contiguousDimType == IteratorType::Reduction) {
+ AffineExpr row = d0 % (operandShape[0]);
+ AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
+ return makeMap({row, col});
+ }
+
+ // This case Corresponds to col-major A|C or row-major B operands. The
+ // operandShape given is already pre-transposed (e.g. 8x16 = KxN).
+ if (params.contiguousDimType == IteratorType::Parallel) {
+ const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
+ // Threads are assigned in groups of 8 first across columns, then to
+ // rows. This is transpose of what `ldmatrix` expects, but when
+ // `ldmatrix` gets the `.trans` qualifier, final the effect will be to
+ // transpose just the blocks.
+ auto groupIdx = d0.floorDiv(kNumThreadsPerTile);
+ auto tileCol = (groupIdx % num8x128bCols);
+ auto tileRow = groupIdx.floorDiv(num8x128bCols);
+ return makeMap({tileCol * kElementsPer128b,
+ tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)});
+ }
+ return failure();
+}
+
+LogicalResult
+PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+ Value res = op.getAcc();
+
+ // Set up the parallel/reduction structure in right form.
+ using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+ auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+ AffineExpr m;
+ AffineExpr n;
+ AffineExpr k;
+ bindDims(rewriter.getContext(), m, n, k);
+ static constexpr std::array<int64_t, 2> perm = {1, 0};
+ auto iteratorTypes = op.getIteratorTypes().getValue();
+ SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+ if (iteratorTypes.size() != 3)
+ return failure();
+ if (!(isParallelIterator(iteratorTypes[0]) &&
+ isParallelIterator(iteratorTypes[1]) &&
+ isReductionIterator(iteratorTypes[2])))
+ return failure();
+
+ // The canonical form is "TNT" = A row-major, B col-major, C row-major.
+ const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
+ if (maps == canonicalForm) {
+ return failure();
+ }
+ if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+ rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+ rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+ std::swap(rhs, lhs);
+ rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+ std::swap(rhs, lhs);
+ rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+ } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+ std::swap(lhs, rhs);
+ lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+ } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+ std::swap(lhs, rhs);
+ } else {
+ return failure();
+ }
+ rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+ op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
+ op.getIteratorTypes());
+ return success();
+}
+
+} // namespace nvgpu
+} // namespace mlir
\ No newline at end of file
--- /dev/null
+//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides utilities to assist in the lowering of Vector operations
+// to GPU dialect MMA operations.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
+#define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace nvgpu {
+
+enum class MatMulOperandRole : int32_t { A = 0, B, C };
+
+/// Collects information about a warp-level matrix operand represented by a
+/// VectorType.
+struct WarpMatrixInfo {
+ VectorType vectorType;
+ MatMulOperandRole operandRole;
+};
+
+/// Given an op that operates on a VectorType representing a warp-level matrix
+/// operand, the function returns a struct containing relevant type information.
+FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op);
+
+/// Returns the number of bits in a single tile row. It is either 128, 256, or
+/// 512 bits depending on the data type and` whether the operand is an
+/// accumulator/result operand
+int64_t inferTileWidthInBits(const WarpMatrixInfo &type);
+
+/// Specifies information about the registers which compose a matrix fragment
+/// according to the PTX documentation.
+struct FragmentElementInfo {
+ Type registerLLVMType;
+ int64_t elementsPerRegister;
+ int64_t registerWidthBits;
+ int64_t numRegistersPerFragment;
+};
+
+/// Returns a FragmentElementInfo struct describing the register types for the
+/// given matrix fragment type.
+FailureOr<FragmentElementInfo>
+getMmaSyncRegisterType(const WarpMatrixInfo &type);
+
+/// Returns an AffineMap which maps a two dimensions representing (laneId,
+/// logicalValueId) and returns two results representing offsets within a
+/// matrix operand. The offsets point to the values the thread is responsible
+/// for (AKA the matrix fragment values) during a warp-collective matrix
+/// operation. For a visual reference of this LaneId -> (row, col) mapping,
+/// please see NVIDIA's PTX documentation:
+/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
+FailureOr<AffineMap>
+getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+ const WarpMatrixInfo &fragmentType);
+
+struct LdMatrixParams {
+ VectorType fragmentType;
+ bool isAccum;
+ int64_t numTiles;
+ IteratorType contiguousDimType;
+ NVVM::MMALayout targetLayout;
+};
+
+FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
+ bool transpose);
+/// Returns an AffineMap which maps a single dimension representing the laneId
+/// to two results representing offsets within the matrix operand that should
+/// be the pointer locations a thread should pass to the ldmatrix instruction.
+FailureOr<AffineMap>
+getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+ const LdMatrixParams ¶ms);
+
+// Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted
+// to MMA matmul.
+struct PrepareContractToGPUMMASync
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override;
+};
+
+} // namespace nvgpu
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
#include <type_traits>
+#include "NvGpuSupport.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
+/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
+/// AffineMap representing offsets to apply to indices, the function fills
+/// `indices` with the original indices plus the offsets. The offsets are
+/// applied by taking into account the permutation map of the transfer op. If
+/// the `offsetMap` has dimension placeholders, those should be provided in
+/// `dimValues`.
+template <typename TransferOpType>
+static void getXferIndices(OpBuilder &b, TransferOpType xferOp,
+ AffineMap offsetMap, ArrayRef<Value> dimValues,
+ SmallVector<Value, 4> &indices) {
+ indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
+ Location loc = xferOp.getLoc();
+ unsigned offsetsIdx = 0;
+ for (auto expr : xferOp.getPermutationMap().getResults()) {
+ if (auto dim = expr.template dyn_cast<AffineDimExpr>()) {
+ Value prevIdx = indices[dim.getPosition()];
+ SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end());
+ dims.push_back(prevIdx);
+ AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims());
+ indices[dim.getPosition()] = makeComposedAffineApply(
+ b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
+ continue;
+ }
+ }
+}
+
// Return true if the contract op can be convert to MMA matmul.
-static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
+static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
+ bool useNvGpu) {
if (llvm::size(contract.getMasks()) != 0)
return false;
// The contract needs to represent a matmul to be able to convert to
// MMAMatrix matmul.
- if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
+ if (!useNvGpu &&
+ contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
+ return false;
+ if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}}))
return false;
return true;
if (!memrefType)
return false;
// If the memref is 0 or 1D the horizontal stride is 0.
- if(memrefType.getRank() < 2)
+ if (memrefType.getRank() < 2)
return 0;
int64_t offset = 0;
SmallVector<int64_t, 2> strides;
}
// Return true if the transfer op can be converted to a MMA matrix load.
-static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
+static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
+ bool useNvGpu) {
if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
readOp.getVectorType().getRank() != 2)
return false;
AffineExpr zero = b.getAffineConstantExpr(0);
auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
readOp.getContext());
- // TODO: Support transpose once it is added to GPU dialect ops.
- // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
- return !(!map.isMinorIdentity() && map != broadcastInnerDim);
+
+ if (!useNvGpu) {
+ // TODO: Support transpose once it is added to GPU dialect ops.
+ // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
+ return map.isMinorIdentity() || map == broadcastInnerDim;
+ }
+
+ return true;
}
// Return true if the transfer op can be converted to a MMA matrix store.
return convertElementwiseOpToMMA(op).hasValue();
}
-static bool supportsMMaMatrixType(Operation *op) {
+static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
if (isa<scf::ForOp, scf::YieldOp>(op))
return true;
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
- return transferReadSupportsMMAMatrixType(transferRead);
+ return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteSupportsMMAMatrixType(transferWrite);
if (auto contract = dyn_cast<vector::ContractionOp>(op))
- return contractSupportsMMAMatrixType(contract);
+ return contractSupportsMMAMatrixType(contract, useNvGpu);
if (auto constant = dyn_cast<arith::ConstantOp>(op))
return constantSupportsMMAMatrixType(constant);
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
// Analyze slice of operations based on convert op to figure out if the whole
// slice can be converted to MMA operations.
-static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
+static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
+ bool useNvGpu) {
auto hasVectorDest = [](Operation *op) {
return llvm::any_of(op->getResultTypes(),
[](Type t) { return t.isa<VectorType>(); });
// If any instruction cannot use MMA matrix type drop the whole
// chain. MMA matrix are stored in an opaque type so they cannot be used
// by all operations.
- if (llvm::any_of(dependentOps,
- [](Operation *op) { return !supportsMMaMatrixType(op); }))
+ if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
+ return !supportsMMaMatrixType(op, useNvGpu);
+ }))
return;
opToConvert.insert(dependentOps.begin(), dependentOps.end());
});
static void convertTransferReadOp(vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
- assert(transferReadSupportsMMAMatrixType(op));
+ assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
Optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType());
AffineMap map = op.getPermutationMap();
op.erase();
}
+/// Returns the vector type which represents a matrix fragment.
+static VectorType
+getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) {
+ SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
+ regInfo.elementsPerRegister};
+ Type elType = regInfo.registerLLVMType;
+ if (auto vecType = elType.dyn_cast<VectorType>())
+ elType = vecType.getElementType();
+ return VectorType::get(shape, elType);
+}
+
+/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
+static LogicalResult
+convertConstantOpMmaSync(arith::ConstantOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder b(op);
+ FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+ nvgpu::getWarpMatrixInfo(op);
+ if (failed(warpMatrixInfo))
+ return failure();
+
+ FailureOr<nvgpu::FragmentElementInfo> regInfo =
+ nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+ if (failed(regInfo))
+ return failure();
+
+ VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+ auto dense = op.getValue().dyn_cast<SplatElementsAttr>();
+ if (!dense)
+ return failure();
+ Value result = b.create<arith::ConstantOp>(
+ op.getLoc(), vectorType,
+ DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
+ valueMapping[op.getResult()] = result;
+ return success();
+}
+
+static LogicalResult
+creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ Location loc = op->getLoc();
+
+ FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+ nvgpu::getWarpMatrixInfo(op);
+ if (failed(warpMatrixInfo))
+ return failure();
+
+ FailureOr<nvgpu::FragmentElementInfo> regInfo =
+ nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+ if (failed(regInfo))
+ return failure();
+
+ FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams(
+ *warpMatrixInfo,
+ /*transpose=*/!op.getPermutationMap().isMinorIdentity());
+ if (failed(params)) {
+ return op->emitError()
+ << "failed to convert vector.transfer_read to ldmatrix; this op "
+ "likely "
+ "should not be converted to a nvgpu.ldmatrix call.";
+ }
+
+ // Adjust the load offset.
+ auto laneId = builder.create<gpu::LaneIdOp>(loc);
+ FailureOr<AffineMap> offsets =
+ nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params);
+ if (failed(offsets))
+ return failure();
+
+ VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+
+ SmallVector<Value, 4> indices;
+ getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId},
+ indices);
+ nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>(
+ loc, vectorType, op.getSource(), indices,
+ !op.getPermutationMap().isMinorIdentity(), params->numTiles);
+ valueMapping[op] = newOp->getResult(0);
+ return success();
+}
+
+static LogicalResult
+createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ Location loc = op.getLoc();
+ FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+ nvgpu::getWarpMatrixInfo(op);
+ if (failed(warpMatrixInfo))
+ return failure();
+ FailureOr<nvgpu::FragmentElementInfo> regInfo =
+ nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+ if (failed(regInfo)) {
+ op->emitError() << "Failed to deduce register fragment type during "
+ "conversion to distributed non-ldmatrix compatible load";
+ return failure();
+ }
+
+ NVVM::MMALayout targetLayout =
+ warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B
+ ? NVVM::MMALayout::col
+ : NVVM::MMALayout::row;
+
+ Value laneId = builder.create<gpu::LaneIdOp>(loc);
+ SmallVector<Value, 4> elements;
+
+ // This is the individual element type.
+ Type loadedElType = regInfo->registerLLVMType;
+ VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+
+ Value fill = builder.create<arith::ConstantOp>(
+ op.getLoc(), vectorType.getElementType(),
+ builder.getZeroAttr(vectorType.getElementType()));
+ Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+
+ bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
+
+ // Vectorized loads.
+ if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) {
+ if (!loadedElType.isa<VectorType>()) {
+ loadedElType = VectorType::get({1}, loadedElType);
+ }
+
+ for (int i = 0; i < vectorType.getShape()[0]; i++) {
+ FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
+ op.getLoc(), builder, *warpMatrixInfo);
+ if (failed(coords))
+ return failure();
+ Value logicalValueId = builder.create<arith::ConstantOp>(
+ loc, builder.getIndexType(),
+ builder.getIndexAttr(i * regInfo->elementsPerRegister));
+ SmallVector<Value, 4> newIndices;
+ getXferIndices<vector::TransferReadOp>(
+ builder, op, *coords, {laneId, logicalValueId}, newIndices);
+
+ Value el = builder.create<vector::LoadOp>(loc, loadedElType,
+ op.getSource(), newIndices);
+ result = builder.create<vector::InsertOp>(loc, el, result,
+ builder.getI64ArrayAttr(i));
+ }
+ } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) {
+ if (auto vecType = loadedElType.dyn_cast<VectorType>()) {
+ loadedElType = vecType.getElementType();
+ }
+ // Load each element individually.
+ for (int i = 0; i < vectorType.getShape()[0]; i++) {
+ for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
+ innerIdx++) {
+
+ Value logicalValueId = builder.create<arith::ConstantOp>(
+ loc, builder.getIndexType(),
+ builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
+ FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
+ op.getLoc(), builder, *warpMatrixInfo);
+ if (failed(coords))
+ return failure();
+
+ SmallVector<Value, 4> newIndices;
+ getXferIndices<vector::TransferReadOp>(
+ builder, op, *coords, {laneId, logicalValueId}, newIndices);
+ Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType,
+ op.getSource(), newIndices);
+ result = builder.create<vector::InsertOp>(
+ op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx}));
+ }
+ }
+ } else {
+ return failure();
+ }
+
+ valueMapping[op.getResult()] = result;
+ return success();
+}
+
+/// Converts a `vector.transfer_read` operation directly to either a
+/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
+/// used when converting to `nvgpu.mma.sync` operations.
+static LogicalResult
+convertTransferReadToLoads(vector::TransferReadOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder b(op);
+
+ FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+ nvgpu::getWarpMatrixInfo(op);
+ if (failed(warpMatrixInfo))
+ return failure();
+
+ bool isLdMatrixCompatible =
+ op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 &&
+ nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
+
+ VectorType vecTy = op.getVectorType();
+ int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
+
+ // When we are transposing the B operand, ldmatrix will only work if we have
+ // at least 8 rows to read and the width to read for the transpose is 128
+ // bits.
+ if (!op.getPermutationMap().isMinorIdentity() &&
+ (vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128))
+ isLdMatrixCompatible = false;
+
+ if (!isLdMatrixCompatible)
+ return createNonLdMatrixLoads(op, b, valueMapping);
+
+ return creatLdMatrixCompatibleLoads(op, b, valueMapping);
+}
+
+static LogicalResult
+convertTransferWriteToStores(vector::TransferWriteOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder b(op);
+ Location loc = op->getLoc();
+ Value matrix = valueMapping.find(op.getVector())->second;
+
+ FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+ nvgpu::getWarpMatrixInfo(op);
+ if (failed(warpMatrixInfo))
+ return failure();
+ FailureOr<nvgpu::FragmentElementInfo> regInfo =
+ nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+ if (failed(regInfo))
+ return failure();
+
+ VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+ Value laneId = b.create<gpu::LaneIdOp>(loc);
+
+ for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
+ Value logicalValueId = b.create<arith::ConstantOp>(
+ loc, b.getIndexType(),
+ b.getIndexAttr(i * regInfo->elementsPerRegister));
+ FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
+ op.getLoc(), b, *warpMatrixInfo);
+ if (failed(coords))
+ return failure();
+
+ Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
+ SmallVector<Value, 4> newIndices;
+ getXferIndices<vector::TransferWriteOp>(
+ b, op, *coords, {laneId, logicalValueId}, newIndices);
+ b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
+ }
+ op->erase();
+ return success();
+}
+
static void convertContractOp(vector::ContractionOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
OpBuilder b(op);
valueMapping[op.getResult()] = matmul;
}
+static LogicalResult
+convertContractOpToMmaSync(vector::ContractionOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder b(op);
+ Value opA = valueMapping.find(op.getLhs())->second;
+ Value opB = valueMapping.find(op.getRhs())->second;
+ Value opC = valueMapping.find(op.getAcc())->second;
+ int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
+ int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
+ int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
+ Value matmul = b.create<nvgpu::MmaSyncOp>(
+ op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k}));
+ valueMapping[op.getResult()] = matmul;
+ return success();
+}
+
/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static void convertConstantOp(arith::ConstantOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
valueMapping[op->getResult(0)] = newOp;
}
-void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
- patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
- patterns.getContext());
+void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
+ bool useNvGpu) {
+ if (!useNvGpu) {
+ patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
+ patterns.getContext());
+ return;
+ }
+ patterns
+ .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>(
+ patterns.getContext());
}
void mlir::convertVectorToMMAOps(Operation *rootOp) {
- SetVector<Operation *> ops = getOpToConvert(rootOp);
+ SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
llvm::DenseMap<Value, Value> valueMapping;
for (Operation *op : ops) {
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
}
}
+LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
+ SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
+ llvm::DenseMap<Value, Value> valueMapping;
+ for (Operation *op : ops) {
+ if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
+ .Case([&](vector::TransferReadOp transferReadOp) {
+ return convertTransferReadToLoads(transferReadOp, valueMapping);
+ })
+ .Case([&](vector::TransferWriteOp transferWriteOp) {
+ return convertTransferWriteToStores(transferWriteOp,
+ valueMapping);
+ })
+ .Case([&](vector::ContractionOp contractionOp) {
+ return convertContractOpToMmaSync(contractionOp, valueMapping);
+ })
+ .Case([&](scf::ForOp forOp) {
+ convertForOp(forOp, valueMapping);
+ return success();
+ })
+ .Case([&](scf::YieldOp yieldOp) {
+ convertYieldOp(yieldOp, valueMapping);
+ return success();
+ })
+ .Case([&](arith::ConstantOp constOp) {
+ return convertConstantOpMmaSync(constOp, valueMapping);
+ })
+ .Default([&](Operation *op) {
+ op->emitError() << "unhandled vector to mma type: " << *op;
+ return failure();
+ })
+ .failed()) {
+ op->emitError() << "Failed to convert op " << *op;
+ return failure();
+ }
+ }
+ return success();
+}
+
namespace {
struct ConvertVectorToGPUPass
: public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
+
+ explicit ConvertVectorToGPUPass(bool useNvGpu_) {
+ useNvGpu.setValue(useNvGpu_);
+ }
+
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- populatePrepareVectorToMMAPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+
+ if (useNvGpu.getValue()) {
+ if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
+ return signalPassFailure();
+ }
- convertVectorToMMAOps(getOperation());
+ (void)convertVectorToMMAOps(getOperation());
}
};
} // namespace
-std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() {
- return std::make_unique<ConvertVectorToGPUPass>();
+std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
+ return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
}
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -pass-pipeline="func.func(convert-vector-to-gpu{use-nvgpu=true})" | FileCheck %s
+
+//#########################################################
+// INT8 row-row-row
+//#########################################################
+
+// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16 + 1)>
+
+// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 39)>
+// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)>
+// CHECK-DAG: [[$rowB1_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 40)>
+// CHECK-DAG: [[$rowB2_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 41)>
+// CHECK-DAG: [[$rowB3_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 42)>
+// CHECK-DAG: [[$rowB4_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 55)>
+// CHECK-DAG: [[$rowB5_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 56)>
+// CHECK-DAG: [[$rowB6_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 57)>
+// CHECK-DAG: [[$rowB7_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 58)>
+
+// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)>
+// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)>
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 57)>
+
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m16n8k32_int8_row_row_row
+func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8, 3>, %arg2: memref<128x128xi32>) {
+ %cst_0 = arith.constant dense<0> : vector<32x8xi8>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c17 = arith.constant 17 : index
+ %c39 = arith.constant 39 : index
+ %c40 = arith.constant 40 : index
+ %c49 = arith.constant 49 : index
+ %c50 = arith.constant 50 : index
+ %cst = arith.constant 0 : i8
+ %cst0 = arith.constant 0 : i32
+
+ // Verify that the operand A is distributed to loads correctly.
+
+ // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[{{%.+}}]
+ // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[{{%.+}}]
+ // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8>
+
+ // Verify that the operand B is distributed to loads correctly. It's elements
+ // must be loaded in a non-vectorized manner to do the transpose.
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+ // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB1_map]]()[{{%.+}}]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+ // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB2_map]]()[{{%.+}}]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+ // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB3_map]]()[{{%.+}}]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+ // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB4_map]]()[{{%.+}}]
+ // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB5_map]]()[{{%.+}}]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+ // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB6_map]]()[{{%.+}}]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+ // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB7_map]]()[{{%.+}}]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+ // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+ // CHECK-NOT: memref.load %arg1
+
+ // Verify that the operand C is distributed to loads correctly.
+ // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
+ // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+ // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+ // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}]
+ // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+ // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+ // CHECK-NOT: vector.load %arg2{{.*}}
+
+ %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8>
+ %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, 3>, vector<8x32xi8>
+ %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
+ // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32>
+
+ // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
+ // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+ // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+ // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}]
+ // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+ // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+ vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32>
+ return
+}
+
+// -----
+
+//#########################################################
+// f64 row-row-row
+//#########################################################
+// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 1)>
+// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 1)>
+
+// CHECK-DAG: [[$rowb0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 39)>
+// CHECK-DAG: [[$colb0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)>
+
+// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)>
+// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m8n8k4_f64_row_row_row
+func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x128xf64>, %arg2: memref<128x128xf64>) {
+ %cst_0 = arith.constant dense<0.0> : vector<4x8xf64>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c17 = arith.constant 17 : index
+ %c39 = arith.constant 39 : index
+ %c40 = arith.constant 40 : index
+ %c49 = arith.constant 49 : index
+ %c50 = arith.constant 50 : index
+ %cst = arith.constant 0.0 : f64
+ %cst0 = arith.constant 0.0 : f64
+
+ // Verify that the operand A is distributed to loads correctly.
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]]
+ // CHECK: vector.load %arg0[[[row]], [[col]]] : memref<128x128xf64>, vector<1xf64>
+
+ // Verify that the operand B is distributed to loads correctly. It's elements
+ // must be loaded in a non-vectorized manner to do the transpose.
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowb0_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colb0_map]]
+ // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xf64>
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]]
+ // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64>
+
+ %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x4xf64>
+ %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64>
+ %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64>
+ // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64>
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]]
+ // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64>
+ vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<8x8xf64>, memref<128x128xf64>
+ return
+}
+
+// -----
+
+//#########################################################
+// FP16 row-row-row
+//#########################################################
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
+
+// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)>
+// CHECK-DAG: [[$colB_map:#.+]] = affine_map<() -> (3)>
+
+// CHECK-LABEL: func @m16n8k16_fp16_row_row_row
+func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+ // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+ // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true}
+ %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
+ %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3>
+ return
+}
+
+// -----
+
+// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
+// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)>
+// CHECK-DAG: [[$Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)>
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @batch_m16n8k16_fp16_row_row_row
+func.func @batch_m16n8k16_fp16_row_row_row(%arg0: memref<2x20x20xf16, 3>, %arg1: memref<2x20x20xf16, 3>, %arg2: memref<2x20x20xf16, 3>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16>
+ // CHECK: [[C0:%.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 0.000000e+00 : f16
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]]
+ // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16>
+ %A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16>
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$Brow_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$Bcol_map]]
+ // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
+ %B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16>
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]]
+ // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
+ %C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+ vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3>
+ return
+}
+
+// -----
+
+//#########################################################
+// FP16 row-col-row
+//#########################################################
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
+
+// CHECK: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + 1)>
+// CHECK: [[$colB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 + 3)>
+
+// CHECK-LABEL: func @m16n8k16_fp16_row_col_row
+func.func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+ // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32
+ // CHECK-SAME: transpose = false
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+ // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32
+ // CHECK-SAME: transpose = false
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+ // CHECK: nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 2 : i32
+ // CHECK-SAME: transpose = false
+ %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
+ %C = vector.transfer_read %arg2[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3>
+ return
+}
+
+// -----
+
+//#########################################################
+// TF32 (multiplicand) F32 (accumulator) row-row-row
+//#########################################################
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 4 + 3)>
+
+// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 3)>
+// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 3)>
+
+// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)>
+
+// CHECK-LABEL: func @m16n8k4_tf32_f32_row_row_row
+func.func @m16n8k4_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 0.000000e+00 : f32
+
+ // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32>
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+ // CHECK: [[a_frag:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false}
+
+ // b and c are not loaded by ldmatrix in this test.
+ // CHECK-NOT: nvgpu.ldmatrix
+
+ // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+ // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+ // CHECK: [[b_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
+ // CHECK: [[b_frag:%.+]] = vector.insert [[b_el]], {{.*}} : f32 into vector<1x1xf32>
+
+ // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
+ // CHECK-SAME: mmaShape = [16, 8, 4]
+ // CHECK-SAME: -> vector<2x2xf32>
+ %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<16x4xf32>
+ %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x4xf32>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x4xf32>, vector<8x4xf32> into vector<16x8xf32>
+
+ // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32>
+ // CHECK: affine.apply [[$rowC_map]]
+ // CHECK: affine.apply [[$colC_map]]
+ // CHECK: vector.store
+ // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32>
+ // CHECK: affine.apply [[$rowC8_map]]
+ // CHECK: affine.apply [[$colC_map]]
+ // CHECK: vector.store
+ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32>
+ return
+}