ConvertLaunchFuncToCudaCalls: use LLVM dialect globals
authorAlex Zinenko <zinenko@google.com>
Tue, 20 Aug 2019 14:51:32 +0000 (07:51 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 20 Aug 2019 14:52:01 +0000 (07:52 -0700)
This conversion has been using a stack-allocated array of i8 to store the
null-terminated kernel name in order to pass it to the CUDA wrappers expecting
a C string because the LLVM dialect was missing support for globals.  Now that
the suport is introduced, use a global instead.

Refactor global string construction from GenerateCubinAccessors into a common
utility function living in the LLVM namespace.

PiperOrigin-RevId: 264382489

mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Conversion/GPUToCUDA/lower-launch-func-to-cuda.mlir

index bd1a3fe..b8b7a1e 100644 (file)
@@ -17,6 +17,7 @@
 #ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
 #define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
 
+#include "mlir/Support/LLVM.h"
 #include <functional>
 #include <memory>
 #include <string>
 
 namespace mlir {
 
-class ModulePassBase;
 class FuncOp;
+class Location;
+class ModulePassBase;
+class OpBuilder;
+class Value;
+
+namespace LLVM {
+class LLVMDialect;
+}
 
 using OwnedCubin = std::unique_ptr<std::vector<char>>;
 using CubinGenerator = std::function<OwnedCubin(const std::string &, FuncOp &)>;
@@ -53,6 +61,7 @@ std::unique_ptr<ModulePassBase> createConvertGpuLaunchFuncToCudaCallsPass();
 /// Creates a pass to augment a module with getter functions for all contained
 /// cubins as encoded via the 'nvvm.cubin' attribute.
 std::unique_ptr<ModulePassBase> createGenerateCubinAccessorPass();
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
index 754fb48..403fade 100644 (file)
@@ -179,6 +179,13 @@ private:
   std::unique_ptr<detail::LLVMDialectImpl> impl;
 };
 
+/// Create an LLVM global containing the string "value" at the module containing
+/// surrounding the insertion point of builder. Obtain the address of that
+/// global and use it to compute the address of the first character in the
+/// string (operations inserted at the builder insertion point).
+Value *createGlobalString(Location loc, OpBuilder &builder, StringRef name,
+                          StringRef value, LLVM::LLVMDialect *llvmDialect);
+
 } // end namespace LLVM
 } // end namespace mlir
 
index 7073e5e..d4293ba 100644 (file)
@@ -39,6 +39,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
 
@@ -253,43 +254,28 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
   return array;
 }
 
-// Generates LLVM IR that produces a value representing the name of the
-// given kernel function. The generated IR consists essentially of the
-// following:
+// Generates an LLVM IR dialect global that contains the name of the given
+// kernel function as a C string, and returns a pointer to its beginning.
+// The code is essentially:
 //
-// %0 = alloca(strlen(name) + 1)
-// %0[0] = constant name[0]
-// ...
-// %0[n] = constant name[n]
-// %0[n+1] = 0
+// llvm.global constant @kernel_name("function_name\00")
+// func(...) {
+//   %0 = llvm.addressof @kernel_name
+//   %1 = llvm.constant (0 : index)
+//   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
+// }
 Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
     FuncOp kernelFunction, Location &loc, OpBuilder &builder) {
-  // TODO(herhut): Make this a constant once this is supported.
-  auto kernelNameSize = builder.create<LLVM::ConstantOp>(
-      loc, getInt32Type(),
-      builder.getI32IntegerAttr(kernelFunction.getName().size() + 1));
-  auto kernelName = builder.create<LLVM::AllocaOp>(
-      loc, getPointerType(), kernelNameSize, /*alignment=*/1);
-  for (auto byte : llvm::enumerate(kernelFunction.getName())) {
-    auto index = builder.create<LLVM::ConstantOp>(
-        loc, getInt32Type(), builder.getI32IntegerAttr(byte.index()));
-    auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
-                                           ArrayRef<Value *>{index});
-    auto value = builder.create<LLVM::ConstantOp>(
-        loc, getInt8Type(),
-        builder.getIntegerAttr(builder.getIntegerType(8), byte.value()));
-    builder.create<LLVM::StoreOp>(loc, value, gep);
-  }
-  // Add trailing zero to terminate string.
-  auto index = builder.create<LLVM::ConstantOp>(
-      loc, getInt32Type(),
-      builder.getI32IntegerAttr(kernelFunction.getName().size()));
-  auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName,
-                                         ArrayRef<Value *>{index});
-  auto value = builder.create<LLVM::ConstantOp>(
-      loc, getInt8Type(), builder.getIntegerAttr(builder.getIntegerType(8), 0));
-  builder.create<LLVM::StoreOp>(loc, value, gep);
-  return kernelName;
+  // Make sure the trailing zero is included in the constant.
+  std::vector<char> kernelName(kernelFunction.getName().begin(),
+                               kernelFunction.getName().end());
+  kernelName.push_back('\0');
+
+  std::string globalName =
+      llvm::formatv("{0}_kernel_name", kernelFunction.getName());
+  return LLVM::createGlobalString(
+      loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
+      llvmDialect);
 }
 
 // Emits LLVM IR to launch a kernel function. Expects the module that contains
index 12f65c7..c4daf8a 100644 (file)
@@ -20,6 +20,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Attributes.h"
@@ -63,35 +64,25 @@ private:
     auto module = orig.getParentOfType<ModuleOp>();
     assert(module && "function must belong to a module");
 
