/// created PadTensorOp.
/// Return failure if the operand cannot be padded to a static shape.
static LogicalResult padOperandToSmallestStaticBoundingBox(
- PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand &operand,
+ PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
const LinalgTilingOptions &options, Value &result) {
- auto tensorType = operand.get().getType().cast<RankedTensorType>();
// Already static shape, no need to pad.
- if (tensorType.hasStaticShape())
+ if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
return success();
- auto subtensor = operand.get().getDefiningOp<SubTensorOp>();
+ auto subtensor = opOperand->get().getDefiningOp<SubTensorOp>();
// Not a subtensor, cannot construct a static bounding box.
if (!subtensor)
return failure();
SmallVector<int64_t> staticSizes;
- staticSizes.reserve(tensorType.getRank());
+ staticSizes.reserve(opToPad.getRank(opOperand));
auto shapedOp =
cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation());
for (auto size : shapedOp.getMixedSizes()) {
opToPad, "No constant bounding box can be found for padding");
staticSizes.push_back(indexAttr.getInt());
}
- Value pad = options.paddingValueComputationFunction(rewriter, operand);
- auto staticTensorType =
- RankedTensorType::get(staticSizes, tensorType.getElementType());
+ Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
+ auto staticTensorType = RankedTensorType::get(
+ staticSizes, getElementTypeOrSelf(opOperand->get().getType()));
result = linalg::PadTensorOp::createPadHighOp(
- staticTensorType, operand.get(), pad, opToPad->getLoc(), rewriter);
+ staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter);
return success();
}
// If padding was requested but the shape cannot be bounded statically then
// the pattern fails to apply.
if (failed(padOperandToSmallestStaticBoundingBox(
- rewriter, opToPad, *opOperand, options, paddedOperand))) {
+ rewriter, opToPad, opOperand, options, paddedOperand)))
return failure();
- }
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
}