[attrs] Handle convergent CallSites.
authorJustin Lebar <jlebar@google.com>
Mon, 14 Mar 2016 20:18:54 +0000 (20:18 +0000)
committerJustin Lebar <jlebar@google.com>
Mon, 14 Mar 2016 20:18:54 +0000 (20:18 +0000)
Summary:
Previously we had a notion of convergent functions but not of convergent
calls.  This is insufficient to correctly analyze calls where the target
is unknown, e.g. indirect calls.

Now a call is convergent if it targets a known-convergent function, or
if it's explicitly marked as convergent.  As usual, we can remove
convergent where we can prove that no convergent operations are
performed in the call.

Originally landed as r261544, then reverted in r261544 for (incidental)
build breakage.  Re-landed here with no changes.

Reviewers: chandlerc, jingyue

Subscribers: llvm-commits, tra, jhen, hfinkel

Differential Revision: http://reviews.llvm.org/D17739

llvm-svn: 263481

llvm/lib/Transforms/IPO/FunctionAttrs.cpp
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
llvm/test/Transforms/FunctionAttrs/convergent.ll
llvm/test/Transforms/InstCombine/convergent.ll [new file with mode: 0644]

index 7885ec7..b145771 100644 (file)
@@ -903,49 +903,44 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes,
   return MadeChange;
 }
 
-/// Removes convergent attributes where we can prove that none of the SCC's
-/// callees are themselves convergent.  Returns true if successful at removing
-/// the attribute.
+/// Remove the convergent attribute from all functions in the SCC if every
+/// callsite within the SCC is not convergent (except for calls to functions
+/// within the SCC).  Returns true if changes were made.
 static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) {
-  // Determines whether a function can be made non-convergent, ignoring all
-  // other functions in SCC.  (A function can *actually* be made non-convergent
-  // only if all functions in its SCC can be made convergent.)
-  auto CanRemoveConvergent = [&](Function *F) {
-    if (!F->isConvergent())
-      return true;
-
-    // Can't remove convergent from declarations.
-    if (F->isDeclaration())
-      return false;
-
-    for (Instruction &I : instructions(*F))
-      if (auto CS = CallSite(&I)) {
-        // Can't remove convergent if any of F's callees -- ignoring functions
-        // in the SCC itself -- are convergent. This needs to consider both
-        // function calls and intrinsic calls. We also assume indirect calls
-        // might call a convergent function.
-        // FIXME: We should revisit this when we put convergent onto calls
-        // instead of functions so that indirect calls which should be
-        // convergent are required to be marked as such.
-        Function *Callee = CS.getCalledFunction();
-        if (!Callee || (SCCNodes.count(Callee) == 0 && Callee->isConvergent()))
-          return false;
-      }
-
-    return true;
-  };
+  // For every function in SCC, ensure that either
+  //  * it is not convergent, or
+  //  * we can remove its convergent attribute.
+  bool HasConvergentFn = false;
+  for (Function *F : SCCNodes) {
+    if (!F->isConvergent()) continue;
+    HasConvergentFn = true;
+
+    // Can't remove convergent from function declarations.
+    if (F->isDeclaration()) return false;
+
+    // Can't remove convergent if any of our functions has a convergent call to a
+    // function not in the SCC.
+    for (Instruction &I : instructions(*F)) {
+      CallSite CS(&I);
+      // Bail if CS is a convergent call to a function not in the SCC.
+      if (CS && CS.isConvergent() &&
+          SCCNodes.count(CS.getCalledFunction()) == 0)
+        return false;
+    }
+  }
 
-  // We can remove the convergent attr from functions in the SCC if they all
-  // can be made non-convergent (because they call only non-convergent
-  // functions, other than each other).
-  if (!llvm::all_of(SCCNodes, CanRemoveConvergent))
-    return false;
+  // If the SCC doesn't have any convergent functions, we have nothing to do.
+  if (!HasConvergentFn) return false;
 
-  // If we got here, all of the SCC's callees are non-convergent. Therefore all
-  // of the SCC's functions can be marked as non-convergent.
+  // If we got here, all of the calls the SCC makes to functions not in the SCC
+  // are non-convergent.  Therefore all of the SCC's functions can also be made
+  // non-convergent.  We'll remove the attr from the callsites in
+  // InstCombineCalls.
   for (Function *F : SCCNodes) {
-    if (F->isConvergent())
-      DEBUG(dbgs() << "Removing convergent attr from " << F->getName() << "\n");
+    if (!F->isConvergent()) continue;
+
+    DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName()
+                 << "\n");
     F->setNotConvergent();
   }
   return true;
