/// Return failure if the operand cannot be padded to a static shape.
static LogicalResult padOperandToSmallestStaticBoundingBox(
PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
- const LinalgTilingOptions &options, Value &result) {
+ const PaddingValueComputationFunction &paddingFunc, Value &result) {
// Already static shape, no need to pad.
if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
return success();
opToPad, "No constant bounding box can be found for padding");
staticSizes.push_back(indexAttr.getInt());
}
- Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
+ Value pad = paddingFunc(rewriter, *opOperand);
auto staticTensorType = RankedTensorType::get(
staticSizes, getElementTypeOrSelf(opOperand->get()));
result = linalg::PadTensorOp::createPadHighOp(
return success();
}
-// Try to create a static bounding box around each operand of `res.op`.
-// If successful, `res.op` is rewritten in static form with padded operands.
-// `res.op` is updated to the cloned static form of the op on success.
-static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
- TiledLinalgOp &res,
- const LinalgTilingOptions &options) {
- LinalgOp opToPad = res.op;
+LogicalResult
+linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
+ const PaddingValueComputationFunction &paddingFunc,
+ LinalgOp &paddedOp) {
Location loc = opToPad->getLoc();
// If the op is fully static, it does not need padding.
// If padding was requested but the shape cannot be bounded statically then
// the pattern fails to apply.
if (failed(padOperandToSmallestStaticBoundingBox(
- rewriter, opToPad, opOperand, options, paddedOperand)))
+ rewriter, opToPad, opOperand, paddingFunc, paddedOperand)))
return failure();
newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
}
// Clone `opToPad` to operate on the statically padded shapes.
auto resultTensorTypes =
ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
- linalg::LinalgOp paddedOp =
- opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
+ paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
// Recover the slice out of the new static results. This keeps the original
// linalg op around because it uses the dims of the original results.
rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
return !newUsersOfOpToPad.contains(opOp.getOwner());
});
-
- res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults};
return success();
}
!linalgOp.hasTensorSemantics())
return success();
- // Try to pad on the fly by rewriting res->op as a padded op.
- if (failed(rewriteAsPaddedOp(rewriter, *res, options))) {
- // Set so RAII guard does not propagate TiledLinalgOp to `result`.
- return failure();
+ // Try to pad on the fly by rewriting res->op as a padded op. If successful,
+ // `res.op` is rewritten in static form with padded operands.
+ LinalgOp paddedOp;
+ if (succeeded(rewriteAsPaddedOp(rewriter, res->op,
+ options.paddingValueComputationFunction,
+ paddedOp))) {
+ res->op = paddedOp;
+ // Do not perform replacement of `linalgOp`, let the derived patterns
+ // do this as they see fit, from the resulting TiledLinalgOp.
+ return success();
}
-
- // Do not perform replacement of `linalgOp`, let the derived patterns
- // do this as they see fit, from the resulting TiledLinalgOp.
- return success();
+ // Set so RAII guard does not propagate TiledLinalgOp to `result`.
+ return failure();
}
static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {