--- /dev/null
+//===- bolt/Passes/ValidateMemRefs.h ----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BOLT_PASSES_VALIDATEMEMREFS_H
+#define BOLT_PASSES_VALIDATEMEMREFS_H
+
+#include "bolt/Passes/BinaryPasses.h"
+
+namespace llvm::bolt {
+
+/// Post processing to check for memory references that cause a symbol
+/// in data section to be ambiguous, requiring us to avoid moving that
+/// object or disambiguating such references. This is currently
+/// limited to fixing false references to the location of jump tables.
+///
+class ValidateMemRefs : public BinaryFunctionPass {
+public:
+ explicit ValidateMemRefs(const cl::opt<bool> &PrintPass)
+ : BinaryFunctionPass(PrintPass) {}
+
+ const char *getName() const override { return "validate-mem-refs"; }
+
+ void runOnFunctions(BinaryContext &BC) override;
+
+private:
+ bool checkAndFixJTReference(BinaryFunction &BF, MCInst &Inst,
+ uint32_t OperandNum, const MCSymbol *Sym,
+ uint64_t Offset);
+ void runOnFunction(BinaryFunction &BF);
+
+ static std::atomic<std::uint64_t> ReplacedReferences;
+};
+
+} // namespace llvm::bolt
+
+#endif
Streamer.switchSection(ColdSection);
Streamer.emitValueToAlignment(JT.EntrySize);
}
- Streamer.emitLabel(LI->second);
+ // Emit all labels registered at the address of this jump table
+ // to sync with our global symbol table. We may have two labels
+ // registered at this address if one label was created via
+ // getOrCreateGlobalSymbol() (e.g. LEA instructions referencing
+ // this location) and another via getOrCreateJumpTable(). This
+ // creates a race where the symbols created by these two
+ // functions may or may not be the same, but they are both
+ // registered in our symbol table at the same address. By
+ // emitting them all here we make sure there is no ambiguity
+ // that depends on the order that these symbols were created, so
+ // whenever this address is referenced in the binary, it is
+ // certain to point to the jump table identified at this
+ // address.
+ if (BinaryData *BD = BC.getBinaryDataByName(LI->second->getName())) {
+ for (MCSymbol *S : BD->getSymbols())
+ Streamer.emitLabel(S);
+ } else {
+ Streamer.emitLabel(LI->second);
+ }
LastLabel = LI->second;
}
if (JT.Type == JumpTable::JTT_NORMAL) {
TailDuplication.cpp
ThreeWayBranch.cpp
ValidateInternalCalls.cpp
+ ValidateMemRefs.cpp
VeneerElimination.cpp
RetpolineInsertion.cpp
--- /dev/null
+//===- bolt/Passes/ValidateMemRefs.cpp ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "bolt/Passes/ValidateMemRefs.h"
+#include "bolt/Core/ParallelUtilities.h"
+
+#define DEBUG_TYPE "bolt-memrefs"
+
+namespace opts {
+extern llvm::cl::opt<llvm::bolt::JumpTableSupportLevel> JumpTables;
+}
+
+namespace llvm::bolt {
+
+std::atomic<std::uint64_t> ValidateMemRefs::ReplacedReferences{0};
+
+bool ValidateMemRefs::checkAndFixJTReference(BinaryFunction &BF, MCInst &Inst,
+ uint32_t OperandNum,
+ const MCSymbol *Sym,
+ uint64_t Offset) {
+ BinaryContext &BC = BF.getBinaryContext();
+ auto L = BC.scopeLock();
+ BinaryData *BD = BC.getBinaryDataByName(Sym->getName());
+ if (!BD)
+ return false;
+
+ const uint64_t TargetAddress = BD->getAddress() + Offset;
+ JumpTable *JT = BC.getJumpTableContainingAddress(TargetAddress);
+ if (!JT)
+ return false;
+
+ const bool IsLegitAccess = llvm::any_of(
+ JT->Parents, [&](const BinaryFunction *Parent) { return Parent == &BF; });
+ if (IsLegitAccess)
+ return true;
+
+ // Accessing a jump table in another function. This is not a
+ // legitimate jump table access, we need to replace the reference to
+ // the jump table label with a regular rodata reference. Get a
+ // non-JT reference by fetching the symbol 1 byte before the JT
+ // label.
+ MCSymbol *NewSym = BC.getOrCreateGlobalSymbol(TargetAddress - 1, "DATAat");
+ BC.MIB->setOperandToSymbolRef(Inst, OperandNum, NewSym, 1, &*BC.Ctx, 0);
+ LLVM_DEBUG(dbgs() << "BOLT-DEBUG: replaced reference @" << BF.getPrintName()
+ << " from " << BD->getName() << " to " << NewSym->getName()
+ << " + 1\n");
+ ++ReplacedReferences;
+ return true;
+}
+
+void ValidateMemRefs::runOnFunction(BinaryFunction &BF) {
+ MCPlusBuilder *MIB = BF.getBinaryContext().MIB.get();
+
+ for (BinaryBasicBlock &BB : BF) {
+ for (MCInst &Inst : BB) {
+ for (int I = 0, E = MCPlus::getNumPrimeOperands(Inst); I != E; ++I) {
+ const MCOperand &Operand = Inst.getOperand(I);
+ if (!Operand.isExpr())
+ continue;
+
+ const auto [Sym, Offset] = MIB->getTargetSymbolInfo(Operand.getExpr());
+ if (!Sym)
+ continue;
+
+ checkAndFixJTReference(BF, Inst, I, Sym, Offset);
+ }
+ }
+ }
+}
+
+void ValidateMemRefs::runOnFunctions(BinaryContext &BC) {
+ if (!BC.isX86())
+ return;
+
+ // Skip validation if not moving JT
+ if (opts::JumpTables == JTS_NONE || opts::JumpTables == JTS_BASIC)
+ return;
+
+ ParallelUtilities::WorkFuncWithAllocTy ProcessFunction =
+ [&](BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId) {
+ runOnFunction(BF);
+ };
+ ParallelUtilities::PredicateTy SkipPredicate = [&](const BinaryFunction &BF) {
+ return !BF.hasCFG();
+ };
+ LLVM_DEBUG(dbgs() << "BOLT-DEBUG: starting memrefs validation pass\n");
+ ParallelUtilities::runOnEachFunctionWithUniqueAllocId(
+ BC, ParallelUtilities::SchedulingPolicy::SP_INST_LINEAR, ProcessFunction,
+ SkipPredicate, "validate-mem-refs", /*ForceSequential=*/true);
+ LLVM_DEBUG(dbgs() << "BOLT-DEBUG: memrefs validation is concluded\n");
+
+ if (!ReplacedReferences)
+ return;
+
+ outs() << "BOLT-INFO: validate-mem-refs updated " << ReplacedReferences
+ << " object references\n";
+}
+
+} // namespace llvm::bolt
#include "bolt/Passes/TailDuplication.h"
#include "bolt/Passes/ThreeWayBranch.h"
#include "bolt/Passes/ValidateInternalCalls.h"
+#include "bolt/Passes/ValidateMemRefs.h"
#include "bolt/Passes/VeneerElimination.h"
#include "bolt/Utils/CommandLineOpts.h"
#include "llvm/Support/FormatVariadic.h"
Manager.registerPass(
std::make_unique<VeneerElimination>(PrintVeneerElimination));
- if (opts::Instrument)
- Manager.registerPass(std::make_unique<Instrumentation>(NeverPrint));
-
// Here we manage dependencies/order manually, since passes are run in the
// order they're registered.
Manager.registerPass(std::make_unique<ValidateInternalCalls>(NeverPrint));
+ Manager.registerPass(std::make_unique<ValidateMemRefs>(NeverPrint));
+
+ if (opts::Instrument)
+ Manager.registerPass(std::make_unique<Instrumentation>(NeverPrint));
+
Manager.registerPass(std::make_unique<ShortenInstructions>(NeverPrint));
Manager.registerPass(std::make_unique<RemoveNops>(NeverPrint));
--- /dev/null
+# In this test case, we reproduce the behavior seen in gcc where the
+# base address of a data object is decremented by some number and lands
+# inside a jump table from another function.
+
+# REQUIRES: system-linux
+
+# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown %s -o %t.o
+# RUN: llvm-strip --strip-unneeded %t.o
+# RUN: %clang %cflags -no-pie -nostartfiles -nostdlib -lc %t.o -o %t.exe -Wl,-q
+
+# RUN: llvm-bolt %t.exe -o %t.exe.bolt --relocs=1 --lite=0 \
+# RUN: --reorder-blocks=reverse -jump-tables=move
+
+# RUN: %t.exe.bolt 1 2 3
+
+ .file "jt-symbol-disambiguation-2.s"
+ .text
+
+# ----
+# Func foo contains a jump table whose start is colocated with a
+# symbol marking the end of a data table
+# ----
+ .globl foo
+ .type foo, @function
+foo:
+ .cfi_startproc
+ xor %rax,%rax
+ and $0x3,%rdi
+ leaq .JT1(%rip), %rax
+ movslq (%rax, %rdi, 4), %rdi
+ addq %rax, %rdi
+ jmpq *%rdi
+.LBB1:
+ movl $0x1,%eax
+ jmp .LBB5
+.LBB2:
+ movl $0x2,%eax
+ jmp .LBB5
+.LBB3:
+ movl $0x3,%eax
+ jmp .LBB5
+.LBB4:
+ movl $0x4,%eax
+.LBB5:
+ retq
+ .cfi_endproc
+ .size foo, .-foo
+
+# ----
+# Func _start scans an object with indexed access using %rax * 8 as an
+# index. However, %rax is known to be at least one, so the compiler
+# loads the pointer for the base address as object - 8 instead of just
+# object.
+# ----
+ .globl _start
+ .type _start, @function
+_start:
+ .cfi_startproc
+ movq (%rsp), %rdi
+ callq foo
+ xorq %rbx, %rbx
+ leaq .object-8(%rip), %rsi # indexed access base address
+ movq $1, %rax # start index
+.LBB6:
+ cmpq $4, %rax
+ je .LBB7
+ addq (%rsi,%rax,8), %rbx
+ incq %rax # ++iterator
+ jmp .LBB6
+.LBB7:
+ cmpq $1368, %rbx # check .object contents integrity
+ jne .LBB_BAD
+ xor %rdi, %rdi
+ callq exit@PLT
+ retq
+.LBB_BAD:
+ leaq .message, %rdi
+ callq puts@PLT
+ movq $1, %rdi
+ callq exit@PLT
+ retq
+ .cfi_endproc
+ .size _start, .-_start
+
+# ----
+# Data section
+# ----
+ .section .rodata,"a",@progbits
+ .p2align 3
+.JT1:
+ .long .LBB1 - .JT1
+ .long .LBB2 - .JT1
+ .long .LBB3 - .JT1
+ .long .LBB4 - .JT1
+.object:
+ .quad 123
+ .quad 456
+ .quad 789
+.message:
+ .asciz "RUNTIME ASSERTION FAILURE: references in test binary are corrupt after BOLT"
--- /dev/null
+# In this test case, the symbol that represents the end of a table
+# in .rodata is being colocated with the start of a jump table from
+# another function, and BOLT moves that jump table. This should not
+# cause the symbol representing the end of the table to be moved as
+# well.
+# Bug reported in https://github.com/llvm/llvm-project/issues/55004
+
+# REQUIRES: system-linux
+
+# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown %s -o %t.o
+# RUN: llvm-strip --strip-unneeded %t.o
+# RUN: %clang %cflags -no-pie -nostartfiles -nostdlib -lc %t.o -o %t.exe -Wl,-q
+
+# RUN: llvm-bolt %t.exe -o %t.exe.bolt --relocs=1 --lite=0 \
+# RUN: --reorder-blocks=reverse -jump-tables=move
+
+# RUN: %t.exe.bolt 1 2 3
+
+ .file "jt-symbol-disambiguation.s"
+ .text
+
+# ----
+# Func foo contains a jump table whose start is colocated with a
+# symbol marking the end of a data table
+# ----
+ .globl foo
+ .type foo, @function
+foo:
+ .cfi_startproc
+ xor %rax,%rax
+ and $0x3,%rdi
+ leaq .JT1(%rip), %rax
+ movslq (%rax, %rdi, 4), %rdi
+ addq %rax, %rdi
+ jmpq *%rdi
+.LBB1:
+ movl $0x1,%eax
+ jmp .LBB5
+.LBB2:
+ movl $0x2,%eax
+ jmp .LBB5
+.LBB3:
+ movl $0x3,%eax
+ jmp .LBB5
+.LBB4:
+ movl $0x4,%eax
+.LBB5:
+ retq
+ .cfi_endproc
+ .size foo, .-foo
+
+# ----
+# Func _start scans a table using begin/end pointers. End pointer is colocated
+# with the start of a jump table of function foo. When that jump
+# table moves, end pointer in _start should not be affected.
+# ----
+ .globl _start
+ .type _start, @function
+_start:
+ .cfi_startproc
+ movq (%rsp), %rdi
+ callq foo
+ leaq .start_of_table(%rip), %rsi # iterator
+ leaq .end_of_table(%rip), %rdi # iterator end
+.LBB6:
+ cmpq %rsi, %rdi
+ je .LBB7
+ movq (%rsi), %rbx
+ leaq 8(%rsi), %rsi # ++iterator
+ jmp .LBB6
+.LBB7:
+ xor %rdi, %rdi
+ callq exit@PLT
+ .cfi_endproc
+ .size _start, .-_start
+
+# ----
+# Data section
+# ----
+ .section .rodata,"a",@progbits
+ .p2align 3
+.start_of_table:
+ .quad 123
+ .quad 456
+ .quad 789
+.end_of_table:
+.JT1:
+ .long .LBB1 - .JT1
+ .long .LBB2 - .JT1
+ .long .LBB3 - .JT1
+ .long .LBB4 - .JT1