[mlir][GPU] Allow bare pointer memrefs when calling GPU kernels
authorKrzysztof Drewniak <Krzysztof.Drewniak@amd.com>
Mon, 11 Jul 2022 18:29:01 +0000 (18:29 +0000)
committerKrzysztof Drewniak <Krzysztof.Drewniak@amd.com>
Tue, 2 Aug 2022 20:58:34 +0000 (20:58 +0000)
In the ROCm runtime (and probably CUDA as well), all kernel arguments
are aligned. Therefore, enable using bare pointers for memref
arguments to kernels when these memrefs have static shape and a
trivial layout.

This is a substantial optimization to launching kernels that use
memrefs with known, static sizes, since it causes the kernel launch
packet to no longer include information already known to the kernel,
which can enable packing the kernel launch arguments into launch
packets instead of having to allocate an entire separate structure to
hold unneeded memref information.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D130716

mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/test/Conversion/GPUToROCDL/memref.mlir [new file with mode: 0644]
mlir/test/Integration/GPU/ROCM/vecadd.mlir

index b048a76..79be00a 100644 (file)
@@ -50,13 +50,15 @@ using LoweringCallback = std::function<std::unique_ptr<llvm::Module>(
 /// This pass does not generate code to call GPU runtime APIs directly but
 /// instead uses a small wrapper library that exports a stable and conveniently
 /// typed ABI on top of GPU runtimes such as CUDA or ROCm (HIP).
-std::unique_ptr<OperationPass<ModuleOp>> createGpuToLLVMConversionPass();
+std::unique_ptr<OperationPass<ModuleOp>>
+createGpuToLLVMConversionPass(bool kernelBarePtrCallConv = false);
 
 /// Collect a set of patterns to convert from the GPU dialect to LLVM and
 /// populate converter for gpu types.
 void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                          RewritePatternSet &patterns,
-                                         StringRef gpuBinaryAnnotation = {});
+                                         StringRef gpuBinaryAnnotation = {},
+                                         bool kernelBarePtrCallConv = false);
 
 } // namespace mlir
 
index 2879cd5..8e45d74 100644 (file)
@@ -41,6 +41,7 @@ std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
 createLowerGpuOpsToROCDLOpsPass(
     const std::string &chipset = "gfx900",
     unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
+    bool useBarePtrCallConv = false,
     gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);
 
 } // namespace mlir
index f9ad358..f9c1511 100644 (file)
@@ -80,6 +80,13 @@ public:
 
   const LowerToLLVMOptions &getOptions() const { return options; }
 
+  /// Set the lowering options to `newOptions`. Note: using this after some
+  /// some conversions have been performed can lead to inconsistencies in the
+  /// IR.
+  void dangerousSetOptions(LowerToLLVMOptions newOptions) {
+    options = std::move(newOptions);
+  }
+
   /// Promote the LLVM representation of all operands including promoting MemRef
   /// descriptors to stack and use pointers to struct to avoid the complexity
   /// of the platform-specific C/C++ ABI lowering related to struct argument
@@ -126,7 +133,7 @@ public:
                                            const DataLayout &layout);
 
   /// Check if a memref type can be converted to a bare pointer.
-  bool canConvertToBarePtr(BaseMemRefType type);
+  static bool canConvertToBarePtr(BaseMemRefType type);
 
 protected:
   /// Pointer to the LLVM dialect.
index 8230e9d..26ec204 100644 (file)
@@ -373,6 +373,10 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
     Option<"indexBitwidth", "index-bitwidth", "unsigned",
            /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
            "Bitwidth of the index type, 0 to use size of machine word">,
+    Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
+           /*default=*/"false",
+           "Replace memref arguments in GPU functions with bare pointers."
+           "All memrefs must have static shape">,
     Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
           "::mlir::gpu::amd::Runtime::Unknown",
           "Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
index e115d9c..62f1f50 100644 (file)
@@ -9,6 +9,7 @@
 #include "GPUOpsLowering.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
