[WPD] Extend checking mode to support fallback to indirect call
authorTeresa Johnson <tejohnson@google.com>
Thu, 10 Mar 2022 22:17:53 +0000 (14:17 -0800)
committerTeresa Johnson <tejohnson@google.com>
Mon, 14 Mar 2022 17:16:28 +0000 (10:16 -0700)
Extend -wholeprogramdevirt-check to support both the existing
trapping mode on an incorrect devirtualization, as well as a new
mode to fallback to an indirect call on a mismatch. The new mode is

The new mode is useful in cases where we want to enable
devirtualization but cannot fully guarantee whole program visibility
(e.g in the case where LTO has been disabled for a small set of objects
that could potentially override virtual methods without having a symbol
reference to anything in the base class including the vtable).

Remove !prof and !callees metadata (which are used by indirect call
promotion) from both the new direct call and the fallback indirect call
(so that we don't perform another round of promotion on the latter).
Also remove it from the direct call in the non-fallback cases, which
was an oversight, although it didn't seem to cause any issues. Add tests
for the metadata removal covering the various cases.

Differential Revision: https://reviews.llvm.org/D121419

llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
llvm/test/ThinLTO/X86/devirt.ll
llvm/test/ThinLTO/X86/devirt_check.ll

index daa8898..fcb384e 100644 (file)
@@ -19,6 +19,7 @@ class CallBase;
 class CastInst;
 class Function;
 class MDNode;
+class Value;
 
 /// Return true if the given indirect call site can be made to call \p Callee.
 ///
@@ -73,6 +74,15 @@ CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
 ///
 bool tryPromoteCall(CallBase &CB);
 
+/// Predicate and clone the given call site.
+///
+/// This function creates an if-then-else structure at the location of the call
+/// site. The "if" condition compares the call site's called value to the given
+/// callee. The original call site is moved into the "else" block, and a clone
+/// of the call site is placed in the "then" block. The cloned instruction is
+/// returned.
+CallBase &versionCallSite(CallBase &CB, Value *Callee, MDNode *BranchWeights);
+
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
index fab080f..d4b669e 100644 (file)
@@ -79,6 +79,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/ModuleSummaryIndexYAML.h"
@@ -95,6 +96,7 @@
 #include "llvm/Transforms/IPO.h"
 #include "llvm/Transforms/IPO/FunctionAttrs.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/CallPromotionUtils.h"
 #include "llvm/Transforms/Utils/Evaluator.h"
 #include <algorithm>
 #include <cstddef>
@@ -163,13 +165,19 @@ static cl::list<std::string>
                       cl::desc("Prevent function(s) from being devirtualized"),
                       cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated);
 
