[mlir][linalg] Unify generic vectorization interface.
authorHanhan Wang <hanchung@google.com>
Fri, 12 May 2023 23:30:59 +0000 (16:30 -0700)
committerHanhan Wang <hanchung@google.com>
Thu, 18 May 2023 19:58:50 +0000 (12:58 -0700)
It breaks the logic of maskedVectorize (on tensor.pad ops) into
precondition checks and vectorization implementation; unifies the
interface.

The revision also rename`s vectorizeLinalgOpPrecondition` to
`vectorizeOpPrecondition` because we can vectorize ops other
than LinalgOps.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D150495

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

index 11cdaa2..693b89a 100644 (file)
@@ -291,10 +291,9 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
                                           LinalgPromotionOptions options);
 
 /// Return success if the operation can be vectorized.
-LogicalResult
-vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
-                              ArrayRef<int64_t> inputVectorSizes = {},
-                              bool vectorizeNDExtract = false);
+LogicalResult vectorizeOpPrecondition(Operation *op,
+                                      ArrayRef<int64_t> inputVectorSizes = {},
+                                      bool vectorizeNDExtract = false);
 
 //===----------------------------------------------------------------------===//
 // Transformations exposed as functional-style API calls.
@@ -576,26 +575,19 @@ LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst);
 /// memory is freed when going outside of the scope.
 LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
 
-/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
-/// are used to vectorize this operation. `inputVectorSizes` must match the rank
-/// of the iteration space of the operation and the sizes must be smaller or
-/// equal than their counterpart interation space sizes, if static.
-/// `inputVectorShapes` also allows the vectorization of operations with dynamic
-/// shapes.
-LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
+/// Emit a suitable vector form for an operation. If provided,
+/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
+/// must match the rank of the iteration space of the operation and the sizes
+/// must be smaller or equal than their counterpart interation space sizes, if
+/// static. `inputVectorShapes` also allows the vectorization of operations with
+/// dynamic shapes.
+LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
                         ArrayRef<int64_t> inputVectorSizes = {},
                         bool vectorizeNDExtract = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
 
-/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
-/// and (3) all-zero lowPad to
-///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
-FailureOr<vector::TransferWriteOp>
-maskedVectorize(RewriterBase &rewriter, tensor::PadOp padOp,
-                ArrayRef<int64_t> inputVectorSizes);
-
 /// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
 FailureOr<LinalgLoops> linalgOpToLoops(RewriterBase &rewriter,
                                        LinalgOp linalgOp);
index 4fe9b9f..c1a24a4 100644 (file)
@@ -2950,26 +2950,15 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
   // TODO: Check that the correct number of vectorSizes was provided.
 
   for (Operation *target : targets) {
-    if (auto padOp = dyn_cast<tensor::PadOp>(target)) {
-      FailureOr<vector::TransferWriteOp> maybeWriteOp =
-          maskedVectorize(rewriter, padOp, vectorSizes);
-      if (failed(maybeWriteOp)) {
-        return mlir::emitSilenceableFailure(target->getLoc())
-               << "failed to vectorize padOp";
-      }
-      continue;
-    }
-
-    auto linalgOp = dyn_cast<LinalgOp>(target);
-    if (!linalgOp) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
-             << "cannot vectorize non-Linalg op";
+             << "Unsupported Op, cannot vectorize";
     }
 
-    if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes,
+    if (failed(linalg::vectorize(rewriter, target, vectorSizes,
                                  getVectorizeNdExtract()))) {
       return mlir::emitSilenceableFailure(target->getLoc())
-             << "failed to vectorize linalg op";
+             << "Attempted to vectorize, but failed";
     }
   }
 
index 2236d1b..fbdc7d7 100644 (file)
@@ -1280,6 +1280,49 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
+/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
+/// and (3) all-zero lowPad to
+///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
+static LogicalResult
+vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
+                       ArrayRef<int64_t> inputVectorSizes,
+                       SmallVectorImpl<Value> &newResults) {
+  auto padValue = padOp.getConstantPaddingValue();
+  Location loc = padOp.getLoc();
+  int64_t rank = inputVectorSizes.size();
+  auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+
+  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(padOp);
+  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto emptyOp =
+      rewriter.create<tensor::EmptyOp>(loc, padOp.getResultType(),
+                                       /*dynamicSizes=*/ValueRange{});
+  SmallVector<OpFoldResult> mixedSourceDims =
+      getMixedDimensions(rewriter, loc, padOp.getSource());
+  Value mask =
+      rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/padOp.getSource(),
+      /*indices=*/SmallVector<Value>(rank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(rank, true));
+  auto maskedOp = cast<vector::MaskOp>(
+      mlir::vector::maskOperation(rewriter, transferReadOp, mask));
+  auto transferWriteOp = rewriter.create<vector::TransferWriteOp>(
+      loc,
+      /*vector=*/maskedOp->getResult(0),
+      /*source=*/emptyOp,
+      /*indices=*/SmallVector<Value>(rank, zero),
+      /*inBounds=*/SmallVector<bool>(rank, true));
+  newResults.push_back(transferWriteOp.getResult());
+  return success();
+}
+
 // TODO: probably need some extra checks for reduction followed by consumer
 // ops that may not commute (e.g. linear reduction + non-linear instructions).
 static LogicalResult reductionPreconditions(LinalgOp op) {
@@ -1311,10 +1354,10 @@ static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
   return success();
 }
 
