#ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
#define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
+#include "mlir/Support/LLVM.h"
#include <functional>
#include <memory>
#include <string>
namespace mlir {
-class ModulePassBase;
class FuncOp;
+class Location;
+class ModulePassBase;
+class OpBuilder;
+class Value;
+
+namespace LLVM {
+class LLVMDialect;
+}
using OwnedCubin = std::unique_ptr<std::vector<char>>;
using CubinGenerator = std::function<OwnedCubin(const std::string &, FuncOp &)>;
/// Creates a pass to augment a module with getter functions for all contained
/// cubins as encoded via the 'nvvm.cubin' attribute.
std::unique_ptr<ModulePassBase> createGenerateCubinAccessorPass();
+
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
std::unique_ptr<detail::LLVMDialectImpl> impl;
};
+/// Create an LLVM global containing the string "value" at the module containing
+/// surrounding the insertion point of builder. Obtain the address of that
+/// global and use it to compute the address of the first character in the
+/// string (operations inserted at the builder insertion point).
+Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name,
+ StringRef value, LLVM::LLVMDialect *llvmDialect);
+
} // end namespace LLVM
} // end namespace mlir
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Error.h"
+#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
return array;
}
-// Generates LLVM IR that produces a value representing the name of the
-// given kernel function. The generated IR consists essentially of the
-// following:
+// Generates an LLVM IR dialect global that contains the name of the given
+// kernel function as a C string, and returns a pointer to its beginning.
+// The code is essentially:
//
-// %0 = alloca(strlen(name) + 1)
-// %0[0] = constant name[0]
-// ...
-// %0[n] = constant name[n]
-// %0[n+1] = 0
+// llvm.global constant @kernel_name("function_name\00")
+// func(...) {
+// %0 = llvm.addressof @kernel_name
+// %1 = llvm.constant (0 : index)
+// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
+// }
Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
FuncOp kernelFunction, Location &loc, OpBuilder &builder) {
- // TODO(herhut): Make this a constant once this is supported.
- auto kernelNameSize = builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(),
- builder.getI32IntegerAttr(kernelFunction.getName().size() + 1));
- auto kernelName = builder.create<LLVM::AllocaOp>(
- loc, getPointerType(), kernelNameSize, /*alignment=*/1);
- for (auto byte : llvm::enumerate(kernelFunction.getName())) {
- auto index = builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(), builder.getI32IntegerAttr(byte.index()));
- auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
- ArrayRef<Value *>{index});
- auto value = builder.create<LLVM::ConstantOp>(
- loc, getInt8Type(),
- builder.getIntegerAttr(builder.getIntegerType(8), byte.value()));
- builder.create<LLVM::StoreOp>(loc, value, gep);
- }
- // Add trailing zero to terminate string.
- auto index = builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(),
- builder.getI32IntegerAttr(kernelFunction.getName().size()));
- auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
- ArrayRef<Value *>{index});
- auto value = builder.create<LLVM::ConstantOp>(
- loc, getInt8Type(), builder.getIntegerAttr(builder.getIntegerType(8), 0));
- builder.create<LLVM::StoreOp>(loc, value, gep);
- return kernelName;
+ // Make sure the trailing zero is included in the constant.
+ std::vector<char> kernelName(kernelFunction.getName().begin(),
+ kernelFunction.getName().end());
+ kernelName.push_back('\0');
+
+ std::string globalName =
+ llvm::formatv("{0}_kernel_name", kernelFunction.getName());
+ return LLVM::createGlobalString(
+ loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
+ llvmDialect);
}
// Emits LLVM IR to launch a kernel function. Expects the module that contains
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
auto module = orig.getParentOfType<ModuleOp>();
assert(module && "function must belong to a module");
- // Create a global at the top of the module.
- OpBuilder moduleBuilder(module.getBody(), module.getBody()->begin());
- auto type = LLVM::LLVMType::getArrayTy(
- LLVM::LLVMType::getInt8Ty(llvmDialect), blob.getValue().size());
- nameBuffer.append(kCubinStorageSuffix);
- auto cubinGlobalString = moduleBuilder.create<LLVM::GlobalOp>(
- loc, type, /*isConstant=*/true, StringRef(nameBuffer), blob);
-
// Insert the getter function just after the original function.
+ OpBuilder moduleBuilder(module.getBody(), module.getBody()->begin());
moduleBuilder.setInsertionPoint(orig.getOperation()->getNextNode());
auto getterType = moduleBuilder.getFunctionType(
llvm::None, LLVM::LLVMType::getInt8PtrTy(llvmDialect));
- // Drop the storage suffix before appending the getter suffix.
- nameBuffer.resize(orig.getName().size());
nameBuffer.append(kCubinGetterSuffix);
auto result = moduleBuilder.create<FuncOp>(
loc, StringRef(nameBuffer), getterType, ArrayRef<NamedAttribute>());
Block *entryBlock = result.addEntryBlock();
+ // Drop the getter suffix before appending the storage suffix.
+ nameBuffer.resize(orig.getName().size());
+ nameBuffer.append(kCubinStorageSuffix);
+
// Obtain the address of the first character of the global string containing
- // the cubin and return from the getter (addressof will return [? x i8]*).
+ // the cubin and return from the getter.
OpBuilder builder(entryBlock);
- Value *cubinGlobalStringPtr =
- builder.create<LLVM::AddressOfOp>(loc, cubinGlobalString);
- Value *cst0 = builder.create<LLVM::ConstantOp>(
- loc, getIndexType(), builder.getIntegerAttr(builder.getIndexType(), 0));
- Value *startPtr = builder.create<LLVM::GEPOp>(
- loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), cubinGlobalStringPtr,
- ArrayRef<Value *>({cst0, cst0}));
+ Value *startPtr = LLVM::createGlobalString(
+ loc, builder, StringRef(nameBuffer), blob.getValue(), llvmDialect);
builder.create<LLVM::ReturnOp>(loc, startPtr);
// Store the name of the getter on the function for easier lookup.
LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
return dialect->impl->voidTy;
}
+
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
+Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
+ StringRef name, StringRef value,
+ LLVM::LLVMDialect *llvmDialect) {
+ assert(builder.getInsertionBlock() &&
+ builder.getInsertionBlock()->getParentOp() &&
+ "expected builder to point to a block constained in an op");
+ auto module =
+ builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
+ assert(module && "builder points to an op outside of a module");
+
+ // Create the global at the entry of the module.
+ OpBuilder moduleBuilder(module.getBodyRegion());
+ auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
+ value.size());
+ auto global = moduleBuilder.create<LLVM::GlobalOp>(
+ loc, type, /*isConstant=*/true, name, builder.getStringAttr(value));
+
+ // Get the pointer to the first character in the global string.
+ Value *globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
+ Value *cst0 = builder.create<LLVM::ConstantOp>(
+ loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+ builder.getIntegerAttr(builder.getIndexType(), 0));
+ return builder.create<LLVM::GEPOp>(
+ loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
+ ArrayRef<Value *>({cst0, cst0}));
+}
// RUN: mlir-opt %s --launch-func-to-cuda | FileCheck %s
+// CHECK: llvm.global constant @[[kernel_name:.*]]("kernel\00")
+
func @cubin_getter() -> !llvm<"i8*">
func @kernel(!llvm.float, !llvm<"float*">)
%1 = "op"() : () -> (!llvm<"float*">)
%cst = constant 8 : index
- // CHECK: %5 = llvm.alloca %4 x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
- // CHECK: %6 = llvm.call @mcuModuleLoad(%5, %3) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
- // CHECK: %32 = llvm.alloca %31 x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
- // CHECK: %33 = llvm.call @mcuModuleGetFunction(%32, %7, %9) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
- // CHECK: %34 = llvm.call @mcuGetStreamHelper() : () -> !llvm<"i8*">
- // CHECK: %48 = llvm.call @mcuLaunchKernel(%35, %c8, %c8, %c8, %c8, %c8, %c8, %2, %34, %38, %47) : (!llvm<"i8*">, index, index, index, index, index, index, !llvm.i32, !llvm<"i8*">, !llvm<"i8**">, !llvm<"i8**">) -> !llvm.i32
- // CHECK: %49 = llvm.call @mcuStreamSynchronize(%34) : (!llvm<"i8*">) -> !llvm.i32
+ // CHECK: [[module_ptr:%.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
+ // CHECK: llvm.call @mcuModuleLoad([[module_ptr]], {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
+ // CHECK: [[func_ptr:%.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
+ // CHECK: llvm.call @mcuModuleGetFunction([[func_ptr]], {{.*}}, {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
+ // CHECK: llvm.call @mcuGetStreamHelper
+ // CHECK: llvm.call @mcuLaunchKernel
+ // CHECK: llvm.call @mcuStreamSynchronize
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel }
: (index, index, index, index, index, index, !llvm.float, !llvm<"float*">) -> ()
return
-}
\ No newline at end of file
+}