From a107d349e3fc124b43f04488a991c5cadf7abac4 Mon Sep 17 00:00:00 2001 From: GregF Date: Tue, 25 Apr 2017 13:57:20 -0600 Subject: [PATCH] Inline: Do not inline functions with multiple returns (for now) --- source/opt/basic_block.h | 8 +++-- source/opt/inline_pass.cpp | 30 ++++++++++++++--- source/opt/inline_pass.h | 6 ++++ test/opt/inline_test.cpp | 80 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 6 deletions(-) diff --git a/source/opt/basic_block.h b/source/opt/basic_block.h index 05258a9..6fb3a23 100644 --- a/source/opt/basic_block.h +++ b/source/opt/basic_block.h @@ -52,8 +52,12 @@ class BasicBlock { iterator begin() { return iterator(&insts_, insts_.begin()); } iterator end() { return iterator(&insts_, insts_.end()); } - const_iterator cbegin() { return const_iterator(&insts_, insts_.cbegin()); } - const_iterator cend() { return const_iterator(&insts_, insts_.cend()); } + const_iterator cbegin() const { + return const_iterator(&insts_, insts_.cbegin()); + } + const_iterator cend() const { + return const_iterator(&insts_, insts_.cend()); + } // Runs the given function |f| on each instruction in this basic block, and // optionally on the debug line instructions that might precede them. diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp index 8be3ef9..6211202 100644 --- a/source/opt/inline_pass.cpp +++ b/source/opt/inline_pass.cpp @@ -353,10 +353,10 @@ void InlinePass::GenInlineCode( bool InlinePass::IsInlinableFunctionCall(const ir::Instruction* inst) { if (inst->opcode() != SpvOp::SpvOpFunctionCall) return false; - const ir::Function* calleeFn = - id2function_[inst->GetSingleWordOperand(kSpvFunctionCallFunctionId)]; - // We can only inline a function if it has blocks. - return calleeFn->cbegin() != calleeFn->cend(); + const uint32_t calleeFnId = + inst->GetSingleWordOperand(kSpvFunctionCallFunctionId); + const auto ci = inlinable_.find(calleeFnId); + return ci != inlinable_.cend(); } bool InlinePass::Inline(ir::Function* func) { @@ -402,6 +402,25 @@ bool InlinePass::Inline(ir::Function* func) { return modified; } +bool InlinePass::IsInlinableFunction(const ir::Function* func) { + // We can only inline a function if it has blocks. + if (func->cbegin() == func->cend()) + return false; + // Do not inline functions with multiple returns + // TODO(greg-lunarg): Enable inlining if no return is in loop + int returnCnt = 0; + for (auto bi = func->cbegin(); bi != func->cend(); bi++) { + auto li = bi->cend(); + li--; + if (li->opcode() == SpvOpReturn || li->opcode() == SpvOpReturnValue) { + if (returnCnt > 0) + return false; + returnCnt++; + } + } + return true; +} + void InlinePass::Initialize(ir::Module* module) { def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module)); @@ -414,11 +433,14 @@ void InlinePass::Initialize(ir::Module* module) { // Initialize function and block maps. id2function_.clear(); id2block_.clear(); + inlinable_.clear(); for (auto& fn : *module_) { id2function_[fn.result_id()] = &fn; for (auto& blk : fn) { id2block_[blk.label_id()] = &blk; } + if (IsInlinableFunction(&fn)) + inlinable_.insert(fn.result_id()); } }; diff --git a/source/opt/inline_pass.h b/source/opt/inline_pass.h index 541695b..2f549cd 100644 --- a/source/opt/inline_pass.h +++ b/source/opt/inline_pass.h @@ -120,6 +120,9 @@ class InlinePass : public Pass { // Returns true if |inst| is a function call that can be inlined. bool IsInlinableFunctionCall(const ir::Instruction* inst); + // Returns true if |func| is a function that can be inlined. + bool IsInlinableFunction(const ir::Function* func); + // Exhaustively inline all function calls in func as well as in // all code that is inlined into func. Return true if func is modified. bool Inline(ir::Function* func); @@ -136,6 +139,9 @@ class InlinePass : public Pass { // Map from block's label id to block. std::unordered_map id2block_; + // Set of ids of inlinable function + std::set inlinable_; + // Next unused ID uint32_t next_id_; }; diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp index ea578ac..598bc38 100644 --- a/test/opt/inline_test.cpp +++ b/test/opt/inline_test.cpp @@ -1358,6 +1358,86 @@ TEST_F(InlineTest, OpImageAndOpSampledImageOutOfBlock) { /* skip_nop = */ false, /* do_validate = */ true); } +TEST_F(InlineTest, EarlyReturnFunctionIsNotInlined) { + // #version 140 + // + // in vec4 BaseColor; + // + // float foo(vec4 bar) + // { + // if (bar.x < 0.0) + // return 0.0; + // return bar.x; + // } + // + // void main() + // { + // vec4 color = vec4(foo(BaseColor)); + // gl_FragColor = color; + // } + + const std::string assembly = + R"(OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor +OpExecutionMode %main OriginUpperLeft +OpSource GLSL 140 +OpName %main "main" +OpName %foo_vf4_ "foo(vf4;" +OpName %bar "bar" +OpName %color "color" +OpName %BaseColor "BaseColor" +OpName %param "param" +OpName %gl_FragColor "gl_FragColor" +%void = OpTypeVoid +%10 = OpTypeFunction %void +%float = OpTypeFloat 32 +%v4float = OpTypeVector %float 4 +%_ptr_Function_v4float = OpTypePointer Function %v4float +%14 = OpTypeFunction %float %_ptr_Function_v4float +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%_ptr_Function_float = OpTypePointer Function %float +%float_0 = OpConstant %float 0 +%bool = OpTypeBool +%_ptr_Input_v4float = OpTypePointer Input %v4float +%BaseColor = OpVariable %_ptr_Input_v4float Input +%_ptr_Output_v4float = OpTypePointer Output %v4float +%gl_FragColor = OpVariable %_ptr_Output_v4float Output +%main = OpFunction %void None %10 +%22 = OpLabel +%color = OpVariable %_ptr_Function_v4float Function +%param = OpVariable %_ptr_Function_v4float Function +%23 = OpLoad %v4float %BaseColor +OpStore %param %23 +%24 = OpFunctionCall %float %foo_vf4_ %param +%25 = OpCompositeConstruct %v4float %24 %24 %24 %24 +OpStore %color %25 +%26 = OpLoad %v4float %color +OpStore %gl_FragColor %26 +OpReturn +OpFunctionEnd +%foo_vf4_ = OpFunction %float None %14 +%bar = OpFunctionParameter %_ptr_Function_v4float +%27 = OpLabel +%28 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%29 = OpLoad %float %28 +%30 = OpFOrdLessThan %bool %29 %float_0 +OpSelectionMerge %31 None +OpBranchConditional %30 %32 %31 +%32 = OpLabel +OpReturnValue %float_0 +%31 = OpLabel +%33 = OpAccessChain %_ptr_Function_float %bar %uint_0 +%34 = OpLoad %float %33 +OpReturnValue %34 +OpFunctionEnd +)"; + + SinglePassRunAndCheck(assembly, assembly, false, true); +} + TEST_F(InlineTest, ExternalFunctionIsNotInlined) { // In particular, don't crash. // See report https://github.com/KhronosGroup/SPIRV-Tools/issues/605 -- 2.7.4