From edd9515bd125634f40ebc2e783d6a127345e7c0d Mon Sep 17 00:00:00 2001 From: thomasraoux Date: Fri, 11 Jun 2021 07:39:01 -0700 Subject: [PATCH] [mlir][VectorToGPU] First step to convert vector ops to GPU MMA ops This is the first step to convert vector ops to MMA operations in order to target GPUs tensor core ops. This currently only support simple cases, transpose and element-wise operation will be added later. Differential Revision: https://reviews.llvm.org/D102962 --- mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 14 + .../mlir/Conversion/VectorToGPU/VectorToGPU.h | 34 +++ mlir/lib/Conversion/CMakeLists.txt | 1 + mlir/lib/Conversion/VectorToGPU/CMakeLists.txt | 15 + mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp | 338 +++++++++++++++++++++ .../Conversion/VectorToGPU/vector-to-mma-ops.mlir | 43 +++ 7 files changed, 446 insertions(+) create mode 100644 mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h create mode 100644 mlir/lib/Conversion/VectorToGPU/CMakeLists.txt create mode 100644 mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp create mode 100644 mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index a78b728..0cac29f 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -38,6 +38,7 @@ #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Conversion/TosaToSCF/TosaToSCF.h" #include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index ba5e27a..47f328b 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -514,6 +514,20 @@ def TosaToStandard : Pass<"tosa-to-standard"> { } //===----------------------------------------------------------------------===// +// VectorToGPU +//===----------------------------------------------------------------------===// + +def ConvertVectorToGPU : FunctionPass<"convert-vector-to-gpu"> { + let summary = "Lower the operations from the vector dialect into the GPU " + "dialect"; + let constructor = "mlir::createConvertVectorToGPUPass()"; + let dependentDialects = [ + "memref::MemRefDialect", + "gpu::GPUDialect" + ]; +} + +//===----------------------------------------------------------------------===// // VectorToSCF //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h new file mode 100644 index 0000000..5f6f7aa --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h @@ -0,0 +1,34 @@ +//===- VectorToGPU.h - Convert vector to GPU dialect ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSCF_VECTORTOGPU_H_ +#define MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSCF_VECTORTOGPU_H_ + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +class MLIRContext; +class Pass; +class FuncOp; +class RewritePatternSet; + +/// Patterns to transform vector ops into a canonical form to convert to MMA +/// matrix operations. +void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns); + +/// Convert vector ops to MMA matrix operations. This will convert slice of +/// operations that can be legally converted to MMA operations. The rest of the +/// vector operations are left untouched. +void convertVectorToMMAOps(FuncOp funcOp); + +/// Convert from vector to GPU ops. +std::unique_ptr createConvertVectorToGPUPass(); + +} // namespace mlir + +#endif // MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSCF_VECTORTOGPU_H_ diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 72cfb08..66b3895 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -29,5 +29,6 @@ add_subdirectory(TosaToSCF) add_subdirectory(TosaToStandard) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) +add_subdirectory(VectorToGPU) add_subdirectory(VectorToSCF) add_subdirectory(VectorToSPIRV) diff --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt new file mode 100644 index 0000000..484ad54 --- /dev/null +++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_conversion_library(MLIRVectorToGPU + VectorToGPU.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRGPU + MLIRLLVMIR + MLIRMemRef + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp new file mode 100644 index 0000000..227890b --- /dev/null +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -0,0 +1,338 @@ +//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering of vector operations to GPU dialect ops. +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" + +#include "../PassDetail.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +// Return true if the contract op can be convert to MMA matmul. +static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { + if (llvm::size(contract.masks()) != 0) + return false; + + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n, k; + bindDims(contract.getContext(), m, n, k); + auto iteratorTypes = contract.iterator_types().getValue(); + if (!(isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2]))) + return false; + + // The contract needs to represent a matmul to be able to convert to + // MMAMatrix matmul. + if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) + return false; + + // Check that the size matches what is natively supported. + VectorType lhsType = contract.lhs().getType().cast(); + VectorType rhsType = contract.rhs().getType().cast(); + VectorType accType = contract.acc().getType().cast(); + + std::tuple dim(lhsType.getDimSize(0), rhsType.getDimSize(1), + lhsType.getDimSize(1)); + if (lhsType.getElementType().isInteger(8) && + rhsType.getElementType().isInteger(8) && + accType.getElementType().isInteger(32) && + (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) || + dim == std::make_tuple(16, 8, 32))) + return true; + + if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() && + (accType.getElementType().isF16() || accType.getElementType().isF32()) && + (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) || + dim == std::make_tuple(16, 8, 16))) + return true; + return false; +} + +// Return the stide for the dimension 0 of |type| if it is a memref and has a +// constant stride. +static llvm::Optional +getMemrefConstantHorizontalStride(ShapedType type) { + auto memrefType = type.dyn_cast(); + if (!memrefType) + return false; + int64_t offset = 0; + SmallVector strides; + if (failed(getStridesAndOffset(memrefType, strides, offset))) + return llvm::None; + if (strides[0] == ShapedType::kDynamicStrideOrOffset) + return llvm::None; + return strides[0]; +} + +// Return true if the transfer op can be converted to a MMA matrix load. +static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { + if (readOp.mask() || readOp.hasOutOfBoundsDim() || + readOp.getVectorType().getRank() != 2) + return false; + if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) + return false; + // TODO: Support transpose once it is added to GPU dialect ops. + if (!readOp.permutation_map().isMinorIdentity()) + return false; + return true; +} + +// Return true if the transfer op can be converted to a MMA matrix store. +static bool +transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { + if (writeOp.mask() || writeOp.hasOutOfBoundsDim() || + writeOp.getVectorType().getRank() != 2) + return false; + if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) + return false; + // TODO: Support transpose once it is added to GPU dialect ops. + if (!writeOp.permutation_map().isMinorIdentity()) + return false; + return true; +} + +static bool supportsMMaMatrixType(Operation *op) { + if (auto transferRead = dyn_cast(op)) + return transferReadSupportsMMAMatrixType(transferRead); + if (auto transferWrite = dyn_cast(op)) + return transferWriteSupportsMMAMatrixType(transferWrite); + if (auto contract = dyn_cast(op)) + return contractSupportsMMAMatrixType(contract); + return false; +} + +// Analyze slice of operations based on convert op to figure out if the whole +// slice can be converted to MMA operations. +static SetVector getOpToConvert(mlir::Operation *op) { + auto hasVectorDest = [](Operation *op) { + return op->getNumResults() == 0 || + llvm::any_of(op->getResultTypes(), + [](Type t) { return t.isa(); }); + }; + SetVector opToConvert; + op->walk([&](vector::ContractionOp contract) { + if (opToConvert.contains(contract.getOperation())) + return; + SetVector dependentOps = + getSlice(contract, hasVectorDest, hasVectorDest); + // If any instruction cannot use MMA matrix type drop the whole + // chaine. MMA matrix are stored in an opaque type so they cannot be used + // by all operations. + if (llvm::any_of(dependentOps, + [](Operation *op) { return !supportsMMaMatrixType(op); })) + return; + opToConvert.insert(dependentOps.begin(), dependentOps.end()); + }); + return opToConvert; +} + +namespace { +// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted +// to MMA matmul. +struct PrepareContractToGPUMMA + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); + + // Set up the parallel/reduction structure in right form. + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n, k; + bindDims(rewriter.getContext(), m, n, k); + static constexpr std::array perm = {1, 0}; + auto iteratorTypes = op.iterator_types().getValue(); + SmallVector maps = op.getIndexingMaps(); + if (!(isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2]))) + return failure(); + // + // Two outer parallel, one inner reduction (matmat flavor). + // + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + // This is the classical row-major matmul, nothing to do. + return failure(); + } + if (maps == infer({{m, k}, {n, k}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + std::swap(lhs, rhs); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + std::swap(lhs, rhs); + } else { + return failure(); + } + rewriter.replaceOpWithNewOp( + op, lhs, rhs, res, + rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), + op.iterator_types()); + return success(); + } +}; + +// Merge transpose op into the transfer read op. Transpose are not supported on +// MMA types but MMA load can transpose the matrix when loading. +struct CombineTransferReadOpTranspose final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto transferReadOp = op.vector().getDefiningOp(); + if (!transferReadOp) + return failure(); + if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim()) + return failure(); + SmallVector perm; + op.getTransp(perm); + SmallVector permU; + for (int64_t o : perm) + permU.push_back(unsigned(o)); + AffineMap permutationMap = + AffineMap::getPermutationMap(permU, op.getContext()); + AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map()); + rewriter.replaceOpWithNewOp( + op, op.getType(), transferReadOp.source(), transferReadOp.indices(), + newMap, transferReadOp.padding(), transferReadOp.mask(), + transferReadOp.in_boundsAttr()); + return success(); + } +}; + +} // namespace + +// MMA types have different layout based on how they are used in matmul ops. +// Figure the right layout to use by looking at Transfer op uses. +// TODO: Change the GPU dialect to abstract the layout at the this level and +// only care about it during lowering to NVVM. +static const char *inferFragType(vector::TransferReadOp op) { + for (Operation *users : op->getUsers()) { + auto contract = dyn_cast(users); + if (!contract) + continue; + if (contract.lhs() == op.getResult()) + return "AOp"; + if (contract.rhs() == op.getResult()) + return "BOp"; + } + return "COp"; +} + +static void convertTransferReadOp(vector::TransferReadOp op, + llvm::DenseMap &valueMapping) { + assert(transferReadSupportsMMAMatrixType(op)); + Optional stride = + getMemrefConstantHorizontalStride(op.getShapedType()); + assert(stride); + const char *fragType = inferFragType(op); + gpu::MMAMatrixType type = + gpu::MMAMatrixType::get(op.getVectorType().getShape(), + op.getVectorType().getElementType(), fragType); + OpBuilder b(op); + Value load = b.create( + op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride)); + valueMapping[op.getResult()] = load; +} + +static void convertTransferWriteOp(vector::TransferWriteOp op, + llvm::DenseMap &valueMapping) { + assert(transferWriteSupportsMMAMatrixType(op)); + Optional stride = + getMemrefConstantHorizontalStride(op.getShapedType()); + assert(stride); + OpBuilder b(op); + Value matrix = valueMapping.find(op.vector())->second; + b.create( + op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride)); + op.erase(); +} + +static void convertContractOp(vector::ContractionOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + Value opA = valueMapping.find(op.lhs())->second; + Value opB = valueMapping.find(op.rhs())->second; + Value opC = valueMapping.find(op.acc())->second; + Value matmul = b.create(op.getLoc(), opC.getType(), + opA, opB, opC); + valueMapping[op.getResult()] = matmul; +} + +namespace mlir { + +void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +void convertVectorToMMAOps(FuncOp funcOp) { + SetVector ops = getOpToConvert(funcOp); + llvm::DenseMap valueMapping; + for (Operation *op : ops) { + if (auto transferRead = dyn_cast(op)) { + convertTransferReadOp(transferRead, valueMapping); + } else if (auto transferWrite = dyn_cast(op)) { + convertTransferWriteOp(transferWrite, valueMapping); + } else if (auto contractOp = dyn_cast(op)) { + convertContractOp(contractOp, valueMapping); + } + } +} + +} // namespace mlir +namespace { + +struct ConvertVectorToGPUPass + : public ConvertVectorToGPUBase { + void runOnFunction() override { + RewritePatternSet patterns(getFunction().getContext()); + populatePrepareVectorToMMAPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + + convertVectorToMMAOps(getFunction()); + } +}; + +} // namespace + +std::unique_ptr mlir::createConvertVectorToGPUPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir new file mode 100644 index 0000000..5005cc6 --- /dev/null +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-opt %s -convert-vector-to-gpu -canonicalize | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @matmul +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> +func @matmul(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) { + %cst_0 = constant dense<0.000000e+00> : vector<16x16xf16> + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> + return +} + +// Negative test until scf.for support is added. +// CHECK-LABEL: func @matmul_loop +// CHECK: vector.contract +func @matmul_loop(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) { + %c0 = constant 0 : index + %c128 = constant 128 : index + %c32 = constant 32 : index + %cst = constant 0.000000e+00 : f16 + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16> + %14 = scf.for %arg17 = %c0 to %c128 step %c32 iter_args(%arg18 = %C) -> (vector<16x16xf16>) { + %17 = vector.transfer_read %arg0[%c0, %arg17], %cst {in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16> + %18 = vector.transfer_read %arg1[%arg17, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16> + %19 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %17, %18, %arg18 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + scf.yield %19 : vector<16x16xf16> + } + vector.transfer_write %14, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<128x128xf16> + return +} -- 2.7.4