Make WholeProgramDevirt understand ConstStruct vtables.
authorPeter Collingbourne <peter@pcc.me.uk>
Fri, 9 Dec 2016 00:33:27 +0000 (00:33 +0000)
committerPeter Collingbourne <peter@pcc.me.uk>
Fri, 9 Dec 2016 00:33:27 +0000 (00:33 +0000)
Based on a patch by LemonBoy!

Differential Revision: https://reviews.llvm.org/D26581

llvm-svn: 289162

llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
llvm/test/Transforms/WholeProgramDevirt/non-aggregate-vtable.ll [moved from llvm/test/Transforms/WholeProgramDevirt/non-array-vtable.ll with 100% similarity]
llvm/test/Transforms/WholeProgramDevirt/struct-vtable.ll [new file with mode: 0644]

index 7ef5f24..9c80a2a 100644 (file)
@@ -293,6 +293,7 @@ struct DevirtModule {
   void buildTypeIdentifierMap(
       std::vector<VTableBits> &Bits,
       DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
+  Constant *getValueAtOffset(Constant *I, uint64_t Offset);
   bool
   tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
                             const std::set<TypeMemberInfo> &TypeMemberInfos,
@@ -382,6 +383,38 @@ void DevirtModule::buildTypeIdentifierMap(
   }
 }
 
+Constant *DevirtModule::getValueAtOffset(Constant *I, uint64_t Offset) {
+  const DataLayout &DL = M.getDataLayout();
+  unsigned Op;
+
+  if (auto *C = dyn_cast<ConstantStruct>(I)) {
+    const StructLayout *SL = DL.getStructLayout(C->getType());
+
+    if (Offset >= SL->getSizeInBytes())
+      return nullptr;
+
+    Op = SL->getElementContainingOffset(Offset);
+
+    if (Offset != SL->getElementOffset(Op))
+      return nullptr;
+
+  } else if (auto *C = dyn_cast<ConstantArray>(I)) {
+    ArrayType *VTableTy = C->getType();
+    uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
+
+    if (Offset % ElemSize != 0)
+      return nullptr;
+
+    Op = Offset / ElemSize;
+
+    if (Op >= C->getNumOperands())
+      return nullptr;
+  } else
+    return nullptr;
+
+  return cast<Constant>(I->getOperand(Op));
+}
+
 bool DevirtModule::tryFindVirtualCallTargets(
     std::vector<VirtualCallTarget> &TargetsForSlot,
     const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) {
@@ -389,22 +422,13 @@ bool DevirtModule::tryFindVirtualCallTargets(
     if (!TM.Bits->GV->isConstant())
       return false;
 
-    auto Init = dyn_cast<ConstantArray>(TM.Bits->GV->getInitializer());
-    if (!Init)
-      return false;
-    ArrayType *VTableTy = Init->getType();
-
-    uint64_t ElemSize =
-        M.getDataLayout().getTypeAllocSize(VTableTy->getElementType());
-    uint64_t GlobalSlotOffset = TM.Offset + ByteOffset;
-    if (GlobalSlotOffset % ElemSize != 0)
-      return false;
+    Constant *I = TM.Bits->GV->getInitializer();
+    Value *V = getValueAtOffset(I, TM.Offset + ByteOffset);
 
-    unsigned Op = GlobalSlotOffset / ElemSize;
-    if (Op >= Init->getNumOperands())
+    if (!V)
       return false;
 
-    auto Fn = dyn_cast<Function>(Init->getOperand(Op)->stripPointerCasts());
+    auto Fn = dyn_cast<Function>(V->stripPointerCasts());
     if (!Fn)
       return false;
 
diff --git a/llvm/test/Transforms/WholeProgramDevirt/struct-vtable.ll b/llvm/test/Transforms/WholeProgramDevirt/struct-vtable.ll
new file mode 100644 (file)
index 0000000..81e41d4
--- /dev/null
@@ -0,0 +1,63 @@
+; RUN: opt -S -wholeprogramdevirt %s | FileCheck %s
+
+target datalayout = "e-p:64:64"
+target triple = "x86_64-unknown-linux-gnu"
+
+%vtTy = type { void (i8*)* }
+
+@vt = constant %vtTy { void (i8*)* @vf }, !type !0
+
+define void @vf(i8* %this) {
+  ret void
+}
+
+; CHECK: define void @call
+define void @call(i8* %obj) {
+  %vtableptr = bitcast i8* %obj to [1 x i8*]**
+  %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr
+  %vtablei8 = bitcast [1 x i8*]* %vtable to i8*
+  %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid")
+  call void @llvm.assume(i1 %p)
+  %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0
+  %fptr = load i8*, i8** %fptrptr
+  %fptr_casted = bitcast i8* %fptr to void (i8*)*
+  ; CHECK: call void @vf(
+  call void %fptr_casted(i8* %obj)
+  ret void
+}
+
+; CHECK: define void @call_oob
+define void @call_oob(i8* %obj) {
+  %vtableptr = bitcast i8* %obj to [1 x i8*]**
+  %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr
+  %vtablei8 = bitcast [1 x i8*]* %vtable to i8*
+  %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid")
+  call void @llvm.assume(i1 %p)
+  %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 4
+  %fptr = load i8*, i8** %fptrptr
+  %fptr_casted = bitcast i8* %fptr to void (i8*)*
+  ; CHECK: call void %
+  call void %fptr_casted(i8* %obj)
+  ret void
+}
+
+; CHECK: define void @call_unaligned
+define void @call_unaligned(i8* %obj) {
+  %vtableptr = bitcast i8* %obj to [1 x i8*]**
+  %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr
+  %vtablei8 = bitcast [1 x i8*]* %vtable to i8*
+  %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid")
+  call void @llvm.assume(i1 %p)
+  %fptrptr = getelementptr i8, i8* %vtablei8, i32 1
+  %fptrptr_casted = bitcast i8* %fptrptr to i8**
+  %fptr = load i8*, i8** %fptrptr_casted
+  %fptr_casted = bitcast i8* %fptr to void (i8*)*
+  ; CHECK: call void %
+  call void %fptr_casted(i8* %obj)
+  ret void
+}
+
+declare i1 @llvm.type.test(i8*, metadata)
+declare void @llvm.assume(i1)
+
+!0 = !{i32 0, !"typeid"}