[mlir][linalg] Break up linalg vectorization pre-condition
authorThomas Raoux <thomasraoux@google.com>
Tue, 14 Dec 2021 20:52:36 +0000 (12:52 -0800)
committerThomas Raoux <thomasraoux@google.com>
Tue, 14 Dec 2021 21:38:14 +0000 (13:38 -0800)
Break up the vectorization pre-condition into the part checking for
static shape and the rest checking if the linalg op is supported by
vectorization. This allows checking if an op could be vectorized if it
had static shapes.

Differential Revision: https://reviews.llvm.org/D115754

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

index 82e6080..c14259f 100644 (file)
@@ -401,9 +401,15 @@ LogicalResult generalizeNamedOpPrecondition(Operation *op);
 LogicalResult promoteSubviewsPrecondition(Operation *op,
                                           LinalgPromotionOptions options);
 
-/// Rewrite a linalg.generic into a suitable vector.contraction op.
+/// Return success if the operation can be vectorized.
 LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
 
+/// Return success if `op` can be vectorized assuming it is static. This allows
+/// checking if an op will be vectorizable once all the dimensions are folded to
+/// static values.
+/// It is the same as `vectorizeLinalgOpPrecondition` for static shapes.
+LogicalResult vectorizeStaticLinalgOpPrecondition(LinalgOp op);
+
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
 //===----------------------------------------------------------------------===//
index bc02298..d4aa16e 100644 (file)
@@ -599,34 +599,39 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
   return success();
 }
 
-LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
-  auto linalgOp = cast<linalg::LinalgOp>(op);
-  // All types must be static shape to go to vector.
-  if (linalgOp.hasDynamicShape()) {
-    LDBG("precondition failed: dynamic shape");
-    return failure();
-  }
+LogicalResult
+mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
   if (isElementwise(op))
     return success();
   // TODO: isaConvolutionOpInterface that can also infer from generic features.
   // But we will still need stride/dilation attributes that will be annoying to
   // reverse-engineer...
-  if (isa<ConvolutionOpInterface>(op))
+  if (isa<ConvolutionOpInterface>(op.getOperation()))
     return success();
   // TODO: the common vector shape is equal to the static loop sizes only when
   // all indexing maps are projected permutations. For convs and stencils the
   // logic will need to evolve.
-  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
+  if (!allIndexingsAreProjectedPermutation(op)) {
     LDBG("precondition failed: not projected permutations");
     return failure();
   }
-  if (failed(reductionPreconditions(linalgOp))) {
+  if (failed(reductionPreconditions(op))) {
     LDBG("precondition failed: reduction preconditions");
     return failure();
   }
   return success();
 }
 
+LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
+  auto linalgOp = cast<linalg::LinalgOp>(op);
+  // All types must be static shape to go to vector.
+  if (linalgOp.hasDynamicShape()) {
+    LDBG("precondition failed: dynamic shape");
+    return failure();
+  }
+  return vectorizeStaticLinalgOpPrecondition(linalgOp);
+}
+
 LogicalResult
 mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
                                 SmallVectorImpl<Value> &newResults) {