@@ -137,6 +138,34 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
                                          &signatureConversion)))
     return failure();
 
+  // If bare memref pointers are being used, remap them back to memref
+  // descriptors This must be done after signature conversion to get rid of the
+  // unrealized casts.
+  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
+    for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
+      auto memrefTy = en.value().dyn_cast<MemRefType>();
+      if (!memrefTy)
+        continue;
+      assert(memrefTy.hasStaticShape() &&
+             "Bare pointer convertion used with dynamically-shaped memrefs");
+      // Use a placeholder when replacing uses of the memref argument to prevent
+      // circular replacements.
+      auto remapping = signatureConversion.getInputMapping(en.index());
+      assert(remapping && remapping->size == 1 &&
+             "Type converter should produce 1-to-1 mapping for bare memrefs");
+      BlockArgument newArg =
+          llvmFuncOp.getBody().getArgument(remapping->inputNo);
+      auto placeholder = rewriter.create<LLVM::UndefOp>(
+          loc, getTypeConverter()->convertType(memrefTy));
+      rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
+      Value desc = MemRefDescriptor::fromStaticShape(
+          rewriter, loc, *getTypeConverter(), memrefTy, newArg);
+      rewriter.replaceOp(placeholder, {desc});
+    }
+  }
+
   rewriter.eraseOp(gpuFuncOp);
   return success();
 }
index aaee821..40159aa 100644 (file)
@@ -49,6 +49,12 @@ class GpuToLLVMConversionPass
 public:
   GpuToLLVMConversionPass() = default;
 
+  GpuToLLVMConversionPass(bool kernelBarePtrCallConv)
+      : GpuToLLVMConversionPass() {
+    if (this->kernelBarePtrCallConv.getNumOccurrences() == 0)
+      this->kernelBarePtrCallConv = kernelBarePtrCallConv;
+  }
+
   GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
       : GpuToLLVMConversionPassBase(other) {}
 
@@ -60,6 +66,11 @@ private:
       *this, "gpu-binary-annotation",
       llvm::cl::desc("Annotation attribute string for GPU binary"),
       llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
+  Option<bool> kernelBarePtrCallConv{
+      *this, "use-bare-pointers-for-kernels",
+      llvm::cl::desc("Use bare pointers to pass memref arguments to kernels. "
+                     "The kernel must use the same setting for this option."),
+      llvm::cl::init(false)};
 };
 
 struct FunctionCallBuilder {
@@ -290,9 +301,11 @@ class ConvertLaunchFuncOpToGpuRuntimeCallPattern
     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
 public:
   ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
-                                             StringRef gpuBinaryAnnotation)
+                                             StringRef gpuBinaryAnnotation,
+                                             bool kernelBarePtrCallConv)
       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
-        gpuBinaryAnnotation(gpuBinaryAnnotation) {}
+        gpuBinaryAnnotation(gpuBinaryAnnotation),
+        kernelBarePtrCallConv(kernelBarePtrCallConv) {}
 
 private:
   Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
@@ -305,6 +318,7 @@ private:
                   ConversionPatternRewriter &rewriter) const override;
 
   llvm::SmallString<32> gpuBinaryAnnotation;
+  bool kernelBarePtrCallConv;
 };
 
 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
@@ -377,7 +391,8 @@ void GpuToLLVMConversionPass::runOnOperation() {
   populateFuncToLLVMConversionPatterns(converter, patterns);
   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
                                                     target);
-  populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
+  populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
+                                      kernelBarePtrCallConv);
 
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
@@ -635,9 +650,24 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
   auto loc = launchOp.getLoc();
   auto numKernelOperands = launchOp.getNumKernelOperands();
