Adding an unique id to Instruction generated by IRContext
authorAlan Baker <alanbaker@google.com>
Tue, 14 Nov 2017 19:11:50 +0000 (14:11 -0500)
committerDavid Neto <dneto@google.com>
Mon, 20 Nov 2017 22:49:10 +0000 (17:49 -0500)
Each instruction is given an unique id that can be used for ordering
purposes. The ids are generated via the IRContext.

Major changes:
* Instructions now contain a uint32_t for unique id and a cached context
pointer
 * Most constructors have been modified to take a context as input
 * unfortunately I cannot remove the default and copy constructors, but
 developers should avoid these
* Added accessors to parents of basic block and function
* Removed the copy constructors for BasicBlock and Function and replaced
them with Clone functions
* Reworked BuildModule to return an IRContext owning the built module
 * Since all instructions require a context, the context now becomes the
basic unit for IR
* Added a constructor to context to create an owned module internally
* Replaced uses of Instruction's copy constructor with Clone whereever I
found them
* Reworked the linker functionality to perform clones into a different
context instead of moves
* Updated many tests to be consistent with the above changes
 * Still need to add new tests to cover added functionality
* Added comparison operators to Instruction
* Added an internal option to LinkerOptions to verify merged ids are
unique
* Added a test for the linker to verify merged ids are unique

* Updated MergeReturnPass to supply a context
* Updated DecorationManager to supply a context for cloned decorations

* Reworked several portions of the def use tests in anticipation of next
set of changes

39 files changed:
include/spirv-tools/linker.hpp
source/link/linker.cpp
source/opt/aggressive_dead_code_elim_pass.cpp
source/opt/basic_block.cpp
source/opt/basic_block.h
source/opt/build_module.cpp
source/opt/build_module.h
source/opt/cfg.cpp
source/opt/common_uniform_elim_pass.cpp
source/opt/dead_branch_elim_pass.cpp
source/opt/decoration_manager.cpp
source/opt/flatten_decoration_pass.cpp
source/opt/fold_spec_constant_op_and_composite_pass.cpp
source/opt/function.cpp
source/opt/function.h
source/opt/inline_pass.cpp
source/opt/instruction.cpp
source/opt/instruction.h
source/opt/ir_context.h
source/opt/ir_loader.cpp
source/opt/ir_loader.h
source/opt/local_access_chain_convert_pass.cpp
source/opt/mem_pass.cpp
source/opt/merge_return_pass.cpp
source/opt/module.cpp
source/opt/module.h
source/opt/optimizer.cpp
source/opt/strength_reduction_pass.cpp
test/link/CMakeLists.txt
test/link/unique_ids_test.cpp [new file with mode: 0644]
test/opt/def_use_test.cpp
test/opt/instruction_test.cpp
test/opt/ir_context_test.cpp
test/opt/ir_loader_test.cpp
test/opt/module_test.cpp
test/opt/pass_fixture.h
test/opt/pass_manager_test.cpp
test/opt/pass_test.cpp
test/opt/type_manager_test.cpp

