Type type, Value source, Value pad, ArrayRef<OpFoldResult> low,
ArrayRef<OpFoldResult> high, Location loc, OpBuilder & builder);
+ // Return the pad value if it is a constant. Return null value otherwise.
+ Value getConstantPaddingValue();
+
// Return a vector of all the static or dynamic values (low/high padding) of
// the op.
inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayAttr staticAttrs,
results.add<FoldStaticZeroPadding>(context);
}
+/// Return the padding value of the PadTensorOp if it constant. In this context,
+/// "constant" means an actual constant or "defined outside of the block".
+///
+/// Values are considered constant in three cases:
+/// - A ConstantLike value.
+/// - A basic block argument from a different block.
+/// - A value defined outside of the block.
+///
+/// If the padding value is not constant, an empty Value is returned.
+Value PadTensorOp::getConstantPaddingValue() {
+ auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
+ if (!yieldOp || yieldOp.values().size() != 1)
+ return {};
+ Value padValue = yieldOp.values().front();
+ // Check if yield value is a constant.
+ if (matchPattern(padValue, m_Constant()))
+ return padValue;
+ // Check if yield value is defined inside the PadTensorOp block.
+ if (padValue.getParentBlock() == &getRegion().front())
+ return {};
+ // Else: Yield value defined outside of the PadTensorOp block.
+ return padValue;
+}
+
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
// Misc. vectorization patterns.
//----------------------------------------------------------------------------//
-/// Given a block, return the Value that the block yields if that Value is
-/// constant. In this context, "constant" means "defined outside of the block".
-/// Should not be called on blocks that yield more than one value.
-///
-/// Values are considered constant in two cases:
-/// - A basic block argument from a different block.
-/// - A value defined outside of the block.
-///
-/// If the yielded value is not constant, an empty Value is returned.
-static Value getConstantYieldValueFromBlock(Block &block) {
- auto yieldOp = cast<YieldOp>(block.getTerminator());
- assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
- Value result = yieldOp.values().front();
- Operation *definingOp = result.getDefiningOp();
-
- // Check if yield value is defined inside the block.
- if (definingOp && definingOp->getBlock() == &block)
- return Value();
- // Check if the yield value is a BB arg of the block.
- if (!definingOp && result.cast<BlockArgument>().getOwner() == &block)
- return Value();
-
- return result;
-}
-
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
/// TransferWriteOp. For now, this only applies when all low and high paddings
/// are determined to be zero.
// High padding must be static 0.
if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure();
// Pad value must be a constant.
- auto padValue = getConstantYieldValueFromBlock(padOp.region().front());
+ auto padValue = padOp.getConstantPaddingValue();
if (!padValue) return failure();
// Bail on non-static shapes.