-LogicalResult
-mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
-                                            ArrayRef<int64_t> inputVectorSizes,
-                                            bool vectorizeNDExtract) {
+static LogicalResult
+vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
+                              ArrayRef<int64_t> inputVectorSizes,
+                              bool vectorizeNDExtract) {
   // tensor with dimension of 0 cannot be vectorized.
   if (llvm::any_of(linalgOp.getStaticShape(),
                    [](int64_t dim) { return dim == 0; }))
@@ -1392,135 +1435,133 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   return success();
 }
 
-/// Converts affine.apply Ops to arithmetic operations.
-static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
-  OpBuilder::InsertionGuard g(rewriter);
-  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
-
-  for (auto op : make_early_inc_range(toReplace)) {
-    rewriter.setInsertionPoint(op);
-    auto expanded = affine::expandAffineExpr(
-        rewriter, op->getLoc(), op.getAffineMap().getResult(0),
-        op.getOperands().take_front(op.getAffineMap().getNumDims()),
-        op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
-    rewriter.replaceOp(op, expanded);
-  }
-}
-
-FailureOr<vector::TransferWriteOp>
-mlir::linalg::maskedVectorize(RewriterBase &rewriter, tensor::PadOp padOp,
-                              ArrayRef<int64_t> inputVectorSizes) {
+static LogicalResult
+vectorizePadOpPrecondition(tensor::PadOp padOp,
+                           ArrayRef<int64_t> inputVectorSizes) {
   auto padValue = padOp.getConstantPaddingValue();
   if (!padValue) {
     LDBG("pad value is not constant: " << padOp << "\n");
-    return rewriter.notifyMatchFailure(padOp, "pad value is not constant");
+    return failure();
   }
 
   ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
   if (!(resultTensorShape == inputVectorSizes)) {
     LDBG("result tensor shape must match input vector sizes: " << padOp
                                                                << "\n");
-    return rewriter.notifyMatchFailure(
-        padOp, "result tensor shape must match input vector sizes");
+    return failure();
   }
+
   if (llvm::any_of(padOp.getLow(), [](Value v) {
         std::optional<int64_t> res = getConstantIntValue(v);
         return !res.has_value() || res.value() != 0;
       })) {
     LDBG("low pad must all be zero: " << padOp << "\n");
-    return rewriter.notifyMatchFailure(padOp, "low pad must all be zero");
+    return failure();
   }
 
-  Location loc = padOp.getLoc();
-  int64_t rank = inputVectorSizes.size();
-  auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+  return success();
+}
 
-  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
+LogicalResult
+mlir::linalg::vectorizeOpPrecondition(Operation *op,
+                                      ArrayRef<int64_t> inputVectorSizes,
+                                      bool vectorizeNDExtract) {
+  return TypeSwitch<Operation *, LogicalResult>(op)
+      .Case<linalg::LinalgOp>([&](auto linalgOp) {
+        return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
+                                             vectorizeNDExtract);
+      })
+      .Case<tensor::PadOp>([&](auto padOp) {
+        return vectorizePadOpPrecondition(padOp, inputVectorSizes);
+      })
+      .Default([](auto) { return failure(); });
+}
+
+/// Converts affine.apply Ops to arithmetic operations.
+static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
   OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(padOp);
