From a771713e4250358f2248eba95a16e362c8524b41 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Tue, 14 Nov 2017 14:11:50 -0500 Subject: [PATCH] Adding an unique id to Instruction generated by IRContext Each instruction is given an unique id that can be used for ordering purposes. The ids are generated via the IRContext. Major changes: * Instructions now contain a uint32_t for unique id and a cached context pointer * Most constructors have been modified to take a context as input * unfortunately I cannot remove the default and copy constructors, but developers should avoid these * Added accessors to parents of basic block and function * Removed the copy constructors for BasicBlock and Function and replaced them with Clone functions * Reworked BuildModule to return an IRContext owning the built module * Since all instructions require a context, the context now becomes the basic unit for IR * Added a constructor to context to create an owned module internally * Replaced uses of Instruction's copy constructor with Clone whereever I found them * Reworked the linker functionality to perform clones into a different context instead of moves * Updated many tests to be consistent with the above changes * Still need to add new tests to cover added functionality * Added comparison operators to Instruction * Added an internal option to LinkerOptions to verify merged ids are unique * Added a test for the linker to verify merged ids are unique * Updated MergeReturnPass to supply a context * Updated DecorationManager to supply a context for cloned decorations * Reworked several portions of the def use tests in anticipation of next set of changes --- include/spirv-tools/linker.hpp | 16 +- source/link/linker.cpp | 113 ++++++--- source/opt/aggressive_dead_code_elim_pass.cpp | 2 +- source/opt/basic_block.cpp | 15 +- source/opt/basic_block.h | 10 +- source/opt/build_module.cpp | 17 +- source/opt/build_module.h | 20 +- source/opt/cfg.cpp | 4 +- source/opt/common_uniform_elim_pass.cpp | 8 +- source/opt/dead_branch_elim_pass.cpp | 8 +- source/opt/decoration_manager.cpp | 6 +- source/opt/flatten_decoration_pass.cpp | 6 +- .../fold_spec_constant_op_and_composite_pass.cpp | 10 +- source/opt/function.cpp | 30 ++- source/opt/function.h | 10 +- source/opt/inline_pass.cpp | 30 +-- source/opt/instruction.cpp | 34 ++- source/opt/instruction.h | 42 +++- source/opt/ir_context.h | 35 ++- source/opt/ir_loader.cpp | 8 +- source/opt/ir_loader.h | 4 +- source/opt/local_access_chain_convert_pass.cpp | 2 +- source/opt/mem_pass.cpp | 6 +- source/opt/merge_return_pass.cpp | 13 +- source/opt/module.cpp | 2 +- source/opt/module.h | 9 + source/opt/optimizer.cpp | 9 +- source/opt/strength_reduction_pass.cpp | 6 +- test/link/CMakeLists.txt | 5 + test/link/unique_ids_test.cpp | 137 +++++++++++ test/opt/def_use_test.cpp | 261 +++++++++------------ test/opt/instruction_test.cpp | 82 ++++++- test/opt/ir_context_test.cpp | 64 ++--- test/opt/ir_loader_test.cpp | 85 ++++++- test/opt/module_test.cpp | 17 +- test/opt/pass_fixture.h | 44 ++-- test/opt/pass_manager_test.cpp | 9 +- test/opt/pass_test.cpp | 35 ++- test/opt/type_manager_test.cpp | 24 +- 39 files changed, 818 insertions(+), 420 deletions(-) create mode 100644 test/link/unique_ids_test.cpp diff --git a/include/spirv-tools/linker.hpp b/include/spirv-tools/linker.hpp index 43c725d..a36aa75 100644 --- a/include/spirv-tools/linker.hpp +++ b/include/spirv-tools/linker.hpp @@ -26,7 +26,9 @@ namespace spvtools { class LinkerOptions { public: - LinkerOptions() : createLibrary_(false) {} + LinkerOptions() + : createLibrary_(false), + verifyIds_(false) {} // Returns whether a library or an executable should be produced by the // linking phase. @@ -36,13 +38,25 @@ class LinkerOptions { // The returned value will be true if creating a library, and false if // creating an executable. bool GetCreateLibrary() const { return createLibrary_; } + // Sets whether a library or an executable should be produced. void SetCreateLibrary(bool create_library) { createLibrary_ = create_library; } + // Returns whether to verify the uniqueness of the unique ids in the merged + // context. + bool GetVerifyIds() const { return verifyIds_; } + + // Sets whether to verify the uniqueness of the unique ids in the merged + // context. + void SetVerifyIds(bool verifyIds) { + verifyIds_ = verifyIds; + } + private: bool createLibrary_; + bool verifyIds_; }; class Linker { diff --git a/source/link/linker.cpp b/source/link/linker.cpp index 59ea36c..7f1b5cd 100644 --- a/source/link/linker.cpp +++ b/source/link/linker.cpp @@ -38,6 +38,7 @@ namespace spvtools { using ir::Instruction; +using ir::IRContext; using ir::Module; using ir::Operand; using opt::PassManager; @@ -69,31 +70,34 @@ using LinkageTable = std::vector; // is returned in |max_id_bound|. // // Both |modules| and |max_id_bound| should not be null, and |modules| should -// not be empty either. +// not be empty either. Furthermore |modules| should not contain any null +// pointers. static spv_result_t ShiftIdsInModules( const MessageConsumer& consumer, - std::vector>* modules, uint32_t* max_id_bound); + std::vector* modules, uint32_t* max_id_bound); // Generates the header for the linked module and returns it in |header|. // -// |header| should not be null, |modules| should not be empty and -// |max_id_bound| should be strictly greater than 0. +// |header| should not be null, |modules| should not be empty and pointers +// should be non-null. |max_id_bound| should be strictly greater than 0. // // TODO(pierremoreau): What to do when binaries use different versions of // SPIR-V? For now, use the max of all versions found in // the input modules. static spv_result_t GenerateHeader( const MessageConsumer& consumer, - const std::vector>& modules, + const std::vector& modules, uint32_t max_id_bound, ir::ModuleHeader* header); -// Merge all the modules from |inModules| into |linked_module|. +// Merge all the modules from |inModules| into a single module owned by +// |linked_context|. // -// |linked_module| should not be null. +// |linked_context| should not be null. static spv_result_t MergeModules( const MessageConsumer& consumer, - const std::vector>& inModules, - const libspirv::AssemblyGrammar& grammar, Module* linked_module); + const std::vector& inModules, + const libspirv::AssemblyGrammar& grammar, + IRContext* linked_context); // Compute all pairs of import and export and return it in |linkings_to_do|. // @@ -123,7 +127,7 @@ static spv_result_t CheckImportExportCompatibility( // functions, declarations of imported variables, import (and export if // necessary) linkage attribtes. // -// |linked_module| and |decoration_manager| should not be null, and the +// |linked_context| and |decoration_manager| should not be null, and the // 'RemoveDuplicatePass' should be run first. // // TODO(pierremoreau): Linkage attributes applied by a group decoration are @@ -136,6 +140,11 @@ static spv_result_t RemoveLinkageSpecificInstructions( const LinkageTable& linkings_to_do, DecorationManager* decoration_manager, ir::IRContext* linked_context); +// Verify that the unique ids of each instruction in |linked_context| (i.e. the +// merged module) are truly unique. Does not check the validity of other ids +static spv_result_t VerifyIds(const MessageConsumer& consumer, + ir::IRContext* linked_context); + // Structs for holding the data members for SpvLinker. struct Linker::Impl { explicit Impl(spv_target_env env) : context(spvContextCreate(env)) { @@ -186,7 +195,8 @@ spv_result_t Linker::Link(const uint32_t* const* binaries, SPV_ERROR_INVALID_BINARY) << "No modules were given."; - std::vector> modules; + std::vector> contexts; + std::vector modules; modules.reserve(num_binaries); for (size_t i = 0u; i < num_binaries; ++i) { const uint32_t schema = binaries[i][4u]; @@ -197,13 +207,14 @@ spv_result_t Linker::Link(const uint32_t* const* binaries, << "Schema is non-zero for module " << i << "."; } - std::unique_ptr module = BuildModule( + std::unique_ptr context = BuildModule( impl_->context->target_env, consumer, binaries[i], binary_sizes[i]); - if (module == nullptr) + if (context == nullptr) return libspirv::DiagnosticStream(position, consumer, SPV_ERROR_INVALID_BINARY) - << "Failed to build a module out of " << modules.size() << "."; - modules.push_back(std::move(module)); + << "Failed to build a module out of " << contexts.size() << "."; + modules.push_back(context->module()); + contexts.push_back(std::move(context)); } // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint @@ -216,14 +227,18 @@ spv_result_t Linker::Link(const uint32_t* const* binaries, ir::ModuleHeader header; res = GenerateHeader(consumer, modules, max_id_bound, &header); if (res != SPV_SUCCESS) return res; - auto linked_module = MakeUnique(); - linked_module->SetHeader(header); + IRContext linked_context(consumer); + linked_context.module()->SetHeader(header); // Phase 3: Merge all the binaries into a single one. libspirv::AssemblyGrammar grammar(impl_->context); - res = MergeModules(consumer, modules, grammar, linked_module.get()); + res = MergeModules(consumer, modules, grammar, &linked_context); if (res != SPV_SUCCESS) return res; - ir::IRContext linked_context(std::move(linked_module), consumer); + + if (options.GetVerifyIds()) { + res = VerifyIds(consumer, &linked_context); + if (res != SPV_SUCCESS) return res; + } // Phase 4: Find the import/export pairs LinkageTable linkings_to_do; @@ -270,7 +285,7 @@ spv_result_t Linker::Link(const uint32_t* const* binaries, static spv_result_t ShiftIdsInModules( const MessageConsumer& consumer, - std::vector>* modules, uint32_t* max_id_bound) { + std::vector* modules, uint32_t* max_id_bound) { spv_position_t position = {}; if (modules == nullptr) @@ -289,7 +304,7 @@ static spv_result_t ShiftIdsInModules( uint32_t id_bound = modules->front()->IdBound() - 1u; for (auto module_iter = modules->begin() + 1; module_iter != modules->end(); ++module_iter) { - Module* module = module_iter->get(); + Module* module = *module_iter; module->ForEachInst([&id_bound](Instruction* insn) { insn->ForEachId([&id_bound](uint32_t* id) { *id += id_bound; }); }); @@ -313,7 +328,7 @@ static spv_result_t ShiftIdsInModules( static spv_result_t GenerateHeader( const MessageConsumer& consumer, - const std::vector>& modules, + const std::vector& modules, uint32_t max_id_bound, ir::ModuleHeader* header) { spv_position_t position = {}; @@ -341,28 +356,32 @@ static spv_result_t GenerateHeader( static spv_result_t MergeModules( const MessageConsumer& consumer, - const std::vector>& input_modules, - const libspirv::AssemblyGrammar& grammar, Module* linked_module) { + const std::vector& input_modules, + const libspirv::AssemblyGrammar& grammar, IRContext* linked_context) { spv_position_t position = {}; - if (linked_module == nullptr) + if (linked_context == nullptr) return libspirv::DiagnosticStream(position, consumer, SPV_ERROR_INVALID_DATA) << "|linked_module| of MergeModules should not be null."; + Module* linked_module = linked_context->module(); if (input_modules.empty()) return SPV_SUCCESS; for (const auto& module : input_modules) for (const auto& inst : module->capabilities()) - linked_module->AddCapability(MakeUnique(inst)); + linked_module->AddCapability( + std::unique_ptr(inst.Clone(linked_context))); for (const auto& module : input_modules) for (const auto& inst : module->extensions()) - linked_module->AddExtension(MakeUnique(inst)); + linked_module->AddExtension( + std::unique_ptr(inst.Clone(linked_context))); for (const auto& module : input_modules) for (const auto& inst : module->ext_inst_imports()) - linked_module->AddExtInstImport(MakeUnique(inst)); + linked_module->AddExtInstImport( + std::unique_ptr(inst.Clone(linked_context))); do { const Instruction* memory_model_inst = input_modules[0]->GetMemoryModel(); @@ -402,7 +421,7 @@ static spv_result_t MergeModules( if (memory_model_inst != nullptr) linked_module->SetMemoryModel( - MakeUnique(*memory_model_inst)); + std::unique_ptr(memory_model_inst->Clone(linked_context))); } while (false); std::vector> entry_points; @@ -424,25 +443,30 @@ static spv_result_t MergeModules( << "The entry point \"" << name << "\", with execution model " << desc->name << ", was already defined."; } - linked_module->AddEntryPoint(MakeUnique(inst)); + linked_module->AddEntryPoint( + std::unique_ptr(inst.Clone(linked_context))); entry_points.emplace_back(model, name); } for (const auto& module : input_modules) for (const auto& inst : module->execution_modes()) - linked_module->AddExecutionMode(MakeUnique(inst)); + linked_module->AddExecutionMode( + std::unique_ptr(inst.Clone(linked_context))); for (const auto& module : input_modules) for (const auto& inst : module->debugs1()) - linked_module->AddDebug1Inst(MakeUnique(inst)); + linked_module->AddDebug1Inst( + std::unique_ptr(inst.Clone(linked_context))); for (const auto& module : input_modules) for (const auto& inst : module->debugs2()) - linked_module->AddDebug2Inst(MakeUnique(inst)); + linked_module->AddDebug2Inst( + std::unique_ptr(inst.Clone(linked_context))); for (const auto& module : input_modules) for (const auto& inst : module->annotations()) - linked_module->AddAnnotationInst(MakeUnique(inst)); + linked_module->AddAnnotationInst( + std::unique_ptr(inst.Clone(linked_context))); // TODO(pierremoreau): Since the modules have not been validate, should we // expect SpvStorageClassFunction variables outside @@ -450,7 +474,8 @@ static spv_result_t MergeModules( uint32_t num_global_values = 0u; for (const auto& module : input_modules) { for (const auto& inst : module->types_values()) { - linked_module->AddType(MakeUnique(inst)); + linked_module->AddType( + std::unique_ptr(inst.Clone(linked_context))); num_global_values += inst.opcode() == SpvOpVariable; } } @@ -462,8 +487,7 @@ static spv_result_t MergeModules( // Process functions and their basic blocks for (const auto& module : input_modules) { for (const auto& func : *module) { - std::unique_ptr cloned_func = - MakeUnique(func); + std::unique_ptr cloned_func(func.Clone(linked_context)); cloned_func->SetParent(linked_module); linked_module->AddFunction(std::move(cloned_func)); } @@ -711,4 +735,19 @@ static spv_result_t RemoveLinkageSpecificInstructions( return SPV_SUCCESS; } +spv_result_t VerifyIds(const MessageConsumer& consumer, ir::IRContext* linked_context) { + std::unordered_set ids; + bool ok = true; + linked_context->module()->ForEachInst([&ids,&ok](const ir::Instruction* inst) { + ok &= ids.insert(inst->unique_id()).second; + }); + + if (!ok) { + consumer(SPV_MSG_INTERNAL_ERROR, "", {}, "Non-unique id in merged module"); + return SPV_ERROR_INVALID_ID; + } + + return SPV_SUCCESS; +} + } // namespace spvtools diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index 7be3f9a..6071b94 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -176,7 +176,7 @@ void AggressiveDCEPass::ComputeInst2BlockMap(ir::Function* func) { void AggressiveDCEPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) { std::unique_ptr newBranch(new ir::Instruction( - SpvOpBranch, 0, 0, + context(), SpvOpBranch, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}})); get_def_use_mgr()->AnalyzeInstDefUse(&*newBranch); bp->AddInstruction(std::move(newBranch)); diff --git a/source/opt/basic_block.cpp b/source/opt/basic_block.cpp index 7e0f421..fccd396 100644 --- a/source/opt/basic_block.cpp +++ b/source/opt/basic_block.cpp @@ -13,6 +13,8 @@ // limitations under the License. #include "basic_block.h" +#include "function.h" +#include "module.h" #include "make_unique.h" @@ -27,12 +29,13 @@ const uint32_t kSelectionMergeMergeBlockIdInIdx = 0; } // namespace -BasicBlock::BasicBlock(const BasicBlock& bb) - : function_(nullptr), - label_(MakeUnique(bb.GetLabelInst())), - insts_() { - for (auto& inst : bb.insts_) - AddInstruction(std::unique_ptr(inst.Clone())); +BasicBlock* BasicBlock::Clone(IRContext* context) const { + BasicBlock* clone = + new BasicBlock(std::unique_ptr(GetLabelInst().Clone(context))); + for (const auto& inst : insts_) + // Use the incoming context + clone->AddInstruction(std::unique_ptr(inst.Clone(context))); + return clone; } const Instruction* BasicBlock::GetMergeInst() const { diff --git a/source/opt/basic_block.h b/source/opt/basic_block.h index 32550e7..f4405f2 100644 --- a/source/opt/basic_block.h +++ b/source/opt/basic_block.h @@ -31,6 +31,7 @@ namespace spvtools { namespace ir { class Function; +class IRContext; // A SPIR-V basic block. class BasicBlock { @@ -41,15 +42,20 @@ class BasicBlock { // Creates a basic block with the given starting |label|. inline explicit BasicBlock(std::unique_ptr label); - // Creates a basic block from the given basic block |bb|. + explicit BasicBlock(const BasicBlock& bb) = delete; + + // Creates a clone of the basic block in the given |context| // // The parent function will default to null and needs to be explicitly set by // the user. - explicit BasicBlock(const BasicBlock& bb); + BasicBlock* Clone(IRContext*) const; // Sets the enclosing function for this basic block. void SetParent(Function* function) { function_ = function; } + // Return the enclosing function + inline Function* GetParent() const { return function_; } + // Appends an instruction to this basic block. inline void AddInstruction(std::unique_ptr i); diff --git a/source/opt/build_module.cpp b/source/opt/build_module.cpp index e3439f3..42dbdd7 100644 --- a/source/opt/build_module.cpp +++ b/source/opt/build_module.cpp @@ -14,6 +14,7 @@ #include "build_module.h" +#include"ir_context.h" #include "ir_loader.h" #include "make_unique.h" #include "table.h" @@ -43,15 +44,15 @@ spv_result_t SetSpvInst(void* builder, const spv_parsed_instruction_t* inst) { } // annoymous namespace -std::unique_ptr BuildModule(spv_target_env env, +std::unique_ptr BuildModule(spv_target_env env, MessageConsumer consumer, const uint32_t* binary, const size_t size) { auto context = spvContextCreate(env); SetContextMessageConsumer(context, consumer); - auto module = MakeUnique(); - ir::IrLoader loader(context->consumer, module.get()); + auto irContext = MakeUnique(consumer); + ir::IrLoader loader(consumer, irContext->module()); spv_result_t status = spvBinaryParse(context, &loader, binary, size, SetSpvHeader, SetSpvInst, nullptr); @@ -59,13 +60,13 @@ std::unique_ptr BuildModule(spv_target_env env, spvContextDestroy(context); - return status == SPV_SUCCESS ? std::move(module) : nullptr; + return status == SPV_SUCCESS ? std::move(irContext) : nullptr; } -std::unique_ptr BuildModule(spv_target_env env, - MessageConsumer consumer, - const std::string& text, - uint32_t assemble_options) { +std::unique_ptr BuildModule(spv_target_env env, + MessageConsumer consumer, + const std::string& text, + uint32_t assemble_options) { SpirvTools t(env); t.SetMessageConsumer(consumer); std::vector binary; diff --git a/source/opt/build_module.h b/source/opt/build_module.h index 36ea74f..3ee6607 100644 --- a/source/opt/build_module.h +++ b/source/opt/build_module.h @@ -18,23 +18,25 @@ #include #include +#include "ir_context.h" #include "module.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { -// Builds and returns an ir::Module from the given SPIR-V |binary|. |size| -// specifies number of words in |binary|. The |binary| will be decoded -// according to the given target |env|. Returns nullptr if erors occur and -// sends the errors to |consumer|. -std::unique_ptr BuildModule(spv_target_env env, +// Builds an ir::Module returns the owning ir::IRContext from the given SPIR-V +// |binary|. |size| specifies number of words in |binary|. The |binary| will be +// decoded according to the given target |env|. Returns nullptr if errors occur +// and sends the errors to |consumer|. +std::unique_ptr BuildModule(spv_target_env env, MessageConsumer consumer, const uint32_t* binary, size_t size); -// Builds and returns an ir::Module from the given SPIR-V assembly |text|. -// The |text| will be encoded according to the given target |env|. Returns -// nullptr if erors occur and sends the errors to |consumer|. -std::unique_ptr BuildModule( +// Builds an ir::Module and returns the owning ir::IRContext from the given +// SPIR-V assembly |text|. The |text| will be encoded according to the given +// target |env|. Returns nullptr if errors occur and sends the errors to +// |consumer|. +std::unique_ptr BuildModule( spv_target_env env, MessageConsumer consumer, const std::string& text, uint32_t assemble_options = SpirvTools::kDefaultAssembleOption); diff --git a/source/opt/cfg.cpp b/source/opt/cfg.cpp index 6adc110..a0b78c7 100644 --- a/source/opt/cfg.cpp +++ b/source/opt/cfg.cpp @@ -29,9 +29,9 @@ const int kInvalidId = 0x400000; CFG::CFG(ir::Module* module) : module_(module), pseudo_entry_block_(std::unique_ptr( - new ir::Instruction(SpvOpLabel, 0, 0, {}))), + new ir::Instruction(module->context(), SpvOpLabel, 0, 0, {}))), pseudo_exit_block_(std::unique_ptr( - new ir::Instruction(SpvOpLabel, 0, kInvalidId, {}))) { + new ir::Instruction(module->context(), SpvOpLabel, 0, kInvalidId, {}))) { for (auto& fn : *module) { for (auto& blk : fn) { uint32_t blkId = blk.id(); diff --git a/source/opt/common_uniform_elim_pass.cpp b/source/opt/common_uniform_elim_pass.cpp index d68ed71..3339c4a 100644 --- a/source/opt/common_uniform_elim_pass.cpp +++ b/source/opt/common_uniform_elim_pass.cpp @@ -231,7 +231,7 @@ void CommonUniformElimPass::GenACLoadRepl( ir::Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, std::initializer_list{varId})); std::unique_ptr newLoad(new ir::Instruction( - SpvOpLoad, varPteTypeId, ldResultId, load_in_operands)); + context(), SpvOpLoad, varPteTypeId, ldResultId, load_in_operands)); get_def_use_mgr()->AnalyzeInstDefUse(&*newLoad); newInsts->emplace_back(std::move(newLoad)); @@ -254,7 +254,7 @@ void CommonUniformElimPass::GenACLoadRepl( ++iidIdx; }); std::unique_ptr newExt(new ir::Instruction( - SpvOpCompositeExtract, ptrPteTypeId, extResultId, ext_in_opnds)); + context(), SpvOpCompositeExtract, ptrPteTypeId, extResultId, ext_in_opnds)); get_def_use_mgr()->AnalyzeInstDefUse(&*newExt); newInsts->emplace_back(std::move(newExt)); *resultId = extResultId; @@ -388,7 +388,7 @@ bool CommonUniformElimPass::CommonUniformLoadElimination(ir::Function* func) { // Copy load into most recent dominating block and remember it replId = TakeNextId(); std::unique_ptr newLoad(new ir::Instruction( - SpvOpLoad, ii->type_id(), replId, + context(), SpvOpLoad, ii->type_id(), replId, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}}})); get_def_use_mgr()->AnalyzeInstDefUse(&*newLoad); insertItr = insertItr.InsertBefore(std::move(newLoad)); @@ -460,7 +460,7 @@ bool CommonUniformElimPass::CommonExtractElimination(ir::Function* func) { if (idxItr.second.size() < 2) continue; uint32_t replId = TakeNextId(); std::unique_ptr newExtract( - new ir::Instruction(*idxItr.second.front())); + idxItr.second.front()->Clone(context())); newExtract->SetResultId(replId); get_def_use_mgr()->AnalyzeInstDefUse(&*newExtract); ++ii; diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp index e3bf25f..f1c9bf1 100644 --- a/source/opt/dead_branch_elim_pass.cpp +++ b/source/opt/dead_branch_elim_pass.cpp @@ -74,7 +74,7 @@ bool DeadBranchElimPass::GetConstInteger(uint32_t selId, uint32_t* selVal) { void DeadBranchElimPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) { std::unique_ptr newBranch(new ir::Instruction( - SpvOpBranch, 0, 0, + context(), SpvOpBranch, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}})); get_def_use_mgr()->AnalyzeInstDefUse(&*newBranch); bp->AddInstruction(std::move(newBranch)); @@ -83,7 +83,7 @@ void DeadBranchElimPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) { void DeadBranchElimPass::AddSelectionMerge(uint32_t labelId, ir::BasicBlock* bp) { std::unique_ptr newMerge(new ir::Instruction( - SpvOpSelectionMerge, 0, 0, + context(), SpvOpSelectionMerge, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}, {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {0}}})); get_def_use_mgr()->AnalyzeInstDefUse(&*newMerge); @@ -95,7 +95,7 @@ void DeadBranchElimPass::AddBranchConditional(uint32_t condId, uint32_t falseLabId, ir::BasicBlock* bp) { std::unique_ptr newBranchCond(new ir::Instruction( - SpvOpBranchConditional, 0, 0, + context(), SpvOpBranchConditional, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {condId}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {trueLabId}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {falseLabId}}})); @@ -302,7 +302,7 @@ bool DeadBranchElimPass::EliminateDeadBranches(ir::Function* func) { ++icnt; }); std::unique_ptr newPhi(new ir::Instruction( - SpvOpPhi, pii->type_id(), replId, phi_in_opnds)); + context(), SpvOpPhi, pii->type_id(), replId, phi_in_opnds)); get_def_use_mgr()->AnalyzeInstDefUse(&*newPhi); pii = pii.InsertBefore(std::move(newPhi)); ++pii; diff --git a/source/opt/decoration_manager.cpp b/source/opt/decoration_manager.cpp index aa926db..b25c20f 100644 --- a/source/opt/decoration_manager.cpp +++ b/source/opt/decoration_manager.cpp @@ -70,11 +70,11 @@ bool DecorationManager::AreDecorationsTheSame( // for (uint32_t i = 2u; i < inst.NumInOperands(); ++i) { // const auto& j = constants.find(inst.GetSingleWordInOperand(i)); // if (j == constants.end()) - // return Instruction(); + // return Instruction(inst.context()); // const auto operand = j->second->GetOperand(0u); // operands.emplace_back(operand.type, operand.words); // } - // return Instruction(SpvOpDecorate, 0u, 0u, operands); + // return Instruction(inst.context(), SpvOpDecorate, 0u, 0u, operands); // }; // Instruction tmpA = (deco1.opcode() == SpvOpDecorateId) ? // decorateIdToDecorate(deco1) : deco1; @@ -261,7 +261,7 @@ void DecorationManager::CloneDecorations( case SpvOpMemberDecorate: case SpvOpDecorateId: { // simply clone decoration and change |target-id| to |to| - std::unique_ptr new_inst(inst->Clone()); + std::unique_ptr new_inst(inst->Clone(module_->context())); new_inst->SetInOperand(0, {to}); id_to_decoration_insts_[to].push_back(new_inst.get()); f(*new_inst, true); diff --git a/source/opt/flatten_decoration_pass.cpp b/source/opt/flatten_decoration_pass.cpp index e92935d..eac8297 100644 --- a/source/opt/flatten_decoration_pass.cpp +++ b/source/opt/flatten_decoration_pass.cpp @@ -91,7 +91,7 @@ Pass::Status FlattenDecorationPass::Process(ir::IRContext* c) { const auto normal_uses_iter = normal_uses.find(group); if (normal_uses_iter != normal_uses.end()) { for (auto target : normal_uses[group]) { - std::unique_ptr new_inst(new Instruction(*inst_iter)); + std::unique_ptr new_inst(inst_iter->Clone(context())); new_inst->SetInOperand(0, Words{target}); inst_iter = inst_iter.InsertBefore(std::move(new_inst)); ++inst_iter; @@ -116,8 +116,8 @@ Pass::Status FlattenDecorationPass::Process(ir::IRContext* c) { decoration_operands_iter++; // Skip the group target. operands.insert(operands.end(), decoration_operands_iter, inst_iter->end()); - std::unique_ptr new_inst( - new Instruction(SpvOp::SpvOpMemberDecorate, 0, 0, operands)); + std::unique_ptr new_inst(new Instruction( + context(), SpvOp::SpvOpMemberDecorate, 0, 0, operands)); inst_iter = inst_iter.InsertBefore(std::move(new_inst)); ++inst_iter; replace = true; diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp index a630d8a..e91d1fb 100644 --- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp +++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp @@ -724,22 +724,23 @@ std::unique_ptr FoldSpecConstantOpAndCompositePass::CreateInstruction(uint32_t id, analysis::Constant* c) { if (c->AsNullConstant()) { - return MakeUnique(SpvOp::SpvOpConstantNull, + return MakeUnique(context(), SpvOp::SpvOpConstantNull, type_mgr_->GetId(c->type()), id, std::initializer_list{}); } else if (analysis::BoolConstant* bc = c->AsBoolConstant()) { return MakeUnique( + context(), bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, type_mgr_->GetId(c->type()), id, std::initializer_list{}); } else if (analysis::IntConstant* ic = c->AsIntConstant()) { return MakeUnique( - SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id, + context(), SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id, std::initializer_list{ir::Operand( spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, ic->words())}); } else if (analysis::FloatConstant* fc = c->AsFloatConstant()) { return MakeUnique( - SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id, + context(), SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id, std::initializer_list{ir::Operand( spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, fc->words())}); @@ -765,7 +766,8 @@ FoldSpecConstantOpAndCompositePass::CreateCompositeInstruction( operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID, std::initializer_list{id}); } - return MakeUnique(SpvOp::SpvOpConstantComposite, + return MakeUnique(context(), + SpvOp::SpvOpConstantComposite, type_mgr_->GetId(cc->type()), result_id, std::move(operands)); } diff --git a/source/opt/function.cpp b/source/opt/function.cpp index 4ad2dce..dc5320f 100644 --- a/source/opt/function.cpp +++ b/source/opt/function.cpp @@ -19,27 +19,25 @@ namespace spvtools { namespace ir { -Function::Function(const Function& f) - : module_(nullptr), - def_inst_(MakeUnique(f.DefInst())), - params_(), - blocks_(), - end_inst_() { - params_.reserve(f.params_.size()); - f.ForEachParam( - [this](const Instruction* insn) { - AddParameter(MakeUnique(*insn)); +Function* Function::Clone(IRContext* context) const { + Function* clone = + new Function(std::unique_ptr(DefInst().Clone(context))); + clone->params_.reserve(params_.size()); + ForEachParam( + [clone,context](const Instruction* inst) { + clone->AddParameter(std::unique_ptr(inst->Clone(context))); }, true); - blocks_.reserve(f.blocks_.size()); - for (const auto& b : f.blocks_) { - std::unique_ptr bb = MakeUnique(*b); - bb->SetParent(this); - AddBasicBlock(std::move(bb)); + clone->blocks_.reserve(blocks_.size()); + for (const auto& b : blocks_) { + std::unique_ptr bb(b->Clone(context)); + bb->SetParent(clone); + clone->AddBasicBlock(std::move(bb)); } - SetFunctionEnd(MakeUnique(f.function_end())); + clone->SetFunctionEnd(std::unique_ptr(function_end().Clone(context))); + return clone; } void Function::ForEachInst(const std::function& f, diff --git a/source/opt/function.h b/source/opt/function.h index 618eb7d..9cd7209 100644 --- a/source/opt/function.h +++ b/source/opt/function.h @@ -27,6 +27,7 @@ namespace spvtools { namespace ir { +class IRContext; class Module; // A SPIR-V function. @@ -38,17 +39,22 @@ class Function { // Creates a function instance declared by the given OpFunction instruction // |def_inst|. inline explicit Function(std::unique_ptr def_inst); - // Creates a function instance based on the given function |f|. + + explicit Function(const Function& f) = delete; + + // Creates a clone of the instruction in the given |context| // // The parent module will default to null and needs to be explicitly set by // the user. - explicit Function(const Function& f); + Function* Clone(IRContext*) const; // The OpFunction instruction that begins the definition of this function. Instruction& DefInst() { return *def_inst_; } const Instruction& DefInst() const { return *def_inst_; } // Sets the enclosing module for this function. void SetParent(Module* module) { module_ = module; } + // Gets the enclosing module for this function + Module* GetParent() const { return module_; } // Appends a parameter to this function. inline void AddParameter(std::unique_ptr p); // Appends a basic block to this function. diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp index f52277b..5c6e3fb 100644 --- a/source/opt/inline_pass.cpp +++ b/source/opt/inline_pass.cpp @@ -49,7 +49,7 @@ uint32_t InlinePass::AddPointerToType(uint32_t type_id, SpvStorageClass storage_class) { uint32_t resultId = TakeNextId(); std::unique_ptr type_inst(new ir::Instruction( - SpvOpTypePointer, 0, resultId, + context(), SpvOpTypePointer, 0, resultId, {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, {uint32_t(storage_class)}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}})); @@ -60,7 +60,7 @@ uint32_t InlinePass::AddPointerToType(uint32_t type_id, void InlinePass::AddBranch(uint32_t label_id, std::unique_ptr* block_ptr) { std::unique_ptr newBranch(new ir::Instruction( - SpvOpBranch, 0, 0, + context(), SpvOpBranch, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}})); (*block_ptr)->AddInstruction(std::move(newBranch)); } @@ -69,7 +69,7 @@ void InlinePass::AddBranchCond(uint32_t cond_id, uint32_t true_id, uint32_t false_id, std::unique_ptr* block_ptr) { std::unique_ptr newBranch(new ir::Instruction( - SpvOpBranchConditional, 0, 0, + context(), SpvOpBranchConditional, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cond_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {true_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {false_id}}})); @@ -79,7 +79,7 @@ void InlinePass::AddBranchCond(uint32_t cond_id, uint32_t true_id, void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id, std::unique_ptr* block_ptr) { std::unique_ptr newLoopMerge(new ir::Instruction( - SpvOpLoopMerge, 0, 0, + context(), SpvOpLoopMerge, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {merge_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {continue_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_LOOP_CONTROL, {0}}})); @@ -89,7 +89,7 @@ void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id, void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id, std::unique_ptr* block_ptr) { std::unique_ptr newStore(new ir::Instruction( - SpvOpStore, 0, 0, + context(), SpvOpStore, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {val_id}}})); (*block_ptr)->AddInstruction(std::move(newStore)); @@ -98,14 +98,14 @@ void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id, void InlinePass::AddLoad(uint32_t type_id, uint32_t resultId, uint32_t ptr_id, std::unique_ptr* block_ptr) { std::unique_ptr newLoad(new ir::Instruction( - SpvOpLoad, type_id, resultId, + context(), SpvOpLoad, type_id, resultId, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}})); (*block_ptr)->AddInstruction(std::move(newLoad)); } std::unique_ptr InlinePass::NewLabel(uint32_t label_id) { std::unique_ptr newLabel( - new ir::Instruction(SpvOpLabel, 0, label_id, {})); + new ir::Instruction(context(), SpvOpLabel, 0, label_id, {})); return newLabel; } @@ -143,7 +143,8 @@ void InlinePass::CloneAndMapLocals( auto callee_block_itr = calleeFn->begin(); auto callee_var_itr = callee_block_itr->begin(); while (callee_var_itr->opcode() == SpvOp::SpvOpVariable) { - std::unique_ptr var_inst(callee_var_itr->Clone()); + std::unique_ptr var_inst( + callee_var_itr->Clone(callee_var_itr->context())); uint32_t newId = TakeNextId(); get_decoration_mgr()->CloneDecorations(callee_var_itr->result_id(), newId, update_def_use_mgr_); var_inst->SetResultId(newId); @@ -169,7 +170,7 @@ uint32_t InlinePass::CreateReturnVar( // Add return var to new function scope variables. returnVarId = TakeNextId(); std::unique_ptr var_inst(new ir::Instruction( - SpvOpVariable, returnVarTypeId, returnVarId, + context(), SpvOpVariable, returnVarTypeId, returnVarId, {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); new_vars->push_back(std::move(var_inst)); @@ -195,7 +196,8 @@ void InlinePass::CloneSameBlockOps( if (mapItr2 != (*preCallSB).end()) { // Clone pre-call same-block ops, map result id. const ir::Instruction* inInst = mapItr2->second; - std::unique_ptr sb_inst(inInst->Clone()); + std::unique_ptr sb_inst( + inInst->Clone(inInst->context())); CloneSameBlockOps(&sb_inst, postCallSB, preCallSB, block_ptr); const uint32_t rid = sb_inst->result_id(); const uint32_t nid = this->TakeNextId(); @@ -325,7 +327,7 @@ void InlinePass::GenInlineCode( // Copy contents of original caller block up to call instruction. for (auto cii = call_block_itr->begin(); cii != call_inst_itr; ++cii) { - std::unique_ptr cp_inst(cii->Clone()); + std::unique_ptr cp_inst(cii->Clone(context())); // Remember same-block ops for possible regeneration. if (IsSameBlockOp(&*cp_inst)) { auto* sb_inst_ptr = cp_inst.get(); @@ -434,7 +436,7 @@ void InlinePass::GenInlineCode( // Copy remaining instructions from caller block. auto cii = call_inst_itr; for (++cii; cii != call_block_itr->end(); ++cii) { - std::unique_ptr cp_inst(cii->Clone()); + std::unique_ptr cp_inst(cii->Clone(context())); // If multiple blocks generated, regenerate any same-block // instruction that has not been seen in this last block. if (multiBlocks) { @@ -452,7 +454,7 @@ void InlinePass::GenInlineCode( } break; default: { // Copy callee instruction and remap all input Ids. - std::unique_ptr cp_inst(cpi->Clone()); + std::unique_ptr cp_inst(cpi->Clone(context())); cp_inst->ForEachInId([&callee2caller, &callee_result_ids, this](uint32_t* iid) { const auto mapItr = callee2caller.find(*iid); @@ -497,7 +499,7 @@ void InlinePass::GenInlineCode( auto loop_merge_itr = last->tail(); --loop_merge_itr; assert(loop_merge_itr->opcode() == SpvOpLoopMerge); - std::unique_ptr cp_inst(loop_merge_itr->Clone()); + std::unique_ptr cp_inst(loop_merge_itr->Clone(context())); if (caller_is_single_block_loop) { // Also, update its continue target to point to the last block. cp_inst->SetInOperand(kSpvLoopMergeContinueTargetIdInIdx, {last->id()}); diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp index f26fb1d..df2dcb7 100644 --- a/source/opt/instruction.cpp +++ b/source/opt/instruction.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "instruction.h" +#include "ir_context.h" #include @@ -21,11 +22,29 @@ namespace spvtools { namespace ir { -Instruction::Instruction(const spv_parsed_instruction_t& inst, +Instruction::Instruction(IRContext* c) + : utils::IntrusiveNodeBase(), + context_(c), + opcode_(SpvOpNop), + type_id_(0), + result_id_(0), + unique_id_(c->TakeNextUniqueId()) {} + +Instruction::Instruction(IRContext* c, SpvOp op) + : utils::IntrusiveNodeBase(), + context_(c), + opcode_(op), + type_id_(0), + result_id_(0), + unique_id_(c->TakeNextUniqueId()) {} + +Instruction::Instruction(IRContext* c, const spv_parsed_instruction_t& inst, std::vector&& dbg_line) - : opcode_(static_cast(inst.opcode)), + : context_(c), + opcode_(static_cast(inst.opcode)), type_id_(inst.type_id), result_id_(inst.result_id), + unique_id_(c->TakeNextUniqueId()), dbg_line_insts_(std::move(dbg_line)) { assert((!IsDebugLineInst(opcode_) || dbg_line.empty()) && "Op(No)Line attaching to Op(No)Line found"); @@ -38,12 +57,14 @@ Instruction::Instruction(const spv_parsed_instruction_t& inst, } } -Instruction::Instruction(SpvOp op, uint32_t ty_id, uint32_t res_id, +Instruction::Instruction(IRContext* c, SpvOp op, uint32_t ty_id, uint32_t res_id, const std::vector& in_operands) : utils::IntrusiveNodeBase(), + context_(c), opcode_(op), type_id_(ty_id), result_id_(res_id), + unique_id_(c->TakeNextUniqueId()), operands_() { if (type_id_ != 0) { operands_.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_TYPE_ID, @@ -61,6 +82,7 @@ Instruction::Instruction(Instruction&& that) opcode_(that.opcode_), type_id_(that.type_id_), result_id_(that.result_id_), + unique_id_(that.unique_id_), operands_(std::move(that.operands_)), dbg_line_insts_(std::move(that.dbg_line_insts_)) {} @@ -68,16 +90,18 @@ Instruction& Instruction::operator=(Instruction&& that) { opcode_ = that.opcode_; type_id_ = that.type_id_; result_id_ = that.result_id_; + unique_id_ = that.unique_id_; operands_ = std::move(that.operands_); dbg_line_insts_ = std::move(that.dbg_line_insts_); return *this; } -Instruction* Instruction::Clone() const { - Instruction* clone = new Instruction(); +Instruction* Instruction::Clone(IRContext *c) const { + Instruction* clone = new Instruction(c); clone->opcode_ = opcode_; clone->type_id_ = type_id_; clone->result_id_ = result_id_; + clone->unique_id_ = c->TakeNextUniqueId(); clone->operands_ = operands_; clone->dbg_line_insts_ = dbg_line_insts_; return clone; diff --git a/source/opt/instruction.h b/source/opt/instruction.h index ff0acdb..4c96474 100644 --- a/source/opt/instruction.h +++ b/source/opt/instruction.h @@ -31,6 +31,7 @@ namespace spvtools { namespace ir { class Function; +class IRContext; class Module; class InstructionList; @@ -84,28 +85,30 @@ class Instruction : public utils::IntrusiveNodeBase { using const_iterator = std::vector::const_iterator; // Creates a default OpNop instruction. + // This exists solely for containers that can't do without. Should be removed. Instruction() : utils::IntrusiveNodeBase(), + context_(nullptr), opcode_(SpvOpNop), type_id_(0), - result_id_(0) {} + result_id_(0), + unique_id_(0) {} + + // Creates a default OpNop instruction. + Instruction(IRContext*); // Creates an instruction with the given opcode |op| and no additional logical // operands. - Instruction(SpvOp op) - : utils::IntrusiveNodeBase(), - opcode_(op), - type_id_(0), - result_id_(0) {} + Instruction(IRContext*, SpvOp); // Creates an instruction using the given spv_parsed_instruction_t |inst|. All // the data inside |inst| will be copied and owned in this instance. And keep // record of line-related debug instructions |dbg_line| ahead of this // instruction, if any. - Instruction(const spv_parsed_instruction_t& inst, + Instruction(IRContext* c, const spv_parsed_instruction_t& inst, std::vector&& dbg_line = {}); // Creates an instruction with the given opcode |op|, type id: |ty_id|, // result id: |res_id| and input operands: |in_operands|. - Instruction(SpvOp op, uint32_t ty_id, uint32_t res_id, + Instruction(IRContext* c, SpvOp op, uint32_t ty_id, uint32_t res_id, const std::vector& in_operands); // TODO: I will want to remove these, but will first have to remove the use of @@ -123,7 +126,9 @@ class Instruction : public utils::IntrusiveNodeBase { // It is the responsibility of the caller to make sure that the storage is // removed. It is the caller's responsibility to make sure that there is only // one instruction for each result id. - Instruction* Clone() const; + Instruction* Clone(IRContext *c) const; + + IRContext* context() const { return context_; } SpvOp opcode() const { return opcode_; } // Sets the opcode of this instruction to a specific opcode. Note this may @@ -133,6 +138,7 @@ class Instruction : public utils::IntrusiveNodeBase { void SetOpcode(SpvOp op) { opcode_ = op; } uint32_t type_id() const { return type_id_; } uint32_t result_id() const { return result_id_; } + uint32_t unique_id() const { assert(unique_id_ != 0); return unique_id_; } // Returns the vector of line-related debug instructions attached to this // instruction and the caller can directly modify them. std::vector& dbg_line_insts() { return dbg_line_insts_; } @@ -241,15 +247,21 @@ class Instruction : public utils::IntrusiveNodeBase { // Returns true if the instruction annotates an id with a decoration. inline bool IsDecoration(); + inline bool operator==(const Instruction&) const; + inline bool operator!=(const Instruction&) const; + inline bool operator<(const Instruction&) const; + private: // Returns the total count of result type id and result id. uint32_t TypeResultIdCount() const { return (type_id_ != 0) + (result_id_ != 0); } + IRContext* context_; // IR Context SpvOp opcode_; // Opcode uint32_t type_id_; // Result type id. A value of 0 means no result type id. uint32_t result_id_; // Result id. A value of 0 means no result id. + uint32_t unique_id_; // Unique instruction id // All logical operands, including result type id and result id. std::vector operands_; // Opline and OpNoLine instructions preceding this instruction. Note that for @@ -260,6 +272,18 @@ class Instruction : public utils::IntrusiveNodeBase { friend InstructionList; }; +inline bool Instruction::operator==(const Instruction& other) const { + return unique_id() == other.unique_id(); +} + +inline bool Instruction::operator!=(const Instruction& other) const { + return !(*this == other); +} + +inline bool Instruction::operator<(const Instruction& other) const { + return unique_id() < other.unique_id(); +} + inline const Operand& Instruction::GetOperand(uint32_t index) const { assert(index < operands_.size() && "operand index out of bound"); return operands_[index]; diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h index 23f59ed..eedee1e 100644 --- a/source/opt/ir_context.h +++ b/source/opt/ir_context.h @@ -21,6 +21,7 @@ #include #include +#include namespace spvtools { namespace ir { @@ -53,11 +54,26 @@ class IRContext { friend inline Analysis operator<<(Analysis a, int shift); friend inline Analysis& operator<<=(Analysis& a, int shift); + // Create an |IRContext| that contains an owned |Module| + IRContext(spvtools::MessageConsumer c) + : unique_id_(0), + module_(new Module()), + consumer_(std::move(c)), + def_use_mgr_(nullptr), + valid_analyses_(kAnalysisNone) + { + module_->SetContext(this); + } + IRContext(std::unique_ptr&& m, spvtools::MessageConsumer c) - : module_(std::move(m)), + : unique_id_(0), + module_(std::move(m)), consumer_(std::move(c)), def_use_mgr_(nullptr), - valid_analyses_(kAnalysisNone) {} + valid_analyses_(kAnalysisNone) + { + module_->SetContext(this); + } Module* module() const { return module_.get(); } inline void SetIdBound(uint32_t i); @@ -239,6 +255,14 @@ class IRContext { // Kill all name and decorate ops targeting the result id of |inst|. void KillNamesAndDecorates(ir::Instruction* inst); + // Returns the next unique id for use by an instruction. + inline uint32_t TakeNextUniqueId() { + assert(unique_id_ != std::numeric_limits::max()); + + // Skip zero. + return ++unique_id_; + } + private: // Builds the def-use manager from scratch, even if it was already valid. void BuildDefUseManager() { @@ -264,6 +288,13 @@ class IRContext { valid_analyses_ = valid_analyses_ | kAnalysisDecorations; } + // An unique identifier for this instruction. Can be used to order + // instructions in a container. + // + // This member is initialized to 0, but always issues this value plus one. + // Therefore, 0 is not a valid unique id for an instruction. + uint32_t unique_id_; + std::unique_ptr module_; spvtools::MessageConsumer consumer_; std::unique_ptr def_use_mgr_; diff --git a/source/opt/ir_loader.cpp b/source/opt/ir_loader.cpp index e3d8484..b705343 100644 --- a/source/opt/ir_loader.cpp +++ b/source/opt/ir_loader.cpp @@ -20,9 +20,9 @@ namespace spvtools { namespace ir { -IrLoader::IrLoader(const MessageConsumer& consumer, Module* module) +IrLoader::IrLoader(const MessageConsumer& consumer, Module* m) : consumer_(consumer), - module_(module), + module_(m), source_(""), inst_index_(0) {} @@ -30,12 +30,12 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) { ++inst_index_; const auto opcode = static_cast(inst->opcode); if (IsDebugLineInst(opcode)) { - dbg_line_info_.push_back(Instruction(*inst)); + dbg_line_info_.push_back(Instruction(module()->context(), *inst)); return true; } std::unique_ptr spv_inst( - new Instruction(*inst, std::move(dbg_line_info_))); + new Instruction(module()->context(), *inst, std::move(dbg_line_info_))); dbg_line_info_.clear(); const char* src = source_.c_str(); diff --git a/source/opt/ir_loader.h b/source/opt/ir_loader.h index bcb55f1..2f0ca8b 100644 --- a/source/opt/ir_loader.h +++ b/source/opt/ir_loader.h @@ -39,11 +39,13 @@ class IrLoader { // All internal messages will be communicated to the outside via the given // message |consumer|. This instance only keeps a reference to the |consumer|, // so the |consumer| should outlive this instance. - IrLoader(const MessageConsumer& consumer, Module* module); + IrLoader(const MessageConsumer& consumer, Module* m); // Sets the source name of the module. void SetSource(const std::string& src) { source_ = src; } + Module* module() const { return module_; } + // Sets the fields in the module's header to the given parameters. void SetModuleHeader(uint32_t magic, uint32_t version, uint32_t generator, uint32_t bound, uint32_t reserved) { diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp index 9663e88..ff5f912 100644 --- a/source/opt/local_access_chain_convert_pass.cpp +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -44,7 +44,7 @@ void LocalAccessChainConvertPass::BuildAndAppendInst( const std::vector& in_opnds, std::vector>* newInsts) { std::unique_ptr newInst( - new ir::Instruction(opcode, typeId, resultId, in_opnds)); + new ir::Instruction(context(), opcode, typeId, resultId, in_opnds)); get_def_use_mgr()->AnalyzeInstDefUse(&*newInst); newInsts->emplace_back(std::move(newInst)); } diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp index 72e4f73..ae5baa8 100644 --- a/source/opt/mem_pass.cpp +++ b/source/opt/mem_pass.cpp @@ -287,7 +287,7 @@ uint32_t MemPass::Type2Undef(uint32_t type_id) { if (uitr != type2undefs_.end()) return uitr->second; const uint32_t undefId = TakeNextId(); std::unique_ptr undef_inst( - new ir::Instruction(SpvOpUndef, type_id, undefId, {})); + new ir::Instruction(context(), SpvOpUndef, type_id, undefId, {})); get_def_use_mgr()->AnalyzeInstDefUse(&*undef_inst); get_module()->AddGlobalValue(std::move(undef_inst)); type2undefs_[type_id] = undefId; @@ -402,7 +402,7 @@ void MemPass::SSABlockInitLoopHeader( } const uint32_t phiId = TakeNextId(); std::unique_ptr newPhi( - new ir::Instruction(SpvOpPhi, typeId, phiId, phi_in_operands)); + new ir::Instruction(context(), SpvOpPhi, typeId, phiId, phi_in_operands)); // The only phis requiring patching are the ones we create. phis_to_patch_.insert(phiId); // Only analyze the phi define now; analyze the phi uses after the @@ -470,7 +470,7 @@ void MemPass::SSABlockInitMultiPred(ir::BasicBlock* block_ptr) { } const uint32_t phiId = TakeNextId(); std::unique_ptr newPhi( - new ir::Instruction(SpvOpPhi, typeId, phiId, phi_in_operands)); + new ir::Instruction(context(), SpvOpPhi, typeId, phiId, phi_in_operands)); get_def_use_mgr()->AnalyzeInstDefUse(&*newPhi); insertItr = insertItr.InsertBefore(std::move(newPhi)); ++insertItr; diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp index 9374a91..e822885 100644 --- a/source/opt/merge_return_pass.cpp +++ b/source/opt/merge_return_pass.cpp @@ -60,7 +60,7 @@ bool MergeReturnPass::MergeReturnBlocks( // Create a label for the new return block std::unique_ptr returnLabel( - new ir::Instruction(SpvOpLabel, 0u, TakeNextId(), {})); + new ir::Instruction(context(), SpvOpLabel, 0u, TakeNextId(), {})); uint32_t returnId = returnLabel->result_id(); // Create the new basic block @@ -84,13 +84,14 @@ bool MergeReturnPass::MergeReturnBlocks( // Need a PHI node to select the correct return value. uint32_t phiResultId = TakeNextId(); uint32_t phiTypeId = function->type_id(); - std::unique_ptr phiInst( - new ir::Instruction(SpvOpPhi, phiTypeId, phiResultId, phiOps)); + std::unique_ptr phiInst(new ir::Instruction( + context(), SpvOpPhi, phiTypeId, phiResultId, phiOps)); retBlockIter->AddInstruction(std::move(phiInst)); ir::BasicBlock::iterator phiIter = retBlockIter->tail(); - std::unique_ptr returnInst(new ir::Instruction( - SpvOpReturnValue, 0u, 0u, {{SPV_OPERAND_TYPE_ID, {phiResultId}}})); + std::unique_ptr returnInst( + new ir::Instruction(context(), SpvOpReturnValue, 0u, 0u, + {{SPV_OPERAND_TYPE_ID, {phiResultId}}})); retBlockIter->AddInstruction(std::move(returnInst)); ir::BasicBlock::iterator ret = retBlockIter->tail(); @@ -98,7 +99,7 @@ bool MergeReturnPass::MergeReturnBlocks( get_def_use_mgr()->AnalyzeInstDef(&*ret); } else { std::unique_ptr returnInst( - new ir::Instruction(SpvOpReturn)); + new ir::Instruction(context(), SpvOpReturn)); retBlockIter->AddInstruction(std::move(returnInst)); } diff --git a/source/opt/module.cpp b/source/opt/module.cpp index e329b3c..9d46a1b 100644 --- a/source/opt/module.cpp +++ b/source/opt/module.cpp @@ -65,7 +65,7 @@ uint32_t Module::GetGlobalValue(SpvOp opcode) const { void Module::AddGlobalValue(SpvOp opcode, uint32_t result_id, uint32_t type_id) { std::unique_ptr newGlobal( - new ir::Instruction(opcode, type_id, result_id, {})); + new ir::Instruction(context(), opcode, type_id, result_id, {})); AddGlobalValue(std::move(newGlobal)); } diff --git a/source/opt/module.h b/source/opt/module.h index e4c03e2..d3fe2b5 100644 --- a/source/opt/module.h +++ b/source/opt/module.h @@ -27,6 +27,8 @@ namespace spvtools { namespace ir { +class IRContext; + // A struct for containing the module header information. struct ModuleHeader { uint32_t magic_number; @@ -223,11 +225,18 @@ class Module { // Returns 0 if not found. uint32_t GetExtInstImportId(const char* extstr); + // Sets the associated context for this module + void SetContext(IRContext* c) { context_ = c; } + + // Gets the associated context for this module + IRContext* context() const { return context_; } + private: ModuleHeader header_; // Module header // The following fields respect the "Logical Layout of a Module" in // Section 2.4 of the SPIR-V specification. + IRContext* context_; InstructionList capabilities_; InstructionList extensions_; InstructionList ext_inst_imports_; diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 3527c1f..ac913dd 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -105,19 +105,18 @@ Optimizer& Optimizer::RegisterSizePasses() { bool Optimizer::Run(const uint32_t* original_binary, const size_t original_binary_size, std::vector* optimized_binary) const { - std::unique_ptr module = + std::unique_ptr context = BuildModule(impl_->target_env, impl_->pass_manager.consumer(), original_binary, original_binary_size); - if (module == nullptr) return false; - ir::IRContext context(std::move(module), impl_->pass_manager.consumer()); + if (context == nullptr) return false; - auto status = impl_->pass_manager.Run(&context); + auto status = impl_->pass_manager.Run(context.get()); if (status == opt::Pass::Status::SuccessWithChange || (status == opt::Pass::Status::SuccessWithoutChange && (optimized_binary->data() != original_binary || optimized_binary->size() != original_binary_size))) { optimized_binary->clear(); - context.module()->ToBinary(optimized_binary, /* skip_nop = */ true); + context->module()->ToBinary(optimized_binary, /* skip_nop = */ true); } return status != opt::Pass::Status::Failure; diff --git a/source/opt/strength_reduction_pass.cpp b/source/opt/strength_reduction_pass.cpp index 5c08f5e..f2aee91 100644 --- a/source/opt/strength_reduction_pass.cpp +++ b/source/opt/strength_reduction_pass.cpp @@ -100,7 +100,7 @@ bool StrengthReductionPass::ReplaceMultiplyByPowerOf2( {shiftConstResultId}); newOperands.push_back(shiftOperand); std::unique_ptr newInstruction( - new ir::Instruction(SpvOp::SpvOpShiftLeftLogical, inst->type_id(), + new ir::Instruction(context(), SpvOp::SpvOpShiftLeftLogical, inst->type_id(), newResultId, newOperands)); // Insert the new instruction and update the data structures. @@ -161,7 +161,7 @@ uint32_t StrengthReductionPass::GetConstantId(uint32_t val) { ir::Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}); std::unique_ptr newConstant(new ir::Instruction( - SpvOp::SpvOpConstant, uint32_type_id_, resultId, {constant})); + context(), SpvOp::SpvOpConstant, uint32_type_id_, resultId, {constant})); get_module()->AddGlobalValue(std::move(newConstant)); // Store the result id for next time. @@ -199,7 +199,7 @@ uint32_t StrengthReductionPass::CreateUint32Type() { ir::Operand signOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {0}); std::unique_ptr newType(new ir::Instruction( - SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand})); + context(), SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand})); context()->AddType(std::move(newType)); return type_id; } diff --git a/test/link/CMakeLists.txt b/test/link/CMakeLists.txt index 9768ab3..f2ced24 100644 --- a/test/link/CMakeLists.txt +++ b/test/link/CMakeLists.txt @@ -41,3 +41,8 @@ add_spvtools_unittest(TARGET link_matching_imports_to_exports SRCS matching_imports_to_exports_test.cpp LIBS SPIRV-Tools-opt SPIRV-Tools-link ) + +add_spvtools_unittest(TARGET link_unique_ids + SRCS unique_ids_test.cpp + LIBS SPIRV-Tools-opt SPIRV-Tools-link +) diff --git a/test/link/unique_ids_test.cpp b/test/link/unique_ids_test.cpp new file mode 100644 index 0000000..8b67d34 --- /dev/null +++ b/test/link/unique_ids_test.cpp @@ -0,0 +1,137 @@ +// Copyright (c) 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "linker_fixture.h" + +namespace { + +using UniqueIds = spvtest::LinkerTest; + +TEST_F(UniqueIds, UniquelyMerged) { + std::vector bodies(2); + bodies[0] = + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint Vertex %main \"main\"\n" + "OpSource ESSL 310\n" + "OpName %main \"main\"\n" + "OpName %f_ \"f(\"\n" + "OpName %gv1 \"gv1\"\n" + "OpName %gv2 \"gv2\"\n" + "OpName %lv1 \"lv1\"\n" + "OpName %lv2 \"lv2\"\n" + "OpName %lv1_0 \"lv1\"\n" + "%void = OpTypeVoid\n" + "%10 = OpTypeFunction %void\n" + "%float = OpTypeFloat 32\n" + "%12 = OpTypeFunction %float\n" + "%_ptr_Private_float = OpTypePointer Private %float\n" + "%gv1 = OpVariable %_ptr_Private_float Private\n" + "%float_10 = OpConstant %float 10\n" + "%gv2 = OpVariable %_ptr_Private_float Private\n" + "%float_100 = OpConstant %float 100\n" + "%_ptr_Function_float = OpTypePointer Function %float\n" + "%main = OpFunction %void None %10\n" + "%17 = OpLabel\n" + "%lv1_0 = OpVariable %_ptr_Function_float Function\n" + "OpStore %gv1 %float_10\n" + "OpStore %gv2 %float_100\n" + "%18 = OpLoad %float %gv1\n" + "%19 = OpLoad %float %gv2\n" + "%20 = OpFSub %float %18 %19\n" + "OpStore %lv1_0 %20\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%f_ = OpFunction %float None %12\n" + "%21 = OpLabel\n" + "%lv1 = OpVariable %_ptr_Function_float Function\n" + "%lv2 = OpVariable %_ptr_Function_float Function\n" + "%22 = OpLoad %float %gv1\n" + "%23 = OpLoad %float %gv2\n" + "%24 = OpFAdd %float %22 %23\n" + "OpStore %lv1 %24\n" + "%25 = OpLoad %float %gv1\n" + "%26 = OpLoad %float %gv2\n" + "%27 = OpFMul %float %25 %26\n" + "OpStore %lv2 %27\n" + "%28 = OpLoad %float %lv1\n" + "%29 = OpLoad %float %lv2\n" + "%30 = OpFDiv %float %28 %29\n" + "OpReturnValue %30\n" + "OpFunctionEnd\n"; + // clang-format on + bodies[1] = + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpSource ESSL 310\n" + "OpName %main \"main2\"\n" + "OpName %f_ \"f(\"\n" + "OpName %gv1 \"gv12\"\n" + "OpName %gv2 \"gv22\"\n" + "OpName %lv1 \"lv12\"\n" + "OpName %lv2 \"lv22\"\n" + "OpName %lv1_0 \"lv12\"\n" + "%void = OpTypeVoid\n" + "%10 = OpTypeFunction %void\n" + "%float = OpTypeFloat 32\n" + "%12 = OpTypeFunction %float\n" + "%_ptr_Private_float = OpTypePointer Private %float\n" + "%gv1 = OpVariable %_ptr_Private_float Private\n" + "%float_10 = OpConstant %float 10\n" + "%gv2 = OpVariable %_ptr_Private_float Private\n" + "%float_100 = OpConstant %float 100\n" + "%_ptr_Function_float = OpTypePointer Function %float\n" + "%main = OpFunction %void None %10\n" + "%17 = OpLabel\n" + "%lv1_0 = OpVariable %_ptr_Function_float Function\n" + "OpStore %gv1 %float_10\n" + "OpStore %gv2 %float_100\n" + "%18 = OpLoad %float %gv1\n" + "%19 = OpLoad %float %gv2\n" + "%20 = OpFSub %float %18 %19\n" + "OpStore %lv1_0 %20\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%f_ = OpFunction %float None %12\n" + "%21 = OpLabel\n" + "%lv1 = OpVariable %_ptr_Function_float Function\n" + "%lv2 = OpVariable %_ptr_Function_float Function\n" + "%22 = OpLoad %float %gv1\n" + "%23 = OpLoad %float %gv2\n" + "%24 = OpFAdd %float %22 %23\n" + "OpStore %lv1 %24\n" + "%25 = OpLoad %float %gv1\n" + "%26 = OpLoad %float %gv2\n" + "%27 = OpFMul %float %25 %26\n" + "OpStore %lv2 %27\n" + "%28 = OpLoad %float %lv1\n" + "%29 = OpLoad %float %lv2\n" + "%30 = OpFDiv %float %28 %29\n" + "OpReturnValue %30\n" + "OpFunctionEnd\n"; + // clang-format on + + spvtest::Binary linked_binary; + spvtools::LinkerOptions options; + options.SetVerifyIds(true); + spv_result_t res = AssembleAndLink(bodies, &linked_binary, options); + EXPECT_EQ(SPV_SUCCESS, res); +} + +} // anonymous namespace diff --git a/test/opt/def_use_test.cpp b/test/opt/def_use_test.cpp index aa88978..bd1ac7e 100644 --- a/test/opt/def_use_test.cpp +++ b/test/opt/def_use_test.cpp @@ -21,6 +21,7 @@ #include "opt/build_module.h" #include "opt/def_use_manager.h" #include "opt/ir_context.h" +#include "opt/module.h" #include "pass_utils.h" #include "spirv-tools/libspirv.hpp" @@ -131,12 +132,12 @@ TEST_P(ParseDefUseTest, Case) { // Build module. const std::vector text = {tc.text}; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text)); - ASSERT_NE(nullptr, module); + ASSERT_NE(nullptr, context); // Analyze def and use. - opt::analysis::DefUseManager manager(module.get()); + opt::analysis::DefUseManager manager(context->module()); CheckDef(tc.du, manager.id_to_defs()); CheckUse(tc.du, manager.id_to_uses()); @@ -512,23 +513,22 @@ TEST_P(ReplaceUseTest, Case) { // Build module. const std::vector text = {tc.before}; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text)); - ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module), spvtools::MessageConsumer()); + ASSERT_NE(nullptr, context); // Force a re-build of def-use manager. - context.InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); - (void)context.get_def_use_mgr(); + context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); + (void)context->get_def_use_mgr(); // Do the substitution. for (const auto& candidate : tc.candidates) { - context.ReplaceAllUsesWith(candidate.first, candidate.second); + context->ReplaceAllUsesWith(candidate.first, candidate.second); } - EXPECT_EQ(tc.after, DisassembleModule(context.module())); - CheckDef(tc.du, context.get_def_use_mgr()->id_to_defs()); - CheckUse(tc.du, context.get_def_use_mgr()->id_to_uses()); + EXPECT_EQ(tc.after, DisassembleModule(context->module())); + CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs()); + CheckUse(tc.du, context->get_def_use_mgr()->id_to_uses()); } // clang-format off @@ -816,20 +816,19 @@ TEST_P(KillDefTest, Case) { // Build module. const std::vector text = {tc.before}; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text)); - ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module), spvtools::MessageConsumer()); + ASSERT_NE(nullptr, context); // Analyze def and use. - opt::analysis::DefUseManager manager(module.get()); + opt::analysis::DefUseManager manager(context->module()); // Do the substitution. - for (const auto id : tc.ids_to_kill) context.KillDef(id); + for (const auto id : tc.ids_to_kill) context->KillDef(id); - EXPECT_EQ(tc.after, DisassembleModule(context.module())); - CheckDef(tc.du, context.get_def_use_mgr()->id_to_defs()); - CheckUse(tc.du, context.get_def_use_mgr()->id_to_uses()); + EXPECT_EQ(tc.after, DisassembleModule(context->module())); + CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs()); + CheckUse(tc.du, context->get_def_use_mgr()->id_to_uses()); } // clang-format off @@ -1067,19 +1066,18 @@ TEST(DefUseTest, OpSwitch) { " OpReturnValue %6 " " OpFunctionEnd"; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original_text); - ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module), spvtools::MessageConsumer()); + ASSERT_NE(nullptr, context); // Force a re-build of def-use manager. - context.InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); - (void)context.get_def_use_mgr(); + context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); + (void)context->get_def_use_mgr(); // Do a bunch replacements. - context.ReplaceAllUsesWith(9, 900); // to unused id - context.ReplaceAllUsesWith(10, 1000); // to unused id - context.ReplaceAllUsesWith(11, 7); // to existing id + context->ReplaceAllUsesWith(9, 900); // to unused id + context->ReplaceAllUsesWith(10, 1000); // to unused id + context->ReplaceAllUsesWith(11, 7); // to existing id // clang-format off const char modified_text[] = @@ -1103,7 +1101,7 @@ TEST(DefUseTest, OpSwitch) { "OpFunctionEnd"; // clang-format on - EXPECT_EQ(modified_text, DisassembleModule(context.module())); + EXPECT_EQ(modified_text, DisassembleModule(context->module())); InstDefUse def_uses = {}; def_uses.defs = { @@ -1118,10 +1116,10 @@ TEST(DefUseTest, OpSwitch) { {10, "%10 = OpLabel"}, {11, "%11 = OpLabel"}, }; - CheckDef(def_uses, context.get_def_use_mgr()->id_to_defs()); + CheckDef(def_uses, context->get_def_use_mgr()->id_to_defs()); { - auto* use_list = context.get_def_use_mgr()->GetUses(6); + auto* use_list = context->get_def_use_mgr()->GetUses(6); ASSERT_NE(nullptr, use_list); EXPECT_EQ(2u, use_list->size()); std::vector opcodes = {use_list->front().inst->opcode(), @@ -1129,7 +1127,7 @@ TEST(DefUseTest, OpSwitch) { EXPECT_THAT(opcodes, UnorderedElementsAre(SpvOpSwitch, SpvOpReturnValue)); } { - auto* use_list = context.get_def_use_mgr()->GetUses(7); + auto* use_list = context->get_def_use_mgr()->GetUses(7); ASSERT_NE(nullptr, use_list); EXPECT_EQ(6u, use_list->size()); std::vector opcodes; @@ -1143,44 +1141,15 @@ TEST(DefUseTest, OpSwitch) { } // Check all ids only used by OpSwitch after replacement. for (const auto id : {8, 900, 1000}) { - auto* use_list = context.get_def_use_mgr()->GetUses(id); + auto* use_list = context->get_def_use_mgr()->GetUses(id); ASSERT_NE(nullptr, use_list); EXPECT_EQ(1u, use_list->size()); EXPECT_EQ(SpvOpSwitch, use_list->front().inst->opcode()); } } -// Creates an |result_id| = OpTypeInt 32 1 instruction. -ir::Instruction Int32TypeInstruction(uint32_t result_id) { - return ir::Instruction(SpvOp::SpvOpTypeInt, 0, result_id, - {ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {32}), - ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {1})}); -} - -// Creates an |result_id| = OpConstantTrue/Flase |type_id| instruction. -ir::Instruction ConstantBoolInstruction(bool value, uint32_t type_id, - uint32_t result_id) { - return ir::Instruction( - value ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, type_id, - result_id, {}); -} - -// Creates an |result_id| = OpLabel instruction. -ir::Instruction LabelInstruction(uint32_t result_id) { - return ir::Instruction(SpvOp::SpvOpLabel, 0, result_id, {}); -} - -// Creates an OpBranch |target_id| instruction. -ir::Instruction BranchInstruction(uint32_t target_id) { - return ir::Instruction(SpvOp::SpvOpBranch, 0, 0, - { - ir::Operand(SPV_OPERAND_TYPE_ID, {target_id}), - }); -} - // Test case for analyzing individual instructions. struct AnalyzeInstDefUseTestCase { - std::vector insts; // instrutions to be analyzed in order. const char* module_text; InstDefUse expected_define_use; }; @@ -1193,15 +1162,12 @@ TEST_P(AnalyzeInstDefUseTest, Case) { auto tc = GetParam(); // Build module. - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.module_text); - ASSERT_NE(nullptr, module); + ASSERT_NE(nullptr, context); // Analyze the instructions. - opt::analysis::DefUseManager manager(module.get()); - for (ir::Instruction& inst : tc.insts) { - manager.AnalyzeInstDefUse(&inst); - } + opt::analysis::DefUseManager manager(context->module()); CheckDef(tc.expected_define_use, manager.id_to_defs()); CheckUse(tc.expected_define_use, manager.id_to_uses()); @@ -1212,8 +1178,7 @@ INSTANTIATE_TEST_CASE_P( TestCase, AnalyzeInstDefUseTest, ::testing::ValuesIn(std::vector{ { // A type declaring instruction. - {Int32TypeInstruction(1)}, - "", + "%1 = OpTypeInt 32 1", { // defs {{1, "%1 = OpTypeInt 32 1"}}, @@ -1221,88 +1186,79 @@ INSTANTIATE_TEST_CASE_P( }, }, { // A type declaring instruction and a constant value. - { - Int32TypeInstruction(1), - ConstantBoolInstruction(true, 1, 2), - }, - "", - { - { // defs - {1, "%1 = OpTypeInt 32 1"}, - {2, "%2 = OpConstantTrue %1"}, // It is fine the SPIR-V code here is invalid. - }, - { // uses - {1, {"%2 = OpConstantTrue %1"}}, - }, - }, - }, - { // Analyze two instrutions that have same result id. The def use info - // of the result id from the first instruction should be overwritten by - // the second instruction. - { - ConstantBoolInstruction(true, 1, 2), - // The def-use info of the following instruction should overwrite the - // records of the above one. - ConstantBoolInstruction(false, 3, 2), - }, - "", - { - // defs - {{2, "%2 = OpConstantFalse %3"}}, - // uses - {{3, {"%2 = OpConstantFalse %3"}}} - } - }, - { // Analyze forward reference instruction, also instruction that does - // not have result id. - { - BranchInstruction(2), - LabelInstruction(2), - }, - "", - { - // defs - {{2, "%2 = OpLabel"}}, - // uses - {{2, {"OpBranch %2"}}}, - } - }, - { // Analyzing an additional instruction with new result id to an - // existing module. - { - ConstantBoolInstruction(true, 1, 2), - }, - "%1 = OpTypeInt 32 1 ", + "%1 = OpTypeBool " + "%2 = OpConstantTrue %1", { { // defs - {1, "%1 = OpTypeInt 32 1"}, + {1, "%1 = OpTypeBool"}, {2, "%2 = OpConstantTrue %1"}, }, { // uses {1, {"%2 = OpConstantTrue %1"}}, }, - } - }, - { // Analyzing an additional instruction with existing result id to an - // existing module. - { - ConstantBoolInstruction(true, 1, 2), }, - "%1 = OpTypeInt 32 1 " - "%2 = OpTypeBool ", - { - { // defs - {1, "%1 = OpTypeInt 32 1"}, - {2, "%2 = OpConstantTrue %1"}, - }, - { // uses - {1, {"%2 = OpConstantTrue %1"}}, - }, - } }, })); // clang-format on +using AnalyzeInstDefUse = ::testing::Test; + +TEST(AnalyzeInstDefUse, UseWithNoResultId) { + ir::IRContext context(nullptr); + + // Analyze the instructions. + opt::analysis::DefUseManager manager(context.module()); + + ir::Instruction label(&context, SpvOpLabel, 0, 2, {}); + manager.AnalyzeInstDefUse(&label); + + ir::Instruction branch(&context, SpvOpBranch, 0, 0, + {{SPV_OPERAND_TYPE_ID, {2}}}); + manager.AnalyzeInstDefUse(&branch); + + InstDefUse expected = + { + // defs + { + {2, "%2 = OpLabel"}, + }, + // uses + {{2, {"OpBranch %2"}}}, + }; + + CheckDef(expected, manager.id_to_defs()); + CheckUse(expected, manager.id_to_uses()); +} + +TEST(AnalyzeInstDefUse, AddNewInstruction) { + const std::string input = "%1 = OpTypeBool"; + + // Build module. + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input); + ASSERT_NE(nullptr, context); + + // Analyze the instructions. + opt::analysis::DefUseManager manager(context->module()); + + ir::Instruction newInst(context.get(), SpvOpConstantTrue, 1, 2, {}); + manager.AnalyzeInstDefUse(&newInst); + + InstDefUse expected = + { + { // defs + {1, "%1 = OpTypeBool"}, + {2, "%2 = OpConstantTrue %1"}, + }, + { // uses + {1, {"%2 = OpConstantTrue %1"}}, + }, + }; + + CheckDef(expected, manager.id_to_defs()); + CheckUse(expected, manager.id_to_uses()); +} + struct KillInstTestCase { const char* before; std::unordered_set indices_for_inst_to_kill; @@ -1316,27 +1272,26 @@ TEST_P(KillInstTest, Case) { auto tc = GetParam(); // Build module. - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.before); - ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module), spvtools::MessageConsumer()); + ASSERT_NE(nullptr, context); // Force a re-build of the def-use manager. - context.InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); - (void)context.get_def_use_mgr(); + context->InvalidateAnalyses(ir::IRContext::Analysis::kAnalysisDefUse); + (void)context->get_def_use_mgr(); // KillInst uint32_t index = 0; - context.module()->ForEachInst([&index, &tc, &context](ir::Instruction* inst) { + context->module()->ForEachInst([&index, &tc, &context](ir::Instruction* inst) { if (tc.indices_for_inst_to_kill.count(index) != 0) { - context.KillInst(inst); + context->KillInst(inst); } index++; }); - EXPECT_EQ(tc.after, DisassembleModule(context.module())); - CheckDef(tc.expected_define_use, context.get_def_use_mgr()->id_to_defs()); - CheckUse(tc.expected_define_use, context.get_def_use_mgr()->id_to_uses()); + EXPECT_EQ(tc.after, DisassembleModule(context->module())); + CheckDef(tc.expected_define_use, context->get_def_use_mgr()->id_to_defs()); + CheckUse(tc.expected_define_use, context->get_def_use_mgr()->id_to_uses()); } // clang-format off @@ -1428,12 +1383,12 @@ TEST_P(GetAnnotationsTest, Case) { const GetAnnotationsTestCase& tc = GetParam(); // Build module. - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.code); - ASSERT_NE(nullptr, module); + ASSERT_NE(nullptr, context); // Get annotations - opt::analysis::DefUseManager manager(module.get()); + opt::analysis::DefUseManager manager(context->module()); auto insts = manager.GetAnnotations(tc.id); // Check diff --git a/test/opt/instruction_test.cpp b/test/opt/instruction_test.cpp index 2db4ed2..f930d46 100644 --- a/test/opt/instruction_test.cpp +++ b/test/opt/instruction_test.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "opt/instruction.h" +#include "opt/ir_context.h" #include "gmock/gmock.h" @@ -23,6 +24,7 @@ namespace { using spvtest::MakeInstruction; using spvtools::ir::Instruction; +using spvtools::ir::IRContext; using spvtools::ir::Operand; using ::testing::Eq; @@ -39,7 +41,8 @@ TEST(InstructionTest, CreateTrivial) { } TEST(InstructionTest, CreateWithOpcodeAndNoOperands) { - Instruction inst(SpvOpReturn); + IRContext context(nullptr); + Instruction inst(&context, SpvOpReturn); EXPECT_EQ(SpvOpReturn, inst.opcode()); EXPECT_EQ(0u, inst.type_id()); EXPECT_EQ(0u, inst.result_id()); @@ -119,7 +122,8 @@ spv_parsed_instruction_t kSampleControlBarrierInstruction = { 3}; TEST(InstructionTest, CreateWithOpcodeAndOperands) { - Instruction inst(kSampleParsedInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleParsedInstruction); EXPECT_EQ(SpvOpTypeInt, inst.opcode()); EXPECT_EQ(0u, inst.type_id()); EXPECT_EQ(44u, inst.result_id()); @@ -129,20 +133,23 @@ TEST(InstructionTest, CreateWithOpcodeAndOperands) { } TEST(InstructionTest, GetOperand) { - Instruction inst(kSampleParsedInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleParsedInstruction); EXPECT_THAT(inst.GetOperand(0).words, Eq(std::vector{44})); EXPECT_THAT(inst.GetOperand(1).words, Eq(std::vector{32})); EXPECT_THAT(inst.GetOperand(2).words, Eq(std::vector{1})); } TEST(InstructionTest, GetInOperand) { - Instruction inst(kSampleParsedInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleParsedInstruction); EXPECT_THAT(inst.GetInOperand(0).words, Eq(std::vector{32})); EXPECT_THAT(inst.GetInOperand(1).words, Eq(std::vector{1})); } TEST(InstructionTest, OperandConstIterators) { - Instruction inst(kSampleParsedInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleParsedInstruction); // Spot check iteration across operands. auto cbegin = inst.cbegin(); auto cend = inst.cend(); @@ -168,7 +175,8 @@ TEST(InstructionTest, OperandConstIterators) { } TEST(InstructionTest, OperandIterators) { - Instruction inst(kSampleParsedInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleParsedInstruction); // Spot check iteration across operands, with mutable iterators. auto begin = inst.begin(); auto end = inst.end(); @@ -198,7 +206,8 @@ TEST(InstructionTest, OperandIterators) { } TEST(InstructionTest, ForInIdStandardIdTypes) { - Instruction inst(kSampleAccessChainInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleAccessChainInstruction); std::vector ids; inst.ForEachInId([&ids](const uint32_t* idptr) { ids.push_back(*idptr); }); @@ -210,7 +219,8 @@ TEST(InstructionTest, ForInIdStandardIdTypes) { } TEST(InstructionTest, ForInIdNonstandardIdTypes) { - Instruction inst(kSampleControlBarrierInstruction); + IRContext context(nullptr); + Instruction inst(&context, kSampleControlBarrierInstruction); std::vector ids; inst.ForEachInId([&ids](const uint32_t* idptr) { ids.push_back(*idptr); }); @@ -221,4 +231,60 @@ TEST(InstructionTest, ForInIdNonstandardIdTypes) { EXPECT_THAT(ids, Eq(std::vector{100, 101, 102})); } +TEST(InstructionTest, UniqueIds) { + IRContext context(nullptr); + Instruction inst1(&context); + Instruction inst2(&context); + EXPECT_NE(inst1.unique_id(), inst2.unique_id()); +} + +TEST(InstructionTest, CloneUniqueIdDifferent) { + IRContext context(nullptr); + Instruction inst(&context); + std::unique_ptr clone(inst.Clone(&context)); + EXPECT_EQ(inst.context(), clone->context()); + EXPECT_NE(inst.unique_id(), clone->unique_id()); +} + +TEST(InstructionTest, CloneDifferentContext) { + IRContext c1(nullptr); + IRContext c2(nullptr); + Instruction inst(&c1); + std::unique_ptr clone(inst.Clone(&c2)); + EXPECT_EQ(&c1, inst.context()); + EXPECT_EQ(&c2, clone->context()); + EXPECT_NE(&c1, &c2); +} + +TEST(InstructionTest, CloneDifferentContextDifferentUniqueId) { + IRContext c1(nullptr); + IRContext c2(nullptr); + Instruction inst(&c1); + Instruction other(&c2); + std::unique_ptr clone(inst.Clone(&c2)); + EXPECT_EQ(&c2, clone->context()); + EXPECT_NE(other.unique_id(), clone->unique_id()); +} + +TEST(InstructionTest, EqualsEqualsOperator) { + IRContext context(nullptr); + Instruction i1(&context); + Instruction i2(&context); + std::unique_ptr clone(i1.Clone(&context)); + EXPECT_TRUE(i1 == i1); + EXPECT_FALSE(i1 == i2); + EXPECT_FALSE(i1 == *clone); + EXPECT_FALSE(i2 == *clone); +} + +TEST(InstructionTest, LessThanOperator) { + IRContext context(nullptr); + Instruction i1(&context); + Instruction i2(&context); + std::unique_ptr clone(i1.Clone(&context)); + EXPECT_TRUE(i1 < i2); + EXPECT_TRUE(i1 < *clone); + EXPECT_TRUE(i2 < *clone); +} + } // anonymous namespace diff --git a/test/opt/ir_context_test.cpp b/test/opt/ir_context_test.cpp index e2ace77..770c306 100644 --- a/test/opt/ir_context_test.cpp +++ b/test/opt/ir_context_test.cpp @@ -62,97 +62,97 @@ using IRContextTest = PassTest<::testing::Test>; TEST_F(IRContextTest, IndividualValidAfterBuild) { std::unique_ptr module(new ir::Module()); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); - EXPECT_TRUE(context.AreAnalysesValid(i)); + localContext.BuildInvalidAnalyses(i); + EXPECT_TRUE(localContext.AreAnalysesValid(i)); } } TEST_F(IRContextTest, AllValidAfterBuild) { std::unique_ptr module = MakeUnique(); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); Analysis built_analyses = IRContext::kAnalysisNone; for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); + localContext.BuildInvalidAnalyses(i); built_analyses |= i; } - EXPECT_TRUE(context.AreAnalysesValid(built_analyses)); + EXPECT_TRUE(localContext.AreAnalysesValid(built_analyses)); } TEST_F(IRContextTest, AllValidAfterPassNoChange) { std::unique_ptr module = MakeUnique(); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); Analysis built_analyses = IRContext::kAnalysisNone; for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); + localContext.BuildInvalidAnalyses(i); built_analyses |= i; } DummyPassPreservesNothing pass(opt::Pass::Status::SuccessWithoutChange); - opt::Pass::Status s = pass.Run(&context); + opt::Pass::Status s = pass.Run(&localContext); EXPECT_EQ(s, opt::Pass::Status::SuccessWithoutChange); - EXPECT_TRUE(context.AreAnalysesValid(built_analyses)); + EXPECT_TRUE(localContext.AreAnalysesValid(built_analyses)); } TEST_F(IRContextTest, NoneValidAfterPassWithChange) { std::unique_ptr module = MakeUnique(); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); + localContext.BuildInvalidAnalyses(i); } DummyPassPreservesNothing pass(opt::Pass::Status::SuccessWithChange); - opt::Pass::Status s = pass.Run(&context); + opt::Pass::Status s = pass.Run(&localContext); EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - EXPECT_FALSE(context.AreAnalysesValid(i)); + EXPECT_FALSE(localContext.AreAnalysesValid(i)); } } TEST_F(IRContextTest, AllPreservedAfterPassWithChange) { std::unique_ptr module = MakeUnique(); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); + localContext.BuildInvalidAnalyses(i); } DummyPassPreservesAll pass(opt::Pass::Status::SuccessWithChange); - opt::Pass::Status s = pass.Run(&context); + opt::Pass::Status s = pass.Run(&localContext); EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - EXPECT_TRUE(context.AreAnalysesValid(i)); + EXPECT_TRUE(localContext.AreAnalysesValid(i)); } } TEST_F(IRContextTest, PreserveFirstOnlyAfterPassWithChange) { std::unique_ptr module = MakeUnique(); - IRContext context(std::move(module), spvtools::MessageConsumer()); + IRContext localContext(std::move(module), spvtools::MessageConsumer()); for (Analysis i = IRContext::kAnalysisBegin; i < IRContext::kAnalysisEnd; i <<= 1) { - context.BuildInvalidAnalyses(i); + localContext.BuildInvalidAnalyses(i); } DummyPassPreservesFirst pass(opt::Pass::Status::SuccessWithChange); - opt::Pass::Status s = pass.Run(&context); + opt::Pass::Status s = pass.Run(&localContext); EXPECT_EQ(s, opt::Pass::Status::SuccessWithChange); - EXPECT_TRUE(context.AreAnalysesValid(IRContext::kAnalysisBegin)); + EXPECT_TRUE(localContext.AreAnalysesValid(IRContext::kAnalysisBegin)); for (Analysis i = IRContext::kAnalysisBegin << 1; i < IRContext::kAnalysisEnd; i <<= 1) { - EXPECT_FALSE(context.AreAnalysesValid(i)); + EXPECT_FALSE(localContext.AreAnalysesValid(i)); } } @@ -178,25 +178,31 @@ TEST_F(IRContextTest, KillMemberName) { OpFunctionEnd )"; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text); - ir::IRContext context(std::move(module), spvtools::MessageConsumer()); // Build the decoration manager. - context.get_decoration_mgr(); + context->get_decoration_mgr(); // Delete the OpTypeStruct. Should delete the OpName, OpMemberName, and // OpMemberDecorate associated with it. - context.KillDef(3); + context->KillDef(3); // Make sure all of the name are removed. - for (auto& inst : context.debugs2()) { + for (auto& inst : context->debugs2()) { EXPECT_EQ(inst.opcode(), SpvOpNop); } // Make sure all of the decorations are removed. - for (auto& inst : context.annotations()) { + for (auto& inst : context->annotations()) { EXPECT_EQ(inst.opcode(), SpvOpNop); } } + +TEST_F(IRContextTest, TakeNextUniqueIdIncrementing) { + const uint32_t NUM_TESTS = 1000; + IRContext localContext(nullptr); + for (uint32_t i = 1; i < NUM_TESTS; ++i) + EXPECT_EQ(i, localContext.TakeNextUniqueId()); +} } // anonymous namespace diff --git a/test/opt/ir_loader_test.cpp b/test/opt/ir_loader_test.cpp index b61f7cb..ae46df9 100644 --- a/test/opt/ir_loader_test.cpp +++ b/test/opt/ir_loader_test.cpp @@ -14,9 +14,11 @@ #include #include +#include #include "message.h" #include "opt/build_module.h" +#include "opt/ir_context.h" #include "spirv-tools/libspirv.hpp" namespace { @@ -25,12 +27,12 @@ using namespace spvtools; void DoRoundTripCheck(const std::string& text) { SpirvTools t(SPV_ENV_UNIVERSAL_1_1); - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - ASSERT_NE(nullptr, module) << "Failed to assemble\n" << text; + ASSERT_NE(nullptr, context) << "Failed to assemble\n" << text; std::vector binary; - module->ToBinary(&binary, /* skip_nop = */ false); + context->module()->ToBinary(&binary, /* skip_nop = */ false); std::string disassembled_text; EXPECT_TRUE(t.Disassemble(binary, &disassembled_text)); @@ -212,17 +214,17 @@ TEST(IrBuilder, OpUndefOutsideFunction) { // clang-format on SpirvTools t(SPV_ENV_UNIVERSAL_1_1); - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - ASSERT_NE(nullptr, module); + ASSERT_NE(nullptr, context); const auto opundef_count = std::count_if( - module->types_values_begin(), module->types_values_end(), + context->module()->types_values_begin(), context->module()->types_values_end(), [](const ir::Instruction& inst) { return inst.opcode() == SpvOpUndef; }); EXPECT_EQ(3, opundef_count); std::vector binary; - module->ToBinary(&binary, /* skip_nop = */ false); + context->module()->ToBinary(&binary, /* skip_nop = */ false); std::string disassembled_text; EXPECT_TRUE(t.Disassemble(binary, &disassembled_text)); @@ -322,9 +324,9 @@ void DoErrorMessageCheck(const std::string& assembly, }; SpirvTools t(SPV_ENV_UNIVERSAL_1_1); - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, std::move(consumer), assembly); - EXPECT_EQ(nullptr, module); + EXPECT_EQ(nullptr, context); } TEST(IrBuilder, FunctionInsideFunction) { @@ -378,4 +380,69 @@ TEST(IrBuilder, NotAllowedInstAppearingInFunction) { "block"); } +TEST(IrBuilder, UniqueIds) { + const std::string text = + // clang-format off + "OpCapability Shader\n" + "%1 = OpExtInstImport \"GLSL.std.450\"\n" + "OpMemoryModel Logical GLSL450\n" + "OpEntryPoint Vertex %main \"main\"\n" + "OpSource ESSL 310\n" + "OpName %main \"main\"\n" + "OpName %f_ \"f(\"\n" + "OpName %gv1 \"gv1\"\n" + "OpName %gv2 \"gv2\"\n" + "OpName %lv1 \"lv1\"\n" + "OpName %lv2 \"lv2\"\n" + "OpName %lv1_0 \"lv1\"\n" + "%void = OpTypeVoid\n" + "%10 = OpTypeFunction %void\n" + "%float = OpTypeFloat 32\n" + "%12 = OpTypeFunction %float\n" + "%_ptr_Private_float = OpTypePointer Private %float\n" + "%gv1 = OpVariable %_ptr_Private_float Private\n" + "%float_10 = OpConstant %float 10\n" + "%gv2 = OpVariable %_ptr_Private_float Private\n" + "%float_100 = OpConstant %float 100\n" + "%_ptr_Function_float = OpTypePointer Function %float\n" + "%main = OpFunction %void None %10\n" + "%17 = OpLabel\n" + "%lv1_0 = OpVariable %_ptr_Function_float Function\n" + "OpStore %gv1 %float_10\n" + "OpStore %gv2 %float_100\n" + "%18 = OpLoad %float %gv1\n" + "%19 = OpLoad %float %gv2\n" + "%20 = OpFSub %float %18 %19\n" + "OpStore %lv1_0 %20\n" + "OpReturn\n" + "OpFunctionEnd\n" + "%f_ = OpFunction %float None %12\n" + "%21 = OpLabel\n" + "%lv1 = OpVariable %_ptr_Function_float Function\n" + "%lv2 = OpVariable %_ptr_Function_float Function\n" + "%22 = OpLoad %float %gv1\n" + "%23 = OpLoad %float %gv2\n" + "%24 = OpFAdd %float %22 %23\n" + "OpStore %lv1 %24\n" + "%25 = OpLoad %float %gv1\n" + "%26 = OpLoad %float %gv2\n" + "%27 = OpFMul %float %25 %26\n" + "OpStore %lv2 %27\n" + "%28 = OpLoad %float %lv1\n" + "%29 = OpLoad %float %lv2\n" + "%30 = OpFDiv %float %28 %29\n" + "OpReturnValue %30\n" + "OpFunctionEnd\n"; + // clang-format on + + std::unique_ptr context = + BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); + ASSERT_NE(nullptr, context); + + std::unordered_set ids; + context->module()->ForEachInst([&ids](const ir::Instruction* inst) { + EXPECT_TRUE(ids.insert(inst->unique_id()).second); + }); +} + } // anonymous namespace diff --git a/test/opt/module_test.cpp b/test/opt/module_test.cpp index 622d920..4a434ed 100644 --- a/test/opt/module_test.cpp +++ b/test/opt/module_test.cpp @@ -26,6 +26,7 @@ namespace { +using spvtools::ir::IRContext; using spvtools::ir::Module; using spvtest::GetIdBound; using ::testing::Eq; @@ -42,31 +43,31 @@ TEST(ModuleTest, SetIdBound) { EXPECT_EQ(102u, GetIdBound(m)); } -// Returns a module formed by assembling the given text, +// Returns an IRContext owning the module formed by assembling the given text, // then loading the result. -inline std::unique_ptr BuildModule(std::string text) { +inline std::unique_ptr BuildModule(std::string text) { return spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); } TEST(ModuleTest, ComputeIdBound) { // Emtpy module case. - EXPECT_EQ(1u, BuildModule("")->ComputeIdBound()); + EXPECT_EQ(1u, BuildModule("")->module()->ComputeIdBound()); // Sensitive to result id - EXPECT_EQ(2u, BuildModule("%void = OpTypeVoid")->ComputeIdBound()); + EXPECT_EQ(2u, BuildModule("%void = OpTypeVoid")->module()->ComputeIdBound()); // Sensitive to type id - EXPECT_EQ(1000u, BuildModule("%a = OpTypeArray !999 3")->ComputeIdBound()); + EXPECT_EQ(1000u, BuildModule("%a = OpTypeArray !999 3")->module()->ComputeIdBound()); // Sensitive to a regular Id parameter - EXPECT_EQ(2000u, BuildModule("OpDecorate !1999 0")->ComputeIdBound()); + EXPECT_EQ(2000u, BuildModule("OpDecorate !1999 0")->module()->ComputeIdBound()); // Sensitive to a scope Id parameter. EXPECT_EQ(3000u, BuildModule("%f = OpFunction %void None %fntype %a = OpLabel " "OpMemoryBarrier !2999 %b\n") - ->ComputeIdBound()); + ->module()->ComputeIdBound()); // Sensitive to a semantics Id parameter EXPECT_EQ(4000u, BuildModule("%f = OpFunction %void None %fntype %a = OpLabel " "OpMemoryBarrier %b !3999\n") - ->ComputeIdBound()); + ->module()->ComputeIdBound()); } } // anonymous namespace diff --git a/test/opt/pass_fixture.h b/test/opt/pass_fixture.h index 7ad7181..fdc4398 100644 --- a/test/opt/pass_fixture.h +++ b/test/opt/pass_fixture.h @@ -46,36 +46,35 @@ class PassTest : public TestT { public: PassTest() : consumer_(nullptr), + context_(nullptr), tools_(SPV_ENV_UNIVERSAL_1_1), manager_(new opt::PassManager()), assemble_options_(SpirvTools::kDefaultAssembleOption), disassemble_options_(SpirvTools::kDefaultDisassembleOption) {} // Runs the given |pass| on the binary assembled from the |original|. - // Returns a tuple of the optimized binary and the boolean value returned + // Returns a tuple of the optimized binary and the boolean value returned // from pass Process() function. std::tuple, opt::Pass::Status> OptimizeToBinary( opt::Pass* pass, const std::string& original, bool skip_nop) { - std::unique_ptr module = BuildModule( - SPV_ENV_UNIVERSAL_1_1, consumer_, original, assemble_options_); - EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" - << original << std::endl; - if (!module) { - return std::make_tuple(std::vector(), + context_ = std::move(BuildModule(SPV_ENV_UNIVERSAL_1_1, consumer_, original, + assemble_options_)); + EXPECT_NE(nullptr, context()) << "Assembling failed for shader:\n" + << original << std::endl; + if (!context()) { + return std::make_tuple(std::vector(), opt::Pass::Status::Failure); } - ir::IRContext context(std::move(module), consumer()); - - const auto status = pass->Run(&context); + const auto status = pass->Run(context()); std::vector binary; - context.module()->ToBinary(&binary, skip_nop); + context()->module()->ToBinary(&binary, skip_nop); return std::make_tuple(binary, status); } // Runs a single pass of class |PassT| on the binary assembled from the - // |assembly|. Returns a tuple of the optimized binary and the boolean value + // |assembly|. Returns a tuple of the optimized binary and the boolean value // from the pass Process() function. template std::tuple, opt::Pass::Status> SinglePassRunToBinary( @@ -106,7 +105,7 @@ class PassTest : public TestT { // Runs a single pass of class |PassT| on the binary assembled from the // |original| assembly, and checks whether the optimized binary can be // disassembled to the |expected| assembly. Optionally will also validate - // the optimized binary. This does *not* involve pass manager. Callers + // the optimized binary. This does *not* involve pass manager. Callers // are suggested to use SCOPED_TRACE() for better messages. template void SinglePassRunAndCheck(const std::string& original, @@ -122,16 +121,16 @@ class PassTest : public TestT { status == opt::Pass::Status::SuccessWithoutChange); if (do_validation) { spv_target_env target_env = SPV_ENV_UNIVERSAL_1_1; - spv_context context = spvContextCreate(target_env); + spv_context spvContext = spvContextCreate(target_env); spv_diagnostic diagnostic = nullptr; spv_const_binary_t binary = {optimized_bin.data(), optimized_bin.size()}; - spv_result_t error = spvValidate(context, &binary, &diagnostic); + spv_result_t error = spvValidate(spvContext, &binary, &diagnostic); EXPECT_EQ(error, 0); if (error != 0) spvDiagnosticPrint(diagnostic); spvDiagnosticDestroy(diagnostic); - spvContextDestroy(context); + spvContextDestroy(spvContext); } std::string optimized_asm; EXPECT_TRUE(tools_.Disassemble(optimized_bin, &optimized_asm, @@ -191,15 +190,14 @@ class PassTest : public TestT { void RunAndCheck(const std::string& original, const std::string& expected) { assert(manager_->NumPasses()); - std::unique_ptr module = BuildModule( - SPV_ENV_UNIVERSAL_1_1, nullptr, original, assemble_options_); - ASSERT_NE(nullptr, module); - ir::IRContext context(std::move(module), consumer()); + context_ = std::move(BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original, + assemble_options_)); + ASSERT_NE(nullptr, context()); - manager_->Run(&context); + manager_->Run(context()); std::vector binary; - context.module()->ToBinary(&binary, /* skip_nop = */ false); + context()->module()->ToBinary(&binary, /* skip_nop = */ false); std::string optimized; EXPECT_TRUE(tools_.Disassemble(binary, &optimized, @@ -216,8 +214,10 @@ class PassTest : public TestT { } MessageConsumer consumer() { return consumer_;} + ir::IRContext* context() { return context_.get(); } private: MessageConsumer consumer_; // Message consumer. + std::unique_ptr context_; // IR context SpirvTools tools_; // An instance for calling SPIRV-Tools functionalities. std::unique_ptr manager_; // The pass manager. uint32_t assemble_options_; diff --git a/test/opt/pass_manager_test.cpp b/test/opt/pass_manager_test.cpp index 43d7005..77ed38b 100644 --- a/test/opt/pass_manager_test.cpp +++ b/test/opt/pass_manager_test.cpp @@ -75,7 +75,7 @@ class AppendOpNopPass : public opt::Pass { public: const char* name() const override { return "AppendOpNop"; } Status Process(ir::IRContext* irContext) override { - irContext->AddDebug1Inst(MakeUnique()); + irContext->AddDebug1Inst(MakeUnique(irContext)); return Status::SuccessWithChange; } }; @@ -89,7 +89,7 @@ class AppendMultipleOpNopPass : public opt::Pass { const char* name() const override { return "AppendOpNop"; } Status Process(ir::IRContext* irContext) override { for (uint32_t i = 0; i < num_nop_; i++) { - irContext->AddDebug1Inst(MakeUnique()); + irContext->AddDebug1Inst(MakeUnique(irContext)); } return Status::SuccessWithChange; } @@ -103,7 +103,8 @@ class DuplicateInstPass : public opt::Pass { public: const char* name() const override { return "DuplicateInst"; } Status Process(ir::IRContext* irContext) override { - auto inst = MakeUnique(*(--irContext->debug1_end())); + auto inst = MakeUnique( + *(--irContext->debug1_end())->Clone(irContext)); irContext->AddDebug1Inst(std::move(inst)); return Status::SuccessWithChange; } @@ -140,7 +141,7 @@ class AppendTypeVoidInstPass : public opt::Pass { const char* name() const override { return "AppendTypeVoidInstPass"; } Status Process(ir::IRContext* irContext) override { - auto inst = MakeUnique(SpvOpTypeVoid, 0, result_id_, + auto inst = MakeUnique(irContext, SpvOpTypeVoid, 0, result_id_, std::vector{}); irContext->AddType(std::move(inst)); return Status::SuccessWithChange; diff --git a/test/opt/pass_test.cpp b/test/opt/pass_test.cpp index 8c62b28..5ff1a12 100644 --- a/test/opt/pass_test.cpp +++ b/test/opt/pass_test.cpp @@ -76,18 +76,18 @@ TEST_F(PassClassTest, BasicVisitFromEntryPoint) { )"; // clang-format on - std::unique_ptr module = + std::unique_ptr localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" - << text << std::endl; + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; DummyPass testPass; std::vector processed; opt::Pass::ProcessFunction mark_visited = [&processed](ir::Function* fp) { processed.push_back(fp->result_id()); return false; }; - testPass.ProcessEntryPointCallTree(mark_visited, module.get()); + testPass.ProcessEntryPointCallTree(mark_visited, localContext->module()); EXPECT_THAT(processed, UnorderedElementsAre(10, 11)); } @@ -132,12 +132,11 @@ TEST_F(PassClassTest, BasicVisitReachable) { )"; // clang-format on - std::unique_ptr module = + std::unique_ptr localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" - << text << std::endl; - ir::IRContext context(std::move(module), consumer()); + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; DummyPass testPass; std::vector processed; @@ -145,7 +144,7 @@ TEST_F(PassClassTest, BasicVisitReachable) { processed.push_back(fp->result_id()); return false; }; - testPass.ProcessReachableCallTree(mark_visited, &context); + testPass.ProcessReachableCallTree(mark_visited, localContext.get()); EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12, 13)); } @@ -185,12 +184,11 @@ TEST_F(PassClassTest, BasicVisitOnlyOnce) { )"; // clang-format on - std::unique_ptr module = + std::unique_ptr localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" - << text << std::endl; - ir::IRContext context(std::move(module), consumer()); + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; DummyPass testPass; std::vector processed; @@ -198,7 +196,7 @@ TEST_F(PassClassTest, BasicVisitOnlyOnce) { processed.push_back(fp->result_id()); return false; }; - testPass.ProcessReachableCallTree(mark_visited, &context); + testPass.ProcessReachableCallTree(mark_visited, localContext.get()); EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12)); } @@ -228,12 +226,11 @@ TEST_F(PassClassTest, BasicDontVisitExportedVariable) { )"; // clang-format on - std::unique_ptr module = + std::unique_ptr localContext = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); - EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" - << text << std::endl; - ir::IRContext context(std::move(module), consumer()); + EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n" + << text << std::endl; DummyPass testPass; std::vector processed; @@ -241,7 +238,7 @@ TEST_F(PassClassTest, BasicDontVisitExportedVariable) { processed.push_back(fp->result_id()); return false; }; - testPass.ProcessReachableCallTree(mark_visited, &context); + testPass.ProcessReachableCallTree(mark_visited, localContext.get()); EXPECT_THAT(processed, UnorderedElementsAre(10)); } } // namespace diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp index 17eb2a4..1c8f2db 100644 --- a/test/opt/type_manager_test.cpp +++ b/test/opt/type_manager_test.cpp @@ -88,9 +88,9 @@ TEST(TypeManager, TypeStrings) { {28, "named_barrier"}, }; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); EXPECT_EQ(type_id_strs.size(), manager.NumTypes()); EXPECT_EQ(2u, manager.NumForwardPointers()); @@ -118,9 +118,9 @@ TEST(TypeManager, DecorationOnStruct) { %struct4 = OpTypeStruct %u32 %f32 ; the same %struct7 = OpTypeStruct %f32 ; no decoration )"; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); ASSERT_EQ(7u, manager.NumTypes()); ASSERT_EQ(0u, manager.NumForwardPointers()); @@ -168,9 +168,9 @@ TEST(TypeManager, DecorationOnMember) { %struct7 = OpTypeStruct %u32 %f32 ; extra decoration on the struct %struct10 = OpTypeStruct %u32 %f32 ; no member decoration )"; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); ASSERT_EQ(10u, manager.NumTypes()); ASSERT_EQ(0u, manager.NumForwardPointers()); @@ -206,9 +206,9 @@ TEST(TypeManager, DecorationEmpty) { %struct2 = OpTypeStruct %f32 %u32 %struct5 = OpTypeStruct %f32 )"; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); ASSERT_EQ(5u, manager.NumTypes()); ASSERT_EQ(0u, manager.NumForwardPointers()); @@ -228,9 +228,9 @@ TEST(TypeManager, DecorationEmpty) { TEST(TypeManager, BeginEndForEmptyModule) { const std::string text = ""; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); ASSERT_EQ(0u, manager.NumTypes()); ASSERT_EQ(0u, manager.NumForwardPointers()); @@ -245,9 +245,9 @@ TEST(TypeManager, BeginEnd) { %u32 = OpTypeInt 32 0 %f64 = OpTypeFloat 64 )"; - std::unique_ptr module = + std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text); - opt::analysis::TypeManager manager(nullptr, *module); + opt::analysis::TypeManager manager(nullptr, *context->module()); ASSERT_EQ(5u, manager.NumTypes()); ASSERT_EQ(0u, manager.NumForwardPointers()); -- 2.7.4