Improving the usability of the type manager. The type manager hashes
authorAlan Baker <alanbaker@google.com>
Fri, 8 Dec 2017 20:33:19 +0000 (15:33 -0500)
committerAlan Baker <alanbaker@google.com>
Mon, 18 Dec 2017 13:20:56 +0000 (08:20 -0500)
types. This allows the lookup of type declaration ids from arbitrarily
constructed types. Users should be cautious when dealing with non-unique
types (structs and potentially pointers) to get the exact id if
necessary.

* Changed the spec composite constant folder to handle ambiguous composites
* Added functionality to create necessary instructions for a type
* Added ability to remove ids from the type manager

21 files changed:
source/opt/constants.cpp
source/opt/constants.h
source/opt/fold_spec_constant_op_and_composite_pass.cpp
source/opt/fold_spec_constant_op_and_composite_pass.h
source/opt/inline_pass.cpp
source/opt/instruction.cpp
source/opt/instruction.h
source/opt/ir_context.cpp
source/opt/ir_context.h
source/opt/optimizer.cpp
source/opt/passes.h
source/opt/scalar_replacement_pass.cpp
source/opt/set_spec_constant_default_value_pass.cpp
source/opt/strength_reduction_pass.cpp
source/opt/strength_reduction_pass.h
source/opt/type_manager.cpp
source/opt/type_manager.h
source/opt/types.cpp
source/opt/types.h
test/opt/type_manager_test.cpp
test/opt/types_test.cpp

index 4e6ba68..c2b1a61 100644 (file)
@@ -50,12 +50,13 @@ std::vector<const analysis::Constant*> ConstantManager::GetConstantsFromIds(
 }
 
 ir::Instruction* ConstantManager::BuildInstructionAndAddToModule(
-    std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos) {
+    std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos,
+    uint32_t type_id) {
   analysis::Constant* new_const = c.get();
   uint32_t new_id = context()->TakeNextId();
   const_val_to_id_[new_const] = new_id;
   id_to_const_val_[new_id] = std::move(c);
-  auto new_inst = CreateInstruction(new_id, new_const);
+  auto new_inst = CreateInstruction(new_id, new_const, type_id);
   if (!new_inst) return nullptr;
   auto* new_inst_ptr = new_inst.get();
   *pos = pos->InsertBefore(std::move(new_inst));
@@ -157,41 +158,40 @@ std::unique_ptr<analysis::Constant> ConstantManager::CreateConstantFromInst(
 }
 
 std::unique_ptr<ir::Instruction> ConstantManager::CreateInstruction(
-    uint32_t id, analysis::Constant* c) const {
+    uint32_t id, analysis::Constant* c, uint32_t type_id) const {
+  uint32_t type =
+      (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id;
   if (c->AsNullConstant()) {
-    return MakeUnique<ir::Instruction>(
-        context(), SpvOp::SpvOpConstantNull,
-        context()->get_type_mgr()->GetId(c->type()), id,
-        std::initializer_list<ir::Operand>{});
+    return MakeUnique<ir::Instruction>(context(), SpvOp::SpvOpConstantNull,
+                                       type, id,
+                                       std::initializer_list<ir::Operand>{});
   } else if (analysis::BoolConstant* bc = c->AsBoolConstant()) {
     return MakeUnique<ir::Instruction>(
         context(),
         bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
-        context()->get_type_mgr()->GetId(c->type()), id,
-        std::initializer_list<ir::Operand>{});
+        type, id, std::initializer_list<ir::Operand>{});
   } else if (analysis::IntConstant* ic = c->AsIntConstant()) {
     return MakeUnique<ir::Instruction>(
-        context(), SpvOp::SpvOpConstant,
-        context()->get_type_mgr()->GetId(c->type()), id,
+        context(), SpvOp::SpvOpConstant, type, id,
         std::initializer_list<ir::Operand>{ir::Operand(
             spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
             ic->words())});
   } else if (analysis::FloatConstant* fc = c->AsFloatConstant()) {
     return MakeUnique<ir::Instruction>(
-        context(), SpvOp::SpvOpConstant,
-        context()->get_type_mgr()->GetId(c->type()), id,
+        context(), SpvOp::SpvOpConstant, type, id,
         std::initializer_list<ir::Operand>{ir::Operand(
             spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
             fc->words())});
   } else if (analysis::CompositeConstant* cc = c->AsCompositeConstant()) {
-    return CreateCompositeInstruction(id, cc);
+    return CreateCompositeInstruction(id, cc, type_id);
   } else {
     return nullptr;
   }
 }
 
 std::unique_ptr<ir::Instruction> ConstantManager::CreateCompositeInstruction(
-    uint32_t result_id, analysis::CompositeConstant* cc) const {
+    uint32_t result_id, analysis::CompositeConstant* cc,
+    uint32_t type_id) const {
   std::vector<ir::Operand> operands;
   for (const analysis::Constant* component_const : cc->GetComponents()) {
     uint32_t id = FindRecordedConstant(component_const);
@@ -204,10 +204,10 @@ std::unique_ptr<ir::Instruction> ConstantManager::CreateCompositeInstruction(
     operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
                           std::initializer_list<uint32_t>{id});
   }
-  return MakeUnique<ir::Instruction>(
-      context(), SpvOp::SpvOpConstantComposite,
-      context()->get_type_mgr()->GetId(cc->type()), result_id,
-      std::move(operands));
+  uint32_t type =
+      (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;
+  return MakeUnique<ir::Instruction>(context(), SpvOp::SpvOpConstantComposite,
+                                     type, result_id, std::move(operands));
 }
 
 }  // namespace analysis
index d41eed6..eda26b4 100644 (file)
@@ -325,22 +325,39 @@ class ConstantManager {
   // points to the same instruction before and after the insertion. This is the
   // only method that actually manages id creation/assignment and instruction
   // creation/insertion for a new Constant instance.
+  //
+  // |type_id| is an optional argument for disambiguating equivalent types. If
+  // |type_id| is specified, it is used as the type of the constant. Otherwise
+  // the type of the constant is derived by getting an id from the type manager
+  // for |c|.
   ir::Instruction* BuildInstructionAndAddToModule(
-      std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos);
+      std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos,
+      uint32_t type_id = 0);
 
   // Creates an instruction with the given result id to declare a constant
   // represented by the given Constant instance. Returns an unique pointer to
   // the created instruction if the instruction can be created successfully.
   // Otherwise, returns a null pointer.
+  //
+  // |type_id| is an optional argument for disambiguating equivalent types. If
+  // |type_id| is specified, it is used as the type of the constant. Otherwise
+  // the type of the constant is derived by getting an id from the type manager
+  // for |c|.
   std::unique_ptr<ir::Instruction> CreateInstruction(
-      uint32_t result_id, analysis::Constant* c) const;
+      uint32_t result_id, analysis::Constant* c, uint32_t type_id = 0) const;
 
   // Creates an OpConstantComposite instruction with the given result id and
   // the CompositeConst instance which represents a composite constant. Returns
   // an unique pointer to the created instruction if succeeded. Otherwise
   // returns a null pointer.
+  //
+  // |type_id| is an optional argument for disambiguating equivalent types. If
+  // |type_id| is specified, it is used as the type of the constant. Otherwise
+  // the type of the constant is derived by getting an id from the type manager
+  // for |c|.
   std::unique_ptr<ir::Instruction> CreateCompositeInstruction(
-      uint32_t result_id, analysis::CompositeConstant* cc) const;
+      uint32_t result_id, analysis::CompositeConstant* cc,
+      uint32_t type_id = 0) const;
 
   // A helper function to get the result type of the given instruction. Returns
   // nullptr if the instruction does not have a type id (type id is 0).
index 2cdefa3..831906f 100644 (file)
@@ -160,6 +160,15 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
   return true;
 }
 
+uint32_t FoldSpecConstantOpAndCompositePass::GetTypeComponent(
+    uint32_t typeId, uint32_t element) const {
+  ir::Instruction* type = context()->get_def_use_mgr()->GetDef(typeId);
+  uint32_t subtype = type->GetTypeComponent(element);
+  assert(subtype != 0);
+
+  return subtype;
+}
+
 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
     ir::Module::inst_iterator* pos) {
   ir::Instruction* inst = &**pos;
@@ -167,21 +176,26 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
          "OpSpecConstantOp CompositeExtract requires at least two non-type "
          "non-opcode operands.");
   assert(inst->GetInOperand(1).type == SPV_OPERAND_TYPE_ID &&
-         "The vector operand must have a SPV_OPERAND_TYPE_ID type");
+         "The composite operand must have a SPV_OPERAND_TYPE_ID type");
   assert(
       inst->GetInOperand(2).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
       "The literal operand must have a SPV_OPERAND_TYPE_LITERAL_INTEGER type");
 
   // Note that for OpSpecConstantOp, the second in-operand is the first id
   // operand. The first in-operand is the spec opcode.
+  uint32_t source = inst->GetSingleWordInOperand(1);
+  uint32_t type = context()->get_def_use_mgr()->GetDef(source)->type_id();
   analysis::Constant* first_operand_const =
-      context()->get_constant_mgr()->FindRecordedConstant(
-          inst->GetSingleWordInOperand(1));
+      context()->get_constant_mgr()->FindRecordedConstant(source);
   if (!first_operand_const) return nullptr;
 
   const analysis::Constant* current_const = first_operand_const;
   for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
     uint32_t literal = inst->GetSingleWordInOperand(i);
+    type = GetTypeComponent(type, literal);
+  }
+  for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
+    uint32_t literal = inst->GetSingleWordInOperand(i);
     if (const analysis::CompositeConstant* composite_const =
             current_const->AsCompositeConstant()) {
       // Case 1: current constant is a non-null composite type constant.
@@ -195,14 +209,14 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
       return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
           context()->get_constant_mgr()->CreateConstant(
               context()->get_constant_mgr()->GetType(inst), {}),
-          pos);
+          pos, type);
     } else {
       // Dereferencing a non-composite constant. Invalid case.
       return nullptr;
     }
   }
   return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-      current_const->Copy(), pos);
+      current_const->Copy(), pos, type);
 }
 
 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
index 4253789..5f901ee 100644 (file)
@@ -79,6 +79,11 @@ class FoldSpecConstantOpAndCompositePass : public Pass {
   // if succeeded, otherwise return nullptr.
   ir::Instruction* DoComponentWiseOperation(
       ir::Module::inst_iterator* inst_iter_ptr);
+
+  // Returns the |element|'th subtype of |type|.
+  //
+  // |type| must be a composite type.
+  uint32_t GetTypeComponent(uint32_t type, uint32_t element) const;
 };
 
 }  // namespace opt
index a050642..8bd67f1 100644 (file)
@@ -33,6 +33,17 @@ namespace opt {
 
 uint32_t InlinePass::FindPointerToType(uint32_t type_id,
                                        SpvStorageClass storage_class) {
+  analysis::Type* pointeeTy;
+  std::unique_ptr<analysis::Pointer> pointerTy;
+  std::tie(pointeeTy, pointerTy) =
+      context()->get_type_mgr()->GetTypeAndPointerType(type_id,
+                                                       SpvStorageClassFunction);
+  if (type_id == context()->get_type_mgr()->GetId(pointeeTy)) {
+    // Non-ambiguous type. Get the pointer type through the type manager.
+    return context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
+  }
+
+  // Ambiguous type, do a linear search.
   ir::Module::inst_iterator type_itr = get_module()->types_values_begin();
   for (; type_itr != get_module()->types_values_end(); ++type_itr) {
     const ir::Instruction* type_inst = &*type_itr;
@@ -54,6 +65,12 @@ uint32_t InlinePass::AddPointerToType(uint32_t type_id,
         {uint32_t(storage_class)}},
        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}}));
   context()->AddType(std::move(type_inst));
