From 8c5ad0a2f6532cec2f6841cc3e9a1ea043409398 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 13 Apr 2023 08:27:40 -0700 Subject: [PATCH] [mlir][Vector] Add a masked vectorization of tensor.pad This revision takes advantage of masking support to introduce a vectorized version of pad that does not require lowering to lower-level form. Lowering to lower-level form (if/else + generate + fill + copy + insert_slice) creates unnecessary complexity that can be completely sidestepped by using masked vectorization properly. Differential Revision: https://reviews.llvm.org/D148261 --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 7 +++ mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 5 ++ .../Linalg/TransformOps/LinalgTransformOps.cpp | 12 ++++- .../Dialect/Linalg/Transforms/Vectorization.cpp | 58 ++++++++++++++++++++++ mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 15 ++++-- mlir/test/Dialect/Linalg/vectorization.mlir | 34 +++++++++++++ 6 files changed, 126 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 7eaa2f7..52982c3 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -589,6 +589,13 @@ LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp, /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); +/// Vectorize a `padOp` with (1) static result type, (2) constant padding value +/// and (3) all-zero lowPad to +/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`. +FailureOr +maskedVectorize(RewriterBase &rewriter, tensor::PadOp padOp, + ArrayRef inputVectorSizes); + /// Emit a loop nest of `scf.for` with the proper body for `linalgOp`. FailureOr linalgOpToLoops(RewriterBase &rewriter, LinalgOp linalgOp); diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 0b844e1..2a95ff2 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2322,6 +2322,11 @@ def Vector_CreateMaskOp : ``` }]; + let builders = [ + // Build with mixed static/dynamic operands. + OpBuilder<(ins "VectorType":$type, "ArrayRef":$mixedOperands)> + ]; + let hasCanonicalizer = 1; let hasVerifier = 1; let assemblyFormat = "$operands attr-dict `:` type(results)"; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 2970d34..39f7802 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3098,6 +3098,16 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( // TODO: Check that the correct number of vectorSizes was provided. for (Operation *target : targets) { + if (auto padOp = dyn_cast(target)) { + FailureOr maybeWriteOp = + maskedVectorize(rewriter, padOp, vectorSizes); + if (failed(maybeWriteOp)) { + return mlir::emitSilenceableFailure(target->getLoc()) + << "failed to vectorize padOp"; + } + continue; + } + auto linalgOp = dyn_cast(target); if (!linalgOp) { return mlir::emitSilenceableFailure(target->getLoc()) @@ -3107,7 +3117,7 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes, getVectorizeNdExtract()))) { return mlir::emitSilenceableFailure(target->getLoc()) - << "failed to vectorize op"; + << "failed to vectorize linalg op"; } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index b54eb0f..14726a8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -1385,6 +1386,63 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { } } +FailureOr +mlir::linalg::maskedVectorize(RewriterBase &rewriter, tensor::PadOp padOp, + ArrayRef inputVectorSizes) { + auto padValue = padOp.getConstantPaddingValue(); + if (!padValue) { + LDBG("pad value is not constant: " << padOp << "\n"); + return rewriter.notifyMatchFailure(padOp, "pad value is not constant"); + } + + ArrayRef resultTensorShape = padOp.getResultType().getShape(); + if (!(resultTensorShape == inputVectorSizes)) { + LDBG("result tensor shape must match input vector sizes: " << padOp + << "\n"); + return rewriter.notifyMatchFailure( + padOp, "result tensor shape must match input vector sizes"); + } + if (llvm::any_of(padOp.getStaticLow(), + [](int64_t val) { return val != 0; })) { + LDBG("low pad must all be zero: " << padOp << "\n"); + return rewriter.notifyMatchFailure(padOp, "low pad must all be zero"); + } + + Location loc = padOp.getLoc(); + int64_t rank = inputVectorSizes.size(); + auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type()); + auto vectorType = VectorType::get(inputVectorSizes, padValue.getType()); + + // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value)) + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(padOp); + auto zero = rewriter.create(loc, 0); + auto emptyOp = + rewriter.create(loc, padOp.getResultType(), + /*dynamicSizes=*/ValueRange{}); + SmallVector mixedSourceDims = + getMixedDimensions(rewriter, loc, padOp.getSource()); + Value mask = + rewriter.create(loc, maskType, mixedSourceDims); + auto transferReadOp = rewriter.create( + loc, + /*vectorType=*/vectorType, + /*source=*/padOp.getSource(), + /*indices=*/SmallVector(rank, zero), + /*padding=*/padValue, + /*inBounds=*/SmallVector(rank, true)); + auto maskedOp = cast( + mlir::vector::maskOperation(rewriter, transferReadOp, mask)); + auto transferWriteOp = rewriter.create( + loc, + /*vector=*/maskedOp->getResult(0), + /*source=*/emptyOp, + /*indices=*/SmallVector(rank, zero), + /*inBounds=*/SmallVector(rank, true)); + rewriter.replaceOp(padOp, transferWriteOp->getResults()); + return transferWriteOp; +} + /// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes` /// are used to vectorize this operation. `inputVectorSizes` must match the rank /// of the iteration space of the operation and the input vector sizes must be diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8ee5965..89ca099 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -640,10 +640,9 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { auto loc = parser.getCurrentLocation(); DictionaryAttr dictAttr; // TODO: Unify linalg op attribute parsing. - if (parser.parseAttribute(dictAttr) || - parser.parseOperand(lhsInfo) || parser.parseComma() || - parser.parseOperand(rhsInfo) || parser.parseComma() || - parser.parseOperand(accInfo) || + if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) || + parser.parseComma() || parser.parseOperand(rhsInfo) || + parser.parseComma() || parser.parseOperand(accInfo) || parser.parseTrailingOperandList(masksInfo) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(types) || @@ -5369,6 +5368,14 @@ LogicalResult ConstantMaskOp::verify() { // CreateMaskOp //===----------------------------------------------------------------------===// +void CreateMaskOp::build(OpBuilder &builder, OperationState &result, + VectorType type, + ArrayRef mixedOperands) { + SmallVector operands = + getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands); + build(builder, result, type, operands); +} + LogicalResult CreateMaskOp::verify() { auto vectorType = getResult().getType().cast(); // Verify that an operand was specified for each result vector each dimension. diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index c407b49..d54a2f5 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -2757,3 +2757,37 @@ transform.sequence failures(propagate) { %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.masked_vectorize %0 vector_sizes [2, 4] } + +// ----- + +// CHECK-LABEL: func @test_masked_vectorize_pad +func.func @test_masked_vectorize_pad( + %0 : tensor, %h0 : index, %h1 : index) + -> tensor<2x4xf32> +{ + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c42:.*]] = arith.constant 4.243000e+01 : f32 + // CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<2x4xf32> + // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor + // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor + // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1> + // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { + // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]] + // CHECK-SAME: {in_bounds = [true, true]} : tensor, vector<2x4xf32> + // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32> + // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0]], %[[c0]]] + // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32> + %cst = arith.constant 42.43 : f32 + %1 = tensor.pad %0 low[0, 0] high[%h0, %h1] { + ^bb0(%hh1: index, %hh2: index): + tensor.yield %cst : f32 + } : tensor to tensor<2x4xf32> + return %1: tensor<2x4xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.masked_vectorize %0 vector_sizes [2, 4] +} -- 2.7.4