Allow specification of the workgroup size for GPUToSPIRV lowering.
authorMahesh Ravishankar <ravishankarm@google.com>
Thu, 5 Dec 2019 19:31:28 +0000 (11:31 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 5 Dec 2019 19:31:57 +0000 (11:31 -0800)
SPIR-V/Vulkan spec requires the workgroups size to be specified with
the spv.ExecutionMode operation. This was hard-wired to be set to a
particular value. It is now changed to be configurable by clients of
the pass or of the patterns that implement the lowering from GPU to
SPIRV.

PiperOrigin-RevId: 284017482

mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h
mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h
mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
mlir/test/Conversion/GPUToSPIRV/simple.mlir

index f617986cdccbd0ab1da64b25b9ae2e6702947197..134dbf40b4df90f6e0a4b20d687b95ec1153d930 100644 (file)
 namespace mlir {
 class SPIRVTypeConverter;
 /// Appends to a pattern list additional patterns for translating GPU Ops to
-/// SPIR-V ops.
+/// SPIR-V ops. Needs the workgroup size as input since SPIR-V/Vulkan requires
+/// the workgroup size to be statically specified.
 void populateGPUToSPIRVPatterns(MLIRContext *context,
                                 SPIRVTypeConverter &typeConverter,
-                                OwningRewritePatternList &patterns);
+                                OwningRewritePatternList &patterns,
+                                ArrayRef<int64_t> workGroupSize);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H
index be8cad2b3d181bec5107737ecb1116f143e9733d..8f0a910c74d78116840da0d73e40ab8bf5dd5900 100644 (file)
@@ -22,6 +22,8 @@
 #ifndef MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
 #define MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
 
+#include "mlir/Support/LLVM.h"
+
 #include <memory>
 
 namespace mlir {
@@ -29,8 +31,10 @@ namespace mlir {
 class ModuleOp;
 template <typename T> class OpPassBase;
 
-/// Pass to convert GPU Ops to SPIR-V ops.
-std::unique_ptr<OpPassBase<ModuleOp>> createConvertGPUToSPIRVPass();
+/// Pass to convert GPU Ops to SPIR-V ops.  Needs the workgroup size as input
+/// since SPIR-V/Vulkan requires the workgroup size to be statically specified.
+std::unique_ptr<OpPassBase<ModuleOp>>
+createConvertGPUToSPIRVPass(ArrayRef<int64_t> workGroupSize);
 
 } // namespace mlir
 #endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
index a562439108d4f6a8192ef58b80c179c207a77945..be82894461d610f93c7648bdd18449cde0361b0a 100644 (file)
@@ -10,5 +10,6 @@ target_link_libraries(MLIRGPUtoSPIRVTransforms
   MLIRSPIRV
   MLIRStandardOps
   MLIRStandardToSPIRVTransforms
+  MLIRSupport
   MLIRTransforms
   )
index 23e7b9166bb671a6f8d5152a539800f4144aa501..2c1847d99ed0de1f541d614af5c9914cd5fa8f38 100644 (file)
@@ -54,11 +54,21 @@ public:
 /// attribute gpu.kernel) within a spv.module.
 class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
 public:
-  using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+  KernelFnConversion(MLIRContext *context, SPIRVTypeConverter &converter,
+                     ArrayRef<int64_t> workGroupSize,
+                     PatternBenefit benefit = 1)
+      : SPIRVOpLowering<FuncOp>(context, converter, benefit) {
+    auto config = workGroupSize.take_front(3);
+    workGroupSizeAsInt32.assign(config.begin(), config.end());
+    workGroupSizeAsInt32.resize(3, 1);
+  }
 
   PatternMatchResult
   matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override;
+
+private:
+  SmallVector<int32_t, 3> workGroupSizeAsInt32;
 };
 
 } // namespace
@@ -172,10 +182,10 @@ KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
     argABI.push_back(spirv::getInterfaceVarABIAttr(
         0, argNum, spirv::StorageClass::StorageBuffer, rewriter.getContext()));
   }
