return MadeChange;
}
-/// Removes convergent attributes where we can prove that none of the SCC's
-/// callees are themselves convergent. Returns true if successful at removing
-/// the attribute.
+/// Remove the convergent attribute from all functions in the SCC if every
+/// callsite within the SCC is not convergent (except for calls to functions
+/// within the SCC). Returns true if changes were made.
static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) {
- // Determines whether a function can be made non-convergent, ignoring all
- // other functions in SCC. (A function can *actually* be made non-convergent
- // only if all functions in its SCC can be made convergent.)
- auto CanRemoveConvergent = [&](Function *F) {
- if (!F->isConvergent())
- return true;
-
- // Can't remove convergent from declarations.
- if (F->isDeclaration())
- return false;
-
- for (Instruction &I : instructions(*F))
- if (auto CS = CallSite(&I)) {
- // Can't remove convergent if any of F's callees -- ignoring functions
- // in the SCC itself -- are convergent. This needs to consider both
- // function calls and intrinsic calls. We also assume indirect calls
- // might call a convergent function.
- // FIXME: We should revisit this when we put convergent onto calls
- // instead of functions so that indirect calls which should be
- // convergent are required to be marked as such.
- Function *Callee = CS.getCalledFunction();
- if (!Callee || (SCCNodes.count(Callee) == 0 && Callee->isConvergent()))
- return false;
- }
-
- return true;
- };
+ // For every function in SCC, ensure that either
+ // * it is not convergent, or
+ // * we can remove its convergent attribute.
+ bool HasConvergentFn = false;
+ for (Function *F : SCCNodes) {
+ if (!F->isConvergent()) continue;
+ HasConvergentFn = true;
+
+ // Can't remove convergent from function declarations.
+ if (F->isDeclaration()) return false;
+
+ // Can't remove convergent if any of our functions has a convergent call to a
+ // function not in the SCC.
+ for (Instruction &I : instructions(*F)) {
+ CallSite CS(&I);
+ // Bail if CS is a convergent call to a function not in the SCC.
+ if (CS && CS.isConvergent() &&
+ SCCNodes.count(CS.getCalledFunction()) == 0)
+ return false;
+ }
+ }
- // We can remove the convergent attr from functions in the SCC if they all
- // can be made non-convergent (because they call only non-convergent
- // functions, other than each other).
- if (!llvm::all_of(SCCNodes, CanRemoveConvergent))
- return false;
+ // If the SCC doesn't have any convergent functions, we have nothing to do.
+ if (!HasConvergentFn) return false;
- // If we got here, all of the SCC's callees are non-convergent. Therefore all
- // of the SCC's functions can be marked as non-convergent.
+ // If we got here, all of the calls the SCC makes to functions not in the SCC
+ // are non-convergent. Therefore all of the SCC's functions can also be made
+ // non-convergent. We'll remove the attr from the callsites in
+ // InstCombineCalls.
for (Function *F : SCCNodes) {
- if (F->isConvergent())
- DEBUG(dbgs() << "Removing convergent attr from " << F->getName() << "\n");
+ if (!F->isConvergent()) continue;
+
+ DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName()
+ << "\n");
F->setNotConvergent();
}
return true;
if (!isa<Function>(Callee) && transformConstExprCastCall(CS))
return nullptr;
- if (Function *CalleeF = dyn_cast<Function>(Callee))
+ if (Function *CalleeF = dyn_cast<Function>(Callee)) {
+ // Remove the convergent attr on calls when the callee is not convergent.
+ if (CS.isConvergent() && !CalleeF->isConvergent()) {
+ DEBUG(dbgs() << "Removing convergent attr from instr "
+ << CS.getInstruction() << "\n");
+ CS.setNotConvergent();
+ return CS.getInstruction();
+ }
+
// If the call and callee calling conventions don't match, this call must
// be unreachable, as the call is undefined.
if (CalleeF->getCallingConv() != CS.getCallingConv() &&
Constant::getNullValue(CalleeF->getType()));
return nullptr;
}
+ }
if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) {
// If CS does not return void then replaceAllUsesWith undef.
-; RUN: opt < %s -basicaa -functionattrs -rpo-functionattrs -S | FileCheck %s
+; RUN: opt -functionattrs -S < %s | FileCheck %s
; CHECK: Function Attrs
; CHECK-NOT: convergent
; CHECK-SAME: convergent
; CHECK-NEXT: define i32 @extern()
define i32 @extern() convergent {
- %a = call i32 @k()
+ %a = call i32 @k() convergent
ret i32 %a
}
+; Convergent should not be removed on the function here. Although the call is
+; not explicitly convergent, it picks up the convergent attr from the callee.
+;
; CHECK: Function Attrs
; CHECK-SAME: convergent
-; CHECK-NEXT: define i32 @call_extern()
-define i32 @call_extern() convergent {
- %a = call i32 @extern()
+; CHECK-NEXT: define i32 @extern_non_convergent_call()
+define i32 @extern_non_convergent_call() convergent {
+ %a = call i32 @k()
ret i32 %a
}
; CHECK: Function Attrs
; CHECK-SAME: convergent
+; CHECK-NEXT: define i32 @indirect_convergent_call(
+define i32 @indirect_convergent_call(i32 ()* %f) convergent {
+ %a = call i32 %f() convergent
+ ret i32 %a
+}
+; Give indirect_non_convergent_call the norecurse attribute so we get a
+; "Function Attrs" comment in the output.
+;
+; CHECK: Function Attrs
+; CHECK-NOT: convergent
+; CHECK-NEXT: define i32 @indirect_non_convergent_call(
+define i32 @indirect_non_convergent_call(i32 ()* %f) convergent norecurse {
+ %a = call i32 %f()
+ ret i32 %a
+}
+
+; CHECK: Function Attrs
+; CHECK-SAME: convergent
; CHECK-NEXT: declare void @llvm.cuda.syncthreads()
declare void @llvm.cuda.syncthreads() convergent
; CHECK-SAME: convergent
; CHECK-NEXT: define i32 @intrinsic()
define i32 @intrinsic() convergent {
+ ; Implicitly convergent, because the intrinsic is convergent.
call void @llvm.cuda.syncthreads()
ret i32 0
}
-@xyz = global i32 ()* null
-; CHECK: Function Attrs
-; CHECK-SAME: convergent
-; CHECK-NEXT: define i32 @functionptr()
-define i32 @functionptr() convergent {
- %1 = load i32 ()*, i32 ()** @xyz
- %2 = call i32 %1()
- ret i32 %2
-}
-
; CHECK: Function Attrs
; CHECK-NOT: convergent
; CHECK-NEXT: define i32 @recursive1()
define i32 @recursive1() convergent {
- %a = call i32 @recursive2()
+ %a = call i32 @recursive2() convergent
ret i32 %a
}
; CHECK-NOT: convergent
; CHECK-NEXT: define i32 @recursive2()
define i32 @recursive2() convergent {
- %a = call i32 @recursive1()
+ %a = call i32 @recursive1() convergent
ret i32 %a
}
; CHECK-SAME: convergent
; CHECK-NEXT: define i32 @noopt()
define i32 @noopt() convergent optnone noinline {
- %a = call i32 @noopt_friend()
+ %a = call i32 @noopt_friend() convergent
ret i32 0
}
--- /dev/null
+; RUN: opt -instcombine -S < %s | FileCheck %s
+
+declare i32 @k() convergent
+declare i32 @f()
+
+define i32 @extern() {
+ ; Convergent attr shouldn't be removed here; k is convergent.
+ ; CHECK: call i32 @k() [[CONVERGENT_ATTR:#[0-9]+]]
+ %a = call i32 @k() convergent
+ ret i32 %a
+}
+
+define i32 @extern_no_attr() {
+ ; Convergent attr shouldn't be added here, even though k is convergent.
+ ; CHECK: call i32 @k(){{$}}
+ %a = call i32 @k()
+ ret i32 %a
+}
+
+define i32 @no_extern() {
+ ; Convergent should be removed here, as the target is convergent.
+ ; CHECK: call i32 @f(){{$}}
+ %a = call i32 @f() convergent
+ ret i32 %a
+}
+
+define i32 @indirect_call(i32 ()* %f) {
+ ; CHECK call i32 %f() [[CONVERGENT_ATTR]]
+ %a = call i32 %f() convergent
+ ret i32 %a
+}
+
+; CHECK: [[CONVERGENT_ATTR]] = { convergent }