-  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  auto emptyOp =
-      rewriter.create<tensor::EmptyOp>(loc, padOp.getResultType(),
-                                       /*dynamicSizes=*/ValueRange{});
-  SmallVector<OpFoldResult> mixedSourceDims =
-      getMixedDimensions(rewriter, loc, padOp.getSource());
-  Value mask =
-      rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
-  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
-      loc,
-      /*vectorType=*/vectorType,
-      /*source=*/padOp.getSource(),
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*padding=*/padValue,
-      /*inBounds=*/SmallVector<bool>(rank, true));
-  auto maskedOp = cast<vector::MaskOp>(
-      mlir::vector::maskOperation(rewriter, transferReadOp, mask));
-  auto transferWriteOp = rewriter.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/maskedOp->getResult(0),
-      /*source=*/emptyOp,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/SmallVector<bool>(rank, true));
-  rewriter.replaceOp(padOp, transferWriteOp->getResults());
-  return transferWriteOp;
+  auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
+
+  for (auto op : make_early_inc_range(toReplace)) {
+    rewriter.setInsertionPoint(op);
+    auto expanded = affine::expandAffineExpr(
+        rewriter, op->getLoc(), op.getAffineMap().getResult(0),
+        op.getOperands().take_front(op.getAffineMap().getNumDims()),
+        op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
+    rewriter.replaceOp(op, expanded);
+  }
 }
 
-/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
-/// are used to vectorize this operation. `inputVectorSizes` must match the rank
-/// of the iteration space of the operation and the input vector sizes must be
-/// greater than or equal to their counterpart iteration space sizes, if static.
-/// `inputVectorShapes` also allows the vectorization of operations with dynamic
-/// shapes.
-LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
+/// Emit a suitable vector form for an operation. If provided,
+/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
+/// must match the rank of the iteration space of the operation and the input
+/// vector sizes must be greater than or equal to their counterpart iteration
+/// space sizes, if static. `inputVectorShapes` also allows the vectorization of
+/// operations with dynamic shapes.
+LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                       ArrayRef<int64_t> inputVectorSizes,
                                       bool vectorizeNDExtract) {
-  LDBG("Attempting to vectorize:\n" << linalgOp << "\n");
+  LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
   LLVM_DEBUG(llvm::dbgs() << "\n");
 
-  if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
-                                           vectorizeNDExtract))) {
+  if (failed(
+          vectorizeOpPrecondition(op, inputVectorSizes, vectorizeNDExtract))) {
     LDBG("Vectorization pre-conditions failed\n");
     return failure();
   }
 
   // Initialize vectorization state.
   VectorizationState state(rewriter);
-  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) {
-    LDBG("Vectorization state couldn't be initialized\n");
-    return failure();
+  if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+    if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) {
+      LDBG("Vectorization state couldn't be initialized\n");
+      return failure();
+    }
   }
 
   SmallVector<Value> results;
-  // TODO: isaConvolutionOpInterface that can also infer from generic
-  // features. Will require stride/dilation attributes inference.
-  FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp);
-  if (succeeded(convOr)) {
-    llvm::append_range(results, (*convOr)->getResults());
-  } else {
-    if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
-                                             vectorizeNDExtract)))
-      return failure();
-    LDBG("Vectorize generic by broadcasting to the canonical vector shape\n");
-
-    // Pre-process before proceeding.
-    convertAffineApply(rewriter, linalgOp);
-
-    // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to
-    // 'OpBuilder' when it is passed over to some methods like
-    // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op
-    // within these methods, the actual rewriter won't be notified and we will
-    // end up with read-after-free issues!
-    if (failed(vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results)))
-      return failure();
+  auto vectorizeResult =
+      TypeSwitch<Operation *, LogicalResult>(op)
+          .Case<linalg::LinalgOp>([&](auto linalgOp) {
+            // TODO: isaConvolutionOpInterface that can also infer from generic
+            // features. Will require stride/dilation attributes inference.
+            FailureOr<Operation *> convOr =
+                vectorizeConvolution(rewriter, linalgOp);
+            if (succeeded(convOr)) {
+              llvm::append_range(results, (*convOr)->getResults());
+              return success();
+            }
+
+            LDBG("Vectorize generic by broadcasting to the canonical vector "
+                 "shape\n");
+
+            // Pre-process before proceeding.
+            convertAffineApply(rewriter, linalgOp);
+
+            // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
+            // to 'OpBuilder' when it is passed over to some methods like
+            // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
+            // erase an op within these methods, the actual rewriter won't be
+            // notified and we will end up with read-after-free issues!
+            return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
+          })
+          .Case<tensor::PadOp>([&](auto padOp) {
+            return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
+                                          results);
+          })
+          .Default([](auto) { return failure(); });
+
+  if (failed(vectorizeResult)) {
+    LDBG("Vectorization failed\n");
+    return failure();
   }
 
   if (!results.empty())
-    rewriter.replaceOp(linalgOp, results);
+    rewriter.replaceOp(op, results);
   else
-    rewriter.eraseOp(linalgOp);
+    rewriter.eraseOp(op);
 
   return success();
 }