Made FlexBuffers reuse tracker track types
authorWouter van Oortmerssen <aardappel@gmail.com>
Tue, 14 Dec 2021 18:00:56 +0000 (10:00 -0800)
committerWouter van Oortmerssen <aardappel@gmail.com>
Tue, 14 Dec 2021 19:20:23 +0000 (11:20 -0800)
include/flatbuffers/flexbuffers.h
include/flatbuffers/verifier.h
src/flatc.cpp
tests/test.cpp

index 09e4d77..0e15c6f 100644 (file)
@@ -373,7 +373,7 @@ class Reference {
   Reference()
       : data_(nullptr),
         parent_width_(0),
-        byte_width_(BIT_WIDTH_8),
+        byte_width_(0),
         type_(FBT_NULL) {}
 
   Reference(const uint8_t *data, uint8_t parent_width, uint8_t byte_width,
@@ -1632,6 +1632,8 @@ class Builder FLATBUFFERS_FINAL_CLASS {
 
   KeyOffsetMap key_pool;
   StringOffsetMap string_pool;
+
+  friend class Verifier;
 };
 
 // Helper class to verify the integrity of a FlexBuffer
@@ -1640,19 +1642,14 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
   Verifier(const uint8_t *buf, size_t buf_len,
            // Supplying this vector likely results in faster verification
            // of larger buffers with many shared keys/strings, but
-           // comes at the cost of using additional memory 1/8th the size of
-           // the buffer being verified, so it is allowed to be null
-           // for special situations (memory constrained devices or
-           // really small buffers etc). Do note that when not supplying
-           // this buffer, you are not protected against buffers crafted
-           // specifically to DoS you, i.e. recursive sharing that causes
-           // exponential amounts of verification CPU time.
-           std::vector<bool> *reuse_tracker)
+           // comes at the cost of using additional memory the same size of
+           // the buffer being verified, so it is by default off.
+           std::vector<uint8_t> *reuse_tracker = nullptr)
       : buf_(buf), size_(buf_len), reuse_tracker_(reuse_tracker) {
     FLATBUFFERS_ASSERT(size_ < FLATBUFFERS_MAX_BUFFER_SIZE);
     if (reuse_tracker_) {
       reuse_tracker_->clear();
-      reuse_tracker_->resize(size_);
+      reuse_tracker_->resize(size_, PackedType(BIT_WIDTH_8, FBT_NULL));
     }
   }
 
@@ -1697,23 +1694,28 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
                  off <= static_cast<uint64_t>(p - buf_);
   }
 
-  bool CheckVerified(const uint8_t *p) {
-    if (!reuse_tracker_) return false;
-    if ((*reuse_tracker_)[p - buf_]) return true;
-    (*reuse_tracker_)[p - buf_] = true;
-    return false;
-  }
+  // Macro, since we want to escape from parent function & use lazy args.
+  #define FLEX_CHECK_VERIFIED(P, PACKED_TYPE) \
+    if (reuse_tracker_) { \
+      auto packed_type = PACKED_TYPE; \
+      auto existing = (*reuse_tracker_)[P - buf_]; \
+      if (existing == packed_type) return true; \
+      /* Fail verification if already set with different type! */ \
+      if (!Check(existing == 0)) return false; \
+      (*reuse_tracker_)[P - buf_] = packed_type; \
+    }
 