index 43c725d..a36aa75 100644 (file)
@@ -26,7 +26,9 @@ namespace spvtools {
 
 class LinkerOptions {
  public:
-  LinkerOptions() : createLibrary_(false) {}
+  LinkerOptions()
+      : createLibrary_(false),
+        verifyIds_(false) {}
 
   // Returns whether a library or an executable should be produced by the
   // linking phase.
@@ -36,13 +38,25 @@ class LinkerOptions {
   // The returned value will be true if creating a library, and false if
   // creating an executable.
   bool GetCreateLibrary() const { return createLibrary_; }
+
   // Sets whether a library or an executable should be produced.
   void SetCreateLibrary(bool create_library) {
     createLibrary_ = create_library;
   }
 
+  // Returns whether to verify the uniqueness of the unique ids in the merged
+  // context.
+  bool GetVerifyIds() const { return verifyIds_; }
+
+  // Sets whether to verify the uniqueness of the unique ids in the merged
+  // context.
+  void SetVerifyIds(bool verifyIds) {
+    verifyIds_ = verifyIds;
+  }
+
  private:
   bool createLibrary_;
+  bool verifyIds_;
 };
 
 class Linker {
index 59ea36c..7f1b5cd 100644 (file)
@@ -38,6 +38,7 @@
 namespace spvtools {
 
 using ir::Instruction;
+using ir::IRContext;
 using ir::Module;
 using ir::Operand;
 using opt::PassManager;
@@ -69,31 +70,34 @@ using LinkageTable = std::vector<LinkageEntry>;
 // is returned in |max_id_bound|.
 //
 // Both |modules| and |max_id_bound| should not be null, and |modules| should
-// not be empty either.
+// not be empty either. Furthermore |modules| should not contain any null
+// pointers.
 static spv_result_t ShiftIdsInModules(
     const MessageConsumer& consumer,
-    std::vector<std::unique_ptr<ir::Module>>* modules, uint32_t* max_id_bound);
+    std::vector<ir::Module*>* modules, uint32_t* max_id_bound);
 
 // Generates the header for the linked module and returns it in |header|.
 //
-// |header| should not be null, |modules| should not be empty and
-// |max_id_bound| should be strictly greater than 0.
+// |header| should not be null, |modules| should not be empty and pointers
+// should be non-null. |max_id_bound| should be strictly greater than 0.
 //
 // TODO(pierremoreau): What to do when binaries use different versions of
 //                     SPIR-V? For now, use the max of all versions found in
 //                     the input modules.
 static spv_result_t GenerateHeader(
     const MessageConsumer& consumer,
-    const std::vector<std::unique_ptr<ir::Module>>& modules,
+    const std::vector<ir::Module*>& modules,
     uint32_t max_id_bound, ir::ModuleHeader* header);
 
-// Merge all the modules from |inModules| into |linked_module|.
+// Merge all the modules from |inModules| into a single module owned by
+// |linked_context|.
 //
-// |linked_module| should not be null.
+// |linked_context| should not be null.
 static spv_result_t MergeModules(
     const MessageConsumer& consumer,
-    const std::vector<std::unique_ptr<Module>>& inModules,
-    const libspirv::AssemblyGrammar& grammar, Module* linked_module);
+    const std::vector<Module*>& inModules,
+    const libspirv::AssemblyGrammar& grammar,
+    IRContext* linked_context);
 
 // Compute all pairs of import and export and return it in |linkings_to_do|.
 //
@@ -123,7 +127,7 @@ static spv_result_t CheckImportExportCompatibility(
 // functions, declarations of imported variables, import (and export if
 // necessary) linkage attribtes.
 //
-// |linked_module| and |decoration_manager| should not be null, and the
+// |linked_context| and |decoration_manager| should not be null, and the
 // 'RemoveDuplicatePass' should be run first.
 //
 // TODO(pierremoreau): Linkage attributes applied by a group decoration are
@@ -136,6 +140,11 @@ static spv_result_t RemoveLinkageSpecificInstructions(
     const LinkageTable& linkings_to_do, DecorationManager* decoration_manager,
     ir::IRContext* linked_context);
 
+// Verify that the unique ids of each instruction in |linked_context| (i.e. the
+// merged module) are truly unique. Does not check the validity of other ids
+static spv_result_t VerifyIds(const MessageConsumer& consumer,
+                              ir::IRContext* linked_context);
+
 // Structs for holding the data members for SpvLinker.
 struct Linker::Impl {
   explicit Impl(spv_target_env env) : context(spvContextCreate(env)) {
@@ -186,7 +195,8 @@ spv_result_t Linker::Link(const uint32_t* const* binaries,
                                       SPV_ERROR_INVALID_BINARY)
            << "No modules were given.";
 
-  std::vector<std::unique_ptr<Module>> modules;
+  std::vector<std::unique_ptr<IRContext>> contexts;
+  std::vector<Module*> modules;
   modules.reserve(num_binaries);
   for (size_t i = 0u; i < num_binaries; ++i) {
     const uint32_t schema = binaries[i][4u];
@@ -197,13 +207,14 @@ spv_result_t Linker::Link(const uint32_t* const* binaries,
              << "Schema is non-zero for module " << i << ".";
     }
 
-    std::unique_ptr<Module> module = BuildModule(
+    std::unique_ptr<IRContext> context = BuildModule(
         impl_->context->target_env, consumer, binaries[i], binary_sizes[i]);
-    if (module == nullptr)
+    if (context == nullptr)
       return libspirv::DiagnosticStream(position, consumer,
                                         SPV_ERROR_INVALID_BINARY)
-             << "Failed to build a module out of " << modules.size() << ".";
-    modules.push_back(std::move(module));
+             << "Failed to build a module out of " << contexts.size() << ".";
+    modules.push_back(context->module());
+    contexts.push_back(std::move(context));
   }
 
   // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint
@@ -216,14 +227,18 @@ spv_result_t Linker::Link(const uint32_t* const* binaries,
   ir::ModuleHeader header;
   res = GenerateHeader(consumer, modules, max_id_bound, &header);
   if (res != SPV_SUCCESS) return res;
-  auto linked_module = MakeUnique<Module>();
-  linked_module->SetHeader(header);
+  IRContext linked_context(consumer);
+  linked_context.module()->SetHeader(header);
 
   // Phase 3: Merge all the binaries into a single one.
   libspirv::AssemblyGrammar grammar(impl_->context);
-  res = MergeModules(consumer, modules, grammar, linked_module.get());
+  res = MergeModules(consumer, modules, grammar, &linked_context);
   if (res != SPV_SUCCESS) return res;
-  ir::IRContext linked_context(std::move(linked_module), consumer);
+
+  if (options.GetVerifyIds()) {
+    res = VerifyIds(consumer, &linked_context);
+    if (res != SPV_SUCCESS) return res;
+  }
 
   // Phase 4: Find the import/export pairs
   LinkageTable linkings_to_do;
@@ -270,7 +285,7 @@ spv_result_t Linker::Link(const uint32_t* const* binaries,
 
 static spv_result_t ShiftIdsInModules(
     const MessageConsumer& consumer,
-    std::vector<std::unique_ptr<ir::Module>>* modules, uint32_t* max_id_bound) {
+    std::vector<ir::Module*>* modules, uint32_t* max_id_bound) {
   spv_position_t position = {};
 
   if (modules == nullptr)
@@ -289,7 +304,7 @@ static spv_result_t ShiftIdsInModules(
   uint32_t id_bound = modules->front()->IdBound() - 1u;
   for (auto module_iter = modules->begin() + 1; module_iter != modules->end();
        ++module_iter) {
-    Module* module = module_iter->get();
+    Module* module = *module_iter;
     module->ForEachInst([&id_bound](Instruction* insn) {
       insn->ForEachId([&id_bound](uint32_t* id) { *id += id_bound; });
     });
@@ -313,7 +328,7 @@ static spv_result_t ShiftIdsInModules(
 
 static spv_result_t GenerateHeader(
     const MessageConsumer& consumer,
-    const std::vector<std::unique_ptr<ir::Module>>& modules,
+    const std::vector<ir::Module*>& modules,
     uint32_t max_id_bound, ir::ModuleHeader* header) {
   spv_position_t position = {};
 
@@ -341,28 +356,32 @@ static spv_result_t GenerateHeader(
 
 static spv_result_t MergeModules(
     const MessageConsumer& consumer,
-    const std::vector<std::unique_ptr<Module>>& input_modules,
-    const libspirv::AssemblyGrammar& grammar, Module* linked_module) {
+    const std::vector<Module*>& input_modules,
+    const libspirv::AssemblyGrammar& grammar, IRContext* linked_context) {
   spv_position_t position = {};
 
-  if (linked_module == nullptr)
+  if (linked_context == nullptr)
     return libspirv::DiagnosticStream(position, consumer,
                                       SPV_ERROR_INVALID_DATA)
            << "|linked_module| of MergeModules should not be null.";
+  Module* linked_module = linked_context->module();
 
   if (input_modules.empty()) return SPV_SUCCESS;
 
   for (const auto& module : input_modules)
     for (const auto& inst : module->capabilities())
-      linked_module->AddCapability(MakeUnique<Instruction>(inst));
+      linked_module->AddCapability(
+          std::unique_ptr<Instruction>(inst.Clone(linked_context)));
 
   for (const auto& module : input_modules)
     for (const auto& inst : module->extensions())
-      linked_module->AddExtension(MakeUnique<Instruction>(inst));
+      linked_module->AddExtension(
+          std::unique_ptr<Instruction>(inst.Clone(linked_context)));
 
   for (const auto& module : input_modules)
     for (const auto& inst : module->ext_inst_imports())
-      linked_module->AddExtInstImport(MakeUnique<Instruction>(inst));
+      linked_module->AddExtInstImport(
+          std::unique_ptr<Instruction>(inst.Clone(linked_context)));
 
   do {
     const Instruction* memory_model_inst = input_modules[0]->GetMemoryModel();
@@ -402,7 +421,7 @@ static spv_result_t MergeModules(
 
     if (memory_model_inst != nullptr)
       linked_module->SetMemoryModel(
-          MakeUnique<Instruction>(*memory_model_inst));
+          std::unique_ptr<Instruction>(memory_model_inst->Clone(linked_context)));
   } while (false);
 
   std::vector<std::pair<uint32_t, const char*>> entry_points;
@@ -424,25 +443,30 @@ static spv_result_t MergeModules(
                << "The entry point \"" << name << "\", with execution model "
                << desc->name << ", was already defined.";
       }
-      linked_module->AddEntryPoint(MakeUnique<Instruction>(inst));
+      linked_module->AddEntryPoint(
+          std::unique_ptr<Instruction>(inst.Clone(linked_context)));
       entry_points.emplace_back(model, name);
     }
 
   for (const auto& module : input_modules)
     for (const auto& inst : module->execution_modes())
-      linked_module->AddExecutionMode(MakeUnique<Instruction>(inst));
+      linked_module->AddExecutionMode(
+          std::unique_ptr<Instruction>(inst.Clone(linked_context)));
 
   for (const auto& module : input_modules)
     for (const auto& inst : module->debugs1())
-      linked_module->AddDebug1Inst(MakeUnique<Instruction>(inst));
+      linked_module->AddDebug1Inst(
+          std::unique_ptr<Instruction>(inst.Clone(linked_context)));
 
   for (const auto& module : input_modules)
     for (const auto& inst : module->debugs2())
-      linked_module->AddDebug2Inst(MakeUnique<Instruction>(inst));
+      linked_module->AddDebug2Inst(
+          std::unique_ptr<Instruction>(inst.Clone(linked_context)));
 
   for (const auto& module : input_modules)
     for (const auto& inst : module->annotations())
-      linked_module->AddAnnotationInst(MakeUnique<Instruction>(inst));
+      linked_module->AddAnnotationInst(
+          std::unique_ptr<Instruction>(inst.Clone(linked_context)));
 
   // TODO(pierremoreau): Since the modules have not been validate, should we
   //                     expect SpvStorageClassFunction variables outside
@@ -450,7 +474,8 @@ static spv_result_t MergeModules(
   uint32_t num_global_values = 0u;
   for (const auto& module : input_modules) {
     for (const auto& inst : module->types_values()) {
-      linked_module->AddType(MakeUnique<Instruction>(inst));
+      linked_module->AddType(
+          std::unique_ptr<Instruction>(inst.Clone(linked_context)));
       num_global_values += inst.opcode() == SpvOpVariable;
     }
   }
@@ -462,8 +487,7 @@ static spv_result_t MergeModules(
   // Process functions and their basic blocks
   for (const auto& module : input_modules) {
     for (const auto& func : *module) {
-      std::unique_ptr<ir::Function> cloned_func =
-          MakeUnique<ir::Function>(func);
+      std::unique_ptr<ir::Function> cloned_func(func.Clone(linked_context));
       cloned_func->SetParent(linked_module);
       linked_module->AddFunction(std::move(cloned_func));
     }
@@ -711,4 +735,19 @@ static spv_result_t RemoveLinkageSpecificInstructions(
   return SPV_SUCCESS;
 }
 
+spv_result_t VerifyIds(const MessageConsumer& consumer, ir::IRContext* linked_context) {
+  std::unordered_set<uint32_t> ids;
+  bool ok = true;
+  linked_context->module()->ForEachInst([&ids,&ok](const ir::Instruction* inst) {
+    ok &= ids.insert(inst->unique_id()).second;
+  });
+
+  if (!ok) {
+    consumer(SPV_MSG_INTERNAL_ERROR, "", {}, "Non-unique id in merged module");
+    return SPV_ERROR_INVALID_ID;
+  }
+
+  return SPV_SUCCESS;
+}
+
 }  // namespace spvtools
index 7be3f9a..6071b94 100644 (file)
@@ -176,7 +176,7 @@ void AggressiveDCEPass::ComputeInst2BlockMap(ir::Function* func) {
 
 void AggressiveDCEPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) {
   std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction(
-      SpvOpBranch, 0, 0,
+      context(), SpvOpBranch, 0, 0,
       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}}));
   get_def_use_mgr()->AnalyzeInstDefUse(&*newBranch);
   bp->AddInstruction(std::move(newBranch));
index 7e0f421..fccd396 100644 (file)
@@ -13,6 +13,8 @@
 // limitations under the License.
 
 #include "basic_block.h"
+#include "function.h"
+#include "module.h"
 
 #include "make_unique.h"
 
@@ -27,12 +29,13 @@ const uint32_t kSelectionMergeMergeBlockIdInIdx = 0;
 
 }  // namespace
 
-BasicBlock::BasicBlock(const BasicBlock& bb)
-    : function_(nullptr),
-      label_(MakeUnique<Instruction>(bb.GetLabelInst())),
-      insts_() {
-  for (auto& inst : bb.insts_)
-    AddInstruction(std::unique_ptr<Instruction>(inst.Clone()));
+BasicBlock* BasicBlock::Clone(IRContext* context) const {
+  BasicBlock* clone =
+      new BasicBlock(std::unique_ptr<Instruction>(GetLabelInst().Clone(context)));
+  for (const auto& inst : insts_)
+    // Use the incoming context
+    clone->AddInstruction(std::unique_ptr<Instruction>(inst.Clone(context)));
+  return clone;
 }
 
 const Instruction* BasicBlock::GetMergeInst() const {
index 32550e7..f4405f2 100644 (file)
@@ -31,6 +31,7 @@ namespace spvtools {
 namespace ir {
 
 class Function;
+class IRContext;
 
 // A SPIR-V basic block.
 class BasicBlock {
@@ -41,15 +42,20 @@ class BasicBlock {
   // Creates a basic block with the given starting |label|.
   inline explicit BasicBlock(std::unique_ptr<Instruction> label);
 
-  // Creates a basic block from the given basic block |bb|.
+  explicit BasicBlock(const BasicBlock& bb) = delete;
+
+  // Creates a clone of the basic block in the given |context|
   //
   // The parent function will default to null and needs to be explicitly set by
   // the user.
-  explicit BasicBlock(const BasicBlock& bb);
+  BasicBlock* Clone(IRContext*) const;
 
   // Sets the enclosing function for this basic block.
   void SetParent(Function* function) { function_ = function; }
 
+  // Return the enclosing function
+  inline Function* GetParent() const { return function_; }
+
   // Appends an instruction to this basic block.
   inline void AddInstruction(std::unique_ptr<Instruction> i);
 
index e3439f3..42dbdd7 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "build_module.h"
 
+#include"ir_context.h"
 #include "ir_loader.h"
 #include "make_unique.h"
 #include "table.h"
@@ -43,15 +44,15 @@ spv_result_t SetSpvInst(void* builder, const spv_parsed_instruction_t* inst) {
 
 }  // annoymous namespace
 
-std::unique_ptr<ir::Module> BuildModule(spv_target_env env,
+std::unique_ptr<ir::IRContext> BuildModule(spv_target_env env,
                                         MessageConsumer consumer,
                                         const uint32_t* binary,
                                         const size_t size) {
   auto context = spvContextCreate(env);
   SetContextMessageConsumer(context, consumer);
 
-  auto module = MakeUnique<ir::Module>();
-  ir::IrLoader loader(context->consumer, module.get());
+  auto irContext = MakeUnique<ir::IRContext>(consumer);
+  ir::IrLoader loader(consumer, irContext->module());
 
   spv_result_t status = spvBinaryParse(context, &loader, binary, size,
                                        SetSpvHeader, SetSpvInst, nullptr);
@@ -59,13 +60,13 @@ std::unique_ptr<ir::Module> BuildModule(spv_target_env env,
 
   spvContextDestroy(context);
 
-  return status == SPV_SUCCESS ? std::move(module) : nullptr;
+  return status == SPV_SUCCESS ? std::move(irContext) : nullptr;
 }
 
-std::unique_ptr<ir::Module> BuildModule(spv_target_env env,
-                                        MessageConsumer consumer,
-                                        const std::string& text,
-                                        uint32_t assemble_options) {
+std::unique_ptr<ir::IRContext> BuildModule(spv_target_env env,
+                                           MessageConsumer consumer,
+                                           const std::string& text,
+                                           uint32_t assemble_options) {
   SpirvTools t(env);
   t.SetMessageConsumer(consumer);
   std::vector<uint32_t> binary;
index 36ea74f..3ee6607 100644 (file)
 #include <memory>
 #include <string>
 
+#include "ir_context.h"
 #include "module.h"
 #include "spirv-tools/libspirv.hpp"
 
 namespace spvtools {
 
-// Builds and returns an ir::Module from the given SPIR-V |binary|. |size|
-// specifies number of words in |binary|. The |binary| will be decoded
-// according to the given target |env|. Returns nullptr if erors occur and
-// sends the errors to |consumer|.
-std::unique_ptr<ir::Module> BuildModule(spv_target_env env,
+// Builds an ir::Module returns the owning ir::IRContext from the given SPIR-V
+// |binary|. |size| specifies number of words in |binary|. The |binary| will be
+// decoded according to the given target |env|. Returns nullptr if errors occur
+// and sends the errors to |consumer|.
+std::unique_ptr<ir::IRContext> BuildModule(spv_target_env env,
                                         MessageConsumer consumer,
                                         const uint32_t* binary, size_t size);
 
-// Builds and returns an ir::Module from the given SPIR-V assembly |text|.
-// The |text| will be encoded according to the given target |env|. Returns
-// nullptr if erors occur and sends the errors to |consumer|.
-std::unique_ptr<ir::Module> BuildModule(
+// Builds an ir::Module and returns the owning ir::IRContext from the given
+// SPIR-V assembly |text|.  The |text| will be encoded according to the given
+// target |env|. Returns nullptr if errors occur and sends the errors to
+// |consumer|.
+std::unique_ptr<ir::IRContext> BuildModule(
     spv_target_env env, MessageConsumer consumer, const std::string& text,
     uint32_t assemble_options = SpirvTools::kDefaultAssembleOption);
 
index 6adc110..a0b78c7 100644 (file)
@@ -29,9 +29,9 @@ const int kInvalidId = 0x400000;
 CFG::CFG(ir::Module* module)
     : module_(module),
       pseudo_entry_block_(std::unique_ptr<ir::Instruction>(
-          new ir::Instruction(SpvOpLabel, 0, 0, {}))),
+          new ir::Instruction(module->context(), SpvOpLabel, 0, 0, {}))),
       pseudo_exit_block_(std::unique_ptr<ir::Instruction>(
-          new ir::Instruction(SpvOpLabel, 0, kInvalidId, {}))) {
+          new ir::Instruction(module->context(), SpvOpLabel, 0, kInvalidId, {}))) {
   for (auto& fn : *module) {
     for (auto& blk : fn) {
       uint32_t blkId = blk.id();
index d68ed71..3339c4a 100644 (file)
@@ -231,7 +231,7 @@ void CommonUniformElimPass::GenACLoadRepl(
       ir::Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
                   std::initializer_list<uint32_t>{varId}));
   std::unique_ptr<ir::Instruction> newLoad(new ir::Instruction(
-      SpvOpLoad, varPteTypeId, ldResultId, load_in_operands));
+      context(), SpvOpLoad, varPteTypeId, ldResultId, load_in_operands));
   get_def_use_mgr()->AnalyzeInstDefUse(&*newLoad);
   newInsts->emplace_back(std::move(newLoad));
 
@@ -254,7 +254,7 @@ void CommonUniformElimPass::GenACLoadRepl(
     ++iidIdx;
   });
   std::unique_ptr<ir::Instruction> newExt(new ir::Instruction(
-      SpvOpCompositeExtract, ptrPteTypeId, extResultId, ext_in_opnds));
+      context(), SpvOpCompositeExtract, ptrPteTypeId, extResultId, ext_in_opnds));
   get_def_use_mgr()->AnalyzeInstDefUse(&*newExt);
   newInsts->emplace_back(std::move(newExt));
   *resultId = extResultId;
@@ -388,7 +388,7 @@ bool CommonUniformElimPass::CommonUniformLoadElimination(ir::Function* func) {
           // Copy load into most recent dominating block and remember it
           replId = TakeNextId();
           std::unique_ptr<ir::Instruction> newLoad(new ir::Instruction(
-              SpvOpLoad, ii->type_id(), replId,
+              context(), SpvOpLoad, ii->type_id(), replId,
               {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}}}));
           get_def_use_mgr()->AnalyzeInstDefUse(&*newLoad);
           insertItr = insertItr.InsertBefore(std::move(newLoad));
@@ -460,7 +460,7 @@ bool CommonUniformElimPass::CommonExtractElimination(ir::Function* func) {
         if (idxItr.second.size() < 2) continue;
         uint32_t replId = TakeNextId();
         std::unique_ptr<ir::Instruction> newExtract(
-            new ir::Instruction(*idxItr.second.front()));
+            idxItr.second.front()->Clone(context()));
         newExtract->SetResultId(replId);
         get_def_use_mgr()->AnalyzeInstDefUse(&*newExtract);
         ++ii;
index e3bf25f..f1c9bf1 100644 (file)
@@ -74,7 +74,7 @@ bool DeadBranchElimPass::GetConstInteger(uint32_t selId, uint32_t* selVal) {
 
 void DeadBranchElimPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) {
   std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction(
-      SpvOpBranch, 0, 0,
+      context(), SpvOpBranch, 0, 0,
       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}}));
   get_def_use_mgr()->AnalyzeInstDefUse(&*newBranch);
   bp->AddInstruction(std::move(newBranch));
@@ -83,7 +83,7 @@ void DeadBranchElimPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) {
 void DeadBranchElimPass::AddSelectionMerge(uint32_t labelId,
                                            ir::BasicBlock* bp) {
   std::unique_ptr<ir::Instruction> newMerge(new ir::Instruction(
-      SpvOpSelectionMerge, 0, 0,
+      context(), SpvOpSelectionMerge, 0, 0,
       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}},
        {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {0}}}));
   get_def_use_mgr()->AnalyzeInstDefUse(&*newMerge);
@@ -95,7 +95,7 @@ void DeadBranchElimPass::AddBranchConditional(uint32_t condId,
                                               uint32_t falseLabId,
                                               ir::BasicBlock* bp) {
   std::unique_ptr<ir::Instruction> newBranchCond(new ir::Instruction(
-      SpvOpBranchConditional, 0, 0,
+      context(), SpvOpBranchConditional, 0, 0,
       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {condId}},
        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {trueLabId}},
        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {falseLabId}}}));
@@ -302,7 +302,7 @@ bool DeadBranchElimPass::EliminateDeadBranches(ir::Function* func) {
               ++icnt;
             });
         std::unique_ptr<ir::Instruction> newPhi(new ir::Instruction(
-            SpvOpPhi, pii->type_id(), replId, phi_in_opnds));
+            context(), SpvOpPhi, pii->type_id(), replId, phi_in_opnds));
         get_def_use_mgr()->AnalyzeInstDefUse(&*newPhi);
         pii = pii.InsertBefore(std::move(newPhi));
         ++pii;
index aa926db..b25c20f 100644 (file)
@@ -70,11 +70,11 @@ bool DecorationManager::AreDecorationsTheSame(
   //    for (uint32_t i = 2u; i < inst.NumInOperands(); ++i) {
   //      const auto& j = constants.find(inst.GetSingleWordInOperand(i));
   //      if (j == constants.end())
-  //        return Instruction();
+  //        return Instruction(inst.context());
   //      const auto operand = j->second->GetOperand(0u);
   //      operands.emplace_back(operand.type, operand.words);
   //    }
-  //    return Instruction(SpvOpDecorate, 0u, 0u, operands);
+  //    return Instruction(inst.context(), SpvOpDecorate, 0u, 0u, operands);
   //  };
   //  Instruction tmpA = (deco1.opcode() == SpvOpDecorateId) ?
   //  decorateIdToDecorate(deco1) : deco1;
@@ -261,7 +261,7 @@ void DecorationManager::CloneDecorations(
       case SpvOpMemberDecorate:
       case SpvOpDecorateId: {
         // simply clone decoration and change |target-id| to |to|
-        std::unique_ptr<ir::Instruction> new_inst(inst->Clone());
+        std::unique_ptr<ir::Instruction> new_inst(inst->Clone(module_->context()));
         new_inst->SetInOperand(0, {to});
         id_to_decoration_insts_[to].push_back(new_inst.get());
         f(*new_inst, true);
index e92935d..eac8297 100644 (file)
@@ -91,7 +91,7 @@ Pass::Status FlattenDecorationPass::Process(ir::IRContext* c) {
         const auto normal_uses_iter = normal_uses.find(group);
         if (normal_uses_iter != normal_uses.end()) {
           for (auto target : normal_uses[group]) {
-            std::unique_ptr<Instruction> new_inst(new Instruction(*inst_iter));
+            std::unique_ptr<Instruction> new_inst(inst_iter->Clone(context()));
             new_inst->SetInOperand(0, Words{target});
             inst_iter = inst_iter.InsertBefore(std::move(new_inst));
             ++inst_iter;
@@ -116,8 +116,8 @@ Pass::Status FlattenDecorationPass::Process(ir::IRContext* c) {
             decoration_operands_iter++;  // Skip the group target.
             operands.insert(operands.end(), decoration_operands_iter,
                             inst_iter->end());
-            std::unique_ptr<Instruction> new_inst(
-                new Instruction(SpvOp::SpvOpMemberDecorate, 0, 0, operands));
+            std::unique_ptr<Instruction> new_inst(new Instruction(
+                context(), SpvOp::SpvOpMemberDecorate, 0, 0, operands));
             inst_iter = inst_iter.InsertBefore(std::move(new_inst));
             ++inst_iter;
             replace = true;
index a630d8a..e91d1fb 100644 (file)
@@ -724,22 +724,23 @@ std::unique_ptr<ir::Instruction>
 FoldSpecConstantOpAndCompositePass::CreateInstruction(uint32_t id,
                                                       analysis::Constant* c) {
   if (c->AsNullConstant()) {
-    return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantNull,
+    return MakeUnique<ir::Instruction>(context(), SpvOp::SpvOpConstantNull,
                                        type_mgr_->GetId(c->type()), id,
                                        std::initializer_list<ir::Operand>{});
   } else if (analysis::BoolConstant* bc = c->AsBoolConstant()) {
     return MakeUnique<ir::Instruction>(
+        context(),
         bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
         type_mgr_->GetId(c->type()), id, std::initializer_list<ir::Operand>{});
   } else if (analysis::IntConstant* ic = c->AsIntConstant()) {
     return MakeUnique<ir::Instruction>(
-        SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
+        context(), SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
         std::initializer_list<ir::Operand>{ir::Operand(
             spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
             ic->words())});
   } else if (analysis::FloatConstant* fc = c->AsFloatConstant()) {
     return MakeUnique<ir::Instruction>(
-        SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
+        context(), SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
         std::initializer_list<ir::Operand>{ir::Operand(
             spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
             fc->words())});
@@ -765,7 +766,8 @@ FoldSpecConstantOpAndCompositePass::CreateCompositeInstruction(
     operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
                           std::initializer_list<uint32_t>{id});
   }
-  return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantComposite,
+  return MakeUnique<ir::Instruction>(context(),
+                                     SpvOp::SpvOpConstantComposite,
                                      type_mgr_->GetId(cc->type()), result_id,
                                      std::move(operands));
 }
index 4ad2dce..dc5320f 100644 (file)
 namespace spvtools {
 namespace ir {
 
-Function::Function(const Function& f)
-    : module_(nullptr),
-      def_inst_(MakeUnique<Instruction>(f.DefInst())),
-      params_(),
-      blocks_(),
-      end_inst_() {
-  params_.reserve(f.params_.size());
-  f.ForEachParam(
-      [this](const Instruction* insn) {
-        AddParameter(MakeUnique<Instruction>(*insn));
+Function* Function::Clone(IRContext* context) const {
+  Function* clone =
+      new Function(std::unique_ptr<Instruction>(DefInst().Clone(context)));
+  clone->params_.reserve(params_.size());
+  ForEachParam(
+      [clone,context](const Instruction* inst) {
+        clone->AddParameter(std::unique_ptr<Instruction>(inst->Clone(context)));
       },
       true);
 
-  blocks_.reserve(f.blocks_.size());
-  for (const auto& b : f.blocks_) {
-    std::unique_ptr<BasicBlock> bb = MakeUnique<BasicBlock>(*b);
-    bb->SetParent(this);
-    AddBasicBlock(std::move(bb));
+  clone->blocks_.reserve(blocks_.size());
+  for (const auto& b : blocks_) {
+    std::unique_ptr<BasicBlock> bb(b->Clone(context));
+    bb->SetParent(clone);
+    clone->AddBasicBlock(std::move(bb));
   }
 
-  SetFunctionEnd(MakeUnique<Instruction>(f.function_end()));
+  clone->SetFunctionEnd(std::unique_ptr<Instruction>(function_end().Clone(context)));
+  return clone;
 }
 
 void Function::ForEachInst(const std::function<void(Instruction*)>& f,
index 618eb7d..9cd7209 100644 (file)
@@ -27,6 +27,7 @@
 namespace spvtools {
 namespace ir {
 
+class IRContext;
 class Module;
 
 // A SPIR-V function.
@@ -38,17 +39,22 @@ class Function {
   // Creates a function instance declared by the given OpFunction instruction
   // |def_inst|.
   inline explicit Function(std::unique_ptr<Instruction> def_inst);
-  // Creates a function instance based on the given function |f|.
+
+  explicit Function(const Function& f) = delete;
+
+  // Creates a clone of the instruction in the given |context|
   //
   // The parent module will default to null and needs to be explicitly set by
   // the user.
-  explicit Function(const Function& f);
+  Function* Clone(IRContext*) const;
   // The OpFunction instruction that begins the definition of this function.
   Instruction& DefInst() { return *def_inst_; }
   const Instruction& DefInst() const { return *def_inst_; }
 
   // Sets the enclosing module for this function.
   void SetParent(Module* module) { module_ = module; }
+  // Gets the enclosing module for this function
+  Module* GetParent() const { return module_; }
   // Appends a parameter to this function.
   inline void AddParameter(std::unique_ptr<Instruction> p);
   // Appends a basic block to this function.
index f52277b..5c6e3fb 100644 (file)
@@ -49,7 +49,7 @@ uint32_t InlinePass::AddPointerToType(uint32_t type_id,
                                       SpvStorageClass storage_class) {
   uint32_t resultId = TakeNextId();
   std::unique_ptr<ir::Instruction> type_inst(new ir::Instruction(
-      SpvOpTypePointer, 0, resultId,
+      context(), SpvOpTypePointer, 0, resultId,
       {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS,
         {uint32_t(storage_class)}},
        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}}));
@@ -60,7 +60,7 @@ uint32_t InlinePass::AddPointerToType(uint32_t type_id,
 void InlinePass::AddBranch(uint32_t label_id,
                            std::unique_ptr<ir::BasicBlock>* block_ptr) {
   std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction(
-      SpvOpBranch, 0, 0,
+      context(), SpvOpBranch, 0, 0,
       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}}));
   (*block_ptr)->AddInstruction(std::move(newBranch));
 }
@@ -69,7 +69,7 @@ void InlinePass::AddBranchCond(uint32_t cond_id, uint32_t true_id,
                                uint32_t false_id,
                                std::unique_ptr<ir::BasicBlock>* block_ptr) {
   std::unique_ptr<ir::Instruction> newBranch(new ir::Instruction(
-      SpvOpBranchConditional, 0, 0,
+      context(), SpvOpBranchConditional, 0, 0,
       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cond_id}},
        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {true_id}},
        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {false_id}}}));
@@ -79,7 +79,7 @@ void InlinePass::AddBranchCond(uint32_t cond_id, uint32_t true_id,
 void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id,
                               std::unique_ptr<ir::BasicBlock>* block_ptr) {
   std::unique_ptr<ir::Instruction> newLoopMerge(new ir::Instruction(
-      SpvOpLoopMerge, 0, 0,
+      context(), SpvOpLoopMerge, 0, 0,
       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {merge_id}},
        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {continue_id}},
        {spv_operand_type_t::SPV_OPERAND_TYPE_LOOP_CONTROL, {0}}}));
@@ -89,7 +89,7 @@ void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id,
 void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id,
                           std::unique_ptr<ir::BasicBlock>* block_ptr) {
   std::unique_ptr<ir::Instruction> newStore(new ir::Instruction(
-      SpvOpStore, 0, 0,
+      context(), SpvOpStore, 0, 0,
       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}},
        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {val_id}}}));
   (*block_ptr)->AddInstruction(std::move(newStore));
@@ -98,14 +98,14 @@ void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id,
 void InlinePass::AddLoad(uint32_t type_id, uint32_t resultId, uint32_t ptr_id,
                          std::unique_ptr<ir::BasicBlock>* block_ptr) {
   std::unique_ptr<ir::Instruction> newLoad(new ir::Instruction(
-      SpvOpLoad, type_id, resultId,
+      context(), SpvOpLoad, type_id, resultId,
       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}}));
   (*block_ptr)->AddInstruction(std::move(newLoad));
 }
 
 std::unique_ptr<ir::Instruction> InlinePass::NewLabel(uint32_t label_id) {
   std::unique_ptr<ir::Instruction> newLabel(
-      new ir::Instruction(SpvOpLabel, 0, label_id, {}));
+      new ir::Instruction(context(), SpvOpLabel, 0, label_id, {}));
   return newLabel;
 }
 
@@ -143,7 +143,8 @@ void InlinePass::CloneAndMapLocals(
   auto callee_block_itr = calleeFn->begin();
   auto callee_var_itr = callee_block_itr->begin();
   while (callee_var_itr->opcode() == SpvOp::SpvOpVariable) {
-    std::unique_ptr<ir::Instruction> var_inst(callee_var_itr->Clone());
+    std::unique_ptr<ir::Instruction> var_inst(
+        callee_var_itr->Clone(callee_var_itr->context()));
     uint32_t newId = TakeNextId();
     get_decoration_mgr()->CloneDecorations(callee_var_itr->result_id(), newId, update_def_use_mgr_);
     var_inst->SetResultId(newId);
@@ -169,7 +170,7 @@ uint32_t InlinePass::CreateReturnVar(
     // Add return var to new function scope variables.
     returnVarId = TakeNextId();
     std::unique_ptr<ir::Instruction> var_inst(new ir::Instruction(
-        SpvOpVariable, returnVarTypeId, returnVarId,
+        context(), SpvOpVariable, returnVarTypeId, returnVarId,
         {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS,
           {SpvStorageClassFunction}}}));
     new_vars->push_back(std::move(var_inst));
@@ -195,7 +196,8 @@ void InlinePass::CloneSameBlockOps(
           if (mapItr2 != (*preCallSB).end()) {
             // Clone pre-call same-block ops, map result id.
             const ir::Instruction* inInst = mapItr2->second;
-            std::unique_ptr<ir::Instruction> sb_inst(inInst->Clone());
+            std::unique_ptr<ir::Instruction> sb_inst(
+                inInst->Clone(inInst->context()));
             CloneSameBlockOps(&sb_inst, postCallSB, preCallSB, block_ptr);
             const uint32_t rid = sb_inst->result_id();
             const uint32_t nid = this->TakeNextId();
@@ -325,7 +327,7 @@ void InlinePass::GenInlineCode(
           // Copy contents of original caller block up to call instruction.
           for (auto cii = call_block_itr->begin(); cii != call_inst_itr;
                ++cii) {
-            std::unique_ptr<ir::Instruction> cp_inst(cii->Clone());
+            std::unique_ptr<ir::Instruction> cp_inst(cii->Clone(context()));
             // Remember same-block ops for possible regeneration.
             if (IsSameBlockOp(&*cp_inst)) {
               auto* sb_inst_ptr = cp_inst.get();
@@ -434,7 +436,7 @@ void InlinePass::GenInlineCode(
         // Copy remaining instructions from caller block.
         auto cii = call_inst_itr;
         for (++cii; cii != call_block_itr->end(); ++cii) {
-          std::unique_ptr<ir::Instruction> cp_inst(cii->Clone());
+          std::unique_ptr<ir::Instruction> cp_inst(cii->Clone(context()));
           // If multiple blocks generated, regenerate any same-block
           // instruction that has not been seen in this last block.
           if (multiBlocks) {
@@ -452,7 +454,7 @@ void InlinePass::GenInlineCode(
       } break;
       default: {
         // Copy callee instruction and remap all input Ids.
-        std::unique_ptr<ir::Instruction> cp_inst(cpi->Clone());
+        std::unique_ptr<ir::Instruction> cp_inst(cpi->Clone(context()));
         cp_inst->ForEachInId([&callee2caller, &callee_result_ids,
                               this](uint32_t* iid) {
           const auto mapItr = callee2caller.find(*iid);
@@ -497,7 +499,7 @@ void InlinePass::GenInlineCode(
     auto loop_merge_itr = last->tail();
     --loop_merge_itr;
     assert(loop_merge_itr->opcode() == SpvOpLoopMerge);
-    std::unique_ptr<ir::Instruction> cp_inst(loop_merge_itr->Clone());
+    std::unique_ptr<ir::Instruction> cp_inst(loop_merge_itr->Clone(context()));
     if (caller_is_single_block_loop) {
       // Also, update its continue target to point to the last block.
       cp_inst->SetInOperand(kSpvLoopMergeContinueTargetIdInIdx, {last->id()});
index f26fb1d..df2dcb7 100644 (file)
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "instruction.h"
+#include "ir_context.h"
 
 #include <initializer_list>
 
 namespace spvtools {
 namespace ir {
 
-Instruction::Instruction(const spv_parsed_instruction_t& inst,
+Instruction::Instruction(IRContext* c)
+    : utils::IntrusiveNodeBase<Instruction>(),
+      context_(c),
+      opcode_(SpvOpNop),
+      type_id_(0),
+      result_id_(0),
+      unique_id_(c->TakeNextUniqueId()) {}
+
+Instruction::Instruction(IRContext* c, SpvOp op)
+    : utils::IntrusiveNodeBase<Instruction>(),
+      context_(c),
+      opcode_(op),
+      type_id_(0),
+      result_id_(0),
+      unique_id_(c->TakeNextUniqueId()) {}
+
+Instruction::Instruction(IRContext* c, const spv_parsed_instruction_t& inst,
                          std::vector<Instruction>&& dbg_line)
-    : opcode_(static_cast<SpvOp>(inst.opcode)),
+    : context_(c),
+      opcode_(static_cast<SpvOp>(inst.opcode)),
       type_id_(inst.type_id),
       result_id_(inst.result_id),
+      unique_id_(c->TakeNextUniqueId()),
       dbg_line_insts_(std::move(dbg_line)) {
   assert((!IsDebugLineInst(opcode_) || dbg_line.empty()) &&
          "Op(No)Line attaching to Op(No)Line found");
@@ -38,12 +57,14 @@ Instruction::Instruction(const spv_parsed_instruction_t& inst,
   }
 }
 
-Instruction::Instruction(SpvOp op, uint32_t ty_id, uint32_t res_id,
+Instruction::Instruction(IRContext* c, SpvOp op, uint32_t ty_id, uint32_t res_id,
                          const std::vector<Operand>& in_operands)
     : utils::IntrusiveNodeBase<Instruction>(),
+      context_(c),
       opcode_(op),
       type_id_(ty_id),
       result_id_(res_id),
+      unique_id_(c->TakeNextUniqueId()),
       operands_() {
   if (type_id_ != 0) {
     operands_.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_TYPE_ID,
@@ -61,6 +82,7 @@ Instruction::Instruction(Instruction&& that)
       opcode_(that.opcode_),
       type_id_(that.type_id_),
       result_id_(that.result_id_),
+      unique_id_(that.unique_id_),
       operands_(std::move(that.operands_)),
       dbg_line_insts_(std::move(that.dbg_line_insts_)) {}
 
@@ -68,16 +90,18 @@ Instruction& Instruction::operator=(Instruction&& that) {
   opcode_ = that.opcode_;
   type_id_ = that.type_id_;
   result_id_ = that.result_id_;
+  unique_id_ = that.unique_id_;
   operands_ = std::move(that.operands_);
   dbg_line_insts_ = std::move(that.dbg_line_insts_);
   return *this;
 }
 
-Instruction* Instruction::Clone() const {
-  Instruction* clone = new Instruction();
+Instruction* Instruction::Clone(IRContext *c) const {
+  Instruction* clone = new Instruction(c);
   clone->opcode_ = opcode_;
   clone->type_id_ = type_id_;
   clone->result_id_ = result_id_;
+  clone->unique_id_ = c->TakeNextUniqueId();
   clone->operands_ = operands_;
   clone->dbg_line_insts_ = dbg_line_insts_;
   return clone;
index ff0acdb..4c96474 100644 (file)
@@ -31,6 +31,7 @@ namespace spvtools {
 namespace ir {
 
 class Function;
+class IRContext;
 class Module;
 class InstructionList;
 
@@ -84,28 +85,30 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   using const_iterator = std::vector<Operand>::const_iterator;
 
   // Creates a default OpNop instruction.
+  // This exists solely for containers that can't do without. Should be removed.
   Instruction()
       : utils::IntrusiveNodeBase<Instruction>(),
+        context_(nullptr),
         opcode_(SpvOpNop),
         type_id_(0),
-        result_id_(0) {}
+        result_id_(0),
+        unique_id_(0) {}
+
+  // Creates a default OpNop instruction.
+  Instruction(IRContext*);
   // Creates an instruction with the given opcode |op| and no additional logical
   // operands.
-  Instruction(SpvOp op)
-      : utils::IntrusiveNodeBase<Instruction>(),
-        opcode_(op),
-        type_id_(0),
-        result_id_(0) {}
+  Instruction(IRContext*, SpvOp);
   // Creates an instruction using the given spv_parsed_instruction_t |inst|. All
   // the data inside |inst| will be copied and owned in this instance. And keep
   // record of line-related debug instructions |dbg_line| ahead of this
   // instruction, if any.
-  Instruction(const spv_parsed_instruction_t& inst,
+  Instruction(IRContext* c, const spv_parsed_instruction_t& inst,
               std::vector<Instruction>&& dbg_line = {});
 
   // Creates an instruction with the given opcode |op|, type id: |ty_id|,
   // result id: |res_id| and input operands: |in_operands|.
-  Instruction(SpvOp op, uint32_t ty_id, uint32_t res_id,
+  Instruction(IRContext* c, SpvOp op, uint32_t ty_id, uint32_t res_id,
               const std::vector<Operand>& in_operands);
 
   // TODO: I will want to remove these, but will first have to remove the use of
@@ -123,7 +126,9 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   // It is the responsibility of the caller to make sure that the storage is
   // removed. It is the caller's responsibility to make sure that there is only
   // one instruction for each result id.
-  Instruction* Clone() const;
+  Instruction* Clone(IRContext *c) const;
+
+  IRContext* context() const { return context_; }
 
   SpvOp opcode() const { return opcode_; }
   // Sets the opcode of this instruction to a specific opcode. Note this may
@@ -133,6 +138,7 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   void SetOpcode(SpvOp op) { opcode_ = op; }
   uint32_t type_id() const { return type_id_; }
   uint32_t result_id() const { return result_id_; }
+  uint32_t unique_id() const { assert(unique_id_ != 0); return unique_id_; }
   // Returns the vector of line-related debug instructions attached to this
   // instruction and the caller can directly modify them.
   std::vector<Instruction>& dbg_line_insts() { return dbg_line_insts_; }
@@ -241,15 +247,21 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   // Returns true if the instruction annotates an id with a decoration.
   inline bool IsDecoration();
 
+  inline bool operator==(const Instruction&) const;
+  inline bool operator!=(const Instruction&) const;
+  inline bool operator<(const Instruction&) const;
+
  private:
   // Returns the total count of result type id and result id.
   uint32_t TypeResultIdCount() const {
     return (type_id_ != 0) + (result_id_ != 0);
   }
 
+  IRContext* context_;  // IR Context
   SpvOp opcode_;        // Opcode
   uint32_t type_id_;    // Result type id. A value of 0 means no result type id.
   uint32_t result_id_;  // Result id. A value of 0 means no result id.
+  uint32_t unique_id_;  // Unique instruction id
   // All logical operands, including result type id and result id.
   std::vector<Operand> operands_;
   // Opline and OpNoLine instructions preceding this instruction. Note that for
@@ -260,6 +272,18 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
   friend InstructionList;
 };
 
+inline bool Instruction::operator==(const Instruction& other) const {
+  return unique_id() == other.unique_id();
+}
+
+inline bool Instruction::operator!=(const Instruction& other) const {
+  return !(*this == other);
+}
+
+inline bool Instruction::operator<(const Instruction& other) const {
+  return unique_id() < other.unique_id();
+}
+
 inline const Operand& Instruction::GetOperand(uint32_t index) const {
   assert(index < operands_.size() && "operand index out of bound");
   return operands_[index];
index 23f59ed..eedee1e 100644 (file)
@@ -21,6 +21,7 @@
 
 #include <algorithm>
 #include <iostream>
+#include <limits>
 
 namespace spvtools {
 namespace ir {
@@ -53,11 +54,26 @@ class IRContext {
   friend inline Analysis operator<<(Analysis a, int shift);
   friend inline Analysis& operator<<=(Analysis& a, int shift);
 
+  // Create an |IRContext| that contains an owned |Module|
+  IRContext(spvtools::MessageConsumer c)
+      : unique_id_(0),
+        module_(new Module()),
+        consumer_(std::move(c)),
+        def_use_mgr_(nullptr),
+        valid_analyses_(kAnalysisNone)
+  {
+    module_->SetContext(this);
+  }
+
   IRContext(std::unique_ptr<Module>&& m, spvtools::MessageConsumer c)
-      : module_(std::move(m)),
+      : unique_id_(0),
+        module_(std::move(m)),
         consumer_(std::move(c)),
         def_use_mgr_(nullptr),
-        valid_analyses_(kAnalysisNone) {}
+        valid_analyses_(kAnalysisNone)
+  {
+    module_->SetContext(this);
+  }
   Module* module() const { return module_.get(); }
 
   inline void SetIdBound(uint32_t i);
@@ -239,6 +255,14 @@ class IRContext {
   // Kill all name and decorate ops targeting the result id of |inst|.
   void KillNamesAndDecorates(ir::Instruction* inst);
 
+  // Returns the next unique id for use by an instruction.
+  inline uint32_t TakeNextUniqueId() {
+    assert(unique_id_ != std::numeric_limits<uint32_t>::max());
+
+    // Skip zero.
+    return ++unique_id_;
+  }
+
  private:
   // Builds the def-use manager from scratch, even if it was already valid.
   void BuildDefUseManager() {
@@ -264,6 +288,13 @@ class IRContext {
     valid_analyses_ = valid_analyses_ | kAnalysisDecorations;
   }
 
+  // An unique identifier for this instruction. Can be used to order
+  // instructions in a container.
+  //
+  // This member is initialized to 0, but always issues this value plus one.
+  // Therefore, 0 is not a valid unique id for an instruction.
+  uint32_t unique_id_;
+
   std::unique_ptr<Module> module_;
   spvtools::MessageConsumer consumer_;
   std::unique_ptr<opt::analysis::DefUseManager> def_use_mgr_;
index e3d8484..b705343 100644 (file)
@@ -20,9 +20,9 @@
 namespace spvtools {
 namespace ir {
 
-IrLoader::IrLoader(const MessageConsumer& consumer, Module* module)
+IrLoader::IrLoader(const MessageConsumer& consumer, Module* m)
     : consumer_(consumer),
-      module_(module),
+      module_(m),
       source_("<instruction>"),
       inst_index_(0) {}
 
@@ -30,12 +30,12 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) {
   ++inst_index_;
   const auto opcode = static_cast<SpvOp>(inst->opcode);
   if (IsDebugLineInst(opcode)) {
-    dbg_line_info_.push_back(Instruction(*inst));
+    dbg_line_info_.push_back(Instruction(module()->context(), *inst));
     return true;
   }
 
   std::unique_ptr<Instruction> spv_inst(
-      new Instruction(*inst, std::move(dbg_line_info_)));
+      new Instruction(module()->context(), *inst, std::move(dbg_line_info_)));
   dbg_line_info_.clear();
 
   const char* src = source_.c_str();
index bcb55f1..2f0ca8b 100644 (file)
@@ -39,11 +39,13 @@ class IrLoader {
   // All internal messages will be communicated to the outside via the given
   // message |consumer|. This instance only keeps a reference to the |consumer|,
   // so the |consumer| should outlive this instance.
-  IrLoader(const MessageConsumer& consumer, Module* module);
+  IrLoader(const MessageConsumer& consumer, Module* m);
 
   // Sets the source name of the module.
   void SetSource(const std::string& src) { source_ = src; }
 
+  Module* module() const { return module_; }
+
   // Sets the fields in the module's header to the given parameters.
   void SetModuleHeader(uint32_t magic, uint32_t version, uint32_t generator,
                        uint32_t bound, uint32_t reserved) {
index 9663e88..ff5f912 100644 (file)
@@ -44,7 +44,7 @@ void LocalAccessChainConvertPass::BuildAndAppendInst(
     const std::vector<ir::Operand>& in_opnds,
     std::vector<std::unique_ptr<ir::Instruction>>* newInsts) {
   std::unique_ptr<ir::Instruction> newInst(
-      new ir::Instruction(opcode, typeId, resultId, in_opnds));
+      new ir::Instruction(context(), opcode, typeId, resultId, in_opnds));
   get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
   newInsts->emplace_back(std::move(newInst));
 }
index 72e4f73..ae5baa8 100644 (file)
@@ -287,7 +287,7 @@ uint32_t MemPass::Type2Undef(uint32_t type_id) {
   if (uitr != type2undefs_.end()) return uitr->second;
   const uint32_t undefId = TakeNextId();
   std::unique_ptr<ir::Instruction> undef_inst(
-      new ir::Instruction(SpvOpUndef, type_id, undefId, {}));
+      new ir::Instruction(context(), SpvOpUndef, type_id, undefId, {}));
   get_def_use_mgr()->AnalyzeInstDefUse(&*undef_inst);
   get_module()->AddGlobalValue(std::move(undef_inst));
   type2undefs_[type_id] = undefId;
@@ -402,7 +402,7 @@ void MemPass::SSABlockInitLoopHeader(
     }
     const uint32_t phiId = TakeNextId();
     std::unique_ptr<ir::Instruction> newPhi(
-        new ir::Instruction(SpvOpPhi, typeId, phiId, phi_in_operands));
+        new ir::Instruction(context(), SpvOpPhi, typeId, phiId, phi_in_operands));
     // The only phis requiring patching are the ones we create.
     phis_to_patch_.insert(phiId);
     // Only analyze the phi define now; analyze the phi uses after the
@@ -470,7 +470,7 @@ void MemPass::SSABlockInitMultiPred(ir::BasicBlock* block_ptr) {
     }
     const uint32_t phiId = TakeNextId();
     std::unique_ptr<ir::Instruction> newPhi(
-        new ir::Instruction(SpvOpPhi, typeId, phiId, phi_in_operands));
+        new ir::Instruction(context(), SpvOpPhi, typeId, phiId, phi_in_operands));
     get_def_use_mgr()->AnalyzeInstDefUse(&*newPhi);
     insertItr = insertItr.InsertBefore(std::move(newPhi));
     ++insertItr;
index 9374a91..e822885 100644 (file)
@@ -60,7 +60,7 @@ bool MergeReturnPass::MergeReturnBlocks(
 
   // Create a label for the new return block
   std::unique_ptr<ir::Instruction> returnLabel(
-      new ir::Instruction(SpvOpLabel, 0u, TakeNextId(), {}));
+      new ir::Instruction(context(), SpvOpLabel, 0u, TakeNextId(), {}));
   uint32_t returnId = returnLabel->result_id();
 
   // Create the new basic block
@@ -84,13 +84,14 @@ bool MergeReturnPass::MergeReturnBlocks(
     // Need a PHI node to select the correct return value.
     uint32_t phiResultId = TakeNextId();
     uint32_t phiTypeId = function->type_id();
-    std::unique_ptr<ir::Instruction> phiInst(
-        new ir::Instruction(SpvOpPhi, phiTypeId, phiResultId, phiOps));
+    std::unique_ptr<ir::Instruction> phiInst(new ir::Instruction(
+        context(), SpvOpPhi, phiTypeId, phiResultId, phiOps));
     retBlockIter->AddInstruction(std::move(phiInst));
     ir::BasicBlock::iterator phiIter = retBlockIter->tail();
 
-    std::unique_ptr<ir::Instruction> returnInst(new ir::Instruction(
-        SpvOpReturnValue, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {phiResultId}}}));
+    std::unique_ptr<ir::Instruction> returnInst(
+        new ir::Instruction(context(), SpvOpReturnValue, 0u, 0u,
+                            {{SPV_OPERAND_TYPE_ID, {phiResultId}}}));
     retBlockIter->AddInstruction(std::move(returnInst));
     ir::BasicBlock::iterator ret = retBlockIter->tail();
 
@@ -98,7 +99,7 @@ bool MergeReturnPass::MergeReturnBlocks(
     get_def_use_mgr()->AnalyzeInstDef(&*ret);
   } else {
     std::unique_ptr<ir::Instruction> returnInst(
-        new ir::Instruction(SpvOpReturn));
+        new ir::Instruction(context(), SpvOpReturn));
     retBlockIter->AddInstruction(std::move(returnInst));
   }
 
index e329b3c..9d46a1b 100644 (file)
@@ -65,7 +65,7 @@ uint32_t Module::GetGlobalValue(SpvOp opcode) const {
 void Module::AddGlobalValue(SpvOp opcode, uint32_t result_id,
                             uint32_t type_id) {
   std::unique_ptr<ir::Instruction> newGlobal(
-      new ir::Instruction(opcode, type_id, result_id, {}));
+      new ir::Instruction(context(), opcode, type_id, result_id, {}));
   AddGlobalValue(std::move(newGlobal));
 }
 
index e4c03e2..d3fe2b5 100644 (file)
@@ -27,6 +27,8 @@
 namespace spvtools {
 namespace ir {
 
+class IRContext;
+
 // A struct for containing the module header information.
 struct ModuleHeader {
   uint32_t magic_number;
@@ -223,11 +225,18 @@ class Module {
   // Returns 0 if not found.
   uint32_t GetExtInstImportId(const char* extstr);
 
+  // Sets the associated context for this module
+  void SetContext(IRContext* c) { context_ = c; }
+
+  // Gets the associated context for this module
+  IRContext* context() const { return context_; }
+
  private:
   ModuleHeader header_;  // Module header
 
   // The following fields respect the "Logical Layout of a Module" in
   // Section 2.4 of the SPIR-V specification.
+  IRContext* context_;
   InstructionList capabilities_;
   InstructionList extensions_;
   InstructionList ext_inst_imports_;
index 3527c1f..ac913dd 100644 (file)
@@ -105,19 +105,18 @@ Optimizer& Optimizer::RegisterSizePasses() {
 bool Optimizer::Run(const uint32_t* original_binary,
                     const size_t original_binary_size,
                     std::vector<uint32_t>* optimized_binary) const {
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(impl_->target_env, impl_->pass_manager.consumer(),
                   original_binary, original_binary_size);
-  if (module == nullptr) return false;
-  ir::IRContext context(std::move(module), impl_->pass_manager.consumer());
+  if (context == nullptr) return false;
 
-  auto status = impl_->pass_manager.Run(&context);
+  auto status = impl_->pass_manager.Run(context.get());
   if (status == opt::Pass::Status::SuccessWithChange ||
       (status == opt::Pass::Status::SuccessWithoutChange &&
        (optimized_binary->data() != original_binary ||
         optimized_binary->size() != original_binary_size))) {
     optimized_binary->clear();
-    context.module()->ToBinary(optimized_binary, /* skip_nop = */ true);
+    context->module()->ToBinary(optimized_binary, /* skip_nop = */ true);
   }
 
   return status != opt::Pass::Status::Failure;
index 5c08f5e..f2aee91 100644 (file)
@@ -100,7 +100,7 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
                                  {shiftConstResultId});
         newOperands.push_back(shiftOperand);
         std::unique_ptr<ir::Instruction> newInstruction(
-            new ir::Instruction(SpvOp::SpvOpShiftLeftLogical, inst->type_id(),
+            new ir::Instruction(context(), SpvOp::SpvOpShiftLeftLogical, inst->type_id(),
                                 newResultId, newOperands));
 
         // Insert the new instruction and update the data structures.
@@ -161,7 +161,7 @@ uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
     ir::Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
                          {val});
     std::unique_ptr<ir::Instruction> newConstant(new ir::Instruction(
-        SpvOp::SpvOpConstant, uint32_type_id_, resultId, {constant}));
+        context(), SpvOp::SpvOpConstant, uint32_type_id_, resultId, {constant}));
     get_module()->AddGlobalValue(std::move(newConstant));
 
     // Store the result id for next time.
@@ -199,7 +199,7 @@ uint32_t StrengthReductionPass::CreateUint32Type() {
   ir::Operand signOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
                           {0});
   std::unique_ptr<ir::Instruction> newType(new ir::Instruction(
-      SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand}));
+      context(), SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand}));
   context()->AddType(std::move(newType));
   return type_id;
 }
index 9768ab3..f2ced24 100644 (file)
@@ -41,3 +41,8 @@ add_spvtools_unittest(TARGET link_matching_imports_to_exports
   SRCS matching_imports_to_exports_test.cpp
   LIBS SPIRV-Tools-opt SPIRV-Tools-link
 )
+
+add_spvtools_unittest(TARGET link_unique_ids
+  SRCS unique_ids_test.cpp
+  LIBS SPIRV-Tools-opt SPIRV-Tools-link
+)
diff --git a/test/link/unique_ids_test.cpp b/test/link/unique_ids_test.cpp
new file mode 100644 (file)
index 0000000..8b67d34
--- /dev/null
@@ -0,0 +1,137 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "linker_fixture.h"
+
+namespace {
+
+using UniqueIds = spvtest::LinkerTest;
+
+TEST_F(UniqueIds, UniquelyMerged) {
+  std::vector<std::string> bodies(2);
+  bodies[0] =
+      // clang-format off
+               "OpCapability Shader\n"
+          "%1 = OpExtInstImport \"GLSL.std.450\"\n"
+               "OpMemoryModel Logical GLSL450\n"
+               "OpEntryPoint Vertex %main \"main\"\n"
+               "OpSource ESSL 310\n"
+               "OpName %main \"main\"\n"
+               "OpName %f_ \"f(\"\n"
+               "OpName %gv1 \"gv1\"\n"
+               "OpName %gv2 \"gv2\"\n"
+               "OpName %lv1 \"lv1\"\n"
+               "OpName %lv2 \"lv2\"\n"
+               "OpName %lv1_0 \"lv1\"\n"
+       "%void = OpTypeVoid\n"
+         "%10 = OpTypeFunction %void\n"
+      "%float = OpTypeFloat 32\n"
+         "%12 = OpTypeFunction %float\n"
+ "%_ptr_Private_float = OpTypePointer Private %float\n"
+        "%gv1 = OpVariable %_ptr_Private_float Private\n"
+   "%float_10 = OpConstant %float 10\n"
+        "%gv2 = OpVariable %_ptr_Private_float Private\n"
+  "%float_100 = OpConstant %float 100\n"
+ "%_ptr_Function_float = OpTypePointer Function %float\n"
+       "%main = OpFunction %void None %10\n"
+         "%17 = OpLabel\n"
+      "%lv1_0 = OpVariable %_ptr_Function_float Function\n"
+               "OpStore %gv1 %float_10\n"
+               "OpStore %gv2 %float_100\n"
+         "%18 = OpLoad %float %gv1\n"
+         "%19 = OpLoad %float %gv2\n"
+         "%20 = OpFSub %float %18 %19\n"
+               "OpStore %lv1_0 %20\n"
+               "OpReturn\n"
+               "OpFunctionEnd\n"
+         "%f_ = OpFunction %float None %12\n"
+         "%21 = OpLabel\n"
+        "%lv1 = OpVariable %_ptr_Function_float Function\n"
+        "%lv2 = OpVariable %_ptr_Function_float Function\n"
+         "%22 = OpLoad %float %gv1\n"
+         "%23 = OpLoad %float %gv2\n"
+         "%24 = OpFAdd %float %22 %23\n"
+               "OpStore %lv1 %24\n"
+         "%25 = OpLoad %float %gv1\n"
+         "%26 = OpLoad %float %gv2\n"
+         "%27 = OpFMul %float %25 %26\n"
+               "OpStore %lv2 %27\n"
+         "%28 = OpLoad %float %lv1\n"
+         "%29 = OpLoad %float %lv2\n"
+         "%30 = OpFDiv %float %28 %29\n"
+               "OpReturnValue %30\n"
+               "OpFunctionEnd\n";
+  // clang-format on
+  bodies[1] =
+      // clang-format off
+               "OpCapability Shader\n"
+          "%1 = OpExtInstImport \"GLSL.std.450\"\n"
+               "OpMemoryModel Logical GLSL450\n"
+               "OpSource ESSL 310\n"
+               "OpName %main \"main2\"\n"
+               "OpName %f_ \"f(\"\n"
+               "OpName %gv1 \"gv12\"\n"
+               "OpName %gv2 \"gv22\"\n"
+               "OpName %lv1 \"lv12\"\n"
+               "OpName %lv2 \"lv22\"\n"
+               "OpName %lv1_0 \"lv12\"\n"
+       "%void = OpTypeVoid\n"
+         "%10 = OpTypeFunction %void\n"
+      "%float = OpTypeFloat 32\n"
+         "%12 = OpTypeFunction %float\n"
+ "%_ptr_Private_float = OpTypePointer Private %float\n"
+        "%gv1 = OpVariable %_ptr_Private_float Private\n"
+   "%float_10 = OpConstant %float 10\n"
+        "%gv2 = OpVariable %_ptr_Private_float Private\n"
+  "%float_100 = OpConstant %float 100\n"
+ "%_ptr_Function_float = OpTypePointer Function %float\n"
+       "%main = OpFunction %void None %10\n"
+         "%17 = OpLabel\n"
+      "%lv1_0 = OpVariable %_ptr_Function_float Function\n"
+               "OpStore %gv1 %float_10\n"
+               "OpStore %gv2 %float_100\n"
+         "%18 = OpLoad %float %gv1\n"
+         "%19 = OpLoad %float %gv2\n"
+         "%20 = OpFSub %float %18 %19\n"
+               "OpStore %lv1_0 %20\n"
+               "OpReturn\n"
+               "OpFunctionEnd\n"
+         "%f_ = OpFunction %float None %12\n"
+         "%21 = OpLabel\n"
+        "%lv1 = OpVariable %_ptr_Function_float Function\n"
+        "%lv2 = OpVariable %_ptr_Function_float Function\n"
+         "%22 = OpLoad %float %gv1\n"
+         "%23 = OpLoad %float %gv2\n"
+         "%24 = OpFAdd %float %22 %23\n"
+               "OpStore %lv1 %24\n"
+         "%25 = OpLoad %float %gv1\n"
+         "%26 = OpLoad %float %gv2\n"
+         "%27 = OpFMul %float %25 %26\n"
+               "OpStore %lv2 %27\n"
+         "%28 = OpLoad %float %lv1\n"
+         "%29 = OpLoad %float %lv2\n"
+         "%30 = OpFDiv %float %28 %29\n"
+               "OpReturnValue %30\n"
+               "OpFunctionEnd\n";
+  // clang-format on
+
+  spvtest::Binary linked_binary;
+  spvtools::LinkerOptions options;
+  options.SetVerifyIds(true);
+  spv_result_t res = AssembleAndLink(bodies, &linked_binary, options);
+  EXPECT_EQ(SPV_SUCCESS, res);
+}
+
+} // anonymous namespace
index aa88978..bd1ac7e 100644 (file)
@@ -21,6 +21,7 @@
 #include "opt/build_module.h"
 #include "opt/def_use_manager.h"
 #include "opt/ir_context.h"
+#include "opt/module.h"
 #include "pass_utils.h"
 #include "spirv-tools/libspirv.hpp"
 
@@ -131,12 +132,12 @@ TEST_P(ParseDefUseTest, Case) {
 
   // Build module.
   const std::vector<const char*> text = {tc.text};
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text));
-  ASSERT_NE(nullptr, module);
+  ASSERT_NE(nullptr, context);
 
   // Analyze def and use.
-  opt::analysis::DefUseManager manager(module.get());
+  opt::analysis::DefUseManager manager(context->module());
 
   CheckDef(tc.du, manager.id_to_defs());
   CheckUse(tc.du, manager.id_to_uses());
@@ -512,23 +513,22 @@ TEST_P(ReplaceUseTest, Case) {
 
   // Build module.
   const std::vector<const char*> text = {tc.before};
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text));
-  ASSERT_NE(nullptr, module);
-  ir::IRContext context(std::move(module), spvtools::MessageConsumer());
+  ASSERT_NE(nullptr, context);
 
   // Force a re-build of def-use manager.
-  context.InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse);
-  (void)context.get_def_use_mgr();
+  context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse);
+  (void)context->get_def_use_mgr();
 
   // Do the substitution.
   for (const auto& candidate : tc.candidates) {
-    context.ReplaceAllUsesWith(candidate.first, candidate.second);
+    context->ReplaceAllUsesWith(candidate.first, candidate.second);
   }
 
-  EXPECT_EQ(tc.after, DisassembleModule(context.module()));
-  CheckDef(tc.du, context.get_def_use_mgr()->id_to_defs());
-  CheckUse(tc.du, context.get_def_use_mgr()->id_to_uses());
+  EXPECT_EQ(tc.after, DisassembleModule(context->module()));
+  CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs());
+  CheckUse(tc.du, context->get_def_use_mgr()->id_to_uses());
 }
 
 // clang-format off
@@ -816,20 +816,19 @@ TEST_P(KillDefTest, Case) {
 
   // Build module.
   const std::vector<const char*> text = {tc.before};
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text));
-  ASSERT_NE(nullptr, module);
-  ir::IRContext context(std::move(module), spvtools::MessageConsumer());
+  ASSERT_NE(nullptr, context);
 
   // Analyze def and use.
-  opt::analysis::DefUseManager manager(module.get());
+  opt::analysis::DefUseManager manager(context->module());
 
   // Do the substitution.
-  for (const auto id : tc.ids_to_kill) context.KillDef(id);
+  for (const auto id : tc.ids_to_kill) context->KillDef(id);
 
-  EXPECT_EQ(tc.after, DisassembleModule(context.module()));
-  CheckDef(tc.du, context.get_def_use_mgr()->id_to_defs());
-  CheckUse(tc.du, context.get_def_use_mgr()->id_to_uses());
+  EXPECT_EQ(tc.after, DisassembleModule(context->module()));
+  CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs());
+  CheckUse(tc.du, context->get_def_use_mgr()->id_to_uses());
 }
 
 // clang-format off
@@ -1067,19 +1066,18 @@ TEST(DefUseTest, OpSwitch) {
       "      OpReturnValue %6 "
       "      OpFunctionEnd";
 
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original_text);
-  ASSERT_NE(nullptr, module);
-  ir::IRContext context(std::move(module), spvtools::MessageConsumer());
+  ASSERT_NE(nullptr, context);
 
   // Force a re-build of def-use manager.
-  context.InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse);
-  (void)context.get_def_use_mgr();
+  context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse);
+  (void)context->get_def_use_mgr();
 
   // Do a bunch replacements.
-  context.ReplaceAllUsesWith(9, 900);    // to unused id
-  context.ReplaceAllUsesWith(10, 1000);  // to unused id
-  context.ReplaceAllUsesWith(11, 7);     // to existing id
+  context->ReplaceAllUsesWith(9, 900);    // to unused id
+  context->ReplaceAllUsesWith(10, 1000);  // to unused id
+  context->ReplaceAllUsesWith(11, 7);     // to existing id
 
   // clang-format off
   const char modified_text[] =
@@ -1103,7 +1101,7 @@ TEST(DefUseTest, OpSwitch) {
             "OpFunctionEnd";
   // clang-format on
 
-  EXPECT_EQ(modified_text, DisassembleModule(context.module()));
+  EXPECT_EQ(modified_text, DisassembleModule(context->module()));
 
   InstDefUse def_uses = {};
   def_uses.defs = {
@@ -1118,10 +1116,10 @@ TEST(DefUseTest, OpSwitch) {
       {10, "%10 = OpLabel"},
       {11, "%11 = OpLabel"},
   };
-  CheckDef(def_uses, context.get_def_use_mgr()->id_to_defs());
+  CheckDef(def_uses, context->get_def_use_mgr()->id_to_defs());
 
   {
-    auto* use_list = context.get_def_use_mgr()->GetUses(6);
+    auto* use_list = context->get_def_use_mgr()->GetUses(6);
     ASSERT_NE(nullptr, use_list);
     EXPECT_EQ(2u, use_list->size());
     std::vector<SpvOp> opcodes = {use_list->front().inst->opcode(),
@@ -1129,7 +1127,7 @@ TEST(DefUseTest, OpSwitch) {
     EXPECT_THAT(opcodes, UnorderedElementsAre(SpvOpSwitch, SpvOpReturnValue));
   }
   {
-    auto* use_list = context.get_def_use_mgr()->GetUses(7);
+    auto* use_list = context->get_def_use_mgr()->GetUses(7);
     ASSERT_NE(nullptr, use_list);
     EXPECT_EQ(6u, use_list->size());
     std::vector<SpvOp> opcodes;
@@ -1143,44 +1141,15 @@ TEST(DefUseTest, OpSwitch) {
   }
   // Check all ids only used by OpSwitch after replacement.
   for (const auto id : {8, 900, 1000}) {
-    auto* use_list = context.get_def_use_mgr()->GetUses(id);
+    auto* use_list = context->get_def_use_mgr()->GetUses(id);
     ASSERT_NE(nullptr, use_list);
     EXPECT_EQ(1u, use_list->size());
     EXPECT_EQ(SpvOpSwitch, use_list->front().inst->opcode());
   }
 }
 
-// Creates an |result_id| = OpTypeInt 32 1 instruction.
-ir::Instruction Int32TypeInstruction(uint32_t result_id) {
-  return ir::Instruction(SpvOp::SpvOpTypeInt, 0, result_id,
-                         {ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {32}),
-                          ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {1})});
-}
-
-// Creates an |result_id| = OpConstantTrue/Flase |type_id| instruction.
-ir::Instruction ConstantBoolInstruction(bool value, uint32_t type_id,
-                                        uint32_t result_id) {
-  return ir::Instruction(
-      value ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, type_id,
-      result_id, {});
-}
-
-// Creates an |result_id| = OpLabel instruction.
-ir::Instruction LabelInstruction(uint32_t result_id) {
-  return ir::Instruction(SpvOp::SpvOpLabel, 0, result_id, {});
-}
-
-// Creates an OpBranch |target_id| instruction.
-ir::Instruction BranchInstruction(uint32_t target_id) {
-  return ir::Instruction(SpvOp::SpvOpBranch, 0, 0,
-                         {
-                             ir::Operand(SPV_OPERAND_TYPE_ID, {target_id}),
-                         });
-}
-
 // Test case for analyzing individual instructions.
 struct AnalyzeInstDefUseTestCase {
-  std::vector<ir::Instruction> insts;  // instrutions to be analyzed in order.
   const char* module_text;
   InstDefUse expected_define_use;
 };
@@ -1193,15 +1162,12 @@ TEST_P(AnalyzeInstDefUseTest, Case) {
   auto tc = GetParam();
 
   // Build module.
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.module_text);
-  ASSERT_NE(nullptr, module);
+  ASSERT_NE(nullptr, context);
 
   // Analyze the instructions.
-  opt::analysis::DefUseManager manager(module.get());
-  for (ir::Instruction& inst : tc.insts) {
-    manager.AnalyzeInstDefUse(&inst);
-  }
+  opt::analysis::DefUseManager manager(context->module());
 
   CheckDef(tc.expected_define_use, manager.id_to_defs());
   CheckUse(tc.expected_define_use, manager.id_to_uses());
@@ -1212,8 +1178,7 @@ INSTANTIATE_TEST_CASE_P(
     TestCase, AnalyzeInstDefUseTest,
     ::testing::ValuesIn(std::vector<AnalyzeInstDefUseTestCase>{
       { // A type declaring instruction.
-        {Int32TypeInstruction(1)},
-        "",
+        "%1 = OpTypeInt 32 1",
         {
           // defs
           {{1, "%1 = OpTypeInt 32 1"}},
@@ -1221,88 +1186,79 @@ INSTANTIATE_TEST_CASE_P(
         },
       },
       { // A type declaring instruction and a constant value.
-        {
-          Int32TypeInstruction(1),
-          ConstantBoolInstruction(true, 1, 2),
-        },
-        "",
-        {
-          { // defs
-            {1, "%1 = OpTypeInt 32 1"},
-            {2, "%2 = OpConstantTrue %1"}, // It is fine the SPIR-V code here is invalid.
-          },
-          { // uses
-            {1, {"%2 = OpConstantTrue %1"}},
-          },
-        },
-      },
-      { // Analyze two instrutions that have same result id. The def use info
-        // of the result id from the first instruction should be overwritten by
-        // the second instruction.
-        {
-          ConstantBoolInstruction(true, 1, 2),
-          // The def-use info of the following instruction should overwrite the
-          // records of the above one.
-          ConstantBoolInstruction(false, 3, 2),
-        },
-        "",
-        {
-          // defs
-          {{2, "%2 = OpConstantFalse %3"}},
-          // uses
-          {{3, {"%2 = OpConstantFalse %3"}}}
-        }
-      },
-      { // Analyze forward reference instruction, also instruction that does
-        // not have result id.
-        {
-          BranchInstruction(2),
-          LabelInstruction(2),
-        },
-        "",
-        {
-          // defs
-          {{2, "%2 = OpLabel"}},
-          // uses
-          {{2, {"OpBranch %2"}}},
-        }
-      },
-      { // Analyzing an additional instruction with new result id to an
-        // existing module.
-        {
-          ConstantBoolInstruction(true, 1, 2),
-        },
-        "%1 = OpTypeInt 32 1 ",
+        "%1 = OpTypeBool "
+        "%2 = OpConstantTrue %1",
         {
           { // defs
-            {1, "%1 = OpTypeInt 32 1"},
+            {1, "%1 = OpTypeBool"},
             {2, "%2 = OpConstantTrue %1"},
           },
           { // uses
             {1, {"%2 = OpConstantTrue %1"}},
           },
-        }
-      },
-      { // Analyzing an additional instruction with existing result id to an
-        // existing module.
-        {
-          ConstantBoolInstruction(true, 1, 2),
         },
-        "%1 = OpTypeInt 32 1 "
-        "%2 = OpTypeBool ",
-        {
-          { // defs
-            {1, "%1 = OpTypeInt 32 1"},
-            {2, "%2 = OpConstantTrue %1"},
-          },
-          { // uses
-            {1, {"%2 = OpConstantTrue %1"}},
-          },
-        }
       },
       }));
 // clang-format on
 
+using AnalyzeInstDefUse = ::testing::Test;
+
+TEST(AnalyzeInstDefUse, UseWithNoResultId) {
+  ir::IRContext context(nullptr);
+
+  // Analyze the instructions.
+  opt::analysis::DefUseManager manager(context.module());
+
+  ir::Instruction label(&context, SpvOpLabel, 0, 2, {});
+  manager.AnalyzeInstDefUse(&label);
+
+  ir::Instruction branch(&context, SpvOpBranch, 0, 0,
+                         {{SPV_OPERAND_TYPE_ID, {2}}});
+  manager.AnalyzeInstDefUse(&branch);
+
+  InstDefUse expected =
+  {
+    // defs
+    {
+      {2, "%2 = OpLabel"},
+    },
+    // uses
+    {{2, {"OpBranch %2"}}},
+  };
+
+  CheckDef(expected, manager.id_to_defs());
+  CheckUse(expected, manager.id_to_uses());
+}
+
+TEST(AnalyzeInstDefUse, AddNewInstruction) {
+  const std::string input = "%1 = OpTypeBool";
+
+  // Build module.
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input);
+  ASSERT_NE(nullptr, context);
+
+  // Analyze the instructions.
+  opt::analysis::DefUseManager manager(context->module());
+
+  ir::Instruction newInst(context.get(), SpvOpConstantTrue, 1, 2, {});
+  manager.AnalyzeInstDefUse(&newInst);
+
+  InstDefUse expected =
+  {
+    { // defs
+      {1, "%1 = OpTypeBool"},
+      {2, "%2 = OpConstantTrue %1"},
+    },
+    { // uses
+      {1, {"%2 = OpConstantTrue %1"}},
+    },
+  };
+
+  CheckDef(expected, manager.id_to_defs());
+  CheckUse(expected, manager.id_to_uses());
+}
+
 struct KillInstTestCase {
   const char* before;
   std::unordered_set<uint32_t> indices_for_inst_to_kill;
@@ -1316,27 +1272,26 @@ TEST_P(KillInstTest, Case) {
   auto tc = GetParam();
 
   // Build module.
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.before);
-  ASSERT_NE(nullptr, module);
-  ir::IRContext context(std::move(module), spvtools::MessageConsumer());
+  ASSERT_NE(nullptr, context);
 
   // Force a re-build of the def-use manager.
-  context.InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse);
-  (void)context.get_def_use_mgr();
+  context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse);
+  (void)context->get_def_use_mgr();
 
   // KillInst
   uint32_t index = 0;
-  context.module()->ForEachInst([&index, &tc, &context](ir::Instruction* inst) {
+  context->module()->ForEachInst([&index, &tc, &context](ir::Instruction* inst) {
     if (tc.indices_for_inst_to_kill.count(index) != 0) {
-      context.KillInst(inst);
+      context->KillInst(inst);
     }
     index++;
   });
 
-  EXPECT_EQ(tc.after, DisassembleModule(context.module()));
-  CheckDef(tc.expected_define_use, context.get_def_use_mgr()->id_to_defs());
-  CheckUse(tc.expected_define_use, context.get_def_use_mgr()->id_to_uses());
+  EXPECT_EQ(tc.after, DisassembleModule(context->module()));
+  CheckDef(tc.expected_define_use, context->get_def_use_mgr()->id_to_defs());
+  CheckUse(tc.expected_define_use, context->get_def_use_mgr()->id_to_uses());
 }
 
 // clang-format off
@@ -1428,12 +1383,12 @@ TEST_P(GetAnnotationsTest, Case) {
   const GetAnnotationsTestCase& tc = GetParam();
 
   // Build module.
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.code);
-  ASSERT_NE(nullptr, module);
+  ASSERT_NE(nullptr, context);
 
   // Get annotations
-  opt::analysis::DefUseManager manager(module.get());
+  opt::analysis::DefUseManager manager(context->module());
   auto insts = manager.GetAnnotations(tc.id);
 
   // Check
index 2db4ed2..f930d46 100644 (file)
@@ -13,6 +13,7 @@
 // limitations under the License.
 
 #include "opt/instruction.h"
+#include "opt/ir_context.h"
 
 #include "gmock/gmock.h"
 
@@ -23,6 +24,7 @@ namespace {
 
 using spvtest::MakeInstruction;
 using spvtools::ir::Instruction;
+using spvtools::ir::IRContext;
 using spvtools::ir::Operand;
 using ::testing::Eq;
 
@@ -39,7 +41,8 @@ TEST(InstructionTest, CreateTrivial) {
 }
 
 TEST(InstructionTest, CreateWithOpcodeAndNoOperands) {
-  Instruction inst(SpvOpReturn);
+  IRContext context(nullptr);
+  Instruction inst(&context, SpvOpReturn);
   EXPECT_EQ(SpvOpReturn, inst.opcode());
   EXPECT_EQ(0u, inst.type_id());
   EXPECT_EQ(0u, inst.result_id());
@@ -119,7 +122,8 @@ spv_parsed_instruction_t kSampleControlBarrierInstruction = {
     3};
 
 TEST(InstructionTest, CreateWithOpcodeAndOperands) {
-  Instruction inst(kSampleParsedInstruction);
+  IRContext context(nullptr);
+  Instruction inst(&context, kSampleParsedInstruction);
   EXPECT_EQ(SpvOpTypeInt, inst.opcode());
   EXPECT_EQ(0u, inst.type_id());
   EXPECT_EQ(44u, inst.result_id());
@@ -129,20 +133,23 @@ TEST(InstructionTest, CreateWithOpcodeAndOperands) {
 }
 
 TEST(InstructionTest, GetOperand) {
-  Instruction inst(kSampleParsedInstruction);
+  IRContext context(nullptr);
+  Instruction inst(&context, kSampleParsedInstruction);
   EXPECT_THAT(inst.GetOperand(0).words, Eq(std::vector<uint32_t>{44}));
   EXPECT_THAT(inst.GetOperand(1).words, Eq(std::vector<uint32_t>{32}));
   EXPECT_THAT(inst.GetOperand(2).words, Eq(std::vector<uint32_t>{1}));
 }
 
 TEST(InstructionTest, GetInOperand) {
-  Instruction inst(kSampleParsedInstruction);
+  IRContext context(nullptr);
+  Instruction inst(&context, kSampleParsedInstruction);
   EXPECT_THAT(inst.GetInOperand(0).words, Eq(std::vector<uint32_t>{32}));
   EXPECT_THAT(inst.GetInOperand(1).words, Eq(std::vector<uint32_t>{1}));
 }
 
 TEST(InstructionTest, OperandConstIterators) {
-  Instruction inst(kSampleParsedInstruction);
+  IRContext context(nullptr);
+  Instruction inst(&context, kSampleParsedInstruction);
   // Spot check iteration across operands.
   auto cbegin = inst.cbegin();
   auto cend = inst.cend();
@@ -168,7 +175,8 @@ TEST(InstructionTest, OperandConstIterators) {
 }
 
 TEST(InstructionTest, OperandIterators) {
-  Instruction inst(kSampleParsedInstruction);
+  IRContext context(nullptr);
+  Instruction inst(&context, kSampleParsedInstruction);
   // Spot check iteration across operands, with mutable iterators.
   auto begin = inst.begin();
   auto end = inst.end();
@@ -198,7 +206,8 @@ TEST(InstructionTest, OperandIterators) {
 }
 
 TEST(InstructionTest, ForInIdStandardIdTypes) {
-  Instruction inst(kSampleAccessChainInstruction);
+  IRContext context(nullptr);
+  Instruction inst(&context, kSampleAccessChainInstruction);
 
   std::vector<uint32_t> ids;
   inst.ForEachInId([&ids](const uint32_t* idptr) { ids.push_back(*idptr); });
@@ -210,7 +219,8 @@ TEST(InstructionTest, ForInIdStandardIdTypes) {
 }
 
 TEST(InstructionTest, ForInIdNonstandardIdTypes) {
-  Instruction inst(kSampleControlBarrierInstruction);
+  IRContext context(nullptr);
+  Instruction inst(&context, kSampleControlBarrierInstruction);
 
   std::vector<uint32_t> ids;
   inst.ForEachInId([&ids](const uint32_t* idptr) { ids.push_back(*idptr); });
@@ -221,4 +231,60 @@ TEST(InstructionTest, ForInIdNonstandardIdTypes) {
   EXPECT_THAT(ids, Eq(std::vector<uint32_t>{100, 101, 102}));
 }
 
+TEST(InstructionTest, UniqueIds) {
+  IRContext context(nullptr);
+  Instruction inst1(&context);
+  Instruction inst2(&context);
+  EXPECT_NE(inst1.unique_id(), inst2.unique_id());
+}
+
+TEST(InstructionTest, CloneUniqueIdDifferent) {
+  IRContext context(nullptr);
+  Instruction inst(&context);
+  std::unique_ptr<Instruction> clone(inst.Clone(&context));
+  EXPECT_EQ(inst.context(), clone->context());
+  EXPECT_NE(inst.unique_id(), clone->unique_id());
+}
+
+TEST(InstructionTest, CloneDifferentContext) {
+  IRContext c1(nullptr);
+  IRContext c2(nullptr);
+  Instruction inst(&c1);
+  std::unique_ptr<Instruction> clone(inst.Clone(&c2));
+  EXPECT_EQ(&c1, inst.context());
+  EXPECT_EQ(&c2, clone->context());
+  EXPECT_NE(&c1, &c2);
+}
+
+TEST(InstructionTest, CloneDifferentContextDifferentUniqueId) {
+  IRContext c1(nullptr);
+  IRContext c2(nullptr);
+  Instruction inst(&c1);
+  Instruction other(&c2);
+  std::unique_ptr<Instruction> clone(inst.Clone(&c2));
+  EXPECT_EQ(&c2, clone->context());
+  EXPECT_NE(other.unique_id(), clone->unique_id());
+}
+
+TEST(InstructionTest, EqualsEqualsOperator) {
+  IRContext context(nullptr);
+  Instruction i1(&context);
+  Instruction i2(&context);
+  std::unique_ptr<Instruction> clone(i1.Clone(&context));
+  EXPECT_TRUE(i1 == i1);
+  EXPECT_FALSE(i1 == i2);
+  EXPECT_FALSE(i1 == *clone);
+  EXPECT_FALSE(i2 == *clone);
+}
+
+TEST(InstructionTest, LessThanOperator) {
+  IRContext context(nullptr);
+  Instruction i1(&context);
+  Instruction i2(&context);
+  std::unique_ptr<Instruction> clone(i1.Clone(&context));
+  EXPECT_TRUE(i1 < i2);
+  EXPECT_TRUE(i1 < *clone);
+  EXPECT_TRUE(i2 < *clone);
+}
+
 }  // anonymous namespace
index e2ace77..770c306 100644 (file)
@@ -62,97 +62,97 @@ using IRContextTest = PassTest<::testing::Test>;
 
 TEST_F(IRContextTest, IndividualValidAfterBuild) {
   std::unique_ptr<ir::Module> module(new ir::Module());
-  IRContext context(std::move(module), spvtools::MessageConsumer());
+  IRContext localContext(std::move(module), spvtools::MessageConsumer());
 
   for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd;
        i <<= 1) {
-    context.BuildInvalidAnalyses(i);
-    EXPECT_TRUE(context.AreAnalysesValid(i));
+    localContext.BuildInvalidAnalyses(i);
+    EXPECT_TRUE(localContext.AreAnalysesValid(i));
   }
 }
 
 TEST_F(IRContextTest, AllValidAfterBuild) {
   std::unique_ptr<ir::Module> module = MakeUnique<ir::Module>();
-  IRContext context(std::move(module), spvtools::MessageConsumer());
+  IRContext localContext(std::move(module), spvtools::MessageConsumer());
 
   Analysis built_analyses = IRContext::kAnalysisNone;
   for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd;
        i <<= 1) {
-    context.BuildInvalidAnalyses(i);
+    localContext.BuildInvalidAnalyses(i);
     built_analyses |= i;
   }
-  EXPECT_TRUE(context.AreAnalysesValid(built_analyses));
+  EXPECT_TRUE(localContext.AreAnalysesValid(built_analyses));
 }
 
 TEST_F(IRContextTest, AllValidAfterPassNoChange) {
   std::unique_ptr<ir::Module> module = MakeUnique<ir::Module>();
-  IRContext context(std::move(module), spvtools::MessageConsumer());
+  IRContext localContext(std::move(module), spvtools::MessageConsumer());
 
   Analysis built_analyses = IRContext::kAnalysisNone;
   for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd;
        i <<= 1) {
-    context.BuildInvalidAnalyses(i);
+    localContext.BuildInvalidAnalyses(i);
     built_analyses |= i;
   }
 
   DummyPassPreservesNothing pass(opt::Pass::Status::SuccessWithoutChange);
-  opt::Pass::Status s = pass.Run(&context);
+  opt::Pass::Status s = pass.Run(&localContext);
   EXPECT_EQ(s, opt::Pass::Status::SuccessWithoutChange);
-  EXPECT_TRUE(context.AreAnalysesValid(built_analyses));
+  EXPECT_TRUE(localContext.AreAnalysesValid(built_analyses));
 }
 
 TEST_F(IRContextTest, NoneValidAfterPassWithChange) {
   std::unique_ptr<ir::Module> module = MakeUnique<ir::Module>();
-  IRContext context(std::move(module), spvtools::MessageConsumer());
+  IRContext localContext(std::move(module), spvtools::MessageConsumer());
 
   for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd;
        i <<= 1) {
-    context.BuildInvalidAnalyses(i);
+    localContext.BuildInvalidAnalyses(i);
   }
 
   DummyPassPreservesNothing pass(opt::Pass::Status::SuccessWithChange);
-  opt::Pass::Status s = pass.Run(&context);
+  opt::Pass::Status s = pass.Run(&localContext);
   EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange);
   for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd;
        i <<= 1) {
-    EXPECT_FALSE(context.AreAnalysesValid(i));
+    EXPECT_FALSE(localContext.AreAnalysesValid(i));
   }
 }
 
 TEST_F(IRContextTest, AllPreservedAfterPassWithChange) {
   std::unique_ptr<ir::Module> module = MakeUnique<ir::Module>();
-  IRContext context(std::move(module), spvtools::MessageConsumer());
+  IRContext localContext(std::move(module), spvtools::MessageConsumer());
 
   for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd;
        i <<= 1) {
-    context.BuildInvalidAnalyses(i);
+    localContext.BuildInvalidAnalyses(i);
   }
 
   DummyPassPreservesAll pass(opt::Pass::Status::SuccessWithChange);
-  opt::Pass::Status s = pass.Run(&context);
+  opt::Pass::Status s = pass.Run(&localContext);
   EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange);
   for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd;
        i <<= 1) {
-    EXPECT_TRUE(context.AreAnalysesValid(i));
+    EXPECT_TRUE(localContext.AreAnalysesValid(i));
   }
 }
 
 TEST_F(IRContextTest, PreserveFirstOnlyAfterPassWithChange) {
   std::unique_ptr<ir::Module> module = MakeUnique<ir::Module>();
-  IRContext context(std::move(module), spvtools::MessageConsumer());
+  IRContext localContext(std::move(module), spvtools::MessageConsumer());
 
   for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd;
        i <<= 1) {
-    context.BuildInvalidAnalyses(i);
+    localContext.BuildInvalidAnalyses(i);
   }
 
   DummyPassPreservesFirst pass(opt::Pass::Status::SuccessWithChange);
-  opt::Pass::Status s = pass.Run(&context);
+  opt::Pass::Status s = pass.Run(&localContext);
   EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange);
-  EXPECT_TRUE(context.AreAnalysesValid(IRContext::kAnalysisBegin));
+  EXPECT_TRUE(localContext.AreAnalysesValid(IRContext::kAnalysisBegin));
   for (Analysis i = IRContext::kAnalysisBegin << 1; i < IRContext::kAnalysisEnd;
        i <<= 1) {
-    EXPECT_FALSE(context.AreAnalysesValid(i));
+    EXPECT_FALSE(localContext.AreAnalysesValid(i));
   }
 }
 
@@ -178,25 +178,31 @@ TEST_F(IRContextTest, KillMemberName) {
                OpFunctionEnd
 )";
 
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
-  ir::IRContext context(std::move(module), spvtools::MessageConsumer());
 
   // Build the decoration manager.
-  context.get_decoration_mgr();
+  context->get_decoration_mgr();
 
   // Delete the OpTypeStruct.  Should delete the OpName, OpMemberName, and
   // OpMemberDecorate associated with it.
-  context.KillDef(3);
+  context->KillDef(3);
 
   // Make sure all of the name are removed.
-  for (auto& inst : context.debugs2()) {
+  for (auto& inst : context->debugs2()) {
     EXPECT_EQ(inst.opcode(), SpvOpNop);
   }
 
   // Make sure all of the decorations are removed.
-  for (auto& inst : context.annotations()) {
+  for (auto& inst : context->annotations()) {
     EXPECT_EQ(inst.opcode(), SpvOpNop);
   }
 }
+
+TEST_F(IRContextTest, TakeNextUniqueIdIncrementing) {
+  const uint32_t NUM_TESTS = 1000;
+  IRContext localContext(nullptr);
+  for (uint32_t i = 1; i < NUM_TESTS; ++i)
+    EXPECT_EQ(i, localContext.TakeNextUniqueId());
+}
 }  // anonymous namespace
index b61f7cb..ae46df9 100644 (file)
 
 #include <gtest/gtest.h>
 #include <algorithm>
+#include <unordered_set>
 
 #include "message.h"
 #include "opt/build_module.h"
+#include "opt/ir_context.h"
 #include "spirv-tools/libspirv.hpp"
 
 namespace {
@@ -25,12 +27,12 @@ using namespace spvtools;
 
 void DoRoundTripCheck(const std::string& text) {
   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  ASSERT_NE(nullptr, module) << "Failed to assemble\n" << text;
+  ASSERT_NE(nullptr, context) << "Failed to assemble\n" << text;
 
   std::vector<uint32_t> binary;
-  module->ToBinary(&binary, /* skip_nop = */ false);
+  context->module()->ToBinary(&binary, /* skip_nop = */ false);
 
   std::string disassembled_text;
   EXPECT_TRUE(t.Disassemble(binary, &disassembled_text));
@@ -212,17 +214,17 @@ TEST(IrBuilder, OpUndefOutsideFunction) {
   // clang-format on
 
   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  ASSERT_NE(nullptr, module);
+  ASSERT_NE(nullptr, context);
 
   const auto opundef_count = std::count_if(
-      module->types_values_begin(), module->types_values_end(),
+      context->module()->types_values_begin(), context->module()->types_values_end(),
       [](const ir::Instruction& inst) { return inst.opcode() == SpvOpUndef; });
   EXPECT_EQ(3, opundef_count);
 
   std::vector<uint32_t> binary;
-  module->ToBinary(&binary, /* skip_nop = */ false);
+  context->module()->ToBinary(&binary, /* skip_nop = */ false);
 
   std::string disassembled_text;
   EXPECT_TRUE(t.Disassemble(binary, &disassembled_text));
@@ -322,9 +324,9 @@ void DoErrorMessageCheck(const std::string& assembly,
   };
 
   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, std::move(consumer), assembly);
-  EXPECT_EQ(nullptr, module);
+  EXPECT_EQ(nullptr, context);
 }
 
 TEST(IrBuilder, FunctionInsideFunction) {
@@ -378,4 +380,69 @@ TEST(IrBuilder, NotAllowedInstAppearingInFunction) {
                       "block");
 }
 
+TEST(IrBuilder, UniqueIds) {
+  const std::string text =
+      // clang-format off
+               "OpCapability Shader\n"
+          "%1 = OpExtInstImport \"GLSL.std.450\"\n"
+               "OpMemoryModel Logical GLSL450\n"
+               "OpEntryPoint Vertex %main \"main\"\n"
+               "OpSource ESSL 310\n"
+               "OpName %main \"main\"\n"
+               "OpName %f_ \"f(\"\n"
+               "OpName %gv1 \"gv1\"\n"
+               "OpName %gv2 \"gv2\"\n"
+               "OpName %lv1 \"lv1\"\n"
+               "OpName %lv2 \"lv2\"\n"
+               "OpName %lv1_0 \"lv1\"\n"
+       "%void = OpTypeVoid\n"
+         "%10 = OpTypeFunction %void\n"
+      "%float = OpTypeFloat 32\n"
+         "%12 = OpTypeFunction %float\n"
+ "%_ptr_Private_float = OpTypePointer Private %float\n"
+        "%gv1 = OpVariable %_ptr_Private_float Private\n"
+   "%float_10 = OpConstant %float 10\n"
+        "%gv2 = OpVariable %_ptr_Private_float Private\n"
+  "%float_100 = OpConstant %float 100\n"
+ "%_ptr_Function_float = OpTypePointer Function %float\n"
+       "%main = OpFunction %void None %10\n"
+         "%17 = OpLabel\n"
+      "%lv1_0 = OpVariable %_ptr_Function_float Function\n"
+               "OpStore %gv1 %float_10\n"
+               "OpStore %gv2 %float_100\n"
+         "%18 = OpLoad %float %gv1\n"
+         "%19 = OpLoad %float %gv2\n"
+         "%20 = OpFSub %float %18 %19\n"
+               "OpStore %lv1_0 %20\n"
+               "OpReturn\n"
+               "OpFunctionEnd\n"
+         "%f_ = OpFunction %float None %12\n"
+         "%21 = OpLabel\n"
+        "%lv1 = OpVariable %_ptr_Function_float Function\n"
+        "%lv2 = OpVariable %_ptr_Function_float Function\n"
+         "%22 = OpLoad %float %gv1\n"
+         "%23 = OpLoad %float %gv2\n"
+         "%24 = OpFAdd %float %22 %23\n"
+               "OpStore %lv1 %24\n"
+         "%25 = OpLoad %float %gv1\n"
+         "%26 = OpLoad %float %gv2\n"
+         "%27 = OpFMul %float %25 %26\n"
+               "OpStore %lv2 %27\n"
+         "%28 = OpLoad %float %lv1\n"
+         "%29 = OpLoad %float %lv2\n"
+         "%30 = OpFDiv %float %28 %29\n"
+               "OpReturnValue %30\n"
+               "OpFunctionEnd\n";
+  // clang-format on
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
+  ASSERT_NE(nullptr, context);
+
+  std::unordered_set<uint32_t> ids;
+  context->module()->ForEachInst([&ids](const ir::Instruction* inst) {
+    EXPECT_TRUE(ids.insert(inst->unique_id()).second);
+  });
+}
+
 }  // anonymous namespace
index 622d920..4a434ed 100644 (file)
@@ -26,6 +26,7 @@
 
 namespace {
 
+using spvtools::ir::IRContext;
 using spvtools::ir::Module;
 using spvtest::GetIdBound;
 using ::testing::Eq;
@@ -42,31 +43,31 @@ TEST(ModuleTest, SetIdBound) {
   EXPECT_EQ(102u, GetIdBound(m));
 }
 
-// Returns a module formed by assembling the given text,
+// Returns an IRContext owning the module formed by assembling the given text,
 // then loading the result.
-inline std::unique_ptr<Module> BuildModule(std::string text) {
+inline std::unique_ptr<IRContext> BuildModule(std::string text) {
   return spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
 }
 
 TEST(ModuleTest, ComputeIdBound) {
   // Emtpy module case.
-  EXPECT_EQ(1u, BuildModule("")->ComputeIdBound());
+  EXPECT_EQ(1u, BuildModule("")->module()->ComputeIdBound());
   // Sensitive to result id
-  EXPECT_EQ(2u, BuildModule("%void = OpTypeVoid")->ComputeIdBound());
+  EXPECT_EQ(2u, BuildModule("%void = OpTypeVoid")->module()->ComputeIdBound());
   // Sensitive to type id
-  EXPECT_EQ(1000u, BuildModule("%a = OpTypeArray !999 3")->ComputeIdBound());
+  EXPECT_EQ(1000u, BuildModule("%a = OpTypeArray !999 3")->module()->ComputeIdBound());
   // Sensitive to a regular Id parameter
-  EXPECT_EQ(2000u, BuildModule("OpDecorate !1999 0")->ComputeIdBound());
+  EXPECT_EQ(2000u, BuildModule("OpDecorate !1999 0")->module()->ComputeIdBound());
   // Sensitive to a scope Id parameter.
   EXPECT_EQ(3000u,
             BuildModule("%f = OpFunction %void None %fntype %a = OpLabel "
                         "OpMemoryBarrier !2999 %b\n")
-                ->ComputeIdBound());
+                ->module()->ComputeIdBound());
   // Sensitive to a semantics Id parameter
   EXPECT_EQ(4000u,
             BuildModule("%f = OpFunction %void None %fntype %a = OpLabel "
                         "OpMemoryBarrier %b !3999\n")
-                ->ComputeIdBound());
+                ->module()->ComputeIdBound());
 }
 
 }  // anonymous namespace
index 7ad7181..fdc4398 100644 (file)
@@ -46,36 +46,35 @@ class PassTest : public TestT {
  public:
   PassTest()
       : consumer_(nullptr),
+        context_(nullptr),
         tools_(SPV_ENV_UNIVERSAL_1_1),
         manager_(new opt::PassManager()),
         assemble_options_(SpirvTools::kDefaultAssembleOption),
         disassemble_options_(SpirvTools::kDefaultDisassembleOption) {}
 
   // Runs the given |pass| on the binary assembled from the |original|.
-  // Returns a tuple of the optimized binary and the boolean value returned 
+  // Returns a tuple of the optimized binary and the boolean value returned
   // from pass Process() function.
   std::tuple<std::vector<uint32_t>, opt::Pass::Status> OptimizeToBinary(
       opt::Pass* pass, const std::string& original, bool skip_nop) {
-    std::unique_ptr<ir::Module> module = BuildModule(
-        SPV_ENV_UNIVERSAL_1_1, consumer_, original, assemble_options_);
-    EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
-                               << original << std::endl;
-    if (!module) {
-      return std::make_tuple(std::vector<uint32_t>(), 
+    context_ = std::move(BuildModule(SPV_ENV_UNIVERSAL_1_1, consumer_, original,
+                                     assemble_options_));
+    EXPECT_NE(nullptr, context()) << "Assembling failed for shader:\n"
+                                  << original << std::endl;
+    if (!context()) {
+      return std::make_tuple(std::vector<uint32_t>(),
           opt::Pass::Status::Failure);
     }
 
-    ir::IRContext context(std::move(module), consumer());
-
-    const auto status = pass->Run(&context);
+    const auto status = pass->Run(context());
 
     std::vector<uint32_t> binary;
-    context.module()->ToBinary(&binary, skip_nop);
+    context()->module()->ToBinary(&binary, skip_nop);
     return std::make_tuple(binary, status);
   }
 
   // Runs a single pass of class |PassT| on the binary assembled from the
-  // |assembly|. Returns a tuple of the optimized binary and the boolean value 
+  // |assembly|. Returns a tuple of the optimized binary and the boolean value
   // from the pass Process() function.
   template <typename PassT, typename... Args>
   std::tuple<std::vector<uint32_t>, opt::Pass::Status> SinglePassRunToBinary(
@@ -106,7 +105,7 @@ class PassTest : public TestT {
   // Runs a single pass of class |PassT| on the binary assembled from the
   // |original| assembly, and checks whether the optimized binary can be
   // disassembled to the |expected| assembly. Optionally will also validate
-  // the optimized binary. This does *not* involve pass manager. Callers 
+  // the optimized binary. This does *not* involve pass manager. Callers
   // are suggested to use SCOPED_TRACE() for better messages.
   template <typename PassT, typename... Args>
   void SinglePassRunAndCheck(const std::string& original,
@@ -122,16 +121,16 @@ class PassTest : public TestT {
               status == opt::Pass::Status::SuccessWithoutChange);
     if (do_validation) {
       spv_target_env target_env = SPV_ENV_UNIVERSAL_1_1;
-      spv_context context = spvContextCreate(target_env);
+      spv_context spvContext = spvContextCreate(target_env);
       spv_diagnostic diagnostic = nullptr;
       spv_const_binary_t binary = {optimized_bin.data(),
           optimized_bin.size()};
-      spv_result_t error = spvValidate(context, &binary, &diagnostic);
+      spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
       EXPECT_EQ(error, 0);
       if (error != 0)
         spvDiagnosticPrint(diagnostic);
       spvDiagnosticDestroy(diagnostic);
-      spvContextDestroy(context);
+      spvContextDestroy(spvContext);
     }
     std::string optimized_asm;
     EXPECT_TRUE(tools_.Disassemble(optimized_bin, &optimized_asm,
@@ -191,15 +190,14 @@ class PassTest : public TestT {
   void RunAndCheck(const std::string& original, const std::string& expected) {
     assert(manager_->NumPasses());
 
-    std::unique_ptr<ir::Module> module = BuildModule(
-        SPV_ENV_UNIVERSAL_1_1, nullptr, original, assemble_options_);
-    ASSERT_NE(nullptr, module);
-    ir::IRContext context(std::move(module), consumer());
+    context_ = std::move(BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original,
+                                     assemble_options_));
+    ASSERT_NE(nullptr, context());
 
-    manager_->Run(&context);
+    manager_->Run(context());
 
     std::vector<uint32_t> binary;
-    context.module()->ToBinary(&binary, /* skip_nop = */ false);
+    context()->module()->ToBinary(&binary, /* skip_nop = */ false);
 
     std::string optimized;
     EXPECT_TRUE(tools_.Disassemble(binary, &optimized,
@@ -216,8 +214,10 @@ class PassTest : public TestT {
   }
 
   MessageConsumer consumer() { return consumer_;}
+  ir::IRContext* context() { return context_.get(); }
  private:
   MessageConsumer consumer_;  // Message consumer.
+  std::unique_ptr<ir::IRContext> context_; // IR context
   SpirvTools tools_;  // An instance for calling SPIRV-Tools functionalities.
   std::unique_ptr<opt::PassManager> manager_;  // The pass manager.
   uint32_t assemble_options_;
index 43d7005..77ed38b 100644 (file)
@@ -75,7 +75,7 @@ class AppendOpNopPass : public opt::Pass {
  public:
   const char* name() const override { return "AppendOpNop"; }
   Status Process(ir::IRContext* irContext) override {
-    irContext->AddDebug1Inst(MakeUnique<ir::Instruction>());
+    irContext->AddDebug1Inst(MakeUnique<ir::Instruction>(irContext));
     return Status::SuccessWithChange;
   }
 };
@@ -89,7 +89,7 @@ class AppendMultipleOpNopPass : public opt::Pass {
   const char* name() const override { return "AppendOpNop"; }
   Status Process(ir::IRContext* irContext) override {
     for (uint32_t i = 0; i < num_nop_; i++) {
-      irContext->AddDebug1Inst(MakeUnique<ir::Instruction>());
+      irContext->AddDebug1Inst(MakeUnique<ir::Instruction>(irContext));
     }
     return Status::SuccessWithChange;
   }
@@ -103,7 +103,8 @@ class DuplicateInstPass : public opt::Pass {
  public:
   const char* name() const override { return "DuplicateInst"; }
   Status Process(ir::IRContext* irContext) override {
-    auto inst = MakeUnique<ir::Instruction>(*(--irContext->debug1_end()));
+    auto inst = MakeUnique<ir::Instruction>(
+        *(--irContext->debug1_end())->Clone(irContext));
     irContext->AddDebug1Inst(std::move(inst));
     return Status::SuccessWithChange;
   }
@@ -140,7 +141,7 @@ class AppendTypeVoidInstPass : public opt::Pass {
 
   const char* name() const override { return "AppendTypeVoidInstPass"; }
   Status Process(ir::IRContext* irContext) override {
-    auto inst = MakeUnique<ir::Instruction>(SpvOpTypeVoid, 0, result_id_,
+    auto inst = MakeUnique<ir::Instruction>(irContext, SpvOpTypeVoid, 0, result_id_,
                                             std::vector<ir::Operand>{});
     irContext->AddType(std::move(inst));
     return Status::SuccessWithChange;
index 8c62b28..5ff1a12 100644 (file)
@@ -76,18 +76,18 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
 )";
   // clang-format on
 
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> localContext =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
-                             << text << std::endl;
+  EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
+                                   << text << std::endl;
   DummyPass testPass;
   std::vector<uint32_t> processed;
   opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) {
     processed.push_back(fp->result_id());
     return false;
   };
-  testPass.ProcessEntryPointCallTree(mark_visited, module.get());
+  testPass.ProcessEntryPointCallTree(mark_visited, localContext->module());
   EXPECT_THAT(processed, UnorderedElementsAre(10, 11));
 }
 
@@ -132,12 +132,11 @@ TEST_F(PassClassTest, BasicVisitReachable) {
 )";
   // clang-format on
 
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> localContext =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
-                             << text << std::endl;
-  ir::IRContext context(std::move(module), consumer());
+  EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
+                                   << text << std::endl;
 
   DummyPass testPass;
   std::vector<uint32_t> processed;
@@ -145,7 +144,7 @@ TEST_F(PassClassTest, BasicVisitReachable) {
     processed.push_back(fp->result_id());
     return false;
   };
-  testPass.ProcessReachableCallTree(mark_visited, &context);
+  testPass.ProcessReachableCallTree(mark_visited, localContext.get());
   EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12, 13));
 }
 
@@ -185,12 +184,11 @@ TEST_F(PassClassTest, BasicVisitOnlyOnce) {
 )";
   // clang-format on
 
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> localContext =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
-                             << text << std::endl;
-  ir::IRContext context(std::move(module), consumer());
+  EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
+                                   << text << std::endl;
 
   DummyPass testPass;
   std::vector<uint32_t> processed;
@@ -198,7 +196,7 @@ TEST_F(PassClassTest, BasicVisitOnlyOnce) {
     processed.push_back(fp->result_id());
     return false;
   };
-  testPass.ProcessReachableCallTree(mark_visited, &context);
+  testPass.ProcessReachableCallTree(mark_visited, localContext.get());
   EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12));
 }
 
@@ -228,12 +226,11 @@ TEST_F(PassClassTest, BasicDontVisitExportedVariable) {
 )";
   // clang-format on
 
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> localContext =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
-                             << text << std::endl;
-  ir::IRContext context(std::move(module), consumer());
+  EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
+                                   << text << std::endl;
 
   DummyPass testPass;
   std::vector<uint32_t> processed;
@@ -241,7 +238,7 @@ TEST_F(PassClassTest, BasicDontVisitExportedVariable) {
     processed.push_back(fp->result_id());
     return false;
   };
-  testPass.ProcessReachableCallTree(mark_visited, &context);
+  testPass.ProcessReachableCallTree(mark_visited, localContext.get());
   EXPECT_THAT(processed, UnorderedElementsAre(10));
 }
 }  // namespace
index 17eb2a4..1c8f2db 100644 (file)
@@ -88,9 +88,9 @@ TEST(TypeManager, TypeStrings) {
       {28, "named_barrier"},
   };
 
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *module);
+  opt::analysis::TypeManager manager(nullptr, *context->module());
 
   EXPECT_EQ(type_id_strs.size(), manager.NumTypes());
   EXPECT_EQ(2u, manager.NumForwardPointers());
@@ -118,9 +118,9 @@ TEST(TypeManager, DecorationOnStruct) {
     %struct4 = OpTypeStruct %u32 %f32 ; the same
     %struct7 = OpTypeStruct %f32      ; no decoration
   )";
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *module);
+  opt::analysis::TypeManager manager(nullptr, *context->module());
 
   ASSERT_EQ(7u, manager.NumTypes());
   ASSERT_EQ(0u, manager.NumForwardPointers());
@@ -168,9 +168,9 @@ TEST(TypeManager, DecorationOnMember) {
     %struct7  = OpTypeStruct %u32 %f32 ; extra decoration on the struct
     %struct10 = OpTypeStruct %u32 %f32 ; no member decoration
   )";
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *module);
+  opt::analysis::TypeManager manager(nullptr, *context->module());
 
   ASSERT_EQ(10u, manager.NumTypes());
   ASSERT_EQ(0u, manager.NumForwardPointers());
@@ -206,9 +206,9 @@ TEST(TypeManager, DecorationEmpty) {
     %struct2  = OpTypeStruct %f32 %u32
     %struct5  = OpTypeStruct %f32
   )";
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *module);
+  opt::analysis::TypeManager manager(nullptr, *context->module());
 
   ASSERT_EQ(5u, manager.NumTypes());
   ASSERT_EQ(0u, manager.NumForwardPointers());
@@ -228,9 +228,9 @@ TEST(TypeManager, DecorationEmpty) {
 
 TEST(TypeManager, BeginEndForEmptyModule) {
   const std::string text = "";
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *module);
+  opt::analysis::TypeManager manager(nullptr, *context->module());
   ASSERT_EQ(0u, manager.NumTypes());
   ASSERT_EQ(0u, manager.NumForwardPointers());
 
@@ -245,9 +245,9 @@ TEST(TypeManager, BeginEnd) {
     %u32     = OpTypeInt 32 0
     %f64     = OpTypeFloat 64
   )";
-  std::unique_ptr<ir::Module> module =
+  std::unique_ptr<ir::IRContext> context =
       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
-  opt::analysis::TypeManager manager(nullptr, *module);
+  opt::analysis::TypeManager manager(nullptr, *context->module());
   ASSERT_EQ(5u, manager.NumTypes());
   ASSERT_EQ(0u, manager.NumForwardPointers());