From: Rafael Auler Date: Sat, 27 Aug 2022 01:39:36 +0000 (-0700) Subject: [BOLT] Add pass to fix ambiguous memory references X-Git-Tag: upstream/17.0.6~30745 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4f158995b9cddae392bfb5989af8c83101ae0789;p=platform%2Fupstream%2Fllvm.git [BOLT] Add pass to fix ambiguous memory references This adds a round of checks to memory references, looking for incorrect references to jump table objects. Fix them by replacing the jump table reference with another object reference + offset. This solves bugs related to regular data references in code accidentally being bound to a jump table, and this reference being updated to a new (incorrect) location because we moved this jump table. Fixes #55004 Reviewed By: #bolt, maksfb Differential Revision: https://reviews.llvm.org/D134098 --- diff --git a/bolt/include/bolt/Passes/ValidateMemRefs.h b/bolt/include/bolt/Passes/ValidateMemRefs.h new file mode 100644 index 0000000..d33862c --- /dev/null +++ b/bolt/include/bolt/Passes/ValidateMemRefs.h @@ -0,0 +1,41 @@ +//===- bolt/Passes/ValidateMemRefs.h ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef BOLT_PASSES_VALIDATEMEMREFS_H +#define BOLT_PASSES_VALIDATEMEMREFS_H + +#include "bolt/Passes/BinaryPasses.h" + +namespace llvm::bolt { + +/// Post processing to check for memory references that cause a symbol +/// in data section to be ambiguous, requiring us to avoid moving that +/// object or disambiguating such references. This is currently +/// limited to fixing false references to the location of jump tables. +/// +class ValidateMemRefs : public BinaryFunctionPass { +public: + explicit ValidateMemRefs(const cl::opt &PrintPass) + : BinaryFunctionPass(PrintPass) {} + + const char *getName() const override { return "validate-mem-refs"; } + + void runOnFunctions(BinaryContext &BC) override; + +private: + bool checkAndFixJTReference(BinaryFunction &BF, MCInst &Inst, + uint32_t OperandNum, const MCSymbol *Sym, + uint64_t Offset); + void runOnFunction(BinaryFunction &BF); + + static std::atomic ReplacedReferences; +}; + +} // namespace llvm::bolt + +#endif diff --git a/bolt/lib/Core/BinaryEmitter.cpp b/bolt/lib/Core/BinaryEmitter.cpp index c4f887d..f618619 100644 --- a/bolt/lib/Core/BinaryEmitter.cpp +++ b/bolt/lib/Core/BinaryEmitter.cpp @@ -815,7 +815,25 @@ void BinaryEmitter::emitJumpTable(const JumpTable &JT, MCSection *HotSection, Streamer.switchSection(ColdSection); Streamer.emitValueToAlignment(JT.EntrySize); } - Streamer.emitLabel(LI->second); + // Emit all labels registered at the address of this jump table + // to sync with our global symbol table. We may have two labels + // registered at this address if one label was created via + // getOrCreateGlobalSymbol() (e.g. LEA instructions referencing + // this location) and another via getOrCreateJumpTable(). This + // creates a race where the symbols created by these two + // functions may or may not be the same, but they are both + // registered in our symbol table at the same address. By + // emitting them all here we make sure there is no ambiguity + // that depends on the order that these symbols were created, so + // whenever this address is referenced in the binary, it is + // certain to point to the jump table identified at this + // address. + if (BinaryData *BD = BC.getBinaryDataByName(LI->second->getName())) { + for (MCSymbol *S : BD->getSymbols()) + Streamer.emitLabel(S); + } else { + Streamer.emitLabel(LI->second); + } LastLabel = LI->second; } if (JT.Type == JumpTable::JTT_NORMAL) { diff --git a/bolt/lib/Passes/CMakeLists.txt b/bolt/lib/Passes/CMakeLists.txt index 544b6df..7f67261 100644 --- a/bolt/lib/Passes/CMakeLists.txt +++ b/bolt/lib/Passes/CMakeLists.txt @@ -42,6 +42,7 @@ add_llvm_library(LLVMBOLTPasses TailDuplication.cpp ThreeWayBranch.cpp ValidateInternalCalls.cpp + ValidateMemRefs.cpp VeneerElimination.cpp RetpolineInsertion.cpp diff --git a/bolt/lib/Passes/ValidateMemRefs.cpp b/bolt/lib/Passes/ValidateMemRefs.cpp new file mode 100644 index 0000000..69cd82a --- /dev/null +++ b/bolt/lib/Passes/ValidateMemRefs.cpp @@ -0,0 +1,104 @@ +//===- bolt/Passes/ValidateMemRefs.cpp ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "bolt/Passes/ValidateMemRefs.h" +#include "bolt/Core/ParallelUtilities.h" + +#define DEBUG_TYPE "bolt-memrefs" + +namespace opts { +extern llvm::cl::opt JumpTables; +} + +namespace llvm::bolt { + +std::atomic ValidateMemRefs::ReplacedReferences{0}; + +bool ValidateMemRefs::checkAndFixJTReference(BinaryFunction &BF, MCInst &Inst, + uint32_t OperandNum, + const MCSymbol *Sym, + uint64_t Offset) { + BinaryContext &BC = BF.getBinaryContext(); + auto L = BC.scopeLock(); + BinaryData *BD = BC.getBinaryDataByName(Sym->getName()); + if (!BD) + return false; + + const uint64_t TargetAddress = BD->getAddress() + Offset; + JumpTable *JT = BC.getJumpTableContainingAddress(TargetAddress); + if (!JT) + return false; + + const bool IsLegitAccess = llvm::any_of( + JT->Parents, [&](const BinaryFunction *Parent) { return Parent == &BF; }); + if (IsLegitAccess) + return true; + + // Accessing a jump table in another function. This is not a + // legitimate jump table access, we need to replace the reference to + // the jump table label with a regular rodata reference. Get a + // non-JT reference by fetching the symbol 1 byte before the JT + // label. + MCSymbol *NewSym = BC.getOrCreateGlobalSymbol(TargetAddress - 1, "DATAat"); + BC.MIB->setOperandToSymbolRef(Inst, OperandNum, NewSym, 1, &*BC.Ctx, 0); + LLVM_DEBUG(dbgs() << "BOLT-DEBUG: replaced reference @" << BF.getPrintName() + << " from " << BD->getName() << " to " << NewSym->getName() + << " + 1\n"); + ++ReplacedReferences; + return true; +} + +void ValidateMemRefs::runOnFunction(BinaryFunction &BF) { + MCPlusBuilder *MIB = BF.getBinaryContext().MIB.get(); + + for (BinaryBasicBlock &BB : BF) { + for (MCInst &Inst : BB) { + for (int I = 0, E = MCPlus::getNumPrimeOperands(Inst); I != E; ++I) { + const MCOperand &Operand = Inst.getOperand(I); + if (!Operand.isExpr()) + continue; + + const auto [Sym, Offset] = MIB->getTargetSymbolInfo(Operand.getExpr()); + if (!Sym) + continue; + + checkAndFixJTReference(BF, Inst, I, Sym, Offset); + } + } + } +} + +void ValidateMemRefs::runOnFunctions(BinaryContext &BC) { + if (!BC.isX86()) + return; + + // Skip validation if not moving JT + if (opts::JumpTables == JTS_NONE || opts::JumpTables == JTS_BASIC) + return; + + ParallelUtilities::WorkFuncWithAllocTy ProcessFunction = + [&](BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId) { + runOnFunction(BF); + }; + ParallelUtilities::PredicateTy SkipPredicate = [&](const BinaryFunction &BF) { + return !BF.hasCFG(); + }; + LLVM_DEBUG(dbgs() << "BOLT-DEBUG: starting memrefs validation pass\n"); + ParallelUtilities::runOnEachFunctionWithUniqueAllocId( + BC, ParallelUtilities::SchedulingPolicy::SP_INST_LINEAR, ProcessFunction, + SkipPredicate, "validate-mem-refs", /*ForceSequential=*/true); + LLVM_DEBUG(dbgs() << "BOLT-DEBUG: memrefs validation is concluded\n"); + + if (!ReplacedReferences) + return; + + outs() << "BOLT-INFO: validate-mem-refs updated " << ReplacedReferences + << " object references\n"; +} + +} // namespace llvm::bolt diff --git a/bolt/lib/Rewrite/BinaryPassManager.cpp b/bolt/lib/Rewrite/BinaryPassManager.cpp index dff93b9..85b49a2 100644 --- a/bolt/lib/Rewrite/BinaryPassManager.cpp +++ b/bolt/lib/Rewrite/BinaryPassManager.cpp @@ -31,6 +31,7 @@ #include "bolt/Passes/TailDuplication.h" #include "bolt/Passes/ThreeWayBranch.h" #include "bolt/Passes/ValidateInternalCalls.h" +#include "bolt/Passes/ValidateMemRefs.h" #include "bolt/Passes/VeneerElimination.h" #include "bolt/Utils/CommandLineOpts.h" #include "llvm/Support/FormatVariadic.h" @@ -317,9 +318,6 @@ void BinaryFunctionPassManager::runAllPasses(BinaryContext &BC) { Manager.registerPass( std::make_unique(PrintVeneerElimination)); - if (opts::Instrument) - Manager.registerPass(std::make_unique(NeverPrint)); - // Here we manage dependencies/order manually, since passes are run in the // order they're registered. @@ -331,6 +329,11 @@ void BinaryFunctionPassManager::runAllPasses(BinaryContext &BC) { Manager.registerPass(std::make_unique(NeverPrint)); + Manager.registerPass(std::make_unique(NeverPrint)); + + if (opts::Instrument) + Manager.registerPass(std::make_unique(NeverPrint)); + Manager.registerPass(std::make_unique(NeverPrint)); Manager.registerPass(std::make_unique(NeverPrint)); diff --git a/bolt/test/runtime/X86/jt-symbol-disambiguation-2.s b/bolt/test/runtime/X86/jt-symbol-disambiguation-2.s new file mode 100644 index 0000000..d81c233 --- /dev/null +++ b/bolt/test/runtime/X86/jt-symbol-disambiguation-2.s @@ -0,0 +1,100 @@ +# In this test case, we reproduce the behavior seen in gcc where the +# base address of a data object is decremented by some number and lands +# inside a jump table from another function. + +# REQUIRES: system-linux + +# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown %s -o %t.o +# RUN: llvm-strip --strip-unneeded %t.o +# RUN: %clang %cflags -no-pie -nostartfiles -nostdlib -lc %t.o -o %t.exe -Wl,-q + +# RUN: llvm-bolt %t.exe -o %t.exe.bolt --relocs=1 --lite=0 \ +# RUN: --reorder-blocks=reverse -jump-tables=move + +# RUN: %t.exe.bolt 1 2 3 + + .file "jt-symbol-disambiguation-2.s" + .text + +# ---- +# Func foo contains a jump table whose start is colocated with a +# symbol marking the end of a data table +# ---- + .globl foo + .type foo, @function +foo: + .cfi_startproc + xor %rax,%rax + and $0x3,%rdi + leaq .JT1(%rip), %rax + movslq (%rax, %rdi, 4), %rdi + addq %rax, %rdi + jmpq *%rdi +.LBB1: + movl $0x1,%eax + jmp .LBB5 +.LBB2: + movl $0x2,%eax + jmp .LBB5 +.LBB3: + movl $0x3,%eax + jmp .LBB5 +.LBB4: + movl $0x4,%eax +.LBB5: + retq + .cfi_endproc + .size foo, .-foo + +# ---- +# Func _start scans an object with indexed access using %rax * 8 as an +# index. However, %rax is known to be at least one, so the compiler +# loads the pointer for the base address as object - 8 instead of just +# object. +# ---- + .globl _start + .type _start, @function +_start: + .cfi_startproc + movq (%rsp), %rdi + callq foo + xorq %rbx, %rbx + leaq .object-8(%rip), %rsi # indexed access base address + movq $1, %rax # start index +.LBB6: + cmpq $4, %rax + je .LBB7 + addq (%rsi,%rax,8), %rbx + incq %rax # ++iterator + jmp .LBB6 +.LBB7: + cmpq $1368, %rbx # check .object contents integrity + jne .LBB_BAD + xor %rdi, %rdi + callq exit@PLT + retq +.LBB_BAD: + leaq .message, %rdi + callq puts@PLT + movq $1, %rdi + callq exit@PLT + retq + .cfi_endproc + .size _start, .-_start + +# ---- +# Data section +# ---- + .section .rodata,"a",@progbits + .p2align 3 +.JT1: + .long .LBB1 - .JT1 + .long .LBB2 - .JT1 + .long .LBB3 - .JT1 + .long .LBB4 - .JT1 +.object: + .quad 123 + .quad 456 + .quad 789 +.message: + .asciz "RUNTIME ASSERTION FAILURE: references in test binary are corrupt after BOLT" diff --git a/bolt/test/runtime/X86/jt-symbol-disambiguation.s b/bolt/test/runtime/X86/jt-symbol-disambiguation.s new file mode 100644 index 0000000..ea3b893 --- /dev/null +++ b/bolt/test/runtime/X86/jt-symbol-disambiguation.s @@ -0,0 +1,91 @@ +# In this test case, the symbol that represents the end of a table +# in .rodata is being colocated with the start of a jump table from +# another function, and BOLT moves that jump table. This should not +# cause the symbol representing the end of the table to be moved as +# well. +# Bug reported in https://github.com/llvm/llvm-project/issues/55004 + +# REQUIRES: system-linux + +# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown %s -o %t.o +# RUN: llvm-strip --strip-unneeded %t.o +# RUN: %clang %cflags -no-pie -nostartfiles -nostdlib -lc %t.o -o %t.exe -Wl,-q + +# RUN: llvm-bolt %t.exe -o %t.exe.bolt --relocs=1 --lite=0 \ +# RUN: --reorder-blocks=reverse -jump-tables=move + +# RUN: %t.exe.bolt 1 2 3 + + .file "jt-symbol-disambiguation.s" + .text + +# ---- +# Func foo contains a jump table whose start is colocated with a +# symbol marking the end of a data table +# ---- + .globl foo + .type foo, @function +foo: + .cfi_startproc + xor %rax,%rax + and $0x3,%rdi + leaq .JT1(%rip), %rax + movslq (%rax, %rdi, 4), %rdi + addq %rax, %rdi + jmpq *%rdi +.LBB1: + movl $0x1,%eax + jmp .LBB5 +.LBB2: + movl $0x2,%eax + jmp .LBB5 +.LBB3: + movl $0x3,%eax + jmp .LBB5 +.LBB4: + movl $0x4,%eax +.LBB5: + retq + .cfi_endproc + .size foo, .-foo + +# ---- +# Func _start scans a table using begin/end pointers. End pointer is colocated +# with the start of a jump table of function foo. When that jump +# table moves, end pointer in _start should not be affected. +# ---- + .globl _start + .type _start, @function +_start: + .cfi_startproc + movq (%rsp), %rdi + callq foo + leaq .start_of_table(%rip), %rsi # iterator + leaq .end_of_table(%rip), %rdi # iterator end +.LBB6: + cmpq %rsi, %rdi + je .LBB7 + movq (%rsi), %rbx + leaq 8(%rsi), %rsi # ++iterator + jmp .LBB6 +.LBB7: + xor %rdi, %rdi + callq exit@PLT + .cfi_endproc + .size _start, .-_start + +# ---- +# Data section +# ---- + .section .rodata,"a",@progbits + .p2align 3 +.start_of_table: + .quad 123 + .quad 456 + .quad 789 +.end_of_table: +.JT1: + .long .LBB1 - .JT1 + .long .LBB2 - .JT1 + .long .LBB3 - .JT1 + .long .LBB4 - .JT1