-    // Create a global at the top of the module.
-    OpBuilder moduleBuilder(module.getBody(), module.getBody()->begin());
-    auto type = LLVM::LLVMType::getArrayTy(
-        LLVM::LLVMType::getInt8Ty(llvmDialect), blob.getValue().size());
-    nameBuffer.append(kCubinStorageSuffix);
-    auto cubinGlobalString = moduleBuilder.create<LLVM::GlobalOp>(
-        loc, type, /*isConstant=*/true, StringRef(nameBuffer), blob);
-
     // Insert the getter function just after the original function.
+    OpBuilder moduleBuilder(module.getBody(), module.getBody()->begin());
     moduleBuilder.setInsertionPoint(orig.getOperation()->getNextNode());
     auto getterType = moduleBuilder.getFunctionType(
         llvm::None, LLVM::LLVMType::getInt8PtrTy(llvmDialect));
-    // Drop the storage suffix before appending the getter suffix.
-    nameBuffer.resize(orig.getName().size());
     nameBuffer.append(kCubinGetterSuffix);
     auto result = moduleBuilder.create<FuncOp>(
         loc, StringRef(nameBuffer), getterType, ArrayRef<NamedAttribute>());
     Block *entryBlock = result.addEntryBlock();
 
+    // Drop the getter suffix before appending the storage suffix.
+    nameBuffer.resize(orig.getName().size());
+    nameBuffer.append(kCubinStorageSuffix);
+
     // Obtain the address of the first character of the global string containing
-    // the cubin and return from the getter (addressof will return [? x i8]*).
+    // the cubin and return from the getter.
     OpBuilder builder(entryBlock);
-    Value *cubinGlobalStringPtr =
-        builder.create<LLVM::AddressOfOp>(loc, cubinGlobalString);
-    Value *cst0 = builder.create<LLVM::ConstantOp>(
-        loc, getIndexType(), builder.getIntegerAttr(builder.getIndexType(), 0));
-    Value *startPtr = builder.create<LLVM::GEPOp>(
-        loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), cubinGlobalStringPtr,
-        ArrayRef<Value *>({cst0, cst0}));
+    Value *startPtr = LLVM::createGlobalString(
+        loc, builder, StringRef(nameBuffer), blob.getValue(), llvmDialect);
     builder.create<LLVM::ReturnOp>(loc, startPtr);
 
     // Store the name of the getter on the function for easier lookup.
index 7a2d4f4..27ee2f6 100644 (file)
@@ -1397,3 +1397,34 @@ LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
 LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
   return dialect->impl->voidTy;
 }
+
+//===----------------------------------------------------------------------===//
+// Utility functions.
+//===----------------------------------------------------------------------===//
+
+Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
+                                      StringRef name, StringRef value,
+                                      LLVM::LLVMDialect *llvmDialect) {
+  assert(builder.getInsertionBlock() &&
+         builder.getInsertionBlock()->getParentOp() &&
+         "expected builder to point to a block constained in an op");
+  auto module =
+      builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
+  assert(module && "builder points to an op outside of a module");
+
+  // Create the global at the entry of the module.
+  OpBuilder moduleBuilder(module.getBodyRegion());
+  auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
+                                         value.size());
+  auto global = moduleBuilder.create<LLVM::GlobalOp>(
+      loc, type, /*isConstant=*/true, name, builder.getStringAttr(value));
+
+  // Get the pointer to the first character in the global string.
+  Value *globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
+  Value *cst0 = builder.create<LLVM::ConstantOp>(
+      loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
+      builder.getIntegerAttr(builder.getIndexType(), 0));
+  return builder.create<LLVM::GEPOp>(
+      loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
+      ArrayRef<Value *>({cst0, cst0}));
+}
index 61ab945..39c4429 100644 (file)
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s --launch-func-to-cuda | FileCheck %s
 
+// CHECK: llvm.global constant @[[kernel_name:.*]]("kernel\00")
+
 func @cubin_getter() -> !llvm<"i8*">
 
 func @kernel(!llvm.float, !llvm<"float*">)
@@ -11,15 +13,15 @@ func @foo() {
   %1 = "op"() : () -> (!llvm<"float*">)
   %cst = constant 8 : index
 
-  // CHECK: %5 = llvm.alloca %4 x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
-  // CHECK: %6 = llvm.call @mcuModuleLoad(%5, %3) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
-  // CHECK: %32 = llvm.alloca %31 x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
-  // CHECK: %33 = llvm.call @mcuModuleGetFunction(%32, %7, %9) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
-  // CHECK: %34 = llvm.call @mcuGetStreamHelper() : () -> !llvm<"i8*">
-  // CHECK: %48 = llvm.call @mcuLaunchKernel(%35, %c8, %c8, %c8, %c8, %c8, %c8, %2, %34, %38, %47) : (!llvm<"i8*">, index, index, index, index, index, index, !llvm.i32, !llvm<"i8*">, !llvm<"i8**">, !llvm<"i8**">) -> !llvm.i32
-  // CHECK: %49 = llvm.call @mcuStreamSynchronize(%34) : (!llvm<"i8*">) -> !llvm.i32
+  // CHECK: [[module_ptr:%.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
+  // CHECK: llvm.call @mcuModuleLoad([[module_ptr]], {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
+  // CHECK: [[func_ptr:%.*]] = llvm.alloca {{.*}} x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
+  // CHECK: llvm.call @mcuModuleGetFunction([[func_ptr]], {{.*}}, {{.*}}) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
+  // CHECK: llvm.call @mcuGetStreamHelper
+  // CHECK: llvm.call @mcuLaunchKernel
+  // CHECK: llvm.call @mcuStreamSynchronize
   "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel }
       : (index, index, index, index, index, index, !llvm.float, !llvm<"float*">) -> ()
 
   return
-}
\ No newline at end of file
+}