From 4abccd3913be0fc56e0383e04b3c0a4b872db767 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 5 Jun 2023 08:40:20 +0200 Subject: [PATCH] [mlir][memref][transform] Register memref dialect patterns Differential Revision: https://reviews.llvm.org/D151998 --- .../mlir/Dialect/MemRef/Transforms/Transforms.h | 8 ++++---- mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 4 ++-- .../Dialect/MemRef/TransformOps/MemRefTransformOps.cpp | 18 ++++++++++++++++++ .../MemRef/Transforms/ResolveShapedTypeResultDims.cpp | 6 +++--- 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index 1fe1342..91ef162 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -/// This header declares functions that assit transformations in the MemRef +/// This header declares functions that assist transformations in the MemRef /// dialect. // //===----------------------------------------------------------------------===// @@ -44,9 +44,9 @@ void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns); /// Appends patterns that resolve `memref.dim` operations with values that are /// defined by operations that implement the -/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input +/// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input /// operands. -void populateResolveRankedShapeTypeResultDimsPatterns( +void populateResolveRankedShapedTypeResultDimsPatterns( RewritePatternSet &patterns); /// Appends patterns that resolve `memref.dim` operations with values that are @@ -68,7 +68,7 @@ void populateMemRefWideIntEmulationPatterns( arith::WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns); -/// Appends type converions for emulating wide integer memref operations with +/// Appends type conversions for emulating wide integer memref operations with /// ops over narrowe integer types. void populateMemRefWideIntEmulationConversions( arith::WideIntEmulationConverter &typeConverter); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index d8eccb9..f238306 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -675,7 +675,7 @@ void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns( tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); tensor::populateFoldTensorEmptyPatterns(patterns); - memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } @@ -689,7 +689,7 @@ void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns( linalg::FillOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::populateFoldTensorEmptyPatterns(patterns); - memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index f8b4491..7b63613 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -161,6 +162,23 @@ public: #define GET_OP_LIST #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" >(); + + addDialectDataInitializer( + [&](transform::PatternRegistry ®istry) { + registry.registerPatterns("memref.expand_ops", + memref::populateExpandOpsPatterns); + registry.registerPatterns("memref.fold_memref_alias_ops", + memref::populateFoldMemRefAliasOpPatterns); + registry.registerPatterns( + "memref.resolve_ranked_shaped_type_result_dims", + memref::populateResolveRankedShapedTypeResultDimsPatterns); + registry.registerPatterns( + "memref.expand_strided_metadata", + memref::populateExpandStridedMetadataPatterns); + registry.registerPatterns( + "memref.extract_address_computations", + memref::populateExtractAddressComputationsPatterns); + }); } }; } // namespace diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 526c1c6..9e5fc73 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -121,7 +121,7 @@ struct ResolveShapedTypeResultDimsPass final } // namespace -void memref::populateResolveRankedShapeTypeResultDimsPatterns( +void memref::populateResolveRankedShapedTypeResultDimsPatterns( RewritePatternSet &patterns) { patterns.add, DimOfReifyRankedShapedTypeOpInterface>( @@ -138,14 +138,14 @@ void memref::populateResolveShapedTypeResultDimsPatterns( void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } void ResolveShapedTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); -- 2.7.4