/// This pass does not generate code to call GPU runtime APIs directly but
/// instead uses a small wrapper library that exports a stable and conveniently
/// typed ABI on top of GPU runtimes such as CUDA or ROCm (HIP).
-std::unique_ptr<OperationPass<ModuleOp>> createGpuToLLVMConversionPass();
+std::unique_ptr<OperationPass<ModuleOp>>
+createGpuToLLVMConversionPass(bool kernelBarePtrCallConv = false);
/// Collect a set of patterns to convert from the GPU dialect to LLVM and
/// populate converter for gpu types.
void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
- StringRef gpuBinaryAnnotation = {});
+ StringRef gpuBinaryAnnotation = {},
+ bool kernelBarePtrCallConv = false);
} // namespace mlir
createLowerGpuOpsToROCDLOpsPass(
const std::string &chipset = "gfx900",
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
+ bool useBarePtrCallConv = false,
gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);
} // namespace mlir
const LowerToLLVMOptions &getOptions() const { return options; }
+ /// Set the lowering options to `newOptions`. Note: using this after some
+ /// some conversions have been performed can lead to inconsistencies in the
+ /// IR.
+ void dangerousSetOptions(LowerToLLVMOptions newOptions) {
+ options = std::move(newOptions);
+ }
+
/// Promote the LLVM representation of all operands including promoting MemRef
/// descriptors to stack and use pointers to struct to avoid the complexity
/// of the platform-specific C/C++ ABI lowering related to struct argument
const DataLayout &layout);
/// Check if a memref type can be converted to a bare pointer.
- bool canConvertToBarePtr(BaseMemRefType type);
+ static bool canConvertToBarePtr(BaseMemRefType type);
protected:
/// Pointer to the LLVM dialect.
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
"Bitwidth of the index type, 0 to use size of machine word">,
+ Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
+ /*default=*/"false",
+ "Replace memref arguments in GPU functions with bare pointers."
+ "All memrefs must have static shape">,
Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
"::mlir::gpu::amd::Runtime::Unknown",
"Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
#include "GPUOpsLowering.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
&signatureConversion)))
return failure();
+ // If bare memref pointers are being used, remap them back to memref
+ // descriptors This must be done after signature conversion to get rid of the
+ // unrealized casts.
+ if (getTypeConverter()->getOptions().useBarePtrCallConv) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
+ for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
+ auto memrefTy = en.value().dyn_cast<MemRefType>();
+ if (!memrefTy)
+ continue;
+ assert(memrefTy.hasStaticShape() &&
+ "Bare pointer convertion used with dynamically-shaped memrefs");
+ // Use a placeholder when replacing uses of the memref argument to prevent
+ // circular replacements.
+ auto remapping = signatureConversion.getInputMapping(en.index());
+ assert(remapping && remapping->size == 1 &&
+ "Type converter should produce 1-to-1 mapping for bare memrefs");
+ BlockArgument newArg =
+ llvmFuncOp.getBody().getArgument(remapping->inputNo);
+ auto placeholder = rewriter.create<LLVM::UndefOp>(
+ loc, getTypeConverter()->convertType(memrefTy));
+ rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
+ Value desc = MemRefDescriptor::fromStaticShape(
+ rewriter, loc, *getTypeConverter(), memrefTy, newArg);
+ rewriter.replaceOp(placeholder, {desc});
+ }
+ }
+
rewriter.eraseOp(gpuFuncOp);
return success();
}
public:
GpuToLLVMConversionPass() = default;
+ GpuToLLVMConversionPass(bool kernelBarePtrCallConv)
+ : GpuToLLVMConversionPass() {
+ if (this->kernelBarePtrCallConv.getNumOccurrences() == 0)
+ this->kernelBarePtrCallConv = kernelBarePtrCallConv;
+ }
+
GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
: GpuToLLVMConversionPassBase(other) {}
*this, "gpu-binary-annotation",
llvm::cl::desc("Annotation attribute string for GPU binary"),
llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
+ Option<bool> kernelBarePtrCallConv{
+ *this, "use-bare-pointers-for-kernels",
+ llvm::cl::desc("Use bare pointers to pass memref arguments to kernels. "
+ "The kernel must use the same setting for this option."),
+ llvm::cl::init(false)};
};
struct FunctionCallBuilder {
: public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
public:
ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
- StringRef gpuBinaryAnnotation)
+ StringRef gpuBinaryAnnotation,
+ bool kernelBarePtrCallConv)
: ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
- gpuBinaryAnnotation(gpuBinaryAnnotation) {}
+ gpuBinaryAnnotation(gpuBinaryAnnotation),
+ kernelBarePtrCallConv(kernelBarePtrCallConv) {}
private:
Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
llvm::SmallString<32> gpuBinaryAnnotation;
+ bool kernelBarePtrCallConv;
};
class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
populateFuncToLLVMConversionPatterns(converter, patterns);
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
target);
- populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
+ populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
+ kernelBarePtrCallConv);
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
auto loc = launchOp.getLoc();
auto numKernelOperands = launchOp.getNumKernelOperands();
- auto arguments = getTypeConverter()->promoteOperands(
- loc, launchOp.getOperands().take_back(numKernelOperands),
- adaptor.getOperands().take_back(numKernelOperands), builder);
+ SmallVector<Value, 4> arguments;
+ if (kernelBarePtrCallConv) {
+ // Hack the bare pointer value on just for the argument promotion
+ LLVMTypeConverter *converter = getTypeConverter();
+ LowerToLLVMOptions options = converter->getOptions();
+ LowerToLLVMOptions overrideToMatchKernelOpts = options;
+ overrideToMatchKernelOpts.useBarePtrCallConv = true;
+ converter->dangerousSetOptions(overrideToMatchKernelOpts);
+ arguments = converter->promoteOperands(
+ loc, launchOp.getOperands().take_back(numKernelOperands),
+ adaptor.getOperands().take_back(numKernelOperands), builder);
+ converter->dangerousSetOptions(options);
+ } else {
+ arguments = getTypeConverter()->promoteOperands(
+ loc, launchOp.getOperands().take_back(numKernelOperands),
+ adaptor.getOperands().take_back(numKernelOperands), builder);
+ }
+
auto numArguments = arguments.size();
SmallVector<Type, 4> argumentTypes;
argumentTypes.reserve(numArguments);
}
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-mlir::createGpuToLLVMConversionPass() {
- return std::make_unique<GpuToLLVMConversionPass>();
+mlir::createGpuToLLVMConversionPass(bool kernelBarePtrCallConv) {
+ return std::make_unique<GpuToLLVMConversionPass>(kernelBarePtrCallConv);
}
void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
- StringRef gpuBinaryAnnotation) {
+ StringRef gpuBinaryAnnotation,
+ bool kernelBarePtrCallConv) {
converter.addConversion(
[context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
ConvertWaitAsyncOpToGpuRuntimeCallPattern,
ConvertWaitOpToGpuRuntimeCallPattern,
ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
- patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
- gpuBinaryAnnotation);
+ patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
+ converter, gpuBinaryAnnotation, kernelBarePtrCallConv);
patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
}
using namespace mlir;
+/// Returns true if the given `gpu.func` can be safely called using the bare
+/// pointer calling convention.
+static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
+ bool canBeBare = true;
+ for (Type type : func.getArgumentTypes())
+ if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
+ canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
+ return canBeBare;
+}
+
namespace {
/// Import the GPU Ops to ROCDL Patterns.
: public ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
LowerGpuOpsToROCDLOpsPass() = default;
LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
+ bool useBarePtrCallConv,
gpu::amd::Runtime runtime) {
- this->chipset = chipset;
- this->indexBitwidth = indexBitwidth;
- this->runtime = runtime;
+ if (this->chipset.getNumOccurrences() == 0)
+ this->chipset = chipset;
+ if (this->indexBitwidth.getNumOccurrences() == 0)
+ this->indexBitwidth = indexBitwidth;
+ if (this->useBarePtrCallConv.getNumOccurrences() == 0)
+ this->useBarePtrCallConv = useBarePtrCallConv;
+ if (this->runtime.getNumOccurrences() == 0)
+ this->runtime = runtime;
}
void runOnOperation() override {
ctx, DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
+
+ if (useBarePtrCallConv) {
+ options.useBarePtrCallConv = true;
+ WalkResult canUseBarePointers =
+ m.walk([](gpu::GPUFuncOp func) -> WalkResult {
+ if (canBeCalledWithBarePointers(func))
+ return WalkResult::advance();
+ return WalkResult::interrupt();
+ });
+ if (canUseBarePointers.wasInterrupted()) {
+ emitError(UnknownLoc::get(ctx),
+ "bare pointer calling convention requires all memrefs to "
+ "have static shape and use the identity map");
+ return signalPassFailure();
+ }
+ }
+
LLVMTypeConverter converter(ctx, options);
RewritePatternSet patterns(ctx);
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
unsigned indexBitwidth,
+ bool useBarePtrCallConv,
gpu::amd::Runtime runtime) {
- return std::make_unique<LowerGpuOpsToROCDLOpsPass>(chipset, indexBitwidth,
- runtime);
+ return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
+ chipset, indexBitwidth, useBarePtrCallConv, runtime);
}
--- /dev/null
+// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
+// RUN: mlir-opt %s \
+// RUN: -convert-gpu-to-rocdl=use-bare-ptr-memref-call-conv=true \
+// RUN: -split-input-file \
+// RUN: | FileCheck %s --check-prefix=BARE
+
+gpu.module @memref_conversions {
+ // CHECK: llvm.func @kern
+ // CHECK-SAME: (%{{.*}}: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64)
+ // BARE: llvm.func @kern
+ // BARE-SAME: (%{{.*}}: !llvm.ptr<f32>)
+ gpu.func @kern(%arg0: memref<8xf32>) kernel {
+ gpu.return
+ }
+}
// RUN: mlir-opt %s \
// RUN: -convert-scf-to-cf \
// RUN: -gpu-kernel-outlining \
-// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl,gpu-to-hsaco{chip=%chip})' \
-// RUN: -gpu-to-llvm \
+// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl{use-bare-ptr-memref-call-conv=true},gpu-to-hsaco{chip=%chip})' \
+// RUN: -gpu-to-llvm=use-bare-pointers-for-kernels=true \
// RUN: | mlir-cpu-runner \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
// RUN: --entry-point-result=void \
// RUN: | FileCheck %s
-func.func @vecadd(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>, %arg2 : memref<?xf32>) {
+func.func @vecadd(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %arg2 : memref<5xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %block_dim = memref.dim %arg0, %c0 : memref<?xf32>
+ %block_dim = arith.constant 5 : index
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %block_dim, %block_y = %c1, %block_z = %c1) {
- %a = memref.load %arg0[%tx] : memref<?xf32>
- %b = memref.load %arg1[%tx] : memref<?xf32>
+ %a = memref.load %arg0[%tx] : memref<5xf32>
+ %b = memref.load %arg1[%tx] : memref<5xf32>
%c = arith.addf %a, %b : f32
- memref.store %c, %arg2[%tx] : memref<?xf32>
+ memref.store %c, %arg2[%tx] : memref<5xf32>
gpu.terminator
}
return
%9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref<?xf32>) -> (memref<?xf32>)
%10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref<?xf32>) -> (memref<?xf32>)
%11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref<?xf32>) -> (memref<?xf32>)
+ %12 = memref.cast %9 : memref<?xf32> to memref<5xf32>
+ %13 = memref.cast %10 : memref<?xf32> to memref<5xf32>
+ %14 = memref.cast %11 : memref<?xf32> to memref<5xf32>
- call @vecadd(%9, %10, %11) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
+ call @vecadd(%12, %13, %14) : (memref<5xf32>, memref<5xf32>, memref<5xf32>) -> ()
call @printMemrefF32(%8) : (memref<*xf32>) -> ()
return
}