From 1ac874c9aa1859fe67fad110c278588a5a670d78 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 15 Feb 2023 06:00:12 +0000 Subject: [PATCH] [mlir][Vector] Add support for masked vector gather ops This patch adds support for masked vector.gather ops using the vector.mask representation. It includes the implementation of the MaskableOpInterface, Linalg vectorizer support and lowering to LLVM. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D143939 --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- .../Dialect/Linalg/Transforms/Vectorization.cpp | 58 ++++++++++++---------- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 ++++ .../Dialect/Vector/Transforms/LowerVectorMask.cpp | 28 ++++++++++- mlir/test/Dialect/Vector/lower-vector-mask.mlir | 29 +++++++++++ 5 files changed, 97 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 6f6d80c..94d4b64 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1846,7 +1846,7 @@ def Vector_MaskedStoreOp : } def Vector_GatherOp : - Vector_Op<"gather">, + Vector_Op<"gather", [DeclareOpInterfaceMethods]>, Arguments<(ins Arg:$base, Variadic:$indices, VectorOf<[AnyInteger, Index]>:$index_vec, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 5a20d23..fc36477 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -416,8 +416,7 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value, vector::BroadcastableToResult::Success) return value; Location loc = b.getInsertionPoint()->getLoc(); - return b.createOrFold(loc, targetVectorType, - value); + return b.createOrFold(loc, targetVectorType, value); } /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This @@ -532,14 +531,16 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult -vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) { +static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, + VectorizationState &state, + Operation *op, + LinalgOp linalgOp) { IndexOp indexOp = dyn_cast(op); if (!indexOp) return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = indexOp.getLoc(); // Compute the static loop sizes of the index op. - auto targetShape = linalgOp.computeStaticLoopSizes(); + auto targetShape = llvm::to_vector(state.getCanonicalVecShape()); // Compute a one-dimensional index vector for the index op dimension. SmallVector constantSeq = llvm::to_vector<16>(llvm::seq(0, targetShape[indexOp.getDim()])); @@ -597,32 +598,33 @@ tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) { /// /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to: /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3 -static Value -calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp, - const IRMapping &bvm, - const SmallVectorImpl &targetShape) { +static Value calculateGatherOffset(RewriterBase &rewriter, + tensor::ExtractOp extractOp, + const IRMapping &bvm, + const ArrayRef targetShape) { // The vector of indices for GatherOp should be shaped as the output vector - auto indexVecType = VectorType::get(targetShape, b.getIndexType()); + auto indexVecType = VectorType::get(targetShape, rewriter.getIndexType()); auto loc = extractOp.getLoc(); - Value offset = b.create( - loc, indexVecType, bvm.lookup(extractOp.getIndices()[0])); + Value offset = broadcastIfNeeded( + rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType.getShape()); const size_t numIndices = extractOp.getIndices().size(); for (size_t i = 1; i < numIndices; i++) { auto dimSize = broadcastIfNeeded( - b, - b.create( + rewriter, + rewriter.create( loc, extractOp.getTensor().getType().cast().getDimSize(i)), indexVecType.getShape()); - offset = b.create(loc, offset, dimSize); + offset = rewriter.create(loc, offset, dimSize); - auto extractOpIndex = broadcastIfNeeded( - b, bvm.lookup(extractOp.getIndices()[i]), indexVecType.getShape()); + auto extractOpIndex = + broadcastIfNeeded(rewriter, bvm.lookup(extractOp.getIndices()[i]), + indexVecType.getShape()); - offset = b.create(loc, extractOpIndex, offset); + offset = rewriter.create(loc, extractOpIndex, offset); } return offset; @@ -632,17 +634,16 @@ calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp, /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, - Operation *op, - LinalgOp linalgOp, - const IRMapping &bvm) { +static VectorizationResult +vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, + Operation *op, LinalgOp linalgOp, const IRMapping &bvm) { tensor::ExtractOp extractOp = dyn_cast(op); if (!extractOp) return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = extractOp.getLoc(); // Compute the static loop sizes of the extract op. - auto targetShape = linalgOp.computeStaticLoopSizes(); + auto targetShape = state.getCanonicalVecShape(); auto resultType = VectorType::get(targetShape, extractOp.getResult().getType()); @@ -662,9 +663,10 @@ static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); // Generate the gather load - auto gatherOp = rewriter.create( + Operation *gatherOp = rewriter.create( loc, resultType, extractOp.getTensor(), baseIndices, offset, maskConstantOp, passThruConstantOp); + gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; } @@ -904,14 +906,14 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, // 4b. Register CustomVectorizationHook for indexOp. CustomVectorizationHook vectorizeIndex = [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { - return vectorizeLinalgIndex(rewriter, op, linalgOp); + return vectorizeLinalgIndex(rewriter, state, op, linalgOp); }; hooks.push_back(vectorizeIndex); // 4c. Register CustomVectorizationHook for extractOp. CustomVectorizationHook vectorizeExtract = [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { - return vectorizeTensorExtract(rewriter, op, linalgOp, bvm); + return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm); }; hooks.push_back(vectorizeExtract); @@ -1007,8 +1009,10 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp, return failure(); if (linalgOp.hasDynamicShape() && - failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) + failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) { + LDBG("Dynamically-shaped op failed vectorization pre-conditions\n"); return failure(); + } SmallVector customPreconditions; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3e145f1..64125ef 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4597,6 +4597,16 @@ LogicalResult GatherOp::verify() { return success(); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. It requires the operation to be vectorized." +Type GatherOp::getExpectedMaskType() { + auto vecType = this->getIndexVectorType(); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1)); +} + namespace { class GatherFolder final : public OpRewritePattern { public: diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index eaba097..7c66e65 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" @@ -109,6 +110,29 @@ public: } }; +/// Lowers a masked `vector.gather` operation. +struct MaskedGatherOpPattern : public MaskOpRewritePattern { +public: + using MaskOpRewritePattern::MaskOpRewritePattern; + + LogicalResult + matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { + Value passthru = maskingOp.hasPassthru() + ? maskingOp.getPassthru() + : rewriter.create( + gatherOp.getLoc(), + rewriter.getZeroAttr(gatherOp.getVectorType())); + + // Replace the `vector.mask` operation. + rewriter.replaceOpWithNewOp( + maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), + gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), + passthru); + return success(); + } +}; + struct LowerVectorMaskPass : public vector::impl::LowerVectorMaskPassBase { using Base::Base; @@ -136,8 +160,8 @@ struct LowerVectorMaskPass /// not its nested `MaskableOpInterface`. void vector::populateVectorMaskLoweringPatternsForSideEffectingOps( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns.add(patterns.getContext()); } std::unique_ptr mlir::vector::createLowerVectorMaskPass() { diff --git a/mlir/test/Dialect/Vector/lower-vector-mask.mlir b/mlir/test/Dialect/Vector/lower-vector-mask.mlir index 360e35d..8f8fae0 100644 --- a/mlir/test/Dialect/Vector/lower-vector-mask.mlir +++ b/mlir/test/Dialect/Vector/lower-vector-mask.mlir @@ -48,3 +48,32 @@ func.func @vector_transfer_write_on_tensor(%val: vector<16xf32>, %t0: tensor // CHECK: } +// ----- + +func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %c3 = arith.constant 3 : index + %0 = vector.create_mask %c3 : vector<4xi1> + %1 = vector.mask %0 { vector.transfer_read %arg1[%c0], %cst {in_bounds = [true]} : tensor<3xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32> + %cst_0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> + %cst_1 = arith.constant dense : vector<4xi1> + %cst_2 = arith.constant dense<0.000000e+00> : vector<4xf32> + %c0_3 = arith.constant 0 : index + %2 = vector.mask %0 { vector.gather %arg0[%c0_3] [%cst_0], %cst_1, %cst_2 : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> } : vector<4xi1> -> vector<4xf32> + %c0_4 = arith.constant 0 : index + %3 = vector.mask %0 { vector.transfer_write %2, %arg1[%c0_4] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> } : vector<4xi1> -> tensor<3xf32> + return %3 : tensor<3xf32> +} + +// CHECK-LABEL: func.func @vector_gather( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_6:.*]] = vector.create_mask %[[VAL_5]] : vector<4xi1> +// CHECK: %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> +// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> + -- 2.7.4