[mlir][spirv] Convert gpu.barrier to spv.ControlBarrier
authorLei Zhang <antiagainst@google.com>
Tue, 1 Mar 2022 17:00:38 +0000 (12:00 -0500)
committerLei Zhang <antiagainst@google.com>
Tue, 1 Mar 2022 17:04:00 +0000 (12:04 -0500)
Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D120722

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/test/Conversion/GPUToSPIRV/simple.mlir

index edfc37f..8c5627c 100644 (file)
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/StringSwitch.h"
 
 using namespace mlir;
 
@@ -109,6 +109,16 @@ public:
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Pattern to convert a gpu.barrier op into a spv.ControlBarrier op.
+class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -327,14 +337,33 @@ LogicalResult GPUReturnOpConversion::matchAndRewrite(
 }
 
 //===----------------------------------------------------------------------===//
+// Barrier.
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPUBarrierConversion::matchAndRewrite(
+    gpu::BarrierOp barrierOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  MLIRContext *context = getContext();
+  // Both execution and memory scope should be workgroup.
+  auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
+  // Require acquire and release memory semantics for workgroup memory.
+  auto memorySemantics = spirv::MemorySemanticsAttr::get(
+      context, spirv::MemorySemantics::WorkgroupMemory |
+                   spirv::MemorySemantics::AcquireRelease);
+  rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
+                                                       memorySemantics);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // GPU To SPIRV Patterns.
 //===----------------------------------------------------------------------===//
 
 void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                       RewritePatternSet &patterns) {
   patterns.add<
-      GPUFuncOpConversion, GPUModuleConversion, GPUModuleEndConversion,
-      GPUReturnOpConversion,
+      GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
+      GPUModuleEndConversion, GPUReturnOpConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::ThreadIdOp,
index d2d983f..b590d41 100644 (file)
@@ -104,3 +104,27 @@ module attributes {gpu.container_module} {
     }
   }
 }
+
+// -----
+
+module attributes {gpu.container_module} {
+  gpu.module @kernels {
+    // CHECK-LABEL: spv.func @barrier
+    gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32>) kernel
+      attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} {
+      // CHECK: spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+      gpu.barrier
+      gpu.return
+    }
+  }
+
+  func @main() {
+    %0 = "op"() : () -> (f32)
+    %1 = "op"() : () -> (memref<12xf32>)
+    %cst = arith.constant 1 : index
+    gpu.launch_func @kernels::@barrier
+        blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
+        args(%0 : f32, %1 : memref<12xf32>)
+    return
+  }
+}