2 // Copyright (C) 2015 LunarG, Inc.
4 // All rights reserved.
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions
10 // Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
13 // Redistributions in binary form must reproduce the above
14 // copyright notice, this list of conditions and the following
15 // disclaimer in the documentation and/or other materials provided
16 // with the distribution.
18 // Neither the name of 3Dlabs Inc. Ltd. nor the names of its
19 // contributors may be used to endorse or promote products derived
20 // from this software without specific prior written permission.
22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33 // POSSIBILITY OF SUCH DAMAGE.
36 #include "SPVRemapper.h"
39 #if !defined (use_cpp11)
40 // ... not supported before C++11
41 #else // defined (use_cpp11)
45 #include "../glslang/Include/Common.h"
49 // By default, just abort on error. Can be overridden via RegisterErrorHandler
50 spirvbin_t::errorfn_t spirvbin_t::errorHandler = [](const std::string&) { exit(5); };
51 // By default, eat log messages. Can be overridden via RegisterLogHandler
52 spirvbin_t::logfn_t spirvbin_t::logHandler = [](const std::string&) { };
54 // This can be overridden to provide other message behavior if needed
55 void spirvbin_t::msg(int minVerbosity, int indent, const std::string& txt) const
57 if (verbose >= minVerbosity)
58 logHandler(std::string(indent, ' ') + txt);
61 // hash opcode, with special handling for OpExtInst
62 std::uint32_t spirvbin_t::asOpCodeHash(unsigned word)
64 const spv::Op opCode = asOpCode(word);
66 std::uint32_t offset = 0;
70 offset += asId(word + 4); break;
75 return opCode * 19 + offset; // 19 = small prime
78 spirvbin_t::range_t spirvbin_t::literalRange(spv::Op opCode) const
80 static const int maxCount = 1<<30;
83 case spv::OpTypeFloat: // fall through...
84 case spv::OpTypePointer: return range_t(2, 3);
85 case spv::OpTypeInt: return range_t(2, 4);
86 // TODO: case spv::OpTypeImage:
87 // TODO: case spv::OpTypeSampledImage:
88 case spv::OpTypeSampler: return range_t(3, 8);
89 case spv::OpTypeVector: // fall through
90 case spv::OpTypeMatrix: // ...
91 case spv::OpTypePipe: return range_t(3, 4);
92 case spv::OpConstant: return range_t(3, maxCount);
93 default: return range_t(0, 0);
97 spirvbin_t::range_t spirvbin_t::typeRange(spv::Op opCode) const
99 static const int maxCount = 1<<30;
101 if (isConstOp(opCode))
102 return range_t(1, 2);
105 case spv::OpTypeVector: // fall through
106 case spv::OpTypeMatrix: // ...
107 case spv::OpTypeSampler: // ...
108 case spv::OpTypeArray: // ...
109 case spv::OpTypeRuntimeArray: // ...
110 case spv::OpTypePipe: return range_t(2, 3);
111 case spv::OpTypeStruct: // fall through
112 case spv::OpTypeFunction: return range_t(2, maxCount);
113 case spv::OpTypePointer: return range_t(3, 4);
114 default: return range_t(0, 0);
118 spirvbin_t::range_t spirvbin_t::constRange(spv::Op opCode) const
120 static const int maxCount = 1<<30;
123 case spv::OpTypeArray: // fall through...
124 case spv::OpTypeRuntimeArray: return range_t(3, 4);
125 case spv::OpConstantComposite: return range_t(3, maxCount);
126 default: return range_t(0, 0);
130 // Return the size of a type in 32-bit words. This currently only
131 // handles ints and floats, and is only invoked by queries which must be
132 // integer types. If ever needed, it can be generalized.
133 unsigned spirvbin_t::typeSizeInWords(spv::Id id) const
135 const unsigned typeStart = idPos(id);
136 const spv::Op opCode = asOpCode(typeStart);
142 case spv::OpTypeInt: // fall through...
143 case spv::OpTypeFloat: return (spv[typeStart+2]+31)/32;
149 // Looks up the type of a given const or variable ID, and
150 // returns its size in 32-bit words.
151 unsigned spirvbin_t::idTypeSizeInWords(spv::Id id) const
153 const auto tid_it = idTypeSizeMap.find(id);
154 if (tid_it == idTypeSizeMap.end()) {
155 error("type size for ID not found");
159 return tid_it->second;
162 // Is this an opcode we should remove when using --strip?
163 bool spirvbin_t::isStripOp(spv::Op opCode) const
167 case spv::OpSourceExtension:
169 case spv::OpMemberName:
170 case spv::OpLine: return true;
171 default: return false;
175 // Return true if this opcode is flow control
176 bool spirvbin_t::isFlowCtrl(spv::Op opCode) const
179 case spv::OpBranchConditional:
182 case spv::OpLoopMerge:
183 case spv::OpSelectionMerge:
185 case spv::OpFunction:
186 case spv::OpFunctionEnd: return true;
187 default: return false;
191 // Return true if this opcode defines a type
192 bool spirvbin_t::isTypeOp(spv::Op opCode) const
195 case spv::OpTypeVoid:
196 case spv::OpTypeBool:
198 case spv::OpTypeFloat:
199 case spv::OpTypeVector:
200 case spv::OpTypeMatrix:
201 case spv::OpTypeImage:
202 case spv::OpTypeSampler:
203 case spv::OpTypeArray:
204 case spv::OpTypeRuntimeArray:
205 case spv::OpTypeStruct:
206 case spv::OpTypeOpaque:
207 case spv::OpTypePointer:
208 case spv::OpTypeFunction:
209 case spv::OpTypeEvent:
210 case spv::OpTypeDeviceEvent:
211 case spv::OpTypeReserveId:
212 case spv::OpTypeQueue:
213 case spv::OpTypeSampledImage:
214 case spv::OpTypePipe: return true;
215 default: return false;
219 // Return true if this opcode defines a constant
220 bool spirvbin_t::isConstOp(spv::Op opCode) const
223 case spv::OpConstantSampler:
224 error("unimplemented constant type");
227 case spv::OpConstantNull:
228 case spv::OpConstantTrue:
229 case spv::OpConstantFalse:
230 case spv::OpConstantComposite:
231 case spv::OpConstant:
239 const auto inst_fn_nop = [](spv::Op, unsigned) { return false; };
240 const auto op_fn_nop = [](spv::Id&) { };
242 // g++ doesn't like these defined in the class proper in an anonymous namespace.
243 // Dunno why. Also MSVC doesn't like the constexpr keyword. Also dunno why.
244 // Defining them externally seems to please both compilers, so, here they are.
245 const spv::Id spirvbin_t::unmapped = spv::Id(-10000);
246 const spv::Id spirvbin_t::unused = spv::Id(-10001);
247 const int spirvbin_t::header_size = 5;
249 spv::Id spirvbin_t::nextUnusedId(spv::Id id)
251 while (isNewIdMapped(id)) // search for an unused ID
257 spv::Id spirvbin_t::localId(spv::Id id, spv::Id newId)
259 //assert(id != spv::NoResult && newId != spv::NoResult);
262 error(std::string("ID out of range: ") + std::to_string(id));
263 return spirvbin_t::unused;
266 if (id >= idMapL.size())
267 idMapL.resize(id+1, unused);
269 if (newId != unmapped && newId != unused) {
270 if (isOldIdUnused(id)) {
271 error(std::string("ID unused in module: ") + std::to_string(id));
272 return spirvbin_t::unused;
275 if (!isOldIdUnmapped(id)) {
276 error(std::string("ID already mapped: ") + std::to_string(id) + " -> "
277 + std::to_string(localId(id)));
279 return spirvbin_t::unused;
282 if (isNewIdMapped(newId)) {
283 error(std::string("ID already used in module: ") + std::to_string(newId));
284 return spirvbin_t::unused;
287 msg(4, 4, std::string("map: ") + std::to_string(id) + " -> " + std::to_string(newId));
289 largestNewId = std::max(largestNewId, newId);
292 return idMapL[id] = newId;
295 // Parse a literal string from the SPIR binary and return it as an std::string
296 // Due to C++11 RValue references, this doesn't copy the result string.
297 std::string spirvbin_t::literalString(unsigned word) const
300 const spirword_t * pos = spv.data() + word;
305 spirword_t word = *pos;
306 for (int i = 0; i < 4; i++) {
307 char c = word & 0xff;
317 void spirvbin_t::applyMap()
319 msg(3, 2, std::string("Applying map: "));
321 // Map local IDs through the ID map
322 process(inst_fn_nop, // ignore instructions
323 [this](spv::Id& id) {
329 assert(id != unused && id != unmapped);
334 // Find free IDs for anything we haven't mapped
335 void spirvbin_t::mapRemainder()
337 msg(3, 2, std::string("Remapping remainder: "));
339 spv::Id unusedId = 1; // can't use 0: that's NoResult
340 spirword_t maxBound = 0;
342 for (spv::Id id = 0; id < idMapL.size(); ++id) {
343 if (isOldIdUnused(id))
346 // Find a new mapping for any used but unmapped IDs
347 if (isOldIdUnmapped(id)) {
348 localId(id, unusedId = nextUnusedId(unusedId));
353 if (isOldIdUnmapped(id)) {
354 error(std::string("old ID not mapped: ") + std::to_string(id));
359 maxBound = std::max(maxBound, localId(id) + 1);
365 bound(maxBound); // reset header ID bound to as big as it now needs to be
368 // Mark debug instructions for stripping
369 void spirvbin_t::stripDebug()
371 // Strip instructions in the stripOp set: debug info.
373 [&](spv::Op opCode, unsigned start) {
374 // remember opcodes we want to strip later
375 if (isStripOp(opCode))
382 // Mark instructions that refer to now-removed IDs for stripping
383 void spirvbin_t::stripDeadRefs()
386 [&](spv::Op opCode, unsigned start) {
387 // strip opcodes pointing to removed data
390 case spv::OpMemberName:
391 case spv::OpDecorate:
392 case spv::OpMemberDecorate:
393 if (idPosR.find(asId(start+1)) == idPosR.end())
397 break; // leave it alone
407 // Update local maps of ID, type, etc positions
408 void spirvbin_t::buildLocalMaps()
410 msg(2, 2, std::string("build local maps: "));
414 // preserve nameMap, so we don't clear that.
417 typeConstPos.clear();
419 entryPoint = spv::NoResult;
422 idMapL.resize(bound(), unused);
425 spv::Id fnRes = spv::NoResult;
427 // build local Id and name maps
429 [&](spv::Op opCode, unsigned start) {
430 unsigned word = start+1;
431 spv::Id typeId = spv::NoResult;
433 if (spv::InstructionDesc[opCode].hasType())
434 typeId = asId(word++);
436 // If there's a result ID, remember the size of its type
437 if (spv::InstructionDesc[opCode].hasResult()) {
438 const spv::Id resultId = asId(word++);
439 idPosR[resultId] = start;
441 if (typeId != spv::NoResult) {
442 const unsigned idTypeSize = typeSizeInWords(typeId);
448 idTypeSizeMap[resultId] = idTypeSize;
452 if (opCode == spv::Op::OpName) {
453 const spv::Id target = asId(start+1);
454 const std::string name = literalString(start+2);
455 nameMap[name] = target;
457 } else if (opCode == spv::Op::OpFunctionCall) {
458 ++fnCalls[asId(start + 3)];
459 } else if (opCode == spv::Op::OpEntryPoint) {
460 entryPoint = asId(start + 2);
461 } else if (opCode == spv::Op::OpFunction) {
463 error("nested function found");
468 fnRes = asId(start + 2);
469 } else if (opCode == spv::Op::OpFunctionEnd) {
470 assert(fnRes != spv::NoResult);
472 error("function end without function start");
476 fnPos[fnRes] = range_t(fnStart, start + asWordCount(start));
478 } else if (isConstOp(opCode)) {
482 assert(asId(start + 2) != spv::NoResult);
483 typeConstPos.insert(start);
484 } else if (isTypeOp(opCode)) {
485 assert(asId(start + 1) != spv::NoResult);
486 typeConstPos.insert(start);
492 [this](spv::Id& id) { localId(id, unmapped); }
496 // Validate the SPIR header
497 void spirvbin_t::validate() const
499 msg(2, 2, std::string("validating: "));
501 if (spv.size() < header_size) {
502 error("file too short: ");
506 if (magic() != spv::MagicNumber) {
507 error("bad magic number");
512 // field 2 = generator magic
513 // field 3 = result <id> bound
515 if (schemaNum() != 0) {
516 error("bad schema, must be 0");
521 int spirvbin_t::processInstruction(unsigned word, instfn_t instFn, idfn_t idFn)
523 const auto instructionStart = word;
524 const unsigned wordCount = asWordCount(instructionStart);
525 const int nextInst = word++ + wordCount;
526 spv::Op opCode = asOpCode(instructionStart);
528 if (nextInst > int(spv.size())) {
529 error("spir instruction terminated too early");
533 // Base for computing number of operands; will be updated as more is learned
534 unsigned numOperands = wordCount - 1;
536 if (instFn(opCode, instructionStart))
539 // Read type and result ID from instruction desc table
540 if (spv::InstructionDesc[opCode].hasType()) {
545 if (spv::InstructionDesc[opCode].hasResult()) {
550 // Extended instructions: currently, assume everything is an ID.
551 // TODO: add whatever data we need for exceptions to that
552 if (opCode == spv::OpExtInst) {
554 idFn(asId(word)); // Instruction set is an ID that also needs to be mapped
556 word += 2; // instruction set, and instruction from set
559 for (unsigned op=0; op < numOperands; ++op)
560 idFn(asId(word++)); // ID
565 // Circular buffer so we can look back at previous unmapped values during the mapping pass.
566 static const unsigned idBufferSize = 4;
567 spv::Id idBuffer[idBufferSize];
568 unsigned idBufferPos = 0;
570 // Store IDs from instruction in our map
571 for (int op = 0; numOperands > 0; ++op, --numOperands) {
572 // SpecConstantOp is special: it includes the operands of another opcode which is
573 // given as a literal in the 3rd word. We will switch over to pretending that the
574 // opcode being processed is the literal opcode value of the SpecConstantOp. See the
575 // SPIRV spec for details. This way we will handle IDs and literals as appropriate for
577 if (opCode == spv::OpSpecConstantOp) {
579 opCode = asOpCode(word++); // this is the opcode embedded in the SpecConstantOp.
584 switch (spv::InstructionDesc[opCode].operands.getClass(op)) {
586 case spv::OperandScope:
587 case spv::OperandMemorySemantics:
588 idBuffer[idBufferPos] = asId(word);
589 idBufferPos = (idBufferPos + 1) % idBufferSize;
593 case spv::OperandVariableIds:
594 for (unsigned i = 0; i < numOperands; ++i)
598 case spv::OperandVariableLiterals:
600 // if (opCode == spv::OpDecorate && asDecoration(word - 1) == spv::DecorationBuiltIn) {
604 // word += numOperands;
607 case spv::OperandVariableLiteralId: {
608 if (opCode == OpSwitch) {
609 // word-2 is the position of the selector ID. OpSwitch Literals match its type.
610 // In case the IDs are currently being remapped, we get the word[-2] ID from
611 // the circular idBuffer.
612 const unsigned literalSizePos = (idBufferPos+idBufferSize-2) % idBufferSize;
613 const unsigned literalSize = idTypeSizeInWords(idBuffer[literalSizePos]);
614 const unsigned numLiteralIdPairs = (nextInst-word) / (1+literalSize);
619 for (unsigned arg=0; arg<numLiteralIdPairs; ++arg) {
620 word += literalSize; // literal
621 idFn(asId(word++)); // label
624 assert(0); // currentely, only OpSwitch uses OperandVariableLiteralId
630 case spv::OperandLiteralString: {
631 const int stringWordCount = literalStringWords(literalString(word));
632 word += stringWordCount;
633 numOperands -= (stringWordCount-1); // -1 because for() header post-decrements
637 case spv::OperandVariableLiteralStrings:
640 // Execution mode might have extra literal operands. Skip them.
641 case spv::OperandExecutionMode:
644 // Single word operands we simply ignore, as they hold no IDs
645 case spv::OperandLiteralNumber:
646 case spv::OperandSource:
647 case spv::OperandExecutionModel:
648 case spv::OperandAddressing:
649 case spv::OperandMemory:
650 case spv::OperandStorage:
651 case spv::OperandDimensionality:
652 case spv::OperandSamplerAddressingMode:
653 case spv::OperandSamplerFilterMode:
654 case spv::OperandSamplerImageFormat:
655 case spv::OperandImageChannelOrder:
656 case spv::OperandImageChannelDataType:
657 case spv::OperandImageOperands:
658 case spv::OperandFPFastMath:
659 case spv::OperandFPRoundingMode:
660 case spv::OperandLinkageType:
661 case spv::OperandAccessQualifier:
662 case spv::OperandFuncParamAttr:
663 case spv::OperandDecoration:
664 case spv::OperandBuiltIn:
665 case spv::OperandSelect:
666 case spv::OperandLoop:
667 case spv::OperandFunction:
668 case spv::OperandMemoryAccess:
669 case spv::OperandGroupOperation:
670 case spv::OperandKernelEnqueueFlags:
671 case spv::OperandKernelProfilingInfo:
672 case spv::OperandCapability:
677 assert(0 && "Unhandled Operand Class");
685 // Make a pass over all the instructions and process them given appropriate functions
686 spirvbin_t& spirvbin_t::process(instfn_t instFn, idfn_t idFn, unsigned begin, unsigned end)
688 // For efficiency, reserve name map space. It can grow if needed.
691 // If begin or end == 0, use defaults
692 begin = (begin == 0 ? header_size : begin);
693 end = (end == 0 ? unsigned(spv.size()) : end);
695 // basic parsing and InstructionDesc table borrowed from SpvDisassemble.cpp...
696 unsigned nextInst = unsigned(spv.size());
698 for (unsigned word = begin; word < end; word = nextInst) {
699 nextInst = processInstruction(word, instFn, idFn);
708 // Apply global name mapping to a single module
709 void spirvbin_t::mapNames()
711 static const std::uint32_t softTypeIdLimit = 3011; // small prime. TODO: get from options
712 static const std::uint32_t firstMappedID = 3019; // offset into ID space
714 for (const auto& name : nameMap) {
715 std::uint32_t hashval = 1911;
716 for (const char c : name.first)
717 hashval = hashval * 1009 + c;
719 if (isOldIdUnmapped(name.second)) {
720 localId(name.second, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
727 // Map fn contents to IDs of similar functions in other modules
728 void spirvbin_t::mapFnBodies()
730 static const std::uint32_t softTypeIdLimit = 19071; // small prime. TODO: get from options
731 static const std::uint32_t firstMappedID = 6203; // offset into ID space
733 // Initial approach: go through some high priority opcodes first and assign them
736 spv::Id fnId = spv::NoResult;
737 std::vector<unsigned> instPos;
738 instPos.reserve(unsigned(spv.size()) / 16); // initial estimate; can grow if needed.
740 // Build local table of instruction start positions
742 [&](spv::Op, unsigned start) { instPos.push_back(start); return true; },
748 // Window size for context-sensitive canonicalization values
749 // Empirical best size from a single data set. TODO: Would be a good tunable.
750 // We essentially perform a little convolution around each instruction,
751 // to capture the flavor of nearby code, to hopefully match to similar
752 // code in other modules.
753 static const unsigned windowSize = 2;
755 for (unsigned entry = 0; entry < unsigned(instPos.size()); ++entry) {
756 const unsigned start = instPos[entry];
757 const spv::Op opCode = asOpCode(start);
759 if (opCode == spv::OpFunction)
760 fnId = asId(start + 2);
762 if (opCode == spv::OpFunctionEnd)
763 fnId = spv::NoResult;
765 if (fnId != spv::NoResult) { // if inside a function
766 if (spv::InstructionDesc[opCode].hasResult()) {
767 const unsigned word = start + (spv::InstructionDesc[opCode].hasType() ? 2 : 1);
768 const spv::Id resId = asId(word);
769 std::uint32_t hashval = fnId * 17; // small prime
771 for (unsigned i = entry-1; i >= entry-windowSize; --i) {
772 if (asOpCode(instPos[i]) == spv::OpFunction)
774 hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime
777 for (unsigned i = entry; i <= entry + windowSize; ++i) {
778 if (asOpCode(instPos[i]) == spv::OpFunctionEnd)
780 hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime
783 if (isOldIdUnmapped(resId)) {
784 localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
793 spv::Op thisOpCode(spv::OpNop);
794 std::unordered_map<int, int> opCounter;
796 fnId = spv::NoResult;
799 [&](spv::Op opCode, unsigned start) {
801 case spv::OpFunction:
802 // Reset counters at each function
805 fnId = asId(start + 2);
808 case spv::OpImageSampleImplicitLod:
809 case spv::OpImageSampleExplicitLod:
810 case spv::OpImageSampleDrefImplicitLod:
811 case spv::OpImageSampleDrefExplicitLod:
812 case spv::OpImageSampleProjImplicitLod:
813 case spv::OpImageSampleProjExplicitLod:
814 case spv::OpImageSampleProjDrefImplicitLod:
815 case spv::OpImageSampleProjDrefExplicitLod:
817 case spv::OpCompositeExtract:
818 case spv::OpCompositeInsert:
819 case spv::OpVectorShuffle:
821 case spv::OpVariable:
823 case spv::OpAccessChain:
826 case spv::OpCompositeConstruct:
827 case spv::OpFunctionCall:
833 thisOpCode = spv::OpNop;
840 if (thisOpCode != spv::OpNop) {
842 const std::uint32_t hashval =
843 // Explicitly cast operands to unsigned int to avoid integer
844 // promotion to signed int followed by integer overflow,
845 // which would result in undefined behavior.
846 static_cast<unsigned int>(opCounter[thisOpCode])
850 + static_cast<unsigned int>(fnId) * 117;
852 if (isOldIdUnmapped(id))
853 localId(id, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
858 // EXPERIMENTAL: forward IO and uniform load/stores into operands
859 // This produces invalid Schema-0 SPIRV
860 void spirvbin_t::forwardLoadStores()
862 idset_t fnLocalVars; // set of function local vars
863 idmap_t idMap; // Map of load result IDs to what they load
865 // EXPERIMENTAL: Forward input and access chain loads into consumptions
867 [&](spv::Op opCode, unsigned start) {
868 // Add inputs and uniforms to the map
869 if ((opCode == spv::OpVariable && asWordCount(start) == 4) &&
870 (spv[start+3] == spv::StorageClassUniform ||
871 spv[start+3] == spv::StorageClassUniformConstant ||
872 spv[start+3] == spv::StorageClassInput))
873 fnLocalVars.insert(asId(start+2));
875 if (opCode == spv::OpAccessChain && fnLocalVars.count(asId(start+3)) > 0)
876 fnLocalVars.insert(asId(start+2));
878 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {
879 idMap[asId(start+2)] = asId(start+3);
886 [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
892 // EXPERIMENTAL: Implicit output stores
897 [&](spv::Op opCode, unsigned start) {
898 // Add inputs and uniforms to the map
899 if ((opCode == spv::OpVariable && asWordCount(start) == 4) &&
900 (spv[start+3] == spv::StorageClassOutput))
901 fnLocalVars.insert(asId(start+2));
903 if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {
904 idMap[asId(start+2)] = asId(start+1);
917 [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
923 strip(); // strip out data we decided to eliminate
926 // optimize loads and stores
927 void spirvbin_t::optLoadStore()
929 idset_t fnLocalVars; // candidates for removal (only locals)
930 idmap_t idMap; // Map of load result IDs to what they load
931 blockmap_t blockMap; // Map of IDs to blocks they first appear in
932 int blockNum = 0; // block count, to avoid crossing flow control
934 // Find all the function local pointers stored at most once, and not via access chains
936 [&](spv::Op opCode, unsigned start) {
937 const int wordCount = asWordCount(start);
939 // Count blocks, so we can avoid crossing flow control
940 if (isFlowCtrl(opCode))
943 // Add local variables to the map
944 if ((opCode == spv::OpVariable && spv[start+3] == spv::StorageClassFunction && asWordCount(start) == 4)) {
945 fnLocalVars.insert(asId(start+2));
949 // Ignore process vars referenced via access chain
950 if ((opCode == spv::OpAccessChain || opCode == spv::OpInBoundsAccessChain) && fnLocalVars.count(asId(start+3)) > 0) {
951 fnLocalVars.erase(asId(start+3));
952 idMap.erase(asId(start+3));
956 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {
957 const spv::Id varId = asId(start+3);
959 // Avoid loads before stores
960 if (idMap.find(varId) == idMap.end()) {
961 fnLocalVars.erase(varId);
965 // don't do for volatile references
966 if (wordCount > 4 && (spv[start+4] & spv::MemoryAccessVolatileMask)) {
967 fnLocalVars.erase(varId);
971 // Handle flow control
972 if (blockMap.find(varId) == blockMap.end()) {
973 blockMap[varId] = blockNum; // track block we found it in.
974 } else if (blockMap[varId] != blockNum) {
975 fnLocalVars.erase(varId); // Ignore if crosses flow control
982 if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {
983 const spv::Id varId = asId(start+1);
985 if (idMap.find(varId) == idMap.end()) {
986 idMap[varId] = asId(start+2);
988 // Remove if it has more than one store to the same pointer
989 fnLocalVars.erase(varId);
993 // don't do for volatile references
994 if (wordCount > 3 && (spv[start+3] & spv::MemoryAccessVolatileMask)) {
995 fnLocalVars.erase(asId(start+3));
996 idMap.erase(asId(start+3));
999 // Handle flow control
1000 if (blockMap.find(varId) == blockMap.end()) {
1001 blockMap[varId] = blockNum; // track block we found it in.
1002 } else if (blockMap[varId] != blockNum) {
1003 fnLocalVars.erase(varId); // Ignore if crosses flow control
1013 // If local var id used anywhere else, don't eliminate
1015 if (fnLocalVars.count(id) > 0) {
1016 fnLocalVars.erase(id);
1026 [&](spv::Op opCode, unsigned start) {
1027 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0)
1028 idMap[asId(start+2)] = idMap[asId(start+3)];
1036 // Chase replacements to their origins, in case there is a chain such as:
1041 // We want to replace uses of 5 with 1.
1042 for (const auto& idPair : idMap) {
1043 spv::Id id = idPair.first;
1044 while (idMap.find(id) != idMap.end()) // Chase to end of chain
1047 idMap[idPair.first] = id; // replace with final result
1050 // Remove the load/store/variables for the ones we've discovered
1052 [&](spv::Op opCode, unsigned start) {
1053 if ((opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) ||
1054 (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) ||
1055 (opCode == spv::OpVariable && fnLocalVars.count(asId(start+2)) > 0)) {
1065 if (idMap.find(id) != idMap.end()) id = idMap[id];
1072 strip(); // strip out data we decided to eliminate
1075 // remove bodies of uncalled functions
1076 void spirvbin_t::dceFuncs()
1078 msg(3, 2, std::string("Removing Dead Functions: "));
1080 // TODO: There are more efficient ways to do this.
1081 bool changed = true;
1086 for (auto fn = fnPos.begin(); fn != fnPos.end(); ) {
1087 if (fn->first == entryPoint) { // don't DCE away the entry point!
1092 const auto call_it = fnCalls.find(fn->first);
1094 if (call_it == fnCalls.end() || call_it->second == 0) {
1096 stripRange.push_back(fn->second);
1098 // decrease counts of called functions
1100 [&](spv::Op opCode, unsigned start) {
1101 if (opCode == spv::Op::OpFunctionCall) {
1102 const auto call_it = fnCalls.find(asId(start + 3));
1103 if (call_it != fnCalls.end()) {
1104 if (--call_it->second <= 0)
1105 fnCalls.erase(call_it);
1118 fn = fnPos.erase(fn);
1124 // remove unused function variables + decorations
1125 void spirvbin_t::dceVars()
1127 msg(3, 2, std::string("DCE Vars: "));
1129 std::unordered_map<spv::Id, int> varUseCount;
1131 // Count function variable use
1133 [&](spv::Op opCode, unsigned start) {
1134 if (opCode == spv::OpVariable) {
1135 ++varUseCount[asId(start+2)];
1137 } else if (opCode == spv::OpEntryPoint) {
1138 const int wordCount = asWordCount(start);
1139 for (int i = 4; i < wordCount; i++) {
1140 ++varUseCount[asId(start+i)];
1147 [&](spv::Id& id) { if (varUseCount[id]) ++varUseCount[id]; }
1153 // Remove single-use function variables + associated decorations and names
1155 [&](spv::Op opCode, unsigned start) {
1156 spv::Id id = spv::NoResult;
1157 if (opCode == spv::OpVariable)
1159 if (opCode == spv::OpDecorate || opCode == spv::OpName)
1162 if (id != spv::NoResult && varUseCount[id] == 1)
1170 // remove unused types
1171 void spirvbin_t::dceTypes()
1173 std::vector<bool> isType(bound(), false);
1175 // for speed, make O(1) way to get to type query (map is log(n))
1176 for (const auto typeStart : typeConstPos)
1177 isType[asTypeConstId(typeStart)] = true;
1179 std::unordered_map<spv::Id, int> typeUseCount;
1181 // This is not the most efficient algorithm, but this is an offline tool, and
1182 // it's easy to write this way. Can be improved opportunistically if needed.
1183 bool changed = true;
1187 typeUseCount.clear();
1189 // Count total type usage
1190 process(inst_fn_nop,
1191 [&](spv::Id& id) { if (isType[id]) ++typeUseCount[id]; }
1197 // Remove single reference types
1198 for (const auto typeStart : typeConstPos) {
1199 const spv::Id typeId = asTypeConstId(typeStart);
1200 if (typeUseCount[typeId] == 1) {
1202 --typeUseCount[typeId];
1203 stripInst(typeStart);
1213 bool spirvbin_t::matchType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt, spv::Id gt) const
1215 // Find the local type id "lt" and global type id "gt"
1216 const auto lt_it = typeConstPosR.find(lt);
1217 if (lt_it == typeConstPosR.end())
1220 const auto typeStart = lt_it->second;
1222 // Search for entry in global table
1223 const auto gtype = globalTypes.find(gt);
1224 if (gtype == globalTypes.end())
1227 const auto& gdata = gtype->second;
1229 // local wordcount and opcode
1230 const int wordCount = asWordCount(typeStart);
1231 const spv::Op opCode = asOpCode(typeStart);
1233 // no type match if opcodes don't match, or operand count doesn't match
1234 if (opCode != opOpCode(gdata[0]) || wordCount != opWordCount(gdata[0]))
1237 const unsigned numOperands = wordCount - 2; // all types have a result
1239 const auto cmpIdRange = [&](range_t range) {
1240 for (int x=range.first; x<std::min(range.second, wordCount); ++x)
1241 if (!matchType(globalTypes, asId(typeStart+x), gdata[x]))
1246 const auto cmpConst = [&]() { return cmpIdRange(constRange(opCode)); };
1247 const auto cmpSubType = [&]() { return cmpIdRange(typeRange(opCode)); };
1249 // Compare literals in range [start,end)
1250 const auto cmpLiteral = [&]() {
1251 const auto range = literalRange(opCode);
1252 return std::equal(spir.begin() + typeStart + range.first,
1253 spir.begin() + typeStart + std::min(range.second, wordCount),
1254 gdata.begin() + range.first);
1257 assert(isTypeOp(opCode) || isConstOp(opCode));
1260 case spv::OpTypeOpaque: // TODO: disable until we compare the literal strings.
1261 case spv::OpTypeQueue: return false;
1262 case spv::OpTypeEvent: // fall through...
1263 case spv::OpTypeDeviceEvent: // ...
1264 case spv::OpTypeReserveId: return false;
1265 // for samplers, we don't handle the optional parameters yet
1266 case spv::OpTypeSampler: return cmpLiteral() && cmpConst() && cmpSubType() && wordCount == 8;
1267 default: return cmpLiteral() && cmpConst() && cmpSubType();
1271 // Look for an equivalent type in the globalTypes map
1272 spv::Id spirvbin_t::findType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt) const
1274 // Try a recursive type match on each in turn, and return a match if we find one
1275 for (const auto& gt : globalTypes)
1276 if (matchType(globalTypes, lt, gt.first))
1283 // Return start position in SPV of given Id. error if not found.
1284 unsigned spirvbin_t::idPos(spv::Id id) const
1286 const auto tid_it = idPosR.find(id);
1287 if (tid_it == idPosR.end()) {
1288 error("ID not found");
1292 return tid_it->second;
1295 // Hash types to canonical values. This can return ID collisions (it's a bit
1296 // inevitable): it's up to the caller to handle that gracefully.
1297 std::uint32_t spirvbin_t::hashType(unsigned typeStart) const
1299 const unsigned wordCount = asWordCount(typeStart);
1300 const spv::Op opCode = asOpCode(typeStart);
1303 case spv::OpTypeVoid: return 0;
1304 case spv::OpTypeBool: return 1;
1305 case spv::OpTypeInt: return 3 + (spv[typeStart+3]);
1306 case spv::OpTypeFloat: return 5;
1307 case spv::OpTypeVector:
1308 return 6 + hashType(idPos(spv[typeStart+2])) * (spv[typeStart+3] - 1);
1309 case spv::OpTypeMatrix:
1310 return 30 + hashType(idPos(spv[typeStart+2])) * (spv[typeStart+3] - 1);
1311 case spv::OpTypeImage:
1312 return 120 + hashType(idPos(spv[typeStart+2])) +
1313 spv[typeStart+3] + // dimensionality
1314 spv[typeStart+4] * 8 * 16 + // depth
1315 spv[typeStart+5] * 4 * 16 + // arrayed
1316 spv[typeStart+6] * 2 * 16 + // multisampled
1317 spv[typeStart+7] * 1 * 16; // format
1318 case spv::OpTypeSampler:
1320 case spv::OpTypeSampledImage:
1322 case spv::OpTypeArray:
1323 return 501 + hashType(idPos(spv[typeStart+2])) * spv[typeStart+3];
1324 case spv::OpTypeRuntimeArray:
1325 return 5000 + hashType(idPos(spv[typeStart+2]));
1326 case spv::OpTypeStruct:
1328 std::uint32_t hash = 10000;
1329 for (unsigned w=2; w < wordCount; ++w)
1330 hash += w * hashType(idPos(spv[typeStart+w]));
1334 case spv::OpTypeOpaque: return 6000 + spv[typeStart+2];
1335 case spv::OpTypePointer: return 100000 + hashType(idPos(spv[typeStart+3]));
1336 case spv::OpTypeFunction:
1338 std::uint32_t hash = 200000;
1339 for (unsigned w=2; w < wordCount; ++w)
1340 hash += w * hashType(idPos(spv[typeStart+w]));
1344 case spv::OpTypeEvent: return 300000;
1345 case spv::OpTypeDeviceEvent: return 300001;
1346 case spv::OpTypeReserveId: return 300002;
1347 case spv::OpTypeQueue: return 300003;
1348 case spv::OpTypePipe: return 300004;
1349 case spv::OpConstantTrue: return 300007;
1350 case spv::OpConstantFalse: return 300008;
1351 case spv::OpConstantComposite:
1353 std::uint32_t hash = 300011 + hashType(idPos(spv[typeStart+1]));
1354 for (unsigned w=3; w < wordCount; ++w)
1355 hash += w * hashType(idPos(spv[typeStart+w]));
1358 case spv::OpConstant:
1360 std::uint32_t hash = 400011 + hashType(idPos(spv[typeStart+1]));
1361 for (unsigned w=3; w < wordCount; ++w)
1362 hash += w * spv[typeStart+w];
1365 case spv::OpConstantNull:
1367 std::uint32_t hash = 500009 + hashType(idPos(spv[typeStart+1]));
1370 case spv::OpConstantSampler:
1372 std::uint32_t hash = 600011 + hashType(idPos(spv[typeStart+1]));
1373 for (unsigned w=3; w < wordCount; ++w)
1374 hash += w * spv[typeStart+w];
1379 error("unknown type opcode");
1384 void spirvbin_t::mapTypeConst()
1386 globaltypes_t globalTypeMap;
1388 msg(3, 2, std::string("Remapping Consts & Types: "));
1390 static const std::uint32_t softTypeIdLimit = 3011; // small prime. TODO: get from options
1391 static const std::uint32_t firstMappedID = 8; // offset into ID space
1393 for (auto& typeStart : typeConstPos) {
1394 const spv::Id resId = asTypeConstId(typeStart);
1395 const std::uint32_t hashval = hashType(typeStart);
1400 if (isOldIdUnmapped(resId)) {
1401 localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
1408 // Strip a single binary by removing ranges given in stripRange
1409 void spirvbin_t::strip()
1411 if (stripRange.empty()) // nothing to do
1414 // Sort strip ranges in order of traversal
1415 std::sort(stripRange.begin(), stripRange.end());
1417 // Allocate a new binary big enough to hold old binary
1418 // We'll step this iterator through the strip ranges as we go through the binary
1419 auto strip_it = stripRange.begin();
1421 int strippedPos = 0;
1422 for (unsigned word = 0; word < unsigned(spv.size()); ++word) {
1423 while (strip_it != stripRange.end() && word >= strip_it->second)
1426 if (strip_it == stripRange.end() || word < strip_it->first || word >= strip_it->second)
1427 spv[strippedPos++] = spv[word];
1430 spv.resize(strippedPos);
1436 // Strip a single binary by removing ranges given in stripRange
1437 void spirvbin_t::remap(std::uint32_t opts)
1441 // Set up opcode tables from SpvDoc
1442 spv::Parameterize();
1444 validate(); // validate header
1445 buildLocalMaps(); // build ID maps
1447 msg(3, 4, std::string("ID bound: ") + std::to_string(bound()));
1449 if (options & STRIP) stripDebug();
1450 if (errorLatch) return;
1452 strip(); // strip out data we decided to eliminate
1453 if (errorLatch) return;
1455 if (options & OPT_LOADSTORE) optLoadStore();
1456 if (errorLatch) return;
1458 if (options & OPT_FWD_LS) forwardLoadStores();
1459 if (errorLatch) return;
1461 if (options & DCE_FUNCS) dceFuncs();
1462 if (errorLatch) return;
1464 if (options & DCE_VARS) dceVars();
1465 if (errorLatch) return;
1467 if (options & DCE_TYPES) dceTypes();
1468 if (errorLatch) return;
1470 strip(); // strip out data we decided to eliminate
1471 if (errorLatch) return;
1473 stripDeadRefs(); // remove references to things we DCEed
1474 if (errorLatch) return;
1476 // after the last strip, we must clean any debug info referring to now-deleted data
1478 if (options & MAP_TYPES) mapTypeConst();
1479 if (errorLatch) return;
1481 if (options & MAP_NAMES) mapNames();
1482 if (errorLatch) return;
1484 if (options & MAP_FUNCS) mapFnBodies();
1485 if (errorLatch) return;
1487 if (options & MAP_ALL) {
1488 mapRemainder(); // map any unmapped IDs
1489 if (errorLatch) return;
1491 applyMap(); // Now remap each shader to the new IDs we've come up with
1492 if (errorLatch) return;
1496 // remap from a memory image
1497 void spirvbin_t::remap(std::vector<std::uint32_t>& in_spv, std::uint32_t opts)
1506 #endif // defined (use_cpp11)