[mlir][Vector] Add support for masked vector gather ops
authorDiego Caballero <diegocaballero@google.com>
Wed, 15 Feb 2023 06:00:12 +0000 (06:00 +0000)
committerDiego Caballero <diegocaballero@google.com>
Wed, 15 Feb 2023 06:10:22 +0000 (06:10 +0000)
This patch adds support for masked vector.gather ops using the
vector.mask representation. It includes the implementation of the
MaskableOpInterface, Linalg vectorizer support and lowering to LLVM.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D143939

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/test/Dialect/Vector/lower-vector-mask.mlir

index 6f6d80c..94d4b64 100644 (file)
@@ -1846,7 +1846,7 @@ def Vector_MaskedStoreOp :
 }
 
 def Vector_GatherOp :
-  Vector_Op<"gather">,
+  Vector_Op<"gather", [DeclareOpInterfaceMethods<MaskableOpInterface>]>,
     Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOf<[AnyInteger, Index]>:$index_vec,
index 5a20d23..fc36477 100644 (file)
@@ -416,8 +416,7 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
       vector::BroadcastableToResult::Success)
     return value;
   Location loc = b.getInsertionPoint()->getLoc();
-  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType,
-                                                    value);
+  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
 }
 
 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
@@ -532,14 +531,16 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult
-vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) {
+static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
+                                                VectorizationState &state,
+                                                Operation *op,
+                                                LinalgOp linalgOp) {
   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
   if (!indexOp)
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
   auto loc = indexOp.getLoc();
   // Compute the static loop sizes of the index op.
-  auto targetShape = linalgOp.computeStaticLoopSizes();
+  auto targetShape = llvm::to_vector(state.getCanonicalVecShape());
   // Compute a one-dimensional index vector for the index op dimension.
   SmallVector<int64_t> constantSeq =
       llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
@@ -597,32 +598,33 @@ tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
 ///
 /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
 ///  offset = ( ( 1 ) * 80 +  2 ) * 15  + 3
-static Value
-calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
-                      const IRMapping &bvm,
-                      const SmallVectorImpl<int64_t> &targetShape) {
+static Value calculateGatherOffset(RewriterBase &rewriter,
+                                   tensor::ExtractOp extractOp,
+                                   const IRMapping &bvm,
+                                   const ArrayRef<int64_t> targetShape) {
   // The vector of indices for GatherOp should be shaped as the output vector
-  auto indexVecType = VectorType::get(targetShape, b.getIndexType());
+  auto indexVecType = VectorType::get(targetShape, rewriter.getIndexType());
   auto loc = extractOp.getLoc();
 
-  Value offset = b.create<vector::BroadcastOp>(
-      loc, indexVecType, bvm.lookup(extractOp.getIndices()[0]));
+  Value offset = broadcastIfNeeded(
+      rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType.getShape());
 
   const size_t numIndices = extractOp.getIndices().size();
   for (size_t i = 1; i < numIndices; i++) {
     auto dimSize = broadcastIfNeeded(
-        b,
-        b.create<arith::ConstantIndexOp>(
+        rewriter,
+        rewriter.create<arith::ConstantIndexOp>(
             loc,
             extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
         indexVecType.getShape());
 
-    offset = b.create<arith::MulIOp>(loc, offset, dimSize);
+    offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
 
-    auto extractOpIndex = broadcastIfNeeded(
-        b, bvm.lookup(extractOp.getIndices()[i]), indexVecType.getShape());
+    auto extractOpIndex =
+        broadcastIfNeeded(rewriter, bvm.lookup(extractOp.getIndices()[i]),
+                          indexVecType.getShape());
 
-    offset = b.create<arith::AddIOp>(loc, extractOpIndex, offset);
+    offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
   }
 
   return offset;
@@ -632,17 +634,16 @@ calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
 /// CustomVectorizationHook.
-static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter,
-                                                  Operation *op,
-                                                  LinalgOp linalgOp,
-                                                  const IRMapping &bvm) {
+static VectorizationResult
+vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
+                       Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
   if (!extractOp)
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
   auto loc = extractOp.getLoc();
 
   // Compute the static loop sizes of the extract op.
-  auto targetShape = linalgOp.computeStaticLoopSizes();
+  auto targetShape = state.getCanonicalVecShape();
 
   auto resultType =
       VectorType::get(targetShape, extractOp.getResult().getType());
@@ -662,9 +663,10 @@ static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter,
   Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
 
   // Generate the gather load
-  auto gatherOp = rewriter.create<vector::GatherOp>(
+  Operation *gatherOp = rewriter.create<vector::GatherOp>(
       loc, resultType, extractOp.getTensor(), baseIndices, offset,
       maskConstantOp, passThruConstantOp);
+  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
 
   return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
 }
@@ -904,14 +906,14 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   // 4b. Register CustomVectorizationHook for indexOp.
   CustomVectorizationHook vectorizeIndex =
       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
-    return vectorizeLinalgIndex(rewriter, op, linalgOp);
+    return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
   };
   hooks.push_back(vectorizeIndex);
 
   // 4c. Register CustomVectorizationHook for extractOp.
   CustomVectorizationHook vectorizeExtract =
       [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
-    return vectorizeTensorExtract(rewriter, op, linalgOp, bvm);
+    return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
   };
   hooks.push_back(vectorizeExtract);
 
@@ -1007,8 +1009,10 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
     return failure();
 
   if (linalgOp.hasDynamicShape() &&
-      failed(vectorizeDynamicLinalgOpPrecondition(linalgOp)))
+      failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) {
+    LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
     return failure();
+  }
 
   SmallVector<CustomVectorizationPrecondition> customPreconditions;
 
index 3e145f1..64125ef 100644 (file)
@@ -4597,6 +4597,16 @@ LogicalResult GatherOp::verify() {
   return success();
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type GatherOp::getExpectedMaskType() {
+  auto vecType = this->getIndexVectorType();
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
 namespace {
 class GatherFolder final : public OpRewritePattern<GatherOp> {
 public:
index eaba097..7c66e65 100644 (file)
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/Passes.h"
@@ -109,6 +110,29 @@ public:
   }
 };
 
+/// Lowers a masked `vector.gather` operation.
+struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
+public:
+  using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
+
+  LogicalResult
+  matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    Value passthru = maskingOp.hasPassthru()
+                         ? maskingOp.getPassthru()
+                         : rewriter.create<arith::ConstantOp>(
+                               gatherOp.getLoc(),
+                               rewriter.getZeroAttr(gatherOp.getVectorType()));
+
+    // Replace the `vector.mask` operation.
+    rewriter.replaceOpWithNewOp<GatherOp>(
+        maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
+        gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
+        passthru);
+    return success();
+  }
+};
+
 struct LowerVectorMaskPass
     : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
   using Base::Base;
@@ -136,8 +160,8 @@ struct LowerVectorMaskPass
 /// not its nested `MaskableOpInterface`.
 void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
     RewritePatternSet &patterns) {
-  patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern>(
-      patterns.getContext());
+  patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
+               MaskedGatherOpPattern>(patterns.getContext());
 }
 
 std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
index 360e35d..8f8fae0 100644 (file)
@@ -48,3 +48,32 @@ func.func @vector_transfer_write_on_tensor(%val: vector<16xf32>, %t0: tensor<?xf
 // CHECK:           return %[[VAL_4]] : tensor<?xf32>
 // CHECK:         }
 
+// -----
+
+func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %c3 = arith.constant 3 : index
+  %0 = vector.create_mask %c3 : vector<4xi1>
+  %1 = vector.mask %0 { vector.transfer_read %arg1[%c0], %cst {in_bounds = [true]} : tensor<3xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+  %cst_0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+  %cst_1 = arith.constant dense<true> : vector<4xi1>
+  %cst_2 = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %c0_3 = arith.constant 0 : index
+  %2 = vector.mask %0 { vector.gather %arg0[%c0_3] [%cst_0], %cst_1, %cst_2 : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+  %c0_4 = arith.constant 0 : index
+  %3 = vector.mask %0 { vector.transfer_write %2, %arg1[%c0_4] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> } : vector<4xi1> -> tensor<3xf32>
+  return %3 : tensor<3xf32>
+}
+
+// CHECK-LABEL:   func.func @vector_gather(
+// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<64xf32>,
+// CHECK-SAME:                             %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK:           %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_6:.*]] = vector.create_mask %[[VAL_5]] : vector<4xi1>
+// CHECK:           %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
+