From 8fb7bb24827b3ae1285d4f67c11c90e7d748e889 Mon Sep 17 00:00:00 2001 From: Sam McCall Date: Mon, 24 Sep 2018 14:51:15 +0000 Subject: [PATCH] [clangd] Do bounds checks while reading data, otherwise var-length records are too painful. NFC llvm-svn: 342888 --- clang-tools-extra/clangd/index/Serialization.cpp | 277 +++++++++++------------ 1 file changed, 138 insertions(+), 139 deletions(-) diff --git a/clang-tools-extra/clangd/index/Serialization.cpp b/clang-tools-extra/clangd/index/Serialization.cpp index e1bf322..919d4fc 100644 --- a/clang-tools-extra/clangd/index/Serialization.cpp +++ b/clang-tools-extra/clangd/index/Serialization.cpp @@ -23,24 +23,83 @@ Error makeError(const Twine &Msg) { // IO PRIMITIVES // We use little-endian 32 bit ints, sometimes with variable-length encoding. +// +// Variable-length int encoding (varint) uses the bottom 7 bits of each byte +// to encode the number, and the top bit to indicate whether more bytes follow. +// e.g. 9a 2f means [0x1a and keep reading, 0x2f and stop]. +// This represents 0x1a | 0x2f<<7 = 6042. +// A 32-bit integer takes 1-5 bytes to encode; small numbers are more compact. -StringRef consume(StringRef &Data, int N) { - StringRef Ret = Data.take_front(N); - Data = Data.drop_front(N); - return Ret; -} +// Reads binary data from a StringRef, and keeps track of position. +class Reader { + const char *Begin, *End; + bool Err; -uint8_t consume8(StringRef &Data) { - uint8_t Ret = Data.front(); - Data = Data.drop_front(); - return Ret; -} +public: + Reader(StringRef Data) : Begin(Data.begin()), End(Data.end()) {} + // The "error" bit is set by reading past EOF or reading invalid data. + // When in an error state, reads may return zero values: callers should check. + bool err() const { return Err; } + // Did we read all the data, or encounter an error? + bool eof() const { return Begin == End || Err; } + // All the data we didn't read yet. + StringRef rest() const { return StringRef(Begin, End - Begin); } + + uint8_t consume8() { + if (LLVM_UNLIKELY(Begin == End)) { + Err = true; + return 0; + } + return *Begin++; + } -uint32_t consume32(StringRef &Data) { - auto Ret = support::endian::read32le(Data.bytes_begin()); - Data = Data.drop_front(4); - return Ret; -} + uint32_t consume32() { + if (LLVM_UNLIKELY(Begin + 4 > End)) { + Err = true; + return 0; + } + auto Ret = support::endian::read32le(Begin); + Begin += 4; + return Ret; + } + + StringRef consume(int N) { + if (LLVM_UNLIKELY(Begin + N > End)) { + Err = true; + return StringRef(); + } + StringRef Ret(Begin, N); + Begin += N; + return Ret; + } + + uint32_t consumeVar() { + constexpr static uint8_t More = 1 << 7; + uint8_t B = consume8(); + if (LLVM_LIKELY(!(B & More))) + return B; + uint32_t Val = B & ~More; + for (int Shift = 7; B & More && Shift < 32; Shift += 7) { + B = consume8(); + Val |= (B & ~More) << Shift; + } + return Val; + } + + StringRef consumeString(ArrayRef Strings) { + auto StringIndex = consumeVar(); + if (LLVM_UNLIKELY(StringIndex >= Strings.size())) { + Err = true; + return StringRef(); + } + return Strings[StringIndex]; + } + + SymbolID consumeID() { + StringRef Raw = consume(SymbolID::RawSize); // short if truncated. + return LLVM_UNLIKELY(err()) ? SymbolID() : SymbolID::fromRaw(Raw); + } +}; void write32(uint32_t I, raw_ostream &OS) { char buf[4]; @@ -48,11 +107,6 @@ void write32(uint32_t I, raw_ostream &OS) { OS.write(buf, sizeof(buf)); } -// Variable-length int encoding (varint) uses the bottom 7 bits of each byte -// to encode the number, and the top bit to indicate whether more bytes follow. -// e.g. 9a 2f means [0x1a and keep reading, 0x2f and stop]. -// This represents 0x1a | 0x2f<<7 = 6042. -// A 32-bit integer takes 1-5 bytes to encode; small numbers are more compact. void writeVar(uint32_t I, raw_ostream &OS) { constexpr static uint8_t More = 1 << 7; if (LLVM_LIKELY(I < 1 << 7)) { @@ -69,19 +123,6 @@ void writeVar(uint32_t I, raw_ostream &OS) { } } -uint32_t consumeVar(StringRef &Data) { - constexpr static uint8_t More = 1 << 7; - uint8_t B = consume8(Data); - if (LLVM_LIKELY(!(B & More))) - return B; - uint32_t Val = B & ~More; - for (int Shift = 7; B & More && Shift < 32; Shift += 7) { - B = consume8(Data); - Val |= (B & ~More) << Shift; - } - return Val; -} - // STRING TABLE ENCODING // Index data has many string fields, and many strings are identical. // We store each string once, and refer to them by index. @@ -146,30 +187,34 @@ struct StringTableIn { }; Expected readStringTable(StringRef Data) { - if (Data.size() < 4) - return makeError("Bad string table: not enough metadata"); - size_t UncompressedSize = consume32(Data); + Reader R(Data); + size_t UncompressedSize = R.consume32(); + if (R.err()) + return makeError("Truncated string table"); StringRef Uncompressed; SmallString<1> UncompressedStorage; if (UncompressedSize == 0) // No compression - Uncompressed = Data; + Uncompressed = R.rest(); else { - if (Error E = - llvm::zlib::uncompress(Data, UncompressedStorage, UncompressedSize)) + if (Error E = llvm::zlib::uncompress(R.rest(), UncompressedStorage, + UncompressedSize)) return std::move(E); Uncompressed = UncompressedStorage; } StringTableIn Table; StringSaver Saver(Table.Arena); - for (StringRef Rest = Uncompressed; !Rest.empty();) { - auto Len = Rest.find(0); + R = Reader(Uncompressed); + for (Reader R(Uncompressed); !R.eof();) { + auto Len = R.rest().find(0); if (Len == StringRef::npos) return makeError("Bad string table: not null terminated"); - Table.Strings.push_back(Saver.save(consume(Rest, Len))); - Rest = Rest.drop_front(); + Table.Strings.push_back(Saver.save(R.consume(Len))); + R.consume8(); } + if (R.err()) + return makeError("Truncated string table"); return std::move(Table); } @@ -179,27 +224,35 @@ Expected readStringTable(StringRef Data) { // - enums encode as the underlying type // - most numbers encode as varint -// It's useful to the implementation to assume symbols have a bounded size. -constexpr size_t SymbolSizeBound = 512; -// To ensure the bounded size, restrict the number of include headers stored. -constexpr unsigned MaxIncludes = 50; +void writeLocation(const SymbolLocation &Loc, const StringTableOut &Strings, + raw_ostream &OS) { + writeVar(Strings.index(Loc.FileURI), OS); + for (const auto &Endpoint : {Loc.Start, Loc.End}) { + writeVar(Endpoint.Line, OS); + writeVar(Endpoint.Column, OS); + } +} + +SymbolLocation readLocation(Reader &Data, ArrayRef Strings) { + SymbolLocation Loc; + Loc.FileURI = Data.consumeString(Strings); + for (auto *Endpoint : {&Loc.Start, &Loc.End}) { + Endpoint->Line = Data.consumeVar(); + Endpoint->Column = Data.consumeVar(); + } + return Loc; +} void writeSymbol(const Symbol &Sym, const StringTableOut &Strings, raw_ostream &OS) { - auto StartOffset = OS.tell(); OS << Sym.ID.raw(); // TODO: once we start writing xrefs and posting lists, // symbol IDs should probably be in a string table. OS.write(static_cast(Sym.SymInfo.Kind)); OS.write(static_cast(Sym.SymInfo.Lang)); writeVar(Strings.index(Sym.Name), OS); writeVar(Strings.index(Sym.Scope), OS); - for (const auto &Loc : {Sym.Definition, Sym.CanonicalDeclaration}) { - writeVar(Strings.index(Loc.FileURI), OS); - for (const auto &Endpoint : {Loc.Start, Loc.End}) { - writeVar(Endpoint.Line, OS); - writeVar(Endpoint.Column, OS); - } - } + writeLocation(Sym.Definition, Strings, OS); + writeLocation(Sym.CanonicalDeclaration, Strings, OS); writeVar(Sym.References, OS); OS.write(static_cast(Sym.Flags)); OS.write(static_cast(Sym.Origin)); @@ -212,86 +265,33 @@ void writeSymbol(const Symbol &Sym, const StringTableOut &Strings, writeVar(Strings.index(Include.IncludeHeader), OS); writeVar(Include.References, OS); }; - // There are almost certainly few includes, so we can just write them. - if (LLVM_LIKELY(Sym.IncludeHeaders.size() <= MaxIncludes)) { - writeVar(Sym.IncludeHeaders.size(), OS); - for (const auto &Include : Sym.IncludeHeaders) - WriteInclude(Include); - } else { - // If there are too many, make sure we truncate the least important. - using Pointer = const Symbol::IncludeHeaderWithReferences *; - std::vector Pointers; - for (const auto &Include : Sym.IncludeHeaders) - Pointers.push_back(&Include); - std::sort(Pointers.begin(), Pointers.end(), [](Pointer L, Pointer R) { - return L->References > R->References; - }); - Pointers.resize(MaxIncludes); - - writeVar(MaxIncludes, OS); - for (Pointer P : Pointers) - WriteInclude(*P); - } - - assert(OS.tell() - StartOffset < SymbolSizeBound && "Symbol length unsafe!"); - (void)StartOffset; // Unused in NDEBUG; + writeVar(Sym.IncludeHeaders.size(), OS); + for (const auto &Include : Sym.IncludeHeaders) + WriteInclude(Include); } -Expected readSymbol(StringRef &Data, const StringTableIn &Strings) { - // Usually we can skip bounds checks because the buffer is huge. - // Near the end of the buffer, this would be unsafe. In this rare case, copy - // the data into a bigger buffer so we can again skip the checks. - if (LLVM_UNLIKELY(Data.size() < SymbolSizeBound)) { - std::string Buf(Data); - Buf.resize(SymbolSizeBound); - StringRef ExtendedData = Buf; - auto Ret = readSymbol(ExtendedData, Strings); - unsigned BytesRead = Buf.size() - ExtendedData.size(); - if (BytesRead > Data.size()) - return makeError("read past end of data"); - Data = Data.drop_front(BytesRead); - return Ret; - } - -#define READ_STRING(Field) \ - do { \ - auto StringIndex = consumeVar(Data); \ - if (LLVM_UNLIKELY(StringIndex >= Strings.Strings.size())) \ - return makeError("Bad string index"); \ - Field = Strings.Strings[StringIndex]; \ - } while (0) - +Symbol readSymbol(Reader &Data, ArrayRef Strings) { Symbol Sym; - Sym.ID = SymbolID::fromRaw(consume(Data, 20)); - Sym.SymInfo.Kind = static_cast(consume8(Data)); - Sym.SymInfo.Lang = static_cast(consume8(Data)); - READ_STRING(Sym.Name); - READ_STRING(Sym.Scope); - for (SymbolLocation *Loc : {&Sym.Definition, &Sym.CanonicalDeclaration}) { - READ_STRING(Loc->FileURI); - for (auto &Endpoint : {&Loc->Start, &Loc->End}) { - Endpoint->Line = consumeVar(Data); - Endpoint->Column = consumeVar(Data); - } - } - Sym.References = consumeVar(Data); - Sym.Flags = static_cast(consume8(Data)); - Sym.Origin = static_cast(consume8(Data)); - READ_STRING(Sym.Signature); - READ_STRING(Sym.CompletionSnippetSuffix); - READ_STRING(Sym.Documentation); - READ_STRING(Sym.ReturnType); - unsigned IncludeHeaderN = consumeVar(Data); - if (IncludeHeaderN > MaxIncludes) - return makeError("too many IncludeHeaders"); - Sym.IncludeHeaders.resize(IncludeHeaderN); + Sym.ID = Data.consumeID(); + Sym.SymInfo.Kind = static_cast(Data.consume8()); + Sym.SymInfo.Lang = static_cast(Data.consume8()); + Sym.Name = Data.consumeString(Strings); + Sym.Scope = Data.consumeString(Strings); + Sym.Definition = readLocation(Data, Strings); + Sym.CanonicalDeclaration = readLocation(Data, Strings); + Sym.References = Data.consumeVar(); + Sym.Flags = static_cast(Data.consumeVar()); + Sym.Origin = static_cast(Data.consumeVar()); + Sym.Signature = Data.consumeString(Strings); + Sym.CompletionSnippetSuffix = Data.consumeString(Strings); + Sym.Documentation = Data.consumeString(Strings); + Sym.ReturnType = Data.consumeString(Strings); + Sym.IncludeHeaders.resize(Data.consumeVar()); for (auto &I : Sym.IncludeHeaders) { - READ_STRING(I.IncludeHeader); - I.References = consumeVar(Data); + I.IncludeHeader = Data.consumeString(Strings); + I.References = Data.consumeVar(); } - -#undef READ_STRING - return std::move(Sym); + return Sym; } } // namespace @@ -306,7 +306,7 @@ Expected readSymbol(StringRef &Data, const StringTableIn &Strings) { // The current versioning scheme is simple - non-current versions are rejected. // If you make a breaking change, bump this version number to invalidate stored // data. Later we may want to support some backward compatibility. -constexpr static uint32_t Version = 3; +constexpr static uint32_t Version = 4; Expected readIndexFile(StringRef Data) { auto RIFF = riff::readFile(Data); @@ -322,8 +322,8 @@ Expected readIndexFile(StringRef Data) { if (!Chunks.count(RequiredChunk)) return makeError("missing required chunk " + RequiredChunk); - StringRef Meta = Chunks.lookup("meta"); - if (Meta.size() < 4 || consume32(Meta) != Version) + Reader Meta(Chunks.lookup("meta")); + if (Meta.consume32() != Version) return makeError("wrong version"); auto Strings = readStringTable(Chunks.lookup("stri")); @@ -332,13 +332,12 @@ Expected readIndexFile(StringRef Data) { IndexFileIn Result; if (Chunks.count("symb")) { - StringRef SymbolData = Chunks.lookup("symb"); + Reader SymbolReader(Chunks.lookup("symb")); SymbolSlab::Builder Symbols; - while (!SymbolData.empty()) - if (auto Sym = readSymbol(SymbolData, *Strings)) - Symbols.insert(*Sym); - else - return Sym.takeError(); + while (!SymbolReader.eof()) + Symbols.insert(readSymbol(SymbolReader, Strings->Strings)); + if (SymbolReader.err()) + return makeError("malformed or truncated symbol"); Result.Symbols = std::move(Symbols).build(); } return std::move(Result); -- 2.7.4