Look for SymbolRefAttr in KernelOutlining instead of hard-coding CallOp
authorMLIR Team <no-reply@google.com>
Sat, 9 Nov 2019 03:12:40 +0000 (19:12 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 9 Nov 2019 03:13:13 +0000 (19:13 -0800)
This code should be exercised using the existing kernel outlining unit test, but
let me know if I should add a dedicated unit test using a fake call instruction
as well.

PiperOrigin-RevId: 279436321

mlir/include/mlir/IR/Module.h
mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
mlir/test/Dialect/GPU/outlining.mlir

index de27a49..9ac985f 100644 (file)
@@ -135,6 +135,12 @@ public:
     return symbolTable.lookup<T>(name);
   }
 
+  /// Look up a symbol with the specified name, returning null if no such
+  /// name exists. Names must never include the @ on them.
+  template <typename NameTy> Operation *lookupSymbol(NameTy &&name) const {
+    return symbolTable.lookup(name);
+  }
+
   /// Insert a new symbol into the module, auto-renaming it as necessary.
   void insert(Operation *op) {
     symbolTable.insert(op);
index d9a1106..672beee 100644 (file)
@@ -182,21 +182,23 @@ private:
                          builder.getUnitAttr());
     ModuleManager moduleManager(kernelModule);
 
-    llvm::SmallVector<FuncOp, 8> funcsToInsert = {kernelFunc};
-    while (!funcsToInsert.empty()) {
-      FuncOp func = funcsToInsert.pop_back_val();
-      moduleManager.insert(func);
-
-      // TODO(b/141098412): Support any op with a callable interface.
-      func.walk([&](CallOp call) {
-        auto callee = call.callee();
-        if (moduleManager.lookupSymbol<FuncOp>(callee))
-          return;
-
-        auto calleeFromParent =
-            parentModuleManager.lookupSymbol<FuncOp>(callee);
-        funcsToInsert.push_back(calleeFromParent.clone());
-      });
+    moduleManager.insert(kernelFunc);
+
+    llvm::SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc};
+    while (!symbolDefWorklist.empty()) {
+      if (Optional<SymbolTable::UseRange> symbolUses =
+              SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
+        for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+          StringRef symbolName = symbolUse.getSymbolRef().getValue();
+          if (moduleManager.lookupSymbol(symbolName))
+            continue;
+
+          Operation *symbolDefClone =
+              parentModuleManager.lookupSymbol(symbolName)->clone();
+          symbolDefWorklist.push_back(symbolDefClone);
+          moduleManager.insert(symbolDefClone);
+        }
+      }
     }
 
     return kernelModule;
index 8398907..d138cfe 100644 (file)
@@ -111,6 +111,8 @@ func @extra_constants(%arg0 : memref<?xf32>) {
 
 // -----
 
+llvm.mlir.global @global(42 : i64) : !llvm.i64
+
 func @function_call(%arg0 : memref<?xf32>) {
   %cst = constant 8 : index
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst,
@@ -119,6 +121,7 @@ func @function_call(%arg0 : memref<?xf32>) {
                                         %block_z = %cst) {
     call @device_function() : () -> ()
     call @device_function() : () -> ()
+    %0 = llvm.mlir.addressof @global : !llvm<"i64*">
     gpu.return
   }
   return
@@ -134,7 +137,14 @@ func @recursive_device_function() {
   gpu.return
 }
 
-// CHECK: @device_function
-// CHECK: @recursive_device_function
-// CHECK: @device_function
-// CHECK: @recursive_device_function
+// CHECK: module @function_call_kernel attributes {gpu.kernel_module} {
+// CHECK:   func @function_call_kernel()
+// CHECK:     call @device_function() : () -> ()
+// CHECK:     call @device_function() : () -> ()
+// CHECK:     llvm.mlir.addressof @global : !llvm<"i64*">
+//
+// CHECK:   llvm.mlir.global @global(42 : i64) : !llvm.i64
+//
+// CHECK:   func @device_function()
+// CHECK:   func @recursive_device_function()
+// CHECK-NOT:   func @device_function