From 25262b3fd95da58de45197e9b06418bdcdeee57c Mon Sep 17 00:00:00 2001 From: qining Date: Fri, 6 May 2016 17:25:16 -0400 Subject: [PATCH] Resolve comments 1. Sink adding noContraction decoration to createBinaryOperation() and createUnaryOperation(). 2. Fix comments. 3. Remove the #define of my delimiter, use global constant char. --- SPIRV/GlslangToSpv.cpp | 108 +++++---- .../propagateNoContraction.cpp | 218 +++++++++--------- 2 files changed, 163 insertions(+), 163 deletions(-) diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp index d160e36e..22575b5a 100755 --- a/SPIRV/GlslangToSpv.cpp +++ b/SPIRV/GlslangToSpv.cpp @@ -33,8 +33,6 @@ //ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE //POSSIBILITY OF SUCH DAMAGE. -// -// Author: John Kessenich, LunarG // // Visit the nodes in the glslang intermediate tree representation to // translate them to SPIR-V. @@ -135,10 +133,10 @@ protected: spv::Id createImageTextureFunctionCall(glslang::TIntermOperator* node); spv::Id handleUserFunctionCall(const glslang::TIntermAggregate*); - spv::Id createBinaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right, glslang::TBasicType typeProxy, bool reduceComparison = true); - spv::Id createBinaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right); - spv::Id createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy); - spv::Id createUnaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy); + spv::Id createBinaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id left, spv::Id right, glslang::TBasicType typeProxy, bool reduceComparison = true); + spv::Id createBinaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id left, spv::Id right); + spv::Id createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy); + spv::Id createUnaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy); spv::Id createConversion(glslang::TOperator op, spv::Decoration precision, spv::Id destTypeId, spv::Id operand); spv::Id makeSmearedConstant(spv::Id constant, int vectorSize); spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector& operands, glslang::TBasicType typeProxy); @@ -244,7 +242,7 @@ spv::StorageClass TranslateStorageClass(const glslang::TType& type) case glslang::EvqGlobal: return spv::StorageClassPrivate; case glslang::EvqConstReadOnly: return spv::StorageClassFunction; case glslang::EvqTemporary: return spv::StorageClassFunction; - default: + default: assert(0); return spv::StorageClassFunction; } @@ -567,7 +565,7 @@ spv::ImageFormat TGlslangToSpvTraverser::TranslateImageFormat(const glslang::TTy } } -// Return whether or not the given type is something that should be tied to a +// Return whether or not the given type is something that should be tied to a // descriptor set. bool IsDescriptorResource(const glslang::TType& type) { @@ -621,8 +619,7 @@ bool HasNonLayoutQualifiers(const glslang::TQualifier& qualifier) // - struct members can inherit from a struct declaration // - effect decorations on the struct members (note smooth does not, and expecting something like volatile to effect the whole object) // - are not part of the offset/st430/etc or row/column-major layout - return qualifier.invariant || qualifier.nopersp || qualifier.flat || qualifier.centroid || qualifier.patch || qualifier.sample || qualifier.hasLocation() || - qualifier.noContraction; + return qualifier.invariant || qualifier.nopersp || qualifier.flat || qualifier.centroid || qualifier.patch || qualifier.sample || qualifier.hasLocation(); } // @@ -792,7 +789,7 @@ TGlslangToSpvTraverser::~TGlslangToSpvTraverser() // // -// Symbols can turn into +// Symbols can turn into // - uniform/input reads // - output writes // - complex lvalue base setups: foo.bar[3].... , where we see foo and start up an access chain @@ -883,13 +880,11 @@ bool TGlslangToSpvTraverser::visitBinary(glslang::TVisit /* visit */, glslang::T spv::Id leftRValue = accessChainLoad(node->getLeft()->getType()); // do the operation - rValue = createBinaryOperation(node->getOp(), TranslatePrecisionDecoration(node->getType()), + rValue = createBinaryOperation(node->getOp(), TranslatePrecisionDecoration(node->getType()), + TranslateNoContractionDecoration(node->getType().getQualifier()), convertGlslangToSpvType(node->getType()), leftRValue, rValue, node->getType().getBasicType()); - // Decorate this instruction, if this node has 'noContraction' qualifier. - addDecoration(rValue, TranslateNoContractionDecoration(node->getType().getQualifier())); - // these all need their counterparts in createBinaryOperation() assert(rValue != spv::NoResult); } @@ -1005,6 +1000,7 @@ bool TGlslangToSpvTraverser::visitBinary(glslang::TVisit /* visit */, glslang::T // get result spv::Id result = createBinaryOperation(node->getOp(), TranslatePrecisionDecoration(node->getType()), + TranslateNoContractionDecoration(node->getType().getQualifier()), convertGlslangToSpvType(node->getType()), left, right, node->getLeft()->getType().getBasicType()); @@ -1013,8 +1009,6 @@ bool TGlslangToSpvTraverser::visitBinary(glslang::TVisit /* visit */, glslang::T logger->missingFunctionality("unknown glslang binary operation"); return true; // pick up a child as the place-holder result } else { - // Decorate this instruction, if this node has 'noContraction' qualifier. - addDecoration(result, TranslateNoContractionDecoration(node->getType().getQualifier())); builder.setAccessChainRValue(result); return false; } @@ -1073,6 +1067,7 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI operand = accessChainLoad(node->getOperand()->getType()); spv::Decoration precision = TranslatePrecisionDecoration(node->getType()); + spv::Decoration noContraction = TranslateNoContractionDecoration(node->getType().getQualifier()); // it could be a conversion if (! result) @@ -1080,11 +1075,9 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI // if not, then possibly an operation if (! result) - result = createUnaryOperation(node->getOp(), precision, convertGlslangToSpvType(node->getType()), operand, node->getOperand()->getBasicType()); + result = createUnaryOperation(node->getOp(), precision, noContraction, convertGlslangToSpvType(node->getType()), operand, node->getOperand()->getBasicType()); if (result) { - // Decorate this instruction, if this node has 'noContraction' qualifier. - addDecoration(result, TranslateNoContractionDecoration(node->getType().getQualifier())); builder.clearAccessChain(); builder.setAccessChainRValue(result); @@ -1113,12 +1106,11 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI else op = glslang::EOpSub; - spv::Id result = createBinaryOperation(op, TranslatePrecisionDecoration(node->getType()), + spv::Id result = createBinaryOperation(op, TranslatePrecisionDecoration(node->getType()), + TranslateNoContractionDecoration(node->getType().getQualifier()), convertGlslangToSpvType(node->getType()), operand, one, node->getType().getBasicType()); assert(result != spv::NoResult); - // Decorate this instruction, if this node has 'noContraction' qualifier. - addDecoration(result, TranslateNoContractionDecoration(node->getType().getQualifier())); // The result of operation is always stored, but conditionally the // consumed result. The consumed result is always an r-value. @@ -1350,7 +1342,7 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt break; } case glslang::EOpMul: - // compontent-wise matrix multiply + // compontent-wise matrix multiply binOp = glslang::EOpMul; break; case glslang::EOpOuterProduct: @@ -1359,7 +1351,7 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt break; case glslang::EOpDot: { - // for scalar dot product, use multiply + // for scalar dot product, use multiply glslang::TIntermSequence& glslangOperands = node->getSequence(); if (! glslangOperands[0]->getAsTyped()->isVector()) binOp = glslang::EOpMul; @@ -1414,8 +1406,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt right->traverse(this); spv::Id rightId = accessChainLoad(right->getType()); - result = createBinaryOperation(binOp, precision, - convertGlslangToSpvType(node->getType()), leftId, rightId, + result = createBinaryOperation(binOp, precision, TranslateNoContractionDecoration(node->getType().getQualifier()), + convertGlslangToSpvType(node->getType()), leftId, rightId, left->getType().getBasicType(), reduceComparison); // code above should only make binOp that exists in createBinaryOperation @@ -1488,7 +1480,11 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt result = createNoArgOperation(node->getOp()); break; case 1: - result = createUnaryOperation(node->getOp(), precision, convertGlslangToSpvType(node->getType()), operands.front(), glslangOperands[0]->getAsTyped()->getBasicType()); + result = createUnaryOperation( + node->getOp(), precision, + TranslateNoContractionDecoration(node->getType().getQualifier()), + convertGlslangToSpvType(node->getType()), operands.front(), + glslangOperands[0]->getAsTyped()->getBasicType()); break; default: result = createMiscOperation(node->getOp(), precision, convertGlslangToSpvType(node->getType()), operands, node->getBasicType()); @@ -1579,7 +1575,7 @@ bool TGlslangToSpvTraverser::visitSwitch(glslang::TVisit /* visit */, glslang::T codeSegments.push_back(child); } - // handle the case where the last code segment is missing, due to no code + // handle the case where the last code segment is missing, due to no code // statements between the last case and the end of the switch statement if ((caseValues.size() && (int)codeSegments.size() == valueIndexToSegment[caseValues.size() - 1]) || (int)codeSegments.size() == defaultSegment) @@ -1714,7 +1710,7 @@ bool TGlslangToSpvTraverser::visitBranch(glslang::TVisit /* visit */, glslang::T spv::Id TGlslangToSpvTraverser::createSpvVariable(const glslang::TIntermSymbol* node) { - // First, steer off constants, which are not SPIR-V variables, but + // First, steer off constants, which are not SPIR-V variables, but // can still have a mapping to a SPIR-V Id. // This includes specialization constants. if (node->getQualifier().isConstant()) { @@ -2018,7 +2014,7 @@ spv::Id TGlslangToSpvTraverser::makeArraySizeId(const glslang::TArraySizes& arra specNode->traverse(this); return accessChainLoad(specNode->getAsTyped()->getType()); } - + // Otherwise, need a compile-time (front end) size, get it: int size = arraySizes.getDimSize(dim); assert(size > 0); @@ -2165,7 +2161,7 @@ void TGlslangToSpvTraverser::updateMemberOffset(const glslang::TType& /*structTy // Getting this far means we need explicit offsets if (currentOffset < 0) currentOffset = 0; - + // Now, currentOffset is valid (either 0, or from a previous nextOffset), // but possibly not yet correctly aligned. @@ -2195,7 +2191,7 @@ void TGlslangToSpvTraverser::makeFunctions(const glslang::TIntermSequence& glslF // so that it's available to call. // Translating the body will happen later. // - // Typically (except for a "const in" parameter), an address will be passed to the + // Typically (except for a "const in" parameter), an address will be passed to the // function. What it is an address of varies: // // - "in" parameters not marked as "const" can be written to without modifying the argument, @@ -2265,7 +2261,7 @@ void TGlslangToSpvTraverser::visitFunctions(const glslang::TIntermSequence& glsl void TGlslangToSpvTraverser::handleFunctionEntry(const glslang::TIntermAggregate* node) { - // SPIR-V functions should already be in the functionMap from the prepass + // SPIR-V functions should already be in the functionMap from the prepass // that called makeFunctions(). spv::Function* function = functionMap[node->getName().c_str()]; spv::Block* functionBlock = function->getEntryBlock(); @@ -2679,7 +2675,8 @@ spv::Id TGlslangToSpvTraverser::handleUserFunctionCall(const glslang::TIntermAgg } // Translate AST operation to SPV operation, already having SPV-based operands/types. -spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv::Decoration precision, +spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv::Decoration precision, + spv::Decoration noContraction, spv::Id typeId, spv::Id left, spv::Id right, glslang::TBasicType typeProxy, bool reduceComparison) { @@ -2816,13 +2813,15 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv if (binOp != spv::OpNop) { assert(comparison == false); if (builder.isMatrix(left) || builder.isMatrix(right)) - return createBinaryMatrixOperation(binOp, precision, typeId, left, right); + return createBinaryMatrixOperation(binOp, precision, noContraction, typeId, left, right); // No matrix involved; make both operands be the same number of components, if needed if (needMatchingVectors) builder.promoteScalar(precision, left, right); - return builder.setPrecision(builder.createBinOp(binOp, typeId, left, right), precision); + spv::Id result = builder.createBinOp(binOp, typeId, left, right); + addDecoration(result, noContraction); + return builder.setPrecision(result, precision); } if (! comparison) @@ -2891,8 +2890,11 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv break; } - if (binOp != spv::OpNop) - return builder.setPrecision(builder.createBinOp(binOp, typeId, left, right), precision); + if (binOp != spv::OpNop) { + spv::Id result = builder.createBinOp(binOp, typeId, left, right); + addDecoration(result, noContraction); + return builder.setPrecision(result, precision); + } return 0; } @@ -2911,7 +2913,7 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv // matrix op scalar op in {+, -, /} // scalar op matrix op in {+, -, /} // -spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right) +spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id left, spv::Id right) { bool firstClass = true; @@ -2947,8 +2949,11 @@ spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Dec break; } - if (firstClass) - return builder.setPrecision(builder.createBinOp(op, typeId, left, right), precision); + if (firstClass) { + spv::Id result = builder.createBinOp(op, typeId, left, right); + addDecoration(result, noContraction); + return builder.setPrecision(result, precision); + } // Handle component-wise +, -, *, and / for all combinations of type. // The result type of all of them is the same type as the (a) matrix operand. @@ -2983,8 +2988,9 @@ spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Dec indexes.push_back(c); spv::Id leftVec = leftMat ? builder.createCompositeExtract( left, vecType, indexes) : smearVec; spv::Id rightVec = rightMat ? builder.createCompositeExtract(right, vecType, indexes) : smearVec; - results.push_back(builder.createBinOp(op, vecType, leftVec, rightVec)); - builder.setPrecision(results.back(), precision); + spv::Id result = builder.createBinOp(op, vecType, leftVec, rightVec); + addDecoration(result, noContraction); + results.push_back(builder.setPrecision(result, precision)); } // put the pieces together @@ -2996,7 +3002,7 @@ spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Dec } } -spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy) +spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy) { spv::Op unaryOp = spv::OpNop; int libCall = -1; @@ -3008,7 +3014,7 @@ spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv: if (isFloat) { unaryOp = spv::OpFNegate; if (builder.isMatrixType(typeId)) - return createUnaryMatrixOperation(unaryOp, precision, typeId, operand, typeProxy); + return createUnaryMatrixOperation(unaryOp, precision, noContraction, typeId, operand, typeProxy); } else unaryOp = spv::OpSNegate; break; @@ -3290,11 +3296,12 @@ spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv: id = builder.createUnaryOp(unaryOp, typeId, operand); } + addDecoration(id, noContraction); return builder.setPrecision(id, precision); } // Create a unary operation on a matrix -spv::Id TGlslangToSpvTraverser::createUnaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Id typeId, spv::Id operand, glslang::TBasicType /* typeProxy */) +spv::Id TGlslangToSpvTraverser::createUnaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand, glslang::TBasicType /* typeProxy */) { // Handle unary operations vector by vector. // The result type is the same type as the original type. @@ -3315,8 +3322,9 @@ spv::Id TGlslangToSpvTraverser::createUnaryMatrixOperation(spv::Op op, spv::Deco std::vector indexes; indexes.push_back(c); spv::Id vec = builder.createCompositeExtract(operand, vecType, indexes); - results.push_back(builder.createUnaryOp(op, vecType, vec)); - builder.setPrecision(results.back(), precision); + spv::Id vec_result = builder.createUnaryOp(op, vecType, vec); + addDecoration(vec_result, noContraction); + results.push_back(builder.setPrecision(vec_result, precision)); } // put the pieces together @@ -4144,7 +4152,7 @@ bool TGlslangToSpvTraverser::isTrivialLeaf(const glslang::TIntermTyped* node) default: return false; } -} +} // A node is trivial if it is a single operation with no side effects. // Error on the side of saying non-trivial. diff --git a/glslang/MachineIndependent/propagateNoContraction.cpp b/glslang/MachineIndependent/propagateNoContraction.cpp index fa2b62a9..f4b88d16 100644 --- a/glslang/MachineIndependent/propagateNoContraction.cpp +++ b/glslang/MachineIndependent/propagateNoContraction.cpp @@ -47,18 +47,27 @@ #include "localintermediate.h" namespace { -// Use string to hold the accesschain information, as in most cases we the -// accesschain is short and may contain only one element, which is the symbol ID. +// Use string to hold the accesschain information, as in most cases the +// accesschain is short and may contain only one element, which is the symbol +// ID. +// Example: struct {float a; float b;} s; +// Object s.a will be represented with: /0 +// Object s.b will be represented with: /1 +// Object s will be representend with: +// For members of vector, matrix and arrays, they will be represented with the +// same symbol ID of their container symbol objects. This is because their +// precise'ness is always the same as their container symbol objects. using ObjectAccessChain = std::string; -#ifndef StructAccessChainDelimiter -#define StructAccessChainDelimiter '/' -#endif + +// The delimiter used in the ObjectAccessChain string to separate symbol ID and +// different level of struct indices. +const char OBJECT_ACCESSCHAIN_DELIMITER = '/'; // Mapping from Symbol IDs of symbol nodes, to their defining operation // nodes. -using NodeMapping = std::unordered_multimap; +using NodeMapping = std::unordered_multimap; // Mapping from object nodes to their accesschain info string. -using AccessChainMapping = std::unordered_map; +using AccessChainMapping = std::unordered_map; // Set of object IDs. using ObjectAccesschainSet = std::unordered_set; @@ -67,7 +76,7 @@ using ReturnBranchNodeSet = std::unordered_set; // A helper function to tell whether a node is 'noContraction'. Returns true if // the node has 'noContraction' qualifier, otherwise false. -bool isPreciseObjectNode(glslang::TIntermTyped *node) +bool isPreciseObjectNode(glslang::TIntermTyped* node) { return node->getType().getQualifier().noContraction; } @@ -118,7 +127,7 @@ bool isAssignOperation(glslang::TOperator op) // A helper function to get the unsigned int from a given constant union node. // Note the node should only holds a uint scalar. -unsigned getStructIndexFromConstantUnion(glslang::TIntermTyped *node) +unsigned getStructIndexFromConstantUnion(glslang::TIntermTyped* node) { assert(node->getAsConstantUnion() && node->getAsConstantUnion()->isScalar()); unsigned struct_dereference_index = node->getAsConstantUnion()->getConstArray()[0].getUConst(); @@ -126,9 +135,10 @@ unsigned getStructIndexFromConstantUnion(glslang::TIntermTyped *node) } // A helper function to generate symbol_label. -ObjectAccessChain generateSymbolLabel(glslang::TIntermSymbol *node) +ObjectAccessChain generateSymbolLabel(glslang::TIntermSymbol* node) { - ObjectAccessChain symbol_id = std::to_string(node->getId()) + "(" + node->getName().c_str() + ")"; + ObjectAccessChain symbol_id = + std::to_string(node->getId()) + "(" + node->getName().c_str() + ")"; return symbol_id; } @@ -180,43 +190,41 @@ bool isArithmeticOperation(glslang::TOperator op) // A helper class to help managing populating_initial_no_contraction_ flag. template class StateSettingGuard { public: - StateSettingGuard(T *state_ptr, T new_state_value) + StateSettingGuard(T* state_ptr, T new_state_value) : state_ptr_(state_ptr), previous_state_(*state_ptr) { *state_ptr = new_state_value; } - StateSettingGuard(T *state_ptr) : state_ptr_(state_ptr), previous_state_(*state_ptr) {} - void setState(T new_state_value) - { - *state_ptr_ = new_state_value; - } + StateSettingGuard(T* state_ptr) : state_ptr_(state_ptr), previous_state_(*state_ptr) {} + void setState(T new_state_value) { *state_ptr_ = new_state_value; } ~StateSettingGuard() { *state_ptr_ = previous_state_; } private: - T *state_ptr_; + T* state_ptr_; T previous_state_; }; // A helper function to get the front element from a given ObjectAccessChain -ObjectAccessChain getFrontElement(const ObjectAccessChain &chain) +ObjectAccessChain getFrontElement(const ObjectAccessChain& chain) { - size_t pos_delimiter = chain.find(StructAccessChainDelimiter); + size_t pos_delimiter = chain.find(OBJECT_ACCESSCHAIN_DELIMITER); return pos_delimiter == std::string::npos ? chain : chain.substr(0, pos_delimiter); } // A helper function to get the accesschain starting from the second element. ObjectAccessChain subAccessChainFromSecondElement(const ObjectAccessChain& chain) { - size_t pos_delimiter = chain.find(StructAccessChainDelimiter); + size_t pos_delimiter = chain.find(OBJECT_ACCESSCHAIN_DELIMITER); return pos_delimiter == std::string::npos ? "" : chain.substr(pos_delimiter + 1); } // A helper function to get the accesschain after removing a given prefix. -ObjectAccessChain getSubAccessChainAfterPrefix(const ObjectAccessChain &chain, const ObjectAccessChain &prefix) +ObjectAccessChain getSubAccessChainAfterPrefix(const ObjectAccessChain& chain, + const ObjectAccessChain& prefix) { size_t pos = chain.find(prefix); if (pos != 0) return chain; - return chain.substr(prefix.length() + sizeof(StructAccessChainDelimiter)); + return chain.substr(prefix.length() + sizeof(OBJECT_ACCESSCHAIN_DELIMITER)); } // @@ -226,34 +234,33 @@ ObjectAccessChain getSubAccessChainAfterPrefix(const ObjectAccessChain &chain, c // class TSymbolDefinitionCollectingTraverser : public glslang::TIntermTraverser { public: - TSymbolDefinitionCollectingTraverser( - NodeMapping *symbol_definition_mapping, AccessChainMapping *accesschain_mapping, - ObjectAccesschainSet *precise_objects, - ReturnBranchNodeSet *precise_return_nodes); - - // bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate *) override; - bool visitUnary(glslang::TVisit, glslang::TIntermUnary *) override; - bool visitBinary(glslang::TVisit, glslang::TIntermBinary *) override; - void visitSymbol(glslang::TIntermSymbol *) override; - bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate *) override; - bool visitBranch(glslang::TVisit, glslang::TIntermBranch *) override; + TSymbolDefinitionCollectingTraverser(NodeMapping* symbol_definition_mapping, + AccessChainMapping* accesschain_mapping, + ObjectAccesschainSet* precise_objects, + ReturnBranchNodeSet* precise_return_nodes); + + bool visitUnary(glslang::TVisit, glslang::TIntermUnary*) override; + bool visitBinary(glslang::TVisit, glslang::TIntermBinary*) override; + void visitSymbol(glslang::TIntermSymbol*) override; + bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate*) override; + bool visitBranch(glslang::TVisit, glslang::TIntermBranch*) override; protected: // The mapping from symbol node IDs to their defining nodes. This should be // populated along traversing the AST. - NodeMapping &symbol_definition_mapping_; + NodeMapping& symbol_definition_mapping_; // The set of symbol node IDs for precise symbol nodes, the ones marked as // 'noContraction'. - ObjectAccesschainSet &precise_objects_; + ObjectAccesschainSet& precise_objects_; // The set of precise return nodes. - ReturnBranchNodeSet &precise_return_nodes_; + ReturnBranchNodeSet& precise_return_nodes_; // A temporary cache of the symbol node whose defining node is to be found // currently along traversing the AST. ObjectAccessChain object_to_be_defined_; // A map from object node to its accesschain. This traverser stores // the built accesschains into this map for each object node it has // visited. - AccessChainMapping &accesschain_mapping_; + AccessChainMapping& accesschain_mapping_; // The pointer to the Function Definition node, so we can get the // precise'ness of the return expression from it when we traverse the // return branch node. @@ -261,9 +268,9 @@ protected: }; TSymbolDefinitionCollectingTraverser::TSymbolDefinitionCollectingTraverser( - NodeMapping *symbol_definition_mapping, AccessChainMapping *accesschain_mapping, - ObjectAccesschainSet *precise_objects, - std::unordered_set *precise_return_nodes) + NodeMapping* symbol_definition_mapping, AccessChainMapping* accesschain_mapping, + ObjectAccesschainSet* precise_objects, + std::unordered_set* precise_return_nodes) : TIntermTraverser(true, false, false), symbol_definition_mapping_(*symbol_definition_mapping), precise_objects_(*precise_objects), object_to_be_defined_(), accesschain_mapping_(*accesschain_mapping), current_function_definition_node_(nullptr), @@ -273,7 +280,7 @@ TSymbolDefinitionCollectingTraverser::TSymbolDefinitionCollectingTraverser( // current node symbol ID, and record a mapping from this node to the current // object_to_be_defined_, which is the just obtained symbol // ID. -void TSymbolDefinitionCollectingTraverser::visitSymbol(glslang::TIntermSymbol *node) +void TSymbolDefinitionCollectingTraverser::visitSymbol(glslang::TIntermSymbol* node) { object_to_be_defined_ = generateSymbolLabel(node); accesschain_mapping_[node] = object_to_be_defined_; @@ -281,12 +288,12 @@ void TSymbolDefinitionCollectingTraverser::visitSymbol(glslang::TIntermSymbol *n // Visits an aggregate node, traverses all of its children. bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit, - glslang::TIntermAggregate *node) + glslang::TIntermAggregate* node) { // This aggreagate node might be a function definition node, in which case we need to // cache this node, so we can get the precise'ness information of the return value // of this function later. - StateSettingGuard current_function_definition_node_setting_guard( + StateSettingGuard current_function_definition_node_setting_guard( ¤t_function_definition_node_); if (node->getOp() == glslang::EOpFunction) { // This is function definition node, we need to cache this node so that we can @@ -294,7 +301,7 @@ bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit, current_function_definition_node_setting_guard.setState(node); } // Traverse the items in the sequence. - glslang::TIntermSequence &seq = node->getSequence(); + glslang::TIntermSequence& seq = node->getSequence(); for (int i = 0; i < (int)seq.size(); ++i) { object_to_be_defined_.clear(); seq[i]->traverse(this); @@ -303,7 +310,7 @@ bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit, } bool TSymbolDefinitionCollectingTraverser::visitBranch(glslang::TVisit, - glslang::TIntermBranch *node) + glslang::TIntermBranch* node) { if (node->getFlowOp() == glslang::EOpReturn && node->getExpression() && current_function_definition_node_ && @@ -319,7 +326,7 @@ bool TSymbolDefinitionCollectingTraverser::visitBranch(glslang::TVisit, // Visits an unary node. This might be an implicit assignment like i++, i--. etc. bool TSymbolDefinitionCollectingTraverser::visitUnary(glslang::TVisit /* visit */, - glslang::TIntermUnary *node) + glslang::TIntermUnary* node) { object_to_be_defined_.clear(); node->getOperand()->traverse(this); @@ -351,7 +358,7 @@ bool TSymbolDefinitionCollectingTraverser::visitUnary(glslang::TVisit /* visit * // Visits a binary node and updates the mapping from symbol IDs to the definition // nodes. Also collects the accesschains for the initial precise objects. bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit */, - glslang::TIntermBinary *node) + glslang::TIntermBinary* node) { // Traverses the left node to build the accesschain info for the object. object_to_be_defined_.clear(); @@ -408,7 +415,7 @@ bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit // object. We need to record the accesschain information of the current // node into its object id. unsigned struct_dereference_index = getStructIndexFromConstantUnion(node->getRight()); - object_to_be_defined_.push_back(StructAccessChainDelimiter); + object_to_be_defined_.push_back(OBJECT_ACCESSCHAIN_DELIMITER); object_to_be_defined_.append(std::to_string(struct_dereference_index)); accesschain_mapping_[node] = object_to_be_defined_; @@ -428,17 +435,18 @@ bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit // 2) a mapping from object nodes in the AST to the accesschains of these objects. // 3) a set of accesschains of precise objects. std::tuple -getSymbolToDefinitionMappingAndPreciseSymbolIDs(const glslang::TIntermediate &intermediate) +getSymbolToDefinitionMappingAndPreciseSymbolIDs(const glslang::TIntermediate& intermediate) { - auto result_tuple = std::make_tuple(NodeMapping(), AccessChainMapping(), ObjectAccesschainSet(), ReturnBranchNodeSet()); + auto result_tuple = std::make_tuple(NodeMapping(), AccessChainMapping(), ObjectAccesschainSet(), + ReturnBranchNodeSet()); - TIntermNode *root = intermediate.getTreeRoot(); + TIntermNode* root = intermediate.getTreeRoot(); if (root == 0) return result_tuple; - NodeMapping &symbol_definition_mapping = std::get<0>(result_tuple); - AccessChainMapping &accesschain_mapping = std::get<1>(result_tuple); - ObjectAccesschainSet &precise_objects = std::get<2>(result_tuple); - ReturnBranchNodeSet &precise_return_nodes = std::get<3>(result_tuple); + NodeMapping& symbol_definition_mapping = std::get<0>(result_tuple); + AccessChainMapping& accesschain_mapping = std::get<1>(result_tuple); + ObjectAccesschainSet& precise_objects = std::get<2>(result_tuple); + ReturnBranchNodeSet& precise_return_nodes = std::get<3>(result_tuple); // Traverses the AST and populate the results. TSymbolDefinitionCollectingTraverser collector(&symbol_definition_mapping, &accesschain_mapping, @@ -474,7 +482,7 @@ class TNoContractionAssigneeCheckingTraverser : public glslang::TIntermTraverser }; public: - TNoContractionAssigneeCheckingTraverser(const AccessChainMapping &accesschain_mapping) + TNoContractionAssigneeCheckingTraverser(const AccessChainMapping& accesschain_mapping) : TIntermTraverser(true, false, false), accesschain_mapping_(accesschain_mapping), precise_object_(nullptr) {} @@ -494,7 +502,7 @@ public: // precise object. std::tuple getPrecisenessAndRemainedAccessChain(glslang::TIntermOperator* node, - const ObjectAccessChain &precise_object) + const ObjectAccessChain& precise_object) { assert(isAssignOperation(node->getOp())); precise_object_ = &precise_object; @@ -570,23 +578,23 @@ public: } protected: - bool visitBinary(glslang::TVisit, glslang::TIntermBinary *node) override; - void visitSymbol(glslang::TIntermSymbol *node) override; + bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override; + void visitSymbol(glslang::TIntermSymbol* node) override; // A map from object nodes to their accesschain string (used as object ID). - const AccessChainMapping &accesschain_mapping_; + const AccessChainMapping& accesschain_mapping_; // A given precise object, represented in it accesschain string. This // precise object is used to be compared with the assignee node to tell if // the assignee node is 'precise', contains 'precise' object or not // 'precise'. - const ObjectAccessChain *precise_object_; + const ObjectAccessChain* precise_object_; }; // Visit a binary node. If the node is an object node, it must be a dereference // node. In such cases, if the left node is 'precise', this node should also be // 'precise'. bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit, - glslang::TIntermBinary *node) + glslang::TIntermBinary* node) { // Traverses the left so that we transfer the 'precise' from nesting object // to its nested object. @@ -602,7 +610,7 @@ bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit, // this node should be marked as 'precise'. if (isPreciseObjectNode(node->getLeft())) { node->getWritableType().getQualifier().noContraction = true; - } else if (accesschain_mapping_.at(node) == *precise_object_){ + } else if (accesschain_mapping_.at(node) == *precise_object_) { node->getWritableType().getQualifier().noContraction = true; } } @@ -611,7 +619,7 @@ bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit, // Visit a symbol node, if the symbol node ID (its accesschain string) matches // with the given precise object, this node should be 'precise'. -void TNoContractionAssigneeCheckingTraverser::visitSymbol(glslang::TIntermSymbol *node) +void TNoContractionAssigneeCheckingTraverser::visitSymbol(glslang::TIntermSymbol* node) { // A symbol node should always be an object node, and should have been added // to the map from object nodes to their accesschain strings. @@ -632,27 +640,27 @@ void TNoContractionAssigneeCheckingTraverser::visitSymbol(glslang::TIntermSymbol // class TNoContractionPropagator : public glslang::TIntermTraverser { public: - TNoContractionPropagator(ObjectAccesschainSet *precise_objects, - const AccessChainMapping &accesschain_mapping) + TNoContractionPropagator(ObjectAccesschainSet* precise_objects, + const AccessChainMapping& accesschain_mapping) : TIntermTraverser(true, false, false), remained_accesschain_(), - precise_objects_(*precise_objects), - accesschain_mapping_(accesschain_mapping), added_precise_object_ids_() {} + precise_objects_(*precise_objects), accesschain_mapping_(accesschain_mapping), + added_precise_object_ids_() {} // Propagates 'precise' in the right nodes of a given assignment node with // accesschain record from the assignee node to a 'precise' object it // contains. void - propagateNoContractionInOneExpression(glslang::TIntermTyped *defining_node, - const ObjectAccessChain &assignee_remained_accesschain) + propagateNoContractionInOneExpression(glslang::TIntermTyped* defining_node, + const ObjectAccessChain& assignee_remained_accesschain) { remained_accesschain_ = assignee_remained_accesschain; - if (glslang::TIntermBinary *BN = defining_node->getAsBinaryNode()) { + if (glslang::TIntermBinary* BN = defining_node->getAsBinaryNode()) { assert(isAssignOperation(BN->getOp())); BN->getRight()->traverse(this); if (isArithmeticOperation(BN->getOp())) { BN->getWritableType().getQualifier().noContraction = true; } - } else if (glslang::TIntermUnary *UN = defining_node->getAsUnaryNode()) { + } else if (glslang::TIntermUnary* UN = defining_node->getAsUnaryNode()) { assert(isAssignOperation(UN->getOp())); UN->getOperand()->traverse(this); if (isArithmeticOperation(UN->getOp())) { @@ -662,8 +670,7 @@ public: } // Propagates 'precise' in a given precise return node. - void - propagateNoContractionInReturnNode(glslang::TIntermBranch *return_node) + void propagateNoContractionInReturnNode(glslang::TIntermBranch* return_node) { remained_accesschain_ = ""; assert(return_node->getFlowOp() == glslang::EOpReturn && return_node->getExpression()); @@ -675,7 +682,7 @@ protected: // case we need to find the 'precise' or 'precise' containing object node // with the accesschain record. In other cases, just need to traverse all // the children nodes. - bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate *node) override + bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate* node) override { if (!remained_accesschain_.empty() && node->getOp() == glslang::EOpConstructStruct) { // This is a struct initializer node, and the remained @@ -689,7 +696,7 @@ protected: getFrontElement(remained_accesschain_); unsigned precise_accesschain_index = std::stoul(precise_accesschain_index_str); // Gets the node pointed by the accesschain index extracted before. - glslang::TIntermTyped *potential_precise_node = + glslang::TIntermTyped* potential_precise_node = node->getSequence()[precise_accesschain_index]->getAsTyped(); assert(potential_precise_node); // Pop the front accesschain index from the path, and visit the nested node. @@ -700,16 +707,9 @@ protected: &remained_accesschain_, next_level_accesschain); potential_precise_node->traverse(this); } - - } else { - // If this is not a struct constructor, just visit each nested node. - glslang::TIntermSequence &seq = node->getSequence(); - for (int i = 0; i < (int)seq.size(); ++i) { - seq[i]->traverse(this); - } + return false; } - - return false; + return true; } // Visit a binary node. A binary node can be an object node, e.g. a dereference node. @@ -718,7 +718,7 @@ protected: // an object node. If the binary node does not represent an object node, it should // go on to traverse its children nodes and if it is an arithmetic operation node, this // operation should be marked as 'noContraction'. - bool visitBinary(glslang::TVisit, glslang::TIntermBinary *node) override + bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override { if (isDereferenceOperation(node->getOp())) { // This binary node is an object node. Need to update the precise @@ -728,8 +728,7 @@ protected: if (remained_accesschain_.empty()) { node->getWritableType().getQualifier().noContraction = true; } else { - new_precise_accesschain += - StructAccessChainDelimiter + remained_accesschain_; + new_precise_accesschain += OBJECT_ACCESSCHAIN_DELIMITER + remained_accesschain_; } // Cache the accesschain as added precise object, so we won't add the // same object to the worklist again. @@ -746,21 +745,18 @@ protected: node->getWritableType().getQualifier().noContraction = true; } // As this node is not an object node, need to traverse the children nodes. - node->getLeft()->traverse(this); - node->getRight()->traverse(this); - return false; + return true; } // Visits an unary node. An unary node can not be an object node. If the operation // is an arithmetic operation, need to mark this node as 'noContraction'. - bool visitUnary(glslang::TVisit /* visit */, glslang::TIntermUnary *node) override + bool visitUnary(glslang::TVisit /* visit */, glslang::TIntermUnary* node) override { // If this is an arithmetic operation, marks this with 'noContraction' if (isArithmeticOperation(node->getOp())) { node->getWritableType().getQualifier().noContraction = true; } - node->getOperand()->traverse(this); - return false; + return true; } // Visits a symbol node. A symbol node is always an object node. So we @@ -768,7 +764,7 @@ protected: // nodes to accesschains. As an object node, a symbol node can be either // 'precise' or containing 'precise' objects according to unused // accesschain information we have when we visit this node. - void visitSymbol(glslang::TIntermSymbol *node) override + void visitSymbol(glslang::TIntermSymbol* node) override { // Symbol nodes are object nodes and should always have an // accesschain collected before matches with it. @@ -781,7 +777,7 @@ protected: if (remained_accesschain_.empty()) { node->getWritableType().getQualifier().noContraction = true; } else { - new_precise_accesschain += StructAccessChainDelimiter + remained_accesschain_; + new_precise_accesschain += OBJECT_ACCESSCHAIN_DELIMITER + remained_accesschain_; } // Add the new 'precise' accesschain to the worklist and make sure we // don't visit it again. @@ -792,7 +788,7 @@ protected: } // A set of precise objects, represented as accesschains. - ObjectAccesschainSet &precise_objects_; + ObjectAccesschainSet& precise_objects_; // Visited symbol nodes, should not revisit these nodes. ObjectAccesschainSet added_precise_object_ids_; // The left node of an assignment operation might be an parent of 'precise' objects. @@ -802,15 +798,13 @@ protected: // tell us how to find the corresponding 'precise' node in the right. ObjectAccessChain remained_accesschain_; // A map from node pointers to their accesschains. - const AccessChainMapping &accesschain_mapping_; + const AccessChainMapping& accesschain_mapping_; }; - -#undef StructAccessChainDelimiter } namespace glslang { -void PropagateNoContraction(const glslang::TIntermediate &intermediate) +void PropagateNoContraction(const glslang::TIntermediate& intermediate) { // First, traverses the AST, records symbols with their defining operations // and collects the initial set of precise symbols (symbol nodes that marked @@ -821,18 +815,17 @@ void PropagateNoContraction(const glslang::TIntermediate &intermediate) // The mapping of symbol node IDs to their defining nodes. This enables us // to get the defining node directly from a given symbol ID without // traversing the tree again. - NodeMapping &symbol_definition_mapping = std::get<0>(mappings_and_precise_objects); + NodeMapping& symbol_definition_mapping = std::get<0>(mappings_and_precise_objects); // The mapping of object nodes to their accesschains recorded. - AccessChainMapping &accesschain_mapping = std::get<1>(mappings_and_precise_objects); + AccessChainMapping& accesschain_mapping = std::get<1>(mappings_and_precise_objects); // The initial set of 'precise' objects which are represented as the // accesschain toward them. - ObjectAccesschainSet &precise_object_accesschains = - std::get<2>(mappings_and_precise_objects); + ObjectAccesschainSet& precise_object_accesschains = std::get<2>(mappings_and_precise_objects); // The set of 'precise' return nodes. - ReturnBranchNodeSet &precise_return_nodes = std::get<3>(mappings_and_precise_objects); + ReturnBranchNodeSet& precise_return_nodes = std::get<3>(mappings_and_precise_objects); // Second, uses the initial set of precise objects as a worklist, pops an // accesschain, extract the symbol ID from it. Then: @@ -845,10 +838,9 @@ void PropagateNoContraction(const glslang::TIntermediate &intermediate) // 'precise' accesschain worklist with new found object nodes. // Repeat above steps until the worklist is empty. TNoContractionAssigneeCheckingTraverser checker(accesschain_mapping); - TNoContractionPropagator propagator(&precise_object_accesschains, - accesschain_mapping); + TNoContractionPropagator propagator(&precise_object_accesschains, accesschain_mapping); - // We have to initial precise worklist to handle: + // We have two initial precise worklists to handle: // 1) precise return nodes // 2) precise object accesschains // We should process the precise return nodes first and the involved @@ -877,12 +869,12 @@ void PropagateNoContraction(const glslang::TIntermediate &intermediate) // objects, and mark arithmetic operations as 'noContraction'. for (NodeMapping::iterator defining_node_iter = range.first; defining_node_iter != range.second; defining_node_iter++) { - TIntermOperator *defining_node = defining_node_iter->second; + TIntermOperator* defining_node = defining_node_iter->second; // Check the assignee node. auto checker_result = checker.getPrecisenessAndRemainedAccessChain( defining_node, precise_object_accesschain); - bool &contain_precise = std::get<0>(checker_result); - ObjectAccessChain &remained_accesschain = std::get<1>(checker_result); + bool& contain_precise = std::get<0>(checker_result); + ObjectAccessChain& remained_accesschain = std::get<1>(checker_result); // If the assignee node is 'precise' or contains 'precise', propagate the // 'precise' to the right. Otherwise just skip this assignment node. if (contain_precise) { -- 2.34.1