-  auto arguments = getTypeConverter()->promoteOperands(
-      loc, launchOp.getOperands().take_back(numKernelOperands),
-      adaptor.getOperands().take_back(numKernelOperands), builder);
+  SmallVector<Value, 4> arguments;
+  if (kernelBarePtrCallConv) {
+    // Hack the bare pointer value on just for the argument promotion
+    LLVMTypeConverter *converter = getTypeConverter();
+    LowerToLLVMOptions options = converter->getOptions();
+    LowerToLLVMOptions overrideToMatchKernelOpts = options;
+    overrideToMatchKernelOpts.useBarePtrCallConv = true;
+    converter->dangerousSetOptions(overrideToMatchKernelOpts);
+    arguments = converter->promoteOperands(
+        loc, launchOp.getOperands().take_back(numKernelOperands),
+        adaptor.getOperands().take_back(numKernelOperands), builder);
+    converter->dangerousSetOptions(options);
+  } else {
+    arguments = getTypeConverter()->promoteOperands(
+        loc, launchOp.getOperands().take_back(numKernelOperands),
+        adaptor.getOperands().take_back(numKernelOperands), builder);
+  }
+
   auto numArguments = arguments.size();
   SmallVector<Type, 4> argumentTypes;
   argumentTypes.reserve(numArguments);
@@ -870,13 +900,14 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
 }
 
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-mlir::createGpuToLLVMConversionPass() {
-  return std::make_unique<GpuToLLVMConversionPass>();
+mlir::createGpuToLLVMConversionPass(bool kernelBarePtrCallConv) {
+  return std::make_unique<GpuToLLVMConversionPass>(kernelBarePtrCallConv);
 }
 
 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns,
-                                               StringRef gpuBinaryAnnotation) {
+                                               StringRef gpuBinaryAnnotation,
+                                               bool kernelBarePtrCallConv) {
   converter.addConversion(
       [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
         return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
@@ -890,7 +921,7 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertWaitAsyncOpToGpuRuntimeCallPattern,
                ConvertWaitOpToGpuRuntimeCallPattern,
                ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
-  patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
-                                                           gpuBinaryAnnotation);
+  patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
+      converter, gpuBinaryAnnotation, kernelBarePtrCallConv);
   patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
 }
index 57e4a15..163c4a3 100644 (file)
 
 using namespace mlir;
 
+/// Returns true if the given `gpu.func` can be safely called using the bare
+/// pointer calling convention.
+static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
+  bool canBeBare = true;
+  for (Type type : func.getArgumentTypes())
+    if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
+      canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
+  return canBeBare;
+}
+
 namespace {
 
 /// Import the GPU Ops to ROCDL Patterns.
@@ -55,10 +65,16 @@ struct LowerGpuOpsToROCDLOpsPass
     : public ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
   LowerGpuOpsToROCDLOpsPass() = default;
   LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
+                            bool useBarePtrCallConv,
                             gpu::amd::Runtime runtime) {
-    this->chipset = chipset;
-    this->indexBitwidth = indexBitwidth;
-    this->runtime = runtime;
+    if (this->chipset.getNumOccurrences() == 0)
+      this->chipset = chipset;
+    if (this->indexBitwidth.getNumOccurrences() == 0)
+      this->indexBitwidth = indexBitwidth;
+    if (this->useBarePtrCallConv.getNumOccurrences() == 0)
+      this->useBarePtrCallConv = useBarePtrCallConv;
+    if (this->runtime.getNumOccurrences() == 0)
+      this->runtime = runtime;
   }
 
   void runOnOperation() override {
@@ -82,6 +98,23 @@ struct LowerGpuOpsToROCDLOpsPass
         ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
       options.overrideIndexBitwidth(indexBitwidth);
+
+    if (useBarePtrCallConv) {
+      options.useBarePtrCallConv = true;
+      WalkResult canUseBarePointers =
+          m.walk([](gpu::GPUFuncOp func) -> WalkResult {
+            if (canBeCalledWithBarePointers(func))
+              return WalkResult::advance();
+            return WalkResult::interrupt();
+          });
+      if (canUseBarePointers.wasInterrupted()) {
+        emitError(UnknownLoc::get(ctx),
+                  "bare pointer calling convention requires all memrefs to "
+                  "have static shape and use the identity map");
+        return signalPassFailure();
+      }
+    }
+
     LLVMTypeConverter converter(ctx, options);
 
     RewritePatternSet patterns(ctx);
@@ -189,7 +222,8 @@ void mlir::populateGpuToROCDLConversionPatterns(
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
 mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
                                       unsigned indexBitwidth,
+                                      bool useBarePtrCallConv,
                                       gpu::amd::Runtime runtime) {
-  return std::make_unique<LowerGpuOpsToROCDLOpsPass>(chipset, indexBitwidth,
-                                                     runtime);
+  return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
+      chipset, indexBitwidth, useBarePtrCallConv, runtime);
 }