+  analysis::Type* pointeeTy;
+  std::unique_ptr<analysis::Pointer> pointerTy;
+  std::tie(pointeeTy, pointerTy) =
+      context()->get_type_mgr()->GetTypeAndPointerType(type_id,
+                                                       SpvStorageClassFunction);
+  context()->get_type_mgr()->RegisterType(resultId, *pointerTy);
   return resultId;
 }
 
index da52c95..914eb3f 100644 (file)
@@ -355,6 +355,26 @@ bool Instruction::IsReadOnlyVariableKernel() const {
   return storage_class == SpvStorageClassUniformConstant;
 }
 
+uint32_t Instruction::GetTypeComponent(uint32_t element) const {
+  uint32_t subtype = 0;
+  switch (opcode()) {
+    case SpvOpTypeStruct:
+      subtype = GetSingleWordInOperand(element);
+      break;
+    case SpvOpTypeArray:
+    case SpvOpTypeRuntimeArray:
+    case SpvOpTypeVector:
+    case SpvOpTypeMatrix:
+      // These types all have uniform subtypes.
+      subtype = GetSingleWordInOperand(0u);
+      break;
+    default:
+      break;
+  }
+
+  return subtype;
+}
+
 Instruction* Instruction::InsertBefore(
     std::vector<std::unique_ptr<Instruction>>&& list) {
   Instruction* first_node = list.front().get();
index 3ddf741..edb0419 100644 (file)
@@ -303,6 +303,10 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   // and return to its caller
   bool IsReturn() const { return spvOpcodeIsReturn(opcode()); }
 
+  // Returns the id for the |element|'th subtype. If the |this| is not a
+  // composite type, this function returns 0.
+  uint32_t GetTypeComponent(uint32_t element) const;
+
   // Returns true if this instruction is a basic block terminator.
   bool IsBlockTerminator() const {
     return spvOpcodeIsBlockTerminator(opcode());
index c899591..38038d7 100644 (file)
@@ -16,6 +16,8 @@
 #include "latest_version_glsl_std_450_header.h"
 #include "log.h"
 #include "mem_pass.h"
+#include "reflect.h"
+#include "spirv/1.0/GLSL.std.450.h"
 
 #include <cstring>
 
@@ -95,6 +97,10 @@ Instruction* IRContext::KillInst(ir::Instruction* inst) {
     }
   }
 
+  if (type_mgr_ && ir::IsTypeInst(inst->opcode())) {
+    type_mgr_->RemoveId(inst->result_id());
+  }
+
   Instruction* next_instruction = nullptr;
   if (inst->IsInAList()) {
     next_instruction = inst->NextNode();
index c0d21f1..10c42da 100644 (file)
@@ -241,7 +241,7 @@ class IRContext {
   // is never re-built.
   opt::analysis::TypeManager* get_type_mgr() {
     if (!type_mgr_)
-      type_mgr_.reset(new opt::analysis::TypeManager(consumer(), *module()));
+      type_mgr_.reset(new opt::analysis::TypeManager(consumer(), this));
     return type_mgr_.get();
   }
 
index 82e57f8..3109aed 100644 (file)
@@ -82,7 +82,8 @@ Optimizer& Optimizer::RegisterLegalizationPasses() {
 }
 
 Optimizer& Optimizer::RegisterPerformancePasses() {
-  return RegisterPass(CreateMergeReturnPass())
+  return RegisterPass(CreateRemoveDuplicatesPass())
+      .RegisterPass(CreateMergeReturnPass())
       .RegisterPass(CreateInlineExhaustivePass())
       .RegisterPass(CreateEliminateDeadFunctionsPass())
       .RegisterPass(CreateScalarReplacementPass())
@@ -102,7 +103,8 @@ Optimizer& Optimizer::RegisterPerformancePasses() {
 }
 
 Optimizer& Optimizer::RegisterSizePasses() {
-  return RegisterPass(CreateMergeReturnPass())
+  return RegisterPass(CreateRemoveDuplicatesPass())
+      .RegisterPass(CreateMergeReturnPass())
       .RegisterPass(CreateInlineExhaustivePass())
       .RegisterPass(CreateEliminateDeadFunctionsPass())
       .RegisterPass(CreateLocalAccessChainConvertPass())
@@ -289,6 +291,11 @@ Optimizer::PassToken CreateRedundancyEliminationPass() {
       MakeUnique<opt::RedundancyEliminationPass>());
 }
 
+Optimizer::PassToken CreateRemoveDuplicatesPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::RemoveDuplicatesPass>());
+}
+
 Optimizer::PassToken CreateScalarReplacementPass() {
   return MakeUnique<Optimizer::PassToken::Impl>(
       MakeUnique<opt::ScalarReplacementPass>());
index ba3b270..daa0948 100644 (file)
@@ -40,6 +40,7 @@
 #include "merge_return_pass.h"
 #include "null_pass.h"
 #include "redundancy_elimination.h"
+#include "remove_duplicates_pass.h"
 #include "scalar_replacement_pass.h"
 #include "set_spec_constant_default_value_pass.h"
 #include "strength_reduction_pass.h"
index fd146b2..23b0ce5 100644 (file)
@@ -21,6 +21,7 @@
 #include "types.h"
 
 #include <queue>
+#include <tuple>
 
 namespace spvtools {
 namespace opt {
@@ -335,8 +336,21 @@ uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
   auto iter = pointee_to_pointer_.find(id);
   if (iter != pointee_to_pointer_.end()) return iter->second;
 
-  // TODO(alanbaker): Make the type manager useful and then replace this code.
+  analysis::Type* pointeeTy;
+  std::unique_ptr<analysis::Pointer> pointerTy;
+  std::tie(pointeeTy, pointerTy) =
+      context()->get_type_mgr()->GetTypeAndPointerType(id,
+                                                       SpvStorageClassFunction);
   uint32_t ptrId = 0;
+  if (id == context()->get_type_mgr()->GetId(pointeeTy)) {
+    // Non-ambiguous type, just ask the type manager for an id.
+    ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
+    pointee_to_pointer_[id] = ptrId;
+    return ptrId;
+  }
+
+  // Ambiguous type. We must perform a linear search to try and find the right
+  // type.
   for (auto global : context()->types_values()) {
     if (global.opcode() == SpvOpTypePointer &&
         global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
@@ -345,7 +359,7 @@ uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
               libspirv::Extension::kSPV_KHR_variable_pointers) ||
           get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
         // If variable pointers is enabled, only reuse a decoration-less
-        // pointer of the correct type
+        // pointer of the correct type.
         ptrId = global.result_id();
         break;
       }
@@ -366,6 +380,8 @@ uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
   ir::Instruction* ptr = &*--context()->types_values_end();
   get_def_use_mgr()->AnalyzeInstDefUse(ptr);
   pointee_to_pointer_[id] = ptrId;
+  // Register with the type manager if necessary.
+  context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
 
   return ptrId;
 }
index 3ec786f..1b0cccd 100644 (file)
@@ -189,6 +189,8 @@ ir::Instruction* GetSpecIdTargetFromDecorationGroup(
 
 Pass::Status SetSpecConstantDefaultValuePass::Process(
     ir::IRContext* irContext) {
+  InitializeProcessing(irContext);
+
   // The operand index of decoration target in an OpDecorate instruction.
   const uint32_t kTargetIdOperandIndex = 0;
   // The operand index of the decoration literal in an OpDecorate instruction.
@@ -202,8 +204,6 @@ Pass::Status SetSpecConstantDefaultValuePass::Process(
   const uint32_t kOpSpecConstantLiteralInOperandIndex = 0;
 
   bool modified = false;
-  analysis::DefUseManager def_use_mgr(irContext->module());
-  analysis::TypeManager type_mgr(consumer(), *irContext->module());
   // Scan through all the annotation instructions to find 'OpDecorate SpecId'
   // instructions. Then extract the decoration target of those instructions.
   // The decoration targets should be spec constant defining instructions with
@@ -229,10 +229,10 @@ Pass::Status SetSpecConstantDefaultValuePass::Process(
     // Find the spec constant defining instruction. Note that the
     // target_id might be a decoration group id.
     ir::Instruction* spec_inst = nullptr;
-    if (ir::Instruction* target_inst = def_use_mgr.GetDef(target_id)) {
+    if (ir::Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) {
       if (target_inst->opcode() == SpvOp::SpvOpDecorationGroup) {
         spec_inst =
-            GetSpecIdTargetFromDecorationGroup(*target_inst, &def_use_mgr);
+            GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr());
       } else {
         spec_inst = target_inst;
       }
@@ -255,7 +255,8 @@ Pass::Status SetSpecConstantDefaultValuePass::Process(
       // with the type of the spec constant.
       const std::string& default_value_str = iter->second;
       bit_pattern = ParseDefaultValueStr(
-          default_value_str.c_str(), type_mgr.GetType(spec_inst->type_id()));
+          default_value_str.c_str(),
+          context()->get_type_mgr()->GetType(spec_inst->type_id()));
 
     } else {
       // Search for the new bit-pattern-form default value for this spec id.
@@ -266,7 +267,8 @@ Pass::Status SetSpecConstantDefaultValuePass::Process(
 
       // Gets the bit-pattern of the default value from the map directly.
       bit_pattern = ParseDefaultValueBitPattern(
-          iter->second, type_mgr.GetType(spec_inst->type_id()));
+          iter->second,
+          context()->get_type_mgr()->GetType(spec_inst->type_id()));
     }
 
     if (bit_pattern.empty()) continue;
index 6edb6f0..fd8ccf9 100644 (file)
@@ -125,18 +125,13 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
 }
 
 void StrengthReductionPass::FindIntTypesAndConstants() {
+  analysis::Integer int32(32, true);
+  int32_type_id_ = context()->get_type_mgr()->GetId(&int32);
+  analysis::Integer uint32(32, false);
+  uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32);
   for (auto iter = get_module()->types_values_begin();
        iter != get_module()->types_values_end(); ++iter) {
     switch (iter->opcode()) {
-      case SpvOp::SpvOpTypeInt:
-        if (iter->GetSingleWordOperand(1) == 32) {
-          if (iter->GetSingleWordOperand(2) == 1) {
-            int32_type_id_ = iter->result_id();
-          } else {
-            uint32_type_id_ = iter->result_id();
-          }
-        }
-        break;
       case SpvOp::SpvOpConstant:
         if (iter->type_id() == uint32_type_id_) {
           uint32_t value = iter->GetSingleWordOperand(2);
@@ -155,7 +150,8 @@ uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
 
   if (constant_ids_[val] == 0) {
     if (uint32_type_id_ == 0) {
-      uint32_type_id_ = CreateUint32Type();
+      analysis::Integer uint(32, false);
+      uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint);
     }
 
     // Construct the constant.
@@ -199,17 +195,5 @@ bool StrengthReductionPass::ScanFunctions() {
   return modified;
 }
 
-uint32_t StrengthReductionPass::CreateUint32Type() {
-  uint32_t type_id = TakeNextId();
-  ir::Operand widthOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
-                           {32});
-  ir::Operand signOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
-                          {0});
-  std::unique_ptr<ir::Instruction> newType(new ir::Instruction(
-      context(), SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand}));
-  context()->AddType(std::move(newType));
-  return type_id;
-}
-
 }  // namespace opt
 }  // namespace spvtools
index 425bb32..6c233e1 100644 (file)
@@ -49,10 +49,6 @@ class StrengthReductionPass : public Pass {
   // ones. Returns true if something changed.
   bool ScanFunctions();
 
-  // Will create the type for an unsigned 32-bit integer and return the id.
-  // This functions assumes one does not already exist.
-  uint32_t CreateUint32Type();
-
   // Type ids for the types of interest, or 0 if they do not exist.
   uint32_t int32_type_id_;
   uint32_t uint32_type_id_;
index 8a12527..971eb84 100644 (file)
 
 #include "type_manager.h"
 
+#include <cassert>
+#include <cstring>
 #include <utility>
 
+#include "ir_context.h"
 #include "log.h"
+#include "make_unique.h"
 #include "reflect.h"
 
 namespace spvtools {
 namespace opt {
 namespace analysis {
 
+TypeManager::TypeManager(const MessageConsumer& consumer,
+                         spvtools::ir::IRContext* c)
+    : consumer_(consumer), context_(c) {
+  AnalyzeTypes(*c->module());
+}
+
 Type* TypeManager::GetType(uint32_t id) const {
   auto iter = id_to_type_.find(id);
   if (iter != id_to_type_.end()) return (*iter).second.get();
   return nullptr;
 }
 
+std::pair<Type*, std::unique_ptr<Pointer>> TypeManager::GetTypeAndPointerType(
+    uint32_t id, SpvStorageClass sc) const {
+  Type* type = GetType(id);
+  if (type) {
+    return std::make_pair(type, MakeUnique<analysis::Pointer>(type, sc));
+  } else {
+    return std::make_pair(type, std::unique_ptr<analysis::Pointer>());
+  }
+}
+
 uint32_t TypeManager::GetId(const Type* type) const {
   auto iter = type_to_id_.find(type);
   if (iter != type_to_id_.end()) return (*iter).second;
@@ -45,6 +65,257 @@ void TypeManager::AnalyzeTypes(const spvtools::ir::Module& module) {
   for (const auto& inst : module.annotations()) AttachIfTypeDecoration(inst);
 }
 
+void TypeManager::RemoveId(uint32_t id) {
+  auto iter = id_to_type_.find(id);
+  if (iter == id_to_type_.end()) return;
+
+  auto& type = iter->second;
+  if (!type->IsUniqueType(true)) {
+    // Search for an equivalent type to re-map.
+    bool found = false;
+    for (auto& pair : id_to_type_) {
+      if (pair.first != id && *pair.second == *type) {
+        // Equivalent ambiguous type, re-map type.
+        type_to_id_.erase(type.get());
+        type_to_id_[pair.second.get()] = pair.first;
+        found = true;
+        break;
+      }
+      // No equivalent ambiguous type, remove mapping.
+      if (!found) type_to_id_.erase(type.get());
+    }
+  } else {
+    // Unique type, so just erase the entry.
+    type_to_id_.erase(type.get());
+  }
+
+  // Erase the entry for |id|.
+  id_to_type_.erase(iter);
+}
+
+uint32_t TypeManager::GetTypeInstruction(const Type* type) {
+  uint32_t id = GetId(type);
+  if (id != 0) return id;
+
+  std::unique_ptr<ir::Instruction> typeInst;
+  id = context()->TakeNextId();
+  RegisterType(id, *type);
+  switch (type->kind()) {
+#define DefineParameterlessCase(kind)                                          \
+  case Type::k##kind:                                                          \
+    typeInst.reset(new ir::Instruction(context(), SpvOpType##kind, 0, id,      \
+                                       std::initializer_list<ir::Operand>{})); \
+    break;
+    DefineParameterlessCase(Void);
+    DefineParameterlessCase(Bool);
+    DefineParameterlessCase(Sampler);
+    DefineParameterlessCase(Event);
+    DefineParameterlessCase(DeviceEvent);
+    DefineParameterlessCase(ReserveId);
+    DefineParameterlessCase(Queue);
+    DefineParameterlessCase(PipeStorage);
+    DefineParameterlessCase(NamedBarrier);
+#undef DefineParameterlessCase
+    case Type::kInteger:
+      typeInst.reset(new ir::Instruction(
+          context(), SpvOpTypeInt, 0, id,
+          std::initializer_list<ir::Operand>{
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsInteger()->width()}},
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER,
+               {(type->AsInteger()->IsSigned() ? 1u : 0u)}}}));
+      break;
+    case Type::kFloat:
+      typeInst.reset(new ir::Instruction(
+          context(), SpvOpTypeFloat, 0, id,
+          std::initializer_list<ir::Operand>{
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsFloat()->width()}}}));
+      break;
+    case Type::kVector: {
+      uint32_t subtype = GetTypeInstruction(type->AsVector()->element_type());
+      typeInst.reset(
+          new ir::Instruction(context(), SpvOpTypeVector, 0, id,
+                              std::initializer_list<ir::Operand>{
+                                  {SPV_OPERAND_TYPE_ID, {subtype}},
+                                  {SPV_OPERAND_TYPE_LITERAL_INTEGER,
+                                   {type->AsVector()->element_count()}}}));
+      break;
+    }
+    case Type::kMatrix: {
+      uint32_t subtype = GetTypeInstruction(type->AsMatrix()->element_type());
+      typeInst.reset(
+          new ir::Instruction(context(), SpvOpTypeMatrix, 0, id,
+                              std::initializer_list<ir::Operand>{
+                                  {SPV_OPERAND_TYPE_ID, {subtype}},
+                                  {SPV_OPERAND_TYPE_LITERAL_INTEGER,
+                                   {type->AsMatrix()->element_count()}}}));
+      break;
+    }
+    case Type::kImage: {
+      const Image* image = type->AsImage();
+      uint32_t subtype = GetTypeInstruction(image->sampled_type());
+      typeInst.reset(new ir::Instruction(
+          context(), SpvOpTypeImage, 0, id,
+          std::initializer_list<ir::Operand>{
+              {SPV_OPERAND_TYPE_ID, {subtype}},
+              {SPV_OPERAND_TYPE_DIMENSIONALITY,
+               {static_cast<uint32_t>(image->dim())}},
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {image->depth()}},
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER,
+               {(image->is_arrayed() ? 1u : 0u)}},
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER,
+               {(image->is_multisampled() ? 1u : 0u)}},
+              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {image->sampled()}},
+              {SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT,
+               {static_cast<uint32_t>(image->format())}},
+              {SPV_OPERAND_TYPE_ACCESS_QUALIFIER,
+               {static_cast<uint32_t>(image->access_qualifier())}}}));
+      break;
+    }
+    case Type::kSampledImage: {
+      uint32_t subtype =
+          GetTypeInstruction(type->AsSampledImage()->image_type());
+      typeInst.reset(
+          new ir::Instruction(context(), SpvOpTypeSampledImage, 0, id,
+                              std::initializer_list<ir::Operand>{
+                                  {SPV_OPERAND_TYPE_ID, {subtype}}}));
+      break;
+    }
+    case Type::kArray: {
+      uint32_t subtype = GetTypeInstruction(type->AsArray()->element_type());
+      typeInst.reset(new ir::Instruction(
+          context(), SpvOpTypeArray, 0, id,
+          std::initializer_list<ir::Operand>{
+              {SPV_OPERAND_TYPE_ID, {subtype}},
+              {SPV_OPERAND_TYPE_ID, {type->AsArray()->LengthId()}}}));
+      break;
+    }
+    case Type::kRuntimeArray: {
+      uint32_t subtype =
+          GetTypeInstruction(type->AsRuntimeArray()->element_type());
+      typeInst.reset(
+          new ir::Instruction(context(), SpvOpTypeRuntimeArray, 0, id,
+                              std::initializer_list<ir::Operand>{
+                                  {SPV_OPERAND_TYPE_ID, {subtype}}}));
+      break;
+    }
+    case Type::kStruct: {
+      std::vector<ir::Operand> ops;
+      const Struct* structTy = type->AsStruct();
+      for (auto ty : structTy->element_types()) {
+        ops.push_back(
+            ir::Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)}));
+      }
+      typeInst.reset(
+          new ir::Instruction(context(), SpvOpTypeStruct, 0, id, ops));
+      break;
+    }
+    case Type::kOpaque: {
+      const Opaque* opaque = type->AsOpaque();
+      size_t size = opaque->name().size();
+      // Convert to null-terminated packed UTF-8 string.
+      std::vector<uint32_t> words(size / 4 + 1, 0);
+      char* dst = reinterpret_cast<char*>(words.data());
+      strncpy(dst, opaque->name().c_str(), size);
+      typeInst.reset(
+          new ir::Instruction(context(), SpvOpTypeOpaque, 0, id,
+                              std::initializer_list<ir::Operand>{
+                                  {SPV_OPERAND_TYPE_LITERAL_STRING, words}}));
+      break;
+    }
+    case Type::kPointer: {
+      const Pointer* pointer = type->AsPointer();
+      uint32_t subtype = GetTypeInstruction(pointer->pointee_type());
+      typeInst.reset(new ir::Instruction(
+          context(), SpvOpTypePointer, 0, id,
+          std::initializer_list<ir::Operand>{
+              {SPV_OPERAND_TYPE_STORAGE_CLASS,
+               {static_cast<uint32_t>(pointer->storage_class())}},
+              {SPV_OPERAND_TYPE_ID, {subtype}}}));
+      break;
+    }
+    case Type::kFunction: {
+      std::vector<ir::Operand> ops;
+      const Function* function = type->AsFunction();
+      ops.push_back(ir::Operand(SPV_OPERAND_TYPE_ID,
+                                {GetTypeInstruction(function->return_type())}));
+      for (auto ty : function->param_types()) {
+        ops.push_back(
+            ir::Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)}));
+      }
+      typeInst.reset(
+          new ir::Instruction(context(), SpvOpTypeFunction, 0, id, ops));
+      break;
+    }
+    case Type::kPipe:
+      typeInst.reset(new ir::Instruction(
+          context(), SpvOpTypePipe, 0, id,
+          std::initializer_list<ir::Operand>{
+              {SPV_OPERAND_TYPE_ACCESS_QUALIFIER,
+               {static_cast<uint32_t>(type->AsPipe()->access_qualifier())}}}));
+      break;
+    case Type::kForwardPointer:
+      typeInst.reset(new ir::Instruction(
+          context(), SpvOpTypeForwardPointer, 0, 0,
+          std::initializer_list<ir::Operand>{
+              {SPV_OPERAND_TYPE_ID, {type->AsForwardPointer()->target_id()}},
+              {SPV_OPERAND_TYPE_STORAGE_CLASS,
+               {static_cast<uint32_t>(
+                   type->AsForwardPointer()->storage_class())}}}));
+      break;
+    default:
+      assert(false && "Unexpected type");
+      break;
+  }
+  context()->AddType(std::move(typeInst));
+  context()->get_def_use_mgr()->AnalyzeInstDefUse(
+      &*--context()->types_values_end());
+  AttachDecorations(id, type);
+
+  return id;
+}
+
+void TypeManager::AttachDecorations(uint32_t id, const Type* type) {
+  for (auto vec : type->decorations()) {
+    CreateDecoration(id, vec);
+  }
+  if (const Struct* structTy = type->AsStruct()) {
+    for (auto pair : structTy->element_decorations()) {
+      uint32_t element = pair.first;
+      for (auto vec : pair.second) {
+        CreateDecoration(id, vec, element);
+      }
+    }
+  }
+}
+
+void TypeManager::CreateDecoration(uint32_t target,
+                                   const std::vector<uint32_t>& decoration,
+                                   uint32_t element) {
+  std::vector<ir::Operand> ops;
+  ops.push_back(ir::Operand(SPV_OPERAND_TYPE_ID, {target}));
+  if (element != 0) {
+    ops.push_back(ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {element}));
+  }
+  ops.push_back(ir::Operand(SPV_OPERAND_TYPE_DECORATION, {decoration[0]}));
+  for (size_t i = 1; i < decoration.size(); ++i) {
+    ops.push_back(
+        ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {decoration[i]}));
+  }
+  context()->AddAnnotationInst(MakeUnique<ir::Instruction>(
+      context(), (element == 0 ? SpvOpDecorate : SpvOpMemberDecorate), 0, 0,
+      ops));
+  ir::Instruction* inst = &*--context()->annotation_end();
+  context()->get_def_use_mgr()->AnalyzeInstUse(inst);
+}
+
+void TypeManager::RegisterType(uint32_t id, const Type& type) {
+  auto& t = id_to_type_[id];
+  t.reset(type.Clone().release());
+  if (GetId(t.get()) == 0) {
+    type_to_id_[t.get()] = id;
+  }
+}
+
 Type* TypeManager::RecordIfTypeDefinition(
     const spvtools::ir::Instruction& inst) {
   if (!spvtools::ir::IsTypeInst(inst.opcode())) return nullptr;
@@ -80,8 +351,8 @@ Type* TypeManager::RecordIfTypeDefinition(
       type = new Image(
           GetType(inst.GetSingleWordInOperand(0)),
           static_cast<SpvDim>(inst.GetSingleWordInOperand(1)),
-          inst.GetSingleWordInOperand(2), inst.GetSingleWordInOperand(3),
-          inst.GetSingleWordInOperand(4), inst.GetSingleWordInOperand(5),
+          inst.GetSingleWordInOperand(2), inst.GetSingleWordInOperand(3) == 1,
+          inst.GetSingleWordInOperand(4) == 1, inst.GetSingleWordInOperand(5),
           static_cast<SpvImageFormat>(inst.GetSingleWordInOperand(6)), access);
     } break;
     case SpvOpTypeSampler:
index c8526aa..20519f1 100644 (file)
 #include "types.h"
 
 namespace spvtools {
+namespace ir {
+class IRContext;
+}  // namespace ir
 namespace opt {
 namespace analysis {
 
+// Hashing functor.
+//
+// All type pointers must be non-null.
+struct HashTypePointer {
+  size_t operator()(const Type* type) const {
+    assert(type);
+    return type->HashValue();
+  }
+};
+
+// Equality functor.
+//
+// Checks if two types pointers are the same type.
+//
+// All type pointers must be non-null.
+struct CompareTypePointers {
+  bool operator()(const Type* lhs, const Type* rhs) const {
+    assert(lhs && rhs);
+    return lhs->IsSame(rhs);
+  }
+};
+
 // A class for managing the SPIR-V type hierarchy.
 class TypeManager {
  public:
@@ -37,8 +62,7 @@ class TypeManager {
   // will be communicated to the outside via the given message |consumer|.
   // This instance only keeps a reference to the |consumer|, so the |consumer|
   // should outlive this instance.
-  TypeManager(const MessageConsumer& consumer,
-              const spvtools::ir::Module& module);
+  TypeManager(const MessageConsumer& consumer, spvtools::ir::IRContext* c);
 
   TypeManager(const TypeManager&) = delete;
   TypeManager(TypeManager&&) = delete;
@@ -62,16 +86,54 @@ class TypeManager {
   // Returns the number of forward pointer types hold in this manager.
   size_t NumForwardPointers() const { return forward_pointers_.size(); }
 
-  // Analyzes the types and decorations on types in the given |module|.
-  // TODO(dnovillo): This should be private and the type manager should know how
-  // to update itself when new types are added
-  // (https://github.com/KhronosGroup/SPIRV-Tools/issues/1071).
-  void AnalyzeTypes(const spvtools::ir::Module& module);
+  // Returns a pair of the type and pointer to the type in |sc|.
+  //
+  // |id| must be a registered type.
+  std::pair<Type*, std::unique_ptr<Pointer>> GetTypeAndPointerType(
+      uint32_t id, SpvStorageClass sc) const;
+
+  // Returns an id for a declaration representing |type|.
+  //
+  // If |type| is registered, then the registered id is returned. Otherwise,
+  // this function recursively adds type and annotation instructions as
+  // necessary to fully define |type|.
+  uint32_t GetTypeInstruction(const Type* type);
+
+  // Registers |id| to |type|.
+  //
+  // If GetId(|type|) already returns a non-zero id, the return value will be
+  // unchanged.
+  void RegisterType(uint32_t id, const Type& type);
+
+  // Removes knowledge of |id| from the manager.
+  //
+  // If |id| is an ambiguous type the multiple ids may be registered to |id|'s
+  // type (e.g. %struct1 and %struct1 might hash to the same type). In that
+  // case, calling GetId() with |id|'s type will return another suitable id
+  // defining that type.
+  void RemoveId(uint32_t id);
 
  private:
-  using TypeToIdMap = std::unordered_map<const Type*, uint32_t>;
+  using TypeToIdMap = std::unordered_map<const Type*, uint32_t, HashTypePointer,
+                                         CompareTypePointers>;
   using ForwardPointerVector = std::vector<std::unique_ptr<ForwardPointer>>;
 
+  // Analyzes the types and decorations on types in the given |module|.
+  void AnalyzeTypes(const spvtools::ir::Module& module);
+
+  spvtools::ir::IRContext* context() { return context_; }
+
+  // Attachs the decorations on |type| to |id|.
+  void AttachDecorations(uint32_t id, const Type* type);
+
+  // Create the annotation instruction.
+  //
+  // If |element| is zero, an OpDecorate is created, other an OpMemberDecorate
+  // is created. The annotation is registered with the DefUseManager and the
+  // DecorationManager.
+  void CreateDecoration(uint32_t id, const std::vector<uint32_t>& decoration,
+                        uint32_t element = 0);
+
   // Creates and returns a type from the given SPIR-V |inst|. Returns nullptr if
   // the given instruction is not for defining a type.
   Type* RecordIfTypeDefinition(const spvtools::ir::Instruction& inst);
@@ -80,6 +142,7 @@ class TypeManager {
   void AttachIfTypeDecoration(const spvtools::ir::Instruction& inst);
 
   const MessageConsumer& consumer_;  // Message consumer.
+  spvtools::ir::IRContext* context_;
   IdToTypeMap id_to_type_;  // Mapping from ids to their type representations.
   TypeToIdMap type_to_id_;  // Mapping from types to their defining ids.
   ForwardPointerVector forward_pointers_;  // All forward pointer declarations.
@@ -88,12 +151,6 @@ class TypeManager {
   std::unordered_set<ForwardPointer*> unresolved_forward_pointers_;
 };
 
-inline TypeManager::TypeManager(const spvtools::MessageConsumer& consumer,
-                                const spvtools::ir::Module& module)
-    : consumer_(consumer) {
-  AnalyzeTypes(module);
-}
-
 }  // namespace analysis
 }  // namespace opt
 }  // namespace spvtools
index 39bc907..62f8d7d 100644 (file)
@@ -14,6 +14,7 @@
 
 #include <algorithm>
 #include <cassert>
+#include <cstdint>
 #include <sstream>
 
 #include "types.h"
@@ -77,7 +78,154 @@ bool Type::HasSameDecorations(const Type* that) const {
   return CompareTwoVectors(decorations_, that->decorations_);
 }
 
-bool Integer::IsSame(Type* that) const {
+bool Type::IsUniqueType(bool allowVariablePointers) const {
+  switch (kind_) {
+    case kPointer:
+      return !allowVariablePointers;
+    case kStruct:
+    case kArray:
+    case kRuntimeArray:
+      return false;
+    default:
+      return true;
+  }
+}
+
+std::unique_ptr<Type> Type::Clone() const {
+  std::unique_ptr<Type> type;
+  switch (kind_) {
+#define DeclareKindCase(kind)                \
+  case k##kind:                              \
+    type.reset(new kind(*this->As##kind())); \
+    break;
+    DeclareKindCase(Void);
+    DeclareKindCase(Bool);
+    DeclareKindCase(Integer);
+    DeclareKindCase(Float);
+    DeclareKindCase(Vector);
+    DeclareKindCase(Matrix);
+    DeclareKindCase(Image);
+    DeclareKindCase(Sampler);
+    DeclareKindCase(SampledImage);
+    DeclareKindCase(Array);
+    DeclareKindCase(RuntimeArray);
+    DeclareKindCase(Struct);
+    DeclareKindCase(Opaque);
+    DeclareKindCase(Pointer);
+    DeclareKindCase(Function);
+    DeclareKindCase(Event);
+    DeclareKindCase(DeviceEvent);
+    DeclareKindCase(ReserveId);
+    DeclareKindCase(Queue);
+    DeclareKindCase(Pipe);
+    DeclareKindCase(ForwardPointer);
+    DeclareKindCase(PipeStorage);
+    DeclareKindCase(NamedBarrier);
+#undef DeclareKindCase
+    default:
+      assert(false && "Unhandled type");
+  }
+  return type;
+}
+
+std::unique_ptr<Type> Type::RemoveDecorations() const {
+  std::unique_ptr<Type> type(Clone());
+  type->ClearDecorations();
+  return type;
+}
+
+bool Type::operator==(const Type& other) const {
+  if (kind_ != other.kind_) return false;
+
+  switch (kind_) {
+#define DeclareKindCase(kind) \
+  case k##kind:               \
+    return As##kind()->IsSame(&other);
+    DeclareKindCase(Void);
+    DeclareKindCase(Bool);
+    DeclareKindCase(Integer);
+    DeclareKindCase(Float);
+    DeclareKindCase(Vector);
+    DeclareKindCase(Matrix);
+    DeclareKindCase(Image);
+    DeclareKindCase(Sampler);
+    DeclareKindCase(SampledImage);
+    DeclareKindCase(Array);
+    DeclareKindCase(RuntimeArray);
+    DeclareKindCase(Struct);
+    DeclareKindCase(Opaque);
+    DeclareKindCase(Pointer);
+    DeclareKindCase(Function);
+    DeclareKindCase(Event);
+    DeclareKindCase(DeviceEvent);
+    DeclareKindCase(ReserveId);
+    DeclareKindCase(Queue);
+    DeclareKindCase(Pipe);
+    DeclareKindCase(ForwardPointer);
+    DeclareKindCase(PipeStorage);
+    DeclareKindCase(NamedBarrier);
+#undef DeclareKindCase
+    default:
+      assert(false && "Unhandled type");
+      return false;
+  }
+}
+
+void Type::GetHashWords(std::vector<uint32_t>* words) const {
+  words->push_back(kind_);
+  for (auto d : decorations_) {
+    for (auto w : d) {
+      words->push_back(w);
+    }
+  }
+
+  switch (kind_) {
+#define DeclareKindCase(type)             \
+  case k##type:                           \
+    As##type()->GetExtraHashWords(words); \
+    break;
+    DeclareKindCase(Void);
+    DeclareKindCase(Bool);
+    DeclareKindCase(Integer);
+    DeclareKindCase(Float);
+    DeclareKindCase(Vector);
+    DeclareKindCase(Matrix);
+    DeclareKindCase(Image);
+    DeclareKindCase(Sampler);
+    DeclareKindCase(SampledImage);
+    DeclareKindCase(Array);
+    DeclareKindCase(RuntimeArray);
+    DeclareKindCase(Struct);
+    DeclareKindCase(Opaque);
+    DeclareKindCase(Pointer);
+    DeclareKindCase(Function);
+    DeclareKindCase(Event);
+    DeclareKindCase(DeviceEvent);
+    DeclareKindCase(ReserveId);
+    DeclareKindCase(Queue);
+    DeclareKindCase(Pipe);
+    DeclareKindCase(ForwardPointer);
+    DeclareKindCase(PipeStorage);
+    DeclareKindCase(NamedBarrier);
+#undef DeclareKindCase
+    default:
+      assert(false && "Unhandled type");
+      break;
+  }
+}
+
+size_t Type::HashValue() const {
+  std::u32string h;
+  std::vector<uint32_t> words;
+  GetHashWords(&words);
+  for (auto w : words) {
+    h.push_back(w);
+  }
+
+  return std::hash<std::u32string>()(h);
+}
+
+bool Integer::IsSame(const Type* that) const {
   const Integer* it = that->AsInteger();
   return it && width_ == it->width_ && signed_ == it->signed_ &&
          HasSameDecorations(that);
@@ -89,7 +237,12 @@ std::string Integer::str() const {
   return oss.str();
 }
 
-bool Float::IsSame(Type* that) const {
+void Integer::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  words->push_back(width_);
+  words->push_back(signed_);
+}
+
+bool Float::IsSame(const Type* that) const {
   const Float* ft = that->AsFloat();
   return ft && width_ == ft->width_ && HasSameDecorations(that);
 }
@@ -100,12 +253,16 @@ std::string Float::str() const {
   return oss.str();
 }
 
+void Float::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  words->push_back(width_);
+}
+
 Vector::Vector(Type* type, uint32_t count)
-    : element_type_(type), count_(count) {
+    : Type(kVector), element_type_(type), count_(count) {
   assert(type->AsBool() || type->AsInteger() || type->AsFloat());
 }
 
-bool Vector::IsSame(Type* that) const {
+bool Vector::IsSame(const Type* that) const {
   const Vector* vt = that->AsVector();
   if (!vt) return false;
   return count_ == vt->count_ && element_type_->IsSame(vt->element_type_) &&
@@ -118,12 +275,17 @@ std::string Vector::str() const {
   return oss.str();
 }
 
+void Vector::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  element_type_->GetHashWords(words);
+  words->push_back(count_);
+}
+
 Matrix::Matrix(Type* type, uint32_t count)
-    : element_type_(type), count_(count) {
+    : Type(kMatrix), element_type_(type), count_(count) {
   assert(type->AsVector());
 }
 
-bool Matrix::IsSame(Type* that) const {
+bool Matrix::IsSame(const Type* that) const {
   const Matrix* mt = that->AsMatrix();
   if (!mt) return false;
   return count_ == mt->count_ && element_type_->IsSame(mt->element_type_) &&
@@ -136,21 +298,26 @@ std::string Matrix::str() const {
   return oss.str();
 }
 
-Image::Image(Type* sampled_type, SpvDim dim, uint32_t depth, uint32_t arrayed,
-             uint32_t ms, uint32_t sampled, SpvImageFormat format,
-             SpvAccessQualifier access_qualifier)
-    : sampled_type_(sampled_type),
-      dim_(dim),
-      depth_(depth),
-      arrayed_(arrayed),
-      ms_(ms),
-      sampled_(sampled),
-      format_(format),
-      access_qualifier_(access_qualifier) {
+void Matrix::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  element_type_->GetHashWords(words);
+  words->push_back(count_);
+}
+
+Image::Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample,
+             uint32_t sampling, SpvImageFormat f, SpvAccessQualifier qualifier)
+    : Type(kImage),
+      sampled_type_(type),
+      dim_(dimen),
+      depth_(d),
+      arrayed_(array),
+      ms_(multisample),
+      sampled_(sampling),
+      format_(f),
+      access_qualifier_(qualifier) {
   // TODO(antiagainst): check sampled_type
 }
 
-bool Image::IsSame(Type* that) const {
+bool Image::IsSame(const Type* that) const {
   const Image* it = that->AsImage();
   if (!it) return false;
   return dim_ == it->dim_ && depth_ == it->depth_ && arrayed_ == it->arrayed_ &&
@@ -167,7 +334,18 @@ std::string Image::str() const {
   return oss.str();
 }
 
-bool SampledImage::IsSame(Type* that) const {
+void Image::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  sampled_type_->GetHashWords(words);
+  words->push_back(dim_);
+  words->push_back(depth_);
+  words->push_back(arrayed_);
+  words->push_back(ms_);
+  words->push_back(sampled_);
+  words->push_back(format_);
+  words->push_back(access_qualifier_);
+}
+
+bool SampledImage::IsSame(const Type* that) const {
   const SampledImage* sit = that->AsSampledImage();
   if (!sit) return false;
   return image_type_->IsSame(sit->image_type_) && HasSameDecorations(that);
@@ -179,12 +357,16 @@ std::string SampledImage::str() const {
   return oss.str();
 }
 
+void SampledImage::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  image_type_->GetHashWords(words);
+}
+
 Array::Array(Type* type, uint32_t length_id)
-    : element_type_(type), length_id_(length_id) {
+    : Type(kArray), element_type_(type), length_id_(length_id) {
   assert(!type->AsVoid());
 }
 
-bool Array::IsSame(Type* that) const {
+bool Array::IsSame(const Type* that) const {
   const Array* at = that->AsArray();
   if (!at) return false;
   return length_id_ == at->length_id_ &&
@@ -197,11 +379,17 @@ std::string Array::str() const {
   return oss.str();
 }
 
-RuntimeArray::RuntimeArray(Type* type) : element_type_(type) {
+void Array::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  element_type_->GetHashWords(words);
+  words->push_back(length_id_);
+}
+
+RuntimeArray::RuntimeArray(Type* type)
+    : Type(kRuntimeArray), element_type_(type) {
   assert(!type->AsVoid());
 }
 
-bool RuntimeArray::IsSame(Type* that) const {
+bool RuntimeArray::IsSame(const Type* that) const {
   const RuntimeArray* rat = that->AsRuntimeArray();
   if (!rat) return false;
   return element_type_->IsSame(rat->element_type_) && HasSameDecorations(that);
@@ -213,7 +401,12 @@ std::string RuntimeArray::str() const {
   return oss.str();
 }
 
-Struct::Struct(const std::vector<Type*>& types) : element_types_(types) {
+void RuntimeArray::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  element_type_->GetHashWords(words);
+}
+
+Struct::Struct(const std::vector<Type*>& types)
+    : Type(kStruct), element_types_(types) {
   for (auto* t : types) {
     (void)t;
     assert(!t->AsVoid());
@@ -230,7 +423,7 @@ void Struct::AddMemberDecoration(uint32_t index,
   element_decorations_[index].push_back(std::move(decoration));
 }
 
-bool Struct::IsSame(Type* that) const {
+bool Struct::IsSame(const Type* that) const {
   const Struct* st = that->AsStruct();
   if (!st) return false;
   if (element_types_.size() != st->element_types_.size()) return false;
@@ -261,7 +454,21 @@ std::string Struct::str() const {
   return oss.str();
 }
 
-bool Opaque::IsSame(Type* that) const {
+void Struct::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  for (auto t : element_types_) {
+    t->GetHashWords(words);
+  }
+  for (auto pair : element_decorations_) {
+    words->push_back(pair.first);
+    for (auto d : pair.second) {
+      for (auto w : d) {
+        words->push_back(w);
+      }
+    }
+  }
+}
+
+bool Opaque::IsSame(const Type* that) const {
   const Opaque* ot = that->AsOpaque();
   if (!ot) return false;
   return name_ == ot->name_ && HasSameDecorations(that);
@@ -273,12 +480,18 @@ std::string Opaque::str() const {
   return oss.str();
 }
 
-Pointer::Pointer(Type* type, SpvStorageClass storage_class)
-    : pointee_type_(type), storage_class_(storage_class) {
+void Opaque::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  for (auto c : name_) {
+    words->push_back(static_cast<char32_t>(c));
+  }
+}
+
+Pointer::Pointer(Type* type, SpvStorageClass sc)
+    : Type(kPointer), pointee_type_(type), storage_class_(sc) {
   assert(!type->AsVoid());
 }
 
-bool Pointer::IsSame(Type* that) const {
+bool Pointer::IsSame(const Type* that) const {
   const Pointer* pt = that->AsPointer();
   if (!pt) return false;
   if (storage_class_ != pt->storage_class_) return false;
@@ -288,15 +501,20 @@ bool Pointer::IsSame(Type* that) const {
 
 std::string Pointer::str() const { return pointee_type_->str() + "*"; }
 
-Function::Function(Type* return_type, const std::vector<Type*>& param_types)
-    : return_type_(return_type), param_types_(param_types) {
-  for (auto* t : param_types) {
+void Pointer::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  pointee_type_->GetHashWords(words);
+  words->push_back(storage_class_);
+}
+
+Function::Function(Type* ret_type, const std::vector<Type*>& params)
+    : Type(kFunction), return_type_(ret_type), param_types_(params) {
+  for (auto* t : params) {
     (void)t;
     assert(!t->AsVoid());
   }
 }
 
-bool Function::IsSame(Type* that) const {
+bool Function::IsSame(const Type* that) const {
   const Function* ft = that->AsFunction();
   if (!ft) return false;
   if (!return_type_->IsSame(ft->return_type_)) return false;
@@ -319,7 +537,14 @@ std::string Function::str() const {
   return oss.str();
 }
 
-bool Pipe::IsSame(Type* that) const {
+void Function::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  return_type_->GetHashWords(words);
+  for (auto t : param_types_) {
+    t->GetHashWords(words);
+  }
+}
+
+bool Pipe::IsSame(const Type* that) const {
   const Pipe* pt = that->AsPipe();
   if (!pt) return false;
   return access_qualifier_ == pt->access_qualifier_ && HasSameDecorations(that);
@@ -331,7 +556,11 @@ std::string Pipe::str() const {
   return oss.str();
 }
 
-bool ForwardPointer::IsSame(Type* that) const {
+void Pipe::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  words->push_back(access_qualifier_);
+}
+
+bool ForwardPointer::IsSame(const Type* that) const {
   const ForwardPointer* fpt = that->AsForwardPointer();
   if (!fpt) return false;
   return target_id_ == fpt->target_id_ &&
@@ -350,6 +579,12 @@ std::string ForwardPointer::str() const {
   return oss.str();
 }
 
+void ForwardPointer::GetExtraHashWords(std::vector<uint32_t>* words) const {
+  words->push_back(target_id_);
+  words->push_back(storage_class_);
+  if (pointer_) pointer_->GetHashWords(words);
+}
+
 }  // namespace analysis
 }  // namespace opt
 }  // namespace spvtools
index f1c8e3e..4b1085d 100644 (file)
@@ -17,6 +17,7 @@
 #ifndef LIBSPIRV_OPT_TYPES_H_
 #define LIBSPIRV_OPT_TYPES_H_
 
+#include <memory>
 #include <string>
 #include <unordered_map>
 #include <vector>
@@ -56,6 +57,37 @@ class NamedBarrier;
 // which is used as a way to probe the actual <subclass>.
 class Type {
  public:
+  // Available subtypes.
+  //
+  // When adding a new derived class of Type, please add an entry to the enum.
+  enum Kind {
+    kVoid,
+    kBool,
+    kInteger,
+    kFloat,
+    kVector,
+    kMatrix,
+    kImage,
+    kSampler,
+    kSampledImage,
+    kArray,
+    kRuntimeArray,
+    kStruct,
+    kOpaque,
+    kPointer,
+    kFunction,
+    kEvent,
+    kDeviceEvent,
+    kReserveId,
+    kQueue,
+    kPipe,
+    kForwardPointer,
+    kPipeStorage,
+    kNamedBarrier,
+  };
+
+  Type(Kind k) : kind_(k) {}
+
   virtual ~Type() {}
 
   // Attaches a decoration directly on this type.
@@ -68,15 +100,33 @@ class Type {
   bool HasSameDecorations(const Type* that) const;
   // Returns true if this type is exactly the same as |that| type, including
   // decorations.
-  virtual bool IsSame(Type* that) const = 0;
+  virtual bool IsSame(const Type* that) const = 0;
   // Returns a human-readable string to represent this type.
   virtual std::string str() const = 0;
 
+  Kind kind() const { return kind_; }
+  const std::vector<std::vector<uint32_t>>& decorations() const {
+    return decorations_;
+  }
+
   // Returns true if there is no decoration on this type. For struct types,
   // returns true only when there is no decoration for both the struct type
   // and the struct members.
   virtual bool decoration_empty() const { return decorations_.empty(); }
 
+  // Creates a clone of |this|.
+  std::unique_ptr<Type> Clone() const;
+
+  // Returns a clone of |this| minus any decorations.
+  std::unique_ptr<Type> RemoveDecorations() const;
+
+  // Returns true if this type must be unique.
+  //
+  // If variable pointers are allowed, then pointers are not required to be
+  // unique.
+  // TODO(alanbaker): Update this if variable pointers become a core feature.
+  bool IsUniqueType(bool allowVariablePointers = false) const;
+
 // A bunch of methods for casting this type to a given type. Returns this if the
 // cast can be done, nullptr otherwise.
 #define DeclareCastMethod(target)                  \
@@ -107,19 +157,39 @@ class Type {
   DeclareCastMethod(NamedBarrier);
 #undef DeclareCastMethod
 
+  bool operator==(const Type& other) const;
+
+  // Returns the hash value of this type.
+  size_t HashValue() const;
+
+  // Adds the necessary words to compute a hash value of this type to |words|.
+  void GetHashWords(std::vector<uint32_t>* words) const;
+
+  // Adds necessary extra words for a subtype to calculate a hash value into
+  // |words|.
+  virtual void GetExtraHashWords(std::vector<uint32_t>* words) const = 0;
+
  protected:
   // Decorations attached to this type. Each decoration is encoded as a vector
   // of uint32_t numbers. The first uint32_t number is the decoration value,
   // and the rest are the parameters to the decoration (if exists).
   std::vector<std::vector<uint32_t>> decorations_;
+
+ private:
+  // Removes decorations on this type. For struct types, also removes element
+  // decorations.
+  virtual void ClearDecorations() { decorations_.clear(); }
+
+  Kind kind_;
 };
 
 class Integer : public Type {
  public:
-  Integer(uint32_t w, bool is_signed) : width_(w), signed_(is_signed) {}
+  Integer(uint32_t w, bool is_signed)
+      : Type(kInteger), width_(w), signed_(is_signed) {}
   Integer(const Integer&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
 
   Integer* AsInteger() override { return this; }
@@ -127,6 +197,8 @@ class Integer : public Type {
   uint32_t width() const { return width_; }
   bool IsSigned() const { return signed_; }
 
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   uint32_t width_;  // bit width
   bool signed_;     // true if this integer is signed
@@ -134,16 +206,18 @@ class Integer : public Type {
 
 class Float : public Type {
  public:
-  Float(uint32_t w) : width_(w) {}
+  Float(uint32_t w) : Type(kFloat), width_(w) {}
   Float(const Float&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
 
   Float* AsFloat() override { return this; }
   const Float* AsFloat() const override { return this; }
   uint32_t width() const { return width_; }
 
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   uint32_t width_;  // bit width
 };
@@ -153,7 +227,7 @@ class Vector : public Type {
   Vector(Type* element_type, uint32_t count);
   Vector(const Vector&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
   const Type* element_type() const { return element_type_; }
   uint32_t element_count() const { return count_; }
@@ -161,6 +235,8 @@ class Vector : public Type {
   Vector* AsVector() override { return this; }
   const Vector* AsVector() const override { return this; }
 
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   Type* element_type_;
   uint32_t count_;
@@ -171,7 +247,7 @@ class Matrix : public Type {
   Matrix(Type* element_type, uint32_t count);
   Matrix(const Matrix&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
   const Type* element_type() const { return element_type_; }
   uint32_t element_count() const { return count_; }
@@ -179,6 +255,8 @@ class Matrix : public Type {
   Matrix* AsMatrix() override { return this; }
   const Matrix* AsMatrix() const override { return this; }
 
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   Type* element_type_;
   uint32_t count_;
@@ -186,23 +264,34 @@ class Matrix : public Type {
 
 class Image : public Type {
  public:
-  Image(Type* sampled_type, SpvDim dim, uint32_t depth, uint32_t arrayed,
-        uint32_t ms, uint32_t sampled, SpvImageFormat format,
-        SpvAccessQualifier access_qualifier = SpvAccessQualifierReadOnly);
+  Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample,
+        uint32_t sampling, SpvImageFormat f,
+        SpvAccessQualifier qualifier = SpvAccessQualifierReadOnly);
   Image(const Image&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
 
   Image* AsImage() override { return this; }
   const Image* AsImage() const override { return this; }
 
+  const Type* sampled_type() const { return sampled_type_; }
+  SpvDim dim() const { return dim_; }
+  uint32_t depth() const { return depth_; }
+  bool is_arrayed() const { return arrayed_; }
+  bool is_multisampled() const { return ms_; }
+  uint32_t sampled() const { return sampled_; }
+  SpvImageFormat format() const { return format_; }
+  SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
+
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   Type* sampled_type_;
   SpvDim dim_;
   uint32_t depth_;
-  uint32_t arrayed_;
-  uint32_t ms_;
+  bool arrayed_;
+  bool ms_;
   uint32_t sampled_;
   SpvImageFormat format_;
   SpvAccessQualifier access_qualifier_;
@@ -210,15 +299,19 @@ class Image : public Type {
 
 class SampledImage : public Type {
  public:
-  SampledImage(Type* image_type) : image_type_(image_type) {}
+  SampledImage(Type* image) : Type(kSampledImage), image_type_(image) {}
   SampledImage(const SampledImage&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
 
   SampledImage* AsSampledImage() override { return this; }
   const SampledImage* AsSampledImage() const override { return this; }
 
+  const Type* image_type() const { return image_type_; }
+
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   Type* image_type_;
 };
@@ -228,7 +321,7 @@ class Array : public Type {
   Array(Type* element_type, uint32_t length_id);
   Array(const Array&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
   const Type* element_type() const { return element_type_; }
   uint32_t LengthId() const { return length_id_; }
@@ -236,6 +329,8 @@ class Array : public Type {
   Array* AsArray() override { return this; }
   const Array* AsArray() const override { return this; }
 
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   Type* element_type_;
   uint32_t length_id_;
@@ -246,13 +341,15 @@ class RuntimeArray : public Type {
   RuntimeArray(Type* element_type);
   RuntimeArray(const RuntimeArray&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
   const Type* element_type() const { return element_type_; }
 
   RuntimeArray* AsRuntimeArray() override { return this; }
   const RuntimeArray* AsRuntimeArray() const override { return this; }
 
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   Type* element_type_;
 };
@@ -266,17 +363,28 @@ class Struct : public Type {
   // decoration enum, and the remaining words, if any, are its operands.
   void AddMemberDecoration(uint32_t index, std::vector<uint32_t>&& decoration);
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
   const std::vector<Type*>& element_types() const { return element_types_; }
   bool decoration_empty() const override {
     return decorations_.empty() && element_decorations_.empty();
   }
+  const std::unordered_map<uint32_t, std::vector<std::vector<uint32_t>>>&
+  element_decorations() const {
+    return element_decorations_;
+  }
 
   Struct* AsStruct() override { return this; }
   const Struct* AsStruct() const override { return this; }
 
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
+  void ClearDecorations() override {
+    decorations_.clear();
+    element_decorations_.clear();
+  }
+
   std::vector<Type*> element_types_;
   // We can attach decorations to struct members and that should not affect the
   // underlying element type. So we need an extra data structure here to keep
@@ -287,31 +395,38 @@ class Struct : public Type {
 
 class Opaque : public Type {
  public:
-  Opaque(std::string name) : name_(std::move(name)) {}
+  Opaque(std::string n) : Type(kOpaque), name_(std::move(n)) {}
   Opaque(const Opaque&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
 
   Opaque* AsOpaque() override { return this; }
   const Opaque* AsOpaque() const override { return this; }
 
+  const std::string& name() const { return name_; }
+
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   std::string name_;
 };
 
 class Pointer : public Type {
  public:
-  Pointer(Type* pointee_type, SpvStorageClass storage_class);
+  Pointer(Type* pointee, SpvStorageClass sc);
   Pointer(const Pointer&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
   const Type* pointee_type() const { return pointee_type_; }
+  SpvStorageClass storage_class() const { return storage_class_; }
 
   Pointer* AsPointer() override { return this; }
   const Pointer* AsPointer() const override { return this; }
 
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   Type* pointee_type_;
   SpvStorageClass storage_class_;
@@ -319,15 +434,20 @@ class Pointer : public Type {
 
 class Function : public Type {
  public:
-  Function(Type* return_type, const std::vector<Type*>& param_types);
+  Function(Type* ret_type, const std::vector<Type*>& params);
   Function(const Function&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
 
   Function* AsFunction() override { return this; }
   const Function* AsFunction() const override { return this; }
 
+  const Type* return_type() const { return return_type_; }
+  const std::vector<Type*>& param_types() const { return param_types_; }
+
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   Type* return_type_;
   std::vector<Type*> param_types_;
@@ -335,55 +455,67 @@ class Function : public Type {
 
 class Pipe : public Type {
  public:
-  Pipe(SpvAccessQualifier access_qualifier)
-      : access_qualifier_(access_qualifier) {}
+  Pipe(SpvAccessQualifier qualifier)
+      : Type(kPipe), access_qualifier_(qualifier) {}
   Pipe(const Pipe&) = default;
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
 
   Pipe* AsPipe() override { return this; }
   const Pipe* AsPipe() const override { return this; }
 
+  SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
+
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   SpvAccessQualifier access_qualifier_;
 };
 
 class ForwardPointer : public Type {
  public:
-  ForwardPointer(uint32_t id, SpvStorageClass storage_class)
-      : target_id_(id), storage_class_(storage_class), pointer_(nullptr) {}
+  ForwardPointer(uint32_t id, SpvStorageClass sc)
+      : Type(kForwardPointer),
+        target_id_(id),
+        storage_class_(sc),
+        pointer_(nullptr) {}
   ForwardPointer(const ForwardPointer&) = default;
 
   uint32_t target_id() const { return target_id_; }
   void SetTargetPointer(Pointer* pointer) { pointer_ = pointer; }
+  SpvStorageClass storage_class() const { return storage_class_; }
 
-  bool IsSame(Type* that) const override;
+  bool IsSame(const Type* that) const override;
   std::string str() const override;
 
   ForwardPointer* AsForwardPointer() override { return this; }
   const ForwardPointer* AsForwardPointer() const override { return this; }
 
+  void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
  private:
   uint32_t target_id_;
   SpvStorageClass storage_class_;
   Pointer* pointer_;
 };
 
-#define DefineParameterlessType(type, name)                \
-  class type : public Type {                               \
-   public:                                                 \
-    type() = default;                                      \
-    type(const type&) = default;                           \
-                                                           \
-    bool IsSame(Type* that) const override {               \
-      return that->As##type() && HasSameDecorations(that); \
-    }                                                      \
-    std::string str() const override { return #name; }     \
-                                                           \
-    type* As##type() override { return this; }             \
-    const type* As##type() const override { return this; } \
-  };
+#define DefineParameterlessType(type, name)                          \
+  class type : public Type {                                         \
+   public:                                                           \
+    type() : Type(k##type) {}                                        \
+    type(const type&) = default;                                     \
+                                                                     \
+    bool IsSame(const Type* that) const override {                   \
+      return that->As##type() && HasSameDecorations(that);           \
+    }                                                                \
+    std::string str() const override { return #name; }               \
+                                                                     \
+    type* As##type() override { return this; }                       \
+    const type* As##type() const override { return this; }           \
+                                                                     \
+    void GetExtraHashWords(std::vector<uint32_t>*) const override {} \
+  };  // namespace analysis
 DefineParameterlessType(Void, void);
 DefineParameterlessType(Bool, bool);
 DefineParameterlessType(Sampler, sampler);
index 1c8f2db..544cc47 100644 (file)
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
+#ifdef SPIRV_EFFCEE
+#include "effcee/effcee.h"
+#endif
+
 #include "opt/build_module.h"
 #include "opt/instruction.h"
 #include "opt/type_manager.h"
+#include "spirv-tools/libspirv.hpp"
 
 namespace {
 
 using namespace spvtools;
+using namespace spvtools::opt::analysis;
+
+bool Validate(const std::vector<uint32_t>& bin) {
+  spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
+  spv_context spvContext = spvContextCreate(target_env);
+  spv_diagnostic diagnostic = nullptr;
+  spv_const_binary_t binary = {bin.data(), bin.size()};
+  spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
+  if (error != 0) spvDiagnosticPrint(diagnostic);
+  spvDiagnosticDestroy(diagnostic);
+  spvContextDestroy(spvContext);
+  return error == 0;
+}
+
+#ifdef SPIRV_EFFCEE
+void Match(const std::string& original, ir::IRContext* context,
+           bool do_validation = true) {
+  std::vector<uint32_t> bin;
+  context->module()->ToBinary(&bin, true);
+  if (do_validation) {
+    EXPECT_TRUE(Validate(bin));
+  }
+  std::string assembly;
+  SpirvTools tools(SPV_ENV_UNIVERSAL_1_2);
+  EXPECT_TRUE(
+      tools.Disassemble(bin, &assembly, SpirvTools::kDefaultDisassembleOption))
+      << "Disassembling failed for shader:\n"
+      << assembly << std::endl;
+  auto match_result = effcee::Match(assembly, original);
+  EXPECT_EQ(effcee::Result::Status::Ok, match_result.status())
+      << match_result.message() << "\nChecking result:\n"
+      << assembly;
+}
+#endif
+
+std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
+  // Types in this test case are only equal to themselves, nothing else.
+  std::vector<std::unique_ptr<Type>> types;
+
+  // Void, Bool
+  types.emplace_back(new Void());
+  auto* voidt = types.back().get();
+  types.emplace_back(new Bool());
+  auto* boolt = types.back().get();
+
+  // Integer
+  types.emplace_back(new Integer(32, true));
+  auto* s32 = types.back().get();
+  types.emplace_back(new Integer(32, false));
+  types.emplace_back(new Integer(64, true));
+  types.emplace_back(new Integer(64, false));
+  auto* u64 = types.back().get();
+
+  // Float
+  types.emplace_back(new Float(32));
+  auto* f32 = types.back().get();
+  types.emplace_back(new Float(64));
+
+  // Vector
+  types.emplace_back(new Vector(s32, 2));
+  types.emplace_back(new Vector(s32, 3));
+  auto* v3s32 = types.back().get();
+  types.emplace_back(new Vector(u64, 4));
+  types.emplace_back(new Vector(f32, 3));
+  auto* v3f32 = types.back().get();
+
+  // Matrix
+  types.emplace_back(new Matrix(v3s32, 3));
+  types.emplace_back(new Matrix(v3s32, 4));
+  types.emplace_back(new Matrix(v3f32, 4));
+
+  // Images
+  types.emplace_back(new Image(s32, SpvDim2D, 0, 0, 0, 0, SpvImageFormatRg8,
+                               SpvAccessQualifierReadOnly));
+  auto* image1 = types.back().get();
+  types.emplace_back(new Image(s32, SpvDim2D, 0, 1, 0, 0, SpvImageFormatRg8,
+                               SpvAccessQualifierReadOnly));
+  types.emplace_back(new Image(s32, SpvDim3D, 0, 1, 0, 0, SpvImageFormatRg8,
+                               SpvAccessQualifierReadOnly));
+  types.emplace_back(new Image(voidt, SpvDim3D, 0, 1, 0, 1, SpvImageFormatRg8,
+                               SpvAccessQualifierReadWrite));
+  auto* image2 = types.back().get();
+
+  // Sampler
+  types.emplace_back(new Sampler());
+
+  // Sampled Image
+  types.emplace_back(new SampledImage(image1));
+  types.emplace_back(new SampledImage(image2));
+
+  // Array
+  types.emplace_back(new Array(f32, 100));
+  types.emplace_back(new Array(f32, 42));
+  auto* a42f32 = types.back().get();
+  types.emplace_back(new Array(u64, 24));
+
+  // RuntimeArray
+  types.emplace_back(new RuntimeArray(v3f32));
+  types.emplace_back(new RuntimeArray(v3s32));
+  auto* rav3s32 = types.back().get();
+
+  // Struct
+  types.emplace_back(new Struct(std::vector<Type*>{s32}));
+  types.emplace_back(new Struct(std::vector<Type*>{s32, f32}));
+  auto* sts32f32 = types.back().get();
+  types.emplace_back(new Struct(std::vector<Type*>{u64, a42f32, rav3s32}));
+
+  // Opaque
+  types.emplace_back(new Opaque(""));
+  types.emplace_back(new Opaque("hello"));
+  types.emplace_back(new Opaque("world"));
+
+  // Pointer
+  types.emplace_back(new Pointer(f32, SpvStorageClassInput));
+  types.emplace_back(new Pointer(sts32f32, SpvStorageClassFunction));
+  types.emplace_back(new Pointer(a42f32, SpvStorageClassFunction));
+
+  // Function
+  types.emplace_back(new Function(voidt, {}));
+  types.emplace_back(new Function(voidt, {boolt}));
+  types.emplace_back(new Function(voidt, {boolt, s32}));
+  types.emplace_back(new Function(s32, {boolt, s32}));
+
+  // Event, Device Event, Reserve Id, Queue,
+  types.emplace_back(new Event());
+  types.emplace_back(new DeviceEvent());
+  types.emplace_back(new ReserveId());
+  types.emplace_back(new Queue());
+
+  // Pipe, Forward Pointer, PipeStorage, NamedBarrier
+  types.emplace_back(new Pipe(SpvAccessQualifierReadWrite));
+  types.emplace_back(new Pipe(SpvAccessQualifierReadOnly));
+  types.emplace_back(new ForwardPointer(1, SpvStorageClassInput));
+  types.emplace_back(new ForwardPointer(2, SpvStorageClassInput));
+  types.emplace_back(new ForwardPointer(2, SpvStorageClassUniform));
+  types.emplace_back(new PipeStorage());
+  types.emplace_back(new NamedBarrier());
+
+  return types;
+}
 
 TEST(TypeManager, TypeStrings) {
   const std::string text = R"(
@@ -90,7 +235,7 @@ TEST(TypeManager, TypeStrings) {
 
   std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *context->module());
+  opt::analysis::TypeManager manager(nullptr, context.get());
 
   EXPECT_EQ(type_id_strs.size(), manager.NumTypes());
   EXPECT_EQ(2u, manager.NumForwardPointers());
@@ -120,7 +265,7 @@ TEST(TypeManager, DecorationOnStruct) {
   )";
   std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *context->module());
+  opt::analysis::TypeManager manager(nullptr, context.get());
 
   ASSERT_EQ(7u, manager.NumTypes());
   ASSERT_EQ(0u, manager.NumForwardPointers());
@@ -170,7 +315,7 @@ TEST(TypeManager, DecorationOnMember) {
   )";
   std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *context->module());
+  opt::analysis::TypeManager manager(nullptr, context.get());
 
   ASSERT_EQ(10u, manager.NumTypes());
   ASSERT_EQ(0u, manager.NumForwardPointers());
@@ -208,7 +353,7 @@ TEST(TypeManager, DecorationEmpty) {
   )";
   std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *context->module());
+  opt::analysis::TypeManager manager(nullptr, context.get());
 
   ASSERT_EQ(5u, manager.NumTypes());
   ASSERT_EQ(0u, manager.NumForwardPointers());
@@ -230,7 +375,7 @@ TEST(TypeManager, BeginEndForEmptyModule) {
   const std::string text = "";
   std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *context->module());
+  opt::analysis::TypeManager manager(nullptr, context.get());
   ASSERT_EQ(0u, manager.NumTypes());
   ASSERT_EQ(0u, manager.NumForwardPointers());
 
@@ -247,7 +392,7 @@ TEST(TypeManager, BeginEnd) {
   )";
   std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *context->module());
+  opt::analysis::TypeManager manager(nullptr, context.get());
   ASSERT_EQ(5u, manager.NumTypes());
   ASSERT_EQ(0u, manager.NumForwardPointers());
 
@@ -274,4 +419,295 @@ TEST(TypeManager, BeginEnd) {
   }
 }
 
+TEST(TypeManager, LookupType) {
+  const std::string text = R"(
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%int  = OpTypeInt 32 1
+%vec2 = OpTypeVector %int 2
+)";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
+  EXPECT_NE(context, nullptr);
+  TypeManager manager(nullptr, context.get());
+
+  Void voidTy;
+  EXPECT_EQ(manager.GetId(&voidTy), 1u);
+
+  Integer uintTy(32, false);
+  EXPECT_EQ(manager.GetId(&uintTy), 2u);
+
+  Integer intTy(32, true);
+  EXPECT_EQ(manager.GetId(&intTy), 3u);
+
+  Integer intTy2(32, true);
+  Vector vecTy(&intTy2, 2u);
+  EXPECT_EQ(manager.GetId(&vecTy), 4u);
+}
+
+TEST(TypeManager, RemoveId) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeInt 32 0
+%2 = OpTypeInt 32 1
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(context, nullptr);
+
+  context->get_type_mgr()->RemoveId(1u);
+  ASSERT_EQ(context->get_type_mgr()->GetType(1u), nullptr);
+  ASSERT_NE(context->get_type_mgr()->GetType(2u), nullptr);
+
+  context->get_type_mgr()->RemoveId(2u);
+  ASSERT_EQ(context->get_type_mgr()->GetType(1u), nullptr);
+  ASSERT_EQ(context->get_type_mgr()->GetType(2u), nullptr);
+}
+
+TEST(TypeManager, RemoveIdNonDuplicateAmbiguousType) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeInt 32 0
+%2 = OpTypeStruct %1
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(context, nullptr);
+
+  Integer u32(32, false);
+  Struct st({&u32});
+  ASSERT_EQ(context->get_type_mgr()->GetId(&st), 2u);
+  context->get_type_mgr()->RemoveId(2u);
+  ASSERT_EQ(context->get_type_mgr()->GetType(2u), nullptr);
+  ASSERT_EQ(context->get_type_mgr()->GetId(&st), 0u);
+}
+
+TEST(TypeManager, RemoveIdDuplicateAmbiguousType) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeInt 32 0
+%2 = OpTypeStruct %1
+%3 = OpTypeStruct %1
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(context, nullptr);
+
+  Integer u32(32, false);
+  Struct st({&u32});
+  uint32_t id = context->get_type_mgr()->GetId(&st);
+  ASSERT_NE(id, 0u);
+  uint32_t toRemove = id == 2u ? 2u : 3u;
+  uint32_t toStay = id == 2u ? 3u : 2u;
+  context->get_type_mgr()->RemoveId(toRemove);
+  ASSERT_EQ(context->get_type_mgr()->GetType(toRemove), nullptr);
+  ASSERT_EQ(context->get_type_mgr()->GetId(&st), toStay);
+}
+
+TEST(TypeManager, GetTypeAndPointerType) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeInt 32 0
+%2 = OpTypeStruct %1
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(context, nullptr);
+
+  Integer u32(32, false);
+  Pointer u32Ptr(&u32, SpvStorageClassFunction);
+  Struct st({&u32});
+  Pointer stPtr(&st, SpvStorageClassInput);
+
+  auto pair = context->get_type_mgr()->GetTypeAndPointerType(
+      3u, SpvStorageClassFunction);
+  ASSERT_EQ(nullptr, pair.first);
+  ASSERT_EQ(nullptr, pair.second);
+
+  pair = context->get_type_mgr()->GetTypeAndPointerType(
+      1u, SpvStorageClassFunction);
+  ASSERT_TRUE(pair.first->IsSame(&u32));
+  ASSERT_TRUE(pair.second->IsSame(&u32Ptr));
+
+  pair =
+      context->get_type_mgr()->GetTypeAndPointerType(2u, SpvStorageClassInput);
+  ASSERT_TRUE(pair.first->IsSame(&st));
+  ASSERT_TRUE(pair.second->IsSame(&stPtr));
+}
+
+#ifdef SPIRV_EFFCEE
+TEST(TypeManager, GetTypeInstructionInt) {
+  const std::string text = R"(
+; CHECK: OpTypeInt 32 0
+; CHECK: OpTypeInt 16 1
+OpCapability Shader
+OpCapability Int16
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
+  EXPECT_NE(context, nullptr);
+
+  Integer uint_32(32, false);
+  context->get_type_mgr()->GetTypeInstruction(&uint_32);
+
+  Integer int_16(16, true);
+  context->get_type_mgr()->GetTypeInstruction(&int_16);
+
+  Match(text, context.get());
+}
+
+TEST(TypeManager, GetTypeInstructionDuplicateInts) {
+  const std::string text = R"(
+; CHECK: OpTypeInt 32 0
+; CHECK-NOT: OpType
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
+  EXPECT_NE(context, nullptr);
+
+  Integer uint_32(32, false);
+  uint32_t id = context->get_type_mgr()->GetTypeInstruction(&uint_32);
+
+  Integer other(32, false);
+  EXPECT_EQ(context->get_type_mgr()->GetTypeInstruction(&other), id);
+
+  Match(text, context.get());
+}
+
+TEST(TypeManager, GetTypeInstructionAllTypes) {
+  const std::string text = R"(
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[input_ptr:%\w+]] = OpTypePointer Input [[uint]]
+; CHECK: [[uniform_ptr:%\w+]] = OpTypePointer Uniform [[uint]]
+; CHECK: [[uint24:%\w+]] = OpConstant [[uint]] 24
+; CHECK: [[uint42:%\w+]] = OpConstant [[uint]] 42
+; CHECK: [[uint100:%\w+]] = OpConstant [[uint]] 100
+; CHECK: [[void:%\w+]] = OpTypeVoid
+; CHECK: [[bool:%\w+]] = OpTypeBool
+; CHECK: [[s32:%\w+]] = OpTypeInt 32 1
+; CHECK: OpTypeInt 64 1
+; CHECK: [[u64:%\w+]] = OpTypeInt 64 0
+; CHECK: [[f32:%\w+]] = OpTypeFloat 32
+; CHECK: OpTypeFloat 64
+; CHECK: OpTypeVector [[s32]] 2
+; CHECK: [[v3s32:%\w+]] = OpTypeVector [[s32]] 3
+; CHECK: OpTypeVector [[u64]] 4
+; CHECK: [[v3f32:%\w+]] = OpTypeVector [[f32]] 3
+; CHECK: OpTypeMatrix [[v3s32]] 3
+; CHECK: OpTypeMatrix [[v3s32]] 4
+; CHECK: OpTypeMatrix [[v3f32]] 4
+; CHECK: [[image1:%\w+]] = OpTypeImage [[s32]] 2D 0 0 0 0 Rg8 ReadOnly
+; CHECK: OpTypeImage [[s32]] 2D 0 1 0 0 Rg8 ReadOnly
+; CHECK: OpTypeImage [[s32]] 3D 0 1 0 0 Rg8 ReadOnly
+; CHECK: [[image2:%\w+]] = OpTypeImage [[void]] 3D 0 1 0 1 Rg8 ReadWrite
+; CHECK: OpTypeSampler
+; CHECK: OpTypeSampledImage [[image1]]
+; CHECK: OpTypeSampledImage [[image2]]
+; CHECK: OpTypeArray [[f32]] [[uint100]]
+; CHECK: [[a42f32:%\w+]] = OpTypeArray [[f32]] [[uint42]]
+; CHECK: OpTypeArray [[u64]] [[uint24]]
+; CHECK: OpTypeRuntimeArray [[v3f32]]
+; CHECK: [[rav3s32:%\w+]] = OpTypeRuntimeArray [[v3s32]]
+; CHECK: OpTypeStruct [[s32]]
+; CHECK: [[sts32f32:%\w+]] = OpTypeStruct [[s32]] [[f32]]
+; CHECK: OpTypeStruct [[u64]] [[a42f32]] [[rav3s32]]
+; CHECK: OpTypeOpaque ""
+; CHECK: OpTypeOpaque "hello"
+; CHECK: OpTypeOpaque "world"
+; CHECK: OpTypePointer Input [[f32]]
+; CHECK: OpTypePointer Function [[sts32f32]]
+; CHECK: OpTypePointer Function [[a42f32]]
+; CHECK: OpTypeFunction [[void]]
+; CHECK: OpTypeFunction [[void]] [[bool]]
+; CHECK: OpTypeFunction [[void]] [[bool]] [[s32]]
+; CHECK: OpTypeFunction [[s32]] [[bool]] [[s32]]
+; CHECK: OpTypeEvent
+; CHECK: OpTypeDeviceEvent
+; CHECK: OpTypeReserveId
+; CHECK: OpTypeQueue
+; CHECK: OpTypePipe ReadWrite
+; CHECK: OpTypePipe ReadOnly
+; CHECK: OpTypeForwardPointer [[input_ptr]] Input
+; CHECK: OpTypeForwardPointer [[uniform_ptr]] Input
+; CHECK: OpTypeForwardPointer [[uniform_ptr]] Uniform
+; CHECK: OpTypePipeStorage
+; CHECK: OpTypeNamedBarrier
+OpCapability Shader
+OpCapability Int64
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%uint = OpTypeInt 32 0
+%1 = OpTypePointer Input %uint
+%2 = OpTypePointer Uniform %uint
+%24 = OpConstant %uint 24
+%42 = OpConstant %uint 42
+%100 = OpConstant %uint 100
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(context, nullptr);
+
+  std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
+  for (auto& t : types) {
+    context->get_type_mgr()->GetTypeInstruction(t.get());
+  }
+
+  Match(text, context.get(), false);
+}
+
+TEST(TypeManager, GetTypeInstructionWithDecorations) {
+  const std::string text = R"(
+; CHECK: OpDecorate [[struct:%\w+]] CPacked
+; CHECK: OpMemberDecorate [[struct]] 1 Offset 4
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[struct]] = OpTypeStruct [[uint]] [[uint]]
+OpCapability Shader
+OpCapability Kernel
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%uint = OpTypeInt 32 0
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  EXPECT_NE(context, nullptr);
+
+  Integer u32(32, false);
+  Struct st({&u32, &u32});
+  st.AddDecoration({10});
+  st.AddMemberDecoration(1, {{35, 4}});
+  (void)context->get_def_use_mgr();
+  context->get_type_mgr()->GetTypeInstruction(&st);
+
+  Match(text, context.get());
+}
+#endif  // SPIRV_EFFCEE
+
 }  // anonymous namespace
index adbc870..c1156af 100644 (file)
@@ -54,6 +54,9 @@ class SameTypeTest : public ::testing::Test {
         EXPECT_TRUE(types[i]->IsSame(types[j].get()))                     \
             << "expected '" << types[i]->str() << "' is the same as '"    \
             << types[j]->str() << "'";                                    \
+        EXPECT_TRUE(*types[i] == *types[j])                               \
+            << "expected '" << types[i]->str() << "' is the same as '"    \
+            << types[j]->str() << "'";                                    \
       }                                                                   \
     }                                                                     \
   }
@@ -86,7 +89,7 @@ TestMultipleInstancesOfTheSameType(PipeStorage);
 TestMultipleInstancesOfTheSameType(NamedBarrier);
 #undef TestMultipleInstanceOfTheSameType
 
-TEST(Types, AllTypes) {
+std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
   // Types in this test case are only equal to themselves, nothing else.
   std::vector<std::unique_ptr<Type>> types;
 
@@ -193,6 +196,13 @@ TEST(Types, AllTypes) {
   types.emplace_back(new PipeStorage());
   types.emplace_back(new NamedBarrier());
 
+  return types;
+}
+
+TEST(Types, AllTypes) {
+  // Types in this test case are only equal to themselves, nothing else.
+  std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
+
   for (size_t i = 0; i < types.size(); ++i) {
     for (size_t j = 0; j < types.size(); ++j) {
       if (i == j) {
@@ -258,4 +268,71 @@ TEST(Types, MatrixElementCount) {
   }
 }
 
+TEST(Types, IsUniqueType) {
+  std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
+
+  for (auto& t : types) {
+    bool expectation = true;
+    // Disallowing variable pointers.
+    switch (t->kind()) {
+      case Type::kArray:
+      case Type::kRuntimeArray:
+      case Type::kStruct:
+        expectation = false;
+        break;
+      default:
+        break;
+    }
+    EXPECT_EQ(t->IsUniqueType(false), expectation)
+        << "expected '" << t->str() << "' to be a "
+        << (expectation ? "" : "non-") << "unique type";
+
+    // Allowing variables pointers.
+    if (t->AsPointer()) expectation = false;
+    EXPECT_EQ(t->IsUniqueType(true), expectation)
+        << "expected '" << t->str() << "' to be a "
+        << (expectation ? "" : "non-") << "unique type";
+  }
+}
+
+std::vector<std::unique_ptr<Type>> GenerateAllTypesWithDecorations() {
+  std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
+  uint32_t elems = 1;
+  uint32_t decs = 1;
+  for (auto& t : types) {
+    for (uint32_t i = 0; i < (decs % 10); ++i) {
+      std::vector<uint32_t> decoration;
+      for (uint32_t j = 0; j < (elems % 4) + 1; ++j) {
+        decoration.push_back(j);
+      }
+      t->AddDecoration(std::move(decoration));
+      ++elems;
+      ++decs;
+    }
+  }
+
+  return types;
+}
+
+TEST(Types, Clone) {
+  std::vector<std::unique_ptr<Type>> types = GenerateAllTypesWithDecorations();
+  for (auto& t : types) {
+    auto clone = t->Clone();
+    EXPECT_TRUE(*t == *clone);
+    EXPECT_TRUE(t->HasSameDecorations(clone.get()));
+    EXPECT_NE(clone.get(), t.get());
+  }
+}
+
+TEST(Types, RemoveDecorations) {
+  std::vector<std::unique_ptr<Type>> types = GenerateAllTypesWithDecorations();
+  for (auto& t : types) {
+    auto decorationless = t->RemoveDecorations();
+    EXPECT_EQ(*t == *decorationless, t->decoration_empty());
+    EXPECT_EQ(t->HasSameDecorations(decorationless.get()),
+              t->decoration_empty());
+    EXPECT_NE(t.get(), decorationless.get());
+  }
+}
+
 }  // anonymous namespace