//
//===----------------------------------------------------------------------===//
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Frontend/OpenMP/OMPGridValues.h"
+#include "llvm/Object/ELF.h"
+#include "llvm/Object/ELFObjectFile.h"
+
#include <algorithm>
#include <assert.h>
#include <cstdio>
#include <unordered_map>
#include <vector>
+#include "ELFSymbols.h"
#include "impl_runtime.h"
#include "interop_hsa.h"
#include "omptargetplugin.h"
#include "print_tracing.h"
-#include "llvm/ADT/StringMap.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
-#include "llvm/Frontend/OpenMP/OMPGridValues.h"
-
using namespace llvm;
+using namespace llvm::object;
// hostrpc interface, FIXME: consider moving to its own include these are
// statically linked into amdgpu/plugin if present from hostrpc_services.a,
return Changed;
}
-Elf64_Shdr *findOnlyShtHash(Elf *Elf) {
- size_t N;
- int Rc = elf_getshdrnum(Elf, &N);
- if (Rc != 0) {
- return nullptr;
- }
-
- Elf64_Shdr *Result = nullptr;
- for (size_t I = 0; I < N; I++) {
- Elf_Scn *Scn = elf_getscn(Elf, I);
- if (Scn) {
- Elf64_Shdr *Shdr = elf64_getshdr(Scn);
- if (Shdr) {
- if (Shdr->sh_type == SHT_HASH) {
- if (Result == nullptr) {
- Result = Shdr;
- } else {
- // multiple SHT_HASH sections not handled
- return nullptr;
- }
- }
- }
- }
- }
- return Result;
-}
-
-const Elf64_Sym *elfLookup(Elf *Elf, char *Base, Elf64_Shdr *SectionHash,
- const char *Symname) {
-
- assert(SectionHash);
- size_t SectionSymtabIndex = SectionHash->sh_link;
- Elf64_Shdr *SectionSymtab =
- elf64_getshdr(elf_getscn(Elf, SectionSymtabIndex));
- size_t SectionStrtabIndex = SectionSymtab->sh_link;
-
- const Elf64_Sym *Symtab =
- reinterpret_cast<const Elf64_Sym *>(Base + SectionSymtab->sh_offset);
-
- const uint32_t *Hashtab =
- reinterpret_cast<const uint32_t *>(Base + SectionHash->sh_offset);
-
- // Layout:
- // nbucket
- // nchain
- // bucket[nbucket]
- // chain[nchain]
- uint32_t Nbucket = Hashtab[0];
- const uint32_t *Bucket = &Hashtab[2];
- const uint32_t *Chain = &Hashtab[Nbucket + 2];
-
- const size_t Max = strlen(Symname) + 1;
- const uint32_t Hash = elf_hash(Symname);
- for (uint32_t I = Bucket[Hash % Nbucket]; I != 0; I = Chain[I]) {
- char *N = elf_strptr(Elf, SectionStrtabIndex, Symtab[I].st_name);
- if (strncmp(Symname, N, Max) == 0) {
- return &Symtab[I];
- }
- }
-
- return nullptr;
-}
-
struct SymbolInfo {
- void *Addr = nullptr;
+ const void *Addr = nullptr;
uint32_t Size = UINT32_MAX;
uint32_t ShType = SHT_NULL;
};
-int getSymbolInfoWithoutLoading(Elf *Elf, char *Base, const char *Symname,
- SymbolInfo *Res) {
- if (elf_kind(Elf) != ELF_K_ELF) {
- return 1;
- }
-
- Elf64_Shdr *SectionHash = findOnlyShtHash(Elf);
- if (!SectionHash) {
- return 1;
- }
-
- const Elf64_Sym *Sym = elfLookup(Elf, Base, SectionHash, Symname);
- if (!Sym) {
+int getSymbolInfoWithoutLoading(const ELFObjectFile<ELF64LE> &ELFObj,
+ StringRef SymName, SymbolInfo *Res) {
+ auto SymOrErr = getELFSymbol(ELFObj, SymName);
+ if (!SymOrErr) {
+ std::string ErrorString = toString(SymOrErr.takeError());
+ DP("Failed ELF lookup: %s\n", ErrorString.c_str());
return 1;
}
-
- if (Sym->st_size > UINT32_MAX) {
- return 1;
- }
-
- if (Sym->st_shndx == SHN_UNDEF) {
- return 1;
- }
-
- Elf_Scn *Section = elf_getscn(Elf, Sym->st_shndx);
- if (!Section) {
+ if (!*SymOrErr)
return 1;
- }
- Elf64_Shdr *Header = elf64_getshdr(Section);
- if (!Header) {
+ auto SymSecOrErr = ELFObj.getELFFile().getSection((*SymOrErr)->st_shndx);
+ if (!SymSecOrErr) {
+ std::string ErrorString = toString(SymOrErr.takeError());
+ DP("Failed ELF lookup: %s\n", ErrorString.c_str());
return 1;
}
- Res->Addr = Sym->st_value + Base;
- Res->Size = static_cast<uint32_t>(Sym->st_size);
- Res->ShType = Header->sh_type;
+ Res->Addr = (*SymOrErr)->st_value + ELFObj.getELFFile().base();
+ Res->Size = static_cast<uint32_t>((*SymOrErr)->st_size);
+ Res->ShType = static_cast<uint32_t>((*SymSecOrErr)->sh_type);
return 0;
}
-int getSymbolInfoWithoutLoading(char *Base, size_t ImgSize, const char *Symname,
+int getSymbolInfoWithoutLoading(char *Base, size_t ImgSize, const char *SymName,
SymbolInfo *Res) {
- Elf *Elf = elf_memory(Base, ImgSize);
- if (Elf) {
- int Rc = getSymbolInfoWithoutLoading(Elf, Base, Symname, Res);
- elf_end(Elf);
- return Rc;
+ StringRef Buffer = StringRef(Base, ImgSize);
+ auto ElfOrErr = ObjectFile::createELFObjectFile(MemoryBufferRef(Buffer, ""),
+ /*InitContent=*/false);
+ if (!ElfOrErr) {
+ REPORT("Failed to load ELF: %s\n", toString(ElfOrErr.takeError()).c_str());
+ return 1;
}
+
+ if (const auto *ELFObj = dyn_cast<ELF64LEObjectFile>(ElfOrErr->get()))
+ return getSymbolInfoWithoutLoading(*ELFObj, SymName, Res);
return 1;
}
hsa_status_t interopGetSymbolInfo(char *Base, size_t ImgSize,
- const char *SymName, void **VarAddr,
+ const char *SymName, const void **VarAddr,
uint32_t *VarSize) {
SymbolInfo SI;
int Rc = getSymbolInfoWithoutLoading(Base, ImgSize, SymName, &SI);
KernDescNameStr += "_kern_desc";
const char *KernDescName = KernDescNameStr.c_str();
- void *KernDescPtr;
+ const void *KernDescPtr;
uint32_t KernDescSize;
void *CallStackAddr = nullptr;
Err = interopGetSymbolInfo((char *)Image->ImageStart, ImgSize, KernDescName,
WGSizeNameStr += "_wg_size";
const char *WGSizeName = WGSizeNameStr.c_str();
- void *WGSizePtr;
+ const void *WGSizePtr;
uint32_t WGSize;
Err = interopGetSymbolInfo((char *)Image->ImageStart, ImgSize, WGSizeName,
&WGSizePtr, &WGSize);
ExecModeNameStr += "_exec_mode";
const char *ExecModeName = ExecModeNameStr.c_str();
- void *ExecModePtr;
+ const void *ExecModePtr;
uint32_t VarSize;
Err = interopGetSymbolInfo((char *)Image->ImageStart, ImgSize, ExecModeName,
&ExecModePtr, &VarSize);
#
##===----------------------------------------------------------------------===##
-add_library(elf_common OBJECT elf_common.cpp)
+add_library(elf_common OBJECT elf_common.cpp ELFSymbols.cpp)
# Build elf_common with PIC to be able to link it with plugin shared libraries.
set_property(TARGET elf_common PROPERTY POSITION_INDEPENDENT_CODE ON)
--- /dev/null
+#include "ELFSymbols.h"
+
+using namespace llvm;
+using namespace llvm::object;
+using namespace llvm::ELF;
+
+template <class ELFT>
+static Expected<const typename ELFT::Sym *>
+getSymbolFromGnuHashTable(StringRef Name, const typename ELFT::GnuHash &HashTab,
+ ArrayRef<typename ELFT::Sym> SymTab,
+ StringRef StrTab) {
+ const uint32_t NameHash = hashGnu(Name);
+ const typename ELFT::Word NBucket = HashTab.nbuckets;
+ const typename ELFT::Word SymOffset = HashTab.symndx;
+ ArrayRef<typename ELFT::Off> Filter = HashTab.filter();
+ ArrayRef<typename ELFT::Word> Bucket = HashTab.buckets();
+ ArrayRef<typename ELFT::Word> Chain = HashTab.values(SymTab.size());
+
+ // Check the bloom filter and exit early if the symbol is not present.
+ uint64_t ElfClassBits = ELFT::Is64Bits ? 64 : 32;
+ typename ELFT::Off Word =
+ Filter[(NameHash / ElfClassBits) % HashTab.maskwords];
+ uint64_t Mask = (0x1ull << (NameHash % ElfClassBits)) |
+ (0x1ull << ((NameHash >> HashTab.shift2) % ElfClassBits));
+ if ((Word & Mask) != Mask)
+ return nullptr;
+
+ // The symbol may or may not be present, check the hash values.
+ for (typename ELFT::Word I = Bucket[NameHash % NBucket];
+ I >= SymOffset && I < SymTab.size(); I = I + 1) {
+ const uint32_t ChainHash = Chain[I - SymOffset];
+
+ if ((NameHash | 0x1) != (ChainHash | 0x1))
+ continue;
+
+ if (SymTab[I].st_name >= StrTab.size())
+ return createError("symbol [index " + Twine(I) +
+ "] has invalid st_name: " + Twine(SymTab[I].st_name));
+ if (StrTab.drop_front(SymTab[I].st_name).data() == Name)
+ return &SymTab[I];
+
+ if (ChainHash & 0x1)
+ return nullptr;
+ }
+ return nullptr;
+}
+
+template <class ELFT>
+static Expected<const typename ELFT::Sym *>
+getSymbolFromSysVHashTable(StringRef Name, const typename ELFT::Hash &HashTab,
+ ArrayRef<typename ELFT::Sym> SymTab,
+ StringRef StrTab) {
+ const uint32_t Hash = hashSysV(Name);
+ const typename ELFT::Word NBucket = HashTab.nbucket;
+ ArrayRef<typename ELFT::Word> Bucket = HashTab.buckets();
+ ArrayRef<typename ELFT::Word> Chain = HashTab.chains();
+ for (typename ELFT::Word I = Bucket[Hash % NBucket]; I != ELF::STN_UNDEF;
+ I = Chain[I]) {
+ if (I >= SymTab.size())
+ return createError(
+ "symbol [index " + Twine(I) +
+ "] is greater than the number of symbols: " + Twine(SymTab.size()));
+ if (SymTab[I].st_name >= StrTab.size())
+ return createError("symbol [index " + Twine(I) +
+ "] has invalid st_name: " + Twine(SymTab[I].st_name));
+
+ if (StrTab.drop_front(SymTab[I].st_name).data() == Name)
+ return &SymTab[I];
+ }
+ return nullptr;
+}
+
+template <class ELFT>
+static Expected<const typename ELFT::Sym *>
+getHashTableSymbol(const ELFFile<ELFT> &Elf, const typename ELFT::Shdr &Sec,
+ StringRef Name) {
+ if (Sec.sh_type != ELF::SHT_HASH && Sec.sh_type != ELF::SHT_GNU_HASH)
+ return createError(
+ "invalid sh_type for hash table, expected SHT_HASH or SHT_GNU_HASH");
+ Expected<typename ELFT::ShdrRange> SectionsOrError = Elf.sections();
+ if (!SectionsOrError)
+ return SectionsOrError.takeError();
+
+ auto SymTabOrErr = getSection<ELFT>(*SectionsOrError, Sec.sh_link);
+ if (!SymTabOrErr)
+ return SymTabOrErr.takeError();
+
+ auto StrTabOrErr =
+ Elf.getStringTableForSymtab(**SymTabOrErr, *SectionsOrError);
+ if (!StrTabOrErr)
+ return StrTabOrErr.takeError();
+ StringRef StrTab = *StrTabOrErr;
+
+ auto SymsOrErr = Elf.symbols(*SymTabOrErr);
+ if (!SymsOrErr)
+ return SymsOrErr.takeError();
+ ArrayRef<typename ELFT::Sym> SymTab = *SymsOrErr;
+
+ // If this is a GNU hash table we verify its size and search the symbol
+ // table using the GNU hash table format.
+ if (Sec.sh_type == ELF::SHT_GNU_HASH) {
+ const typename ELFT::GnuHash *HashTab =
+ reinterpret_cast<const typename ELFT::GnuHash *>(Elf.base() +
+ Sec.sh_offset);
+ if (Sec.sh_offset + Sec.sh_size >= Elf.getBufSize())
+ return createError("section has invalid sh_offset: " +
+ Twine(Sec.sh_offset));
+ if (Sec.sh_size < sizeof(typename ELFT::GnuHash) ||
+ Sec.sh_size <
+ sizeof(typename ELFT::GnuHash) +
+ sizeof(typename ELFT::Word) * HashTab->maskwords +
+ sizeof(typename ELFT::Word) * HashTab->nbuckets +
+ sizeof(typename ELFT::Word) * (SymTab.size() - HashTab->symndx))
+ return createError("section has invalid sh_size: " + Twine(Sec.sh_size));
+ return getSymbolFromGnuHashTable<ELFT>(Name, *HashTab, SymTab, StrTab);
+ }
+
+ // If this is a Sys-V hash table we verify its size and search the symbol
+ // table using the Sys-V hash table format.
+ if (Sec.sh_type == ELF::SHT_HASH) {
+ const typename ELFT::Hash *HashTab =
+ reinterpret_cast<const typename ELFT::Hash *>(Elf.base() +
+ Sec.sh_offset);
+ if (Sec.sh_offset + Sec.sh_size >= Elf.getBufSize())
+ return createError("section has invalid sh_offset: " +
+ Twine(Sec.sh_offset));
+ if (Sec.sh_size < sizeof(typename ELFT::Hash) ||
+ Sec.sh_size < sizeof(typename ELFT::Hash) +
+ sizeof(typename ELFT::Word) * HashTab->nbucket +
+ sizeof(typename ELFT::Word) * HashTab->nchain)
+ return createError("section has invalid sh_size: " + Twine(Sec.sh_size));
+
+ return getSymbolFromSysVHashTable<ELFT>(Name, *HashTab, SymTab, StrTab);
+ }
+
+ return nullptr;
+}
+
+template <class ELFT>
+static Expected<const typename ELFT::Sym *>
+getSymTableSymbol(const ELFFile<ELFT> &Elf, const typename ELFT::Shdr &Sec,
+ StringRef Name) {
+ if (Sec.sh_type != ELF::SHT_SYMTAB && Sec.sh_type != ELF::SHT_DYNSYM)
+ return createError(
+ "invalid sh_type for hash table, expected SHT_SYMTAB or SHT_DYNSYM");
+ Expected<typename ELFT::ShdrRange> SectionsOrError = Elf.sections();
+ if (!SectionsOrError)
+ return SectionsOrError.takeError();
+
+ auto StrTabOrErr = Elf.getStringTableForSymtab(Sec, *SectionsOrError);
+ if (!StrTabOrErr)
+ return StrTabOrErr.takeError();
+ StringRef StrTab = *StrTabOrErr;
+
+ auto SymsOrErr = Elf.symbols(&Sec);
+ if (!SymsOrErr)
+ return SymsOrErr.takeError();
+ ArrayRef<typename ELFT::Sym> SymTab = *SymsOrErr;
+
+ for (const typename ELFT::Sym &Sym : SymTab)
+ if (StrTab.drop_front(Sym.st_name).data() == Name)
+ return &Sym;
+
+ return nullptr;
+}
+
+Expected<const typename ELF64LE::Sym *>
+getELFSymbol(const ELFObjectFile<ELF64LE> &ELFObj, StringRef Name) {
+ // First try to look up the symbol via the hash table.
+ for (ELFSectionRef Sec : ELFObj.sections()) {
+ if (Sec.getType() != SHT_HASH && Sec.getType() != SHT_GNU_HASH)
+ continue;
+
+ auto HashTabOrErr = ELFObj.getELFFile().getSection(Sec.getIndex());
+ if (!HashTabOrErr)
+ return HashTabOrErr.takeError();
+ return getHashTableSymbol<ELF64LE>(ELFObj.getELFFile(), **HashTabOrErr,
+ Name);
+ }
+
+ // If this is an executable file check the entire standard symbol table.
+ for (ELFSectionRef Sec : ELFObj.sections()) {
+ if (Sec.getType() != SHT_SYMTAB)
+ continue;
+
+ auto SymTabOrErr = ELFObj.getELFFile().getSection(Sec.getIndex());
+ if (!SymTabOrErr)
+ return SymTabOrErr.takeError();
+ return getSymTableSymbol<ELF64LE>(ELFObj.getELFFile(), **SymTabOrErr, Name);
+ }
+
+ return nullptr;
+}
--- /dev/null
+//===-- ELFSymbols.h - ELF Symbol look-up functionality ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// ELF routines for obtaining a symbol from an Elf file without loading it.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_ELF_COMMON_ELF_SYMBOLS_H
+#define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_ELF_COMMON_ELF_SYMBOLS_H
+
+#include "llvm/Object/ELF.h"
+#include "llvm/Object/ELFObjectFile.h"
+
+/// Returns the symbol associated with the \p Name in the \p ELFObj. It will
+/// first search for the hash sections to identify symbols from the hash table.
+/// If that fails it will fall back to a linear search in the case of an
+/// executable file without a hash table.
+llvm::Expected<const typename llvm::object::ELF64LE::Sym *>
+getELFSymbol(const llvm::object::ELFObjectFile<llvm::object::ELF64LE> &ELFObj,
+ llvm::StringRef Name);
+
+#endif