struct LinalgFusionOptions;
struct LinalgTilingOptions;
+/// Default function to control reshape folding. Skips folding unit dimension
+/// reshapes.
+bool skipUnitDimReshape(const OpResult &producer, const OpOperand &consumer);
+
//===----------------------------------------------------------------------===//
// Transformations exposed as function calls.
//===----------------------------------------------------------------------===//
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
+using ControlElementwiseOpsFusionFn =
+ std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
+
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
void populateFoldReshapeOpsByExpansionPatterns(
- RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
+ RewritePatternSet &patterns,
+ ControlElementwiseOpsFusionFn controlFoldingReshapes = skipUnitDimReshape);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// tensors.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
-using ControlElementwiseOpsFusionFn =
- std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
-
/// Options that control fusion of elementwise operations.
struct LinalgElementwiseFusionOptions {
- /// Enable fusion of reshapes that are introducing unit-dimensions into the
- /// shape with elementwise operations. By default this is disabled.
- bool allowFoldingUnitDimReshapes = false;
+ /// Enable fusion of reshapes into the shape with elementwise operations. By
+ /// default it is disabled for unit dimensions reshape.
+ ControlElementwiseOpsFusionFn controlFoldingReshapesFn = skipUnitDimReshape;
- LinalgElementwiseFusionOptions &setAllowFoldingUnitDimReshapes(bool val) {
- allowFoldingUnitDimReshapes = val;
+ LinalgElementwiseFusionOptions &
+ setControlFoldingReshapes(ControlElementwiseOpsFusionFn fun) {
+ controlFoldingReshapesFn = std::move(fun);
return *this;
}
class FoldWithProducerReshapeOpByExpansion
: public OpRewritePattern<GenericOpTy> {
public:
- FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
- bool foldUnitDimReshapes,
- PatternBenefit benefit = 1)
+ FoldWithProducerReshapeOpByExpansion(
+ MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
: OpRewritePattern<GenericOpTy>(context, benefit),
- allowFoldingUnitDimReshapes(foldUnitDimReshapes) {}
+ controlFoldingReshapes(foldReshapes) {}
LogicalResult matchAndRewrite(GenericOpTy genericOp,
PatternRewriter &rewriter) const override {
operand.value().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp)
continue;
-
// Fold only if
// - The tensor reshape op is folding.
// - All constraints of fusing with reshape by expansion are met.
if (reshapeOp.getSrcType().getRank() <
reshapeOp.getResultType().getRank() ||
!isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
- (!allowFoldingUnitDimReshapes &&
- isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
- reshapeOp.getReassociationMaps())))
+ (!controlFoldingReshapes(
+ reshapeOp->getResult(0),
+ linalgOp.getInputOpOperands()[operand.index()])))
continue;
Optional<SmallVector<Value, 1>> replacementValues =
}
private:
- bool allowFoldingUnitDimReshapes;
+ ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
/// Pattern to fold tensor_reshape op with its producer. The corresponding index
controlFn, rewriter);
}
+bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
+ const OpOperand &consumer) {
+ auto reshapeOp = producer.getDefiningOp<linalg::TensorReshapeOp>();
+ return !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
+ reshapeOp.getReassociationMaps());
+}
+
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
template <typename LinalgOpTy>
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
+ ControlElementwiseOpsFusionFn allowFoldingFn =
+ [](const OpResult &producer, const OpOperand &consumer) {
+ return true;
+ };
populateElementwiseOpsFusionPatterns(
patterns,
- LinalgElementwiseFusionOptions().setAllowFoldingUnitDimReshapes(
- allowFoldingUnitDimReshapes));
+ LinalgElementwiseFusionOptions().setControlFoldingReshapes(
+ allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
}
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
- RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
+ RewritePatternSet &patterns,
+ ControlElementwiseOpsFusionFn controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
- patterns.getContext(), allowFoldingUnitDimReshapes);
+ patterns.getContext(), controlFoldingReshapes);
}
void mlir::linalg::populateElementwiseOpsFusionPatterns(
.add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
context, options.controlElementwiseOpsFusionFn);
- populateFoldReshapeOpsByExpansionPatterns(
- patterns, options.allowFoldingUnitDimReshapes);
+ populateFoldReshapeOpsByExpansionPatterns(patterns,
+ options.controlFoldingReshapesFn);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);