From aa7e687ef0a14066a12d0eb3825a2301f32363cb Mon Sep 17 00:00:00 2001 From: GregF Date: Fri, 12 May 2017 17:27:21 -0600 Subject: [PATCH] Mem2Reg: Add Local Access Chain Convert pass - Supports OpAccessChain and OpInBoundsAccessChain - Does not process modules with non-32-bit integer types. --- include/spirv-tools/optimizer.hpp | 17 + source/opt/CMakeLists.txt | 2 + source/opt/local_access_chain_convert_pass.cpp | 369 +++++++++++++++++++++ source/opt/local_access_chain_convert_pass.h | 167 ++++++++++ source/opt/optimizer.cpp | 5 + source/opt/passes.h | 1 + test/opt/CMakeLists.txt | 5 + test/opt/local_access_chain_convert_test.cpp | 422 +++++++++++++++++++++++++ tools/opt/opt.cpp | 2 + 9 files changed, 990 insertions(+) create mode 100644 source/opt/local_access_chain_convert_pass.cpp create mode 100644 source/opt/local_access_chain_convert_pass.h create mode 100644 test/opt/local_access_chain_convert_test.cpp diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index 3c510e0..6695573 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -194,6 +194,23 @@ Optimizer::PassToken CreateEliminateDeadConstantPass(); // points are not changed. Optimizer::PassToken CreateInlinePass(); +// Creates a local access chain conversion pass. +// A local access chain conversion pass identifies all function scope +// variables which are accessed only with loads, stores and access chains +// with constant indices. It then converts all loads and stores of such +// variables into equivalent sequences of loads, stores, extracts and inserts. +// +// This pass only processes entry point functions. It currently only converts +// non-nested, non-ptr access chains. It does not process modules with +// non-32-bit integer types present. Optional memory access options on loads +// and stores are ignored as we are only processing function scope variables. +// +// This pass unifies access to these variables to a single mode and simplifies +// subsequent analysis and elimination of these variables along with their +// loads and stores allowing values to propagate to their points of use where +// possible. +Optimizer::PassToken CreateLocalAccessChainConvertPass(); + // Creates a compact ids pass. // The pass remaps result ids to a compact and gapless range starting from %1. Optimizer::PassToken CreateCompactIdsPass(); diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index a665526..2a9a61b 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -25,6 +25,7 @@ add_library(SPIRV-Tools-opt inline_pass.h instruction.h ir_loader.h + local_access_chain_convert_pass.h log.h module.h null_pass.h @@ -50,6 +51,7 @@ add_library(SPIRV-Tools-opt inline_pass.cpp instruction.cpp ir_loader.cpp + local_access_chain_convert_pass.cpp module.cpp set_spec_constant_default_value_pass.cpp optimizer.cpp diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp new file mode 100644 index 0000000..187494f --- /dev/null +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -0,0 +1,369 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iterator.h" +#include "local_access_chain_convert_pass.h" + +static const int kSpvEntryPointFunctionId = 1; +static const int kSpvStorePtrId = 0; +static const int kSpvStoreValId = 1; +static const int kSpvLoadPtrId = 0; +static const int kSpvAccessChainPtrId = 0; +static const int kSpvTypePointerStorageClass = 0; +static const int kSpvTypePointerTypeId = 1; +static const int kSpvConstantValue = 0; +static const int kSpvTypeIntWidth = 0; + +namespace spvtools { +namespace opt { + +bool LocalAccessChainConvertPass::IsNonPtrAccessChain( + const SpvOp opcode) const { + return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain; +} + +bool LocalAccessChainConvertPass::IsMathType( + const ir::Instruction* typeInst) const { + switch (typeInst->opcode()) { + case SpvOpTypeInt: + case SpvOpTypeFloat: + case SpvOpTypeBool: + case SpvOpTypeVector: + case SpvOpTypeMatrix: + return true; + default: + break; + } + return false; +} + +bool LocalAccessChainConvertPass::IsTargetType( + const ir::Instruction* typeInst) const { + if (IsMathType(typeInst)) + return true; + if (typeInst->opcode() == SpvOpTypeArray) + return IsMathType(def_use_mgr_->GetDef(typeInst->GetSingleWordOperand(1))); + if (typeInst->opcode() != SpvOpTypeStruct) + return false; + // All struct members must be math type + int nonMathComp = 0; + typeInst->ForEachInId([&nonMathComp,this](const uint32_t* tid) { + ir::Instruction* compTypeInst = def_use_mgr_->GetDef(*tid); + if (!IsMathType(compTypeInst)) ++nonMathComp; + }); + return nonMathComp == 0; +} + +ir::Instruction* LocalAccessChainConvertPass::GetPtr( + ir::Instruction* ip, + uint32_t* varId) { + const uint32_t ptrId = ip->GetSingleWordInOperand( + ip->opcode() == SpvOpStore ? kSpvStorePtrId : kSpvLoadPtrId); + ir::Instruction* ptrInst = def_use_mgr_->GetDef(ptrId); + *varId = IsNonPtrAccessChain(ptrInst->opcode()) ? + ptrInst->GetSingleWordInOperand(kSpvAccessChainPtrId) : + ptrId; + return ptrInst; +} + +bool LocalAccessChainConvertPass::IsTargetVar(uint32_t varId) { + if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end()) + return false; + if (seen_target_vars_.find(varId) != seen_target_vars_.end()) + return true; + const ir::Instruction* varInst = def_use_mgr_->GetDef(varId); + if (varInst->opcode() != SpvOpVariable) + return false;; + const uint32_t varTypeId = varInst->type_id(); + const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId); + if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) != + SpvStorageClassFunction) { + seen_non_target_vars_.insert(varId); + return false; + } + const uint32_t varPteTypeId = + varTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId); + ir::Instruction* varPteTypeInst = def_use_mgr_->GetDef(varPteTypeId); + if (!IsTargetType(varPteTypeInst)) { + seen_non_target_vars_.insert(varId); + return false; + } + seen_target_vars_.insert(varId); + return true; +} + +void LocalAccessChainConvertPass::DeleteIfUseless(ir::Instruction* inst) { + const uint32_t resId = inst->result_id(); + assert(resId != 0); + analysis::UseList* uses = def_use_mgr_->GetUses(resId); + if (uses == nullptr) + def_use_mgr_->KillInst(inst); +} + +void LocalAccessChainConvertPass::ReplaceAndDeleteLoad( + ir::Instruction* loadInst, + uint32_t replId, + ir::Instruction* ptrInst) { + const uint32_t loadId = loadInst->result_id(); + (void) def_use_mgr_->ReplaceAllUsesWith(loadId, replId); + // remove load instruction + def_use_mgr_->KillInst(loadInst); + // if access chain, see if it can be removed as well + if (IsNonPtrAccessChain(ptrInst->opcode())) { + DeleteIfUseless(ptrInst); + } +} + +uint32_t LocalAccessChainConvertPass::GetPointeeTypeId( + const ir::Instruction* ptrInst) const { + const uint32_t ptrTypeId = ptrInst->type_id(); + const ir::Instruction* ptrTypeInst = def_use_mgr_->GetDef(ptrTypeId); + return ptrTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId); +} + +void LocalAccessChainConvertPass::BuildAndAppendInst( + SpvOp opcode, + uint32_t typeId, + uint32_t resultId, + const std::vector& in_opnds, + std::vector>* newInsts) { + std::unique_ptr newInst(new ir::Instruction( + opcode, typeId, resultId, in_opnds)); + def_use_mgr_->AnalyzeInstDefUse(&*newInst); + newInsts->emplace_back(std::move(newInst)); +} + +uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad( + const ir::Instruction* ptrInst, + uint32_t* varId, + uint32_t* varPteTypeId, + std::vector>* newInsts) { + const uint32_t ldResultId = TakeNextId(); + *varId = ptrInst->GetSingleWordInOperand(kSpvAccessChainPtrId); + const ir::Instruction* varInst = def_use_mgr_->GetDef(*varId); + assert(varInst->opcode() == SpvOpVariable); + *varPteTypeId = GetPointeeTypeId(varInst); + BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}}, newInsts); + return ldResultId; +} + +void LocalAccessChainConvertPass::AppendConstantOperands( + const ir::Instruction* ptrInst, + std::vector* in_opnds) { + uint32_t iidIdx = 0; + ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t *iid) { + if (iidIdx > 0) { + const ir::Instruction* cInst = def_use_mgr_->GetDef(*iid); + uint32_t val = cInst->GetSingleWordInOperand(kSpvConstantValue); + in_opnds->push_back( + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}}); + } + ++iidIdx; + }); +} + +uint32_t LocalAccessChainConvertPass::GenAccessChainLoadReplacement( + const ir::Instruction* ptrInst, + std::vector>* newInsts) { + + // Build and append load of variable in ptrInst + uint32_t varId; + uint32_t varPteTypeId; + const uint32_t ldResultId = BuildAndAppendVarLoad(ptrInst, &varId, + &varPteTypeId, newInsts); + + // Build and append Extract + const uint32_t extResultId = TakeNextId(); + const uint32_t ptrPteTypeId = GetPointeeTypeId(ptrInst); + std::vector ext_in_opnds = + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}}; + AppendConstantOperands(ptrInst, &ext_in_opnds); + BuildAndAppendInst(SpvOpCompositeExtract, ptrPteTypeId, extResultId, + ext_in_opnds, newInsts); + return extResultId; +} + +void LocalAccessChainConvertPass::GenAccessChainStoreReplacement( + const ir::Instruction* ptrInst, + uint32_t valId, + std::vector>* newInsts) { + + // Build and append load of variable in ptrInst + uint32_t varId; + uint32_t varPteTypeId; + const uint32_t ldResultId = BuildAndAppendVarLoad(ptrInst, &varId, + &varPteTypeId, newInsts); + + // Build and append Insert + const uint32_t insResultId = TakeNextId(); + std::vector ins_in_opnds = + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}}; + AppendConstantOperands(ptrInst, &ins_in_opnds); + BuildAndAppendInst( + SpvOpCompositeInsert, varPteTypeId, insResultId, ins_in_opnds, newInsts); + + // Build and append Store + BuildAndAppendInst(SpvOpStore, 0, 0, + {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}}, + newInsts); +} + +bool LocalAccessChainConvertPass::IsConstantIndexAccessChain( + const ir::Instruction* acp) const { + uint32_t inIdx = 0; + uint32_t nonConstCnt = 0; + acp->ForEachInId([&inIdx, &nonConstCnt, this](const uint32_t* tid) { + if (inIdx > 0) { + ir::Instruction* opInst = def_use_mgr_->GetDef(*tid); + if (opInst->opcode() != SpvOpConstant) ++nonConstCnt; + } + ++inIdx; + }); + return nonConstCnt == 0; +} + +void LocalAccessChainConvertPass::FindTargetVars(ir::Function* func) { + for (auto bi = func->begin(); bi != func->end(); ++bi) { + for (auto ii = bi->begin(); ii != bi->end(); ++ii) { + switch (ii->opcode()) { + case SpvOpStore: + case SpvOpLoad: { + uint32_t varId; + ir::Instruction* ptrInst = GetPtr(&*ii, &varId); + // For now, only convert non-ptr access chains + if (!IsNonPtrAccessChain(ptrInst->opcode())) + break; + // For now, only convert non-nested access chains + // TODO(): Convert nested access chains + if (!IsTargetVar(varId)) + break; + // Rule out variables accessed with non-constant indices + if (!IsConstantIndexAccessChain(ptrInst)) { + seen_non_target_vars_.insert(varId); + seen_target_vars_.erase(varId); + break; + } + } break; + default: + break; + } + } + } +} + +bool LocalAccessChainConvertPass::ConvertLocalAccessChains(ir::Function* func) { + FindTargetVars(func); + // Replace access chains of all targeted variables with equivalent + // extract and insert sequences + bool modified = false; + for (auto bi = func->begin(); bi != func->end(); ++bi) { + for (auto ii = bi->begin(); ii != bi->end(); ++ii) { + switch (ii->opcode()) { + case SpvOpLoad: { + uint32_t varId; + ir::Instruction* ptrInst = GetPtr(&*ii, &varId); + if (!IsNonPtrAccessChain(ptrInst->opcode())) + break; + if (!IsTargetVar(varId)) + break; + std::vector> newInsts; + uint32_t replId = + GenAccessChainLoadReplacement(ptrInst, &newInsts); + ReplaceAndDeleteLoad(&*ii, replId, ptrInst); + ++ii; + ii = ii.InsertBefore(&newInsts); + ++ii; + modified = true; + } break; + case SpvOpStore: { + uint32_t varId; + ir::Instruction* ptrInst = GetPtr(&*ii, &varId); + if (!IsNonPtrAccessChain(ptrInst->opcode())) + break; + if (!IsTargetVar(varId)) + break; + std::vector> newInsts; + uint32_t valId = ii->GetSingleWordInOperand(kSpvStoreValId); + GenAccessChainStoreReplacement(ptrInst, valId, &newInsts); + def_use_mgr_->KillInst(&*ii); + DeleteIfUseless(ptrInst); + ++ii; + ii = ii.InsertBefore(&newInsts); + ++ii; + ++ii; + modified = true; + } break; + default: + break; + } + } + } + return modified; +} + +void LocalAccessChainConvertPass::Initialize(ir::Module* module) { + + module_ = module; + + // Initialize function and block maps + id2function_.clear(); + for (auto& fn : *module_) + id2function_[fn.result_id()] = &fn; + + // Initialize Target Variable Caches + seen_target_vars_.clear(); + seen_non_target_vars_.clear(); + + def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_)); + + // Initialize next unused Id. + next_id_ = module->id_bound(); +}; + +Pass::Status LocalAccessChainConvertPass::ProcessImpl() { + // If non-32-bit integer type in module, terminate processing + // TODO(): Handle non-32-bit integer constants in access chains + for (const ir::Instruction& inst : module_->types_values()) + if (inst.opcode() == SpvOpTypeInt && + inst.GetSingleWordInOperand(kSpvTypeIntWidth) != 32) + return Status::SuccessWithoutChange; + // Process all entry point functions. + bool modified = false; + for (auto& e : module_->entry_points()) { + ir::Function* fn = + id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)]; + modified = modified || ConvertLocalAccessChains(fn); + } + + FinalizeNextId(module_); + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +LocalAccessChainConvertPass::LocalAccessChainConvertPass() + : module_(nullptr), def_use_mgr_(nullptr), next_id_(0) {} + +Pass::Status LocalAccessChainConvertPass::Process(ir::Module* module) { + Initialize(module); + return ProcessImpl(); +} + +} // namespace opt +} // namespace spvtools + diff --git a/source/opt/local_access_chain_convert_pass.h b/source/opt/local_access_chain_convert_pass.h new file mode 100644 index 0000000..3a2d605 --- /dev/null +++ b/source/opt/local_access_chain_convert_pass.h @@ -0,0 +1,167 @@ +// Copyright (c) 2017 The Khronos Group Inc. +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef LIBSPIRV_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ +#define LIBSPIRV_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ + + +#include +#include +#include +#include +#include +#include + +#include "basic_block.h" +#include "def_use_manager.h" +#include "module.h" +#include "pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class LocalAccessChainConvertPass : public Pass { + public: + LocalAccessChainConvertPass(); + const char* name() const override { return "convert-local-access-chains"; } + Status Process(ir::Module*) override; + + private: + // Returns true if |opcode| is a non-pointer access chain op + // TODO(): Support conversion of pointer access chains. + bool IsNonPtrAccessChain(const SpvOp opcode) const; + + // Returns true if |typeInst| is a scalar type + // or a vector or matrix + bool IsMathType(const ir::Instruction* typeInst) const; + + // Returns true if |typeInst| is a math type or a struct or array + // of a math type. + // TODO(): Add more complex types to convert + bool IsTargetType(const ir::Instruction* typeInst) const; + + // Given a load or store |ip|, return the pointer instruction. + // If the pointer is an access chain, |*varId| is its base id. + // Otherwise it is the id of the pointer of the load/store. + ir::Instruction* GetPtr(ir::Instruction* ip, uint32_t* varId); + + // Search |func| and cache function scope variables of target type that are + // not accessed with non-constant-index access chains. Also cache non-target + // variables. + void FindTargetVars(ir::Function* func); + + // Return true if |varId| is a previously identified target variable. + // Return false if |varId| is a previously identified non-target variable. + // See FindTargetVars() for definition of target variable. If variable is + // not cached, return true if variable is a function scope variable of + // target type, false otherwise. Updates caches of target and non-target + // variables. + bool IsTargetVar(uint32_t varId); + + // Delete |inst| if it has no uses. Assumes |inst| has a non-zero resultId. + void DeleteIfUseless(ir::Instruction* inst); + + // Replace all instances of |loadInst|'s id with |replId| and delete + // |loadInst| and its pointer |ptrInst| if it is a useless access chain. + void ReplaceAndDeleteLoad(ir::Instruction* loadInst, + uint32_t replId, + ir::Instruction* ptrInst); + + // Return type id for |ptrInst|'s pointee + uint32_t GetPointeeTypeId(const ir::Instruction* ptrInst) const; + + // Build instruction from |opcode|, |typeId|, |resultId|, and |in_opnds|. + // Append to |newInsts|. + void BuildAndAppendInst(SpvOp opcode, uint32_t typeId, uint32_t resultId, + const std::vector& in_opnds, + std::vector>* newInsts); + + // Build load of variable in |ptrInst| and append to |newInsts|. + // Return var in |varId| and its pointee type in |varPteTypeId|. + uint32_t BuildAndAppendVarLoad(const ir::Instruction* ptrInst, + uint32_t* varId, uint32_t* varPteTypeId, + std::vector>* newInsts); + + // Append literal integer operands to |in_opnds| corresponding to constant + // integer operands from access chain |ptrInst|. Assumes all indices in + // access chains are OpConstant. + void AppendConstantOperands( const ir::Instruction* ptrInst, + std::vector* in_opnds); + + // Create a load/insert/store equivalent to a store of + // |valId| through (constant index) access chaing |ptrInst|. + // Append to |newInsts|. + void GenAccessChainStoreReplacement(const ir::Instruction* ptrInst, + uint32_t valId, + std::vector>* newInsts); + + // For the (constant index) access chain |ptrInst|, create an + // equivalent load and extract. Append to |newInsts|. + uint32_t GenAccessChainLoadReplacement(const ir::Instruction* ptrInst, + std::vector>* newInsts); + + // Return true if all indices of access chain |acp| are OpConstant integers + bool IsConstantIndexAccessChain(const ir::Instruction* acp) const; + + // Identify all function scope variables of target type which are + // accessed only with loads, stores and access chains with constant + // indices. Convert all loads and stores of such variables into equivalent + // loads, stores, extracts and inserts. This unifies access to these + // variables to a single mode and simplifies analysis and optimization. + // See IsTargetType() for targeted types. + // + // Nested access chains and pointer access chains are not currently + // converted. + bool ConvertLocalAccessChains(ir::Function* func); + + // Save next available id into |module|. + inline void FinalizeNextId(ir::Module* module) { + module->SetIdBound(next_id_); + } + + // Return next available id and calculate next. + inline uint32_t TakeNextId() { + return next_id_++; + } + + void Initialize(ir::Module* module); + Pass::Status ProcessImpl(); + + // Module this pass is processing + ir::Module* module_; + + // Def-Uses for the module we are processing + std::unique_ptr def_use_mgr_; + + // Map from function's result id to function + std::unordered_map id2function_; + + // Cache of verified target vars + std::unordered_set seen_target_vars_; + + // Cache of verified non-target vars + std::unordered_set seen_non_target_vars_; + + // Next unused ID + uint32_t next_id_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // LIBSPIRV_OPT_LOCAL_ACCESS_CHAIN_CONVERT_PASS_H_ + diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index 80643e9..9fde8d3 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -136,6 +136,11 @@ Optimizer::PassToken CreateInlinePass() { return MakeUnique(MakeUnique()); } +Optimizer::PassToken CreateLocalAccessChainConvertPass() { + return MakeUnique( + MakeUnique()); +} + Optimizer::PassToken CreateCompactIdsPass() { return MakeUnique( MakeUnique()); diff --git a/source/opt/passes.h b/source/opt/passes.h index 9f7668b..3d19753 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -23,6 +23,7 @@ #include "fold_spec_constant_op_and_composite_pass.h" #include "inline_pass.h" #include "freeze_spec_constant_value_pass.h" +#include "local_access_chain_convert_pass.h" #include "null_pass.h" #include "set_spec_constant_default_value_pass.h" #include "strip_debug_info_pass.h" diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index bcfd90f..97eadb8 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -58,6 +58,11 @@ add_spvtools_unittest(TARGET pass_inline LIBS SPIRV-Tools-opt ) +add_spvtools_unittest(TARGET pass_local_access_chain_convert + SRCS local_access_chain_convert_test.cpp pass_utils.cpp + LIBS SPIRV-Tools-opt +) + add_spvtools_unittest(TARGET pass_eliminate_dead_const SRCS eliminate_dead_const_test.cpp pass_utils.cpp LIBS SPIRV-Tools-opt diff --git a/test/opt/local_access_chain_convert_test.cpp b/test/opt/local_access_chain_convert_test.cpp new file mode 100644 index 0000000..ad37622 --- /dev/null +++ b/test/opt/local_access_chain_convert_test.cpp @@ -0,0 +1,422 @@ +// Copyright (c) 2017 Valve Corporation +// Copyright (c) 2017 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "pass_fixture.h" +#include "pass_utils.h" + +namespace { + +using namespace spvtools; + +using LocalAccessChainConvertTest = PassTest<::testing::Test>; + +TEST_F(LocalAccessChainConvertTest, StructOfVecsOfFloatConverted) { + + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %19 %18 +%20 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +%21 = OpLoad %v4float %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%22 = OpLoad %S_t %s0 +%23 = OpCompositeInsert %S_t %18 %22 1 +OpStore %s0 %23 +%24 = OpLoad %S_t %s0 +%25 = OpCompositeExtract %v4float %24 1 +OpStore %gl_FragColor %25 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before, predefs + after, true, true); +} + +TEST_F(LocalAccessChainConvertTest, InBoundsAccessChainsConverted) { + + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpInBoundsAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %19 %18 +%20 = OpInBoundsAccessChain %_ptr_Function_v4float %s0 %int_1 +%21 = OpLoad %v4float %20 +OpStore %gl_FragColor %21 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%22 = OpLoad %S_t %s0 +%23 = OpCompositeInsert %S_t %18 %22 1 +OpStore %s0 %23 +%24 = OpLoad %S_t %s0 +%25 = OpCompositeExtract %v4float %24 1 +OpStore %gl_FragColor %25 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before, predefs + after, true, true); +} + +TEST_F(LocalAccessChainConvertTest, TwoUsesofSingleChainConverted) { + + // #version 140 + // + // in vec4 BaseColor; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // gl_FragColor = s0.v1; + // } + + const std::string predefs = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%8 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +)"; + + const std::string before = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%19 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %19 %18 +%20 = OpLoad %v4float %19 +OpStore %gl_FragColor %20 +OpReturn +OpFunctionEnd +)"; + + const std::string after = + R"(%main = OpFunction %void None %8 +%17 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%18 = OpLoad %v4float %BaseColor +%21 = OpLoad %S_t %s0 +%22 = OpCompositeInsert %S_t %18 %21 1 +OpStore %s0 %22 +%23 = OpLoad %S_t %s0 +%24 = OpCompositeExtract %v4float %23 1 +OpStore %gl_FragColor %24 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + predefs + before, predefs + after, true, true); +} + +TEST_F(LocalAccessChainConvertTest, + UntargetedTypeNotConverted) { + + // #version 140 + // + // in vec4 BaseColor; + // + // struct S1_t { + // vec4 v1; + // }; + // + // struct S2_t { + // vec4 v2; + // S1_t s1; + // }; + // + // void main() + // { + // S2_t s2; + // s2.s1.v1 = BaseColor; + // gl_FragColor = s2.s1.v1; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S1_t "S1_t" +OpMemberName %S1_t 0 "v1" +OpName %S2_t "S2_t" +OpMemberName %S2_t 0 "v2" +OpMemberName %S2_t 1 "s1" +OpName %s2 "s2" +OpName %BaseColor "BaseColor" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%9 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S1_t = OpTypeStruct %v4float +%S2_t = OpTypeStruct %v4float %S1_t +%_ptr_Function_S2_t = OpTypePointer Function %S2_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%int_0 = OpConstant %int 0 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %9 +%19 = OpLabel +%s2 = OpVariable %_ptr_Function_S2_t Function +%20 = OpLoad %v4float %BaseColor +%21 = OpAccessChain %_ptr_Function_v4float %s2 %int_1 %int_0 +OpStore %21 %20 +%22 = OpAccessChain %_ptr_Function_v4float %s2 %int_1 %int_0 +%23 = OpLoad %v4float %22 +OpStore %gl_FragColor %23 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + assembly, assembly, false, true); +} + +TEST_F(LocalAccessChainConvertTest, + DynamicallyIndexedVarNotConverted) { + + // #version 140 + // + // in vec4 BaseColor; + // flat in int Idx; + // in float Bi; + // + // struct S_t { + // vec4 v0; + // vec4 v1; + // }; + // + // void main() + // { + // S_t s0; + // s0.v1 = BaseColor; + // s0.v1[Idx] = Bi; + // gl_FragColor = s0.v1; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %Idx %Bi %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %S_t "S_t" +OpMemberName %S_t 0 "v0" +OpMemberName %S_t 1 "v1" +OpName %s0 "s0" +OpName %BaseColor "BaseColor" +OpName %Idx "Idx" +OpName %Bi "Bi" +OpName %gl_FragColor "gl_FragColor" +OpDecorate %Idx Flat +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%S_t = OpTypeStruct %v4float %v4float +%_ptr_Function_S_t = OpTypePointer Function %S_t +%int = OpTypeInt 32 1 +%int_1 = OpConstant %int 1 +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Function_v4float = OpTypePointer Function %v4float +%_ptr_Input_int = OpTypePointer Input %int +%Idx = OpVariable %_ptr_Input_int Input +%_ptr_Input_float = OpTypePointer Input %float +%Bi = OpVariable %_ptr_Input_float Input +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %10 +%22 = OpLabel +%s0 = OpVariable %_ptr_Function_S_t Function +%23 = OpLoad %v4float %BaseColor +%24 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +OpStore %24 %23 +%25 = OpLoad %int %Idx +%26 = OpLoad %float %Bi +%27 = OpAccessChain %_ptr_Function_float %s0 %int_1 %25 +OpStore %27 %26 +%28 = OpAccessChain %_ptr_Function_v4float %s0 %int_1 +%29 = OpLoad %v4float %28 +OpStore %gl_FragColor %29 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck( + assembly, assembly, false, true); +} + +// TODO(greg-lunarg): Add tests to verify handling of these cases: +// +// Assorted vector and matrix types +// Assorted struct array types +// Assorted scalar types +// Assorted non-target types +// OpInBoundsAccessChain +// Others? + +} // anonymous namespace diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp index a27d514..234cdc5 100644 --- a/tools/opt/opt.cpp +++ b/tools/opt/opt.cpp @@ -133,6 +133,8 @@ int main(int argc, char** argv) { optimizer.RegisterPass(CreateFreezeSpecConstantValuePass()); } else if (0 == strcmp(cur_arg, "--inline-entry-points-exhaustive")) { optimizer.RegisterPass(CreateInlinePass()); + } else if (0 == strcmp(cur_arg, "--convert-local-access-chains")) { + optimizer.RegisterPass(CreateLocalAccessChainConvertPass()); } else if (0 == strcmp(cur_arg, "--eliminate-dead-const")) { optimizer.RegisterPass(CreateEliminateDeadConstantPass()); } else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) { -- 2.7.4