-  // TODO(ravishankarm) : For now set this to {32, 1, 1}. This is incorrect. The
-  // actual workgroup size needs to be plumbed through.
+
   auto context = rewriter.getContext();
-  auto entryPointAttr = spirv::getEntryPointABIAttr({32, 1, 1}, context);
+  auto entryPointAttr =
+      spirv::getEntryPointABIAttr(workGroupSizeAsInt32, context);
   FuncOp newFuncOp = spirv::lowerAsEntryFunction(
       funcOp, typeConverter, rewriter, argABI, entryPointAttr);
   if (!newFuncOp) {
@@ -189,9 +199,11 @@ KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
 namespace mlir {
 void populateGPUToSPIRVPatterns(MLIRContext *context,
                                 SPIRVTypeConverter &typeConverter,
-                                OwningRewritePatternList &patterns) {
+                                OwningRewritePatternList &patterns,
+                                ArrayRef<int64_t> workGroupSize) {
+  patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize);
   patterns.insert<
-      ForOpConversion, KernelFnConversion,
+      ForOpConversion,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
index 49f161e3794019d170a4c0cf3056523aeba3b29e..cec71ca9b3fd8cbcf9a808c3cea943b187f30fe2 100644 (file)
@@ -28,6 +28,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
 
 using namespace mlir;
 
@@ -42,7 +43,23 @@ namespace {
 ///
 /// 2) Lower the body of the spirv::ModuleOp.
 class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+public:
+  GPUToSPIRVPass(ArrayRef<int64_t> workGroupSize)
+      : workGroupSize(workGroupSize.begin(), workGroupSize.end()) {}
   void runOnModule() override;
+
+private:
+  SmallVector<int64_t, 3> workGroupSize;
+};
+
+/// Command line option to specify the workgroup size.
+struct GPUToSPIRVPassOptions : public PassOptions<GPUToSPIRVPassOptions> {
+  List<unsigned> workGroupSize{
+      *this, "workgroup-size",
+      llvm::cl::desc(
+          "Workgroup Sizes in the SPIR-V module for x, followed by y, followed "
+          "by z dimension of the dispatch (others will be ignored)"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
 };
 } // namespace
 
@@ -80,7 +97,7 @@ void GPUToSPIRVPass::runOnModule() {
   /// Dialect conversion to lower the functions with the spirv::ModuleOps.
   SPIRVTypeConverter typeConverter;
   OwningRewritePatternList patterns;
-  populateGPUToSPIRVPatterns(context, typeConverter, patterns);
+  populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
 
   ConversionTarget target(*context);
@@ -94,9 +111,16 @@ void GPUToSPIRVPass::runOnModule() {
   }
 }
 
-std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertGPUToSPIRVPass() {
-  return std::make_unique<GPUToSPIRVPass>();
+std::unique_ptr<OpPassBase<ModuleOp>>
+mlir::createConvertGPUToSPIRVPass(ArrayRef<int64_t> workGroupSize) {
+  return std::make_unique<GPUToSPIRVPass>(workGroupSize);
 }
 
-static PassRegistration<GPUToSPIRVPass>
-    pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
+static PassRegistration<GPUToSPIRVPass, GPUToSPIRVPassOptions>
+    pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect",
+         [](const GPUToSPIRVPassOptions &passOptions) {
+           SmallVector<int64_t, 3> workGroupSize;
+           workGroupSize.assign(passOptions.workGroupSize.begin(),
+                                passOptions.workGroupSize.end());
+           return std::make_unique<GPUToSPIRVPass>(workGroupSize);
+         });
index b0723c69bc6181ebd94bf045ff93800baff39c27..ef136b9bb62c908fcc65e943c542f2f113b24bfe 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -pass-pipeline='convert-gpu-to-spirv{workgroup-size=32,4}' %s -o - | FileCheck %s
 
 module attributes {gpu.container_module} {
 
@@ -7,6 +7,7 @@ module attributes {gpu.container_module} {
     // CHECK-LABEL: func @kernel_1
     // CHECK-SAME: {{%.*}}: f32 {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
     // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: spirv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
     func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32>)
         attributes { gpu.kernel } {
       // CHECK: spv.Return