index f6ed690..1de05dc 100644 (file)
@@ -2179,7 +2179,15 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {
   if (!isa<Function>(Callee) && transformConstExprCastCall(CS))
     return nullptr;
 
-  if (Function *CalleeF = dyn_cast<Function>(Callee))
+  if (Function *CalleeF = dyn_cast<Function>(Callee)) {
+    // Remove the convergent attr on calls when the callee is not convergent.
+    if (CS.isConvergent() && !CalleeF->isConvergent()) {
+      DEBUG(dbgs() << "Removing convergent attr from instr "
+                   << CS.getInstruction() << "\n");
+      CS.setNotConvergent();
+      return CS.getInstruction();
+    }
+
     // If the call and callee calling conventions don't match, this call must
     // be unreachable, as the call is undefined.
     if (CalleeF->getCallingConv() != CS.getCallingConv() &&
@@ -2204,6 +2212,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {
                                     Constant::getNullValue(CalleeF->getType()));
       return nullptr;
     }
+  }
 
   if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) {
     // If CS does not return void then replaceAllUsesWith undef.
index 46370d7..bc21d85 100644 (file)
@@ -1,4 +1,4 @@
-; RUN: opt < %s -basicaa -functionattrs -rpo-functionattrs -S | FileCheck %s
+; RUN: opt -functionattrs -S < %s | FileCheck %s
 
 ; CHECK: Function Attrs
 ; CHECK-NOT: convergent
@@ -24,20 +24,41 @@ declare i32 @k() convergent
 ; CHECK-SAME: convergent
 ; CHECK-NEXT: define i32 @extern()
 define i32 @extern() convergent {
-  %a = call i32 @k()
+  %a = call i32 @k() convergent
   ret i32 %a
 }
 
+; Convergent should not be removed on the function here.  Although the call is
+; not explicitly convergent, it picks up the convergent attr from the callee.
+;
 ; CHECK: Function Attrs
 ; CHECK-SAME: convergent
-; CHECK-NEXT: define i32 @call_extern()
-define i32 @call_extern() convergent {
-  %a = call i32 @extern()
+; CHECK-NEXT: define i32 @extern_non_convergent_call()
+define i32 @extern_non_convergent_call() convergent {
+  %a = call i32 @k()
   ret i32 %a
 }
 
 ; CHECK: Function Attrs
 ; CHECK-SAME: convergent
+; CHECK-NEXT: define i32 @indirect_convergent_call(
+define i32 @indirect_convergent_call(i32 ()* %f) convergent {
+   %a = call i32 %f() convergent
+   ret i32 %a
+}
+; Give indirect_non_convergent_call the norecurse attribute so we get a
+; "Function Attrs" comment in the output.
+;
+; CHECK: Function Attrs
+; CHECK-NOT: convergent
+; CHECK-NEXT: define i32 @indirect_non_convergent_call(
+define i32 @indirect_non_convergent_call(i32 ()* %f) convergent norecurse {
+   %a = call i32 %f()
+   ret i32 %a
+}
+
+; CHECK: Function Attrs
+; CHECK-SAME: convergent
 ; CHECK-NEXT: declare void @llvm.cuda.syncthreads()
 declare void @llvm.cuda.syncthreads() convergent
 
@@ -45,25 +66,16 @@ declare void @llvm.cuda.syncthreads() convergent
 ; CHECK-SAME: convergent
 ; CHECK-NEXT: define i32 @intrinsic()
 define i32 @intrinsic() convergent {
+  ; Implicitly convergent, because the intrinsic is convergent.
   call void @llvm.cuda.syncthreads()
   ret i32 0
 }
 
-@xyz = global i32 ()* null
-; CHECK: Function Attrs
-; CHECK-SAME: convergent
-; CHECK-NEXT: define i32 @functionptr()
-define i32 @functionptr() convergent {
-  %1 = load i32 ()*, i32 ()** @xyz
-  %2 = call i32 %1()
-  ret i32 %2
-}
-
 ; CHECK: Function Attrs
 ; CHECK-NOT: convergent
 ; CHECK-NEXT: define i32 @recursive1()
 define i32 @recursive1() convergent {
-  %a = call i32 @recursive2()
+  %a = call i32 @recursive2() convergent
   ret i32 %a
 }
 
@@ -71,7 +83,7 @@ define i32 @recursive1() convergent {
 ; CHECK-NOT: convergent
 ; CHECK-NEXT: define i32 @recursive2()
 define i32 @recursive2() convergent {
-  %a = call i32 @recursive1()
+  %a = call i32 @recursive1() convergent
   ret i32 %a
 }
 
@@ -79,7 +91,7 @@ define i32 @recursive2() convergent {
 ; CHECK-SAME: convergent
 ; CHECK-NEXT: define i32 @noopt()
 define i32 @noopt() convergent optnone noinline {
-  %a = call i32 @noopt_friend()
+  %a = call i32 @noopt_friend() convergent
   ret i32 0
 }
 
diff --git a/llvm/test/Transforms/InstCombine/convergent.ll b/llvm/test/Transforms/InstCombine/convergent.ll
new file mode 100644 (file)
index 0000000..4ed40d8
--- /dev/null
@@ -0,0 +1,33 @@
+; RUN: opt -instcombine -S < %s | FileCheck %s
+
+declare i32 @k() convergent
+declare i32 @f()
+
+define i32 @extern() {
+  ; Convergent attr shouldn't be removed here; k is convergent.
+  ; CHECK: call i32 @k() [[CONVERGENT_ATTR:#[0-9]+]]
+  %a = call i32 @k() convergent
+  ret i32 %a
+}
+
+define i32 @extern_no_attr() {
+  ; Convergent attr shouldn't be added here, even though k is convergent.
+  ; CHECK: call i32 @k(){{$}}
+  %a = call i32 @k()
+  ret i32 %a
+}
+
+define i32 @no_extern() {
+  ; Convergent should be removed here, as the target is convergent.
+  ; CHECK: call i32 @f(){{$}}
+  %a = call i32 @f() convergent
+  ret i32 %a
+}
+
+define i32 @indirect_call(i32 ()* %f) {
+  ; CHECK call i32 %f() [[CONVERGENT_ATTR]]
+  %a = call i32 %f() convergent
+  ret i32 %a
+}
+
+; CHECK: [[CONVERGENT_ATTR]] = { convergent }