Registering a type now rebuilds it out of memory owned by the manager.
authorAlan Baker <alanbaker@google.com>
Mon, 5 Feb 2018 14:47:05 +0000 (09:47 -0500)
committerAlan Baker <alanbaker@google.com>
Tue, 6 Feb 2018 15:17:56 +0000 (10:17 -0500)
* Added TypeManager::RebuildType
 * rebuilds the type and its constituent types in terms of memory owned
 by the manager.
 * Used by TypeManager::RegisterType to properly allocate memory
* Adding an unit test to expose the issue
* Added some tests to provide coverage of RebuildType
* Added an accessor to the target pointer for a forward pointer

source/opt/type_manager.cpp
source/opt/type_manager.h
source/opt/types.h
test/opt/type_manager_test.cpp

index b0e6114..4d201ef 100644 (file)
@@ -349,16 +349,145 @@ void TypeManager::CreateDecoration(uint32_t target,
   context()->get_def_use_mgr()->AnalyzeInstUse(inst);
 }
 
+Type* TypeManager::RebuildType(const Type& type) {
+  // The comparison and hash on the type pool will avoid inserting the rebuilt
+  // type if an equivalent type already exists. The rebuilt type will be deleted
+  // when it goes out of scope at the end of the function in that case. Repeated
+  // insertions of the same Type will, at most, keep one corresponding object in
+  // the type pool.
+  std::unique_ptr<Type> rebuilt_ty;
+  switch (type.kind()) {
+#define DefineNoSubtypeCase(kind)             \
+  case Type::k##kind:                         \
+    rebuilt_ty.reset(type.Clone().release()); \
+    break;
+    DefineNoSubtypeCase(Void);
+    DefineNoSubtypeCase(Bool);
+    DefineNoSubtypeCase(Integer);
+    DefineNoSubtypeCase(Float);
+    DefineNoSubtypeCase(Sampler);
+    DefineNoSubtypeCase(Opaque);
+    DefineNoSubtypeCase(Event);
+    DefineNoSubtypeCase(DeviceEvent);
+    DefineNoSubtypeCase(ReserveId);
+    DefineNoSubtypeCase(Queue);
+    DefineNoSubtypeCase(Pipe);
+    DefineNoSubtypeCase(PipeStorage);
+    DefineNoSubtypeCase(NamedBarrier);
+#undef DefineNoSubtypeCase
+    case Type::kVector: {
+      const Vector* vec_ty = type.AsVector();
+      const Type* ele_ty = vec_ty->element_type();
+      rebuilt_ty.reset(
+          new Vector(RebuildType(*ele_ty), vec_ty->element_count()));
+      break;
+    }
+    case Type::kMatrix: {
+      const Matrix* mat_ty = type.AsMatrix();
+      const Type* ele_ty = mat_ty->element_type();
+      rebuilt_ty.reset(
+          new Matrix(RebuildType(*ele_ty), mat_ty->element_count()));
+      break;
+    }
+    case Type::kImage: {
+      const Image* image_ty = type.AsImage();
+      const Type* ele_ty = image_ty->sampled_type();
+      rebuilt_ty.reset(new Image(RebuildType(*ele_ty), image_ty->dim(),
+                                 image_ty->depth(), image_ty->is_arrayed(),
+                                 image_ty->is_multisampled(),
+                                 image_ty->sampled(), image_ty->format(),
+                                 image_ty->access_qualifier()));
+      break;
+    }
+    case Type::kSampledImage: {
+      const SampledImage* image_ty = type.AsSampledImage();
+      const Type* ele_ty = image_ty->image_type();
+      rebuilt_ty.reset(
+
+          new SampledImage(RebuildType(*ele_ty)));
+      break;
+    }
+    case Type::kArray: {
+      const Array* array_ty = type.AsArray();
+      const Type* ele_ty = array_ty->element_type();
+      rebuilt_ty.reset(new Array(RebuildType(*ele_ty), array_ty->LengthId()));
+      break;
+    }
+    case Type::kRuntimeArray: {
+      const RuntimeArray* array_ty = type.AsRuntimeArray();
+      const Type* ele_ty = array_ty->element_type();
+      rebuilt_ty.reset(new RuntimeArray(RebuildType(*ele_ty)));
+      break;
+    }
+    case Type::kStruct: {
+      const Struct* struct_ty = type.AsStruct();
+      std::vector<Type*> subtypes;
+      subtypes.reserve(struct_ty->element_types().size());
+      for (const auto* ele_ty : struct_ty->element_types()) {
+        subtypes.push_back(RebuildType(*ele_ty));
+      }
+      rebuilt_ty.reset(new Struct(subtypes));
+      Struct* rebuilt_struct = rebuilt_ty->AsStruct();
+      for (auto pair : struct_ty->element_decorations()) {
+        uint32_t index = pair.first;
+        for (const auto& dec : pair.second) {
+          // Explicit copy intended.
+          std::vector<uint32_t> copy(dec);
+          rebuilt_struct->AddMemberDecoration(index, std::move(copy));
+        }
+      }
+      break;
+    }
+    case Type::kPointer: {
+      const Pointer* pointer_ty = type.AsPointer();
+      const Type* ele_ty = pointer_ty->pointee_type();
+      rebuilt_ty.reset(
+          new Pointer(RebuildType(*ele_ty), pointer_ty->storage_class()));
+      break;
+    }
+    case Type::kFunction: {
+      const Function* function_ty = type.AsFunction();
+      const Type* ret_ty = function_ty->return_type();
+      std::vector<Type*> param_types;
+      param_types.reserve(function_ty->param_types().size());
+      for (const auto* param_ty : function_ty->param_types()) {
+        param_types.push_back(RebuildType(*param_ty));
+      }
+      rebuilt_ty.reset(new Function(RebuildType(*ret_ty), param_types));
+      break;
+    }
+    case Type::kForwardPointer: {
+      const ForwardPointer* forward_ptr_ty = type.AsForwardPointer();
+      rebuilt_ty.reset(new ForwardPointer(forward_ptr_ty->target_id(),
+                                          forward_ptr_ty->storage_class()));
+      const Pointer* target_ptr = forward_ptr_ty->target_pointer();
+      if (target_ptr) {
+        rebuilt_ty->AsForwardPointer()->SetTargetPointer(
+            RebuildType(*target_ptr)->AsPointer());
+      }
+      break;
+    }
+    default:
+      assert(false && "Unhandled type");
+      return nullptr;
+  }
+  for (const auto& dec : type.decorations()) {
+    // Explicit copy intended.
+    std::vector<uint32_t> copy(dec);
+    rebuilt_ty->AddDecoration(std::move(copy));
+  }
+
+  return type_pool_.insert(std::move(rebuilt_ty)).first->get();
+}
+
 void TypeManager::RegisterType(uint32_t id, const Type& type) {
-  // The comparison and hash on the type pool will avoid inserting the clone if
-  // an equivalent type already exists. The clone will be deleted when it goes
-  // out of scope at the end of the function in that case. Repeated insertions
-  // of the same Type will atmost keep one corresponding object in the type
+  // Rebuild |type| so it and all its constituent types are owned by the type
   // pool.
-  auto pair = type_pool_.insert(type.Clone());
-  id_to_type_[id] = pair.first->get();
-  if (GetId(pair.first->get()) == 0) {
-    type_to_id_[pair.first->get()] = id;
+  Type* rebuilt = RebuildType(type);
+  assert(rebuilt->IsSame(&type));
+  id_to_type_[id] = rebuilt;
+  if (GetId(rebuilt) == 0) {
+    type_to_id_[rebuilt] = id;
   }
 }
 
index ded313e..0550a57 100644 (file)
@@ -162,6 +162,11 @@ class TypeManager {
   // |type| (e.g. should be called in loop of |type|'s decorations).
   void AttachDecoration(const spvtools::ir::Instruction& inst, Type* type);
 
+  // Returns an equivalent pointer to |type| built in terms of pointers owned by
+  // |type_pool_|. For example, if |type| is a vec3 of bool, it will be rebuilt
+  // replacing the bool subtype with one owned by |type_pool_|.
+  Type* RebuildType(const Type& type);
+
   const MessageConsumer& consumer_;  // Message consumer.
   spvtools::ir::IRContext* context_;
   IdToTypeMap id_to_type_;  // Mapping from ids to their type representations.
index 4b1085d..34ee3ac 100644 (file)
@@ -485,6 +485,7 @@ class ForwardPointer : public Type {
   uint32_t target_id() const { return target_id_; }
   void SetTargetPointer(Pointer* pointer) { pointer_ = pointer; }
   SpvStorageClass storage_class() const { return storage_class_; }
+  const Pointer* target_pointer() const { return pointer_; }
 
   bool IsSame(const Type* that) const override;
   std::string str() const override;
index d1e8b50..b235db4 100644 (file)
@@ -644,6 +644,84 @@ OpMemoryModel Logical GLSL450
   EXPECT_TRUE(type->IsSame(&st));
 }
 
+TEST(TypeManager, RegisterAndRemoveId) {
+  const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+%1 = OpTypeInt 32 0
+)";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(context, nullptr);
+
+  uint32_t id = 2u;
+  {
+    // Ensure that u32 goes out of scope.
+    Integer u32(32, false);
+    Struct st({&u32});
+    context->get_type_mgr()->RegisterType(id, st);
+  }
+
+  context->get_type_mgr()->RemoveId(id);
+  EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id));
+}
+
+TEST(TypeManager, RegisterAndRemoveIdAllTypes) {
+  const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+)";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(context, nullptr);
+
+  std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
+  uint32_t id = 1u;
+  for (auto& t : types) {
+    context->get_type_mgr()->RegisterType(id, *t);
+    EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id));
+  }
+  types.clear();
+
+  for (; id > 0; --id) {
+    context->get_type_mgr()->RemoveId(id);
+    EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id));
+  }
+}
+
+TEST(TypeManager, RegisterAndRemoveIdWithDecorations) {
+  const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+%1 = OpTypeInt 32 0
+)";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(context, nullptr);
+
+  uint32_t id = 2u;
+  {
+    Integer u32(32, false);
+    Struct st({&u32, &u32});
+    st.AddDecoration({10});
+    st.AddDecoration({11});
+    st.AddMemberDecoration(0, {{35, 4}});
+    st.AddMemberDecoration(1, {{35, 4}});
+    st.AddMemberDecoration(1, {{36, 5}});
+    context->get_type_mgr()->RegisterType(id, st);
+    EXPECT_EQ(st, *context->get_type_mgr()->GetType(id));
+  }
+
+  context->get_type_mgr()->RemoveId(id);
+  EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id));
+}
+
 #ifdef SPIRV_EFFCEE
 TEST(TypeManager, GetTypeInstructionInt) {
   const std::string text = R"(