[mlir][gpu] NFC: Make room for more than one GPU rewrite pattern.
authorChristian Sigg <csigg@google.com>
Wed, 14 Oct 2020 08:31:08 +0000 (10:31 +0200)
committerChristian Sigg <csigg@google.com>
Mon, 19 Oct 2020 05:52:47 +0000 (07:52 +0200)
AllReduceLowering is currently the only GPU rewrite pattern, but more are coming. This is a preparation change.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D89370

mlir/include/mlir/Dialect/GPU/Passes.h
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/test/Dialect/GPU/all-reduce-max.mlir
mlir/test/Dialect/GPU/all-reduce.mlir
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/TestGpuRewrite.cpp [moved from mlir/test/lib/Transforms/TestAllReduceLowering.cpp with 81% similarity]

index 64b744b..21c526e 100644 (file)
 namespace mlir {
 std::unique_ptr<OperationPass<ModuleOp>> createGpuKernelOutliningPass();
 
-/// Collect a set of patterns to rewrite ops within the GPU dialect.
-void populateGpuRewritePatterns(MLIRContext *context,
-                                OwningRewritePatternList &patterns);
+/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
+void populateGpuAllReducePatterns(MLIRContext *context,
+                                  OwningRewritePatternList &patterns);
+
+/// Collect all patterns to rewrite ops within the GPU dialect.
+inline void populateGpuRewritePatterns(MLIRContext *context,
+                                       OwningRewritePatternList &patterns) {
+  populateGpuAllReducePatterns(context, patterns);
+}
 
 //===----------------------------------------------------------------------===//
 // Registration
index 38df9ef..d3ee3e2 100644 (file)
@@ -397,7 +397,7 @@ struct GpuAllReduceConversion : public RewritePattern {
 };
 } // namespace
 
-void mlir::populateGpuRewritePatterns(MLIRContext *context,
-                                      OwningRewritePatternList &patterns) {
+void mlir::populateGpuAllReducePatterns(MLIRContext *context,
+                                        OwningRewritePatternList &patterns) {
   patterns.insert<GpuAllReduceConversion>(context);
 }
index 142228d..fb5fcaf 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s
+// RUN: mlir-opt -test-gpu-rewrite %s | FileCheck %s
 
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
 // CHECK: gpu.module @kernels {
index 491d9b3..758fb44 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s
+// RUN: mlir-opt -test-gpu-rewrite %s | FileCheck %s
 
 // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
 // CHECK: gpu.module @kernels {
index effa7e2..e3e82f2 100644 (file)
@@ -1,6 +1,5 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestTransforms
-  TestAllReduceLowering.cpp
   TestAffineLoopParametricTiling.cpp
   TestBufferPlacement.cpp
   TestExpandTanh.cpp
@@ -15,6 +14,7 @@ add_mlir_library(MLIRTestTransforms
   TestLoopFusion.cpp
   TestGpuMemoryPromotion.cpp
   TestGpuParallelLoopMapping.cpp
+  TestGpuRewrite.cpp
   TestInlining.cpp
   TestLinalgCodegenStrategy.cpp
   TestLinalgFusionTransforms.cpp
@@ -18,8 +18,8 @@
 using namespace mlir;
 
 namespace {
-struct TestAllReduceLoweringPass
-    : public PassWrapper<TestAllReduceLoweringPass, OperationPass<ModuleOp>> {
+struct TestGpuRewritePass
+    : public PassWrapper<TestGpuRewritePass, OperationPass<ModuleOp>> {
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<StandardOpsDialect>();
   }
@@ -33,8 +33,8 @@ struct TestAllReduceLoweringPass
 
 namespace mlir {
 void registerTestAllReduceLoweringPass() {
-  PassRegistration<TestAllReduceLoweringPass> pass(
-      "test-all-reduce-lowering",
-      "Lowers gpu.all-reduce ops within the GPU dialect.");
+  PassRegistration<TestGpuRewritePass> pass(
+      "test-gpu-rewrite",
+      "Applies all rewrite patterns within the GPU dialect.");
 }
 } // namespace mlir