vector::BroadcastableToResult::Success)
return value;
Location loc = b.getInsertionPoint()->getLoc();
- return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType,
- value);
+ return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
}
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
-static VectorizationResult
-vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) {
+static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
+ VectorizationState &state,
+ Operation *op,
+ LinalgOp linalgOp) {
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
if (!indexOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = indexOp.getLoc();
// Compute the static loop sizes of the index op.
- auto targetShape = linalgOp.computeStaticLoopSizes();
+ auto targetShape = llvm::to_vector(state.getCanonicalVecShape());
// Compute a one-dimensional index vector for the index op dimension.
SmallVector<int64_t> constantSeq =
llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
///
/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
-static Value
-calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
- const IRMapping &bvm,
- const SmallVectorImpl<int64_t> &targetShape) {
+static Value calculateGatherOffset(RewriterBase &rewriter,
+ tensor::ExtractOp extractOp,
+ const IRMapping &bvm,
+ const ArrayRef<int64_t> targetShape) {
// The vector of indices for GatherOp should be shaped as the output vector
- auto indexVecType = VectorType::get(targetShape, b.getIndexType());
+ auto indexVecType = VectorType::get(targetShape, rewriter.getIndexType());
auto loc = extractOp.getLoc();
- Value offset = b.create<vector::BroadcastOp>(
- loc, indexVecType, bvm.lookup(extractOp.getIndices()[0]));
+ Value offset = broadcastIfNeeded(
+ rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType.getShape());
const size_t numIndices = extractOp.getIndices().size();
for (size_t i = 1; i < numIndices; i++) {
auto dimSize = broadcastIfNeeded(
- b,
- b.create<arith::ConstantIndexOp>(
+ rewriter,
+ rewriter.create<arith::ConstantIndexOp>(
loc,
extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
indexVecType.getShape());
- offset = b.create<arith::MulIOp>(loc, offset, dimSize);
+ offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
- auto extractOpIndex = broadcastIfNeeded(
- b, bvm.lookup(extractOp.getIndices()[i]), indexVecType.getShape());
+ auto extractOpIndex =
+ broadcastIfNeeded(rewriter, bvm.lookup(extractOp.getIndices()[i]),
+ indexVecType.getShape());
- offset = b.create<arith::AddIOp>(loc, extractOpIndex, offset);
+ offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
}
return offset;
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
-static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter,
- Operation *op,
- LinalgOp linalgOp,
- const IRMapping &bvm) {
+static VectorizationResult
+vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
+ Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = extractOp.getLoc();
// Compute the static loop sizes of the extract op.
- auto targetShape = linalgOp.computeStaticLoopSizes();
+ auto targetShape = state.getCanonicalVecShape();
auto resultType =
VectorType::get(targetShape, extractOp.getResult().getType());
Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
// Generate the gather load
- auto gatherOp = rewriter.create<vector::GatherOp>(
+ Operation *gatherOp = rewriter.create<vector::GatherOp>(
loc, resultType, extractOp.getTensor(), baseIndices, offset,
maskConstantOp, passThruConstantOp);
+ gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
}
// 4b. Register CustomVectorizationHook for indexOp.
CustomVectorizationHook vectorizeIndex =
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
- return vectorizeLinalgIndex(rewriter, op, linalgOp);
+ return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
};
hooks.push_back(vectorizeIndex);
// 4c. Register CustomVectorizationHook for extractOp.
CustomVectorizationHook vectorizeExtract =
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
- return vectorizeTensorExtract(rewriter, op, linalgOp, bvm);
+ return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
};
hooks.push_back(vectorizeExtract);
return failure();
if (linalgOp.hasDynamicShape() &&
- failed(vectorizeDynamicLinalgOpPrecondition(linalgOp)))
+ failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) {
+ LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
return failure();
+ }
SmallVector<CustomVectorizationPrecondition> customPreconditions;
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
}
};
+/// Lowers a masked `vector.gather` operation.
+struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
+public:
+ using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
+
+ LogicalResult
+ matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
+ Value passthru = maskingOp.hasPassthru()
+ ? maskingOp.getPassthru()
+ : rewriter.create<arith::ConstantOp>(
+ gatherOp.getLoc(),
+ rewriter.getZeroAttr(gatherOp.getVectorType()));
+
+ // Replace the `vector.mask` operation.
+ rewriter.replaceOpWithNewOp<GatherOp>(
+ maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
+ gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
+ passthru);
+ return success();
+ }
+};
+
struct LowerVectorMaskPass
: public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
using Base::Base;
/// not its nested `MaskableOpInterface`.
void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
RewritePatternSet &patterns) {
- patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern>(
- patterns.getContext());
+ patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
+ MaskedGatherOpPattern>(patterns.getContext());
}
std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
// CHECK: return %[[VAL_4]] : tensor<?xf32>
// CHECK: }
+// -----
+
+func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %c3 = arith.constant 3 : index
+ %0 = vector.create_mask %c3 : vector<4xi1>
+ %1 = vector.mask %0 { vector.transfer_read %arg1[%c0], %cst {in_bounds = [true]} : tensor<3xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+ %cst_0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+ %cst_1 = arith.constant dense<true> : vector<4xi1>
+ %cst_2 = arith.constant dense<0.000000e+00> : vector<4xf32>
+ %c0_3 = arith.constant 0 : index
+ %2 = vector.mask %0 { vector.gather %arg0[%c0_3] [%cst_0], %cst_1, %cst_2 : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+ %c0_4 = arith.constant 0 : index
+ %3 = vector.mask %0 { vector.transfer_write %2, %arg1[%c0_4] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> } : vector<4xi1> -> tensor<3xf32>
+ return %3 : tensor<3xf32>
+}
+
+// CHECK-LABEL: func.func @vector_gather(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_6:.*]] = vector.create_mask %[[VAL_5]] : vector<4xi1>
+// CHECK: %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
+