From 04fcc6674333914e9b8da9e4c103de638ef4c609 Mon Sep 17 00:00:00 2001 From: Greg Fischer Date: Thu, 10 Nov 2016 10:11:50 -0700 Subject: [PATCH] Add exhaustive function call inlining to spirv-opt Inlining is done for all functions designated as entry points. Add optional validation to test fixture method SinglePassRunAndCheck. --- README.md | 1 + include/spirv-tools/optimizer.hpp | 9 + source/opt/CMakeLists.txt | 3 + source/opt/basic_block.cpp | 42 ++ source/opt/basic_block.h | 20 + source/opt/build_module.cpp | 4 +- source/opt/function.cpp | 11 +- source/opt/function.h | 11 + source/opt/inline_pass.cpp | 440 ++++++++++++ source/opt/inline_pass.h | 143 ++++ source/opt/instruction.h | 45 ++ source/opt/iterator.h | 42 ++ source/opt/module.cpp | 26 +- source/opt/module.h | 14 + source/opt/optimizer.cpp | 4 + source/opt/passes.h | 1 + source/opt/types.cpp | 10 +- test/opt/CMakeLists.txt | 5 + test/opt/inline_test.cpp | 1371 +++++++++++++++++++++++++++++++++++++ test/opt/pass_fixture.h | 82 ++- tools/opt/opt.cpp | 4 + 21 files changed, 2246 insertions(+), 42 deletions(-) create mode 100644 source/opt/basic_block.cpp create mode 100644 source/opt/inline_pass.cpp create mode 100644 source/opt/inline_pass.h create mode 100644 test/opt/inline_test.cpp diff --git a/README.md b/README.md index 73c198a..d55f28f 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,7 @@ Currently supported optimizations: * Fold `OpSpecConstantOp` and `OpSpecConstantComposite` * Unify constants * Eliminate dead constant +* Inline all function calls in entry points For the latest list with detailed documentation, please refer to [`include/spirv-tools/optimizer.hpp`](include/spirv-tools/optimizer.hpp). diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index f856e3b..68bb695 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -167,6 +167,15 @@ Optimizer::PassToken CreateUnifyConstantPass(); // OpSpecConstantOp. Optimizer::PassToken CreateEliminateDeadConstantPass(); +// Creates an inline pass. +// An inline pass exhaustively inlines all function calls in all functions +// designated as an entry point. The intent is to enable, albeit through +// brute force, analysis and optimization across function calls by subsequent +// passes. As the inlining is exhaustive, there is no attempt to optimize for +// size or runtime performance. Functions that are not designated as entry +// points are not changed. +Optimizer::PassToken CreateInlinePass(); + } // namespace spvtools #endif // SPIRV_TOOLS_OPTIMIZER_HPP_ diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index f1420e8..55949c8 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(SPIRV-Tools-opt function.h fold_spec_constant_op_and_composite_pass.h freeze_spec_constant_value_pass.h + inline_pass.h instruction.h ir_loader.h log.h @@ -35,12 +36,14 @@ add_library(SPIRV-Tools-opt type_manager.h unify_const_pass.h + basic_block.cpp build_module.cpp def_use_manager.cpp eliminate_dead_constant_pass.cpp function.cpp fold_spec_constant_op_and_composite_pass.cpp freeze_spec_constant_value_pass.cpp + inline_pass.cpp instruction.cpp ir_loader.cpp module.cpp diff --git a/source/opt/basic_block.cpp b/source/opt/basic_block.cpp new file mode 100644 index 0000000..8ab2ec1 --- /dev/null +++ b/source/opt/basic_block.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "basic_block.h" + +namespace spvtools { +namespace ir { + +void BasicBlock::ForEachSuccessorLabel( + const std::function& f) { + const auto br = &*insts_.back(); + switch (br->opcode()) { + case SpvOpBranch: { + f(br->GetOperand(0).words[0]); + } break; + case SpvOpBranchConditional: + case SpvOpSwitch: { + bool is_first = true; + br->ForEachInId([&is_first, &f](const uint32_t* idp) { + if (!is_first) f(*idp); + is_first = false; + }); + } break; + default: + break; + } +} + +} // namespace ir +} // namespace spvtools + diff --git a/source/opt/basic_block.h b/source/opt/basic_block.h index 895ea44..05258a9 100644 --- a/source/opt/basic_block.h +++ b/source/opt/basic_block.h @@ -47,6 +47,9 @@ class BasicBlock { // The label starting this basic block. Instruction& Label() { return *label_; } + // Returns the id of the label at the top of this block + inline uint32_t label_id() const { return label_->result_id(); } + iterator begin() { return iterator(&insts_, insts_.begin()); } iterator end() { return iterator(&insts_, insts_.end()); } const_iterator cbegin() { return const_iterator(&insts_, insts_.cbegin()); } @@ -59,6 +62,15 @@ class BasicBlock { inline void ForEachInst(const std::function& f, bool run_on_debug_line_insts = false) const; + // Runs the given function |f| on each Phi instruction in this basic block, + // and optionally on the debug line instructions that might precede them. + inline void ForEachPhiInst(const std::function& f, + bool run_on_debug_line_insts = false); + + // Runs the given function |f| on each label id of each successor block + void ForEachSuccessorLabel( + const std::function& f); + private: // The enclosing function. Function* function_; @@ -92,6 +104,14 @@ inline void BasicBlock::ForEachInst( ->ForEachInst(f, run_on_debug_line_insts); } +inline void BasicBlock::ForEachPhiInst( + const std::function& f, bool run_on_debug_line_insts) { + for (auto& inst : insts_) { + if (inst->opcode() != SpvOpPhi) break; + inst->ForEachInst(f, run_on_debug_line_insts); + } +} + } // namespace ir } // namespace spvtools diff --git a/source/opt/build_module.cpp b/source/opt/build_module.cpp index 2699a46..c1daea5 100644 --- a/source/opt/build_module.cpp +++ b/source/opt/build_module.cpp @@ -27,8 +27,8 @@ namespace { spv_result_t SetSpvHeader(void* builder, spv_endianness_t, uint32_t magic, uint32_t version, uint32_t generator, uint32_t id_bound, uint32_t reserved) { - reinterpret_cast(builder)->SetModuleHeader( - magic, version, generator, id_bound, reserved); + reinterpret_cast(builder) + ->SetModuleHeader(magic, version, generator, id_bound, reserved); return SPV_SUCCESS; }; diff --git a/source/opt/function.cpp b/source/opt/function.cpp index 9fc476b..7f7952c 100644 --- a/source/opt/function.cpp +++ b/source/opt/function.cpp @@ -36,13 +36,20 @@ void Function::ForEachInst(const std::function& f, ->ForEachInst(f, run_on_debug_line_insts); for (const auto& bb : blocks_) - static_cast(bb.get())->ForEachInst( - f, run_on_debug_line_insts); + static_cast(bb.get()) + ->ForEachInst(f, run_on_debug_line_insts); if (end_inst_) static_cast(end_inst_.get()) ->ForEachInst(f, run_on_debug_line_insts); } +void Function::ForEachParam(const std::function& f, + bool run_on_debug_line_insts) const { + for (const auto& param : params_) + static_cast(param.get()) + ->ForEachInst(f, run_on_debug_line_insts); +} + } // namespace ir } // namespace spvtools diff --git a/source/opt/function.h b/source/opt/function.h index 12166fa..2e0674e 100644 --- a/source/opt/function.h +++ b/source/opt/function.h @@ -51,6 +51,12 @@ class Function { // Saves the given function end instruction. inline void SetFunctionEnd(std::unique_ptr end_inst); + // Returns function's id + inline uint32_t result_id() const { return def_inst_->result_id(); } + + // Returns function's type id + inline uint32_t type_id() const { return def_inst_->type_id(); } + iterator begin() { return iterator(&blocks_, blocks_.begin()); } iterator end() { return iterator(&blocks_, blocks_.end()); } const_iterator cbegin() { return const_iterator(&blocks_, blocks_.cbegin()); } @@ -63,6 +69,11 @@ class Function { void ForEachInst(const std::function& f, bool run_on_debug_line_insts = false) const; + // Runs the given function |f| on each parameter instruction in this function, + // and optionally on debug line instructions that might precede them. + void ForEachParam(const std::function& f, + bool run_on_debug_line_insts = false) const; + private: // The enclosing module. Module* module_; diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp new file mode 100644 index 0000000..26dd4a3 --- /dev/null +++ b/source/opt/inline_pass.cpp @@ -0,0 +1,440 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "inline_pass.h" + +// Indices of operands in SPIR-V instructions + +static const int kSpvEntryPointFunctionId = 1; +static const int kSpvFunctionCallFunctionId = 2; +static const int kSpvFunctionCallArgumentId = 3; +static const int kSpvReturnValueId = 0; +static const int kSpvTypePointerStorageClass = 1; +static const int kSpvTypePointerTypeId = 2; + +namespace spvtools { +namespace opt { + +uint32_t InlinePass::FindPointerToType(uint32_t type_id, + SpvStorageClass storage_class) { + ir::Module::inst_iterator type_itr = module_->types_values_begin(); + for (; type_itr != module_->types_values_end(); ++type_itr) { + const ir::Instruction* type_inst = &*type_itr; + if (type_inst->opcode() == SpvOpTypePointer && + type_inst->GetSingleWordOperand(kSpvTypePointerTypeId) == type_id && + type_inst->GetSingleWordOperand(kSpvTypePointerStorageClass) == + storage_class) + return type_inst->result_id(); + } + return 0; +} + +uint32_t InlinePass::AddPointerToType(uint32_t type_id, + SpvStorageClass storage_class) { + uint32_t resultId = TakeNextId(); + std::unique_ptr type_inst(new ir::Instruction( + SpvOpTypePointer, 0, resultId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, + {uint32_t(storage_class)}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}})); + module_->AddType(std::move(type_inst)); + return resultId; +} + +void InlinePass::AddBranch(uint32_t label_id, + std::unique_ptr* block_ptr) { + std::unique_ptr newBranch(new ir::Instruction( + SpvOpBranch, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}})); + (*block_ptr)->AddInstruction(std::move(newBranch)); +} + +void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id, + std::unique_ptr* block_ptr) { + std::unique_ptr newStore(new ir::Instruction( + SpvOpStore, 0, 0, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {val_id}}})); + (*block_ptr)->AddInstruction(std::move(newStore)); +} + +void InlinePass::AddLoad(uint32_t type_id, uint32_t resultId, uint32_t ptr_id, + std::unique_ptr* block_ptr) { + std::unique_ptr newLoad(new ir::Instruction( + SpvOpLoad, type_id, resultId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}})); + (*block_ptr)->AddInstruction(std::move(newLoad)); +} + +std::unique_ptr InlinePass::NewLabel(uint32_t label_id) { + std::unique_ptr newLabel( + new ir::Instruction(SpvOpLabel, 0, label_id, {})); + return newLabel; +} + +void InlinePass::MapParams( + ir::Function* calleeFn, + ir::UptrVectorIterator call_inst_itr, + std::unordered_map* callee2caller) { + int param_idx = 0; + calleeFn->ForEachParam( + [&call_inst_itr, ¶m_idx, &callee2caller](const ir::Instruction* cpi) { + const uint32_t pid = cpi->result_id(); + (*callee2caller)[pid] = call_inst_itr->GetSingleWordOperand( + kSpvFunctionCallArgumentId + param_idx); + param_idx++; + }); +} + +void InlinePass::CloneAndMapLocals( + ir::Function* calleeFn, + std::vector>* new_vars, + std::unordered_map* callee2caller) { + auto callee_block_itr = calleeFn->begin(); + auto callee_var_itr = callee_block_itr->begin(); + while (callee_var_itr->opcode() == SpvOp::SpvOpVariable) { + std::unique_ptr var_inst( + new ir::Instruction(*callee_var_itr)); + uint32_t newId = TakeNextId(); + var_inst->SetResultId(newId); + (*callee2caller)[callee_var_itr->result_id()] = newId; + new_vars->push_back(std::move(var_inst)); + callee_var_itr++; + } +} + +uint32_t InlinePass::CreateReturnVar( + ir::Function* calleeFn, + std::vector>* new_vars) { + uint32_t returnVarId = 0; + const uint32_t calleeTypeId = calleeFn->type_id(); + const ir::Instruction* calleeType = + def_use_mgr_->id_to_defs().find(calleeTypeId)->second; + if (calleeType->opcode() != SpvOpTypeVoid) { + // Find or create ptr to callee return type. + uint32_t returnVarTypeId = + FindPointerToType(calleeTypeId, SpvStorageClassFunction); + if (returnVarTypeId == 0) + returnVarTypeId = AddPointerToType(calleeTypeId, SpvStorageClassFunction); + // Add return var to new function scope variables. + returnVarId = TakeNextId(); + std::unique_ptr var_inst(new ir::Instruction( + SpvOpVariable, returnVarTypeId, returnVarId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, + {SpvStorageClassFunction}}})); + new_vars->push_back(std::move(var_inst)); + } + return returnVarId; +} + +bool InlinePass::IsSameBlockOp(const ir::Instruction* inst) const { + return inst->opcode() == SpvOpSampledImage || inst->opcode() == SpvOpImage; +} + +void InlinePass::CloneSameBlockOps( + std::unique_ptr* inst, + std::unordered_map* postCallSB, + std::unordered_map* preCallSB, + std::unique_ptr* block_ptr) { + (*inst) + ->ForEachInId([&postCallSB, &preCallSB, &block_ptr, this](uint32_t* iid) { + const auto mapItr = (*postCallSB).find(*iid); + if (mapItr == (*postCallSB).end()) { + const auto mapItr2 = (*preCallSB).find(*iid); + if (mapItr2 != (*preCallSB).end()) { + // Clone pre-call same-block ops, map result id. + const ir::Instruction* inInst = mapItr2->second; + std::unique_ptr sb_inst( + new ir::Instruction(*inInst)); + CloneSameBlockOps(&sb_inst, postCallSB, preCallSB, block_ptr); + const uint32_t rid = sb_inst->result_id(); + const uint32_t nid = this->TakeNextId(); + sb_inst->SetResultId(nid); + (*postCallSB)[rid] = nid; + *iid = nid; + (*block_ptr)->AddInstruction(std::move(sb_inst)); + } + } else { + // Reset same-block op operand. + *iid = mapItr->second; + } + }); +} + +void InlinePass::GenInlineCode( + std::vector>* new_blocks, + std::vector>* new_vars, + ir::UptrVectorIterator call_inst_itr, + ir::UptrVectorIterator call_block_itr) { + // Map from all ids in the callee to their equivalent id in the caller + // as callee instructions are copied into caller. + std::unordered_map callee2caller; + // Pre-call same-block insts + std::unordered_map preCallSB; + // Post-call same-block op ids + std::unordered_map postCallSB; + + ir::Function* calleeFn = id2function_[call_inst_itr->GetSingleWordOperand( + kSpvFunctionCallFunctionId)]; + + // Map parameters to actual arguments. + MapParams(calleeFn, call_inst_itr, &callee2caller); + + // Define caller local variables for all callee variables and create map to + // them. + CloneAndMapLocals(calleeFn, new_vars, &callee2caller); + + // Create return var if needed. + uint32_t returnVarId = CreateReturnVar(calleeFn, new_vars); + + // Clone and map callee code. Copy caller block code to beginning of + // first block and end of last block. + bool prevInstWasReturn = false; + uint32_t returnLabelId = 0; + bool multiBlocks = false; + const uint32_t calleeTypeId = calleeFn->type_id(); + std::unique_ptr new_blk_ptr; + calleeFn->ForEachInst([&new_blocks, &callee2caller, &call_block_itr, + &call_inst_itr, &new_blk_ptr, &prevInstWasReturn, + &returnLabelId, &returnVarId, &calleeTypeId, + &multiBlocks, &postCallSB, &preCallSB, this]( + const ir::Instruction* cpi) { + switch (cpi->opcode()) { + case SpvOpFunction: + case SpvOpFunctionParameter: + case SpvOpVariable: + // Already processed + break; + case SpvOpLabel: { + // If previous instruction was early return, insert branch + // instruction to return block. + if (prevInstWasReturn) { + if (returnLabelId == 0) returnLabelId = this->TakeNextId(); + AddBranch(returnLabelId, &new_blk_ptr); + prevInstWasReturn = false; + } + // Finish current block (if it exists) and get label for next block. + uint32_t labelId; + bool firstBlock = false; + if (new_blk_ptr != nullptr) { + new_blocks->push_back(std::move(new_blk_ptr)); + // If result id is already mapped, use it, otherwise get a new + // one. + const uint32_t rid = cpi->result_id(); + const auto mapItr = callee2caller.find(rid); + labelId = (mapItr != callee2caller.end()) ? mapItr->second + : this->TakeNextId(); + } else { + // First block needs to use label of original block + // but map callee label in case of phi reference. + labelId = call_block_itr->label_id(); + callee2caller[cpi->result_id()] = labelId; + firstBlock = true; + } + // Create first/next block. + new_blk_ptr.reset(new ir::BasicBlock(NewLabel(labelId))); + if (firstBlock) { + // Copy contents of original caller block up to call instruction. + for (auto cii = call_block_itr->begin(); cii != call_inst_itr; + cii++) { + std::unique_ptr cp_inst(new ir::Instruction(*cii)); + // Remember same-block ops for possible regeneration. + if (IsSameBlockOp(&*cp_inst)) { + auto* sb_inst_ptr = cp_inst.get(); + preCallSB[cp_inst->result_id()] = sb_inst_ptr; + } + new_blk_ptr->AddInstruction(std::move(cp_inst)); + } + } else { + multiBlocks = true; + } + } break; + case SpvOpReturnValue: { + // Store return value to return variable. + assert(returnVarId != 0); + uint32_t valId = cpi->GetInOperand(kSpvReturnValueId).words[0]; + const auto mapItr = callee2caller.find(valId); + if (mapItr != callee2caller.end()) { + valId = mapItr->second; + } + AddStore(returnVarId, valId, &new_blk_ptr); + + // Remember we saw a return; if followed by a label, will need to + // insert branch. + prevInstWasReturn = true; + } break; + case SpvOpReturn: { + // Remember we saw a return; if followed by a label, will need to + // insert branch. + prevInstWasReturn = true; + } break; + case SpvOpFunctionEnd: { + // If there was an early return, create return label/block. + // If previous instruction was return, insert branch instruction + // to return block. + if (returnLabelId != 0) { + if (prevInstWasReturn) AddBranch(returnLabelId, &new_blk_ptr); + new_blocks->push_back(std::move(new_blk_ptr)); + new_blk_ptr.reset(new ir::BasicBlock(NewLabel(returnLabelId))); + multiBlocks = true; + } + // Load return value into result id of call, if it exists. + if (returnVarId != 0) { + const uint32_t resId = call_inst_itr->result_id(); + assert(resId != 0); + AddLoad(calleeTypeId, resId, returnVarId, &new_blk_ptr); + } + // Copy remaining instructions from caller block. + auto cii = call_inst_itr; + for (cii++; cii != call_block_itr->end(); cii++) { + std::unique_ptr cp_inst(new ir::Instruction(*cii)); + // If multiple blocks generated, regenerate any same-block + // instruction that has not been seen in this last block. + if (multiBlocks) { + CloneSameBlockOps(&cp_inst, &postCallSB, &preCallSB, &new_blk_ptr); + // Remember same-block ops in this block. + if (IsSameBlockOp(&*cp_inst)) { + const uint32_t rid = cp_inst->result_id(); + postCallSB[rid] = rid; + } + } + new_blk_ptr->AddInstruction(std::move(cp_inst)); + } + // Finalize inline code. + new_blocks->push_back(std::move(new_blk_ptr)); + } break; + default: { + // Copy callee instruction and remap all input Ids. + std::unique_ptr cp_inst(new ir::Instruction(*cpi)); + cp_inst->ForEachInId([&callee2caller, &cpi, this](uint32_t* iid) { + const auto mapItr = callee2caller.find(*iid); + if (mapItr != callee2caller.end()) { + *iid = mapItr->second; + } else if (cpi->has_labels()) { + const ir::Instruction* inst = + def_use_mgr_->id_to_defs().find(*iid)->second; + if (inst->opcode() == SpvOpLabel) { + // Forward label reference. Allocate a new label id, map it, + // use it and check for it at each label. + const uint32_t nid = this->TakeNextId(); + callee2caller[*iid] = nid; + *iid = nid; + } + } + }); + // Map and reset result id. + const uint32_t rid = cp_inst->result_id(); + if (rid != 0) { + const uint32_t nid = this->TakeNextId(); + callee2caller[rid] = nid; + cp_inst->SetResultId(nid); + } + new_blk_ptr->AddInstruction(std::move(cp_inst)); + } break; + } + }); + // Update block map given replacement blocks. + for (auto& blk : *new_blocks) { + id2block_[blk->label_id()] = &*blk; + } +} + +bool InlinePass::Inline(ir::Function* func) { + bool modified = false; + // Using block iterators here because of block erasures and insertions. + for (auto bi = func->begin(); bi != func->end(); bi++) { + for (auto ii = bi->begin(); ii != bi->end();) { + if (ii->opcode() == SpvOp::SpvOpFunctionCall) { + // Inline call. + std::vector> newBlocks; + std::vector> newVars; + GenInlineCode(&newBlocks, &newVars, ii, bi); + // Update phi functions in successor blocks if call block + // is replaced with more than one block. + if (newBlocks.size() > 1) { + const auto firstBlk = newBlocks.begin(); + const auto lastBlk = newBlocks.end() - 1; + const uint32_t firstId = (*firstBlk)->label_id(); + const uint32_t lastId = (*lastBlk)->label_id(); + (*lastBlk) + ->ForEachSuccessorLabel([&firstId, &lastId, this](uint32_t succ) { + ir::BasicBlock* sbp = this->id2block_[succ]; + sbp->ForEachPhiInst([&firstId, &lastId](ir::Instruction* phi) { + phi->ForEachInId([&firstId, &lastId](uint32_t* id) { + if (*id == firstId) *id = lastId; + }); + }); + }); + } + // Replace old calling block with new block(s). + bi = bi.Erase(); + bi = bi.InsertBefore(&newBlocks); + // Insert new function variables. + if (newVars.size() > 0) func->begin()->begin().InsertBefore(&newVars); + // Restart inlining at beginning of calling block. + ii = bi->begin(); + modified = true; + } else { + ii++; + } + } + } + return modified; +} + +void InlinePass::Initialize(ir::Module* module) { + def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module)); + + // Initialize next unused Id. + next_id_ = module->id_bound(); + + // Save module. + module_ = module; + + // Initialize function and block maps. + id2function_.clear(); + id2block_.clear(); + for (auto& fn : *module_) { + id2function_[fn.result_id()] = &fn; + for (auto& blk : fn) { + id2block_[blk.label_id()] = &blk; + } + } +}; + +Pass::Status InlinePass::ProcessImpl() { + // Do exhaustive inlining on each entry point function in module + bool modified = false; + for (auto& e : module_->entry_points()) { + ir::Function* fn = + id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)]; + modified = modified || Inline(fn); + } + + FinalizeNextId(module_); + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +InlinePass::InlinePass() + : module_(nullptr), def_use_mgr_(nullptr), next_id_(0) {} + +Pass::Status InlinePass::Process(ir::Module* module) { + Initialize(module); + return ProcessImpl(); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/inline_pass.h b/source/opt/inline_pass.h new file mode 100644 index 0000000..523e193 --- /dev/null +++ b/source/opt/inline_pass.h @@ -0,0 +1,143 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIBSPIRV_OPT_INLINE_PASS_H_ +#define LIBSPIRV_OPT_INLINE_PASS_H_ + +#include +#include +#include +#include + +#include "def_use_manager.h" +#include "module.h" +#include "pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class InlinePass : public Pass { + public: + InlinePass(); + Status Process(ir::Module*) override; + + const char* name() const override { return "inline"; } + + private: + // Return the next available Id and increment it. + inline uint32_t TakeNextId() { return next_id_++; } + + // Write the next available Id back to the module. + inline void FinalizeNextId(ir::Module* module) { + module->SetIdBound(next_id_); + } + + // Find pointer to type and storage in module, return its resultId, + // 0 if not found. TODO(greg-lunarg): Move this into type manager. + uint32_t FindPointerToType(uint32_t type_id, SpvStorageClass storage_class); + + // Add pointer to type to module and return resultId. + uint32_t AddPointerToType(uint32_t type_id, SpvStorageClass storage_class); + + // Add unconditional branch to labelId to end of block block_ptr. + void AddBranch(uint32_t labelId, std::unique_ptr* block_ptr); + + // Add store of valId to ptrId to end of block block_ptr. + void AddStore(uint32_t ptrId, uint32_t valId, + std::unique_ptr* block_ptr); + + // Add load of ptrId into resultId to end of block block_ptr. + void AddLoad(uint32_t typeId, uint32_t resultId, uint32_t ptrId, + std::unique_ptr* block_ptr); + + // Return new label. + std::unique_ptr NewLabel(uint32_t label_id); + + // Map callee params to caller args + void MapParams(ir::Function* calleeFn, + ir::UptrVectorIterator call_inst_itr, + std::unordered_map* callee2caller); + + // Clone and map callee locals + void CloneAndMapLocals( + ir::Function* calleeFn, + std::vector>* new_vars, + std::unordered_map* callee2caller); + + // Create return variable for callee clone code if needed. Return id + // if created, otherwise 0. + uint32_t CreateReturnVar( + ir::Function* calleeFn, + std::vector>* new_vars); + + // Return true if instruction must be in the same block that its result + // is used. + bool IsSameBlockOp(const ir::Instruction* inst) const; + + // Clone operands which must be in same block as consumer instructions. + // Look in preCallSB for instructions that need cloning. Look in + // postCallSB for instructions already cloned. Add cloned instruction + // to postCallSB. + void CloneSameBlockOps( + std::unique_ptr* inst, + std::unordered_map* postCallSB, + std::unordered_map* preCallSB, + std::unique_ptr* block_ptr); + + // Return in new_blocks the result of inlining the call at call_inst_itr + // within its block at call_block_itr. The block at call_block_itr can + // just be replaced with the blocks in new_blocks. Any additional branches + // are avoided. Debug instructions are cloned along with their callee + // instructions. Early returns are replaced by a store to a local return + // variable and a branch to a (created) exit block where the local variable + // is returned. Formal parameters are trivially mapped to their actual + // parameters. Note that the first block in new_blocks retains the label + // of the original calling block. Also note that if an exit block is + // created, it is the last block of new_blocks. + // + // Also return in new_vars additional OpVariable instructions required by + // and to be inserted into the caller function after the block at + // call_block_itr is replaced with new_blocks. + void GenInlineCode(std::vector>* new_blocks, + std::vector>* new_vars, + ir::UptrVectorIterator call_inst_itr, + ir::UptrVectorIterator call_block_itr); + + // Exhaustively inline all function calls in func as well as in + // all code that is inlined into func. Return true if func is modified. + bool Inline(ir::Function* func); + + void Initialize(ir::Module* module); + Pass::Status ProcessImpl(); + + ir::Module* module_; + std::unique_ptr def_use_mgr_; + + // Map from function's result id to function. + std::unordered_map id2function_; + + // Map from block's label id to block. + std::unordered_map id2block_; + + // Next unused ID + uint32_t next_id_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // LIBSPIRV_OPT_INLINE_PASS_H_ diff --git a/source/opt/instruction.h b/source/opt/instruction.h index aaeef7d..2c3189c 100644 --- a/source/opt/instruction.h +++ b/source/opt/instruction.h @@ -20,6 +20,8 @@ #include #include +#include "operand.h" + #include "spirv-tools/libspirv.h" #include "spirv/1.1/spirv.h" @@ -135,6 +137,8 @@ class Instruction { inline void SetInOperand(uint32_t index, std::vector&& data); // Sets the result type id. inline void SetResultType(uint32_t ty_id); + // Sets the result id + inline void SetResultId(uint32_t res_id); // The following methods are similar to the above, but are for in operands. uint32_t NumInOperands() const { @@ -162,6 +166,13 @@ class Instruction { inline void ForEachInst(const std::function& f, bool run_on_debug_line_insts = false) const; + // Runs the given function |f| on all "in" operand ids + inline void ForEachInId(const std::function& f); + inline void ForEachInId(const std::function& f) const; + + // Returns true if any operands can be labels + inline bool has_labels() const; + // Pushes the binary segments for this instruction into the back of *|binary|. void ToBinaryWithoutAttachedDebugInsts(std::vector* binary) const; @@ -194,6 +205,13 @@ inline void Instruction::SetInOperand(uint32_t index, operands_[index + TypeResultIdCount()].words = std::move(data); } +inline void Instruction::SetResultId(uint32_t res_id) { + result_id_ = res_id; + auto ridx = (type_id_ != 0) ? 1 : 0; + assert(operands_[ridx].type == SPV_OPERAND_TYPE_RESULT_ID); + operands_[ridx].words = {res_id}; +} + inline void Instruction::SetResultType(uint32_t ty_id) { if (type_id_ != 0) { type_id_ = ty_id; @@ -228,6 +246,33 @@ inline void Instruction::ForEachInst( f(this); } +inline void Instruction::ForEachInId(const std::function& f) { + for (auto& opnd : operands_) + if (opnd.type == SPV_OPERAND_TYPE_ID) f(&opnd.words[0]); +} + +inline void Instruction::ForEachInId( + const std::function& f) const { + for (const auto& opnd : operands_) + if (opnd.type == SPV_OPERAND_TYPE_ID) f(&opnd.words[0]); +} + +inline bool Instruction::has_labels() const { + switch (opcode_) { + case SpvOpSelectionMerge: + case SpvOpBranch: + case SpvOpLoopMerge: + case SpvOpBranchConditional: + case SpvOpSwitch: + case SpvOpPhi: + return true; + break; + default: + break; + } + return false; +} + } // namespace ir } // namespace spvtools diff --git a/source/opt/iterator.h b/source/opt/iterator.h index 0430638..d82c954 100644 --- a/source/opt/iterator.h +++ b/source/opt/iterator.h @@ -81,6 +81,24 @@ class UptrVectorIterator inline typename std::enable_if::type InsertBefore(Uptr value); + // Inserts the given |valueVector| to the position pointed to by this iterator + // and returns an iterator to the first newly inserted value. + // If the underlying vector changes capacity, all previous iterators will be + // invalidated. Otherwise, those previous iterators pointing to after the + // insertion point will be invalidated. + template + inline typename std::enable_if::type + InsertBefore(UptrVector* valueVector); + + // Erases the value at the position pointed to by this iterator + // and returns an iterator to the following value. + // If the underlying vector changes capacity, all previous iterators will be + // invalidated. Otherwise, those previous iterators pointing to after the + // erasure point will be invalidated. + template + inline typename std::enable_if::type + Erase(); + private: UptrVector* container_; // The container we are manipulating. UnderlyingIterator iterator_; // The raw iterator from the container. @@ -183,6 +201,30 @@ inline return UptrVectorIterator(container_, container_->begin() + index); } +template +template +inline + typename std::enable_if>::type + UptrVectorIterator::InsertBefore(UptrVector* values) { + const auto pos = iterator_ - container_->begin(); + const auto origsz = container_->size(); + container_->resize(origsz + values->size()); + std::move_backward(container_->begin() + pos, container_->begin() + origsz, + container_->end()); + std::move(values->begin(), values->end(), container_->begin() + pos); + return UptrVectorIterator(container_, container_->begin() + pos); +} + +template +template +inline + typename std::enable_if>::type + UptrVectorIterator::Erase() { + auto index = iterator_ - container_->begin(); + (void)container_->erase(iterator_); + return UptrVectorIterator(container_, container_->begin() + index); +} + } // namespace ir } // namespace spvtools diff --git a/source/opt/module.cpp b/source/opt/module.cpp index 9acb3a4..372b70c 100644 --- a/source/opt/module.cpp +++ b/source/opt/module.cpp @@ -76,9 +76,9 @@ void Module::ForEachInst(const std::function& f, void Module::ForEachInst(const std::function& f, bool run_on_debug_line_insts) const { -#define DELEGATE(i) \ - static_cast(i.get())->ForEachInst( \ - f, run_on_debug_line_insts) +#define DELEGATE(i) \ + static_cast(i.get()) \ + ->ForEachInst(f, run_on_debug_line_insts) for (auto& i : capabilities_) DELEGATE(i); for (auto& i : extensions_) DELEGATE(i); for (auto& i : ext_inst_imports_) DELEGATE(i); @@ -89,8 +89,8 @@ void Module::ForEachInst(const std::function& f, for (auto& i : annotations_) DELEGATE(i); for (auto& i : types_values_) DELEGATE(i); for (auto& i : functions_) { - static_cast(i.get())->ForEachInst(f, - run_on_debug_line_insts); + static_cast(i.get()) + ->ForEachInst(f, run_on_debug_line_insts); } #undef DELEGATE } @@ -112,15 +112,13 @@ void Module::ToBinary(std::vector* binary, bool skip_nop) const { uint32_t Module::ComputeIdBound() const { uint32_t highest = 0; - ForEachInst( - [&highest](const Instruction* inst) { - for (const auto& operand : *inst) { - if (spvIsIdType(operand.type)) { - highest = std::max(highest, operand.words[0]); - } - } - }, - true /* scan debug line insts as well */); + ForEachInst([&highest](const Instruction* inst) { + for (const auto& operand : *inst) { + if (spvIsIdType(operand.type)) { + highest = std::max(highest, operand.words[0]); + } + } + }, true /* scan debug line insts as well */); return highest + 1; } diff --git a/source/opt/module.h b/source/opt/module.h index 87410e1..5d98ddc 100644 --- a/source/opt/module.h +++ b/source/opt/module.h @@ -86,6 +86,8 @@ class Module { std::vector GetConstants(); std::vector GetConstants() const; + inline uint32_t id_bound() const { return header_.bound; } + // Iterators for debug instructions (excluding OpLine & OpNoLine) contained in // this module. inline inst_iterator debug_begin(); @@ -93,6 +95,10 @@ class Module { inline IteratorRange debugs(); inline IteratorRange debugs() const; + // Iterators for entry point instructions contained in this module + inline IteratorRange entry_points(); + inline IteratorRange entry_points() const; + // Clears all debug instructions (excluding OpLine & OpNoLine). void debug_clear() { debugs_.clear(); } @@ -204,6 +210,14 @@ inline IteratorRange Module::debugs() const { return make_const_range(debugs_); } +inline IteratorRange Module::entry_points() { + return make_range(entry_points_); +} + +inline IteratorRange Module::entry_points() const { + return make_const_range(entry_points_); +} + inline IteratorRange Module::annotations() { return make_range(annotations_); } diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index abe5b93..f6dc1cd 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -118,4 +118,8 @@ Optimizer::PassToken CreateEliminateDeadConstantPass() { MakeUnique()); } +Optimizer::PassToken CreateInlinePass() { + return MakeUnique(MakeUnique()); +} + } // namespace spvtools diff --git a/source/opt/passes.h b/source/opt/passes.h index ba1b233..26c55f9 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -19,6 +19,7 @@ #include "eliminate_dead_constant_pass.h" #include "fold_spec_constant_op_and_composite_pass.h" +#include "inline_pass.h" #include "freeze_spec_constant_value_pass.h" #include "null_pass.h" #include "set_spec_constant_default_value_pass.h" diff --git a/source/opt/types.cpp b/source/opt/types.cpp index 8e6d32a..285c148 100644 --- a/source/opt/types.cpp +++ b/source/opt/types.cpp @@ -34,7 +34,7 @@ bool CompareTwoVectors(const U32VecVec a, const U32VecVec b) { if (size == 0) return true; if (size == 1) return a.front() == b.front(); - std::vector *> a_ptrs, b_ptrs; + std::vector*> a_ptrs, b_ptrs; a_ptrs.reserve(size); a_ptrs.reserve(size); for (uint32_t i = 0; i < size; ++i) { @@ -42,10 +42,10 @@ bool CompareTwoVectors(const U32VecVec a, const U32VecVec b) { b_ptrs.push_back(&b[i]); } - const auto cmp = [](const std::vector* m, - const std::vector* n) { - return m->front() < n->front(); - }; + const auto cmp = + [](const std::vector* m, const std::vector* n) { + return m->front() < n->front(); + }; std::sort(a_ptrs.begin(), a_ptrs.end(), cmp); std::sort(b_ptrs.begin(), b_ptrs.end(), cmp); diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index 8e98793..2f1f482 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -38,6 +38,11 @@ add_spvtools_unittest(TARGET pass_freeze_spec_const LIBS SPIRV-Tools-opt ) +add_spvtools_unittest(TARGET pass_inline + SRCS inline_test.cpp pass_utils.cpp + LIBS SPIRV-Tools-opt +) + add_spvtools_unittest(TARGET pass_eliminate_dead_const SRCS eliminate_dead_const_test.cpp pass_utils.cpp LIBS SPIRV-Tools-opt diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp new file mode 100644 index 0000000..a1a89fa --- /dev/null +++ b/test/opt/inline_test.cpp @@ -0,0 +1,1371 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pass_fixture.h" +#include "pass_utils.h" + +template std::vector concat(const std::vector &a, const std::vector &b) { + std::vector ret = std::vector(); + std::copy(a.begin(), a.end(), back_inserter(ret)); + std::copy(b.begin(), b.end(), back_inserter(ret)); + return ret; +} + +namespace { + +using namespace spvtools; + +using InlineTest = PassTest<::testing::Test>; + +TEST_F(InlineTest, Simple) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // return bar.x + bar.y; + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 140", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %color \"color\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "%void = OpTypeVoid", + "%10 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%14 = OpTypeFunction %float %_ptr_Function_v4float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint_1 = OpConstant %uint 1", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %float None %14", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%26 = OpLabel", + "%27 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%28 = OpLoad %float %27", + "%29 = OpAccessChain %_ptr_Function_float %bar %uint_1", + "%30 = OpLoad %float %29", + "%31 = OpFAdd %float %28 %30", + "OpReturnValue %31", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %10", + "%21 = OpLabel", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%22 = OpLoad %v4float %BaseColor", + "OpStore %param %22", + "%23 = OpFunctionCall %float %foo_vf4_ %param", + "%24 = OpCompositeConstruct %v4float %23 %23 %23 %23", + "OpStore %color %24", + "%25 = OpLoad %v4float %color", + "OpStore %gl_FragColor %25", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %10", + "%21 = OpLabel", + "%32 = OpVariable %_ptr_Function_float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%22 = OpLoad %v4float %BaseColor", + "OpStore %param %22", + "%33 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%34 = OpLoad %float %33", + "%35 = OpAccessChain %_ptr_Function_float %param %uint_1", + "%36 = OpLoad %float %35", + "%37 = OpFAdd %float %34 %36", + "OpStore %32 %37", + "%23 = OpLoad %float %32", + "%24 = OpCompositeConstruct %v4float %23 %23 %23 %23", + "OpStore %color %24", + "%25 = OpLoad %v4float %color", + "OpStore %gl_FragColor %25", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, Nested) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo2(float f, float f2) + // { + // return f * f2; + // } + // + // float foo(vec4 bar) + // { + // return foo2(bar.x + bar.y, bar.z); + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 140", + "OpName %main \"main\"", + "OpName %foo2_f1_f1_ \"foo2(f1;f1;\"", + "OpName %f \"f\"", + "OpName %f2 \"f2\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %param \"param\"", + "OpName %param_0 \"param\"", + "OpName %color \"color\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param_1 \"param\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "%void = OpTypeVoid", + "%15 = OpTypeFunction %void", + "%float = OpTypeFloat 32", +"%_ptr_Function_float = OpTypePointer Function %float", + "%18 = OpTypeFunction %float %_ptr_Function_float %_ptr_Function_float", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%21 = OpTypeFunction %float %_ptr_Function_v4float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%uint_1 = OpConstant %uint 1", + "%uint_2 = OpConstant %uint 2", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off +"%foo2_f1_f1_ = OpFunction %float None %18", + "%f = OpFunctionParameter %_ptr_Function_float", + "%f2 = OpFunctionParameter %_ptr_Function_float", + "%33 = OpLabel", + "%34 = OpLoad %float %f", + "%35 = OpLoad %float %f2", + "%36 = OpFMul %float %34 %35", + "OpReturnValue %36", + "OpFunctionEnd", + "%foo_vf4_ = OpFunction %float None %21", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%37 = OpLabel", + "%param = OpVariable %_ptr_Function_float Function", + "%param_0 = OpVariable %_ptr_Function_float Function", + "%38 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%39 = OpLoad %float %38", + "%40 = OpAccessChain %_ptr_Function_float %bar %uint_1", + "%41 = OpLoad %float %40", + "%42 = OpFAdd %float %39 %41", + "OpStore %param %42", + "%43 = OpAccessChain %_ptr_Function_float %bar %uint_2", + "%44 = OpLoad %float %43", + "OpStore %param_0 %44", + "%45 = OpFunctionCall %float %foo2_f1_f1_ %param %param_0", + "OpReturnValue %45", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %15", + "%28 = OpLabel", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param_1 = OpVariable %_ptr_Function_v4float Function", + "%29 = OpLoad %v4float %BaseColor", + "OpStore %param_1 %29", + "%30 = OpFunctionCall %float %foo_vf4_ %param_1", + "%31 = OpCompositeConstruct %v4float %30 %30 %30 %30", + "OpStore %color %31", + "%32 = OpLoad %v4float %color", + "OpStore %gl_FragColor %32", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %15", + "%28 = OpLabel", + "%57 = OpVariable %_ptr_Function_float Function", + "%46 = OpVariable %_ptr_Function_float Function", + "%47 = OpVariable %_ptr_Function_float Function", + "%48 = OpVariable %_ptr_Function_float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param_1 = OpVariable %_ptr_Function_v4float Function", + "%29 = OpLoad %v4float %BaseColor", + "OpStore %param_1 %29", + "%49 = OpAccessChain %_ptr_Function_float %param_1 %uint_0", + "%50 = OpLoad %float %49", + "%51 = OpAccessChain %_ptr_Function_float %param_1 %uint_1", + "%52 = OpLoad %float %51", + "%53 = OpFAdd %float %50 %52", + "OpStore %46 %53", + "%54 = OpAccessChain %_ptr_Function_float %param_1 %uint_2", + "%55 = OpLoad %float %54", + "OpStore %47 %55", + "%58 = OpLoad %float %46", + "%59 = OpLoad %float %47", + "%60 = OpFMul %float %58 %59", + "OpStore %57 %60", + "%56 = OpLoad %float %57", + "OpStore %48 %56", + "%30 = OpLoad %float %48", + "%31 = OpCompositeConstruct %v4float %30 %30 %30 %30", + "OpStore %color %31", + "%32 = OpLoad %v4float %color", + "OpStore %gl_FragColor %32", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, InOutParameter) { + // #version 400 + // + // in vec4 Basecolor; + // + // void foo(inout vec4 bar) + // { + // bar.z = bar.x + bar.y; + // } + // + // void main() + // { + // vec4 b = Basecolor; + // foo(b); + // vec4 color = vec4(b.z); + // gl_FragColor = color; + // } + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %Basecolor %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 400", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %b \"b\"", + "OpName %Basecolor \"Basecolor\"", + "OpName %param \"param\"", + "OpName %color \"color\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "%void = OpTypeVoid", + "%11 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%15 = OpTypeFunction %void %_ptr_Function_v4float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint_1 = OpConstant %uint 1", + "%uint_2 = OpConstant %uint 2", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%Basecolor = OpVariable %_ptr_Input_v4float Input", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %void None %15", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%32 = OpLabel", + "%33 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%34 = OpLoad %float %33", + "%35 = OpAccessChain %_ptr_Function_float %bar %uint_1", + "%36 = OpLoad %float %35", + "%37 = OpFAdd %float %34 %36", + "%38 = OpAccessChain %_ptr_Function_float %bar %uint_2", + "OpStore %38 %37", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %11", + "%23 = OpLabel", + "%b = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%24 = OpLoad %v4float %Basecolor", + "OpStore %b %24", + "%25 = OpLoad %v4float %b", + "OpStore %param %25", + "%26 = OpFunctionCall %void %foo_vf4_ %param", + "%27 = OpLoad %v4float %param", + "OpStore %b %27", + "%28 = OpAccessChain %_ptr_Function_float %b %uint_2", + "%29 = OpLoad %float %28", + "%30 = OpCompositeConstruct %v4float %29 %29 %29 %29", + "OpStore %color %30", + "%31 = OpLoad %v4float %color", + "OpStore %gl_FragColor %31", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %11", + "%23 = OpLabel", + "%b = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%24 = OpLoad %v4float %Basecolor", + "OpStore %b %24", + "%25 = OpLoad %v4float %b", + "OpStore %param %25", + "%39 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%40 = OpLoad %float %39", + "%41 = OpAccessChain %_ptr_Function_float %param %uint_1", + "%42 = OpLoad %float %41", + "%43 = OpFAdd %float %40 %42", + "%44 = OpAccessChain %_ptr_Function_float %param %uint_2", + "OpStore %44 %43", + "%27 = OpLoad %v4float %param", + "OpStore %b %27", + "%28 = OpAccessChain %_ptr_Function_float %b %uint_2", + "%29 = OpLoad %float %28", + "%30 = OpCompositeConstruct %v4float %29 %29 %29 %29", + "OpStore %color %30", + "%31 = OpLoad %v4float %color", + "OpStore %gl_FragColor %31", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, BranchInCallee) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // float r = bar.x; + // if (r < 0.0) + // r = -r; + // return r; + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // + // gl_FragColor = color; + // } + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 140", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %r \"r\"", + "OpName %color \"color\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "%void = OpTypeVoid", + "%11 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%15 = OpTypeFunction %float %_ptr_Function_v4float", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%float_0 = OpConstant %float 0", + "%bool = OpTypeBool", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %float None %15", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%28 = OpLabel", + "%r = OpVariable %_ptr_Function_float Function", + "%29 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%30 = OpLoad %float %29", + "OpStore %r %30", + "%31 = OpLoad %float %r", + "%32 = OpFOrdLessThan %bool %31 %float_0", + "OpSelectionMerge %33 None", + "OpBranchConditional %32 %34 %33", + "%34 = OpLabel", + "%35 = OpLoad %float %r", + "%36 = OpFNegate %float %35", + "OpStore %r %36", + "OpBranch %33", + "%33 = OpLabel", + "%37 = OpLoad %float %r", + "OpReturnValue %37", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %11", + "%23 = OpLabel", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%24 = OpLoad %v4float %BaseColor", + "OpStore %param %24", + "%25 = OpFunctionCall %float %foo_vf4_ %param", + "%26 = OpCompositeConstruct %v4float %25 %25 %25 %25", + "OpStore %color %26", + "%27 = OpLoad %v4float %color", + "OpStore %gl_FragColor %27", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %11", + "%23 = OpLabel", + "%38 = OpVariable %_ptr_Function_float Function", + "%39 = OpVariable %_ptr_Function_float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%24 = OpLoad %v4float %BaseColor", + "OpStore %param %24", + "%40 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%41 = OpLoad %float %40", + "OpStore %38 %41", + "%42 = OpLoad %float %38", + "%43 = OpFOrdLessThan %bool %42 %float_0", + "OpSelectionMerge %44 None", + "OpBranchConditional %43 %45 %44", + "%45 = OpLabel", + "%46 = OpLoad %float %38", + "%47 = OpFNegate %float %46", + "OpStore %38 %47", + "OpBranch %44", + "%44 = OpLabel", + "%48 = OpLoad %float %38", + "OpStore %39 %48", + "%25 = OpLoad %float %39", + "%26 = OpCompositeConstruct %v4float %25 %25 %25 %25", + "OpStore %color %26", + "%27 = OpLoad %v4float %color", + "OpStore %gl_FragColor %27", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, PhiAfterCall) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(float bar) + // { + // float r = bar; + // if (r < 0.0) + // r = -r; + // return r; + // } + // + // void main() + // { + // vec4 color = BaseColor; + // if (foo(color.x) > 2.0 && foo(color.y) > 2.0) + // color = vec4(0.0); + // gl_FragColor = color; + // } + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %gl_FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 140", + "OpName %main \"main\"", + "OpName %foo_f1_ \"foo(f1;\"", + "OpName %bar \"bar\"", + "OpName %r \"r\"", + "OpName %color \"color\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %param_0 \"param\"", + "OpName %gl_FragColor \"gl_FragColor\"", + "%void = OpTypeVoid", + "%12 = OpTypeFunction %void", + "%float = OpTypeFloat 32", +"%_ptr_Function_float = OpTypePointer Function %float", + "%15 = OpTypeFunction %float %_ptr_Function_float", + "%float_0 = OpConstant %float 0", + "%bool = OpTypeBool", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%float_2 = OpConstant %float 2", + "%uint_1 = OpConstant %uint 1", + "%25 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", +"%gl_FragColor = OpVariable %_ptr_Output_v4float Output", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_f1_ = OpFunction %float None %15", + "%bar = OpFunctionParameter %_ptr_Function_float", + "%43 = OpLabel", + "%r = OpVariable %_ptr_Function_float Function", + "%44 = OpLoad %float %bar", + "OpStore %r %44", + "%45 = OpLoad %float %r", + "%46 = OpFOrdLessThan %bool %45 %float_0", + "OpSelectionMerge %47 None", + "OpBranchConditional %46 %48 %47", + "%48 = OpLabel", + "%49 = OpLoad %float %r", + "%50 = OpFNegate %float %49", + "OpStore %r %50", + "OpBranch %47", + "%47 = OpLabel", + "%51 = OpLoad %float %r", + "OpReturnValue %51", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %12", + "%27 = OpLabel", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_float Function", + "%param_0 = OpVariable %_ptr_Function_float Function", + "%28 = OpLoad %v4float %BaseColor", + "OpStore %color %28", + "%29 = OpAccessChain %_ptr_Function_float %color %uint_0", + "%30 = OpLoad %float %29", + "OpStore %param %30", + "%31 = OpFunctionCall %float %foo_f1_ %param", + "%32 = OpFOrdGreaterThan %bool %31 %float_2", + "OpSelectionMerge %33 None", + "OpBranchConditional %32 %34 %33", + "%34 = OpLabel", + "%35 = OpAccessChain %_ptr_Function_float %color %uint_1", + "%36 = OpLoad %float %35", + "OpStore %param_0 %36", + "%37 = OpFunctionCall %float %foo_f1_ %param_0", + "%38 = OpFOrdGreaterThan %bool %37 %float_2", + "OpBranch %33", + "%33 = OpLabel", + "%39 = OpPhi %bool %32 %27 %38 %34", + "OpSelectionMerge %40 None", + "OpBranchConditional %39 %41 %40", + "%41 = OpLabel", + "OpStore %color %25", + "OpBranch %40", + "%40 = OpLabel", + "%42 = OpLoad %v4float %color", + "OpStore %gl_FragColor %42", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %12", + "%27 = OpLabel", + "%62 = OpVariable %_ptr_Function_float Function", + "%63 = OpVariable %_ptr_Function_float Function", + "%52 = OpVariable %_ptr_Function_float Function", + "%53 = OpVariable %_ptr_Function_float Function", + "%color = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_float Function", + "%param_0 = OpVariable %_ptr_Function_float Function", + "%28 = OpLoad %v4float %BaseColor", + "OpStore %color %28", + "%29 = OpAccessChain %_ptr_Function_float %color %uint_0", + "%30 = OpLoad %float %29", + "OpStore %param %30", + "%54 = OpLoad %float %param", + "OpStore %52 %54", + "%55 = OpLoad %float %52", + "%56 = OpFOrdLessThan %bool %55 %float_0", + "OpSelectionMerge %57 None", + "OpBranchConditional %56 %58 %57", + "%58 = OpLabel", + "%59 = OpLoad %float %52", + "%60 = OpFNegate %float %59", + "OpStore %52 %60", + "OpBranch %57", + "%57 = OpLabel", + "%61 = OpLoad %float %52", + "OpStore %53 %61", + "%31 = OpLoad %float %53", + "%32 = OpFOrdGreaterThan %bool %31 %float_2", + "OpSelectionMerge %33 None", + "OpBranchConditional %32 %34 %33", + "%34 = OpLabel", + "%35 = OpAccessChain %_ptr_Function_float %color %uint_1", + "%36 = OpLoad %float %35", + "OpStore %param_0 %36", + "%64 = OpLoad %float %param_0", + "OpStore %62 %64", + "%65 = OpLoad %float %62", + "%66 = OpFOrdLessThan %bool %65 %float_0", + "OpSelectionMerge %67 None", + "OpBranchConditional %66 %68 %67", + "%68 = OpLabel", + "%69 = OpLoad %float %62", + "%70 = OpFNegate %float %69", + "OpStore %62 %70", + "OpBranch %67", + "%67 = OpLabel", + "%71 = OpLoad %float %62", + "OpStore %63 %71", + "%37 = OpLoad %float %63", + "%38 = OpFOrdGreaterThan %bool %37 %float_2", + "OpBranch %33", + "%33 = OpLabel", + "%39 = OpPhi %bool %32 %57 %38 %67", + "OpSelectionMerge %40 None", + "OpBranchConditional %39 %41 %40", + "%41 = OpLabel", + "OpStore %color %25", + "OpBranch %40", + "%40 = OpLabel", + "%42 = OpLoad %v4float %color", + "OpStore %gl_FragColor %42", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, OpSampledImageOutOfBlock) { + // #version 450 + // + // uniform texture2D t2D; + // uniform sampler samp; + // out vec4 FragColor; + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // float r = bar.x; + // if (r < 0.0) + // r = -r; + // return r; + // } + // + // void main() + // { + // vec4 color1 = texture(sampler2D(t2D, samp), vec2(1.0)); + // vec4 color2 = vec4(foo(BaseColor)); + // vec4 color3 = texture(sampler2D(t2D, samp), vec2(0.5)); + // FragColor = (color1 + color2 + color3)/3; + // } + // + // Note: the before SPIR-V will need to be edited to create a use of + // the OpSampledImage across the function call. + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 450", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %r \"r\"", + "OpName %color1 \"color1\"", + "OpName %t2D \"t2D\"", + "OpName %samp \"samp\"", + "OpName %color2 \"color2\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %color3 \"color3\"", + "OpName %FragColor \"FragColor\"", + "OpDecorate %t2D DescriptorSet 0", + "OpDecorate %samp DescriptorSet 0", + "%void = OpTypeVoid", + "%15 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%19 = OpTypeFunction %float %_ptr_Function_v4float", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%float_0 = OpConstant %float 0", + "%bool = OpTypeBool", + "%25 = OpTypeImage %float 2D 0 0 0 1 Unknown", +"%_ptr_UniformConstant_25 = OpTypePointer UniformConstant %25", + "%t2D = OpVariable %_ptr_UniformConstant_25 UniformConstant", + "%27 = OpTypeSampler", +"%_ptr_UniformConstant_27 = OpTypePointer UniformConstant %27", + "%samp = OpVariable %_ptr_UniformConstant_27 UniformConstant", + "%29 = OpTypeSampledImage %25", + "%v2float = OpTypeVector %float 2", + "%float_1 = OpConstant %float 1", + "%32 = OpConstantComposite %v2float %float_1 %float_1", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", + "%float_0_5 = OpConstant %float 0.5", + "%35 = OpConstantComposite %v2float %float_0_5 %float_0_5", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", + "%FragColor = OpVariable %_ptr_Output_v4float Output", + "%float_3 = OpConstant %float 3", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %float None %19", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%56 = OpLabel", + "%r = OpVariable %_ptr_Function_float Function", + "%57 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%58 = OpLoad %float %57", + "OpStore %r %58", + "%59 = OpLoad %float %r", + "%60 = OpFOrdLessThan %bool %59 %float_0", + "OpSelectionMerge %61 None", + "OpBranchConditional %60 %62 %61", + "%62 = OpLabel", + "%63 = OpLoad %float %r", + "%64 = OpFNegate %float %63", + "OpStore %r %64", + "OpBranch %61", + "%61 = OpLabel", + "%65 = OpLoad %float %r", + "OpReturnValue %65", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %15", + "%38 = OpLabel", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%39 = OpLoad %25 %t2D", + "%40 = OpLoad %27 %samp", + "%41 = OpSampledImage %29 %39 %40", + "%42 = OpImageSampleImplicitLod %v4float %41 %32", + "OpStore %color1 %42", + "%43 = OpLoad %v4float %BaseColor", + "OpStore %param %43", + "%44 = OpFunctionCall %float %foo_vf4_ %param", + "%45 = OpCompositeConstruct %v4float %44 %44 %44 %44", + "OpStore %color2 %45", + "%46 = OpLoad %25 %t2D", + "%47 = OpLoad %27 %samp", + "%48 = OpImageSampleImplicitLod %v4float %41 %35", + "OpStore %color3 %48", + "%49 = OpLoad %v4float %color1", + "%50 = OpLoad %v4float %color2", + "%51 = OpFAdd %v4float %49 %50", + "%52 = OpLoad %v4float %color3", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%55 = OpFDiv %v4float %53 %54", + "OpStore %FragColor %55", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %15", + "%38 = OpLabel", + "%66 = OpVariable %_ptr_Function_float Function", + "%67 = OpVariable %_ptr_Function_float Function", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%39 = OpLoad %25 %t2D", + "%40 = OpLoad %27 %samp", + "%41 = OpSampledImage %29 %39 %40", + "%42 = OpImageSampleImplicitLod %v4float %41 %32", + "OpStore %color1 %42", + "%43 = OpLoad %v4float %BaseColor", + "OpStore %param %43", + "%68 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%69 = OpLoad %float %68", + "OpStore %66 %69", + "%70 = OpLoad %float %66", + "%71 = OpFOrdLessThan %bool %70 %float_0", + "OpSelectionMerge %72 None", + "OpBranchConditional %71 %73 %72", + "%73 = OpLabel", + "%74 = OpLoad %float %66", + "%75 = OpFNegate %float %74", + "OpStore %66 %75", + "OpBranch %72", + "%72 = OpLabel", + "%76 = OpLoad %float %66", + "OpStore %67 %76", + "%44 = OpLoad %float %67", + "%45 = OpCompositeConstruct %v4float %44 %44 %44 %44", + "OpStore %color2 %45", + "%46 = OpLoad %25 %t2D", + "%47 = OpLoad %27 %samp", + "%77 = OpSampledImage %29 %39 %40", + "%48 = OpImageSampleImplicitLod %v4float %77 %35", + "OpStore %color3 %48", + "%49 = OpLoad %v4float %color1", + "%50 = OpLoad %v4float %color2", + "%51 = OpFAdd %v4float %49 %50", + "%52 = OpLoad %v4float %color3", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%55 = OpFDiv %v4float %53 %54", + "OpStore %FragColor %55", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, OpImageOutOfBlock) { + // #version 450 + // + // uniform texture2D t2D; + // uniform sampler samp; + // uniform sampler samp2; + // + // out vec4 FragColor; + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // float r = bar.x; + // if (r < 0.0) + // r = -r; + // return r; + // } + // + // void main() + // { + // vec4 color1 = texture(sampler2D(t2D, samp), vec2(1.0)); + // vec4 color2 = vec4(foo(BaseColor)); + // vec4 color3 = texture(sampler2D(t2D, samp2), vec2(0.5)); + // FragColor = (color1 + color2 + color3)/3; + // } + // Note: the before SPIR-V will need to be edited to create an OpImage + // from the first OpSampledImage, place it before the call and use it + // in the second OpSampledImage following the call. + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 450", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %r \"r\"", + "OpName %color1 \"color1\"", + "OpName %t2D \"t2D\"", + "OpName %samp \"samp\"", + "OpName %color2 \"color2\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %color3 \"color3\"", + "OpName %samp2 \"samp2\"", + "OpName %FragColor \"FragColor\"", + "OpDecorate %t2D DescriptorSet 0", + "OpDecorate %samp DescriptorSet 0", + "OpDecorate %samp2 DescriptorSet 0", + "%void = OpTypeVoid", + "%16 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%20 = OpTypeFunction %float %_ptr_Function_v4float", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%float_0 = OpConstant %float 0", + "%bool = OpTypeBool", + "%26 = OpTypeImage %float 2D 0 0 0 1 Unknown", +"%_ptr_UniformConstant_26 = OpTypePointer UniformConstant %26", + "%t2D = OpVariable %_ptr_UniformConstant_26 UniformConstant", + "%28 = OpTypeSampler", +"%_ptr_UniformConstant_28 = OpTypePointer UniformConstant %28", + "%samp = OpVariable %_ptr_UniformConstant_28 UniformConstant", + "%30 = OpTypeSampledImage %26", + "%v2float = OpTypeVector %float 2", + "%float_1 = OpConstant %float 1", + "%33 = OpConstantComposite %v2float %float_1 %float_1", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", + "%samp2 = OpVariable %_ptr_UniformConstant_28 UniformConstant", + "%float_0_5 = OpConstant %float 0.5", + "%36 = OpConstantComposite %v2float %float_0_5 %float_0_5", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", + "%FragColor = OpVariable %_ptr_Output_v4float Output", + "%float_3 = OpConstant %float 3", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %float None %20", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%58 = OpLabel", + "%r = OpVariable %_ptr_Function_float Function", + "%59 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%60 = OpLoad %float %59", + "OpStore %r %60", + "%61 = OpLoad %float %r", + "%62 = OpFOrdLessThan %bool %61 %float_0", + "OpSelectionMerge %63 None", + "OpBranchConditional %62 %64 %63", + "%64 = OpLabel", + "%65 = OpLoad %float %r", + "%66 = OpFNegate %float %65", + "OpStore %r %66", + "OpBranch %63", + "%63 = OpLabel", + "%67 = OpLoad %float %r", + "OpReturnValue %67", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %16", + "%39 = OpLabel", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%40 = OpLoad %26 %t2D", + "%41 = OpLoad %28 %samp", + "%42 = OpSampledImage %30 %40 %41", + "%43 = OpImageSampleImplicitLod %v4float %42 %33", + "%44 = OpImage %26 %42", + "%45 = OpLoad %28 %samp2", + "OpStore %color1 %43", + "%46 = OpLoad %v4float %BaseColor", + "OpStore %param %46", + "%47 = OpFunctionCall %float %foo_vf4_ %param", + "%48 = OpCompositeConstruct %v4float %47 %47 %47 %47", + "OpStore %color2 %48", + "%49 = OpSampledImage %30 %44 %45", + "%50 = OpImageSampleImplicitLod %v4float %49 %36", + "OpStore %color3 %50", + "%51 = OpLoad %v4float %color1", + "%52 = OpLoad %v4float %color2", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpLoad %v4float %color3", + "%55 = OpFAdd %v4float %53 %54", + "%56 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%57 = OpFDiv %v4float %55 %56", + "OpStore %FragColor %57", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %16", + "%39 = OpLabel", + "%68 = OpVariable %_ptr_Function_float Function", + "%69 = OpVariable %_ptr_Function_float Function", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%40 = OpLoad %26 %t2D", + "%41 = OpLoad %28 %samp", + "%42 = OpSampledImage %30 %40 %41", + "%43 = OpImageSampleImplicitLod %v4float %42 %33", + "%44 = OpImage %26 %42", + "%45 = OpLoad %28 %samp2", + "OpStore %color1 %43", + "%46 = OpLoad %v4float %BaseColor", + "OpStore %param %46", + "%70 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%71 = OpLoad %float %70", + "OpStore %68 %71", + "%72 = OpLoad %float %68", + "%73 = OpFOrdLessThan %bool %72 %float_0", + "OpSelectionMerge %74 None", + "OpBranchConditional %73 %75 %74", + "%75 = OpLabel", + "%76 = OpLoad %float %68", + "%77 = OpFNegate %float %76", + "OpStore %68 %77", + "OpBranch %74", + "%74 = OpLabel", + "%78 = OpLoad %float %68", + "OpStore %69 %78", + "%47 = OpLoad %float %69", + "%48 = OpCompositeConstruct %v4float %47 %47 %47 %47", + "OpStore %color2 %48", + "%79 = OpSampledImage %30 %40 %41", + "%80 = OpImage %26 %79", + "%49 = OpSampledImage %30 %80 %45", + "%50 = OpImageSampleImplicitLod %v4float %49 %36", + "OpStore %color3 %50", + "%51 = OpLoad %v4float %color1", + "%52 = OpLoad %v4float %color2", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpLoad %v4float %color3", + "%55 = OpFAdd %v4float %53 %54", + "%56 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%57 = OpFDiv %v4float %55 %56", + "OpStore %FragColor %57", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +TEST_F(InlineTest, OpImageAndOpSampledImageOutOfBlock) { + // #version 450 + // + // uniform texture2D t2D; + // uniform sampler samp; + // uniform sampler samp2; + // + // out vec4 FragColor; + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // float r = bar.x; + // if (r < 0.0) + // r = -r; + // return r; + // } + // + // void main() + // { + // vec4 color1 = texture(sampler2D(t2D, samp), vec2(1.0)); + // vec4 color2 = vec4(foo(BaseColor)); + // vec4 color3 = texture(sampler2D(t2D, samp2), vec2(0.5)); + // FragColor = (color1 + color2 + color3)/3; + // } + // Note: the before SPIR-V will need to be edited to create an OpImage + // and subsequent OpSampledImage that is used across the function call. + const std::vector predefs = { + // clang-format off + "OpCapability Shader", + "%1 = OpExtInstImport \"GLSL.std.450\"", + "OpMemoryModel Logical GLSL450", + "OpEntryPoint Fragment %main \"main\" %BaseColor %FragColor", + "OpExecutionMode %main OriginUpperLeft", + "OpSource GLSL 450", + "OpName %main \"main\"", + "OpName %foo_vf4_ \"foo(vf4;\"", + "OpName %bar \"bar\"", + "OpName %r \"r\"", + "OpName %color1 \"color1\"", + "OpName %t2D \"t2D\"", + "OpName %samp \"samp\"", + "OpName %color2 \"color2\"", + "OpName %BaseColor \"BaseColor\"", + "OpName %param \"param\"", + "OpName %color3 \"color3\"", + "OpName %samp2 \"samp2\"", + "OpName %FragColor \"FragColor\"", + "OpDecorate %t2D DescriptorSet 0", + "OpDecorate %samp DescriptorSet 0", + "OpDecorate %samp2 DescriptorSet 0", + "%void = OpTypeVoid", + "%16 = OpTypeFunction %void", + "%float = OpTypeFloat 32", + "%v4float = OpTypeVector %float 4", +"%_ptr_Function_v4float = OpTypePointer Function %v4float", + "%20 = OpTypeFunction %float %_ptr_Function_v4float", +"%_ptr_Function_float = OpTypePointer Function %float", + "%uint = OpTypeInt 32 0", + "%uint_0 = OpConstant %uint 0", + "%float_0 = OpConstant %float 0", + "%bool = OpTypeBool", + "%26 = OpTypeImage %float 2D 0 0 0 1 Unknown", +"%_ptr_UniformConstant_26 = OpTypePointer UniformConstant %26", + "%t2D = OpVariable %_ptr_UniformConstant_26 UniformConstant", + "%28 = OpTypeSampler", +"%_ptr_UniformConstant_28 = OpTypePointer UniformConstant %28", + "%samp = OpVariable %_ptr_UniformConstant_28 UniformConstant", + "%30 = OpTypeSampledImage %26", + "%v2float = OpTypeVector %float 2", + "%float_1 = OpConstant %float 1", + "%33 = OpConstantComposite %v2float %float_1 %float_1", +"%_ptr_Input_v4float = OpTypePointer Input %v4float", + "%BaseColor = OpVariable %_ptr_Input_v4float Input", + "%samp2 = OpVariable %_ptr_UniformConstant_28 UniformConstant", + "%float_0_5 = OpConstant %float 0.5", + "%36 = OpConstantComposite %v2float %float_0_5 %float_0_5", +"%_ptr_Output_v4float = OpTypePointer Output %v4float", + "%FragColor = OpVariable %_ptr_Output_v4float Output", + "%float_3 = OpConstant %float 3", + // clang-format on + }; + + const std::vector nonEntryFuncs = { + // clang-format off + "%foo_vf4_ = OpFunction %float None %20", + "%bar = OpFunctionParameter %_ptr_Function_v4float", + "%58 = OpLabel", + "%r = OpVariable %_ptr_Function_float Function", + "%59 = OpAccessChain %_ptr_Function_float %bar %uint_0", + "%60 = OpLoad %float %59", + "OpStore %r %60", + "%61 = OpLoad %float %r", + "%62 = OpFOrdLessThan %bool %61 %float_0", + "OpSelectionMerge %63 None", + "OpBranchConditional %62 %64 %63", + "%64 = OpLabel", + "%65 = OpLoad %float %r", + "%66 = OpFNegate %float %65", + "OpStore %r %66", + "OpBranch %63", + "%63 = OpLabel", + "%67 = OpLoad %float %r", + "OpReturnValue %67", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector before = { + // clang-format off + "%main = OpFunction %void None %16", + "%39 = OpLabel", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%40 = OpLoad %26 %t2D", + "%41 = OpLoad %28 %samp", + "%42 = OpSampledImage %30 %40 %41", + "%43 = OpImageSampleImplicitLod %v4float %42 %33", + "%44 = OpImage %26 %42", + "%45 = OpLoad %28 %samp2", + "%46 = OpSampledImage %30 %44 %45", + "OpStore %color1 %43", + "%47 = OpLoad %v4float %BaseColor", + "OpStore %param %47", + "%48 = OpFunctionCall %float %foo_vf4_ %param", + "%49 = OpCompositeConstruct %v4float %48 %48 %48 %48", + "OpStore %color2 %49", + "%50 = OpImageSampleImplicitLod %v4float %46 %36", + "OpStore %color3 %50", + "%51 = OpLoad %v4float %color1", + "%52 = OpLoad %v4float %color2", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpLoad %v4float %color3", + "%55 = OpFAdd %v4float %53 %54", + "%56 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%57 = OpFDiv %v4float %55 %56", + "OpStore %FragColor %57", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + + const std::vector after = { + // clang-format off + "%main = OpFunction %void None %16", + "%39 = OpLabel", + "%68 = OpVariable %_ptr_Function_float Function", + "%69 = OpVariable %_ptr_Function_float Function", + "%color1 = OpVariable %_ptr_Function_v4float Function", + "%color2 = OpVariable %_ptr_Function_v4float Function", + "%param = OpVariable %_ptr_Function_v4float Function", + "%color3 = OpVariable %_ptr_Function_v4float Function", + "%40 = OpLoad %26 %t2D", + "%41 = OpLoad %28 %samp", + "%42 = OpSampledImage %30 %40 %41", + "%43 = OpImageSampleImplicitLod %v4float %42 %33", + "%44 = OpImage %26 %42", + "%45 = OpLoad %28 %samp2", + "%46 = OpSampledImage %30 %44 %45", + "OpStore %color1 %43", + "%47 = OpLoad %v4float %BaseColor", + "OpStore %param %47", + "%70 = OpAccessChain %_ptr_Function_float %param %uint_0", + "%71 = OpLoad %float %70", + "OpStore %68 %71", + "%72 = OpLoad %float %68", + "%73 = OpFOrdLessThan %bool %72 %float_0", + "OpSelectionMerge %74 None", + "OpBranchConditional %73 %75 %74", + "%75 = OpLabel", + "%76 = OpLoad %float %68", + "%77 = OpFNegate %float %76", + "OpStore %68 %77", + "OpBranch %74", + "%74 = OpLabel", + "%78 = OpLoad %float %68", + "OpStore %69 %78", + "%48 = OpLoad %float %69", + "%49 = OpCompositeConstruct %v4float %48 %48 %48 %48", + "OpStore %color2 %49", + "%79 = OpSampledImage %30 %40 %41", + "%80 = OpImage %26 %79", + "%81 = OpSampledImage %30 %80 %45", + "%50 = OpImageSampleImplicitLod %v4float %81 %36", + "OpStore %color3 %50", + "%51 = OpLoad %v4float %color1", + "%52 = OpLoad %v4float %color2", + "%53 = OpFAdd %v4float %51 %52", + "%54 = OpLoad %v4float %color3", + "%55 = OpFAdd %v4float %53 %54", + "%56 = OpCompositeConstruct %v4float %float_3 %float_3 %float_3 %float_3", + "%57 = OpFDiv %v4float %55 %56", + "OpStore %FragColor %57", + "OpReturn", + "OpFunctionEnd", + // clang-format on + }; + SinglePassRunAndCheck( + JoinAllInsts(concat(concat(predefs, before), nonEntryFuncs)), + JoinAllInsts(concat(concat(predefs, after), nonEntryFuncs)), + /* skip_nop = */ false, /* do_validate = */ true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// Empty modules +// Modules without function definitions +// Modules in which all functions do not call other functions +// Recursive functions (calling self & calling each other) +// Caller and callee both accessing the same global variable +// Functions with OpLine & OpNoLine +// Others? + +} // anonymous namespace diff --git a/test/opt/pass_fixture.h b/test/opt/pass_fixture.h index 97a9611..1b257a6 100644 --- a/test/opt/pass_fixture.h +++ b/test/opt/pass_fixture.h @@ -45,28 +45,36 @@ class PassTest : public TestT { tools_(SPV_ENV_UNIVERSAL_1_1), manager_(new opt::PassManager()) {} - // Runs the given |pass| on the binary assembled from the |assembly|, and - // disassebles the optimized binary. Returns a tuple of disassembly string - // and the boolean value returned from pass Process() function. - std::tuple OptimizeAndDisassemble( + // Runs the given |pass| on the binary assembled from the |original|. + // Returns a tuple of the optimized binary and the boolean value returned + // from pass Process() function. + std::tuple, opt::Pass::Status> OptimizeToBinary( opt::Pass* pass, const std::string& original, bool skip_nop) { std::unique_ptr module = BuildModule(SPV_ENV_UNIVERSAL_1_1, consumer_, original); EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" << original << std::endl; if (!module) { - return std::make_tuple(std::string(), opt::Pass::Status::Failure); + return std::make_tuple(std::vector(), + opt::Pass::Status::Failure); } const auto status = pass->Process(module.get()); std::vector binary; module->ToBinary(&binary, skip_nop); - std::string optimized; - EXPECT_TRUE(tools_.Disassemble(binary, &optimized)) - << "Disassembling failed for shader:\n" - << original << std::endl; - return std::make_tuple(optimized, status); + return std::make_tuple(binary, status); + } + + // Runs a single pass of class |PassT| on the binary assembled from the + // |assembly|. Returns a tuple of the optimized binary and the boolean value + // from the pass Process() function. + template + std::tuple, opt::Pass::Status> SinglePassRunToBinary( + const std::string& assembly, bool skip_nop, Args&&... args) { + auto pass = MakeUnique(std::forward(args)...); + pass->SetMessageConsumer(consumer_); + return OptimizeToBinary(pass.get(), assembly, skip_nop); } // Runs a single pass of class |PassT| on the binary assembled from the @@ -75,28 +83,64 @@ class PassTest : public TestT { template std::tuple SinglePassRunAndDisassemble( const std::string& assembly, bool skip_nop, Args&&... args) { - auto pass = MakeUnique(std::forward(args)...); - pass->SetMessageConsumer(consumer_); - return OptimizeAndDisassemble(pass.get(), assembly, skip_nop); + std::vector optimized_bin; + auto status = opt::Pass::Status::SuccessWithoutChange; + std::tie(optimized_bin, status) = SinglePassRunToBinary( + assembly, skip_nop, std::forward(args)...); + std::string optimized_asm; + EXPECT_TRUE(tools_.Disassemble(optimized_bin, &optimized_asm)) + << "Disassembling failed for shader:\n" + << assembly << std::endl; + return std::make_tuple(optimized_asm, status); } // Runs a single pass of class |PassT| on the binary assembled from the // |original| assembly, and checks whether the optimized binary can be - // disassembled to the |expected| assembly. This does *not* involve pass - // manager. Callers are suggested to use SCOPED_TRACE() for better messages. + // disassembled to the |expected| assembly. Optionally will also validate + // the optimized binary. This does *not* involve pass manager. Callers + // are suggested to use SCOPED_TRACE() for better messages. template void SinglePassRunAndCheck(const std::string& original, const std::string& expected, bool skip_nop, - Args&&... args) { - std::string optimized; + bool do_validation, Args&&... args) { + std::vector optimized_bin; auto status = opt::Pass::Status::SuccessWithoutChange; - std::tie(optimized, status) = SinglePassRunAndDisassemble( + std::tie(optimized_bin, status) = SinglePassRunToBinary( original, skip_nop, std::forward(args)...); // Check whether the pass returns the correct modification indication. EXPECT_NE(opt::Pass::Status::Failure, status); EXPECT_EQ(original == expected, status == opt::Pass::Status::SuccessWithoutChange); - EXPECT_EQ(expected, optimized); + if (do_validation) { + spv_target_env target_env = SPV_ENV_UNIVERSAL_1_1; + spv_context context = spvContextCreate(target_env); + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t binary = {optimized_bin.data(), + optimized_bin.size()}; + spv_result_t error = spvValidate(context, &binary, &diagnostic); + EXPECT_EQ(error, 0); + if (error != 0) + spvDiagnosticPrint(diagnostic); + spvDiagnosticDestroy(diagnostic); + spvContextDestroy(context); + } + std::string optimized_asm; + EXPECT_TRUE(tools_.Disassemble(optimized_bin, &optimized_asm)) + << "Disassembling failed for shader:\n" + << original << std::endl; + EXPECT_EQ(expected, optimized_asm); + } + + // Runs a single pass of class |PassT| on the binary assembled from the + // |original| assembly, and checks whether the optimized binary can be + // disassembled to the |expected| assembly. This does *not* involve pass + // manager. Callers are suggested to use SCOPED_TRACE() for better messages. + template + void SinglePassRunAndCheck(const std::string& original, + const std::string& expected, bool skip_nop, + Args&&... args) { + SinglePassRunAndCheck(original, expected, skip_nop, false, + std::forward(args)...); } // Adds a pass to be run. diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp index 1732818..d9dea33 100644 --- a/tools/opt/opt.cpp +++ b/tools/opt/opt.cpp @@ -61,6 +61,8 @@ Options: e.g.: --set-spec-const-default-value "1:100 2:400" --unify-const Remove the duplicated constants. + --inline-entry-points-all + Exhaustively inline all function calls in entry points -h, --help Print this help. --version Display optimizer version information. )", @@ -121,6 +123,8 @@ int main(int argc, char** argv) { } } else if (0 == strcmp(cur_arg, "--freeze-spec-const")) { pass_manager.AddPass(); + } else if (0 == strcmp(cur_arg, "--inline-entry-points-all")) { + pass_manager.AddPass(); } else if (0 == strcmp(cur_arg, "--eliminate-dead-const")) { pass_manager.AddPass(); } else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) { -- 2.7.4