diff --git a/mlir/test/Conversion/GPUToROCDL/memref.mlir b/mlir/test/Conversion/GPUToROCDL/memref.mlir
new file mode 100644 (file)
index 0000000..9d179ee
--- /dev/null
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
+// RUN: mlir-opt %s \
+// RUN:   -convert-gpu-to-rocdl=use-bare-ptr-memref-call-conv=true \
+// RUN:   -split-input-file \
+// RUN: | FileCheck %s --check-prefix=BARE
+
+gpu.module @memref_conversions {
+  // CHECK: llvm.func @kern
+  // CHECK-SAME: (%{{.*}}: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64)
+  // BARE: llvm.func @kern
+  // BARE-SAME: (%{{.*}}: !llvm.ptr<f32>)
+  gpu.func @kern(%arg0: memref<8xf32>) kernel {
+    gpu.return
+  }
+}
index 52a440e..4175825 100644 (file)
@@ -1,24 +1,24 @@
 // RUN: mlir-opt %s \
 // RUN:   -convert-scf-to-cf \
 // RUN:   -gpu-kernel-outlining \
-// RUN:   -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl,gpu-to-hsaco{chip=%chip})' \
-// RUN:   -gpu-to-llvm \
+// RUN:   -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl{use-bare-ptr-memref-call-conv=true},gpu-to-hsaco{chip=%chip})' \
+// RUN:   -gpu-to-llvm=use-bare-pointers-for-kernels=true \
 // RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext \
 // RUN:   --shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
 // RUN:   --entry-point-result=void \
 // RUN: | FileCheck %s
 
-func.func @vecadd(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>, %arg2 : memref<?xf32>) {
+func.func @vecadd(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %arg2 : memref<5xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  %block_dim = memref.dim %arg0, %c0 : memref<?xf32>
+  %block_dim = arith.constant 5 : index
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
              threads(%tx, %ty, %tz) in (%block_x = %block_dim, %block_y = %c1, %block_z = %c1) {
-    %a = memref.load %arg0[%tx] : memref<?xf32>
-    %b = memref.load %arg1[%tx] : memref<?xf32>
+    %a = memref.load %arg0[%tx] : memref<5xf32>
+    %b = memref.load %arg1[%tx] : memref<5xf32>
     %c = arith.addf %a, %b : f32
-    memref.store %c, %arg2[%tx] : memref<?xf32>
+    memref.store %c, %arg2[%tx] : memref<5xf32>
     gpu.terminator
   }
   return
@@ -49,8 +49,11 @@ func.func @main() {
   %9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref<?xf32>) -> (memref<?xf32>)
   %10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref<?xf32>) -> (memref<?xf32>)
   %11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref<?xf32>) -> (memref<?xf32>)
+  %12 = memref.cast %9 : memref<?xf32> to memref<5xf32>
+  %13 = memref.cast %10 : memref<?xf32> to memref<5xf32>
+  %14 = memref.cast %11 : memref<?xf32> to memref<5xf32>
 
-  call @vecadd(%9, %10, %11) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
+  call @vecadd(%12, %13, %14) : (memref<5xf32>, memref<5xf32>, memref<5xf32>) -> ()
   call @printMemrefF32(%8) : (memref<*xf32>) -> ()
   return
 }