WholeProgramDevirt: Fix call target propagation for ptrauth architectures
authorArnold Schwaighofer <aschwaighofer@apple.com>
Wed, 28 Jun 2023 20:23:04 +0000 (13:23 -0700)
committerArnold Schwaighofer <aschwaighofer@apple.com>
Thu, 29 Jun 2023 15:02:58 +0000 (08:02 -0700)
We can't have a call with a constant target with a ptrauth bundle. Remove the
ptrauth bundle operand in such a case

rdar://105696396

Differential Revision: https://reviews.llvm.org/D144581

llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
llvm/test/Transforms/WholeProgramDevirt/devirt-single-impl-check-ptrauth.ll [new file with mode: 0644]

index bdd0506..203846d 100644 (file)
@@ -567,6 +567,10 @@ struct DevirtModule {
   // optimize a call more than once.
   SmallPtrSet<CallBase *, 8> OptimizedCalls;
 
+  // Store calls that had their ptrauth bundle removed. They are to be deleted
+  // at the end of the optimization.
+  SmallVector<CallBase *, 8> CallsWithPtrAuthBundleRemoved;
+
   // This map keeps track of the number of "unsafe" uses of a loaded function
   // pointer. The key is the associated llvm.type.test intrinsic call generated
   // by this pass. An unsafe use is one that calls the loaded function pointer
@@ -1165,6 +1169,14 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
         // !callees metadata.
         CB.setMetadata(LLVMContext::MD_prof, nullptr);
         CB.setMetadata(LLVMContext::MD_callees, nullptr);
+        if (CB.getCalledOperand() &&
+            CB.getOperandBundle(LLVMContext::OB_ptrauth)) {
+          auto *NewCS =
+              CallBase::removeOperandBundle(&CB, LLVMContext::OB_ptrauth, &CB);
+          CB.replaceAllUsesWith(NewCS);
+          // Schedule for deletion at the end of pass run.
+          CallsWithPtrAuthBundleRemoved.push_back(&CB);
+        }
       }
 
       // This use is no longer unsafe.
@@ -2349,6 +2361,9 @@ bool DevirtModule::run() {
   for (GlobalVariable &GV : M.globals())
     GV.eraseMetadata(LLVMContext::MD_vcall_visibility);
 
+  for (auto *CI : CallsWithPtrAuthBundleRemoved)
+    CI->eraseFromParent();
+
   return true;
 }
 
diff --git a/llvm/test/Transforms/WholeProgramDevirt/devirt-single-impl-check-ptrauth.ll b/llvm/test/Transforms/WholeProgramDevirt/devirt-single-impl-check-ptrauth.ll
new file mode 100644 (file)
index 0000000..b7a90d9
--- /dev/null
@@ -0,0 +1,39 @@
+; RUN: opt -S -passes=wholeprogramdevirt,verify -whole-program-visibility -pass-remarks=wholeprogramdevirt %s 2>&1 | FileCheck %s
+
+target datalayout = "e-p:64:64"
+target triple = "x86_64-unknown-linux-gnu"
+
+; CHECK: remark: <unknown>:0:0: single-impl: devirtualized a call to vf
+; CHECK: remark: <unknown>:0:0: devirtualized vf
+; CHECK-NOT: devirtualized
+
+@vt1 = constant [1 x ptr] [ptr @vf], !type !0
+@vt2 = constant [1 x ptr] [ptr @vf], !type !0
+
+define void @vf(ptr %this) {
+  ret void
+}
+
+; CHECK: define void @call
+define void @call(ptr %obj) {
+  %vtable = load ptr, ptr %obj
+  %pair = call {ptr, i1} @llvm.type.checked.load(ptr %vtable, i32 0, metadata !"typeid")
+  %fptr = extractvalue {ptr, i1} %pair, 0
+  %p = extractvalue {ptr, i1} %pair, 1
+  ; CHECK: br i1 true,
+  br i1 %p, label %cont, label %trap
+
+cont:
+  ; CHECK: call void @vf(
+  call void %fptr(ptr %obj) [ "ptrauth"(i32 5, i64 120) ]
+  ret void
+
+trap:
+  call void @llvm.trap()
+  unreachable
+}
+
+declare {ptr, i1} @llvm.type.checked.load(ptr, i32, metadata)
+declare void @llvm.trap()
+
+!0 = !{i32 0, !"typeid"}