From 2a99e700e0f337c34c2d9d1cb5e4dc1d312fa248 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Sat, 11 Apr 2020 23:02:09 -0700 Subject: [PATCH] [mlir][Linalg] NFC: Add utility function to tile, fuse and set marker to use loop.parallel. This change is NFC since the facility to tile and generate loop.parallel loops already exists in Linalg. Differential Revision: https://reviews.llvm.org/D77965 --- .../Dialect/Linalg/Transforms/LinalgTransforms.h | 6 +++ .../Dialect/Linalg/Transforms/LinalgTransforms.cpp | 49 ++++++++++++++++++---- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index 3bff0f1..4340366 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -63,12 +63,18 @@ LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op, ArrayRef sizes, StringRef linalgMarker, ArrayRef permutation); +LogicalResult tileLinalgOpToParallelLoopsAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + StringRef linalgMarker, ArrayRef permutation); /// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and /// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. LogicalResult tileAndFuseLinalgOpAndSetMarker( PatternRewriter &rewriter, Operation *op, ArrayRef sizes, ArrayRef operandIndicesToFuse, StringRef linalgMarker); +LogicalResult tileAndFuseLinalgOpToParallelLoopsAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, StringRef linalgMarker); using LinalgLoops = SmallVector; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 8511c8e..2e7043d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -40,11 +40,16 @@ using llvm::SetVector; const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = "__internal_linalg_transform__"; -LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - StringRef linalgMarker, ArrayRef permutation) { +using TileFn = Optional(OpBuilder &, LinalgOp, ArrayRef, + ArrayRef, OperationFolder *); + +static LogicalResult +tileLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter, + Operation *op, ArrayRef sizes, + StringRef linalgMarker, + ArrayRef permutation) { assert(permutation.empty() || permutation.size() == sizes.size()); - auto tileRes = tileLinalgOperation(rewriter, op, sizes, permutation); + auto tileRes = tileFn(rewriter, op, sizes, permutation, /*folder=*/nullptr); if (!tileRes) return failure(); tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, @@ -52,10 +57,26 @@ LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( return success(); } -LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( +LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker) { - auto tileRes = tileLinalgOperation(rewriter, op, sizes); + StringRef linalgMarker, ArrayRef permutation) { + return tileLinalgOpAndSetMarkerImpl(tileLinalgOp, rewriter, op, sizes, + linalgMarker, permutation); +} +LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + StringRef linalgMarker, ArrayRef permutation) { + return tileLinalgOpAndSetMarkerImpl(tileLinalgOpToParallelLoops, rewriter, op, + sizes, linalgMarker, permutation); +} + +static LogicalResult +tileAndFuseLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter, + Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, + StringRef linalgMarker) { + auto tileRes = + tileFn(rewriter, op, sizes, /*permutation=*/{}, /*folder=*/nullptr); if (!tileRes) return failure(); tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, @@ -89,6 +110,20 @@ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( return success(); } +LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, StringRef linalgMarker) { + return tileAndFuseLinalgOpAndSetMarkerImpl( + tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker); +} +LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, StringRef linalgMarker) { + return tileAndFuseLinalgOpAndSetMarkerImpl( + tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse, + linalgMarker); +} + bool mlir::linalg::detail::isProducedByOpOfTypeImpl( Operation *consumerOp, Value consumedView, function_ref isaOpType) { -- 2.7.4