[BOLT] Add pass to fix ambiguous memory references
authorRafael Auler <rafaelauler@fb.com>
Sat, 27 Aug 2022 01:39:36 +0000 (18:39 -0700)
committerRafael Auler <rafaelauler@fb.com>
Thu, 13 Oct 2022 01:39:50 +0000 (18:39 -0700)
This adds a round of checks to memory references, looking for
incorrect references to jump table objects. Fix them by replacing the
jump table reference with another object reference + offset.

This solves bugs related to regular data references in code
accidentally being bound to a jump table, and this reference being
updated to a new (incorrect) location because we moved this jump
table.

Fixes #55004

Reviewed By: #bolt, maksfb

Differential Revision: https://reviews.llvm.org/D134098

bolt/include/bolt/Passes/ValidateMemRefs.h [new file with mode: 0644]
bolt/lib/Core/BinaryEmitter.cpp
bolt/lib/Passes/CMakeLists.txt
bolt/lib/Passes/ValidateMemRefs.cpp [new file with mode: 0644]
bolt/lib/Rewrite/BinaryPassManager.cpp
bolt/test/runtime/X86/jt-symbol-disambiguation-2.s [new file with mode: 0644]
bolt/test/runtime/X86/jt-symbol-disambiguation.s [new file with mode: 0644]

diff --git a/bolt/include/bolt/Passes/ValidateMemRefs.h b/bolt/include/bolt/Passes/ValidateMemRefs.h
new file mode 100644 (file)
index 0000000..d33862c
--- /dev/null
@@ -0,0 +1,41 @@
+//===- bolt/Passes/ValidateMemRefs.h ----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BOLT_PASSES_VALIDATEMEMREFS_H
+#define BOLT_PASSES_VALIDATEMEMREFS_H
+
+#include "bolt/Passes/BinaryPasses.h"
+
+namespace llvm::bolt {
+
+/// Post processing to check for memory references that cause a symbol
+/// in data section to be ambiguous, requiring us to avoid moving that
+/// object or disambiguating such references. This is currently
+/// limited to fixing false references to the location of jump tables.
+///
+class ValidateMemRefs : public BinaryFunctionPass {
+public:
+  explicit ValidateMemRefs(const cl::opt<bool> &PrintPass)
+      : BinaryFunctionPass(PrintPass) {}
+
+  const char *getName() const override { return "validate-mem-refs"; }
+
+  void runOnFunctions(BinaryContext &BC) override;
+
+private:
+  bool checkAndFixJTReference(BinaryFunction &BF, MCInst &Inst,
+                              uint32_t OperandNum, const MCSymbol *Sym,
+                              uint64_t Offset);
+  void runOnFunction(BinaryFunction &BF);
+
+  static std::atomic<std::uint64_t> ReplacedReferences;
+};
+
+} // namespace llvm::bolt
+
+#endif
index c4f887d..f618619 100644 (file)
@@ -815,7 +815,25 @@ void BinaryEmitter::emitJumpTable(const JumpTable &JT, MCSection *HotSection,
           Streamer.switchSection(ColdSection);
         Streamer.emitValueToAlignment(JT.EntrySize);
       }
