AllReduceLowering is currently the only GPU rewrite pattern, but more are coming. This is a preparation change.
Reviewed By: herhut
Differential Revision: https://reviews.llvm.org/D89370
namespace mlir {
std::unique_ptr<OperationPass<ModuleOp>> createGpuKernelOutliningPass();
-/// Collect a set of patterns to rewrite ops within the GPU dialect.
-void populateGpuRewritePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
+void populateGpuAllReducePatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns);
+
+/// Collect all patterns to rewrite ops within the GPU dialect.
+inline void populateGpuRewritePatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns) {
+ populateGpuAllReducePatterns(context, patterns);
+}
//===----------------------------------------------------------------------===//
// Registration
};
} // namespace
-void mlir::populateGpuRewritePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
+void mlir::populateGpuAllReducePatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns) {
patterns.insert<GpuAllReduceConversion>(context);
}
-// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s
+// RUN: mlir-opt -test-gpu-rewrite %s | FileCheck %s
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK: gpu.module @kernels {
-// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s
+// RUN: mlir-opt -test-gpu-rewrite %s | FileCheck %s
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK: gpu.module @kernels {
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
- TestAllReduceLowering.cpp
TestAffineLoopParametricTiling.cpp
TestBufferPlacement.cpp
TestExpandTanh.cpp
TestLoopFusion.cpp
TestGpuMemoryPromotion.cpp
TestGpuParallelLoopMapping.cpp
+ TestGpuRewrite.cpp
TestInlining.cpp
TestLinalgCodegenStrategy.cpp
TestLinalgFusionTransforms.cpp
using namespace mlir;
namespace {
-struct TestAllReduceLoweringPass
- : public PassWrapper<TestAllReduceLoweringPass, OperationPass<ModuleOp>> {
+struct TestGpuRewritePass
+ : public PassWrapper<TestGpuRewritePass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<StandardOpsDialect>();
}
namespace mlir {
void registerTestAllReduceLoweringPass() {
- PassRegistration<TestAllReduceLoweringPass> pass(
- "test-all-reduce-lowering",
- "Lowers gpu.all-reduce ops within the GPU dialect.");
+ PassRegistration<TestGpuRewritePass> pass(
+ "test-gpu-rewrite",
+ "Applies all rewrite patterns within the GPU dialect.");
}
} // namespace mlir