2 // Copyright (C) 2015 LunarG, Inc.
4 // All rights reserved.
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions
10 // Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
13 // Redistributions in binary form must reproduce the above
14 // copyright notice, this list of conditions and the following
15 // disclaimer in the documentation and/or other materials provided
16 // with the distribution.
18 // Neither the name of 3Dlabs Inc. Ltd. nor the names of its
19 // contributors may be used to endorse or promote products derived
20 // from this software without specific prior written permission.
22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33 // POSSIBILITY OF SUCH DAMAGE.
36 #include "SPVRemapper.h"
39 #if !defined (use_cpp11)
40 // ... not supported before C++11
41 #else // defined (use_cpp11)
45 #include "../glslang/Include/Common.h"
49 // By default, just abort on error. Can be overridden via RegisterErrorHandler
50 spirvbin_t::errorfn_t spirvbin_t::errorHandler = [](const std::string&) { exit(5); };
51 // By default, eat log messages. Can be overridden via RegisterLogHandler
52 spirvbin_t::logfn_t spirvbin_t::logHandler = [](const std::string&) { };
54 // This can be overridden to provide other message behavior if needed
55 void spirvbin_t::msg(int minVerbosity, int indent, const std::string& txt) const
57 if (verbose >= minVerbosity)
58 logHandler(std::string(indent, ' ') + txt);
61 // hash opcode, with special handling for OpExtInst
62 std::uint32_t spirvbin_t::asOpCodeHash(unsigned word)
64 const spv::Op opCode = asOpCode(word);
66 std::uint32_t offset = 0;
70 offset += asId(word + 4); break;
75 return opCode * 19 + offset; // 19 = small prime
78 spirvbin_t::range_t spirvbin_t::literalRange(spv::Op opCode) const
80 static const int maxCount = 1<<30;
83 case spv::OpTypeFloat: // fall through...
84 case spv::OpTypePointer: return range_t(2, 3);
85 case spv::OpTypeInt: return range_t(2, 4);
86 // TODO: case spv::OpTypeImage:
87 // TODO: case spv::OpTypeSampledImage:
88 case spv::OpTypeSampler: return range_t(3, 8);
89 case spv::OpTypeVector: // fall through
90 case spv::OpTypeMatrix: // ...
91 case spv::OpTypePipe: return range_t(3, 4);
92 case spv::OpConstant: return range_t(3, maxCount);
93 default: return range_t(0, 0);
97 spirvbin_t::range_t spirvbin_t::typeRange(spv::Op opCode) const
99 static const int maxCount = 1<<30;
101 if (isConstOp(opCode))
102 return range_t(1, 2);
105 case spv::OpTypeVector: // fall through
106 case spv::OpTypeMatrix: // ...
107 case spv::OpTypeSampler: // ...
108 case spv::OpTypeArray: // ...
109 case spv::OpTypeRuntimeArray: // ...
110 case spv::OpTypePipe: return range_t(2, 3);
111 case spv::OpTypeStruct: // fall through
112 case spv::OpTypeFunction: return range_t(2, maxCount);
113 case spv::OpTypePointer: return range_t(3, 4);
114 default: return range_t(0, 0);
118 spirvbin_t::range_t spirvbin_t::constRange(spv::Op opCode) const
120 static const int maxCount = 1<<30;
123 case spv::OpTypeArray: // fall through...
124 case spv::OpTypeRuntimeArray: return range_t(3, 4);
125 case spv::OpConstantComposite: return range_t(3, maxCount);
126 default: return range_t(0, 0);
130 // Return the size of a type in 32-bit words. This currently only
131 // handles ints and floats, and is only invoked by queries which must be
132 // integer types. If ever needed, it can be generalized.
133 unsigned spirvbin_t::typeSizeInWords(spv::Id id) const
135 const unsigned typeStart = idPos(id);
136 const spv::Op opCode = asOpCode(typeStart);
142 case spv::OpTypeInt: // fall through...
143 case spv::OpTypeFloat: return (spv[typeStart+2]+31)/32;
149 // Looks up the type of a given const or variable ID, and
150 // returns its size in 32-bit words.
151 unsigned spirvbin_t::idTypeSizeInWords(spv::Id id) const
153 const auto tid_it = idTypeSizeMap.find(id);
154 if (tid_it == idTypeSizeMap.end()) {
155 error("type size for ID not found");
159 return tid_it->second;
162 // Is this an opcode we should remove when using --strip?
163 bool spirvbin_t::isStripOp(spv::Op opCode) const
167 case spv::OpSourceExtension:
169 case spv::OpMemberName:
170 case spv::OpLine: return true;
171 default: return false;
175 // Return true if this opcode is flow control
176 bool spirvbin_t::isFlowCtrl(spv::Op opCode) const
179 case spv::OpBranchConditional:
182 case spv::OpLoopMerge:
183 case spv::OpSelectionMerge:
185 case spv::OpFunction:
186 case spv::OpFunctionEnd: return true;
187 default: return false;
191 // Return true if this opcode defines a type
192 bool spirvbin_t::isTypeOp(spv::Op opCode) const
195 case spv::OpTypeVoid:
196 case spv::OpTypeBool:
198 case spv::OpTypeFloat:
199 case spv::OpTypeVector:
200 case spv::OpTypeMatrix:
201 case spv::OpTypeImage:
202 case spv::OpTypeSampler:
203 case spv::OpTypeArray:
204 case spv::OpTypeRuntimeArray:
205 case spv::OpTypeStruct:
206 case spv::OpTypeOpaque:
207 case spv::OpTypePointer:
208 case spv::OpTypeFunction:
209 case spv::OpTypeEvent:
210 case spv::OpTypeDeviceEvent:
211 case spv::OpTypeReserveId:
212 case spv::OpTypeQueue:
213 case spv::OpTypeSampledImage:
214 case spv::OpTypePipe: return true;
215 default: return false;
219 // Return true if this opcode defines a constant
220 bool spirvbin_t::isConstOp(spv::Op opCode) const
223 case spv::OpConstantSampler:
224 error("unimplemented constant type");
227 case spv::OpConstantNull:
228 case spv::OpConstantTrue:
229 case spv::OpConstantFalse:
230 case spv::OpConstantComposite:
231 case spv::OpConstant:
239 const auto inst_fn_nop = [](spv::Op, unsigned) { return false; };
240 const auto op_fn_nop = [](spv::Id&) { };
242 // g++ doesn't like these defined in the class proper in an anonymous namespace.
243 // Dunno why. Also MSVC doesn't like the constexpr keyword. Also dunno why.
244 // Defining them externally seems to please both compilers, so, here they are.
245 const spv::Id spirvbin_t::unmapped = spv::Id(-10000);
246 const spv::Id spirvbin_t::unused = spv::Id(-10001);
247 const int spirvbin_t::header_size = 5;
249 spv::Id spirvbin_t::nextUnusedId(spv::Id id)
251 while (isNewIdMapped(id)) // search for an unused ID
257 spv::Id spirvbin_t::localId(spv::Id id, spv::Id newId)
259 //assert(id != spv::NoResult && newId != spv::NoResult);
262 error(std::string("ID out of range: ") + std::to_string(id));
263 return spirvbin_t::unused;
266 if (id >= idMapL.size())
267 idMapL.resize(id+1, unused);
269 if (newId != unmapped && newId != unused) {
270 if (isOldIdUnused(id)) {
271 error(std::string("ID unused in module: ") + std::to_string(id));
272 return spirvbin_t::unused;
275 if (!isOldIdUnmapped(id)) {
276 error(std::string("ID already mapped: ") + std::to_string(id) + " -> "
277 + std::to_string(localId(id)));
279 return spirvbin_t::unused;
282 if (isNewIdMapped(newId)) {
283 error(std::string("ID already used in module: ") + std::to_string(newId));
284 return spirvbin_t::unused;
287 msg(4, 4, std::string("map: ") + std::to_string(id) + " -> " + std::to_string(newId));
289 largestNewId = std::max(largestNewId, newId);
292 return idMapL[id] = newId;
295 // Parse a literal string from the SPIR binary and return it as an std::string
296 // Due to C++11 RValue references, this doesn't copy the result string.
297 std::string spirvbin_t::literalString(unsigned word) const
303 const char* bytes = reinterpret_cast<const char*>(spv.data() + word);
305 while (bytes && *bytes)
311 void spirvbin_t::applyMap()
313 msg(3, 2, std::string("Applying map: "));
315 // Map local IDs through the ID map
316 process(inst_fn_nop, // ignore instructions
317 [this](spv::Id& id) {
323 assert(id != unused && id != unmapped);
328 // Find free IDs for anything we haven't mapped
329 void spirvbin_t::mapRemainder()
331 msg(3, 2, std::string("Remapping remainder: "));
333 spv::Id unusedId = 1; // can't use 0: that's NoResult
334 spirword_t maxBound = 0;
336 for (spv::Id id = 0; id < idMapL.size(); ++id) {
337 if (isOldIdUnused(id))
340 // Find a new mapping for any used but unmapped IDs
341 if (isOldIdUnmapped(id)) {
342 localId(id, unusedId = nextUnusedId(unusedId));
347 if (isOldIdUnmapped(id)) {
348 error(std::string("old ID not mapped: ") + std::to_string(id));
353 maxBound = std::max(maxBound, localId(id) + 1);
359 bound(maxBound); // reset header ID bound to as big as it now needs to be
362 // Mark debug instructions for stripping
363 void spirvbin_t::stripDebug()
365 // Strip instructions in the stripOp set: debug info.
367 [&](spv::Op opCode, unsigned start) {
368 // remember opcodes we want to strip later
369 if (isStripOp(opCode))
376 // Mark instructions that refer to now-removed IDs for stripping
377 void spirvbin_t::stripDeadRefs()
380 [&](spv::Op opCode, unsigned start) {
381 // strip opcodes pointing to removed data
384 case spv::OpMemberName:
385 case spv::OpDecorate:
386 case spv::OpMemberDecorate:
387 if (idPosR.find(asId(start+1)) == idPosR.end())
391 break; // leave it alone
401 // Update local maps of ID, type, etc positions
402 void spirvbin_t::buildLocalMaps()
404 msg(2, 2, std::string("build local maps: "));
408 // preserve nameMap, so we don't clear that.
411 typeConstPos.clear();
413 entryPoint = spv::NoResult;
416 idMapL.resize(bound(), unused);
419 spv::Id fnRes = spv::NoResult;
421 // build local Id and name maps
423 [&](spv::Op opCode, unsigned start) {
424 unsigned word = start+1;
425 spv::Id typeId = spv::NoResult;
427 if (spv::InstructionDesc[opCode].hasType())
428 typeId = asId(word++);
430 // If there's a result ID, remember the size of its type
431 if (spv::InstructionDesc[opCode].hasResult()) {
432 const spv::Id resultId = asId(word++);
433 idPosR[resultId] = start;
435 if (typeId != spv::NoResult) {
436 const unsigned idTypeSize = typeSizeInWords(typeId);
442 idTypeSizeMap[resultId] = idTypeSize;
446 if (opCode == spv::Op::OpName) {
447 const spv::Id target = asId(start+1);
448 const std::string name = literalString(start+2);
449 nameMap[name] = target;
451 } else if (opCode == spv::Op::OpFunctionCall) {
452 ++fnCalls[asId(start + 3)];
453 } else if (opCode == spv::Op::OpEntryPoint) {
454 entryPoint = asId(start + 2);
455 } else if (opCode == spv::Op::OpFunction) {
457 error("nested function found");
462 fnRes = asId(start + 2);
463 } else if (opCode == spv::Op::OpFunctionEnd) {
464 assert(fnRes != spv::NoResult);
466 error("function end without function start");
470 fnPos[fnRes] = range_t(fnStart, start + asWordCount(start));
472 } else if (isConstOp(opCode)) {
476 assert(asId(start + 2) != spv::NoResult);
477 typeConstPos.insert(start);
478 } else if (isTypeOp(opCode)) {
479 assert(asId(start + 1) != spv::NoResult);
480 typeConstPos.insert(start);
486 [this](spv::Id& id) { localId(id, unmapped); }
490 // Validate the SPIR header
491 void spirvbin_t::validate() const
493 msg(2, 2, std::string("validating: "));
495 if (spv.size() < header_size) {
496 error("file too short: ");
500 if (magic() != spv::MagicNumber) {
501 error("bad magic number");
506 // field 2 = generator magic
507 // field 3 = result <id> bound
509 if (schemaNum() != 0) {
510 error("bad schema, must be 0");
515 int spirvbin_t::processInstruction(unsigned word, instfn_t instFn, idfn_t idFn)
517 const auto instructionStart = word;
518 const unsigned wordCount = asWordCount(instructionStart);
519 const int nextInst = word++ + wordCount;
520 spv::Op opCode = asOpCode(instructionStart);
522 if (nextInst > int(spv.size())) {
523 error("spir instruction terminated too early");
527 // Base for computing number of operands; will be updated as more is learned
528 unsigned numOperands = wordCount - 1;
530 if (instFn(opCode, instructionStart))
533 // Read type and result ID from instruction desc table
534 if (spv::InstructionDesc[opCode].hasType()) {
539 if (spv::InstructionDesc[opCode].hasResult()) {
544 // Extended instructions: currently, assume everything is an ID.
545 // TODO: add whatever data we need for exceptions to that
546 if (opCode == spv::OpExtInst) {
548 idFn(asId(word)); // Instruction set is an ID that also needs to be mapped
550 word += 2; // instruction set, and instruction from set
553 for (unsigned op=0; op < numOperands; ++op)
554 idFn(asId(word++)); // ID
559 // Circular buffer so we can look back at previous unmapped values during the mapping pass.
560 static const unsigned idBufferSize = 4;
561 spv::Id idBuffer[idBufferSize];
562 unsigned idBufferPos = 0;
564 // Store IDs from instruction in our map
565 for (int op = 0; numOperands > 0; ++op, --numOperands) {
566 // SpecConstantOp is special: it includes the operands of another opcode which is
567 // given as a literal in the 3rd word. We will switch over to pretending that the
568 // opcode being processed is the literal opcode value of the SpecConstantOp. See the
569 // SPIRV spec for details. This way we will handle IDs and literals as appropriate for
571 if (opCode == spv::OpSpecConstantOp) {
573 opCode = asOpCode(word++); // this is the opcode embedded in the SpecConstantOp.
578 switch (spv::InstructionDesc[opCode].operands.getClass(op)) {
580 case spv::OperandScope:
581 case spv::OperandMemorySemantics:
582 idBuffer[idBufferPos] = asId(word);
583 idBufferPos = (idBufferPos + 1) % idBufferSize;
587 case spv::OperandVariableIds:
588 for (unsigned i = 0; i < numOperands; ++i)
592 case spv::OperandVariableLiterals:
594 // if (opCode == spv::OpDecorate && asDecoration(word - 1) == spv::DecorationBuiltIn) {
598 // word += numOperands;
601 case spv::OperandVariableLiteralId: {
602 if (opCode == OpSwitch) {
603 // word-2 is the position of the selector ID. OpSwitch Literals match its type.
604 // In case the IDs are currently being remapped, we get the word[-2] ID from
605 // the circular idBuffer.
606 const unsigned literalSizePos = (idBufferPos+idBufferSize-2) % idBufferSize;
607 const unsigned literalSize = idTypeSizeInWords(idBuffer[literalSizePos]);
608 const unsigned numLiteralIdPairs = (nextInst-word) / (1+literalSize);
613 for (unsigned arg=0; arg<numLiteralIdPairs; ++arg) {
614 word += literalSize; // literal
615 idFn(asId(word++)); // label
618 assert(0); // currentely, only OpSwitch uses OperandVariableLiteralId
624 case spv::OperandLiteralString: {
625 const int stringWordCount = literalStringWords(literalString(word));
626 word += stringWordCount;
627 numOperands -= (stringWordCount-1); // -1 because for() header post-decrements
631 case spv::OperandVariableLiteralStrings:
634 // Execution mode might have extra literal operands. Skip them.
635 case spv::OperandExecutionMode:
638 // Single word operands we simply ignore, as they hold no IDs
639 case spv::OperandLiteralNumber:
640 case spv::OperandSource:
641 case spv::OperandExecutionModel:
642 case spv::OperandAddressing:
643 case spv::OperandMemory:
644 case spv::OperandStorage:
645 case spv::OperandDimensionality:
646 case spv::OperandSamplerAddressingMode:
647 case spv::OperandSamplerFilterMode:
648 case spv::OperandSamplerImageFormat:
649 case spv::OperandImageChannelOrder:
650 case spv::OperandImageChannelDataType:
651 case spv::OperandImageOperands:
652 case spv::OperandFPFastMath:
653 case spv::OperandFPRoundingMode:
654 case spv::OperandLinkageType:
655 case spv::OperandAccessQualifier:
656 case spv::OperandFuncParamAttr:
657 case spv::OperandDecoration:
658 case spv::OperandBuiltIn:
659 case spv::OperandSelect:
660 case spv::OperandLoop:
661 case spv::OperandFunction:
662 case spv::OperandMemoryAccess:
663 case spv::OperandGroupOperation:
664 case spv::OperandKernelEnqueueFlags:
665 case spv::OperandKernelProfilingInfo:
666 case spv::OperandCapability:
671 assert(0 && "Unhandled Operand Class");
679 // Make a pass over all the instructions and process them given appropriate functions
680 spirvbin_t& spirvbin_t::process(instfn_t instFn, idfn_t idFn, unsigned begin, unsigned end)
682 // For efficiency, reserve name map space. It can grow if needed.
685 // If begin or end == 0, use defaults
686 begin = (begin == 0 ? header_size : begin);
687 end = (end == 0 ? unsigned(spv.size()) : end);
689 // basic parsing and InstructionDesc table borrowed from SpvDisassemble.cpp...
690 unsigned nextInst = unsigned(spv.size());
692 for (unsigned word = begin; word < end; word = nextInst) {
693 nextInst = processInstruction(word, instFn, idFn);
702 // Apply global name mapping to a single module
703 void spirvbin_t::mapNames()
705 static const std::uint32_t softTypeIdLimit = 3011; // small prime. TODO: get from options
706 static const std::uint32_t firstMappedID = 3019; // offset into ID space
708 for (const auto& name : nameMap) {
709 std::uint32_t hashval = 1911;
710 for (const char c : name.first)
711 hashval = hashval * 1009 + c;
713 if (isOldIdUnmapped(name.second)) {
714 localId(name.second, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
721 // Map fn contents to IDs of similar functions in other modules
722 void spirvbin_t::mapFnBodies()
724 static const std::uint32_t softTypeIdLimit = 19071; // small prime. TODO: get from options
725 static const std::uint32_t firstMappedID = 6203; // offset into ID space
727 // Initial approach: go through some high priority opcodes first and assign them
730 spv::Id fnId = spv::NoResult;
731 std::vector<unsigned> instPos;
732 instPos.reserve(unsigned(spv.size()) / 16); // initial estimate; can grow if needed.
734 // Build local table of instruction start positions
736 [&](spv::Op, unsigned start) { instPos.push_back(start); return true; },
742 // Window size for context-sensitive canonicalization values
743 // Empirical best size from a single data set. TODO: Would be a good tunable.
744 // We essentially perform a little convolution around each instruction,
745 // to capture the flavor of nearby code, to hopefully match to similar
746 // code in other modules.
747 static const unsigned windowSize = 2;
749 for (unsigned entry = 0; entry < unsigned(instPos.size()); ++entry) {
750 const unsigned start = instPos[entry];
751 const spv::Op opCode = asOpCode(start);
753 if (opCode == spv::OpFunction)
754 fnId = asId(start + 2);
756 if (opCode == spv::OpFunctionEnd)
757 fnId = spv::NoResult;
759 if (fnId != spv::NoResult) { // if inside a function
760 if (spv::InstructionDesc[opCode].hasResult()) {
761 const unsigned word = start + (spv::InstructionDesc[opCode].hasType() ? 2 : 1);
762 const spv::Id resId = asId(word);
763 std::uint32_t hashval = fnId * 17; // small prime
765 for (unsigned i = entry-1; i >= entry-windowSize; --i) {
766 if (asOpCode(instPos[i]) == spv::OpFunction)
768 hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime
771 for (unsigned i = entry; i <= entry + windowSize; ++i) {
772 if (asOpCode(instPos[i]) == spv::OpFunctionEnd)
774 hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime
777 if (isOldIdUnmapped(resId)) {
778 localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
787 spv::Op thisOpCode(spv::OpNop);
788 std::unordered_map<int, int> opCounter;
790 fnId = spv::NoResult;
793 [&](spv::Op opCode, unsigned start) {
795 case spv::OpFunction:
796 // Reset counters at each function
799 fnId = asId(start + 2);
802 case spv::OpImageSampleImplicitLod:
803 case spv::OpImageSampleExplicitLod:
804 case spv::OpImageSampleDrefImplicitLod:
805 case spv::OpImageSampleDrefExplicitLod:
806 case spv::OpImageSampleProjImplicitLod:
807 case spv::OpImageSampleProjExplicitLod:
808 case spv::OpImageSampleProjDrefImplicitLod:
809 case spv::OpImageSampleProjDrefExplicitLod:
811 case spv::OpCompositeExtract:
812 case spv::OpCompositeInsert:
813 case spv::OpVectorShuffle:
815 case spv::OpVariable:
817 case spv::OpAccessChain:
820 case spv::OpCompositeConstruct:
821 case spv::OpFunctionCall:
827 thisOpCode = spv::OpNop;
834 if (thisOpCode != spv::OpNop) {
836 const std::uint32_t hashval =
837 // Explicitly cast operands to unsigned int to avoid integer
838 // promotion to signed int followed by integer overflow,
839 // which would result in undefined behavior.
840 static_cast<unsigned int>(opCounter[thisOpCode])
844 + static_cast<unsigned int>(fnId) * 117;
846 if (isOldIdUnmapped(id))
847 localId(id, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
852 // EXPERIMENTAL: forward IO and uniform load/stores into operands
853 // This produces invalid Schema-0 SPIRV
854 void spirvbin_t::forwardLoadStores()
856 idset_t fnLocalVars; // set of function local vars
857 idmap_t idMap; // Map of load result IDs to what they load
859 // EXPERIMENTAL: Forward input and access chain loads into consumptions
861 [&](spv::Op opCode, unsigned start) {
862 // Add inputs and uniforms to the map
863 if ((opCode == spv::OpVariable && asWordCount(start) == 4) &&
864 (spv[start+3] == spv::StorageClassUniform ||
865 spv[start+3] == spv::StorageClassUniformConstant ||
866 spv[start+3] == spv::StorageClassInput))
867 fnLocalVars.insert(asId(start+2));
869 if (opCode == spv::OpAccessChain && fnLocalVars.count(asId(start+3)) > 0)
870 fnLocalVars.insert(asId(start+2));
872 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {
873 idMap[asId(start+2)] = asId(start+3);
880 [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
886 // EXPERIMENTAL: Implicit output stores
891 [&](spv::Op opCode, unsigned start) {
892 // Add inputs and uniforms to the map
893 if ((opCode == spv::OpVariable && asWordCount(start) == 4) &&
894 (spv[start+3] == spv::StorageClassOutput))
895 fnLocalVars.insert(asId(start+2));
897 if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {
898 idMap[asId(start+2)] = asId(start+1);
911 [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
917 strip(); // strip out data we decided to eliminate
920 // optimize loads and stores
921 void spirvbin_t::optLoadStore()
923 idset_t fnLocalVars; // candidates for removal (only locals)
924 idmap_t idMap; // Map of load result IDs to what they load
925 blockmap_t blockMap; // Map of IDs to blocks they first appear in
926 int blockNum = 0; // block count, to avoid crossing flow control
928 // Find all the function local pointers stored at most once, and not via access chains
930 [&](spv::Op opCode, unsigned start) {
931 const int wordCount = asWordCount(start);
933 // Count blocks, so we can avoid crossing flow control
934 if (isFlowCtrl(opCode))
937 // Add local variables to the map
938 if ((opCode == spv::OpVariable && spv[start+3] == spv::StorageClassFunction && asWordCount(start) == 4)) {
939 fnLocalVars.insert(asId(start+2));
943 // Ignore process vars referenced via access chain
944 if ((opCode == spv::OpAccessChain || opCode == spv::OpInBoundsAccessChain) && fnLocalVars.count(asId(start+3)) > 0) {
945 fnLocalVars.erase(asId(start+3));
946 idMap.erase(asId(start+3));
950 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {
951 const spv::Id varId = asId(start+3);
953 // Avoid loads before stores
954 if (idMap.find(varId) == idMap.end()) {
955 fnLocalVars.erase(varId);
959 // don't do for volatile references
960 if (wordCount > 4 && (spv[start+4] & spv::MemoryAccessVolatileMask)) {
961 fnLocalVars.erase(varId);
965 // Handle flow control
966 if (blockMap.find(varId) == blockMap.end()) {
967 blockMap[varId] = blockNum; // track block we found it in.
968 } else if (blockMap[varId] != blockNum) {
969 fnLocalVars.erase(varId); // Ignore if crosses flow control
976 if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {
977 const spv::Id varId = asId(start+1);
979 if (idMap.find(varId) == idMap.end()) {
980 idMap[varId] = asId(start+2);
982 // Remove if it has more than one store to the same pointer
983 fnLocalVars.erase(varId);
987 // don't do for volatile references
988 if (wordCount > 3 && (spv[start+3] & spv::MemoryAccessVolatileMask)) {
989 fnLocalVars.erase(asId(start+3));
990 idMap.erase(asId(start+3));
993 // Handle flow control
994 if (blockMap.find(varId) == blockMap.end()) {
995 blockMap[varId] = blockNum; // track block we found it in.
996 } else if (blockMap[varId] != blockNum) {
997 fnLocalVars.erase(varId); // Ignore if crosses flow control
1007 // If local var id used anywhere else, don't eliminate
1009 if (fnLocalVars.count(id) > 0) {
1010 fnLocalVars.erase(id);
1020 [&](spv::Op opCode, unsigned start) {
1021 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0)
1022 idMap[asId(start+2)] = idMap[asId(start+3)];
1030 // Chase replacements to their origins, in case there is a chain such as:
1035 // We want to replace uses of 5 with 1.
1036 for (const auto& idPair : idMap) {
1037 spv::Id id = idPair.first;
1038 while (idMap.find(id) != idMap.end()) // Chase to end of chain
1041 idMap[idPair.first] = id; // replace with final result
1044 // Remove the load/store/variables for the ones we've discovered
1046 [&](spv::Op opCode, unsigned start) {
1047 if ((opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) ||
1048 (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) ||
1049 (opCode == spv::OpVariable && fnLocalVars.count(asId(start+2)) > 0)) {
1059 if (idMap.find(id) != idMap.end()) id = idMap[id];
1066 strip(); // strip out data we decided to eliminate
1069 // remove bodies of uncalled functions
1070 void spirvbin_t::dceFuncs()
1072 msg(3, 2, std::string("Removing Dead Functions: "));
1074 // TODO: There are more efficient ways to do this.
1075 bool changed = true;
1080 for (auto fn = fnPos.begin(); fn != fnPos.end(); ) {
1081 if (fn->first == entryPoint) { // don't DCE away the entry point!
1086 const auto call_it = fnCalls.find(fn->first);
1088 if (call_it == fnCalls.end() || call_it->second == 0) {
1090 stripRange.push_back(fn->second);
1092 // decrease counts of called functions
1094 [&](spv::Op opCode, unsigned start) {
1095 if (opCode == spv::Op::OpFunctionCall) {
1096 const auto call_it = fnCalls.find(asId(start + 3));
1097 if (call_it != fnCalls.end()) {
1098 if (--call_it->second <= 0)
1099 fnCalls.erase(call_it);
1112 fn = fnPos.erase(fn);
1118 // remove unused function variables + decorations
1119 void spirvbin_t::dceVars()
1121 msg(3, 2, std::string("DCE Vars: "));
1123 std::unordered_map<spv::Id, int> varUseCount;
1125 // Count function variable use
1127 [&](spv::Op opCode, unsigned start) {
1128 if (opCode == spv::OpVariable) {
1129 ++varUseCount[asId(start+2)];
1131 } else if (opCode == spv::OpEntryPoint) {
1132 const int wordCount = asWordCount(start);
1133 for (int i = 4; i < wordCount; i++) {
1134 ++varUseCount[asId(start+i)];
1141 [&](spv::Id& id) { if (varUseCount[id]) ++varUseCount[id]; }
1147 // Remove single-use function variables + associated decorations and names
1149 [&](spv::Op opCode, unsigned start) {
1150 spv::Id id = spv::NoResult;
1151 if (opCode == spv::OpVariable)
1153 if (opCode == spv::OpDecorate || opCode == spv::OpName)
1156 if (id != spv::NoResult && varUseCount[id] == 1)
1164 // remove unused types
1165 void spirvbin_t::dceTypes()
1167 std::vector<bool> isType(bound(), false);
1169 // for speed, make O(1) way to get to type query (map is log(n))
1170 for (const auto typeStart : typeConstPos)
1171 isType[asTypeConstId(typeStart)] = true;
1173 std::unordered_map<spv::Id, int> typeUseCount;
1175 // This is not the most efficient algorithm, but this is an offline tool, and
1176 // it's easy to write this way. Can be improved opportunistically if needed.
1177 bool changed = true;
1181 typeUseCount.clear();
1183 // Count total type usage
1184 process(inst_fn_nop,
1185 [&](spv::Id& id) { if (isType[id]) ++typeUseCount[id]; }
1191 // Remove single reference types
1192 for (const auto typeStart : typeConstPos) {
1193 const spv::Id typeId = asTypeConstId(typeStart);
1194 if (typeUseCount[typeId] == 1) {
1196 --typeUseCount[typeId];
1197 stripInst(typeStart);
1207 bool spirvbin_t::matchType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt, spv::Id gt) const
1209 // Find the local type id "lt" and global type id "gt"
1210 const auto lt_it = typeConstPosR.find(lt);
1211 if (lt_it == typeConstPosR.end())
1214 const auto typeStart = lt_it->second;
1216 // Search for entry in global table
1217 const auto gtype = globalTypes.find(gt);
1218 if (gtype == globalTypes.end())
1221 const auto& gdata = gtype->second;
1223 // local wordcount and opcode
1224 const int wordCount = asWordCount(typeStart);
1225 const spv::Op opCode = asOpCode(typeStart);
1227 // no type match if opcodes don't match, or operand count doesn't match
1228 if (opCode != opOpCode(gdata[0]) || wordCount != opWordCount(gdata[0]))
1231 const unsigned numOperands = wordCount - 2; // all types have a result
1233 const auto cmpIdRange = [&](range_t range) {
1234 for (int x=range.first; x<std::min(range.second, wordCount); ++x)
1235 if (!matchType(globalTypes, asId(typeStart+x), gdata[x]))
1240 const auto cmpConst = [&]() { return cmpIdRange(constRange(opCode)); };
1241 const auto cmpSubType = [&]() { return cmpIdRange(typeRange(opCode)); };
1243 // Compare literals in range [start,end)
1244 const auto cmpLiteral = [&]() {
1245 const auto range = literalRange(opCode);
1246 return std::equal(spir.begin() + typeStart + range.first,
1247 spir.begin() + typeStart + std::min(range.second, wordCount),
1248 gdata.begin() + range.first);
1251 assert(isTypeOp(opCode) || isConstOp(opCode));
1254 case spv::OpTypeOpaque: // TODO: disable until we compare the literal strings.
1255 case spv::OpTypeQueue: return false;
1256 case spv::OpTypeEvent: // fall through...
1257 case spv::OpTypeDeviceEvent: // ...
1258 case spv::OpTypeReserveId: return false;
1259 // for samplers, we don't handle the optional parameters yet
1260 case spv::OpTypeSampler: return cmpLiteral() && cmpConst() && cmpSubType() && wordCount == 8;
1261 default: return cmpLiteral() && cmpConst() && cmpSubType();
1265 // Look for an equivalent type in the globalTypes map
1266 spv::Id spirvbin_t::findType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt) const
1268 // Try a recursive type match on each in turn, and return a match if we find one
1269 for (const auto& gt : globalTypes)
1270 if (matchType(globalTypes, lt, gt.first))
1277 // Return start position in SPV of given Id. error if not found.
1278 unsigned spirvbin_t::idPos(spv::Id id) const
1280 const auto tid_it = idPosR.find(id);
1281 if (tid_it == idPosR.end()) {
1282 error("ID not found");
1286 return tid_it->second;
1289 // Hash types to canonical values. This can return ID collisions (it's a bit
1290 // inevitable): it's up to the caller to handle that gracefully.
1291 std::uint32_t spirvbin_t::hashType(unsigned typeStart) const
1293 const unsigned wordCount = asWordCount(typeStart);
1294 const spv::Op opCode = asOpCode(typeStart);
1297 case spv::OpTypeVoid: return 0;
1298 case spv::OpTypeBool: return 1;
1299 case spv::OpTypeInt: return 3 + (spv[typeStart+3]);
1300 case spv::OpTypeFloat: return 5;
1301 case spv::OpTypeVector:
1302 return 6 + hashType(idPos(spv[typeStart+2])) * (spv[typeStart+3] - 1);
1303 case spv::OpTypeMatrix:
1304 return 30 + hashType(idPos(spv[typeStart+2])) * (spv[typeStart+3] - 1);
1305 case spv::OpTypeImage:
1306 return 120 + hashType(idPos(spv[typeStart+2])) +
1307 spv[typeStart+3] + // dimensionality
1308 spv[typeStart+4] * 8 * 16 + // depth
1309 spv[typeStart+5] * 4 * 16 + // arrayed
1310 spv[typeStart+6] * 2 * 16 + // multisampled
1311 spv[typeStart+7] * 1 * 16; // format
1312 case spv::OpTypeSampler:
1314 case spv::OpTypeSampledImage:
1316 case spv::OpTypeArray:
1317 return 501 + hashType(idPos(spv[typeStart+2])) * spv[typeStart+3];
1318 case spv::OpTypeRuntimeArray:
1319 return 5000 + hashType(idPos(spv[typeStart+2]));
1320 case spv::OpTypeStruct:
1322 std::uint32_t hash = 10000;
1323 for (unsigned w=2; w < wordCount; ++w)
1324 hash += w * hashType(idPos(spv[typeStart+w]));
1328 case spv::OpTypeOpaque: return 6000 + spv[typeStart+2];
1329 case spv::OpTypePointer: return 100000 + hashType(idPos(spv[typeStart+3]));
1330 case spv::OpTypeFunction:
1332 std::uint32_t hash = 200000;
1333 for (unsigned w=2; w < wordCount; ++w)
1334 hash += w * hashType(idPos(spv[typeStart+w]));
1338 case spv::OpTypeEvent: return 300000;
1339 case spv::OpTypeDeviceEvent: return 300001;
1340 case spv::OpTypeReserveId: return 300002;
1341 case spv::OpTypeQueue: return 300003;
1342 case spv::OpTypePipe: return 300004;
1343 case spv::OpConstantTrue: return 300007;
1344 case spv::OpConstantFalse: return 300008;
1345 case spv::OpConstantComposite:
1347 std::uint32_t hash = 300011 + hashType(idPos(spv[typeStart+1]));
1348 for (unsigned w=3; w < wordCount; ++w)
1349 hash += w * hashType(idPos(spv[typeStart+w]));
1352 case spv::OpConstant:
1354 std::uint32_t hash = 400011 + hashType(idPos(spv[typeStart+1]));
1355 for (unsigned w=3; w < wordCount; ++w)
1356 hash += w * spv[typeStart+w];
1359 case spv::OpConstantNull:
1361 std::uint32_t hash = 500009 + hashType(idPos(spv[typeStart+1]));
1364 case spv::OpConstantSampler:
1366 std::uint32_t hash = 600011 + hashType(idPos(spv[typeStart+1]));
1367 for (unsigned w=3; w < wordCount; ++w)
1368 hash += w * spv[typeStart+w];
1373 error("unknown type opcode");
1378 void spirvbin_t::mapTypeConst()
1380 globaltypes_t globalTypeMap;
1382 msg(3, 2, std::string("Remapping Consts & Types: "));
1384 static const std::uint32_t softTypeIdLimit = 3011; // small prime. TODO: get from options
1385 static const std::uint32_t firstMappedID = 8; // offset into ID space
1387 for (auto& typeStart : typeConstPos) {
1388 const spv::Id resId = asTypeConstId(typeStart);
1389 const std::uint32_t hashval = hashType(typeStart);
1394 if (isOldIdUnmapped(resId)) {
1395 localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
1402 // Strip a single binary by removing ranges given in stripRange
1403 void spirvbin_t::strip()
1405 if (stripRange.empty()) // nothing to do
1408 // Sort strip ranges in order of traversal
1409 std::sort(stripRange.begin(), stripRange.end());
1411 // Allocate a new binary big enough to hold old binary
1412 // We'll step this iterator through the strip ranges as we go through the binary
1413 auto strip_it = stripRange.begin();
1415 int strippedPos = 0;
1416 for (unsigned word = 0; word < unsigned(spv.size()); ++word) {
1417 while (strip_it != stripRange.end() && word >= strip_it->second)
1420 if (strip_it == stripRange.end() || word < strip_it->first || word >= strip_it->second)
1421 spv[strippedPos++] = spv[word];
1424 spv.resize(strippedPos);
1430 // Strip a single binary by removing ranges given in stripRange
1431 void spirvbin_t::remap(std::uint32_t opts)
1435 // Set up opcode tables from SpvDoc
1436 spv::Parameterize();
1438 validate(); // validate header
1439 buildLocalMaps(); // build ID maps
1441 msg(3, 4, std::string("ID bound: ") + std::to_string(bound()));
1443 if (options & STRIP) stripDebug();
1444 if (errorLatch) return;
1446 strip(); // strip out data we decided to eliminate
1447 if (errorLatch) return;
1449 if (options & OPT_LOADSTORE) optLoadStore();
1450 if (errorLatch) return;
1452 if (options & OPT_FWD_LS) forwardLoadStores();
1453 if (errorLatch) return;
1455 if (options & DCE_FUNCS) dceFuncs();
1456 if (errorLatch) return;
1458 if (options & DCE_VARS) dceVars();
1459 if (errorLatch) return;
1461 if (options & DCE_TYPES) dceTypes();
1462 if (errorLatch) return;
1464 strip(); // strip out data we decided to eliminate
1465 if (errorLatch) return;
1467 stripDeadRefs(); // remove references to things we DCEed
1468 if (errorLatch) return;
1470 // after the last strip, we must clean any debug info referring to now-deleted data
1472 if (options & MAP_TYPES) mapTypeConst();
1473 if (errorLatch) return;
1475 if (options & MAP_NAMES) mapNames();
1476 if (errorLatch) return;
1478 if (options & MAP_FUNCS) mapFnBodies();
1479 if (errorLatch) return;
1481 if (options & MAP_ALL) {
1482 mapRemainder(); // map any unmapped IDs
1483 if (errorLatch) return;
1485 applyMap(); // Now remap each shader to the new IDs we've come up with
1486 if (errorLatch) return;
1490 // remap from a memory image
1491 void spirvbin_t::remap(std::vector<std::uint32_t>& in_spv, std::uint32_t opts)
1500 #endif // defined (use_cpp11)