-      Streamer.emitLabel(LI->second);
+      // Emit all labels registered at the address of this jump table
+      // to sync with our global symbol table.  We may have two labels
+      // registered at this address if one label was created via
+      // getOrCreateGlobalSymbol() (e.g. LEA instructions referencing
+      // this location) and another via getOrCreateJumpTable().  This
+      // creates a race where the symbols created by these two
+      // functions may or may not be the same, but they are both
+      // registered in our symbol table at the same address. By
+      // emitting them all here we make sure there is no ambiguity
+      // that depends on the order that these symbols were created, so
+      // whenever this address is referenced in the binary, it is
+      // certain to point to the jump table identified at this
+      // address.
+      if (BinaryData *BD = BC.getBinaryDataByName(LI->second->getName())) {
+        for (MCSymbol *S : BD->getSymbols())
+          Streamer.emitLabel(S);
+      } else {
+        Streamer.emitLabel(LI->second);
+      }
       LastLabel = LI->second;
     }
     if (JT.Type == JumpTable::JTT_NORMAL) {
index 544b6df..7f67261 100644 (file)
@@ -42,6 +42,7 @@ add_llvm_library(LLVMBOLTPasses
   TailDuplication.cpp
   ThreeWayBranch.cpp
   ValidateInternalCalls.cpp
+  ValidateMemRefs.cpp
   VeneerElimination.cpp
   RetpolineInsertion.cpp
 
diff --git a/bolt/lib/Passes/ValidateMemRefs.cpp b/bolt/lib/Passes/ValidateMemRefs.cpp
new file mode 100644 (file)
index 0000000..69cd82a
--- /dev/null
@@ -0,0 +1,104 @@
+//===- bolt/Passes/ValidateMemRefs.cpp ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "bolt/Passes/ValidateMemRefs.h"
+#include "bolt/Core/ParallelUtilities.h"
+
+#define DEBUG_TYPE "bolt-memrefs"
+
+namespace opts {
+extern llvm::cl::opt<llvm::bolt::JumpTableSupportLevel> JumpTables;
+}
+
+namespace llvm::bolt {
+
+std::atomic<std::uint64_t> ValidateMemRefs::ReplacedReferences{0};
+
+bool ValidateMemRefs::checkAndFixJTReference(BinaryFunction &BF, MCInst &Inst,
+                                             uint32_t OperandNum,
+                                             const MCSymbol *Sym,
+                                             uint64_t Offset) {
+  BinaryContext &BC = BF.getBinaryContext();
+  auto L = BC.scopeLock();
+  BinaryData *BD = BC.getBinaryDataByName(Sym->getName());
+  if (!BD)
+    return false;
+
+  const uint64_t TargetAddress = BD->getAddress() + Offset;
+  JumpTable *JT = BC.getJumpTableContainingAddress(TargetAddress);
+  if (!JT)
+    return false;
+
+  const bool IsLegitAccess = llvm::any_of(
+      JT->Parents, [&](const BinaryFunction *Parent) { return Parent == &BF; });
+  if (IsLegitAccess)
+    return true;
+
+  // Accessing a jump table in another function. This is not a
+  // legitimate jump table access, we need to replace the reference to
+  // the jump table label with a regular rodata reference. Get a
+  // non-JT reference by fetching the symbol 1 byte before the JT
+  // label.
+  MCSymbol *NewSym = BC.getOrCreateGlobalSymbol(TargetAddress - 1, "DATAat");
+  BC.MIB->setOperandToSymbolRef(Inst, OperandNum, NewSym, 1, &*BC.Ctx, 0);
+  LLVM_DEBUG(dbgs() << "BOLT-DEBUG: replaced reference @" << BF.getPrintName()
+                    << " from " << BD->getName() << " to " << NewSym->getName()
+                    << " + 1\n");
+  ++ReplacedReferences;
+  return true;
+}
+
+void ValidateMemRefs::runOnFunction(BinaryFunction &BF) {
+  MCPlusBuilder *MIB = BF.getBinaryContext().MIB.get();
+
+  for (BinaryBasicBlock &BB : BF) {
+    for (MCInst &Inst : BB) {
+      for (int I = 0, E = MCPlus::getNumPrimeOperands(Inst); I != E; ++I) {
+        const MCOperand &Operand = Inst.getOperand(I);
+        if (!Operand.isExpr())
+          continue;
+
+        const auto [Sym, Offset] = MIB->getTargetSymbolInfo(Operand.getExpr());
+        if (!Sym)
+          continue;
+
+        checkAndFixJTReference(BF, Inst, I, Sym, Offset);
+      }
+    }
+  }
+}
+
+void ValidateMemRefs::runOnFunctions(BinaryContext &BC) {
+  if (!BC.isX86())
+    return;
+
+  // Skip validation if not moving JT
+  if (opts::JumpTables == JTS_NONE || opts::JumpTables == JTS_BASIC)
+    return;
+
+  ParallelUtilities::WorkFuncWithAllocTy ProcessFunction =
+      [&](BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId) {
+        runOnFunction(BF);
+      };
+  ParallelUtilities::PredicateTy SkipPredicate = [&](const BinaryFunction &BF) {
+    return !BF.hasCFG();
+  };
+  LLVM_DEBUG(dbgs() << "BOLT-DEBUG: starting memrefs validation pass\n");
+  ParallelUtilities::runOnEachFunctionWithUniqueAllocId(
+      BC, ParallelUtilities::SchedulingPolicy::SP_INST_LINEAR, ProcessFunction,
+      SkipPredicate, "validate-mem-refs", /*ForceSequential=*/true);
+  LLVM_DEBUG(dbgs() << "BOLT-DEBUG: memrefs validation is concluded\n");
+
+  if (!ReplacedReferences)
+    return;
+
+  outs() << "BOLT-INFO: validate-mem-refs updated " << ReplacedReferences
+         << " object references\n";
+}
+
+} // namespace llvm::bolt
index dff93b9..85b49a2 100644 (file)
@@ -31,6 +31,7 @@
 #include "bolt/Passes/TailDuplication.h"
 #include "bolt/Passes/ThreeWayBranch.h"
 #include "bolt/Passes/ValidateInternalCalls.h"
+#include "bolt/Passes/ValidateMemRefs.h"
 #include "bolt/Passes/VeneerElimination.h"
 #include "bolt/Utils/CommandLineOpts.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -317,9 +318,6 @@ void BinaryFunctionPassManager::runAllPasses(BinaryContext &BC) {
     Manager.registerPass(
         std::make_unique<VeneerElimination>(PrintVeneerElimination));
 
-  if (opts::Instrument)
-    Manager.registerPass(std::make_unique<Instrumentation>(NeverPrint));
-
   // Here we manage dependencies/order manually, since passes are run in the
   // order they're registered.
 
@@ -331,6 +329,11 @@ void BinaryFunctionPassManager::runAllPasses(BinaryContext &BC) {
 
   Manager.registerPass(std::make_unique<ValidateInternalCalls>(NeverPrint));
 
+  Manager.registerPass(std::make_unique<ValidateMemRefs>(NeverPrint));
+
+  if (opts::Instrument)
+    Manager.registerPass(std::make_unique<Instrumentation>(NeverPrint));
+
   Manager.registerPass(std::make_unique<ShortenInstructions>(NeverPrint));
 
   Manager.registerPass(std::make_unique<RemoveNops>(NeverPrint));
diff --git a/bolt/test/runtime/X86/jt-symbol-disambiguation-2.s b/bolt/test/runtime/X86/jt-symbol-disambiguation-2.s
new file mode 100644 (file)
index 0000000..d81c233
--- /dev/null
@@ -0,0 +1,100 @@
+# In this test case, we reproduce the behavior seen in gcc where the
+# base address of a data object is decremented by some number and lands
+# inside a jump table from another function.
+
+# REQUIRES: system-linux
+
+# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown %s -o %t.o
+# RUN: llvm-strip --strip-unneeded %t.o
+# RUN: %clang %cflags -no-pie -nostartfiles -nostdlib -lc %t.o -o %t.exe -Wl,-q
+
+# RUN: llvm-bolt %t.exe -o %t.exe.bolt --relocs=1 --lite=0 \
+# RUN:   --reorder-blocks=reverse -jump-tables=move
+
+# RUN: %t.exe.bolt 1 2 3
+
+  .file "jt-symbol-disambiguation-2.s"
+  .text
+
+# ----
+# Func foo contains a jump table whose start is colocated with a
+# symbol marking the end of a data table
+# ----
+  .globl foo
+  .type foo, @function
+foo:
+  .cfi_startproc
+  xor    %rax,%rax
+  and    $0x3,%rdi
+  leaq   .JT1(%rip), %rax
+  movslq  (%rax, %rdi, 4), %rdi
+  addq   %rax, %rdi
+  jmpq   *%rdi
+.LBB1:
+  movl   $0x1,%eax
+  jmp    .LBB5
+.LBB2:
+  movl   $0x2,%eax
+  jmp    .LBB5
+.LBB3:
+  movl   $0x3,%eax
+  jmp    .LBB5
+.LBB4:
+  movl   $0x4,%eax
+.LBB5:
+  retq
+  .cfi_endproc
+  .size foo, .-foo
+
+# ----
+# Func _start scans an object with indexed access using %rax * 8 as an
+# index. However, %rax is known to be at least one, so the compiler
+# loads the pointer for the base address as object - 8 instead of just
+# object.
+# ----
+  .globl _start
+  .type _start, @function
+_start:
+  .cfi_startproc
+  movq   (%rsp), %rdi
+  callq foo
+  xorq   %rbx, %rbx
+  leaq   .object-8(%rip), %rsi  # indexed access base address
+  movq   $1, %rax               # start index
+.LBB6:
+  cmpq $4, %rax
+  je .LBB7
+  addq (%rsi,%rax,8), %rbx
+  incq %rax   # ++iterator
+  jmp .LBB6
+.LBB7:
+  cmpq  $1368, %rbx             # check .object contents integrity
+  jne   .LBB_BAD
+  xor   %rdi, %rdi
+  callq exit@PLT
+  retq
+.LBB_BAD:
+  leaq  .message, %rdi
+  callq puts@PLT
+  movq  $1, %rdi
+  callq exit@PLT
+  retq
+  .cfi_endproc
+  .size _start, .-_start
+
+# ----
+# Data section
+# ----
+  .section .rodata,"a",@progbits
+  .p2align 3
+.JT1:
+  .long .LBB1 - .JT1
+  .long .LBB2 - .JT1
+  .long .LBB3 - .JT1
+  .long .LBB4 - .JT1
+.object:
+  .quad 123
+  .quad 456
+  .quad 789
+.message:
+  .asciz "RUNTIME ASSERTION FAILURE: references in test binary are corrupt after BOLT"
diff --git a/bolt/test/runtime/X86/jt-symbol-disambiguation.s b/bolt/test/runtime/X86/jt-symbol-disambiguation.s
new file mode 100644 (file)
index 0000000..ea3b893
--- /dev/null
@@ -0,0 +1,91 @@
+# In this test case, the symbol that represents the end of a table
+# in .rodata is being colocated with the start of a jump table from
+# another function, and BOLT moves that jump table. This should not
+# cause the symbol representing the end of the table to be moved as
+# well.
+# Bug reported in https://github.com/llvm/llvm-project/issues/55004
+
+# REQUIRES: system-linux
+
+# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown %s -o %t.o
+# RUN: llvm-strip --strip-unneeded %t.o
+# RUN: %clang %cflags -no-pie -nostartfiles -nostdlib -lc %t.o -o %t.exe -Wl,-q
+
+# RUN: llvm-bolt %t.exe -o %t.exe.bolt --relocs=1 --lite=0 \
+# RUN:   --reorder-blocks=reverse -jump-tables=move
+
+# RUN: %t.exe.bolt 1 2 3
+
+  .file "jt-symbol-disambiguation.s"
+  .text
+
+# ----
+# Func foo contains a jump table whose start is colocated with a
+# symbol marking the end of a data table
+# ----
+  .globl foo
+  .type foo, @function
+foo:
+  .cfi_startproc
+  xor    %rax,%rax
+  and    $0x3,%rdi
+  leaq   .JT1(%rip), %rax
+  movslq  (%rax, %rdi, 4), %rdi
+  addq   %rax, %rdi
+  jmpq   *%rdi
+.LBB1:
+  movl   $0x1,%eax
+  jmp    .LBB5
+.LBB2:
+  movl   $0x2,%eax
+  jmp    .LBB5
+.LBB3:
+  movl   $0x3,%eax
+  jmp    .LBB5
+.LBB4:
+  movl   $0x4,%eax
+.LBB5:
+  retq
+  .cfi_endproc
+  .size foo, .-foo
+
+# ----
+# Func _start scans a table using begin/end pointers. End pointer is colocated
+# with the start of a jump table of function foo. When that jump
+# table moves, end pointer in _start should not be affected.
+# ----
+  .globl _start
+  .type _start, @function
+_start:
+  .cfi_startproc
+  movq   (%rsp), %rdi
+  callq foo
+  leaq   .start_of_table(%rip), %rsi  # iterator
+  leaq   .end_of_table(%rip), %rdi    # iterator end
+.LBB6:
+  cmpq %rsi, %rdi
+  je .LBB7
+  movq (%rsi), %rbx
+  leaq 8(%rsi), %rsi            # ++iterator
+  jmp .LBB6
+.LBB7:
+  xor   %rdi, %rdi
+  callq exit@PLT
+  .cfi_endproc
+  .size _start, .-_start
+
+# ----
+# Data section
+# ----
+  .section .rodata,"a",@progbits
+  .p2align 3
+.start_of_table:
+  .quad 123
+  .quad 456
+  .quad 789
+.end_of_table:
+.JT1:
+  .long .LBB1 - .JT1
+  .long .LBB2 - .JT1
+  .long .LBB3 - .JT1
+  .long .LBB4 - .JT1