[mlir][linalg] Add constant padding helper to PadTensorOp
authorMatthias Springer <springerm@google.com>
Mon, 14 Jun 2021 00:43:02 +0000 (09:43 +0900)
committerMatthias Springer <springerm@google.com>
Mon, 14 Jun 2021 00:44:39 +0000 (09:44 +0900)
* Add a helper function that returns the constant padding value (if applicable).
* Remove existing getConstantYieldValueFromBlock function, which does almost the same.
* Adapted from D103243.

Differential Revision: https://reviews.llvm.org/D104004

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

index 51e1ab4..5d7e2cc 100644 (file)
@@ -242,6 +242,9 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
         Type type, Value source, Value pad, ArrayRef<OpFoldResult> low,
         ArrayRef<OpFoldResult> high, Location loc, OpBuilder & builder);
 
+    // Return the pad value if it is a constant. Return null value otherwise.
+    Value getConstantPaddingValue();
+
     // Return a vector of all the static or dynamic values (low/high padding) of
     // the op.
     inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayAttr staticAttrs,
index 8830f57..6aa0ed1 100644 (file)
@@ -1141,6 +1141,30 @@ void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<FoldStaticZeroPadding>(context);
 }
 
+/// Return the padding value of the PadTensorOp if it constant. In this context,
+/// "constant" means an actual constant or "defined outside of the block".
+///
+/// Values are considered constant in three cases:
+///  - A ConstantLike value.
+///  - A basic block argument from a different block.
+///  - A value defined outside of the block.
+///
+/// If the padding value is not constant, an empty Value is returned.
+Value PadTensorOp::getConstantPaddingValue() {
+  auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
+  if (!yieldOp || yieldOp.values().size() != 1)
+    return {};
+  Value padValue = yieldOp.values().front();
+  // Check if yield value is a constant.
+  if (matchPattern(padValue, m_Constant()))
+    return padValue;
+  // Check if yield value is defined inside the PadTensorOp block.
+  if (padValue.getParentBlock() == &getRegion().front())
+    return {};
+  // Else: Yield value defined outside of the PadTensorOp block.
+  return padValue;
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//
index 6da0737..36cdf79 100644 (file)
@@ -650,31 +650,6 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
 // Misc. vectorization patterns.
 //----------------------------------------------------------------------------//
 
-/// Given a block, return the Value that the block yields if that Value is
-/// constant. In this context, "constant" means "defined outside of the block".
-/// Should not be called on blocks that yield more than one value.
-///
-/// Values are considered constant in two cases:
-///  - A basic block argument from a different block.
-///  - A value defined outside of the block.
-///
-/// If the yielded value is not constant, an empty Value is returned.
-static Value getConstantYieldValueFromBlock(Block &block) {
-  auto yieldOp = cast<YieldOp>(block.getTerminator());
-  assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
-  Value result = yieldOp.values().front();
-  Operation *definingOp = result.getDefiningOp();
-
-  // Check if yield value is defined inside the block.
-  if (definingOp && definingOp->getBlock() == &block)
-    return Value();
-  // Check if the yield value is a BB arg of the block.
-  if (!definingOp && result.cast<BlockArgument>().getOwner() == &block)
-    return Value();
-
-  return result;
-}
-
 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
 /// TransferWriteOp. For now, this only applies when all low and high paddings
 /// are determined to be zero.
@@ -693,7 +668,7 @@ struct GenericPadTensorOpVectorizationPattern
     // High padding must be static 0.
     if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure();
     // Pad value must be a constant.
-    auto padValue = getConstantYieldValueFromBlock(padOp.region().front());
+    auto padValue = padOp.getConstantPaddingValue();
     if (!padValue) return failure();
 
     // Bail on non-static shapes.