[COFF] Store import symbol pointers as pointers to the base class
authorMartin Storsjo <martin@martin.st>
Tue, 10 Jul 2018 10:40:11 +0000 (10:40 +0000)
committerMartin Storsjo <martin@martin.st>
Tue, 10 Jul 2018 10:40:11 +0000 (10:40 +0000)
Future symbol insertions can potentially change the type of these
symbols - keep pointers to the base class to reflect this, and
use dynamic casts to inspect them before using as the subclass
type.

This fixes crashes that were possible before, by touching these
symbols that now are populated as e.g. a DefinedRegular, via
the old pointers with DefinedImportThunk type.

Differential Revision: https://reviews.llvm.org/D48953

llvm-svn: 336652

lld/COFF/InputFiles.cpp
lld/COFF/InputFiles.h
lld/COFF/SymbolTable.cpp
lld/COFF/SymbolTable.h
lld/COFF/Writer.cpp
lld/test/COFF/Inputs/otherFunc.s [new file with mode: 0644]
lld/test/COFF/thunk-replace.s [new file with mode: 0644]

index 9e2345b..8684321 100644 (file)
@@ -431,7 +431,8 @@ void ImportFile::parse() {
   // address pointed by the __imp_ symbol. (This allows you to call
   // DLL functions just like regular non-DLL functions.)
   if (Hdr->getType() == llvm::COFF::IMPORT_CODE)
-    ThunkSym = Symtab->addImportThunk(Name, ImpSym, Hdr->Machine);
+    ThunkSym = Symtab->addImportThunk(
+        Name, cast_or_null<DefinedImportData>(ImpSym), Hdr->Machine);
 }
 
 void BitcodeFile::parse() {
index 9f4db45..4ee4b36 100644 (file)
@@ -207,8 +207,8 @@ public:
 
   static std::vector<ImportFile *> Instances;
 
-  DefinedImportData *ImpSym = nullptr;
-  DefinedImportThunk *ThunkSym = nullptr;
+  Symbol *ImpSym = nullptr;
+  Symbol *ThunkSym = nullptr;
   std::string DLLName;
 
 private:
index a3d1fcd..b286d86 100644 (file)
@@ -342,30 +342,29 @@ Symbol *SymbolTable::addCommon(InputFile *F, StringRef N, uint64_t Size,
   return S;
 }
 
-DefinedImportData *SymbolTable::addImportData(StringRef N, ImportFile *F) {
+Symbol *SymbolTable::addImportData(StringRef N, ImportFile *F) {
   Symbol *S;
   bool WasInserted;
   std::tie(S, WasInserted) = insert(N);
   S->IsUsedInRegularObj = true;
   if (WasInserted || isa<Undefined>(S) || isa<Lazy>(S)) {
     replaceSymbol<DefinedImportData>(S, N, F);
-    return cast<DefinedImportData>(S);
+    return S;
   }
 
   reportDuplicate(S, F);
   return nullptr;
 }
 
-DefinedImportThunk *SymbolTable::addImportThunk(StringRef Name,
-                                               DefinedImportData *ID,
-                                               uint16_t Machine) {
+Symbol *SymbolTable::addImportThunk(StringRef Name, DefinedImportData *ID,
+                                    uint16_t Machine) {
   Symbol *S;
   bool WasInserted;
   std::tie(S, WasInserted) = insert(Name);
   S->IsUsedInRegularObj = true;
   if (WasInserted || isa<Undefined>(S) || isa<Lazy>(S)) {
     replaceSymbol<DefinedImportThunk>(S, Name, ID, Machine);
-    return cast<DefinedImportThunk>(S);
+    return S;
   }
 
   reportDuplicate(S, ID->File);
index 55481e6..30cb1a5 100644 (file)
@@ -92,9 +92,9 @@ public:
   Symbol *addCommon(InputFile *F, StringRef N, uint64_t Size,
                     const llvm::object::coff_symbol_generic *S = nullptr,
                     CommonChunk *C = nullptr);
-  DefinedImportData *addImportData(StringRef N, ImportFile *F);
-  DefinedImportThunk *addImportThunk(StringRef Name, DefinedImportData *S,
-                                     uint16_t Machine);
+  Symbol *addImportData(StringRef N, ImportFile *F);
+  Symbol *addImportThunk(StringRef Name, DefinedImportData *S,
+                         uint16_t Machine);
 
   void reportDuplicate(Symbol *Existing, InputFile *NewFile);
 
index c6e17ee..e9b21df 100644 (file)
@@ -544,17 +544,24 @@ void Writer::createImportTables() {
     if (Config->DLLOrder.count(DLL) == 0)
       Config->DLLOrder[DLL] = Config->DLLOrder.size();
 
-    if (DefinedImportThunk *Thunk = File->ThunkSym)
+    if (File->ThunkSym) {
+      if (!isa<DefinedImportThunk>(File->ThunkSym))
+        fatal(toString(*File->ThunkSym) + " was replaced");
+      DefinedImportThunk *Thunk = cast<DefinedImportThunk>(File->ThunkSym);
       if (File->ThunkLive)
         TextSec->addChunk(Thunk->getChunk());
+    }
 
+    if (File->ImpSym && !isa<DefinedImportData>(File->ImpSym))
+      fatal(toString(*File->ImpSym) + " was replaced");
+    DefinedImportData *ImpSym = cast_or_null<DefinedImportData>(File->ImpSym);
     if (Config->DelayLoads.count(StringRef(File->DLLName).lower())) {
       if (!File->ThunkSym)
         fatal("cannot delay-load " + toString(File) +
-              " due to import of data: " + toString(*File->ImpSym));
-      DelayIdata.add(File->ImpSym);
+              " due to import of data: " + toString(*ImpSym));
+      DelayIdata.add(ImpSym);
     } else {
-      Idata.add(File->ImpSym);
+      Idata.add(ImpSym);
     }
   }
 
diff --git a/lld/test/COFF/Inputs/otherFunc.s b/lld/test/COFF/Inputs/otherFunc.s
new file mode 100644 (file)
index 0000000..ae8b922
--- /dev/null
@@ -0,0 +1,7 @@
+.global otherFunc
+.global MessageBoxA
+.text
+otherFunc:
+  ret
+MessageBoxA:
+  ret
diff --git a/lld/test/COFF/thunk-replace.s b/lld/test/COFF/thunk-replace.s
new file mode 100644 (file)
index 0000000..2d47fcc
--- /dev/null
@@ -0,0 +1,15 @@
+# REQUIRES: x86
+
+# RUN: llvm-mc -triple=x86_64-win32 %s -filetype=obj -o %t.main.obj
+# RUN: llvm-mc -triple=x86_64-win32 %p/Inputs/otherFunc.s -filetype=obj -o %t.other.obj
+# RUN: llvm-ar rcs %t.other.lib %t.other.obj
+# RUN: not lld-link -out:%t.exe -entry:main %t.main.obj %p/Inputs/std64.lib %t.other.lib -opt:noref 2>&1 | FileCheck %s
+# CHECK: MessageBoxA was replaced
+
+.global main
+.text
+main:
+  callq MessageBoxA
+  callq ExitProcess
+  callq otherFunc
+  ret