Fixed Reflection Verifier not handling vectors of unions.
authorWouter van Oortmerssen <aardappel@gmail.com>
Tue, 17 Sep 2019 00:48:54 +0000 (17:48 -0700)
committerWouter van Oortmerssen <aardappel@gmail.com>
Tue, 17 Sep 2019 00:48:54 +0000 (17:48 -0700)
Change-Id: Ie94386ff8e10fd2a964bd9155139b50953746a37

include/flatbuffers/flatbuffers.h
src/reflection.cpp
tests/test.cpp

index 00593d6..46457bd 100644 (file)
@@ -2109,6 +2109,11 @@ class Verifier FLATBUFFERS_FINAL_CLASS {
     return VerifyAlignment<T>(elem) && Verify(elem, sizeof(T));
   }
 
+  bool VerifyFromPointer(const uint8_t *p, size_t len) {
+    auto o = static_cast<size_t>(p - buf_);
+    return Verify(o, len);
+  }
+
   // Verify relative to a known-good base pointer.
   bool Verify(const uint8_t *base, voffset_t elem_off, size_t elem_len) const {
     return Verify(static_cast<size_t>(base - buf_) + elem_off, elem_len);
index 5055959..fc211c5 100644 (file)
@@ -515,6 +515,32 @@ bool VerifyObject(flatbuffers::Verifier &v, const reflection::Schema &schema,
                   const reflection::Object &obj,
                   const flatbuffers::Table *table, bool required);
 
+bool VerifyUnion(flatbuffers::Verifier &v, const reflection::Schema &schema,
+                 uint8_t utype, const uint8_t *elem,
+                 const reflection::Field &union_field) {
+  if (!utype) return true;  // Not present.
+  auto fb_enum = schema.enums()->Get(union_field.type()->index());
+  if (utype >= fb_enum->values()->size()) return false;
+  auto elem_type = fb_enum->values()->Get(utype)->union_type();
+  switch (elem_type->base_type()) {
+    case reflection::Obj: {
+      auto elem_obj = schema.objects()->Get(elem_type->index());
+      if (elem_obj->is_struct()) {
+        return v.VerifyFromPointer(elem, elem_obj->bytesize());
+      } else {
+        return VerifyObject(v, schema, *elem_obj,
+                            reinterpret_cast<const flatbuffers::Table *>(elem),
+                            true);
+      }
+    }
+    case reflection::String:
+      return v.VerifyString(
+            reinterpret_cast<const flatbuffers::String *>(elem));
+    default:
+      return false;
+  }
+}
+
 bool VerifyVector(flatbuffers::Verifier &v, const reflection::Schema &schema,
                   const flatbuffers::Table &table,
                   const reflection::Field &vec_field) {
@@ -522,7 +548,6 @@ bool VerifyVector(flatbuffers::Verifier &v, const reflection::Schema &schema,
   if (!table.VerifyField<uoffset_t>(v, vec_field.offset())) return false;
 
   switch (vec_field.type()->element()) {
-    case reflection::None: FLATBUFFERS_ASSERT(false); break;
     case reflection::UType:
       return v.VerifyVector(flatbuffers::GetFieldV<uint8_t>(table, vec_field));
     case reflection::Bool:
@@ -552,48 +577,55 @@ bool VerifyVector(flatbuffers::Verifier &v, const reflection::Schema &schema,
         return false;
       }
     }
-    case reflection::Vector: FLATBUFFERS_ASSERT(false); break;
     case reflection::Obj: {
       auto obj = schema.objects()->Get(vec_field.type()->index());
       if (obj->is_struct()) {
-        if (!VerifyVectorOfStructs(v, table, vec_field.offset(), *obj,
-                                   vec_field.required())) {
-          return false;
-        }
+        return VerifyVectorOfStructs(v, table, vec_field.offset(), *obj,
+                                     vec_field.required());
       } else {
         auto vec =
             flatbuffers::GetFieldV<flatbuffers::Offset<flatbuffers::Table>>(
                 table, vec_field);
         if (!v.VerifyVector(vec)) return false;
-        if (vec) {
-          for (uoffset_t j = 0; j < vec->size(); j++) {
-            if (!VerifyObject(v, schema, *obj, vec->Get(j), true)) {
-              return false;
-            }
+        if (!vec) return true;
+        for (uoffset_t j = 0; j < vec->size(); j++) {
+          if (!VerifyObject(v, schema, *obj, vec->Get(j), true)) {
+            return false;
           }
         }
+        return true;
+      }
+    }
+    case reflection::Union: {
+      auto vec = flatbuffers::GetFieldV<flatbuffers::Offset<uint8_t>>(table,
+                                                                     vec_field);
+      if (!v.VerifyVector(vec)) return false;
+      if (!vec) return true;
+      auto type_vec = table.GetPointer<Vector<uint8_t> *>
+                          (vec_field.offset() - sizeof(voffset_t));
+      if (!v.VerifyVector(type_vec)) return false;
+      for (uoffset_t j = 0; j < vec->size(); j++) {
+        //  get union type from the prev field
+        auto utype = type_vec->Get(j);
+        auto elem = vec->Get(j);
+        if (!VerifyUnion(v, schema, utype, elem, vec_field))
+          return false;
       }
       return true;
     }
-    case reflection::Union: FLATBUFFERS_ASSERT(false); break;
-    default: FLATBUFFERS_ASSERT(false); break;
+    case reflection::Vector:
+    case reflection::None:
+    default:
+      FLATBUFFERS_ASSERT(false);
+      return false;
   }
-
-  return false;
 }
 
 bool VerifyObject(flatbuffers::Verifier &v, const reflection::Schema &schema,
                   const reflection::Object &obj,
                   const flatbuffers::Table *table, bool required) {
-  if (!table) {
-    if (!required)
-      return true;
-    else
-      return false;
-  }
-
+  if (!table) return !required;
   if (!table->VerifyTableStart(v)) return false;
-
   for (uoffset_t i = 0; i < obj.fields()->size(); i++) {
     auto field_def = obj.fields()->Get(i);
     switch (field_def->type()->base_type()) {
@@ -631,7 +663,8 @@ bool VerifyObject(flatbuffers::Verifier &v, const reflection::Schema &schema,
         }
         break;
       case reflection::Vector:
-        if (!VerifyVector(v, schema, *table, *field_def)) return false;
+        if (!VerifyVector(v, schema, *table, *field_def))
+          return false;
         break;
       case reflection::Obj: {
         auto child_obj = schema.objects()->Get(field_def->type()->index());
@@ -653,20 +686,16 @@ bool VerifyObject(flatbuffers::Verifier &v, const reflection::Schema &schema,
         //  get union type from the prev field
         voffset_t utype_offset = field_def->offset() - sizeof(voffset_t);
         auto utype = table->GetField<uint8_t>(utype_offset, 0);
-        if (utype != 0) {
-          // Means we have this union field present
-          auto fb_enum = schema.enums()->Get(field_def->type()->index());
-          if (utype >= fb_enum->values()->size()) return false;
-          auto child_obj = fb_enum->values()->Get(utype)->object();
-          if (!VerifyObject(v, schema, *child_obj,
-                            flatbuffers::GetFieldT(*table, *field_def),
-                            field_def->required())) {
-            return false;
-          }
+        auto uval = reinterpret_cast<const uint8_t *>(
+                      flatbuffers::GetFieldT(*table, *field_def));
+        if (!VerifyUnion(v, schema, utype, uval, *field_def)) {
+          return false;
         }
         break;
       }
-      default: FLATBUFFERS_ASSERT(false); break;
+      default:
+        FLATBUFFERS_ASSERT(false);
+        break;
     }
   }
 
index dfa0d6a..939550d 100644 (file)
@@ -2348,12 +2348,11 @@ void UnionVectorTest() {
                   fbb.CreateStruct(Rapunzel(/*hair_length=*/6)).Union(),
                   fbb.CreateVector(types), fbb.CreateVector(characters));
   FinishMovieBuffer(fbb, movie_offset);
-  auto buf = fbb.GetBufferPointer();
 
-  flatbuffers::Verifier verifier(buf, fbb.GetSize());
+  flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize());
   TEST_EQ(VerifyMovieBuffer(verifier), true);
 
-  auto flat_movie = GetMovie(buf);
+  auto flat_movie = GetMovie(fbb.GetBufferPointer());
 
   auto TestMovie = [](const Movie *movie) {
     TEST_EQ(movie->main_character_type() == Character_Rapunzel, true);
@@ -2485,6 +2484,13 @@ void UnionVectorTest() {
       "  ]\n"
       "}\n");
 
+  // Simple test with reflection.
+  parser.Serialize();
+  auto schema = reflection::GetSchema(parser.builder_.GetBufferPointer());
+  auto ok = flatbuffers::Verify(*schema, *schema->root_table(),
+                                fbb.GetBufferPointer(), fbb.GetSize());
+  TEST_EQ(ok, true);
+
   flatbuffers::Parser parser2(idl_opts);
   TEST_EQ(parser2.Parse("struct Bool { b:bool; }"
                         "union Any { Bool }"