//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
-///
-/// Linalg lowering patterns.
-///
-/// Apply the `linalgLowerOpToLoops` transformation as a pattern.
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `linalgLowerOpToLoops` for more details.
-enum class LinalgLoweringType {
- LibraryCall = 0,
- Loops = 1,
- AffineLoops = 2,
- ParallelLoops = 3
-};
-
-template <typename OpTy>
-struct LinalgLoweringPattern : public RewritePattern {
- LinalgLoweringPattern(
- MLIRContext *context, LinalgLoweringType loweringType,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : RewritePattern(OpTy::getOperationName(), benefit, context),
- filter(std::move(f)), loweringType(loweringType) {}
-
- // TODO: Move implementation to .cpp once named ops are auto-generated.
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- if (!linalgOp)
- return failure();
- if (failed(filter.checkAndNotify(rewriter, linalgOp)))
- return failure();
-
- switch (loweringType) {
- case LinalgLoweringType::LibraryCall:
- // TODO: Move lowering to library calls here.
- return failure();
- case LinalgLoweringType::Loops:
- if (failed(linalgOpToLoops(rewriter, op)))
- return failure();
- break;
- case LinalgLoweringType::AffineLoops:
- if (failed(linalgOpToAffineLoops(rewriter, op)))
- return failure();
- break;
- case LinalgLoweringType::ParallelLoops:
- if (failed(linalgOpToParallelLoops(rewriter, op)))
- return failure();
- break;
- }
-
- rewriter.eraseOp(op);
- return success();
- }
-
-private:
- /// LinalgTransformMarker handles special attribute manipulations.
- LinalgTransformationFilter filter;
- /// Controls whether the pattern lowers to library calls, scf.for, affine.for
- /// or scf.parallel.
- LinalgLoweringType loweringType;
-};
/// Linalg generalization patterns
RewritePatternSet patterns(ctx);
//===--------------------------------------------------------------------===//
- // Linalg to loops patterns.
- //===--------------------------------------------------------------------===//
- patterns.add<LinalgLoweringPattern<DotOp>>(
- ctx,
- /*loweringType=*/LinalgLoweringType::Loops);
-
- //===--------------------------------------------------------------------===//
// Linalg distribution patterns.
//===--------------------------------------------------------------------===//
LinalgLoopDistributionOptions distributionOptions;