-  bool VerifyVector(const uint8_t *p, Type elem_type, uint8_t size_byte_width,
-                    uint8_t elem_byte_width) {
+  bool VerifyVector(Reference r, const uint8_t *p, Type elem_type) {
     // Any kind of nesting goes thru this function, so guard against that
     // here.
-    if (CheckVerified(p))
-      return true;
+    auto size_byte_width = r.byte_width_;
+    FLEX_CHECK_VERIFIED(p, PackedType(Builder::WidthB(size_byte_width), r.type_));
     if (!VerifyBeforePointer(p, size_byte_width))
       return false;
     auto sized = Sized(p, size_byte_width);
     auto num_elems = sized.size();
+    auto elem_byte_width =
+        r.type_ == FBT_STRING || r.type_ == FBT_BLOB ? uint8_t(1) : r.byte_width_;
     auto max_elems = SIZE_MAX / elem_byte_width;
     if (!Check(num_elems < max_elems))
       return false;  // Protect against byte_size overflowing.
@@ -1749,17 +1751,19 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
       static_cast<uint8_t>(ReadUInt64(p + byte_width, byte_width));
     if (!VerifyByteWidth(key_byte_with))
       return false;
-    return VerifyVector(p - off, FBT_KEY, key_byte_with, key_byte_with);
+    return VerifyVector(Reference(p, byte_width, key_byte_with, FBT_VECTOR_KEY),
+                        p - off, FBT_KEY);
   }
 
   bool VerifyKey(const uint8_t* p) {
-    if (CheckVerified(p))
-      return true;
+    FLEX_CHECK_VERIFIED(p, PackedType(BIT_WIDTH_8, FBT_KEY));
     while (p < buf_ + size_)
       if (*p++) return true;
     return false;
   }
 
+  #undef FLEX_CHECK_VERIFIED
+
   bool VerifyTerminator(const String &s) {
     return VerifyFromPointer(reinterpret_cast<const uint8_t *>(s.c_str()),
                              s.size() + 1);
@@ -1787,26 +1791,26 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
       case FBT_KEY:
         return VerifyKey(p);
       case FBT_MAP:
-        return VerifyVector(p, FBT_NULL, r.byte_width_, r.byte_width_) &&
+        return VerifyVector(r, p, FBT_NULL) &&
                VerifyKeys(p, r.byte_width_);
       case FBT_VECTOR:
-        return VerifyVector(p, FBT_NULL, r.byte_width_, r.byte_width_);
+        return VerifyVector(r, p, FBT_NULL);
       case FBT_VECTOR_INT:
-        return VerifyVector(p, FBT_INT, r.byte_width_, r.byte_width_);
+        return VerifyVector(r, p, FBT_INT);
       case FBT_VECTOR_BOOL:
       case FBT_VECTOR_UINT:
-        return VerifyVector(p, FBT_UINT, r.byte_width_, r.byte_width_);
+        return VerifyVector(r, p, FBT_UINT);
       case FBT_VECTOR_FLOAT:
-        return VerifyVector(p, FBT_FLOAT, r.byte_width_, r.byte_width_);
+        return VerifyVector(r, p, FBT_FLOAT);
       case FBT_VECTOR_KEY:
-        return VerifyVector(p, FBT_KEY, r.byte_width_, r.byte_width_);
+        return VerifyVector(r, p, FBT_KEY);
       case FBT_VECTOR_STRING_DEPRECATED:
         // Use of FBT_KEY here intentional, see elsewhere.
-        return VerifyVector(p, FBT_KEY, r.byte_width_, r.byte_width_);
+        return VerifyVector(r, p, FBT_KEY);
       case FBT_BLOB:
-        return VerifyVector(p, FBT_UINT, r.byte_width_, 1);
+        return VerifyVector(r, p, FBT_UINT);
       case FBT_STRING:
-        return VerifyVector(p, FBT_UINT, r.byte_width_, 1) &&
+        return VerifyVector(r, p, FBT_UINT) &&
                VerifyTerminator(String(p, r.byte_width_));
       case FBT_VECTOR_INT2:
       case FBT_VECTOR_UINT2:
@@ -1842,11 +1846,12 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
  private:
   const uint8_t *buf_;
   size_t size_;
-  std::vector<bool> *reuse_tracker_;
+  std::vector<uint8_t> *reuse_tracker_;
 };
 
 // Utility function that contructs the Verifier for you, see above for parameters.
-inline bool VerifyBuffer(const uint8_t *buf, size_t buf_len, std::vector<bool> *reuse_tracker) {
+inline bool VerifyBuffer(const uint8_t *buf, size_t buf_len,
+                         std::vector<uint8_t> *reuse_tracker = nullptr) {
   Verifier verifier(buf, buf_len, reuse_tracker);
   return verifier.VerifyBuffer();
 }
@@ -1861,7 +1866,7 @@ inline bool VerifyNestedFlexBuffer(const flatbuffers::Vector<uint8_t> *nv,
   if (!nv) return true;
   return verifier.Check(
     flexbuffers::VerifyBuffer(nv->data(), nv->size(),
-                              &verifier.GetReuseVector()));
+                              verifier.GetFlexReuseTracker()));
 }
 #endif
 
index 5198dcc..dfa3da8 100644 (file)
@@ -35,7 +35,8 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
         num_tables_(0),
         max_tables_(_max_tables),
         upper_bound_(0),
-        check_alignment_(_check_alignment) {
+        check_alignment_(_check_alignment),
+        flex_reuse_tracker_(nullptr) {
     FLATBUFFERS_ASSERT(size_ < FLATBUFFERS_MAX_BUFFER_SIZE);
   }
 
@@ -254,7 +255,13 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
     // clang-format on
   }
 
-  std::vector<bool> &GetReuseVector() { return reuse_tracker_; }
+  std::vector<uint8_t> *GetFlexReuseTracker() {
+    return flex_reuse_tracker_;
+  }
+
+  void SetFlexReuseTracker(std::vector<uint8_t> *rt) {
+    flex_reuse_tracker_ = rt;
+  }
 
  private:
   const uint8_t *buf_;
@@ -265,9 +272,7 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
   uoffset_t max_tables_;
   mutable size_t upper_bound_;
   bool check_alignment_;
-  // This is here for nested FlexBuffers, cheap if not touched.
-  // TODO: allow user to supply memory for this.
-  std::vector<bool> reuse_tracker_;
+  std::vector<uint8_t> *flex_reuse_tracker_;
 };
 
 }  // namespace flatbuffers
index 91220b4..2398b38 100644 (file)
@@ -618,7 +618,7 @@ int FlatCompiler::Compile(int argc, const char **argv) {
         if (opts.lang_to_generate == IDLOptions::kJson) {
           auto data = reinterpret_cast<const uint8_t *>(contents.c_str());
           auto size = contents.size();
-          std::vector<bool> reuse_tracker;
+          std::vector<uint8_t> reuse_tracker;
           if (!flexbuffers::VerifyBuffer(data, size, &reuse_tracker))
             Error("flexbuffers file failed to verify: " + filename, false);
           parser->flex_root_ = flexbuffers::GetRoot(data, size);
index 1e6e9d3..2c490a8 100644 (file)
@@ -233,6 +233,8 @@ void AccessFlatBufferTest(const uint8_t *flatbuf, size_t length,
                           bool pooled = true) {
   // First, verify the buffers integrity (optional)
   flatbuffers::Verifier verifier(flatbuf, length);
+  std::vector<uint8_t> flex_reuse_tracker;
+  verifier.SetFlexReuseTracker(&flex_reuse_tracker);
   TEST_EQ(VerifyMonsterBuffer(verifier), true);
 
   // clang-format off
@@ -3022,7 +3024,7 @@ void FlexBuffersTest() {
   #endif
   // clang-format on
 
-  std::vector<bool> reuse_tracker;
+  std::vector<uint8_t> reuse_tracker;
   TEST_EQ(flexbuffers::VerifyBuffer(slb.GetBuffer().data(), slb.GetBuffer().size(),
                            &reuse_tracker), true);