2 // Copyright (C) 2018 Google, Inc.
4 // All rights reserved.
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions
10 // Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
13 // Redistributions in binary form must reproduce the above
14 // copyright notice, this list of conditions and the following
15 // disclaimer in the documentation and/or other materials provided
16 // with the distribution.
18 // Neither the name of 3Dlabs Inc. Ltd. nor the names of its
19 // contributors may be used to endorse or promote products derived
20 // from this software without specific prior written permission.
22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33 // POSSIBILITY OF SUCH DAMAGE.
36 // Post-processing for SPIR-V IR, in internal form, not standard binary form.
42 #include <unordered_set>
45 #include "SpvBuilder.h"
48 #include "GlslangToSpv.h"
49 #include "SpvBuilder.h"
51 #include "GLSL.std.450.h"
52 #include "GLSL.ext.KHR.h"
53 #include "GLSL.ext.EXT.h"
55 #include "GLSL.ext.AMD.h"
58 #include "GLSL.ext.NV.h"
64 // Hook to visit each operand type and result type of an instruction.
65 // Will be called multiple times for one instruction, once for each typed
66 // operand and the result.
67 void Builder::postProcessType(const Instruction& inst, Id typeId)
69 // Characterize the type being questioned
70 Id basicTypeOp = getMostBasicTypeClass(typeId);
72 if (basicTypeOp == OpTypeFloat || basicTypeOp == OpTypeInt)
73 width = getScalarTypeWidth(typeId);
75 // Do opcode-specific checks
76 switch (inst.getOpCode()) {
79 if (basicTypeOp == OpTypeStruct) {
80 if (containsType(typeId, OpTypeInt, 8))
81 addCapability(CapabilityInt8);
82 if (containsType(typeId, OpTypeInt, 16))
83 addCapability(CapabilityInt16);
84 if (containsType(typeId, OpTypeFloat, 16))
85 addCapability(CapabilityFloat16);
87 StorageClass storageClass = getStorageClass(inst.getIdOperand(0));
89 switch (storageClass) {
90 case StorageClassPhysicalStorageBufferEXT:
91 case StorageClassUniform:
92 case StorageClassStorageBuffer:
93 case StorageClassPushConstant:
96 addCapability(CapabilityInt8);
99 } else if (width == 16) {
100 switch (storageClass) {
101 case StorageClassPhysicalStorageBufferEXT:
102 case StorageClassUniform:
103 case StorageClassStorageBuffer:
104 case StorageClassPushConstant:
105 case StorageClassInput:
106 case StorageClassOutput:
109 if (basicTypeOp == OpTypeInt)
110 addCapability(CapabilityInt16);
111 if (basicTypeOp == OpTypeFloat)
112 addCapability(CapabilityFloat16);
119 case OpPtrAccessChain:
127 switch (inst.getImmediateOperand(1)) {
128 case GLSLstd450Frexp:
129 case GLSLstd450FrexpStruct:
130 if (getSpvVersion() < glslang::EShTargetSpv_1_3 && containsType(typeId, OpTypeInt, 16))
131 addExtension(spv::E_SPV_AMD_gpu_shader_int16);
133 case GLSLstd450InterpolateAtCentroid:
134 case GLSLstd450InterpolateAtSample:
135 case GLSLstd450InterpolateAtOffset:
136 if (getSpvVersion() < glslang::EShTargetSpv_1_3 && containsType(typeId, OpTypeFloat, 16))
137 addExtension(spv::E_SPV_AMD_gpu_shader_half_float);
145 if (basicTypeOp == OpTypeFloat && width == 16)
146 addCapability(CapabilityFloat16);
147 if (basicTypeOp == OpTypeInt && width == 16)
148 addCapability(CapabilityInt16);
149 if (basicTypeOp == OpTypeInt && width == 8)
150 addCapability(CapabilityInt8);
155 // Called for each instruction that resides in a block.
156 void Builder::postProcess(Instruction& inst)
158 // Add capabilities based simply on the opcode.
159 switch (inst.getOpCode()) {
161 switch (inst.getImmediateOperand(1)) {
162 case GLSLstd450InterpolateAtCentroid:
163 case GLSLstd450InterpolateAtSample:
164 case GLSLstd450InterpolateAtOffset:
165 addCapability(CapabilityInterpolationFunction);
177 addCapability(CapabilityDerivativeControl);
180 case OpImageQueryLod:
181 case OpImageQuerySize:
182 case OpImageQuerySizeLod:
183 case OpImageQuerySamples:
184 case OpImageQueryLevels:
185 addCapability(CapabilityImageQuery);
189 case OpGroupNonUniformPartitionNV:
190 addExtension(E_SPV_NV_shader_subgroup_partitioned);
191 addCapability(CapabilityGroupNonUniformPartitionedNV);
198 // For any load/store to a PhysicalStorageBufferEXT, walk the accesschain
199 // index list to compute the misalignment. The pre-existing alignment value
200 // (set via Builder::AccessChain::alignment) only accounts for the base of
201 // the reference type and any scalar component selection in the accesschain,
202 // and this function computes the rest from the SPIR-V Offset decorations.
203 Instruction *accessChain = module.getInstruction(inst.getIdOperand(0));
204 if (accessChain->getOpCode() == OpAccessChain) {
205 Instruction *base = module.getInstruction(accessChain->getIdOperand(0));
206 // Get the type of the base of the access chain. It must be a pointer type.
207 Id typeId = base->getTypeId();
208 Instruction *type = module.getInstruction(typeId);
209 assert(type->getOpCode() == OpTypePointer);
210 if (type->getImmediateOperand(0) != StorageClassPhysicalStorageBufferEXT) {
213 // Get the pointee type.
214 typeId = type->getIdOperand(1);
215 type = module.getInstruction(typeId);
216 // Walk the index list for the access chain. For each index, find any
217 // misalignment that can apply when accessing the member/element via
218 // Offset/ArrayStride/MatrixStride decorations, and bitwise OR them all
221 for (int i = 1; i < accessChain->getNumOperands(); ++i) {
222 Instruction *idx = module.getInstruction(accessChain->getIdOperand(i));
223 if (type->getOpCode() == OpTypeStruct) {
224 assert(idx->getOpCode() == OpConstant);
225 int c = idx->getImmediateOperand(0);
227 const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
228 if (decoration.get()->getOpCode() == OpMemberDecorate &&
229 decoration.get()->getIdOperand(0) == typeId &&
230 decoration.get()->getImmediateOperand(1) == c &&
231 (decoration.get()->getImmediateOperand(2) == DecorationOffset ||
232 decoration.get()->getImmediateOperand(2) == DecorationMatrixStride)) {
233 alignment |= decoration.get()->getImmediateOperand(3);
236 std::for_each(decorations.begin(), decorations.end(), function);
237 // get the next member type
238 typeId = type->getIdOperand(c);
239 type = module.getInstruction(typeId);
240 } else if (type->getOpCode() == OpTypeArray ||
241 type->getOpCode() == OpTypeRuntimeArray) {
242 const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
243 if (decoration.get()->getOpCode() == OpDecorate &&
244 decoration.get()->getIdOperand(0) == typeId &&
245 decoration.get()->getImmediateOperand(1) == DecorationArrayStride) {
246 alignment |= decoration.get()->getImmediateOperand(2);
249 std::for_each(decorations.begin(), decorations.end(), function);
250 // Get the element type
251 typeId = type->getIdOperand(0);
252 type = module.getInstruction(typeId);
254 // Once we get to any non-aggregate type, we're done.
258 assert(inst.getNumOperands() >= 3);
259 unsigned int memoryAccess = inst.getImmediateOperand((inst.getOpCode() == OpStore) ? 2 : 1);
260 assert(memoryAccess & MemoryAccessAlignedMask);
261 // Compute the index of the alignment operand.
262 int alignmentIdx = 2;
263 if (memoryAccess & MemoryAccessVolatileMask)
265 if (inst.getOpCode() == OpStore)
267 // Merge new and old (mis)alignment
268 alignment |= inst.getImmediateOperand(alignmentIdx);
270 alignment = alignment & ~(alignment & (alignment-1));
271 // update the Aligned operand
272 inst.setImmediateOperand(alignmentIdx, alignment);
281 // Checks based on type
282 if (inst.getTypeId() != NoType)
283 postProcessType(inst, inst.getTypeId());
284 for (int op = 0; op < inst.getNumOperands(); ++op) {
285 if (inst.isIdOperand(op)) {
286 // In blocks, these are always result ids, but we are relying on
287 // getTypeId() to return NoType for things like OpLabel.
288 if (getTypeId(inst.getIdOperand(op)) != NoType)
289 postProcessType(inst, getTypeId(inst.getIdOperand(op)));
294 // Called for each instruction in a reachable block.
295 void Builder::postProcessReachable(const Instruction&)
297 // did have code here, but questionable to do so without deleting the instructions
301 void Builder::postProcess()
303 std::unordered_set<const Block*> reachableBlocks;
304 std::unordered_set<Id> unreachableDefinitions;
305 // Collect IDs defined in unreachable blocks. For each function, label the
306 // reachable blocks first. Then for each unreachable block, collect the
307 // result IDs of the instructions in it.
308 for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
310 Block* entry = f->getEntryBlock();
311 inReadableOrder(entry, [&reachableBlocks](const Block* b) { reachableBlocks.insert(b); });
312 for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
314 if (reachableBlocks.count(b) == 0) {
315 for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ii++)
316 unreachableDefinitions.insert(ii->get()->getResultId());
321 // Remove unneeded decorations, for unreachable instructions
322 decorations.erase(std::remove_if(decorations.begin(), decorations.end(),
323 [&unreachableDefinitions](std::unique_ptr<Instruction>& I) -> bool {
324 Id decoration_id = I.get()->getIdOperand(0);
325 return unreachableDefinitions.count(decoration_id) != 0;
329 // Add per-instruction capabilities, extensions, etc.,
331 // process all reachable instructions...
332 for (auto bi = reachableBlocks.cbegin(); bi != reachableBlocks.cend(); ++bi) {
333 const Block* block = *bi;
334 const auto function = [this](const std::unique_ptr<Instruction>& inst) { postProcessReachable(*inst.get()); };
335 std::for_each(block->getInstructions().begin(), block->getInstructions().end(), function);
338 // process all block-contained instructions
339 for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
341 for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
343 for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ii++)
344 postProcess(*ii->get());
346 // For all local variables that contain pointers to PhysicalStorageBufferEXT, check whether
347 // there is an existing restrict/aliased decoration. If we don't find one, add Aliased as the
349 for (auto vi = b->getLocalVariables().cbegin(); vi != b->getLocalVariables().cend(); vi++) {
350 const Instruction& inst = *vi->get();
351 Id resultId = inst.getResultId();
352 if (containsPhysicalStorageBufferOrArray(getDerefTypeId(resultId))) {
353 bool foundDecoration = false;
354 const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
355 if (decoration.get()->getIdOperand(0) == resultId &&
356 decoration.get()->getOpCode() == OpDecorate &&
357 (decoration.get()->getImmediateOperand(1) == spv::DecorationAliasedPointerEXT ||
358 decoration.get()->getImmediateOperand(1) == spv::DecorationRestrictPointerEXT)) {
359 foundDecoration = true;
362 std::for_each(decorations.begin(), decorations.end(), function);
363 if (!foundDecoration) {
364 addDecoration(resultId, spv::DecorationAliasedPointerEXT);
371 // Look for any 8/16 bit type in physical storage buffer class, and set the
372 // appropriate capability. This happens in createSpvVariable for other storage
373 // classes, but there isn't always a variable for physical storage buffer.
374 for (int t = 0; t < (int)groupedTypes[OpTypePointer].size(); ++t) {
375 Instruction* type = groupedTypes[OpTypePointer][t];
376 if (type->getImmediateOperand(0) == (unsigned)StorageClassPhysicalStorageBufferEXT) {
377 if (containsType(type->getIdOperand(1), OpTypeInt, 8)) {
378 addExtension(spv::E_SPV_KHR_8bit_storage);
379 addCapability(spv::CapabilityStorageBuffer8BitAccess);
381 if (containsType(type->getIdOperand(1), OpTypeInt, 16) ||
382 containsType(type->getIdOperand(1), OpTypeFloat, 16)) {
383 addExtension(spv::E_SPV_KHR_16bit_storage);
384 addCapability(spv::CapabilityStorageBuffer16BitAccess);
390 }; // end spv namespace