From 9fbf52e330faa9f310855d7e4a02d48c3a1ccd41 Mon Sep 17 00:00:00 2001 From: MLIR Team Date: Fri, 8 Nov 2019 19:12:40 -0800 Subject: [PATCH] Look for SymbolRefAttr in KernelOutlining instead of hard-coding CallOp This code should be exercised using the existing kernel outlining unit test, but let me know if I should add a dedicated unit test using a fake call instruction as well. PiperOrigin-RevId: 279436321 --- mlir/include/mlir/IR/Module.h | 6 ++++ .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 32 ++++++++++++---------- mlir/test/Dialect/GPU/outlining.mlir | 18 +++++++++--- 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h index de27a49..9ac985f 100644 --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -135,6 +135,12 @@ public: return symbolTable.lookup(name); } + /// Look up a symbol with the specified name, returning null if no such + /// name exists. Names must never include the @ on them. + template Operation *lookupSymbol(NameTy &&name) const { + return symbolTable.lookup(name); + } + /// Insert a new symbol into the module, auto-renaming it as necessary. void insert(Operation *op) { symbolTable.insert(op); diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index d9a1106..672beee 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -182,21 +182,23 @@ private: builder.getUnitAttr()); ModuleManager moduleManager(kernelModule); - llvm::SmallVector funcsToInsert = {kernelFunc}; - while (!funcsToInsert.empty()) { - FuncOp func = funcsToInsert.pop_back_val(); - moduleManager.insert(func); - - // TODO(b/141098412): Support any op with a callable interface. - func.walk([&](CallOp call) { - auto callee = call.callee(); - if (moduleManager.lookupSymbol(callee)) - return; - - auto calleeFromParent = - parentModuleManager.lookupSymbol(callee); - funcsToInsert.push_back(calleeFromParent.clone()); - }); + moduleManager.insert(kernelFunc); + + llvm::SmallVector symbolDefWorklist = {kernelFunc}; + while (!symbolDefWorklist.empty()) { + if (Optional symbolUses = + SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + StringRef symbolName = symbolUse.getSymbolRef().getValue(); + if (moduleManager.lookupSymbol(symbolName)) + continue; + + Operation *symbolDefClone = + parentModuleManager.lookupSymbol(symbolName)->clone(); + symbolDefWorklist.push_back(symbolDefClone); + moduleManager.insert(symbolDefClone); + } + } } return kernelModule; diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir index 8398907..d138cfe 100644 --- a/mlir/test/Dialect/GPU/outlining.mlir +++ b/mlir/test/Dialect/GPU/outlining.mlir @@ -111,6 +111,8 @@ func @extra_constants(%arg0 : memref) { // ----- +llvm.mlir.global @global(42 : i64) : !llvm.i64 + func @function_call(%arg0 : memref) { %cst = constant 8 : index gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, @@ -119,6 +121,7 @@ func @function_call(%arg0 : memref) { %block_z = %cst) { call @device_function() : () -> () call @device_function() : () -> () + %0 = llvm.mlir.addressof @global : !llvm<"i64*"> gpu.return } return @@ -134,7 +137,14 @@ func @recursive_device_function() { gpu.return } -// CHECK: @device_function -// CHECK: @recursive_device_function -// CHECK: @device_function -// CHECK: @recursive_device_function +// CHECK: module @function_call_kernel attributes {gpu.kernel_module} { +// CHECK: func @function_call_kernel() +// CHECK: call @device_function() : () -> () +// CHECK: call @device_function() : () -> () +// CHECK: llvm.mlir.addressof @global : !llvm<"i64*"> +// +// CHECK: llvm.mlir.global @global(42 : i64) : !llvm.i64 +// +// CHECK: func @device_function() +// CHECK: func @recursive_device_function() +// CHECK-NOT: func @device_function -- 2.7.4