-/// Mechanism to add runtime checking of devirtualization decisions, trapping on
-/// any that are not correct. Useful for debugging undefined behavior leading to
-/// failures with WPD.
-static cl::opt<bool>
-    CheckDevirt("wholeprogramdevirt-check", cl::init(false), cl::Hidden,
-                cl::ZeroOrMore,
-                cl::desc("Add code to trap on incorrect devirtualizations"));
+/// Mechanism to add runtime checking of devirtualization decisions, optionally
+/// trapping or falling back to indirect call on any that are not correct.
+/// Trapping mode is useful for debugging undefined behavior leading to failures
+/// with WPD. Fallback mode is useful for ensuring safety when whole program
+/// visibility may be compromised.
+enum WPDCheckMode { None, Trap, Fallback };
+static cl::opt<WPDCheckMode> DevirtCheckMode(
+    "wholeprogramdevirt-check", cl::Hidden, cl::ZeroOrMore,
+    cl::desc("Type of checking for incorrect devirtualizations"),
+    cl::values(clEnumValN(WPDCheckMode::None, "none", "No checking"),
+               clEnumValN(WPDCheckMode::Trap, "trap", "Trap when incorrect"),
+               clEnumValN(WPDCheckMode::Fallback, "fallback",
+                          "Fallback to indirect when incorrect")));
 
 namespace {
 struct PatternList {
@@ -1140,10 +1148,10 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
       Value *Callee =
           Builder.CreateBitCast(TheFn, CB.getCalledOperand()->getType());
 
-      // If checking is enabled, add support to compare the virtual function
-      // pointer to the devirtualized target. In case of a mismatch, perform a
-      // debug trap.
-      if (CheckDevirt) {
+      // If trap checking is enabled, add support to compare the virtual
+      // function pointer to the devirtualized target. In case of a mismatch,
+      // perform a debug trap.
+      if (DevirtCheckMode == WPDCheckMode::Trap) {
         auto *Cond = Builder.CreateICmpNE(CB.getCalledOperand(), Callee);
         Instruction *ThenTerm =
             SplitBlockAndInsertIfThen(Cond, &CB, /*Unreachable=*/false);
@@ -1153,8 +1161,38 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
         CallTrap->setDebugLoc(CB.getDebugLoc());
       }
 
-      // Devirtualize.
-      CB.setCalledOperand(Callee);
+      // If fallback checking is enabled, add support to compare the virtual
+      // function pointer to the devirtualized target. In case of a mismatch,
+      // fall back to indirect call.
+      if (DevirtCheckMode == WPDCheckMode::Fallback) {
+        MDNode *Weights =
+            MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1);
+        // Version the indirect call site. If the called value is equal to the
+        // given callee, 'NewInst' will be executed, otherwise the original call
+        // site will be executed.
+        CallBase &NewInst = versionCallSite(CB, Callee, Weights);
+        NewInst.setCalledOperand(Callee);
+        // Since the new call site is direct, we must clear metadata that
+        // is only appropriate for indirect calls. This includes !prof and
+        // !callees metadata.
+        NewInst.setMetadata(LLVMContext::MD_prof, nullptr);
+        NewInst.setMetadata(LLVMContext::MD_callees, nullptr);
+        // Additionally, we should remove them from the fallback indirect call,
+        // so that we don't attempt to perform indirect call promotion later.
+        CB.setMetadata(LLVMContext::MD_prof, nullptr);
+        CB.setMetadata(LLVMContext::MD_callees, nullptr);
+      }
+
+      // In either trapping or non-checking mode, devirtualize original call.
+      else {
+        // Devirtualize unconditionally.
+        CB.setCalledOperand(Callee);
+        // Since the call site is now direct, we must clear metadata that
+        // is only appropriate for indirect calls. This includes !prof and
+        // !callees metadata.
+        CB.setMetadata(LLVMContext::MD_prof, nullptr);
+        CB.setMetadata(LLVMContext::MD_callees, nullptr);
+      }
 
       // This use is no longer unsafe.
       if (VCallSite.NumUnsafeUses)
index 56b6e4b..e530afc 100644 (file)
@@ -279,8 +279,8 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
 ///     ; The original call instruction stays in its original block.
 ///     %t0 = musttail call i32 %ptr()
 ///     ret %t0
-static CallBase &versionCallSite(CallBase &CB, Value *Callee,
-                                 MDNode *BranchWeights) {
+CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee,
+                                MDNode *BranchWeights) {
 
   IRBuilder<> Builder(&CB);
   CallBase *OrigInst = &CB;
index 66adec0..9ba1dc2 100644 (file)
@@ -154,7 +154,10 @@ entry:
 
   ; Check that the call was devirtualized.
   ; CHECK-IR: %call = tail call i32 @_ZN1A1nEi
-  %call = tail call i32 %fptr1(%struct.A* nonnull %obj, i32 %a)
+  ; Ensure !prof and !callees metadata for indirect call promotion removed.
+  ; CHECK-IR-NOT: prof
+  ; CHECK-IR-NOT: callees
+  %call = tail call i32 %fptr1(%struct.A* nonnull %obj, i32 %a), !prof !5, !callees !6
 
   %3 = bitcast i8** %vtable to i32 (%struct.A*, i32)**
   %fptr22 = load i32 (%struct.A*, i32)*, i32 (%struct.A*, i32)** %3, align 8
@@ -207,3 +210,5 @@ attributes #0 = { noinline optnone }
 !2 = !{i64 16, !"_ZTS1C"}
 !3 = !{i64 16, !4}
 !4 = distinct !{}
+!5 = !{!"VP", i32 0, i64 1, i64 1621563287929432257, i64 1}
+!6 = !{i32 (%struct.A*, i32)* @_ZN1A1nEi}
index 0ede1e1..a16c828 100644 (file)
@@ -1,21 +1,33 @@
 ; REQUIRES: x86-registered-target
 
 ; Test that devirtualization option -wholeprogramdevirt-check adds code to check
-; that the devirtualization decision was correct and trap if not.
+; that the devirtualization decision was correct and trap or fallback if not.
 
 ; The vtables have vcall_visibility metadata with hidden visibility, to enable
 ; devirtualization.
 
 ; Generate unsplit module with summary for ThinLTO index-based WPD.
 ; RUN: opt -thinlto-bc -o %t2.o %s
+
+; Check first in trapping mode.
 ; RUN: llvm-lto2 run %t2.o -save-temps -use-new-pm -pass-remarks=. \
-; RUN:  -wholeprogramdevirt-check \
+; RUN:  -wholeprogramdevirt-check=trap \
 ; RUN:   -o %t3 \
 ; RUN:   -r=%t2.o,test,px \
 ; RUN:   -r=%t2.o,_ZN1A1nEi,p \
 ; RUN:   -r=%t2.o,_ZN1B1fEi,p \
 ; RUN:   -r=%t2.o,_ZTV1B,px 2>&1 | FileCheck %s --check-prefix=REMARK
-; RUN: llvm-dis %t3.1.4.opt.bc -o - | FileCheck %s --check-prefix=CHECK-IR
+; RUN: llvm-dis %t3.1.4.opt.bc -o - | FileCheck %s --check-prefix=CHECK --check-prefix=TRAP
+
+; Check next in fallback mode.
+; RUN: llvm-lto2 run %t2.o -save-temps -use-new-pm -pass-remarks=. \
+; RUN:  -wholeprogramdevirt-check=fallback \
+; RUN:   -o %t3 \
+; RUN:   -r=%t2.o,test,px \
+; RUN:   -r=%t2.o,_ZN1A1nEi,p \
+; RUN:   -r=%t2.o,_ZN1B1fEi,p \
+; RUN:   -r=%t2.o,_ZTV1B,px 2>&1 | FileCheck %s --check-prefix=REMARK
+; RUN: llvm-dis %t3.1.4.opt.bc -o - | FileCheck %s --check-prefix=CHECK --check-prefix=FALLBACK
 
 ; REMARK-DAG: single-impl: devirtualized a call to _ZN1A1nEi
 
@@ -28,7 +40,7 @@ target triple = "x86_64-grtev4-linux-gnu"
 @_ZTV1B = constant { [4 x i8*] } { [4 x i8*] [i8* null, i8* undef, i8* bitcast (i32 (%struct.B*, i32)* @_ZN1B1fEi to i8*), i8* bitcast (i32 (%struct.A*, i32)* @_ZN1A1nEi to i8*)] }, !type !0, !type !1, !vcall_visibility !5
 
 
-; CHECK-IR-LABEL: define i32 @test
+; CHECK-LABEL: define i32 @test
 define i32 @test(%struct.A* %obj, i32 %a) {
 entry:
   %0 = bitcast %struct.A* %obj to i8***
@@ -42,19 +54,40 @@ entry:
 
   ; Check that the call was devirtualized, but preceeded by a check guarding
   ; a trap if the function pointer doesn't match.
-  ; CHECK-IR:   %.not = icmp eq i32 (%struct.A*, i32)* %fptr1, @_ZN1A1nEi
-  ; CHECK-IR:   br i1 %.not, label %3, label %2
-  ; CHECK-IR: 2:
-  ; CHECK-IR:   tail call void @llvm.debugtrap()
-  ; CHECK-IR:   br label %3
-  ; CHECK-IR: 3:
-  ; CHECK-IR:   tail call i32 @_ZN1A1nEi
-  %call = tail call i32 %fptr1(%struct.A* nonnull %obj, i32 %a)
+  ; TRAP:   %.not = icmp eq i32 (%struct.A*, i32)* %fptr1, @_ZN1A1nEi
+  ; Ensure !prof and !callees metadata for indirect call promotion removed.
+  ; TRAP-NOT: prof
+  ; TRAP-NOT: callees
+  ; TRAP:   br i1 %.not, label %3, label %2
+  ; TRAP: 2:
+  ; TRAP:   tail call void @llvm.debugtrap()
+  ; TRAP:   br label %3
+  ; TRAP: 3:
+  ; TRAP:   tail call i32 @_ZN1A1nEi
+  ; Check that the call was devirtualized, but preceeded by a check guarding
+  ; a fallback if the function pointer doesn't match.
+  ; FALLBACK:   %2 = icmp eq i32 (%struct.A*, i32)* %fptr1, @_ZN1A1nEi
+  ; FALLBACK:   br i1 %2, label %if.true.direct_targ, label %if.false.orig_indirect
+  ; FALLBACK: if.true.direct_targ:
+  ; FALLBACK:   tail call i32 @_ZN1A1nEi
+  ; Ensure !prof and !callees metadata for indirect call promotion removed.
+  ; FALLBACK-NOT: prof
+  ; FALLBACK-NOT: callees
+  ; FALLBACK:   br label %if.end.icp
+  ; FALLBACK: if.false.orig_indirect:
+  ; FALLBACK:   tail call i32 %fptr1
+  ; Ensure !prof and !callees metadata for indirect call promotion removed.
+  ; In particular, if left on the fallback indirect call ICP may perform an
+  ; additional round of promotion.
+  ; FALLBACK-NOT: prof
+  ; FALLBACK-NOT: callees
+  ; FALLBACK:   br label %if.end.icp
+  %call = tail call i32 %fptr1(%struct.A* nonnull %obj, i32 %a), !prof !6, !callees !7
 
   ret i32 %call
 }
-; CHECK-IR-LABEL:   ret i32
-; CHECK-IR-LABEL: }
+; CHECK-LABEL:   ret i32
+; CHECK-LABEL: }
 
 declare i1 @llvm.type.test(i8*, metadata)
 declare void @llvm.assume(i1)
@@ -75,3 +108,5 @@ attributes #0 = { noinline optnone }
 !3 = !{i64 16, !4}
 !4 = distinct !{}
 !5 = !{i64 1}
+!6 = !{!"VP", i32 0, i64 1, i64 1621563287929432257, i64 1}
+!7 = !{i32 (%struct.A*, i32)* @_ZN1A1nEi}