Merge pull request #1679 from okuoku/fix-c-sample-code
[platform/upstream/SPIRV-Cross.git] / spirv_hlsl.cpp
1 /*
2  * Copyright 2016-2021 Robert Konrad
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17
18 /*
19  * At your option, you may choose to accept this material under either:
20  *  1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
21  *  2. The MIT License, found at <http://opensource.org/licenses/MIT>.
22  * SPDX-License-Identifier: Apache-2.0 OR MIT.
23  */
24
25 #include "spirv_hlsl.hpp"
26 #include "GLSL.std.450.h"
27 #include <algorithm>
28 #include <assert.h>
29
30 using namespace spv;
31 using namespace SPIRV_CROSS_NAMESPACE;
32 using namespace std;
33
34 enum class ImageFormatNormalizedState
35 {
36         None = 0,
37         Unorm = 1,
38         Snorm = 2
39 };
40
41 static ImageFormatNormalizedState image_format_to_normalized_state(ImageFormat fmt)
42 {
43         switch (fmt)
44         {
45         case ImageFormatR8:
46         case ImageFormatR16:
47         case ImageFormatRg8:
48         case ImageFormatRg16:
49         case ImageFormatRgba8:
50         case ImageFormatRgba16:
51         case ImageFormatRgb10A2:
52                 return ImageFormatNormalizedState::Unorm;
53
54         case ImageFormatR8Snorm:
55         case ImageFormatR16Snorm:
56         case ImageFormatRg8Snorm:
57         case ImageFormatRg16Snorm:
58         case ImageFormatRgba8Snorm:
59         case ImageFormatRgba16Snorm:
60                 return ImageFormatNormalizedState::Snorm;
61
62         default:
63                 break;
64         }
65
66         return ImageFormatNormalizedState::None;
67 }
68
69 static unsigned image_format_to_components(ImageFormat fmt)
70 {
71         switch (fmt)
72         {
73         case ImageFormatR8:
74         case ImageFormatR16:
75         case ImageFormatR8Snorm:
76         case ImageFormatR16Snorm:
77         case ImageFormatR16f:
78         case ImageFormatR32f:
79         case ImageFormatR8i:
80         case ImageFormatR16i:
81         case ImageFormatR32i:
82         case ImageFormatR8ui:
83         case ImageFormatR16ui:
84         case ImageFormatR32ui:
85                 return 1;
86
87         case ImageFormatRg8:
88         case ImageFormatRg16:
89         case ImageFormatRg8Snorm:
90         case ImageFormatRg16Snorm:
91         case ImageFormatRg16f:
92         case ImageFormatRg32f:
93         case ImageFormatRg8i:
94         case ImageFormatRg16i:
95         case ImageFormatRg32i:
96         case ImageFormatRg8ui:
97         case ImageFormatRg16ui:
98         case ImageFormatRg32ui:
99                 return 2;
100
101         case ImageFormatR11fG11fB10f:
102                 return 3;
103
104         case ImageFormatRgba8:
105         case ImageFormatRgba16:
106         case ImageFormatRgb10A2:
107         case ImageFormatRgba8Snorm:
108         case ImageFormatRgba16Snorm:
109         case ImageFormatRgba16f:
110         case ImageFormatRgba32f:
111         case ImageFormatRgba8i:
112         case ImageFormatRgba16i:
113         case ImageFormatRgba32i:
114         case ImageFormatRgba8ui:
115         case ImageFormatRgba16ui:
116         case ImageFormatRgba32ui:
117         case ImageFormatRgb10a2ui:
118                 return 4;
119
120         case ImageFormatUnknown:
121                 return 4; // Assume 4.
122
123         default:
124                 SPIRV_CROSS_THROW("Unrecognized typed image format.");
125         }
126 }
127
128 static string image_format_to_type(ImageFormat fmt, SPIRType::BaseType basetype)
129 {
130         switch (fmt)
131         {
132         case ImageFormatR8:
133         case ImageFormatR16:
134                 if (basetype != SPIRType::Float)
135                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
136                 return "unorm float";
137         case ImageFormatRg8:
138         case ImageFormatRg16:
139                 if (basetype != SPIRType::Float)
140                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
141                 return "unorm float2";
142         case ImageFormatRgba8:
143         case ImageFormatRgba16:
144                 if (basetype != SPIRType::Float)
145                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
146                 return "unorm float4";
147         case ImageFormatRgb10A2:
148                 if (basetype != SPIRType::Float)
149                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
150                 return "unorm float4";
151
152         case ImageFormatR8Snorm:
153         case ImageFormatR16Snorm:
154                 if (basetype != SPIRType::Float)
155                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
156                 return "snorm float";
157         case ImageFormatRg8Snorm:
158         case ImageFormatRg16Snorm:
159                 if (basetype != SPIRType::Float)
160                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
161                 return "snorm float2";
162         case ImageFormatRgba8Snorm:
163         case ImageFormatRgba16Snorm:
164                 if (basetype != SPIRType::Float)
165                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
166                 return "snorm float4";
167
168         case ImageFormatR16f:
169         case ImageFormatR32f:
170                 if (basetype != SPIRType::Float)
171                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
172                 return "float";
173         case ImageFormatRg16f:
174         case ImageFormatRg32f:
175                 if (basetype != SPIRType::Float)
176                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
177                 return "float2";
178         case ImageFormatRgba16f:
179         case ImageFormatRgba32f:
180                 if (basetype != SPIRType::Float)
181                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
182                 return "float4";
183
184         case ImageFormatR11fG11fB10f:
185                 if (basetype != SPIRType::Float)
186                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
187                 return "float3";
188
189         case ImageFormatR8i:
190         case ImageFormatR16i:
191         case ImageFormatR32i:
192                 if (basetype != SPIRType::Int)
193                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
194                 return "int";
195         case ImageFormatRg8i:
196         case ImageFormatRg16i:
197         case ImageFormatRg32i:
198                 if (basetype != SPIRType::Int)
199                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
200                 return "int2";
201         case ImageFormatRgba8i:
202         case ImageFormatRgba16i:
203         case ImageFormatRgba32i:
204                 if (basetype != SPIRType::Int)
205                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
206                 return "int4";
207
208         case ImageFormatR8ui:
209         case ImageFormatR16ui:
210         case ImageFormatR32ui:
211                 if (basetype != SPIRType::UInt)
212                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
213                 return "uint";
214         case ImageFormatRg8ui:
215         case ImageFormatRg16ui:
216         case ImageFormatRg32ui:
217                 if (basetype != SPIRType::UInt)
218                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
219                 return "uint2";
220         case ImageFormatRgba8ui:
221         case ImageFormatRgba16ui:
222         case ImageFormatRgba32ui:
223                 if (basetype != SPIRType::UInt)
224                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
225                 return "uint4";
226         case ImageFormatRgb10a2ui:
227                 if (basetype != SPIRType::UInt)
228                         SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
229                 return "uint4";
230
231         case ImageFormatUnknown:
232                 switch (basetype)
233                 {
234                 case SPIRType::Float:
235                         return "float4";
236                 case SPIRType::Int:
237                         return "int4";
238                 case SPIRType::UInt:
239                         return "uint4";
240                 default:
241                         SPIRV_CROSS_THROW("Unsupported base type for image.");
242                 }
243
244         default:
245                 SPIRV_CROSS_THROW("Unrecognized typed image format.");
246         }
247 }
248
249 string CompilerHLSL::image_type_hlsl_modern(const SPIRType &type, uint32_t id)
250 {
251         auto &imagetype = get<SPIRType>(type.image.type);
252         const char *dim = nullptr;
253         bool typed_load = false;
254         uint32_t components = 4;
255
256         bool force_image_srv = hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id, DecorationNonWritable);
257
258         switch (type.image.dim)
259         {
260         case Dim1D:
261                 typed_load = type.image.sampled == 2;
262                 dim = "1D";
263                 break;
264         case Dim2D:
265                 typed_load = type.image.sampled == 2;
266                 dim = "2D";
267                 break;
268         case Dim3D:
269                 typed_load = type.image.sampled == 2;
270                 dim = "3D";
271                 break;
272         case DimCube:
273                 if (type.image.sampled == 2)
274                         SPIRV_CROSS_THROW("RWTextureCube does not exist in HLSL.");
275                 dim = "Cube";
276                 break;
277         case DimRect:
278                 SPIRV_CROSS_THROW("Rectangle texture support is not yet implemented for HLSL."); // TODO
279         case DimBuffer:
280                 if (type.image.sampled == 1)
281                         return join("Buffer<", type_to_glsl(imagetype), components, ">");
282                 else if (type.image.sampled == 2)
283                 {
284                         if (interlocked_resources.count(id))
285                                 return join("RasterizerOrderedBuffer<", image_format_to_type(type.image.format, imagetype.basetype),
286                                             ">");
287
288                         typed_load = !force_image_srv && type.image.sampled == 2;
289
290                         const char *rw = force_image_srv ? "" : "RW";
291                         return join(rw, "Buffer<",
292                                     typed_load ? image_format_to_type(type.image.format, imagetype.basetype) :
293                                                  join(type_to_glsl(imagetype), components),
294                                     ">");
295                 }
296                 else
297                         SPIRV_CROSS_THROW("Sampler buffers must be either sampled or unsampled. Cannot deduce in runtime.");
298         case DimSubpassData:
299                 dim = "2D";
300                 typed_load = false;
301                 break;
302         default:
303                 SPIRV_CROSS_THROW("Invalid dimension.");
304         }
305         const char *arrayed = type.image.arrayed ? "Array" : "";
306         const char *ms = type.image.ms ? "MS" : "";
307         const char *rw = typed_load && !force_image_srv ? "RW" : "";
308
309         if (force_image_srv)
310                 typed_load = false;
311
312         if (typed_load && interlocked_resources.count(id))
313                 rw = "RasterizerOrdered";
314
315         return join(rw, "Texture", dim, ms, arrayed, "<",
316                     typed_load ? image_format_to_type(type.image.format, imagetype.basetype) :
317                                  join(type_to_glsl(imagetype), components),
318                     ">");
319 }
320
321 string CompilerHLSL::image_type_hlsl_legacy(const SPIRType &type, uint32_t /*id*/)
322 {
323         auto &imagetype = get<SPIRType>(type.image.type);
324         string res;
325
326         switch (imagetype.basetype)
327         {
328         case SPIRType::Int:
329                 res = "i";
330                 break;
331         case SPIRType::UInt:
332                 res = "u";
333                 break;
334         default:
335                 break;
336         }
337
338         if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData)
339                 return res + "subpassInput" + (type.image.ms ? "MS" : "");
340
341         // If we're emulating subpassInput with samplers, force sampler2D
342         // so we don't have to specify format.
343         if (type.basetype == SPIRType::Image && type.image.dim != DimSubpassData)
344         {
345                 // Sampler buffers are always declared as samplerBuffer even though they might be separate images in the SPIR-V.
346                 if (type.image.dim == DimBuffer && type.image.sampled == 1)
347                         res += "sampler";
348                 else
349                         res += type.image.sampled == 2 ? "image" : "texture";
350         }
351         else
352                 res += "sampler";
353
354         switch (type.image.dim)
355         {
356         case Dim1D:
357                 res += "1D";
358                 break;
359         case Dim2D:
360                 res += "2D";
361                 break;
362         case Dim3D:
363                 res += "3D";
364                 break;
365         case DimCube:
366                 res += "CUBE";
367                 break;
368
369         case DimBuffer:
370                 res += "Buffer";
371                 break;
372
373         case DimSubpassData:
374                 res += "2D";
375                 break;
376         default:
377                 SPIRV_CROSS_THROW("Only 1D, 2D, 3D, Buffer, InputTarget and Cube textures supported.");
378         }
379
380         if (type.image.ms)
381                 res += "MS";
382         if (type.image.arrayed)
383                 res += "Array";
384
385         return res;
386 }
387
388 string CompilerHLSL::image_type_hlsl(const SPIRType &type, uint32_t id)
389 {
390         if (hlsl_options.shader_model <= 30)
391                 return image_type_hlsl_legacy(type, id);
392         else
393                 return image_type_hlsl_modern(type, id);
394 }
395
396 // The optional id parameter indicates the object whose type we are trying
397 // to find the description for. It is optional. Most type descriptions do not
398 // depend on a specific object's use of that type.
399 string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
400 {
401         // Ignore the pointer type since GLSL doesn't have pointers.
402
403         switch (type.basetype)
404         {
405         case SPIRType::Struct:
406                 // Need OpName lookup here to get a "sensible" name for a struct.
407                 if (backend.explicit_struct_type)
408                         return join("struct ", to_name(type.self));
409                 else
410                         return to_name(type.self);
411
412         case SPIRType::Image:
413         case SPIRType::SampledImage:
414                 return image_type_hlsl(type, id);
415
416         case SPIRType::Sampler:
417                 return comparison_ids.count(id) ? "SamplerComparisonState" : "SamplerState";
418
419         case SPIRType::Void:
420                 return "void";
421
422         default:
423                 break;
424         }
425
426         if (type.vecsize == 1 && type.columns == 1) // Scalar builtin
427         {
428                 switch (type.basetype)
429                 {
430                 case SPIRType::Boolean:
431                         return "bool";
432                 case SPIRType::Int:
433                         return backend.basic_int_type;
434                 case SPIRType::UInt:
435                         return backend.basic_uint_type;
436                 case SPIRType::AtomicCounter:
437                         return "atomic_uint";
438                 case SPIRType::Half:
439                         if (hlsl_options.enable_16bit_types)
440                                 return "half";
441                         else
442                                 return "min16float";
443                 case SPIRType::Short:
444                         if (hlsl_options.enable_16bit_types)
445                                 return "int16_t";
446                         else
447                                 return "min16int";
448                 case SPIRType::UShort:
449                         if (hlsl_options.enable_16bit_types)
450                                 return "uint16_t";
451                         else
452                                 return "min16uint";
453                 case SPIRType::Float:
454                         return "float";
455                 case SPIRType::Double:
456                         return "double";
457                 case SPIRType::Int64:
458                         if (hlsl_options.shader_model < 60)
459                                 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
460                         return "int64_t";
461                 case SPIRType::UInt64:
462                         if (hlsl_options.shader_model < 60)
463                                 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
464                         return "uint64_t";
465                 default:
466                         return "???";
467                 }
468         }
469         else if (type.vecsize > 1 && type.columns == 1) // Vector builtin
470         {
471                 switch (type.basetype)
472                 {
473                 case SPIRType::Boolean:
474                         return join("bool", type.vecsize);
475                 case SPIRType::Int:
476                         return join("int", type.vecsize);
477                 case SPIRType::UInt:
478                         return join("uint", type.vecsize);
479                 case SPIRType::Half:
480                         return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.vecsize);
481                 case SPIRType::Short:
482                         return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.vecsize);
483                 case SPIRType::UShort:
484                         return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.vecsize);
485                 case SPIRType::Float:
486                         return join("float", type.vecsize);
487                 case SPIRType::Double:
488                         return join("double", type.vecsize);
489                 case SPIRType::Int64:
490                         return join("i64vec", type.vecsize);
491                 case SPIRType::UInt64:
492                         return join("u64vec", type.vecsize);
493                 default:
494                         return "???";
495                 }
496         }
497         else
498         {
499                 switch (type.basetype)
500                 {
501                 case SPIRType::Boolean:
502                         return join("bool", type.columns, "x", type.vecsize);
503                 case SPIRType::Int:
504                         return join("int", type.columns, "x", type.vecsize);
505                 case SPIRType::UInt:
506                         return join("uint", type.columns, "x", type.vecsize);
507                 case SPIRType::Half:
508                         return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.columns, "x", type.vecsize);
509                 case SPIRType::Short:
510                         return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.columns, "x", type.vecsize);
511                 case SPIRType::UShort:
512                         return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.columns, "x", type.vecsize);
513                 case SPIRType::Float:
514                         return join("float", type.columns, "x", type.vecsize);
515                 case SPIRType::Double:
516                         return join("double", type.columns, "x", type.vecsize);
517                 // Matrix types not supported for int64/uint64.
518                 default:
519                         return "???";
520                 }
521         }
522 }
523
524 void CompilerHLSL::emit_header()
525 {
526         for (auto &header : header_lines)
527                 statement(header);
528
529         if (header_lines.size() > 0)
530         {
531                 statement("");
532         }
533 }
534
535 void CompilerHLSL::emit_interface_block_globally(const SPIRVariable &var)
536 {
537         add_resource_name(var.self);
538
539         // The global copies of I/O variables should not contain interpolation qualifiers.
540         // These are emitted inside the interface structs.
541         auto &flags = ir.meta[var.self].decoration.decoration_flags;
542         auto old_flags = flags;
543         flags.reset();
544         statement("static ", variable_decl(var), ";");
545         flags = old_flags;
546 }
547
548 const char *CompilerHLSL::to_storage_qualifiers_glsl(const SPIRVariable &var)
549 {
550         // Input and output variables are handled specially in HLSL backend.
551         // The variables are declared as global, private variables, and do not need any qualifiers.
552         if (var.storage == StorageClassUniformConstant || var.storage == StorageClassUniform ||
553             var.storage == StorageClassPushConstant)
554         {
555                 return "uniform ";
556         }
557
558         return "";
559 }
560
561 void CompilerHLSL::emit_builtin_outputs_in_struct()
562 {
563         auto &execution = get_entry_point();
564
565         bool legacy = hlsl_options.shader_model <= 30;
566         active_output_builtins.for_each_bit([&](uint32_t i) {
567                 const char *type = nullptr;
568                 const char *semantic = nullptr;
569                 auto builtin = static_cast<BuiltIn>(i);
570                 switch (builtin)
571                 {
572                 case BuiltInPosition:
573                         type = is_position_invariant() && backend.support_precise_qualifier ? "precise float4" : "float4";
574                         semantic = legacy ? "POSITION" : "SV_Position";
575                         break;
576
577                 case BuiltInSampleMask:
578                         if (hlsl_options.shader_model < 41 || execution.model != ExecutionModelFragment)
579                                 SPIRV_CROSS_THROW("Sample Mask output is only supported in PS 4.1 or higher.");
580                         type = "uint";
581                         semantic = "SV_Coverage";
582                         break;
583
584                 case BuiltInFragDepth:
585                         type = "float";
586                         if (legacy)
587                         {
588                                 semantic = "DEPTH";
589                         }
590                         else
591                         {
592                                 if (hlsl_options.shader_model >= 50 && execution.flags.get(ExecutionModeDepthGreater))
593                                         semantic = "SV_DepthGreaterEqual";
594                                 else if (hlsl_options.shader_model >= 50 && execution.flags.get(ExecutionModeDepthLess))
595                                         semantic = "SV_DepthLessEqual";
596                                 else
597                                         semantic = "SV_Depth";
598                         }
599                         break;
600
601                 case BuiltInClipDistance:
602                         // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
603                         for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
604                         {
605                                 uint32_t to_declare = clip_distance_count - clip;
606                                 if (to_declare > 4)
607                                         to_declare = 4;
608
609                                 uint32_t semantic_index = clip / 4;
610
611                                 static const char *types[] = { "float", "float2", "float3", "float4" };
612                                 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
613                                           " : SV_ClipDistance", semantic_index, ";");
614                         }
615                         break;
616
617                 case BuiltInCullDistance:
618                         // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
619                         for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
620                         {
621                                 uint32_t to_declare = cull_distance_count - cull;
622                                 if (to_declare > 4)
623                                         to_declare = 4;
624
625                                 uint32_t semantic_index = cull / 4;
626
627                                 static const char *types[] = { "float", "float2", "float3", "float4" };
628                                 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
629                                           " : SV_CullDistance", semantic_index, ";");
630                         }
631                         break;
632
633                 case BuiltInPointSize:
634                         // If point_size_compat is enabled, just ignore PointSize.
635                         // PointSize does not exist in HLSL, but some code bases might want to be able to use these shaders,
636                         // even if it means working around the missing feature.
637                         if (hlsl_options.point_size_compat)
638                                 break;
639                         else
640                                 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
641
642                 default:
643                         SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
644                         break;
645                 }
646
647                 if (type && semantic)
648                         statement(type, " ", builtin_to_glsl(builtin, StorageClassOutput), " : ", semantic, ";");
649         });
650 }
651
652 void CompilerHLSL::emit_builtin_inputs_in_struct()
653 {
654         bool legacy = hlsl_options.shader_model <= 30;
655         active_input_builtins.for_each_bit([&](uint32_t i) {
656                 const char *type = nullptr;
657                 const char *semantic = nullptr;
658                 auto builtin = static_cast<BuiltIn>(i);
659                 switch (builtin)
660                 {
661                 case BuiltInFragCoord:
662                         type = "float4";
663                         semantic = legacy ? "VPOS" : "SV_Position";
664                         break;
665
666                 case BuiltInVertexId:
667                 case BuiltInVertexIndex:
668                         if (legacy)
669                                 SPIRV_CROSS_THROW("Vertex index not supported in SM 3.0 or lower.");
670                         type = "uint";
671                         semantic = "SV_VertexID";
672                         break;
673
674                 case BuiltInInstanceId:
675                 case BuiltInInstanceIndex:
676                         if (legacy)
677                                 SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower.");
678                         type = "uint";
679                         semantic = "SV_InstanceID";
680                         break;
681
682                 case BuiltInSampleId:
683                         if (legacy)
684                                 SPIRV_CROSS_THROW("Sample ID not supported in SM 3.0 or lower.");
685                         type = "uint";
686                         semantic = "SV_SampleIndex";
687                         break;
688
689                 case BuiltInSampleMask:
690                         if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
691                                 SPIRV_CROSS_THROW("Sample Mask input is only supported in PS 5.0 or higher.");
692                         type = "uint";
693                         semantic = "SV_Coverage";
694                         break;
695
696                 case BuiltInGlobalInvocationId:
697                         type = "uint3";
698                         semantic = "SV_DispatchThreadID";
699                         break;
700
701                 case BuiltInLocalInvocationId:
702                         type = "uint3";
703                         semantic = "SV_GroupThreadID";
704                         break;
705
706                 case BuiltInLocalInvocationIndex:
707                         type = "uint";
708                         semantic = "SV_GroupIndex";
709                         break;
710
711                 case BuiltInWorkgroupId:
712                         type = "uint3";
713                         semantic = "SV_GroupID";
714                         break;
715
716                 case BuiltInFrontFacing:
717                         type = "bool";
718                         semantic = "SV_IsFrontFace";
719                         break;
720
721                 case BuiltInNumWorkgroups:
722                 case BuiltInSubgroupSize:
723                 case BuiltInSubgroupLocalInvocationId:
724                 case BuiltInSubgroupEqMask:
725                 case BuiltInSubgroupLtMask:
726                 case BuiltInSubgroupLeMask:
727                 case BuiltInSubgroupGtMask:
728                 case BuiltInSubgroupGeMask:
729                         // Handled specially.
730                         break;
731
732                 case BuiltInClipDistance:
733                         // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
734                         for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
735                         {
736                                 uint32_t to_declare = clip_distance_count - clip;
737                                 if (to_declare > 4)
738                                         to_declare = 4;
739
740                                 uint32_t semantic_index = clip / 4;
741
742                                 static const char *types[] = { "float", "float2", "float3", "float4" };
743                                 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassInput), semantic_index,
744                                           " : SV_ClipDistance", semantic_index, ";");
745                         }
746                         break;
747
748                 case BuiltInCullDistance:
749                         // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
750                         for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
751                         {
752                                 uint32_t to_declare = cull_distance_count - cull;
753                                 if (to_declare > 4)
754                                         to_declare = 4;
755
756                                 uint32_t semantic_index = cull / 4;
757
758                                 static const char *types[] = { "float", "float2", "float3", "float4" };
759                                 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassInput), semantic_index,
760                                           " : SV_CullDistance", semantic_index, ";");
761                         }
762                         break;
763
764                 case BuiltInPointCoord:
765                         // PointCoord is not supported, but provide a way to just ignore that, similar to PointSize.
766                         if (hlsl_options.point_coord_compat)
767                                 break;
768                         else
769                                 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
770
771                 default:
772                         SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
773                         break;
774                 }
775
776                 if (type && semantic)
777                         statement(type, " ", builtin_to_glsl(builtin, StorageClassInput), " : ", semantic, ";");
778         });
779 }
780
781 uint32_t CompilerHLSL::type_to_consumed_locations(const SPIRType &type) const
782 {
783         // TODO: Need to verify correctness.
784         uint32_t elements = 0;
785
786         if (type.basetype == SPIRType::Struct)
787         {
788                 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
789                         elements += type_to_consumed_locations(get<SPIRType>(type.member_types[i]));
790         }
791         else
792         {
793                 uint32_t array_multiplier = 1;
794                 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
795                 {
796                         if (type.array_size_literal[i])
797                                 array_multiplier *= type.array[i];
798                         else
799                                 array_multiplier *= evaluate_constant_u32(type.array[i]);
800                 }
801                 elements += array_multiplier * type.columns;
802         }
803         return elements;
804 }
805
806 string CompilerHLSL::to_interpolation_qualifiers(const Bitset &flags)
807 {
808         string res;
809         //if (flags & (1ull << DecorationSmooth))
810         //    res += "linear ";
811         if (flags.get(DecorationFlat))
812                 res += "nointerpolation ";
813         if (flags.get(DecorationNoPerspective))
814                 res += "noperspective ";
815         if (flags.get(DecorationCentroid))
816                 res += "centroid ";
817         if (flags.get(DecorationPatch))
818                 res += "patch "; // Seems to be different in actual HLSL.
819         if (flags.get(DecorationSample))
820                 res += "sample ";
821         if (flags.get(DecorationInvariant) && backend.support_precise_qualifier)
822                 res += "precise "; // Not supported?
823
824         return res;
825 }
826
827 std::string CompilerHLSL::to_semantic(uint32_t location, ExecutionModel em, StorageClass sc)
828 {
829         if (em == ExecutionModelVertex && sc == StorageClassInput)
830         {
831                 // We have a vertex attribute - we should look at remapping it if the user provided
832                 // vertex attribute hints.
833                 for (auto &attribute : remap_vertex_attributes)
834                         if (attribute.location == location)
835                                 return attribute.semantic;
836         }
837
838         // Not a vertex attribute, or no remap_vertex_attributes entry.
839         return join("TEXCOORD", location);
840 }
841
842 std::string CompilerHLSL::to_initializer_expression(const SPIRVariable &var)
843 {
844         // We cannot emit static const initializer for block constants for practical reasons,
845         // so just inline the initializer.
846         // FIXME: There is a theoretical problem here if someone tries to composite extract
847         // into this initializer since we don't declare it properly, but that is somewhat non-sensical.
848         auto &type = get<SPIRType>(var.basetype);
849         bool is_block = has_decoration(type.self, DecorationBlock);
850         auto *c = maybe_get<SPIRConstant>(var.initializer);
851         if (is_block && c)
852                 return constant_expression(*c);
853         else
854                 return CompilerGLSL::to_initializer_expression(var);
855 }
856
857 void CompilerHLSL::emit_io_block(const SPIRVariable &var)
858 {
859         auto &execution = get_entry_point();
860
861         auto &type = get<SPIRType>(var.basetype);
862         add_resource_name(type.self);
863
864         statement("struct ", to_name(type.self));
865         begin_scope();
866         type.member_name_cache.clear();
867
868         for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
869         {
870                 uint32_t location = get_accumulated_member_location(var, i, false);
871                 string semantic = join(" : ", to_semantic(location, execution.model, var.storage));
872
873                 add_member_name(type, i);
874
875                 auto &membertype = get<SPIRType>(type.member_types[i]);
876                 statement(to_interpolation_qualifiers(get_member_decoration_bitset(type.self, i)),
877                           variable_decl(membertype, to_member_name(type, i)), semantic, ";");
878         }
879
880         end_scope_decl();
881         statement("");
882
883         statement("static ", variable_decl(var), ";");
884         statement("");
885 }
886
887 void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unordered_set<uint32_t> &active_locations)
888 {
889         auto &execution = get_entry_point();
890         auto type = get<SPIRType>(var.basetype);
891
892         string binding;
893         bool use_location_number = true;
894         bool legacy = hlsl_options.shader_model <= 30;
895         if (execution.model == ExecutionModelFragment && var.storage == StorageClassOutput)
896         {
897                 // Dual-source blending is achieved in HLSL by emitting to SV_Target0 and 1.
898                 uint32_t index = get_decoration(var.self, DecorationIndex);
899                 uint32_t location = get_decoration(var.self, DecorationLocation);
900
901                 if (index != 0 && location != 0)
902                         SPIRV_CROSS_THROW("Dual-source blending is only supported on MRT #0 in HLSL.");
903
904                 binding = join(legacy ? "COLOR" : "SV_Target", location + index);
905                 use_location_number = false;
906                 if (legacy) // COLOR must be a four-component vector on legacy shader model targets (HLSL ERR_COLOR_4COMP)
907                         type.vecsize = 4;
908         }
909
910         const auto get_vacant_location = [&]() -> uint32_t {
911                 for (uint32_t i = 0; i < 64; i++)
912                         if (!active_locations.count(i))
913                                 return i;
914                 SPIRV_CROSS_THROW("All locations from 0 to 63 are exhausted.");
915         };
916
917         bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
918
919         auto &m = ir.meta[var.self].decoration;
920         auto name = to_name(var.self);
921         if (use_location_number)
922         {
923                 uint32_t location_number;
924
925                 // If an explicit location exists, use it with TEXCOORD[N] semantic.
926                 // Otherwise, pick a vacant location.
927                 if (m.decoration_flags.get(DecorationLocation))
928                         location_number = m.location;
929                 else
930                         location_number = get_vacant_location();
931
932                 // Allow semantic remap if specified.
933                 auto semantic = to_semantic(location_number, execution.model, var.storage);
934
935                 if (need_matrix_unroll && type.columns > 1)
936                 {
937                         if (!type.array.empty())
938                                 SPIRV_CROSS_THROW("Arrays of matrices used as input/output. This is not supported.");
939
940                         // Unroll matrices.
941                         for (uint32_t i = 0; i < type.columns; i++)
942                         {
943                                 SPIRType newtype = type;
944                                 newtype.columns = 1;
945
946                                 string effective_semantic;
947                                 if (hlsl_options.flatten_matrix_vertex_input_semantics)
948                                         effective_semantic = to_semantic(location_number, execution.model, var.storage);
949                                 else
950                                         effective_semantic = join(semantic, "_", i);
951
952                                 statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)),
953                                           variable_decl(newtype, join(name, "_", i)), " : ", effective_semantic, ";");
954                                 active_locations.insert(location_number++);
955                         }
956                 }
957                 else
958                 {
959                         statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)), variable_decl(type, name), " : ",
960                                   semantic, ";");
961
962                         // Structs and arrays should consume more locations.
963                         uint32_t consumed_locations = type_to_consumed_locations(type);
964                         for (uint32_t i = 0; i < consumed_locations; i++)
965                                 active_locations.insert(location_number + i);
966                 }
967         }
968         else
969                 statement(variable_decl(type, name), " : ", binding, ";");
970 }
971
972 std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage)
973 {
974         switch (builtin)
975         {
976         case BuiltInVertexId:
977                 return "gl_VertexID";
978         case BuiltInInstanceId:
979                 return "gl_InstanceID";
980         case BuiltInNumWorkgroups:
981         {
982                 if (!num_workgroups_builtin)
983                         SPIRV_CROSS_THROW("NumWorkgroups builtin is used, but remap_num_workgroups_builtin() was not called. "
984                                           "Cannot emit code for this builtin.");
985
986                 auto &var = get<SPIRVariable>(num_workgroups_builtin);
987                 auto &type = get<SPIRType>(var.basetype);
988                 auto ret = join(to_name(num_workgroups_builtin), "_", get_member_name(type.self, 0));
989                 ParsedIR::sanitize_underscores(ret);
990                 return ret;
991         }
992         case BuiltInPointCoord:
993                 // Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set.
994                 return "float2(0.5f, 0.5f)";
995         case BuiltInSubgroupLocalInvocationId:
996                 return "WaveGetLaneIndex()";
997         case BuiltInSubgroupSize:
998                 return "WaveGetLaneCount()";
999
1000         default:
1001                 return CompilerGLSL::builtin_to_glsl(builtin, storage);
1002         }
1003 }
1004
1005 void CompilerHLSL::emit_builtin_variables()
1006 {
1007         Bitset builtins = active_input_builtins;
1008         builtins.merge_or(active_output_builtins);
1009
1010         bool need_base_vertex_info = false;
1011
1012         std::unordered_map<uint32_t, ID> builtin_to_initializer;
1013         ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1014                 if (!is_builtin_variable(var) || var.storage != StorageClassOutput || !var.initializer)
1015                         return;
1016
1017                 auto *c = this->maybe_get<SPIRConstant>(var.initializer);
1018                 if (!c)
1019                         return;
1020
1021                 auto &type = this->get<SPIRType>(var.basetype);
1022                 if (type.basetype == SPIRType::Struct)
1023                 {
1024                         uint32_t member_count = uint32_t(type.member_types.size());
1025                         for (uint32_t i = 0; i < member_count; i++)
1026                         {
1027                                 if (has_member_decoration(type.self, i, DecorationBuiltIn))
1028                                 {
1029                                         builtin_to_initializer[get_member_decoration(type.self, i, DecorationBuiltIn)] =
1030                                                 c->subconstants[i];
1031                                 }
1032                         }
1033                 }
1034                 else if (has_decoration(var.self, DecorationBuiltIn))
1035                         builtin_to_initializer[get_decoration(var.self, DecorationBuiltIn)] = var.initializer;
1036         });
1037
1038         // Emit global variables for the interface variables which are statically used by the shader.
1039         builtins.for_each_bit([&](uint32_t i) {
1040                 const char *type = nullptr;
1041                 auto builtin = static_cast<BuiltIn>(i);
1042                 uint32_t array_size = 0;
1043
1044                 string init_expr;
1045                 auto init_itr = builtin_to_initializer.find(builtin);
1046                 if (init_itr != builtin_to_initializer.end())
1047                         init_expr = join(" = ", to_expression(init_itr->second));
1048
1049                 switch (builtin)
1050                 {
1051                 case BuiltInFragCoord:
1052                 case BuiltInPosition:
1053                         type = "float4";
1054                         break;
1055
1056                 case BuiltInFragDepth:
1057                         type = "float";
1058                         break;
1059
1060                 case BuiltInVertexId:
1061                 case BuiltInVertexIndex:
1062                 case BuiltInInstanceIndex:
1063                         type = "int";
1064                         if (hlsl_options.support_nonzero_base_vertex_base_instance)
1065                                 need_base_vertex_info = true;
1066                         break;
1067
1068                 case BuiltInInstanceId:
1069                 case BuiltInSampleId:
1070                         type = "int";
1071                         break;
1072
1073                 case BuiltInPointSize:
1074                         if (hlsl_options.point_size_compat)
1075                         {
1076                                 // Just emit the global variable, it will be ignored.
1077                                 type = "float";
1078                                 break;
1079                         }
1080                         else
1081                                 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1082
1083                 case BuiltInGlobalInvocationId:
1084                 case BuiltInLocalInvocationId:
1085                 case BuiltInWorkgroupId:
1086                         type = "uint3";
1087                         break;
1088
1089                 case BuiltInLocalInvocationIndex:
1090                         type = "uint";
1091                         break;
1092
1093                 case BuiltInFrontFacing:
1094                         type = "bool";
1095                         break;
1096
1097                 case BuiltInNumWorkgroups:
1098                 case BuiltInPointCoord:
1099                         // Handled specially.
1100                         break;
1101
1102                 case BuiltInSubgroupLocalInvocationId:
1103                 case BuiltInSubgroupSize:
1104                         if (hlsl_options.shader_model < 60)
1105                                 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1106                         break;
1107
1108                 case BuiltInSubgroupEqMask:
1109                 case BuiltInSubgroupLtMask:
1110                 case BuiltInSubgroupLeMask:
1111                 case BuiltInSubgroupGtMask:
1112                 case BuiltInSubgroupGeMask:
1113                         if (hlsl_options.shader_model < 60)
1114                                 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1115                         type = "uint4";
1116                         break;
1117
1118                 case BuiltInClipDistance:
1119                         array_size = clip_distance_count;
1120                         type = "float";
1121                         break;
1122
1123                 case BuiltInCullDistance:
1124                         array_size = cull_distance_count;
1125                         type = "float";
1126                         break;
1127
1128                 case BuiltInSampleMask:
1129                         type = "int";
1130                         break;
1131
1132                 default:
1133                         SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1134                 }
1135
1136                 StorageClass storage = active_input_builtins.get(i) ? StorageClassInput : StorageClassOutput;
1137
1138                 if (type)
1139                 {
1140                         if (array_size)
1141                                 statement("static ", type, " ", builtin_to_glsl(builtin, storage), "[", array_size, "]", init_expr, ";");
1142                         else
1143                                 statement("static ", type, " ", builtin_to_glsl(builtin, storage), init_expr, ";");
1144                 }
1145
1146                 // SampleMask can be both in and out with sample builtin, in this case we have already
1147                 // declared the input variable and we need to add the output one now.
1148                 if (builtin == BuiltInSampleMask && storage == StorageClassInput && this->active_output_builtins.get(i))
1149                 {
1150                         statement("static ", type, " ", this->builtin_to_glsl(builtin, StorageClassOutput), init_expr, ";");
1151                 }
1152         });
1153
1154         if (need_base_vertex_info)
1155         {
1156                 statement("cbuffer SPIRV_Cross_VertexInfo");
1157                 begin_scope();
1158                 statement("int SPIRV_Cross_BaseVertex;");
1159                 statement("int SPIRV_Cross_BaseInstance;");
1160                 end_scope_decl();
1161                 statement("");
1162         }
1163 }
1164
1165 void CompilerHLSL::emit_composite_constants()
1166 {
1167         // HLSL cannot declare structs or arrays inline, so we must move them out to
1168         // global constants directly.
1169         bool emitted = false;
1170
1171         ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
1172                 if (c.specialization)
1173                         return;
1174
1175                 auto &type = this->get<SPIRType>(c.constant_type);
1176
1177                 // Cannot declare block type constants here.
1178                 // We do not have the struct type yet.
1179                 bool is_block = has_decoration(type.self, DecorationBlock);
1180                 if (!is_block && (type.basetype == SPIRType::Struct || !type.array.empty()))
1181                 {
1182                         auto name = to_name(c.self);
1183                         statement("static const ", variable_decl(type, name), " = ", constant_expression(c), ";");
1184                         emitted = true;
1185                 }
1186         });
1187
1188         if (emitted)
1189                 statement("");
1190 }
1191
1192 void CompilerHLSL::emit_specialization_constants_and_structs()
1193 {
1194         bool emitted = false;
1195         SpecializationConstant wg_x, wg_y, wg_z;
1196         ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
1197
1198         auto loop_lock = ir.create_loop_hard_lock();
1199         for (auto &id_ : ir.ids_for_constant_or_type)
1200         {
1201                 auto &id = ir.ids[id_];
1202
1203                 if (id.get_type() == TypeConstant)
1204                 {
1205                         auto &c = id.get<SPIRConstant>();
1206
1207                         if (c.self == workgroup_size_id)
1208                         {
1209                                 statement("static const uint3 gl_WorkGroupSize = ",
1210                                           constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
1211                                 emitted = true;
1212                         }
1213                         else if (c.specialization)
1214                         {
1215                                 auto &type = get<SPIRType>(c.constant_type);
1216                                 auto name = to_name(c.self);
1217
1218                                 // HLSL does not support specialization constants, so fallback to macros.
1219                                 c.specialization_constant_macro_name =
1220                                     constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
1221
1222                                 statement("#ifndef ", c.specialization_constant_macro_name);
1223                                 statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
1224                                 statement("#endif");
1225                                 statement("static const ", variable_decl(type, name), " = ", c.specialization_constant_macro_name, ";");
1226                                 emitted = true;
1227                         }
1228                 }
1229                 else if (id.get_type() == TypeConstantOp)
1230                 {
1231                         auto &c = id.get<SPIRConstantOp>();
1232                         auto &type = get<SPIRType>(c.basetype);
1233                         auto name = to_name(c.self);
1234                         statement("static const ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
1235                         emitted = true;
1236                 }
1237                 else if (id.get_type() == TypeType)
1238                 {
1239                         auto &type = id.get<SPIRType>();
1240                         if (type.basetype == SPIRType::Struct && type.array.empty() && !type.pointer &&
1241                             (!ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) &&
1242                              !ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock)))
1243                         {
1244                                 if (emitted)
1245                                         statement("");
1246                                 emitted = false;
1247
1248                                 emit_struct(type);
1249                         }
1250                 }
1251         }
1252
1253         if (emitted)
1254                 statement("");
1255 }
1256
1257 void CompilerHLSL::replace_illegal_names()
1258 {
1259         static const unordered_set<string> keywords = {
1260                 // Additional HLSL specific keywords.
1261                 "line", "linear", "matrix", "point", "row_major", "sampler", "vector"
1262         };
1263
1264         CompilerGLSL::replace_illegal_names(keywords);
1265         CompilerGLSL::replace_illegal_names();
1266 }
1267
1268 void CompilerHLSL::declare_undefined_values()
1269 {
1270         bool emitted = false;
1271         ir.for_each_typed_id<SPIRUndef>([&](uint32_t, const SPIRUndef &undef) {
1272                 auto &type = this->get<SPIRType>(undef.basetype);
1273                 // OpUndef can be void for some reason ...
1274                 if (type.basetype == SPIRType::Void)
1275                         return;
1276
1277                 string initializer;
1278                 if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
1279                         initializer = join(" = ", to_zero_initialized_expression(undef.basetype));
1280
1281                 statement("static ", variable_decl(type, to_name(undef.self), undef.self), initializer, ";");
1282                 emitted = true;
1283         });
1284
1285         if (emitted)
1286                 statement("");
1287 }
1288
1289 void CompilerHLSL::emit_resources()
1290 {
1291         auto &execution = get_entry_point();
1292
1293         replace_illegal_names();
1294
1295         emit_specialization_constants_and_structs();
1296         emit_composite_constants();
1297
1298         bool emitted = false;
1299
1300         // Output UBOs and SSBOs
1301         ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1302                 auto &type = this->get<SPIRType>(var.basetype);
1303
1304                 bool is_block_storage = type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform;
1305                 bool has_block_flags = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) ||
1306                                        ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock);
1307
1308                 if (var.storage != StorageClassFunction && type.pointer && is_block_storage && !is_hidden_variable(var) &&
1309                     has_block_flags)
1310                 {
1311                         emit_buffer_block(var);
1312                         emitted = true;
1313                 }
1314         });
1315
1316         // Output push constant blocks
1317         ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1318                 auto &type = this->get<SPIRType>(var.basetype);
1319                 if (var.storage != StorageClassFunction && type.pointer && type.storage == StorageClassPushConstant &&
1320                     !is_hidden_variable(var))
1321                 {
1322                         emit_push_constant_block(var);
1323                         emitted = true;
1324                 }
1325         });
1326
1327         if (execution.model == ExecutionModelVertex && hlsl_options.shader_model <= 30)
1328         {
1329                 statement("uniform float4 gl_HalfPixel;");
1330                 emitted = true;
1331         }
1332
1333         bool skip_separate_image_sampler = !combined_image_samplers.empty() || hlsl_options.shader_model <= 30;
1334
1335         // Output Uniform Constants (values, samplers, images, etc).
1336         ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1337                 auto &type = this->get<SPIRType>(var.basetype);
1338
1339                 // If we're remapping separate samplers and images, only emit the combined samplers.
1340                 if (skip_separate_image_sampler)
1341                 {
1342                         // Sampler buffers are always used without a sampler, and they will also work in regular D3D.
1343                         bool sampler_buffer = type.basetype == SPIRType::Image && type.image.dim == DimBuffer;
1344                         bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1;
1345                         bool separate_sampler = type.basetype == SPIRType::Sampler;
1346                         if (!sampler_buffer && (separate_image || separate_sampler))
1347                                 return;
1348                 }
1349
1350                 if (var.storage != StorageClassFunction && !is_builtin_variable(var) && !var.remapped_variable &&
1351                     type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter) &&
1352                     !is_hidden_variable(var))
1353                 {
1354                         emit_uniform(var);
1355                         emitted = true;
1356                 }
1357         });
1358
1359         if (emitted)
1360                 statement("");
1361         emitted = false;
1362
1363         // Emit builtin input and output variables here.
1364         emit_builtin_variables();
1365
1366         ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1367                 auto &type = this->get<SPIRType>(var.basetype);
1368                 bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
1369
1370                 // Do not emit I/O blocks here.
1371                 // I/O blocks can be arrayed, so we must deal with them separately to support geometry shaders
1372                 // and tessellation down the line.
1373                 if (!block && var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
1374                     (var.storage == StorageClassInput || var.storage == StorageClassOutput) && !is_builtin_variable(var) &&
1375                     interface_variable_exists_in_entry_point(var.self))
1376                 {
1377                         // Only emit non-builtins which are not blocks here. Builtin variables are handled separately.
1378                         emit_interface_block_globally(var);
1379                         emitted = true;
1380                 }
1381         });
1382
1383         if (emitted)
1384                 statement("");
1385         emitted = false;
1386
1387         require_input = false;
1388         require_output = false;
1389         unordered_set<uint32_t> active_inputs;
1390         unordered_set<uint32_t> active_outputs;
1391         SmallVector<SPIRVariable *> input_variables;
1392         SmallVector<SPIRVariable *> output_variables;
1393         ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1394                 auto &type = this->get<SPIRType>(var.basetype);
1395                 bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
1396
1397                 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
1398                         return;
1399
1400                 // Do not emit I/O blocks here.
1401                 // I/O blocks can be arrayed, so we must deal with them separately to support geometry shaders
1402                 // and tessellation down the line.
1403                 if (!block && !var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1404                     interface_variable_exists_in_entry_point(var.self))
1405                 {
1406                         if (var.storage == StorageClassInput)
1407                                 input_variables.push_back(&var);
1408                         else
1409                                 output_variables.push_back(&var);
1410                 }
1411
1412                 // Reserve input and output locations for block variables as necessary.
1413                 if (block && !is_builtin_variable(var) && interface_variable_exists_in_entry_point(var.self))
1414                 {
1415                         auto &active = var.storage == StorageClassInput ? active_inputs : active_outputs;
1416                         for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
1417                         {
1418                                 if (has_member_decoration(type.self, i, DecorationLocation))
1419                                 {
1420                                         uint32_t location = get_member_decoration(type.self, i, DecorationLocation);
1421                                         active.insert(location);
1422                                 }
1423                         }
1424
1425                         // Emit the block struct and a global variable here.
1426                         emit_io_block(var);
1427                 }
1428         });
1429
1430         const auto variable_compare = [&](const SPIRVariable *a, const SPIRVariable *b) -> bool {
1431                 // Sort input and output variables based on, from more robust to less robust:
1432                 // - Location
1433                 // - Variable has a location
1434                 // - Name comparison
1435                 // - Variable has a name
1436                 // - Fallback: ID
1437                 bool has_location_a = has_decoration(a->self, DecorationLocation);
1438                 bool has_location_b = has_decoration(b->self, DecorationLocation);
1439
1440                 if (has_location_a && has_location_b)
1441                 {
1442                         return get_decoration(a->self, DecorationLocation) < get_decoration(b->self, DecorationLocation);
1443                 }
1444                 else if (has_location_a && !has_location_b)
1445                         return true;
1446                 else if (!has_location_a && has_location_b)
1447                         return false;
1448
1449                 const auto &name1 = to_name(a->self);
1450                 const auto &name2 = to_name(b->self);
1451
1452                 if (name1.empty() && name2.empty())
1453                         return a->self < b->self;
1454                 else if (name1.empty())
1455                         return true;
1456                 else if (name2.empty())
1457                         return false;
1458
1459                 return name1.compare(name2) < 0;
1460         };
1461
1462         auto input_builtins = active_input_builtins;
1463         input_builtins.clear(BuiltInNumWorkgroups);
1464         input_builtins.clear(BuiltInPointCoord);
1465         input_builtins.clear(BuiltInSubgroupSize);
1466         input_builtins.clear(BuiltInSubgroupLocalInvocationId);
1467         input_builtins.clear(BuiltInSubgroupEqMask);
1468         input_builtins.clear(BuiltInSubgroupLtMask);
1469         input_builtins.clear(BuiltInSubgroupLeMask);
1470         input_builtins.clear(BuiltInSubgroupGtMask);
1471         input_builtins.clear(BuiltInSubgroupGeMask);
1472
1473         if (!input_variables.empty() || !input_builtins.empty())
1474         {
1475                 require_input = true;
1476                 statement("struct SPIRV_Cross_Input");
1477
1478                 begin_scope();
1479                 sort(input_variables.begin(), input_variables.end(), variable_compare);
1480                 for (auto var : input_variables)
1481                         emit_interface_block_in_struct(*var, active_inputs);
1482                 emit_builtin_inputs_in_struct();
1483                 end_scope_decl();
1484                 statement("");
1485         }
1486
1487         if (!output_variables.empty() || !active_output_builtins.empty())
1488         {
1489                 require_output = true;
1490                 statement("struct SPIRV_Cross_Output");
1491
1492                 begin_scope();
1493                 // FIXME: Use locations properly if they exist.
1494                 sort(output_variables.begin(), output_variables.end(), variable_compare);
1495                 for (auto var : output_variables)
1496                         emit_interface_block_in_struct(*var, active_outputs);
1497                 emit_builtin_outputs_in_struct();
1498                 end_scope_decl();
1499                 statement("");
1500         }
1501
1502         // Global variables.
1503         for (auto global : global_variables)
1504         {
1505                 auto &var = get<SPIRVariable>(global);
1506                 if (is_hidden_variable(var, true))
1507                         continue;
1508
1509                 if (var.storage != StorageClassOutput)
1510                 {
1511                         if (!variable_is_lut(var))
1512                         {
1513                                 add_resource_name(var.self);
1514
1515                                 const char *storage = nullptr;
1516                                 switch (var.storage)
1517                                 {
1518                                 case StorageClassWorkgroup:
1519                                         storage = "groupshared";
1520                                         break;
1521
1522                                 default:
1523                                         storage = "static";
1524                                         break;
1525                                 }
1526
1527                                 string initializer;
1528                                 if (options.force_zero_initialized_variables && var.storage == StorageClassPrivate &&
1529                                     !var.initializer && !var.static_expression && type_can_zero_initialize(get_variable_data_type(var)))
1530                                 {
1531                                         initializer = join(" = ", to_zero_initialized_expression(get_variable_data_type_id(var)));
1532                                 }
1533                                 statement(storage, " ", variable_decl(var), initializer, ";");
1534
1535                                 emitted = true;
1536                         }
1537                 }
1538         }
1539
1540         if (emitted)
1541                 statement("");
1542
1543         declare_undefined_values();
1544
1545         if (requires_op_fmod)
1546         {
1547                 static const char *types[] = {
1548                         "float",
1549                         "float2",
1550                         "float3",
1551                         "float4",
1552                 };
1553
1554                 for (auto &type : types)
1555                 {
1556                         statement(type, " mod(", type, " x, ", type, " y)");
1557                         begin_scope();
1558                         statement("return x - y * floor(x / y);");
1559                         end_scope();
1560                         statement("");
1561                 }
1562         }
1563
1564         emit_texture_size_variants(required_texture_size_variants.srv, "4", false, "");
1565         for (uint32_t norm = 0; norm < 3; norm++)
1566         {
1567                 for (uint32_t comp = 0; comp < 4; comp++)
1568                 {
1569                         static const char *qualifiers[] = { "", "unorm ", "snorm " };
1570                         static const char *vecsizes[] = { "", "2", "3", "4" };
1571                         emit_texture_size_variants(required_texture_size_variants.uav[norm][comp], vecsizes[comp], true,
1572                                                    qualifiers[norm]);
1573                 }
1574         }
1575
1576         if (requires_fp16_packing)
1577         {
1578                 // HLSL does not pack into a single word sadly :(
1579                 statement("uint spvPackHalf2x16(float2 value)");
1580                 begin_scope();
1581                 statement("uint2 Packed = f32tof16(value);");
1582                 statement("return Packed.x | (Packed.y << 16);");
1583                 end_scope();
1584                 statement("");
1585
1586                 statement("float2 spvUnpackHalf2x16(uint value)");
1587                 begin_scope();
1588                 statement("return f16tof32(uint2(value & 0xffff, value >> 16));");
1589                 end_scope();
1590                 statement("");
1591         }
1592
1593         if (requires_uint2_packing)
1594         {
1595                 statement("uint64_t spvPackUint2x32(uint2 value)");
1596                 begin_scope();
1597                 statement("return (uint64_t(value.y) << 32) | uint64_t(value.x);");
1598                 end_scope();
1599                 statement("");
1600
1601                 statement("uint2 spvUnpackUint2x32(uint64_t value)");
1602                 begin_scope();
1603                 statement("uint2 Unpacked;");
1604                 statement("Unpacked.x = uint(value & 0xffffffff);");
1605                 statement("Unpacked.y = uint(value >> 32);");
1606                 statement("return Unpacked;");
1607                 end_scope();
1608                 statement("");
1609         }
1610
1611         if (requires_explicit_fp16_packing)
1612         {
1613                 // HLSL does not pack into a single word sadly :(
1614                 statement("uint spvPackFloat2x16(min16float2 value)");
1615                 begin_scope();
1616                 statement("uint2 Packed = f32tof16(value);");
1617                 statement("return Packed.x | (Packed.y << 16);");
1618                 end_scope();
1619                 statement("");
1620
1621                 statement("min16float2 spvUnpackFloat2x16(uint value)");
1622                 begin_scope();
1623                 statement("return min16float2(f16tof32(uint2(value & 0xffff, value >> 16)));");
1624                 end_scope();
1625                 statement("");
1626         }
1627
1628         // HLSL does not seem to have builtins for these operation, so roll them by hand ...
1629         if (requires_unorm8_packing)
1630         {
1631                 statement("uint spvPackUnorm4x8(float4 value)");
1632                 begin_scope();
1633                 statement("uint4 Packed = uint4(round(saturate(value) * 255.0));");
1634                 statement("return Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24);");
1635                 end_scope();
1636                 statement("");
1637
1638                 statement("float4 spvUnpackUnorm4x8(uint value)");
1639                 begin_scope();
1640                 statement("uint4 Packed = uint4(value & 0xff, (value >> 8) & 0xff, (value >> 16) & 0xff, value >> 24);");
1641                 statement("return float4(Packed) / 255.0;");
1642                 end_scope();
1643                 statement("");
1644         }
1645
1646         if (requires_snorm8_packing)
1647         {
1648                 statement("uint spvPackSnorm4x8(float4 value)");
1649                 begin_scope();
1650                 statement("int4 Packed = int4(round(clamp(value, -1.0, 1.0) * 127.0)) & 0xff;");
1651                 statement("return uint(Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24));");
1652                 end_scope();
1653                 statement("");
1654
1655                 statement("float4 spvUnpackSnorm4x8(uint value)");
1656                 begin_scope();
1657                 statement("int SignedValue = int(value);");
1658                 statement("int4 Packed = int4(SignedValue << 24, SignedValue << 16, SignedValue << 8, SignedValue) >> 24;");
1659                 statement("return clamp(float4(Packed) / 127.0, -1.0, 1.0);");
1660                 end_scope();
1661                 statement("");
1662         }
1663
1664         if (requires_unorm16_packing)
1665         {
1666                 statement("uint spvPackUnorm2x16(float2 value)");
1667                 begin_scope();
1668                 statement("uint2 Packed = uint2(round(saturate(value) * 65535.0));");
1669                 statement("return Packed.x | (Packed.y << 16);");
1670                 end_scope();
1671                 statement("");
1672
1673                 statement("float2 spvUnpackUnorm2x16(uint value)");
1674                 begin_scope();
1675                 statement("uint2 Packed = uint2(value & 0xffff, value >> 16);");
1676                 statement("return float2(Packed) / 65535.0;");
1677                 end_scope();
1678                 statement("");
1679         }
1680
1681         if (requires_snorm16_packing)
1682         {
1683                 statement("uint spvPackSnorm2x16(float2 value)");
1684                 begin_scope();
1685                 statement("int2 Packed = int2(round(clamp(value, -1.0, 1.0) * 32767.0)) & 0xffff;");
1686                 statement("return uint(Packed.x | (Packed.y << 16));");
1687                 end_scope();
1688                 statement("");
1689
1690                 statement("float2 spvUnpackSnorm2x16(uint value)");
1691                 begin_scope();
1692                 statement("int SignedValue = int(value);");
1693                 statement("int2 Packed = int2(SignedValue << 16, SignedValue) >> 16;");
1694                 statement("return clamp(float2(Packed) / 32767.0, -1.0, 1.0);");
1695                 end_scope();
1696                 statement("");
1697         }
1698
1699         if (requires_bitfield_insert)
1700         {
1701                 static const char *types[] = { "uint", "uint2", "uint3", "uint4" };
1702                 for (auto &type : types)
1703                 {
1704                         statement(type, " spvBitfieldInsert(", type, " Base, ", type, " Insert, uint Offset, uint Count)");
1705                         begin_scope();
1706                         statement("uint Mask = Count == 32 ? 0xffffffff : (((1u << Count) - 1) << (Offset & 31));");
1707                         statement("return (Base & ~Mask) | ((Insert << Offset) & Mask);");
1708                         end_scope();
1709                         statement("");
1710                 }
1711         }
1712
1713         if (requires_bitfield_extract)
1714         {
1715                 static const char *unsigned_types[] = { "uint", "uint2", "uint3", "uint4" };
1716                 for (auto &type : unsigned_types)
1717                 {
1718                         statement(type, " spvBitfieldUExtract(", type, " Base, uint Offset, uint Count)");
1719                         begin_scope();
1720                         statement("uint Mask = Count == 32 ? 0xffffffff : ((1 << Count) - 1);");
1721                         statement("return (Base >> Offset) & Mask;");
1722                         end_scope();
1723                         statement("");
1724                 }
1725
1726                 // In this overload, we will have to do sign-extension, which we will emulate by shifting up and down.
1727                 static const char *signed_types[] = { "int", "int2", "int3", "int4" };
1728                 for (auto &type : signed_types)
1729                 {
1730                         statement(type, " spvBitfieldSExtract(", type, " Base, int Offset, int Count)");
1731                         begin_scope();
1732                         statement("int Mask = Count == 32 ? -1 : ((1 << Count) - 1);");
1733                         statement(type, " Masked = (Base >> Offset) & Mask;");
1734                         statement("int ExtendShift = (32 - Count) & 31;");
1735                         statement("return (Masked << ExtendShift) >> ExtendShift;");
1736                         end_scope();
1737                         statement("");
1738                 }
1739         }
1740
1741         if (requires_inverse_2x2)
1742         {
1743                 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1744                 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1745                 statement("float2x2 spvInverse(float2x2 m)");
1746                 begin_scope();
1747                 statement("float2x2 adj;        // The adjoint matrix (inverse after dividing by determinant)");
1748                 statement_no_indent("");
1749                 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1750                 statement("adj[0][0] =  m[1][1];");
1751                 statement("adj[0][1] = -m[0][1];");
1752                 statement_no_indent("");
1753                 statement("adj[1][0] = -m[1][0];");
1754                 statement("adj[1][1] =  m[0][0];");
1755                 statement_no_indent("");
1756                 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1757                 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
1758                 statement_no_indent("");
1759                 statement("// Divide the classical adjoint matrix by the determinant.");
1760                 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1761                 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1762                 end_scope();
1763                 statement("");
1764         }
1765
1766         if (requires_inverse_3x3)
1767         {
1768                 statement("// Returns the determinant of a 2x2 matrix.");
1769                 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
1770                 begin_scope();
1771                 statement("return a1 * b2 - b1 * a2;");
1772                 end_scope();
1773                 statement_no_indent("");
1774                 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1775                 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1776                 statement("float3x3 spvInverse(float3x3 m)");
1777                 begin_scope();
1778                 statement("float3x3 adj;        // The adjoint matrix (inverse after dividing by determinant)");
1779                 statement_no_indent("");
1780                 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1781                 statement("adj[0][0] =  spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
1782                 statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
1783                 statement("adj[0][2] =  spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
1784                 statement_no_indent("");
1785                 statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
1786                 statement("adj[1][1] =  spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
1787                 statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
1788                 statement_no_indent("");
1789                 statement("adj[2][0] =  spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
1790                 statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
1791                 statement("adj[2][2] =  spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
1792                 statement_no_indent("");
1793                 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1794                 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
1795                 statement_no_indent("");
1796                 statement("// Divide the classical adjoint matrix by the determinant.");
1797                 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1798                 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1799                 end_scope();
1800                 statement("");
1801         }
1802
1803         if (requires_inverse_4x4)
1804         {
1805                 if (!requires_inverse_3x3)
1806                 {
1807                         statement("// Returns the determinant of a 2x2 matrix.");
1808                         statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
1809                         begin_scope();
1810                         statement("return a1 * b2 - b1 * a2;");
1811                         end_scope();
1812                         statement("");
1813                 }
1814
1815                 statement("// Returns the determinant of a 3x3 matrix.");
1816                 statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
1817                           "float c2, float c3)");
1818                 begin_scope();
1819                 statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * "
1820                           "spvDet2x2(a2, a3, "
1821                           "b2, b3);");
1822                 end_scope();
1823                 statement_no_indent("");
1824                 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1825                 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1826                 statement("float4x4 spvInverse(float4x4 m)");
1827                 begin_scope();
1828                 statement("float4x4 adj;        // The adjoint matrix (inverse after dividing by determinant)");
1829                 statement_no_indent("");
1830                 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1831                 statement(
1832                     "adj[0][0] =  spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1833                     "m[3][3]);");
1834                 statement(
1835                     "adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1836                     "m[3][3]);");
1837                 statement(
1838                     "adj[0][2] =  spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
1839                     "m[3][3]);");
1840                 statement(
1841                     "adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
1842                     "m[2][3]);");
1843                 statement_no_indent("");
1844                 statement(
1845                     "adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1846                     "m[3][3]);");
1847                 statement(
1848                     "adj[1][1] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1849                     "m[3][3]);");
1850                 statement(
1851                     "adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
1852                     "m[3][3]);");
1853                 statement(
1854                     "adj[1][3] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
1855                     "m[2][3]);");
1856                 statement_no_indent("");
1857                 statement(
1858                     "adj[2][0] =  spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1859                     "m[3][3]);");
1860                 statement(
1861                     "adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1862                     "m[3][3]);");
1863                 statement(
1864                     "adj[2][2] =  spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
1865                     "m[3][3]);");
1866                 statement(
1867                     "adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
1868                     "m[2][3]);");
1869                 statement_no_indent("");
1870                 statement(
1871                     "adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1872                     "m[3][2]);");
1873                 statement(
1874                     "adj[3][1] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1875                     "m[3][2]);");
1876                 statement(
1877                     "adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
1878                     "m[3][2]);");
1879                 statement(
1880                     "adj[3][3] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
1881                     "m[2][2]);");
1882                 statement_no_indent("");
1883                 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1884                 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
1885                           "* m[3][0]);");
1886                 statement_no_indent("");
1887                 statement("// Divide the classical adjoint matrix by the determinant.");
1888                 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1889                 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1890                 end_scope();
1891                 statement("");
1892         }
1893
1894         if (requires_scalar_reflect)
1895         {
1896                 // FP16/FP64? No templates in HLSL.
1897                 statement("float spvReflect(float i, float n)");
1898                 begin_scope();
1899                 statement("return i - 2.0 * dot(n, i) * n;");
1900                 end_scope();
1901                 statement("");
1902         }
1903
1904         if (requires_scalar_refract)
1905         {
1906                 // FP16/FP64? No templates in HLSL.
1907                 statement("float spvRefract(float i, float n, float eta)");
1908                 begin_scope();
1909                 statement("float NoI = n * i;");
1910                 statement("float NoI2 = NoI * NoI;");
1911                 statement("float k = 1.0 - eta * eta * (1.0 - NoI2);");
1912                 statement("if (k < 0.0)");
1913                 begin_scope();
1914                 statement("return 0.0;");
1915                 end_scope();
1916                 statement("else");
1917                 begin_scope();
1918                 statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
1919                 end_scope();
1920                 end_scope();
1921                 statement("");
1922         }
1923
1924         if (requires_scalar_faceforward)
1925         {
1926                 // FP16/FP64? No templates in HLSL.
1927                 statement("float spvFaceForward(float n, float i, float nref)");
1928                 begin_scope();
1929                 statement("return i * nref < 0.0 ? n : -n;");
1930                 end_scope();
1931                 statement("");
1932         }
1933 }
1934
1935 void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav,
1936                                               const char *type_qualifier)
1937 {
1938         if (variant_mask == 0)
1939                 return;
1940
1941         static const char *types[QueryTypeCount] = { "float", "int", "uint" };
1942         static const char *dims[QueryDimCount] = { "Texture1D",   "Texture1DArray",  "Texture2D",   "Texture2DArray",
1943                                                        "Texture3D",   "Buffer",          "TextureCube", "TextureCubeArray",
1944                                                        "Texture2DMS", "Texture2DMSArray" };
1945
1946         static const bool has_lod[QueryDimCount] = { true, true, true, true, true, false, true, true, false, false };
1947
1948         static const char *ret_types[QueryDimCount] = {
1949                 "uint", "uint2", "uint2", "uint3", "uint3", "uint", "uint2", "uint3", "uint2", "uint3",
1950         };
1951
1952         static const uint32_t return_arguments[QueryDimCount] = {
1953                 1, 2, 2, 3, 3, 1, 2, 3, 2, 3,
1954         };
1955
1956         for (uint32_t index = 0; index < QueryDimCount; index++)
1957         {
1958                 for (uint32_t type_index = 0; type_index < QueryTypeCount; type_index++)
1959                 {
1960                         uint32_t bit = 16 * type_index + index;
1961                         uint64_t mask = 1ull << bit;
1962
1963                         if ((variant_mask & mask) == 0)
1964                                 continue;
1965
1966                         statement(ret_types[index], " spv", (uav ? "Image" : "Texture"), "Size(", (uav ? "RW" : ""),
1967                                   dims[index], "<", type_qualifier, types[type_index], vecsize_qualifier, "> Tex, ",
1968                                   (uav ? "" : "uint Level, "), "out uint Param)");
1969                         begin_scope();
1970                         statement(ret_types[index], " ret;");
1971                         switch (return_arguments[index])
1972                         {
1973                         case 1:
1974                                 if (has_lod[index] && !uav)
1975                                         statement("Tex.GetDimensions(Level, ret.x, Param);");
1976                                 else
1977                                 {
1978                                         statement("Tex.GetDimensions(ret.x);");
1979                                         statement("Param = 0u;");
1980                                 }
1981                                 break;
1982                         case 2:
1983                                 if (has_lod[index] && !uav)
1984                                         statement("Tex.GetDimensions(Level, ret.x, ret.y, Param);");
1985                                 else if (!uav)
1986                                         statement("Tex.GetDimensions(ret.x, ret.y, Param);");
1987                                 else
1988                                 {
1989                                         statement("Tex.GetDimensions(ret.x, ret.y);");
1990                                         statement("Param = 0u;");
1991                                 }
1992                                 break;
1993                         case 3:
1994                                 if (has_lod[index] && !uav)
1995                                         statement("Tex.GetDimensions(Level, ret.x, ret.y, ret.z, Param);");
1996                                 else if (!uav)
1997                                         statement("Tex.GetDimensions(ret.x, ret.y, ret.z, Param);");
1998                                 else
1999                                 {
2000                                         statement("Tex.GetDimensions(ret.x, ret.y, ret.z);");
2001                                         statement("Param = 0u;");
2002                                 }
2003                                 break;
2004                         }
2005
2006                         statement("return ret;");
2007                         end_scope();
2008                         statement("");
2009                 }
2010         }
2011 }
2012
2013 string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
2014 {
2015         auto &flags = get_member_decoration_bitset(type.self, index);
2016
2017         // HLSL can emit row_major or column_major decoration in any struct.
2018         // Do not try to merge combined decorations for children like in GLSL.
2019
2020         // Flip the convention. HLSL is a bit odd in that the memory layout is column major ... but the language API is "row-major".
2021         // The way to deal with this is to multiply everything in inverse order, and reverse the memory layout.
2022         if (flags.get(DecorationColMajor))
2023                 return "row_major ";
2024         else if (flags.get(DecorationRowMajor))
2025                 return "column_major ";
2026
2027         return "";
2028 }
2029
2030 void CompilerHLSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
2031                                       const string &qualifier, uint32_t base_offset)
2032 {
2033         auto &membertype = get<SPIRType>(member_type_id);
2034
2035         Bitset memberflags;
2036         auto &memb = ir.meta[type.self].members;
2037         if (index < memb.size())
2038                 memberflags = memb[index].decoration_flags;
2039
2040         string qualifiers;
2041         bool is_block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) ||
2042                         ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock);
2043
2044         if (is_block)
2045                 qualifiers = to_interpolation_qualifiers(memberflags);
2046
2047         string packing_offset;
2048         bool is_push_constant = type.storage == StorageClassPushConstant;
2049
2050         if ((has_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset) || is_push_constant) &&
2051             has_member_decoration(type.self, index, DecorationOffset))
2052         {
2053                 uint32_t offset = memb[index].offset - base_offset;
2054                 if (offset & 3)
2055                         SPIRV_CROSS_THROW("Cannot pack on tighter bounds than 4 bytes in HLSL.");
2056
2057                 static const char *packing_swizzle[] = { "", ".y", ".z", ".w" };
2058                 packing_offset = join(" : packoffset(c", offset / 16, packing_swizzle[(offset & 15) >> 2], ")");
2059         }
2060
2061         statement(layout_for_member(type, index), qualifiers, qualifier,
2062                   variable_decl(membertype, to_member_name(type, index)), packing_offset, ";");
2063 }
2064
2065 void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
2066 {
2067         auto &type = get<SPIRType>(var.basetype);
2068
2069         bool is_uav = var.storage == StorageClassStorageBuffer || has_decoration(type.self, DecorationBufferBlock);
2070
2071         if (is_uav)
2072         {
2073                 Bitset flags = ir.get_buffer_block_flags(var);
2074                 bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
2075                 bool is_coherent = flags.get(DecorationCoherent) && !is_readonly;
2076                 bool is_interlocked = interlocked_resources.count(var.self) > 0;
2077                 const char *type_name = "ByteAddressBuffer ";
2078                 if (!is_readonly)
2079                         type_name = is_interlocked ? "RasterizerOrderedByteAddressBuffer " : "RWByteAddressBuffer ";
2080                 add_resource_name(var.self);
2081                 statement(is_coherent ? "globallycoherent " : "", type_name, to_name(var.self), type_to_array_glsl(type),
2082                           to_resource_binding(var), ";");
2083         }
2084         else
2085         {
2086                 if (type.array.empty())
2087                 {
2088                         // Flatten the top-level struct so we can use packoffset,
2089                         // this restriction is similar to GLSL where layout(offset) is not possible on sub-structs.
2090                         flattened_structs[var.self] = false;
2091
2092                         // Prefer the block name if possible.
2093                         auto buffer_name = to_name(type.self, false);
2094                         if (ir.meta[type.self].decoration.alias.empty() ||
2095                             resource_names.find(buffer_name) != end(resource_names) ||
2096                             block_names.find(buffer_name) != end(block_names))
2097                         {
2098                                 buffer_name = get_block_fallback_name(var.self);
2099                         }
2100
2101                         add_variable(block_names, resource_names, buffer_name);
2102
2103                         // If for some reason buffer_name is an illegal name, make a final fallback to a workaround name.
2104                         // This cannot conflict with anything else, so we're safe now.
2105                         if (buffer_name.empty())
2106                                 buffer_name = join("_", get<SPIRType>(var.basetype).self, "_", var.self);
2107
2108                         uint32_t failed_index = 0;
2109                         if (buffer_is_packing_standard(type, BufferPackingHLSLCbufferPackOffset, &failed_index))
2110                                 set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset);
2111                         else
2112                         {
2113                                 SPIRV_CROSS_THROW(join("cbuffer ID ", var.self, " (name: ", buffer_name, "), member index ",
2114                                                        failed_index, " (name: ", to_member_name(type, failed_index),
2115                                                        ") cannot be expressed with either HLSL packing layout or packoffset."));
2116                         }
2117
2118                         block_names.insert(buffer_name);
2119
2120                         // Save for post-reflection later.
2121                         declared_block_names[var.self] = buffer_name;
2122
2123                         type.member_name_cache.clear();
2124                         // var.self can be used as a backup name for the block name,
2125                         // so we need to make sure we don't disturb the name here on a recompile.
2126                         // It will need to be reset if we have to recompile.
2127                         preserve_alias_on_reset(var.self);
2128                         add_resource_name(var.self);
2129                         statement("cbuffer ", buffer_name, to_resource_binding(var));
2130                         begin_scope();
2131
2132                         uint32_t i = 0;
2133                         for (auto &member : type.member_types)
2134                         {
2135                                 add_member_name(type, i);
2136                                 auto backup_name = get_member_name(type.self, i);
2137                                 auto member_name = to_member_name(type, i);
2138                                 member_name = join(to_name(var.self), "_", member_name);
2139                                 ParsedIR::sanitize_underscores(member_name);
2140                                 set_member_name(type.self, i, member_name);
2141                                 emit_struct_member(type, member, i, "");
2142                                 set_member_name(type.self, i, backup_name);
2143                                 i++;
2144                         }
2145
2146                         end_scope_decl();
2147                         statement("");
2148                 }
2149                 else
2150                 {
2151                         if (hlsl_options.shader_model < 51)
2152                                 SPIRV_CROSS_THROW(
2153                                     "Need ConstantBuffer<T> to use arrays of UBOs, but this is only supported in SM 5.1.");
2154
2155                         add_resource_name(type.self);
2156                         add_resource_name(var.self);
2157
2158                         // ConstantBuffer<T> does not support packoffset, so it is unuseable unless everything aligns as we expect.
2159                         uint32_t failed_index = 0;
2160                         if (!buffer_is_packing_standard(type, BufferPackingHLSLCbuffer, &failed_index))
2161                         {
2162                                 SPIRV_CROSS_THROW(join("HLSL ConstantBuffer<T> ID ", var.self, " (name: ", to_name(type.self),
2163                                                        "), member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2164                                                        ") cannot be expressed with normal HLSL packing rules."));
2165                         }
2166
2167                         emit_struct(get<SPIRType>(type.self));
2168                         statement("ConstantBuffer<", to_name(type.self), "> ", to_name(var.self), type_to_array_glsl(type),
2169                                   to_resource_binding(var), ";");
2170                 }
2171         }
2172 }
2173
2174 void CompilerHLSL::emit_push_constant_block(const SPIRVariable &var)
2175 {
2176         if (root_constants_layout.empty())
2177         {
2178                 emit_buffer_block(var);
2179         }
2180         else
2181         {
2182                 for (const auto &layout : root_constants_layout)
2183                 {
2184                         auto &type = get<SPIRType>(var.basetype);
2185
2186                         uint32_t failed_index = 0;
2187                         if (buffer_is_packing_standard(type, BufferPackingHLSLCbufferPackOffset, &failed_index, layout.start,
2188                                                        layout.end))
2189                                 set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset);
2190                         else
2191                         {
2192                                 SPIRV_CROSS_THROW(join("Root constant cbuffer ID ", var.self, " (name: ", to_name(type.self), ")",
2193                                                        ", member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2194                                                        ") cannot be expressed with either HLSL packing layout or packoffset."));
2195                         }
2196
2197                         flattened_structs[var.self] = false;
2198                         type.member_name_cache.clear();
2199                         add_resource_name(var.self);
2200                         auto &memb = ir.meta[type.self].members;
2201
2202                         statement("cbuffer SPIRV_CROSS_RootConstant_", to_name(var.self),
2203                                   to_resource_register(HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT, 'b', layout.binding, layout.space));
2204                         begin_scope();
2205
2206                         // Index of the next field in the generated root constant constant buffer
2207                         auto constant_index = 0u;
2208
2209                         // Iterate over all member of the push constant and check which of the fields
2210                         // fit into the given root constant layout.
2211                         for (auto i = 0u; i < memb.size(); i++)
2212                         {
2213                                 const auto offset = memb[i].offset;
2214                                 if (layout.start <= offset && offset < layout.end)
2215                                 {
2216                                         const auto &member = type.member_types[i];
2217
2218                                         add_member_name(type, constant_index);
2219                                         auto backup_name = get_member_name(type.self, i);
2220                                         auto member_name = to_member_name(type, i);
2221                                         member_name = join(to_name(var.self), "_", member_name);
2222                                         ParsedIR::sanitize_underscores(member_name);
2223                                         set_member_name(type.self, constant_index, member_name);
2224                                         emit_struct_member(type, member, i, "", layout.start);
2225                                         set_member_name(type.self, constant_index, backup_name);
2226
2227                                         constant_index++;
2228                                 }
2229                         }
2230
2231                         end_scope_decl();
2232                 }
2233         }
2234 }
2235
2236 string CompilerHLSL::to_sampler_expression(uint32_t id)
2237 {
2238         auto expr = join("_", to_non_uniform_aware_expression(id));
2239         auto index = expr.find_first_of('[');
2240         if (index == string::npos)
2241         {
2242                 return expr + "_sampler";
2243         }
2244         else
2245         {
2246                 // We have an expression like _ident[array], so we cannot tack on _sampler, insert it inside the string instead.
2247                 return expr.insert(index, "_sampler");
2248         }
2249 }
2250
2251 void CompilerHLSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
2252 {
2253         if (hlsl_options.shader_model >= 40 && combined_image_samplers.empty())
2254         {
2255                 set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
2256         }
2257         else
2258         {
2259                 // Make sure to suppress usage tracking. It is illegal to create temporaries of opaque types.
2260                 emit_op(result_type, result_id, to_combined_image_sampler(image_id, samp_id), true, true);
2261         }
2262 }
2263
2264 string CompilerHLSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
2265 {
2266         string arg_str = CompilerGLSL::to_func_call_arg(arg, id);
2267
2268         if (hlsl_options.shader_model <= 30)
2269                 return arg_str;
2270
2271         // Manufacture automatic sampler arg if the arg is a SampledImage texture and we're in modern HLSL.
2272         auto &type = expression_type(id);
2273
2274         // We don't have to consider combined image samplers here via OpSampledImage because
2275         // those variables cannot be passed as arguments to functions.
2276         // Only global SampledImage variables may be used as arguments.
2277         if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
2278                 arg_str += ", " + to_sampler_expression(id);
2279
2280         return arg_str;
2281 }
2282
2283 void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
2284 {
2285         if (func.self != ir.default_entry_point)
2286                 add_function_overload(func);
2287
2288         auto &execution = get_entry_point();
2289         // Avoid shadow declarations.
2290         local_variable_names = resource_names;
2291
2292         string decl;
2293
2294         auto &type = get<SPIRType>(func.return_type);
2295         if (type.array.empty())
2296         {
2297                 decl += flags_to_qualifiers_glsl(type, return_flags);
2298                 decl += type_to_glsl(type);
2299                 decl += " ";
2300         }
2301         else
2302         {
2303                 // We cannot return arrays in HLSL, so "return" through an out variable.
2304                 decl = "void ";
2305         }
2306
2307         if (func.self == ir.default_entry_point)
2308         {
2309                 if (execution.model == ExecutionModelVertex)
2310                         decl += "vert_main";
2311                 else if (execution.model == ExecutionModelFragment)
2312                         decl += "frag_main";
2313                 else if (execution.model == ExecutionModelGLCompute)
2314                         decl += "comp_main";
2315                 else
2316                         SPIRV_CROSS_THROW("Unsupported execution model.");
2317                 processing_entry_point = true;
2318         }
2319         else
2320                 decl += to_name(func.self);
2321
2322         decl += "(";
2323         SmallVector<string> arglist;
2324
2325         if (!type.array.empty())
2326         {
2327                 // Fake array returns by writing to an out array instead.
2328                 string out_argument;
2329                 out_argument += "out ";
2330                 out_argument += type_to_glsl(type);
2331                 out_argument += " ";
2332                 out_argument += "spvReturnValue";
2333                 out_argument += type_to_array_glsl(type);
2334                 arglist.push_back(move(out_argument));
2335         }
2336
2337         for (auto &arg : func.arguments)
2338         {
2339                 // Do not pass in separate images or samplers if we're remapping
2340                 // to combined image samplers.
2341                 if (skip_argument(arg.id))
2342                         continue;
2343
2344                 // Might change the variable name if it already exists in this function.
2345                 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2346                 // to use same name for variables.
2347                 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2348                 add_local_variable_name(arg.id);
2349
2350                 arglist.push_back(argument_decl(arg));
2351
2352                 // Flatten a combined sampler to two separate arguments in modern HLSL.
2353                 auto &arg_type = get<SPIRType>(arg.type);
2354                 if (hlsl_options.shader_model > 30 && arg_type.basetype == SPIRType::SampledImage &&
2355                     arg_type.image.dim != DimBuffer)
2356                 {
2357                         // Manufacture automatic sampler arg for SampledImage texture
2358                         arglist.push_back(join(image_is_comparison(arg_type, arg.id) ? "SamplerComparisonState " : "SamplerState ",
2359                                                to_sampler_expression(arg.id), type_to_array_glsl(arg_type)));
2360                 }
2361
2362                 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2363                 auto *var = maybe_get<SPIRVariable>(arg.id);
2364                 if (var)
2365                         var->parameter = &arg;
2366         }
2367
2368         for (auto &arg : func.shadow_arguments)
2369         {
2370                 // Might change the variable name if it already exists in this function.
2371                 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2372                 // to use same name for variables.
2373                 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2374                 add_local_variable_name(arg.id);
2375
2376                 arglist.push_back(argument_decl(arg));
2377
2378                 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2379                 auto *var = maybe_get<SPIRVariable>(arg.id);
2380                 if (var)
2381                         var->parameter = &arg;
2382         }
2383
2384         decl += merge(arglist);
2385         decl += ")";
2386         statement(decl);
2387 }
2388
2389 void CompilerHLSL::emit_hlsl_entry_point()
2390 {
2391         SmallVector<string> arguments;
2392
2393         if (require_input)
2394                 arguments.push_back("SPIRV_Cross_Input stage_input");
2395
2396         // Add I/O blocks as separate arguments with appropriate storage qualifier.
2397         ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2398                 auto &type = this->get<SPIRType>(var.basetype);
2399                 bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
2400
2401                 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
2402                         return;
2403
2404                 if (block && !is_builtin_variable(var) && interface_variable_exists_in_entry_point(var.self))
2405                 {
2406                         if (var.storage == StorageClassInput)
2407                         {
2408                                 arguments.push_back(join("in ", variable_decl(type, join("stage_input", to_name(var.self)))));
2409                         }
2410                         else if (var.storage == StorageClassOutput)
2411                         {
2412                                 arguments.push_back(join("out ", variable_decl(type, join("stage_output", to_name(var.self)))));
2413                         }
2414                 }
2415         });
2416
2417         auto &execution = get_entry_point();
2418
2419         switch (execution.model)
2420         {
2421         case ExecutionModelGLCompute:
2422         {
2423                 SpecializationConstant wg_x, wg_y, wg_z;
2424                 get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
2425
2426                 uint32_t x = execution.workgroup_size.x;
2427                 uint32_t y = execution.workgroup_size.y;
2428                 uint32_t z = execution.workgroup_size.z;
2429
2430                 auto x_expr = wg_x.id ? get<SPIRConstant>(wg_x.id).specialization_constant_macro_name : to_string(x);
2431                 auto y_expr = wg_y.id ? get<SPIRConstant>(wg_y.id).specialization_constant_macro_name : to_string(y);
2432                 auto z_expr = wg_z.id ? get<SPIRConstant>(wg_z.id).specialization_constant_macro_name : to_string(z);
2433
2434                 statement("[numthreads(", x_expr, ", ", y_expr, ", ", z_expr, ")]");
2435                 break;
2436         }
2437         case ExecutionModelFragment:
2438                 if (execution.flags.get(ExecutionModeEarlyFragmentTests))
2439                         statement("[earlydepthstencil]");
2440                 break;
2441         default:
2442                 break;
2443         }
2444
2445         statement(require_output ? "SPIRV_Cross_Output " : "void ", "main(", merge(arguments), ")");
2446         begin_scope();
2447         bool legacy = hlsl_options.shader_model <= 30;
2448
2449         // Copy builtins from entry point arguments to globals.
2450         active_input_builtins.for_each_bit([&](uint32_t i) {
2451                 auto builtin = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassInput);
2452                 switch (static_cast<BuiltIn>(i))
2453                 {
2454                 case BuiltInFragCoord:
2455                         // VPOS in D3D9 is sampled at integer locations, apply half-pixel offset to be consistent.
2456                         // TODO: Do we need an option here? Any reason why a D3D9 shader would be used
2457                         // on a D3D10+ system with a different rasterization config?
2458                         if (legacy)
2459                                 statement(builtin, " = stage_input.", builtin, " + float4(0.5f, 0.5f, 0.0f, 0.0f);");
2460                         else
2461                         {
2462                                 statement(builtin, " = stage_input.", builtin, ";");
2463                                 // ZW are undefined in D3D9, only do this fixup here.
2464                                 statement(builtin, ".w = 1.0 / ", builtin, ".w;");
2465                         }
2466                         break;
2467
2468                 case BuiltInVertexId:
2469                 case BuiltInVertexIndex:
2470                 case BuiltInInstanceIndex:
2471                         // D3D semantics are uint, but shader wants int.
2472                         if (hlsl_options.support_nonzero_base_vertex_base_instance)
2473                         {
2474                                 if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
2475                                         statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
2476                                 else
2477                                         statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
2478                         }
2479                         else
2480                                 statement(builtin, " = int(stage_input.", builtin, ");");
2481                         break;
2482
2483                 case BuiltInInstanceId:
2484                         // D3D semantics are uint, but shader wants int.
2485                         statement(builtin, " = int(stage_input.", builtin, ");");
2486                         break;
2487
2488                 case BuiltInNumWorkgroups:
2489                 case BuiltInPointCoord:
2490                 case BuiltInSubgroupSize:
2491                 case BuiltInSubgroupLocalInvocationId:
2492                         break;
2493
2494                 case BuiltInSubgroupEqMask:
2495                         // Emulate these ...
2496                         // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2497                         statement("gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));");
2498                         statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;");
2499                         statement("if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;");
2500                         statement("if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;");
2501                         statement("if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;");
2502                         break;
2503
2504                 case BuiltInSubgroupGeMask:
2505                         // Emulate these ...
2506                         // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2507                         statement("gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);");
2508                         statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;");
2509                         statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;");
2510                         statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;");
2511                         statement("if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;");
2512                         statement("if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;");
2513                         statement("if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;");
2514                         break;
2515
2516                 case BuiltInSubgroupGtMask:
2517                         // Emulate these ...
2518                         // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2519                         statement("uint gt_lane_index = WaveGetLaneIndex() + 1;");
2520                         statement("gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);");
2521                         statement("if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;");
2522                         statement("if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;");
2523                         statement("if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;");
2524                         statement("if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;");
2525                         statement("if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;");
2526                         statement("if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;");
2527                         statement("if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;");
2528                         break;
2529
2530                 case BuiltInSubgroupLeMask:
2531                         // Emulate these ...
2532                         // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2533                         statement("uint le_lane_index = WaveGetLaneIndex() + 1;");
2534                         statement("gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;");
2535                         statement("if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;");
2536                         statement("if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;");
2537                         statement("if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;");
2538                         statement("if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;");
2539                         statement("if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;");
2540                         statement("if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;");
2541                         statement("if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;");
2542                         break;
2543
2544                 case BuiltInSubgroupLtMask:
2545                         // Emulate these ...
2546                         // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2547                         statement("gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;");
2548                         statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;");
2549                         statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;");
2550                         statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;");
2551                         statement("if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;");
2552                         statement("if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;");
2553                         statement("if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;");
2554                         break;
2555
2556                 case BuiltInClipDistance:
2557                         for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2558                                 statement("gl_ClipDistance[", clip, "] = stage_input.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3],
2559                                           ";");
2560                         break;
2561
2562                 case BuiltInCullDistance:
2563                         for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2564                                 statement("gl_CullDistance[", cull, "] = stage_input.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3],
2565                                           ";");
2566                         break;
2567
2568                 default:
2569                         statement(builtin, " = stage_input.", builtin, ";");
2570                         break;
2571                 }
2572         });
2573
2574         // Copy from stage input struct to globals.
2575         ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2576                 auto &type = this->get<SPIRType>(var.basetype);
2577                 bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
2578
2579                 if (var.storage != StorageClassInput)
2580                         return;
2581
2582                 bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
2583
2584                 if (!block && !var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
2585                     interface_variable_exists_in_entry_point(var.self))
2586                 {
2587                         auto name = to_name(var.self);
2588                         auto &mtype = this->get<SPIRType>(var.basetype);
2589                         if (need_matrix_unroll && mtype.columns > 1)
2590                         {
2591                                 // Unroll matrices.
2592                                 for (uint32_t col = 0; col < mtype.columns; col++)
2593                                         statement(name, "[", col, "] = stage_input.", name, "_", col, ";");
2594                         }
2595                         else
2596                         {
2597                                 statement(name, " = stage_input.", name, ";");
2598                         }
2599                 }
2600
2601                 // I/O blocks don't use the common stage input/output struct, but separate outputs.
2602                 if (block && !is_builtin_variable(var) && interface_variable_exists_in_entry_point(var.self))
2603                 {
2604                         auto name = to_name(var.self);
2605                         statement(name, " = stage_input", name, ";");
2606                 }
2607         });
2608
2609         // Run the shader.
2610         if (execution.model == ExecutionModelVertex)
2611                 statement("vert_main();");
2612         else if (execution.model == ExecutionModelFragment)
2613                 statement("frag_main();");
2614         else if (execution.model == ExecutionModelGLCompute)
2615                 statement("comp_main();");
2616         else
2617                 SPIRV_CROSS_THROW("Unsupported shader stage.");
2618
2619         // Copy block outputs.
2620         ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2621                 auto &type = this->get<SPIRType>(var.basetype);
2622                 bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
2623
2624                 if (var.storage != StorageClassOutput)
2625                         return;
2626
2627                 // I/O blocks don't use the common stage input/output struct, but separate outputs.
2628                 if (block && !is_builtin_variable(var) && interface_variable_exists_in_entry_point(var.self))
2629                 {
2630                         auto name = to_name(var.self);
2631                         statement("stage_output", name, " = ", name, ";");
2632                 }
2633         });
2634
2635         // Copy stage outputs.
2636         if (require_output)
2637         {
2638                 statement("SPIRV_Cross_Output stage_output;");
2639
2640                 // Copy builtins from globals to return struct.
2641                 active_output_builtins.for_each_bit([&](uint32_t i) {
2642                         // PointSize doesn't exist in HLSL.
2643                         if (i == BuiltInPointSize)
2644                                 return;
2645
2646                         switch (static_cast<BuiltIn>(i))
2647                         {
2648                         case BuiltInClipDistance:
2649                                 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2650                                         statement("stage_output.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3], " = gl_ClipDistance[",
2651                                                   clip, "];");
2652                                 break;
2653
2654                         case BuiltInCullDistance:
2655                                 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2656                                         statement("stage_output.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3], " = gl_CullDistance[",
2657                                                   cull, "];");
2658                                 break;
2659
2660                         default:
2661                         {
2662                                 auto builtin_expr = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassOutput);
2663                                 statement("stage_output.", builtin_expr, " = ", builtin_expr, ";");
2664                                 break;
2665                         }
2666                         }
2667                 });
2668
2669                 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2670                         auto &type = this->get<SPIRType>(var.basetype);
2671                         bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
2672
2673                         if (var.storage != StorageClassOutput)
2674                                 return;
2675
2676                         if (!block && var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
2677                             !is_builtin_variable(var) && interface_variable_exists_in_entry_point(var.self))
2678                         {
2679                                 auto name = to_name(var.self);
2680
2681                                 if (legacy && execution.model == ExecutionModelFragment)
2682                                 {
2683                                         string output_filler;
2684                                         for (uint32_t size = type.vecsize; size < 4; ++size)
2685                                                 output_filler += ", 0.0";
2686
2687                                         statement("stage_output.", name, " = float4(", name, output_filler, ");");
2688                                 }
2689                                 else
2690                                 {
2691                                         statement("stage_output.", name, " = ", name, ";");
2692                                 }
2693                         }
2694                 });
2695
2696                 statement("return stage_output;");
2697         }
2698
2699         end_scope();
2700 }
2701
2702 void CompilerHLSL::emit_fixup()
2703 {
2704         if (is_vertex_like_shader())
2705         {
2706                 // Do various mangling on the gl_Position.
2707                 if (hlsl_options.shader_model <= 30)
2708                 {
2709                         statement("gl_Position.x = gl_Position.x - gl_HalfPixel.x * "
2710                                   "gl_Position.w;");
2711                         statement("gl_Position.y = gl_Position.y + gl_HalfPixel.y * "
2712                                   "gl_Position.w;");
2713                 }
2714
2715                 if (options.vertex.flip_vert_y)
2716                         statement("gl_Position.y = -gl_Position.y;");
2717                 if (options.vertex.fixup_clipspace)
2718                         statement("gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5;");
2719         }
2720 }
2721
2722 void CompilerHLSL::emit_texture_op(const Instruction &i, bool sparse)
2723 {
2724         if (sparse)
2725                 SPIRV_CROSS_THROW("Sparse feedback not yet supported in HLSL.");
2726
2727         auto *ops = stream(i);
2728         auto op = static_cast<Op>(i.op);
2729         uint32_t length = i.length;
2730
2731         SmallVector<uint32_t> inherited_expressions;
2732
2733         uint32_t result_type = ops[0];
2734         uint32_t id = ops[1];
2735         VariableID img = ops[2];
2736         uint32_t coord = ops[3];
2737         uint32_t dref = 0;
2738         uint32_t comp = 0;
2739         bool gather = false;
2740         bool proj = false;
2741         const uint32_t *opt = nullptr;
2742         auto *combined_image = maybe_get<SPIRCombinedImageSampler>(img);
2743
2744         if (combined_image && has_decoration(img, DecorationNonUniform))
2745         {
2746                 set_decoration(combined_image->image, DecorationNonUniform);
2747                 set_decoration(combined_image->sampler, DecorationNonUniform);
2748         }
2749
2750         auto img_expr = to_non_uniform_aware_expression(combined_image ? combined_image->image : img);
2751
2752         inherited_expressions.push_back(coord);
2753
2754         switch (op)
2755         {
2756         case OpImageSampleDrefImplicitLod:
2757         case OpImageSampleDrefExplicitLod:
2758                 dref = ops[4];
2759                 opt = &ops[5];
2760                 length -= 5;
2761                 break;
2762
2763         case OpImageSampleProjDrefImplicitLod:
2764         case OpImageSampleProjDrefExplicitLod:
2765                 dref = ops[4];
2766                 proj = true;
2767                 opt = &ops[5];
2768                 length -= 5;
2769                 break;
2770
2771         case OpImageDrefGather:
2772                 dref = ops[4];
2773                 opt = &ops[5];
2774                 gather = true;
2775                 length -= 5;
2776                 break;
2777
2778         case OpImageGather:
2779                 comp = ops[4];
2780                 opt = &ops[5];
2781                 gather = true;
2782                 length -= 5;
2783                 break;
2784
2785         case OpImageSampleProjImplicitLod:
2786         case OpImageSampleProjExplicitLod:
2787                 opt = &ops[4];
2788                 length -= 4;
2789                 proj = true;
2790                 break;
2791
2792         case OpImageQueryLod:
2793                 opt = &ops[4];
2794                 length -= 4;
2795                 break;
2796
2797         default:
2798                 opt = &ops[4];
2799                 length -= 4;
2800                 break;
2801         }
2802
2803         auto &imgtype = expression_type(img);
2804         uint32_t coord_components = 0;
2805         switch (imgtype.image.dim)
2806         {
2807         case spv::Dim1D:
2808                 coord_components = 1;
2809                 break;
2810         case spv::Dim2D:
2811                 coord_components = 2;
2812                 break;
2813         case spv::Dim3D:
2814                 coord_components = 3;
2815                 break;
2816         case spv::DimCube:
2817                 coord_components = 3;
2818                 break;
2819         case spv::DimBuffer:
2820                 coord_components = 1;
2821                 break;
2822         default:
2823                 coord_components = 2;
2824                 break;
2825         }
2826
2827         if (dref)
2828                 inherited_expressions.push_back(dref);
2829
2830         if (imgtype.image.arrayed)
2831                 coord_components++;
2832
2833         uint32_t bias = 0;
2834         uint32_t lod = 0;
2835         uint32_t grad_x = 0;
2836         uint32_t grad_y = 0;
2837         uint32_t coffset = 0;
2838         uint32_t offset = 0;
2839         uint32_t coffsets = 0;
2840         uint32_t sample = 0;
2841         uint32_t minlod = 0;
2842         uint32_t flags = 0;
2843
2844         if (length)
2845         {
2846                 flags = opt[0];
2847                 opt++;
2848                 length--;
2849         }
2850
2851         auto test = [&](uint32_t &v, uint32_t flag) {
2852                 if (length && (flags & flag))
2853                 {
2854                         v = *opt++;
2855                         inherited_expressions.push_back(v);
2856                         length--;
2857                 }
2858         };
2859
2860         test(bias, ImageOperandsBiasMask);
2861         test(lod, ImageOperandsLodMask);
2862         test(grad_x, ImageOperandsGradMask);
2863         test(grad_y, ImageOperandsGradMask);
2864         test(coffset, ImageOperandsConstOffsetMask);
2865         test(offset, ImageOperandsOffsetMask);
2866         test(coffsets, ImageOperandsConstOffsetsMask);
2867         test(sample, ImageOperandsSampleMask);
2868         test(minlod, ImageOperandsMinLodMask);
2869
2870         string expr;
2871         string texop;
2872
2873         if (minlod != 0)
2874                 SPIRV_CROSS_THROW("MinLod texture operand not supported in HLSL.");
2875
2876         if (op == OpImageFetch)
2877         {
2878                 if (hlsl_options.shader_model < 40)
2879                 {
2880                         SPIRV_CROSS_THROW("texelFetch is not supported in HLSL shader model 2/3.");
2881                 }
2882                 texop += img_expr;
2883                 texop += ".Load";
2884         }
2885         else if (op == OpImageQueryLod)
2886         {
2887                 texop += img_expr;
2888                 texop += ".CalculateLevelOfDetail";
2889         }
2890         else
2891         {
2892                 auto &imgformat = get<SPIRType>(imgtype.image.type);
2893                 if (imgformat.basetype != SPIRType::Float)
2894                 {
2895                         SPIRV_CROSS_THROW("Sampling non-float textures is not supported in HLSL.");
2896                 }
2897
2898                 if (hlsl_options.shader_model >= 40)
2899                 {
2900                         texop += img_expr;
2901
2902                         if (image_is_comparison(imgtype, img))
2903                         {
2904                                 if (gather)
2905                                 {
2906                                         SPIRV_CROSS_THROW("GatherCmp does not exist in HLSL.");
2907                                 }
2908                                 else if (lod || grad_x || grad_y)
2909                                 {
2910                                         // Assume we want a fixed level, and the only thing we can get in HLSL is SampleCmpLevelZero.
2911                                         texop += ".SampleCmpLevelZero";
2912                                 }
2913                                 else
2914                                         texop += ".SampleCmp";
2915                         }
2916                         else if (gather)
2917                         {
2918                                 uint32_t comp_num = evaluate_constant_u32(comp);
2919                                 if (hlsl_options.shader_model >= 50)
2920                                 {
2921                                         switch (comp_num)
2922                                         {
2923                                         case 0:
2924                                                 texop += ".GatherRed";
2925                                                 break;
2926                                         case 1:
2927                                                 texop += ".GatherGreen";
2928                                                 break;
2929                                         case 2:
2930                                                 texop += ".GatherBlue";
2931                                                 break;
2932                                         case 3:
2933                                                 texop += ".GatherAlpha";
2934                                                 break;
2935                                         default:
2936                                                 SPIRV_CROSS_THROW("Invalid component.");
2937                                         }
2938                                 }
2939                                 else
2940                                 {
2941                                         if (comp_num == 0)
2942                                                 texop += ".Gather";
2943                                         else
2944                                                 SPIRV_CROSS_THROW("HLSL shader model 4 can only gather from the red component.");
2945                                 }
2946                         }
2947                         else if (bias)
2948                                 texop += ".SampleBias";
2949                         else if (grad_x || grad_y)
2950                                 texop += ".SampleGrad";
2951                         else if (lod)
2952                                 texop += ".SampleLevel";
2953                         else
2954                                 texop += ".Sample";
2955                 }
2956                 else
2957                 {
2958                         switch (imgtype.image.dim)
2959                         {
2960                         case Dim1D:
2961                                 texop += "tex1D";
2962                                 break;
2963                         case Dim2D:
2964                                 texop += "tex2D";
2965                                 break;
2966                         case Dim3D:
2967                                 texop += "tex3D";
2968                                 break;
2969                         case DimCube:
2970                                 texop += "texCUBE";
2971                                 break;
2972                         case DimRect:
2973                         case DimBuffer:
2974                         case DimSubpassData:
2975                                 SPIRV_CROSS_THROW("Buffer texture support is not yet implemented for HLSL"); // TODO
2976                         default:
2977                                 SPIRV_CROSS_THROW("Invalid dimension.");
2978                         }
2979
2980                         if (gather)
2981                                 SPIRV_CROSS_THROW("textureGather is not supported in HLSL shader model 2/3.");
2982                         if (offset || coffset)
2983                                 SPIRV_CROSS_THROW("textureOffset is not supported in HLSL shader model 2/3.");
2984
2985                         if (grad_x || grad_y)
2986                                 texop += "grad";
2987                         else if (lod)
2988                                 texop += "lod";
2989                         else if (bias)
2990                                 texop += "bias";
2991                         else if (proj || dref)
2992                                 texop += "proj";
2993                 }
2994         }
2995
2996         expr += texop;
2997         expr += "(";
2998         if (hlsl_options.shader_model < 40)
2999         {
3000                 if (combined_image)
3001                         SPIRV_CROSS_THROW("Separate images/samplers are not supported in HLSL shader model 2/3.");
3002                 expr += to_expression(img);
3003         }
3004         else if (op != OpImageFetch)
3005         {
3006                 string sampler_expr;
3007                 if (combined_image)
3008                         sampler_expr = to_non_uniform_aware_expression(combined_image->sampler);
3009                 else
3010                         sampler_expr = to_sampler_expression(img);
3011                 expr += sampler_expr;
3012         }
3013
3014         auto swizzle = [](uint32_t comps, uint32_t in_comps) -> const char * {
3015                 if (comps == in_comps)
3016                         return "";
3017
3018                 switch (comps)
3019                 {
3020                 case 1:
3021                         return ".x";
3022                 case 2:
3023                         return ".xy";
3024                 case 3:
3025                         return ".xyz";
3026                 default:
3027                         return "";
3028                 }
3029         };
3030
3031         bool forward = should_forward(coord);
3032
3033         // The IR can give us more components than we need, so chop them off as needed.
3034         string coord_expr;
3035         auto &coord_type = expression_type(coord);
3036         if (coord_components != coord_type.vecsize)
3037                 coord_expr = to_enclosed_expression(coord) + swizzle(coord_components, expression_type(coord).vecsize);
3038         else
3039                 coord_expr = to_expression(coord);
3040
3041         if (proj && hlsl_options.shader_model >= 40) // Legacy HLSL has "proj" operations which do this for us.
3042                 coord_expr = coord_expr + " / " + to_extract_component_expression(coord, coord_components);
3043
3044         if (hlsl_options.shader_model < 40)
3045         {
3046                 if (dref)
3047                 {
3048                         if (imgtype.image.dim != spv::Dim1D && imgtype.image.dim != spv::Dim2D)
3049                         {
3050                                 SPIRV_CROSS_THROW(
3051                                     "Depth comparison is only supported for 1D and 2D textures in HLSL shader model 2/3.");
3052                         }
3053
3054                         if (grad_x || grad_y)
3055                                 SPIRV_CROSS_THROW("Depth comparison is not supported for grad sampling in HLSL shader model 2/3.");
3056
3057                         for (uint32_t size = coord_components; size < 2; ++size)
3058                                 coord_expr += ", 0.0";
3059
3060                         forward = forward && should_forward(dref);
3061                         coord_expr += ", " + to_expression(dref);
3062                 }
3063                 else if (lod || bias || proj)
3064                 {
3065                         for (uint32_t size = coord_components; size < 3; ++size)
3066                                 coord_expr += ", 0.0";
3067                 }
3068
3069                 if (lod)
3070                 {
3071                         coord_expr = "float4(" + coord_expr + ", " + to_expression(lod) + ")";
3072                 }
3073                 else if (bias)
3074                 {
3075                         coord_expr = "float4(" + coord_expr + ", " + to_expression(bias) + ")";
3076                 }
3077                 else if (proj)
3078                 {
3079                         coord_expr = "float4(" + coord_expr + ", " + to_extract_component_expression(coord, coord_components) + ")";
3080                 }
3081                 else if (dref)
3082                 {
3083                         // A "normal" sample gets fed into tex2Dproj as well, because the
3084                         // regular tex2D accepts only two coordinates.
3085                         coord_expr = "float4(" + coord_expr + ", 1.0)";
3086                 }
3087
3088                 if (!!lod + !!bias + !!proj > 1)
3089                         SPIRV_CROSS_THROW("Legacy HLSL can only use one of lod/bias/proj modifiers.");
3090         }
3091
3092         if (op == OpImageFetch)
3093         {
3094                 if (imgtype.image.dim != DimBuffer && !imgtype.image.ms)
3095                         coord_expr =
3096                             join("int", coord_components + 1, "(", coord_expr, ", ", lod ? to_expression(lod) : string("0"), ")");
3097         }
3098         else
3099                 expr += ", ";
3100         expr += coord_expr;
3101
3102         if (dref && hlsl_options.shader_model >= 40)
3103         {
3104                 forward = forward && should_forward(dref);
3105                 expr += ", ";
3106
3107                 if (proj)
3108                         expr += to_enclosed_expression(dref) + " / " + to_extract_component_expression(coord, coord_components);
3109                 else
3110                         expr += to_expression(dref);
3111         }
3112
3113         if (!dref && (grad_x || grad_y))
3114         {
3115                 forward = forward && should_forward(grad_x);
3116                 forward = forward && should_forward(grad_y);
3117                 expr += ", ";
3118                 expr += to_expression(grad_x);
3119                 expr += ", ";
3120                 expr += to_expression(grad_y);
3121         }
3122
3123         if (!dref && lod && hlsl_options.shader_model >= 40 && op != OpImageFetch)
3124         {
3125                 forward = forward && should_forward(lod);
3126                 expr += ", ";
3127                 expr += to_expression(lod);
3128         }
3129
3130         if (!dref && bias && hlsl_options.shader_model >= 40)
3131         {
3132                 forward = forward && should_forward(bias);
3133                 expr += ", ";
3134                 expr += to_expression(bias);
3135         }
3136
3137         if (coffset)
3138         {
3139                 forward = forward && should_forward(coffset);
3140                 expr += ", ";
3141                 expr += to_expression(coffset);
3142         }
3143         else if (offset)
3144         {
3145                 forward = forward && should_forward(offset);
3146                 expr += ", ";
3147                 expr += to_expression(offset);
3148         }
3149
3150         if (sample)
3151         {
3152                 expr += ", ";
3153                 expr += to_expression(sample);
3154         }
3155
3156         expr += ")";
3157
3158         if (dref && hlsl_options.shader_model < 40)
3159                 expr += ".x";
3160
3161         if (op == OpImageQueryLod)
3162         {
3163                 // This is rather awkward.
3164                 // textureQueryLod returns two values, the "accessed level",
3165                 // as well as the actual LOD lambda.
3166                 // As far as I can tell, there is no way to get the .x component
3167                 // according to GLSL spec, and it depends on the sampler itself.
3168                 // Just assume X == Y, so we will need to splat the result to a float2.
3169                 statement("float _", id, "_tmp = ", expr, ";");
3170                 statement("float2 _", id, " = _", id, "_tmp.xx;");
3171                 set<SPIRExpression>(id, join("_", id), result_type, true);
3172         }
3173         else
3174         {
3175                 emit_op(result_type, id, expr, forward, false);
3176         }
3177
3178         for (auto &inherit : inherited_expressions)
3179                 inherit_expression_dependencies(id, inherit);
3180
3181         switch (op)
3182         {
3183         case OpImageSampleDrefImplicitLod:
3184         case OpImageSampleImplicitLod:
3185         case OpImageSampleProjImplicitLod:
3186         case OpImageSampleProjDrefImplicitLod:
3187                 register_control_dependent_expression(id);
3188                 break;
3189
3190         default:
3191                 break;
3192         }
3193 }
3194
3195 string CompilerHLSL::to_resource_binding(const SPIRVariable &var)
3196 {
3197         const auto &type = get<SPIRType>(var.basetype);
3198
3199         // We can remap push constant blocks, even if they don't have any binding decoration.
3200         if (type.storage != StorageClassPushConstant && !has_decoration(var.self, DecorationBinding))
3201                 return "";
3202
3203         char space = '\0';
3204
3205         HLSLBindingFlagBits resource_flags = HLSL_BINDING_AUTO_NONE_BIT;
3206
3207         switch (type.basetype)
3208         {
3209         case SPIRType::SampledImage:
3210                 space = 't'; // SRV
3211                 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3212                 break;
3213
3214         case SPIRType::Image:
3215                 if (type.image.sampled == 2 && type.image.dim != DimSubpassData)
3216                 {
3217                         if (has_decoration(var.self, DecorationNonWritable) && hlsl_options.nonwritable_uav_texture_as_srv)
3218                         {
3219                                 space = 't'; // SRV
3220                                 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3221                         }
3222                         else
3223                         {
3224                                 space = 'u'; // UAV
3225                                 resource_flags = HLSL_BINDING_AUTO_UAV_BIT;
3226                         }
3227                 }
3228                 else
3229                 {
3230                         space = 't'; // SRV
3231                         resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3232                 }
3233                 break;
3234
3235         case SPIRType::Sampler:
3236                 space = 's';
3237                 resource_flags = HLSL_BINDING_AUTO_SAMPLER_BIT;
3238                 break;
3239
3240         case SPIRType::Struct:
3241         {
3242                 auto storage = type.storage;
3243                 if (storage == StorageClassUniform)
3244                 {
3245                         if (has_decoration(type.self, DecorationBufferBlock))
3246                         {
3247                                 Bitset flags = ir.get_buffer_block_flags(var);
3248                                 bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
3249                                 space = is_readonly ? 't' : 'u'; // UAV
3250                                 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3251                         }
3252                         else if (has_decoration(type.self, DecorationBlock))
3253                         {
3254                                 space = 'b'; // Constant buffers
3255                                 resource_flags = HLSL_BINDING_AUTO_CBV_BIT;
3256                         }
3257                 }
3258                 else if (storage == StorageClassPushConstant)
3259                 {
3260                         space = 'b'; // Constant buffers
3261                         resource_flags = HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT;
3262                 }
3263                 else if (storage == StorageClassStorageBuffer)
3264                 {
3265                         // UAV or SRV depending on readonly flag.
3266                         Bitset flags = ir.get_buffer_block_flags(var);
3267                         bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
3268                         space = is_readonly ? 't' : 'u';
3269                         resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3270                 }
3271
3272                 break;
3273         }
3274         default:
3275                 break;
3276         }
3277
3278         if (!space)
3279                 return "";
3280
3281         uint32_t desc_set =
3282             resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantDescriptorSet : 0u;
3283         uint32_t binding = resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantBinding : 0u;
3284
3285         if (has_decoration(var.self, DecorationBinding))
3286                 binding = get_decoration(var.self, DecorationBinding);
3287         if (has_decoration(var.self, DecorationDescriptorSet))
3288                 desc_set = get_decoration(var.self, DecorationDescriptorSet);
3289
3290         return to_resource_register(resource_flags, space, binding, desc_set);
3291 }
3292
3293 string CompilerHLSL::to_resource_binding_sampler(const SPIRVariable &var)
3294 {
3295         // For combined image samplers.
3296         if (!has_decoration(var.self, DecorationBinding))
3297                 return "";
3298
3299         return to_resource_register(HLSL_BINDING_AUTO_SAMPLER_BIT, 's', get_decoration(var.self, DecorationBinding),
3300                                     get_decoration(var.self, DecorationDescriptorSet));
3301 }
3302
3303 void CompilerHLSL::remap_hlsl_resource_binding(HLSLBindingFlagBits type, uint32_t &desc_set, uint32_t &binding)
3304 {
3305         auto itr = resource_bindings.find({ get_execution_model(), desc_set, binding });
3306         if (itr != end(resource_bindings))
3307         {
3308                 auto &remap = itr->second;
3309                 remap.second = true;
3310
3311                 switch (type)
3312                 {
3313                 case HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT:
3314                 case HLSL_BINDING_AUTO_CBV_BIT:
3315                         desc_set = remap.first.cbv.register_space;
3316                         binding = remap.first.cbv.register_binding;
3317                         break;
3318
3319                 case HLSL_BINDING_AUTO_SRV_BIT:
3320                         desc_set = remap.first.srv.register_space;
3321                         binding = remap.first.srv.register_binding;
3322                         break;
3323
3324                 case HLSL_BINDING_AUTO_SAMPLER_BIT:
3325                         desc_set = remap.first.sampler.register_space;
3326                         binding = remap.first.sampler.register_binding;
3327                         break;
3328
3329                 case HLSL_BINDING_AUTO_UAV_BIT:
3330                         desc_set = remap.first.uav.register_space;
3331                         binding = remap.first.uav.register_binding;
3332                         break;
3333
3334                 default:
3335                         break;
3336                 }
3337         }
3338 }
3339
3340 string CompilerHLSL::to_resource_register(HLSLBindingFlagBits flag, char space, uint32_t binding, uint32_t space_set)
3341 {
3342         if ((flag & resource_binding_flags) == 0)
3343         {
3344                 remap_hlsl_resource_binding(flag, space_set, binding);
3345
3346                 // The push constant block did not have a binding, and there were no remap for it,
3347                 // so, declare without register binding.
3348                 if (flag == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT && space_set == ResourceBindingPushConstantDescriptorSet)
3349                         return "";
3350
3351                 if (hlsl_options.shader_model >= 51)
3352                         return join(" : register(", space, binding, ", space", space_set, ")");
3353                 else
3354                         return join(" : register(", space, binding, ")");
3355         }
3356         else
3357                 return "";
3358 }
3359
3360 void CompilerHLSL::emit_modern_uniform(const SPIRVariable &var)
3361 {
3362         auto &type = get<SPIRType>(var.basetype);
3363         switch (type.basetype)
3364         {
3365         case SPIRType::SampledImage:
3366         case SPIRType::Image:
3367         {
3368                 bool is_coherent = false;
3369                 if (type.basetype == SPIRType::Image && type.image.sampled == 2)
3370                         is_coherent = has_decoration(var.self, DecorationCoherent);
3371
3372                 statement(is_coherent ? "globallycoherent " : "", image_type_hlsl_modern(type, var.self), " ",
3373                           to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";");
3374
3375                 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
3376                 {
3377                         // For combined image samplers, also emit a combined image sampler.
3378                         if (image_is_comparison(type, var.self))
3379                                 statement("SamplerComparisonState ", to_sampler_expression(var.self), type_to_array_glsl(type),
3380                                           to_resource_binding_sampler(var), ";");
3381                         else
3382                                 statement("SamplerState ", to_sampler_expression(var.self), type_to_array_glsl(type),
3383                                           to_resource_binding_sampler(var), ";");
3384                 }
3385                 break;
3386         }
3387
3388         case SPIRType::Sampler:
3389                 if (comparison_ids.count(var.self))
3390                         statement("SamplerComparisonState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var),
3391                                   ";");
3392                 else
3393                         statement("SamplerState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";");
3394                 break;
3395
3396         default:
3397                 statement(variable_decl(var), to_resource_binding(var), ";");
3398                 break;
3399         }
3400 }
3401
3402 void CompilerHLSL::emit_legacy_uniform(const SPIRVariable &var)
3403 {
3404         auto &type = get<SPIRType>(var.basetype);
3405         switch (type.basetype)
3406         {
3407         case SPIRType::Sampler:
3408         case SPIRType::Image:
3409                 SPIRV_CROSS_THROW("Separate image and samplers not supported in legacy HLSL.");
3410
3411         default:
3412                 statement(variable_decl(var), ";");
3413                 break;
3414         }
3415 }
3416
3417 void CompilerHLSL::emit_uniform(const SPIRVariable &var)
3418 {
3419         add_resource_name(var.self);
3420         if (hlsl_options.shader_model >= 40)
3421                 emit_modern_uniform(var);
3422         else
3423                 emit_legacy_uniform(var);
3424 }
3425
3426 bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
3427 {
3428         return false;
3429 }
3430
3431 string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
3432 {
3433         if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
3434                 return type_to_glsl(out_type);
3435         else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Int64)
3436                 return type_to_glsl(out_type);
3437         else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Float)
3438                 return "asuint";
3439         else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::UInt)
3440                 return type_to_glsl(out_type);
3441         else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::UInt64)
3442                 return type_to_glsl(out_type);
3443         else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Float)
3444                 return "asint";
3445         else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::UInt)
3446                 return "asfloat";
3447         else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::Int)
3448                 return "asfloat";
3449         else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Double)
3450                 SPIRV_CROSS_THROW("Double to Int64 is not supported in HLSL.");
3451         else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Double)
3452                 SPIRV_CROSS_THROW("Double to UInt64 is not supported in HLSL.");
3453         else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::Int64)
3454                 return "asdouble";
3455         else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::UInt64)
3456                 return "asdouble";
3457         else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UInt && in_type.vecsize == 1)
3458         {
3459                 if (!requires_explicit_fp16_packing)
3460                 {
3461                         requires_explicit_fp16_packing = true;
3462                         force_recompile();
3463                 }
3464                 return "spvUnpackFloat2x16";
3465         }
3466         else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Half && in_type.vecsize == 2)
3467         {
3468                 if (!requires_explicit_fp16_packing)
3469                 {
3470                         requires_explicit_fp16_packing = true;
3471                         force_recompile();
3472                 }
3473                 return "spvPackFloat2x16";
3474         }
3475         else
3476                 return "";
3477 }
3478
3479 void CompilerHLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
3480 {
3481         auto op = static_cast<GLSLstd450>(eop);
3482
3483         // If we need to do implicit bitcasts, make sure we do it with the correct type.
3484         uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
3485         auto int_type = to_signed_basetype(integer_width);
3486         auto uint_type = to_unsigned_basetype(integer_width);
3487
3488         switch (op)
3489         {
3490         case GLSLstd450InverseSqrt:
3491                 emit_unary_func_op(result_type, id, args[0], "rsqrt");
3492                 break;
3493
3494         case GLSLstd450Fract:
3495                 emit_unary_func_op(result_type, id, args[0], "frac");
3496                 break;
3497
3498         case GLSLstd450RoundEven:
3499                 if (hlsl_options.shader_model < 40)
3500                         SPIRV_CROSS_THROW("roundEven is not supported in HLSL shader model 2/3.");
3501                 emit_unary_func_op(result_type, id, args[0], "round");
3502                 break;
3503
3504         case GLSLstd450Acosh:
3505         case GLSLstd450Asinh:
3506         case GLSLstd450Atanh:
3507                 SPIRV_CROSS_THROW("Inverse hyperbolics are not supported on HLSL.");
3508
3509         case GLSLstd450FMix:
3510         case GLSLstd450IMix:
3511                 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "lerp");
3512                 break;
3513
3514         case GLSLstd450Atan2:
3515                 emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
3516                 break;
3517
3518         case GLSLstd450Fma:
3519                 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "mad");
3520                 break;
3521
3522         case GLSLstd450InterpolateAtCentroid:
3523                 emit_unary_func_op(result_type, id, args[0], "EvaluateAttributeAtCentroid");
3524                 break;
3525         case GLSLstd450InterpolateAtSample:
3526                 emit_binary_func_op(result_type, id, args[0], args[1], "EvaluateAttributeAtSample");
3527                 break;
3528         case GLSLstd450InterpolateAtOffset:
3529                 emit_binary_func_op(result_type, id, args[0], args[1], "EvaluateAttributeSnapped");
3530                 break;
3531
3532         case GLSLstd450PackHalf2x16:
3533                 if (!requires_fp16_packing)
3534                 {
3535                         requires_fp16_packing = true;
3536                         force_recompile();
3537                 }
3538                 emit_unary_func_op(result_type, id, args[0], "spvPackHalf2x16");
3539                 break;
3540
3541         case GLSLstd450UnpackHalf2x16:
3542                 if (!requires_fp16_packing)
3543                 {
3544                         requires_fp16_packing = true;
3545                         force_recompile();
3546                 }
3547                 emit_unary_func_op(result_type, id, args[0], "spvUnpackHalf2x16");
3548                 break;
3549
3550         case GLSLstd450PackSnorm4x8:
3551                 if (!requires_snorm8_packing)
3552                 {
3553                         requires_snorm8_packing = true;
3554                         force_recompile();
3555                 }
3556                 emit_unary_func_op(result_type, id, args[0], "spvPackSnorm4x8");
3557                 break;
3558
3559         case GLSLstd450UnpackSnorm4x8:
3560                 if (!requires_snorm8_packing)
3561                 {
3562                         requires_snorm8_packing = true;
3563                         force_recompile();
3564                 }
3565                 emit_unary_func_op(result_type, id, args[0], "spvUnpackSnorm4x8");
3566                 break;
3567
3568         case GLSLstd450PackUnorm4x8:
3569                 if (!requires_unorm8_packing)
3570                 {
3571                         requires_unorm8_packing = true;
3572                         force_recompile();
3573                 }
3574                 emit_unary_func_op(result_type, id, args[0], "spvPackUnorm4x8");
3575                 break;
3576
3577         case GLSLstd450UnpackUnorm4x8:
3578                 if (!requires_unorm8_packing)
3579                 {
3580                         requires_unorm8_packing = true;
3581                         force_recompile();
3582                 }
3583                 emit_unary_func_op(result_type, id, args[0], "spvUnpackUnorm4x8");
3584                 break;
3585
3586         case GLSLstd450PackSnorm2x16:
3587                 if (!requires_snorm16_packing)
3588                 {
3589                         requires_snorm16_packing = true;
3590                         force_recompile();
3591                 }
3592                 emit_unary_func_op(result_type, id, args[0], "spvPackSnorm2x16");
3593                 break;
3594
3595         case GLSLstd450UnpackSnorm2x16:
3596                 if (!requires_snorm16_packing)
3597                 {
3598                         requires_snorm16_packing = true;
3599                         force_recompile();
3600                 }
3601                 emit_unary_func_op(result_type, id, args[0], "spvUnpackSnorm2x16");
3602                 break;
3603
3604         case GLSLstd450PackUnorm2x16:
3605                 if (!requires_unorm16_packing)
3606                 {
3607                         requires_unorm16_packing = true;
3608                         force_recompile();
3609                 }
3610                 emit_unary_func_op(result_type, id, args[0], "spvPackUnorm2x16");
3611                 break;
3612
3613         case GLSLstd450UnpackUnorm2x16:
3614                 if (!requires_unorm16_packing)
3615                 {
3616                         requires_unorm16_packing = true;
3617                         force_recompile();
3618                 }
3619                 emit_unary_func_op(result_type, id, args[0], "spvUnpackUnorm2x16");
3620                 break;
3621
3622         case GLSLstd450PackDouble2x32:
3623         case GLSLstd450UnpackDouble2x32:
3624                 SPIRV_CROSS_THROW("packDouble2x32/unpackDouble2x32 not supported in HLSL.");
3625
3626         case GLSLstd450FindILsb:
3627         {
3628                 auto basetype = expression_type(args[0]).basetype;
3629                 emit_unary_func_op_cast(result_type, id, args[0], "firstbitlow", basetype, basetype);
3630                 break;
3631         }
3632
3633         case GLSLstd450FindSMsb:
3634                 emit_unary_func_op_cast(result_type, id, args[0], "firstbithigh", int_type, int_type);
3635                 break;
3636
3637         case GLSLstd450FindUMsb:
3638                 emit_unary_func_op_cast(result_type, id, args[0], "firstbithigh", uint_type, uint_type);
3639                 break;
3640
3641         case GLSLstd450MatrixInverse:
3642         {
3643                 auto &type = get<SPIRType>(result_type);
3644                 if (type.vecsize == 2 && type.columns == 2)
3645                 {
3646                         if (!requires_inverse_2x2)
3647                         {
3648                                 requires_inverse_2x2 = true;
3649                                 force_recompile();
3650                         }
3651                 }
3652                 else if (type.vecsize == 3 && type.columns == 3)
3653                 {
3654                         if (!requires_inverse_3x3)
3655                         {
3656                                 requires_inverse_3x3 = true;
3657                                 force_recompile();
3658                         }
3659                 }
3660                 else if (type.vecsize == 4 && type.columns == 4)
3661                 {
3662                         if (!requires_inverse_4x4)
3663                         {
3664                                 requires_inverse_4x4 = true;
3665                                 force_recompile();
3666                         }
3667                 }
3668                 emit_unary_func_op(result_type, id, args[0], "spvInverse");
3669                 break;
3670         }
3671
3672         case GLSLstd450Normalize:
3673                 // HLSL does not support scalar versions here.
3674                 if (expression_type(args[0]).vecsize == 1)
3675                 {
3676                         // Returns -1 or 1 for valid input, sign() does the job.
3677                         emit_unary_func_op(result_type, id, args[0], "sign");
3678                 }
3679                 else
3680                         CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3681                 break;
3682
3683         case GLSLstd450Reflect:
3684                 if (get<SPIRType>(result_type).vecsize == 1)
3685                 {
3686                         if (!requires_scalar_reflect)
3687                         {
3688                                 requires_scalar_reflect = true;
3689                                 force_recompile();
3690                         }
3691                         emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
3692                 }
3693                 else
3694                         CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3695                 break;
3696
3697         case GLSLstd450Refract:
3698                 if (get<SPIRType>(result_type).vecsize == 1)
3699                 {
3700                         if (!requires_scalar_refract)
3701                         {
3702                                 requires_scalar_refract = true;
3703                                 force_recompile();
3704                         }
3705                         emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
3706                 }
3707                 else
3708                         CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3709                 break;
3710
3711         case GLSLstd450FaceForward:
3712                 if (get<SPIRType>(result_type).vecsize == 1)
3713                 {
3714                         if (!requires_scalar_faceforward)
3715                         {
3716                                 requires_scalar_faceforward = true;
3717                                 force_recompile();
3718                         }
3719                         emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
3720                 }
3721                 else
3722                         CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3723                 break;
3724
3725         default:
3726                 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3727                 break;
3728         }
3729 }
3730
3731 void CompilerHLSL::read_access_chain_array(const string &lhs, const SPIRAccessChain &chain)
3732 {
3733         auto &type = get<SPIRType>(chain.basetype);
3734
3735         // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
3736         auto ident = get_unique_identifier();
3737
3738         statement("[unroll]");
3739         statement("for (int ", ident, " = 0; ", ident, " < ", to_array_size(type, uint32_t(type.array.size() - 1)), "; ",
3740                   ident, "++)");
3741         begin_scope();
3742         auto subchain = chain;
3743         subchain.dynamic_index = join(ident, " * ", chain.array_stride, " + ", chain.dynamic_index);
3744         subchain.basetype = type.parent_type;
3745         if (!get<SPIRType>(subchain.basetype).array.empty())
3746                 subchain.array_stride = get_decoration(subchain.basetype, DecorationArrayStride);
3747         read_access_chain(nullptr, join(lhs, "[", ident, "]"), subchain);
3748         end_scope();
3749 }
3750
3751 void CompilerHLSL::read_access_chain_struct(const string &lhs, const SPIRAccessChain &chain)
3752 {
3753         auto &type = get<SPIRType>(chain.basetype);
3754         auto subchain = chain;
3755         uint32_t member_count = uint32_t(type.member_types.size());
3756
3757         for (uint32_t i = 0; i < member_count; i++)
3758         {
3759                 uint32_t offset = type_struct_member_offset(type, i);
3760                 subchain.static_index = chain.static_index + offset;
3761                 subchain.basetype = type.member_types[i];
3762
3763                 subchain.matrix_stride = 0;
3764                 subchain.array_stride = 0;
3765                 subchain.row_major_matrix = false;
3766
3767                 auto &member_type = get<SPIRType>(subchain.basetype);
3768                 if (member_type.columns > 1)
3769                 {
3770                         subchain.matrix_stride = type_struct_member_matrix_stride(type, i);
3771                         subchain.row_major_matrix = has_member_decoration(type.self, i, DecorationRowMajor);
3772                 }
3773
3774                 if (!member_type.array.empty())
3775                         subchain.array_stride = type_struct_member_array_stride(type, i);
3776
3777                 read_access_chain(nullptr, join(lhs, ".", to_member_name(type, i)), subchain);
3778         }
3779 }
3780
3781 void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIRAccessChain &chain)
3782 {
3783         auto &type = get<SPIRType>(chain.basetype);
3784
3785         SPIRType target_type;
3786         target_type.basetype = SPIRType::UInt;
3787         target_type.vecsize = type.vecsize;
3788         target_type.columns = type.columns;
3789
3790         if (!type.array.empty())
3791         {
3792                 read_access_chain_array(lhs, chain);
3793                 return;
3794         }
3795         else if (type.basetype == SPIRType::Struct)
3796         {
3797                 read_access_chain_struct(lhs, chain);
3798                 return;
3799         }
3800         else if (type.width != 32 && !hlsl_options.enable_16bit_types)
3801                 SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported, unless SM 6.2 and "
3802                                   "native 16-bit types are enabled.");
3803
3804         string base = chain.base;
3805         if (has_decoration(chain.self, DecorationNonUniform))
3806                 convert_non_uniform_expression(base, chain.self);
3807
3808         bool templated_load = hlsl_options.shader_model >= 62;
3809         string load_expr;
3810
3811         string template_expr;
3812         if (templated_load)
3813                 template_expr = join("<", type_to_glsl(type), ">");
3814
3815         // Load a vector or scalar.
3816         if (type.columns == 1 && !chain.row_major_matrix)
3817         {
3818                 const char *load_op = nullptr;
3819                 switch (type.vecsize)
3820                 {
3821                 case 1:
3822                         load_op = "Load";
3823                         break;
3824                 case 2:
3825                         load_op = "Load2";
3826                         break;
3827                 case 3:
3828                         load_op = "Load3";
3829                         break;
3830                 case 4:
3831                         load_op = "Load4";
3832                         break;
3833                 default:
3834                         SPIRV_CROSS_THROW("Unknown vector size.");
3835                 }
3836
3837                 if (templated_load)
3838                         load_op = "Load";
3839
3840                 load_expr = join(base, ".", load_op, template_expr, "(", chain.dynamic_index, chain.static_index, ")");
3841         }
3842         else if (type.columns == 1)
3843         {
3844                 // Strided load since we are loading a column from a row-major matrix.
3845                 if (templated_load)
3846                 {
3847                         auto scalar_type = type;
3848                         scalar_type.vecsize = 1;
3849                         scalar_type.columns = 1;
3850                         template_expr = join("<", type_to_glsl(scalar_type), ">");
3851                         if (type.vecsize > 1)
3852                                 load_expr += type_to_glsl(type) + "(";
3853                 }
3854                 else if (type.vecsize > 1)
3855                 {
3856                         load_expr = type_to_glsl(target_type);
3857                         load_expr += "(";
3858                 }
3859
3860                 for (uint32_t r = 0; r < type.vecsize; r++)
3861                 {
3862                         load_expr += join(base, ".Load", template_expr, "(", chain.dynamic_index,
3863                                           chain.static_index + r * chain.matrix_stride, ")");
3864                         if (r + 1 < type.vecsize)
3865                                 load_expr += ", ";
3866                 }
3867
3868                 if (type.vecsize > 1)
3869                         load_expr += ")";
3870         }
3871         else if (!chain.row_major_matrix)
3872         {
3873                 // Load a matrix, column-major, the easy case.
3874                 const char *load_op = nullptr;
3875                 switch (type.vecsize)
3876                 {
3877                 case 1:
3878                         load_op = "Load";
3879                         break;
3880                 case 2:
3881                         load_op = "Load2";
3882                         break;
3883                 case 3:
3884                         load_op = "Load3";
3885                         break;
3886                 case 4:
3887                         load_op = "Load4";
3888                         break;
3889                 default:
3890                         SPIRV_CROSS_THROW("Unknown vector size.");
3891                 }
3892
3893                 if (templated_load)
3894                 {
3895                         auto vector_type = type;
3896                         vector_type.columns = 1;
3897                         template_expr = join("<", type_to_glsl(vector_type), ">");
3898                         load_expr = type_to_glsl(type);
3899                         load_op = "Load";
3900                 }
3901                 else
3902                 {
3903                         // Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
3904                         // so row-major is technically column-major ...
3905                         load_expr = type_to_glsl(target_type);
3906                 }
3907                 load_expr += "(";
3908
3909                 for (uint32_t c = 0; c < type.columns; c++)
3910                 {
3911                         load_expr += join(base, ".", load_op, template_expr, "(", chain.dynamic_index,
3912                                           chain.static_index + c * chain.matrix_stride, ")");
3913                         if (c + 1 < type.columns)
3914                                 load_expr += ", ";
3915                 }
3916                 load_expr += ")";
3917         }
3918         else
3919         {
3920                 // Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
3921                 // considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
3922
3923                 if (templated_load)
3924                 {
3925                         load_expr = type_to_glsl(type);
3926                         auto scalar_type = type;
3927                         scalar_type.vecsize = 1;
3928                         scalar_type.columns = 1;
3929                         template_expr = join("<", type_to_glsl(scalar_type), ">");
3930                 }
3931                 else
3932                         load_expr = type_to_glsl(target_type);
3933
3934                 load_expr += "(";
3935
3936                 for (uint32_t c = 0; c < type.columns; c++)
3937                 {
3938                         for (uint32_t r = 0; r < type.vecsize; r++)
3939                         {
3940                                 load_expr += join(base, ".Load", template_expr, "(", chain.dynamic_index,
3941                                                   chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ")");
3942
3943                                 if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
3944                                         load_expr += ", ";
3945                         }
3946                 }
3947                 load_expr += ")";
3948         }
3949
3950         if (!templated_load)
3951         {
3952                 auto bitcast_op = bitcast_glsl_op(type, target_type);
3953                 if (!bitcast_op.empty())
3954                         load_expr = join(bitcast_op, "(", load_expr, ")");
3955         }
3956
3957         if (lhs.empty())
3958         {
3959                 assert(expr);
3960                 *expr = move(load_expr);
3961         }
3962         else
3963                 statement(lhs, " = ", load_expr, ";");
3964 }
3965
3966 void CompilerHLSL::emit_load(const Instruction &instruction)
3967 {
3968         auto ops = stream(instruction);
3969
3970         auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
3971         if (chain)
3972         {
3973                 uint32_t result_type = ops[0];
3974                 uint32_t id = ops[1];
3975                 uint32_t ptr = ops[2];
3976
3977                 auto &type = get<SPIRType>(result_type);
3978                 bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;
3979
3980                 if (composite_load)
3981                 {
3982                         // We cannot make this work in one single expression as we might have nested structures and arrays,
3983                         // so unroll the load to an uninitialized temporary.
3984                         emit_uninitialized_temporary_expression(result_type, id);
3985                         read_access_chain(nullptr, to_expression(id), *chain);
3986                         track_expression_read(chain->self);
3987                 }
3988                 else
3989                 {
3990                         string load_expr;
3991                         read_access_chain(&load_expr, "", *chain);
3992
3993                         bool forward = should_forward(ptr) && forced_temporaries.find(id) == end(forced_temporaries);
3994
3995                         // If we are forwarding this load,
3996                         // don't register the read to access chain here, defer that to when we actually use the expression,
3997                         // using the add_implied_read_expression mechanism.
3998                         if (!forward)
3999                                 track_expression_read(chain->self);
4000
4001                         // Do not forward complex load sequences like matrices, structs and arrays.
4002                         if (type.columns > 1)
4003                                 forward = false;
4004
4005                         auto &e = emit_op(result_type, id, load_expr, forward, true);
4006                         e.need_transpose = false;
4007                         register_read(id, ptr, forward);
4008                         inherit_expression_dependencies(id, ptr);
4009                         if (forward)
4010                                 add_implied_read_expression(e, chain->self);
4011                 }
4012         }
4013         else
4014                 CompilerGLSL::emit_instruction(instruction);
4015 }
4016
4017 void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
4018                                             const SmallVector<uint32_t> &composite_chain)
4019 {
4020         auto &type = get<SPIRType>(chain.basetype);
4021
4022         // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
4023         auto ident = get_unique_identifier();
4024
4025         uint32_t id = ir.increase_bound_by(2);
4026         uint32_t int_type_id = id + 1;
4027         SPIRType int_type;
4028         int_type.basetype = SPIRType::Int;
4029         int_type.width = 32;
4030         set<SPIRType>(int_type_id, int_type);
4031         set<SPIRExpression>(id, ident, int_type_id, true);
4032         set_name(id, ident);
4033         suppressed_usage_tracking.insert(id);
4034
4035         statement("[unroll]");
4036         statement("for (int ", ident, " = 0; ", ident, " < ", to_array_size(type, uint32_t(type.array.size() - 1)), "; ",
4037                   ident, "++)");
4038         begin_scope();
4039         auto subchain = chain;
4040         subchain.dynamic_index = join(ident, " * ", chain.array_stride, " + ", chain.dynamic_index);
4041         subchain.basetype = type.parent_type;
4042
4043         // Forcefully allow us to use an ID here by setting MSB.
4044         auto subcomposite_chain = composite_chain;
4045         subcomposite_chain.push_back(0x80000000u | id);
4046
4047         if (!get<SPIRType>(subchain.basetype).array.empty())
4048                 subchain.array_stride = get_decoration(subchain.basetype, DecorationArrayStride);
4049
4050         write_access_chain(subchain, value, subcomposite_chain);
4051         end_scope();
4052 }
4053
4054 void CompilerHLSL::write_access_chain_struct(const SPIRAccessChain &chain, uint32_t value,
4055                                              const SmallVector<uint32_t> &composite_chain)
4056 {
4057         auto &type = get<SPIRType>(chain.basetype);
4058         uint32_t member_count = uint32_t(type.member_types.size());
4059         auto subchain = chain;
4060
4061         auto subcomposite_chain = composite_chain;
4062         subcomposite_chain.push_back(0);
4063
4064         for (uint32_t i = 0; i < member_count; i++)
4065         {
4066                 uint32_t offset = type_struct_member_offset(type, i);
4067                 subchain.static_index = chain.static_index + offset;
4068                 subchain.basetype = type.member_types[i];
4069
4070                 subchain.matrix_stride = 0;
4071                 subchain.array_stride = 0;
4072                 subchain.row_major_matrix = false;
4073
4074                 auto &member_type = get<SPIRType>(subchain.basetype);
4075                 if (member_type.columns > 1)
4076                 {
4077                         subchain.matrix_stride = type_struct_member_matrix_stride(type, i);
4078                         subchain.row_major_matrix = has_member_decoration(type.self, i, DecorationRowMajor);
4079                 }
4080
4081                 if (!member_type.array.empty())
4082                         subchain.array_stride = type_struct_member_array_stride(type, i);
4083
4084                 subcomposite_chain.back() = i;
4085                 write_access_chain(subchain, value, subcomposite_chain);
4086         }
4087 }
4088
4089 string CompilerHLSL::write_access_chain_value(uint32_t value, const SmallVector<uint32_t> &composite_chain,
4090                                               bool enclose)
4091 {
4092         string ret;
4093         if (composite_chain.empty())
4094                 ret = to_expression(value);
4095         else
4096         {
4097                 AccessChainMeta meta;
4098                 ret = access_chain_internal(value, composite_chain.data(), uint32_t(composite_chain.size()),
4099                                             ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_LITERAL_MSB_FORCE_ID, &meta);
4100         }
4101
4102         if (enclose)
4103                 ret = enclose_expression(ret);
4104         return ret;
4105 }
4106
4107 void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t value,
4108                                       const SmallVector<uint32_t> &composite_chain)
4109 {
4110         auto &type = get<SPIRType>(chain.basetype);
4111
4112         // Make sure we trigger a read of the constituents in the access chain.
4113         track_expression_read(chain.self);
4114
4115         SPIRType target_type;
4116         target_type.basetype = SPIRType::UInt;
4117         target_type.vecsize = type.vecsize;
4118         target_type.columns = type.columns;
4119
4120         if (!type.array.empty())
4121         {
4122                 write_access_chain_array(chain, value, composite_chain);
4123                 register_write(chain.self);
4124                 return;
4125         }
4126         else if (type.basetype == SPIRType::Struct)
4127         {
4128                 write_access_chain_struct(chain, value, composite_chain);
4129                 register_write(chain.self);
4130                 return;
4131         }
4132         else if (type.width != 32 && !hlsl_options.enable_16bit_types)
4133                 SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported, unless SM 6.2 and "
4134                                   "native 16-bit types are enabled.");
4135
4136         bool templated_store = hlsl_options.shader_model >= 62;
4137
4138         auto base = chain.base;
4139         if (has_decoration(chain.self, DecorationNonUniform))
4140                 convert_non_uniform_expression(base, chain.self);
4141
4142         string template_expr;
4143         if (templated_store)
4144                 template_expr = join("<", type_to_glsl(type), ">");
4145
4146         if (type.columns == 1 && !chain.row_major_matrix)
4147         {
4148                 const char *store_op = nullptr;
4149                 switch (type.vecsize)
4150                 {
4151                 case 1:
4152                         store_op = "Store";
4153                         break;
4154                 case 2:
4155                         store_op = "Store2";
4156                         break;
4157                 case 3:
4158                         store_op = "Store3";
4159                         break;
4160                 case 4:
4161                         store_op = "Store4";
4162                         break;
4163                 default:
4164                         SPIRV_CROSS_THROW("Unknown vector size.");
4165                 }
4166
4167                 auto store_expr = write_access_chain_value(value, composite_chain, false);
4168
4169                 if (!templated_store)
4170                 {
4171                         auto bitcast_op = bitcast_glsl_op(target_type, type);
4172                         if (!bitcast_op.empty())
4173                                 store_expr = join(bitcast_op, "(", store_expr, ")");
4174                 }
4175                 else
4176                         store_op = "Store";
4177                 statement(base, ".", store_op, template_expr, "(", chain.dynamic_index, chain.static_index, ", ",
4178                           store_expr, ");");
4179         }
4180         else if (type.columns == 1)
4181         {
4182                 if (templated_store)
4183                 {
4184                         auto scalar_type = type;
4185                         scalar_type.vecsize = 1;
4186                         scalar_type.columns = 1;
4187                         template_expr = join("<", type_to_glsl(scalar_type), ">");
4188                 }
4189
4190                 // Strided store.
4191                 for (uint32_t r = 0; r < type.vecsize; r++)
4192                 {
4193                         auto store_expr = write_access_chain_value(value, composite_chain, true);
4194                         if (type.vecsize > 1)
4195                         {
4196                                 store_expr += ".";
4197                                 store_expr += index_to_swizzle(r);
4198                         }
4199                         remove_duplicate_swizzle(store_expr);
4200
4201                         if (!templated_store)
4202                         {
4203                                 auto bitcast_op = bitcast_glsl_op(target_type, type);
4204                                 if (!bitcast_op.empty())
4205                                         store_expr = join(bitcast_op, "(", store_expr, ")");
4206                         }
4207
4208                         statement(base, ".Store", template_expr, "(", chain.dynamic_index,
4209                                   chain.static_index + chain.matrix_stride * r, ", ", store_expr, ");");
4210                 }
4211         }
4212         else if (!chain.row_major_matrix)
4213         {
4214                 const char *store_op = nullptr;
4215                 switch (type.vecsize)
4216                 {
4217                 case 1:
4218                         store_op = "Store";
4219                         break;
4220                 case 2:
4221                         store_op = "Store2";
4222                         break;
4223                 case 3:
4224                         store_op = "Store3";
4225                         break;
4226                 case 4:
4227                         store_op = "Store4";
4228                         break;
4229                 default:
4230                         SPIRV_CROSS_THROW("Unknown vector size.");
4231                 }
4232
4233                 if (templated_store)
4234                 {
4235                         store_op = "Store";
4236                         auto vector_type = type;
4237                         vector_type.columns = 1;
4238                         template_expr = join("<", type_to_glsl(vector_type), ">");
4239                 }
4240
4241                 for (uint32_t c = 0; c < type.columns; c++)
4242                 {
4243                         auto store_expr = join(write_access_chain_value(value, composite_chain, true), "[", c, "]");
4244
4245                         if (!templated_store)
4246                         {
4247                                 auto bitcast_op = bitcast_glsl_op(target_type, type);
4248                                 if (!bitcast_op.empty())
4249                                         store_expr = join(bitcast_op, "(", store_expr, ")");
4250                         }
4251
4252                         statement(base, ".", store_op, template_expr, "(", chain.dynamic_index,
4253                                   chain.static_index + c * chain.matrix_stride, ", ", store_expr, ");");
4254                 }
4255         }
4256         else
4257         {
4258                 if (templated_store)
4259                 {
4260                         auto scalar_type = type;
4261                         scalar_type.vecsize = 1;
4262                         scalar_type.columns = 1;
4263                         template_expr = join("<", type_to_glsl(scalar_type), ">");
4264                 }
4265
4266                 for (uint32_t r = 0; r < type.vecsize; r++)
4267                 {
4268                         for (uint32_t c = 0; c < type.columns; c++)
4269                         {
4270                                 auto store_expr =
4271                                     join(write_access_chain_value(value, composite_chain, true), "[", c, "].", index_to_swizzle(r));
4272                                 remove_duplicate_swizzle(store_expr);
4273                                 auto bitcast_op = bitcast_glsl_op(target_type, type);
4274                                 if (!bitcast_op.empty())
4275                                         store_expr = join(bitcast_op, "(", store_expr, ")");
4276                                 statement(base, ".Store", template_expr, "(", chain.dynamic_index,
4277                                           chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ", ", store_expr, ");");
4278                         }
4279                 }
4280         }
4281
4282         register_write(chain.self);
4283 }
4284
4285 void CompilerHLSL::emit_store(const Instruction &instruction)
4286 {
4287         auto ops = stream(instruction);
4288         auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
4289         if (chain)
4290                 write_access_chain(*chain, ops[1], {});
4291         else
4292                 CompilerGLSL::emit_instruction(instruction);
4293 }
4294
4295 void CompilerHLSL::emit_access_chain(const Instruction &instruction)
4296 {
4297         auto ops = stream(instruction);
4298         uint32_t length = instruction.length;
4299
4300         bool need_byte_access_chain = false;
4301         auto &type = expression_type(ops[2]);
4302         const auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4303
4304         if (chain)
4305         {
4306                 // Keep tacking on an existing access chain.
4307                 need_byte_access_chain = true;
4308         }
4309         else if (type.storage == StorageClassStorageBuffer || has_decoration(type.self, DecorationBufferBlock))
4310         {
4311                 // If we are starting to poke into an SSBO, we are dealing with ByteAddressBuffers, and we need
4312                 // to emit SPIRAccessChain rather than a plain SPIRExpression.
4313                 uint32_t chain_arguments = length - 3;
4314                 if (chain_arguments > type.array.size())
4315                         need_byte_access_chain = true;
4316         }
4317
4318         if (need_byte_access_chain)
4319         {
4320                 // If we have a chain variable, we are already inside the SSBO, and any array type will refer to arrays within a block,
4321                 // and not array of SSBO.
4322                 uint32_t to_plain_buffer_length = chain ? 0u : static_cast<uint32_t>(type.array.size());
4323
4324                 auto *backing_variable = maybe_get_backing_variable(ops[2]);
4325
4326                 string base;
4327                 if (to_plain_buffer_length != 0)
4328                         base = access_chain(ops[2], &ops[3], to_plain_buffer_length, get<SPIRType>(ops[0]));
4329                 else if (chain)
4330                         base = chain->base;
4331                 else
4332                         base = to_expression(ops[2]);
4333
4334                 // Start traversing type hierarchy at the proper non-pointer types.
4335                 auto *basetype = &get_pointee_type(type);
4336
4337                 // Traverse the type hierarchy down to the actual buffer types.
4338                 for (uint32_t i = 0; i < to_plain_buffer_length; i++)
4339                 {
4340                         assert(basetype->parent_type);
4341                         basetype = &get<SPIRType>(basetype->parent_type);
4342                 }
4343
4344                 uint32_t matrix_stride = 0;
4345                 uint32_t array_stride = 0;
4346                 bool row_major_matrix = false;
4347
4348                 // Inherit matrix information.
4349                 if (chain)
4350                 {
4351                         matrix_stride = chain->matrix_stride;
4352                         row_major_matrix = chain->row_major_matrix;
4353                         array_stride = chain->array_stride;
4354                 }
4355
4356                 auto offsets = flattened_access_chain_offset(*basetype, &ops[3 + to_plain_buffer_length],
4357                                                              length - 3 - to_plain_buffer_length, 0, 1, &row_major_matrix,
4358                                                              &matrix_stride, &array_stride);
4359
4360                 auto &e = set<SPIRAccessChain>(ops[1], ops[0], type.storage, base, offsets.first, offsets.second);
4361                 e.row_major_matrix = row_major_matrix;
4362                 e.matrix_stride = matrix_stride;
4363                 e.array_stride = array_stride;
4364                 e.immutable = should_forward(ops[2]);
4365                 e.loaded_from = backing_variable ? backing_variable->self : ID(0);
4366
4367                 if (chain)
4368                 {
4369                         e.dynamic_index += chain->dynamic_index;
4370                         e.static_index += chain->static_index;
4371                 }
4372
4373                 for (uint32_t i = 2; i < length; i++)
4374                 {
4375                         inherit_expression_dependencies(ops[1], ops[i]);
4376                         add_implied_read_expression(e, ops[i]);
4377                 }
4378         }
4379         else
4380         {
4381                 CompilerGLSL::emit_instruction(instruction);
4382         }
4383 }
4384
4385 void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
4386 {
4387         const char *atomic_op = nullptr;
4388
4389         string value_expr;
4390         if (op != OpAtomicIDecrement && op != OpAtomicIIncrement && op != OpAtomicLoad && op != OpAtomicStore)
4391                 value_expr = to_expression(ops[op == OpAtomicCompareExchange ? 6 : 5]);
4392
4393         bool is_atomic_store = false;
4394
4395         switch (op)
4396         {
4397         case OpAtomicIIncrement:
4398                 atomic_op = "InterlockedAdd";
4399                 value_expr = "1";
4400                 break;
4401
4402         case OpAtomicIDecrement:
4403                 atomic_op = "InterlockedAdd";
4404                 value_expr = "-1";
4405                 break;
4406
4407         case OpAtomicLoad:
4408                 atomic_op = "InterlockedAdd";
4409                 value_expr = "0";
4410                 break;
4411
4412         case OpAtomicISub:
4413                 atomic_op = "InterlockedAdd";
4414                 value_expr = join("-", enclose_expression(value_expr));
4415                 break;
4416
4417         case OpAtomicSMin:
4418         case OpAtomicUMin:
4419                 atomic_op = "InterlockedMin";
4420                 break;
4421
4422         case OpAtomicSMax:
4423         case OpAtomicUMax:
4424                 atomic_op = "InterlockedMax";
4425                 break;
4426
4427         case OpAtomicAnd:
4428                 atomic_op = "InterlockedAnd";
4429                 break;
4430
4431         case OpAtomicOr:
4432                 atomic_op = "InterlockedOr";
4433                 break;
4434
4435         case OpAtomicXor:
4436                 atomic_op = "InterlockedXor";
4437                 break;
4438
4439         case OpAtomicIAdd:
4440                 atomic_op = "InterlockedAdd";
4441                 break;
4442
4443         case OpAtomicExchange:
4444                 atomic_op = "InterlockedExchange";
4445                 break;
4446
4447         case OpAtomicStore:
4448                 atomic_op = "InterlockedExchange";
4449                 is_atomic_store = true;
4450                 break;
4451
4452         case OpAtomicCompareExchange:
4453                 if (length < 8)
4454                         SPIRV_CROSS_THROW("Not enough data for opcode.");
4455                 atomic_op = "InterlockedCompareExchange";
4456                 value_expr = join(to_expression(ops[7]), ", ", value_expr);
4457                 break;
4458
4459         default:
4460                 SPIRV_CROSS_THROW("Unknown atomic opcode.");
4461         }
4462
4463         if (is_atomic_store)
4464         {
4465                 auto &data_type = expression_type(ops[0]);
4466                 auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
4467
4468                 auto &tmp_id = extra_sub_expressions[ops[0]];
4469                 if (!tmp_id)
4470                 {
4471                         tmp_id = ir.increase_bound_by(1);
4472                         emit_uninitialized_temporary_expression(get_pointee_type(data_type).self, tmp_id);
4473                 }
4474
4475                 if (data_type.storage == StorageClassImage || !chain)
4476                 {
4477                         statement(atomic_op, "(", to_non_uniform_aware_expression(ops[0]), ", ",
4478                                   to_expression(ops[3]), ", ", to_expression(tmp_id), ");");
4479                 }
4480                 else
4481                 {
4482                         string base = chain->base;
4483                         if (has_decoration(chain->self, DecorationNonUniform))
4484                                 convert_non_uniform_expression(base, chain->self);
4485                         // RWByteAddress buffer is always uint in its underlying type.
4486                         statement(base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ",
4487                                   to_expression(ops[3]), ", ", to_expression(tmp_id), ");");
4488                 }
4489         }
4490         else
4491         {
4492                 uint32_t result_type = ops[0];
4493                 uint32_t id = ops[1];
4494                 forced_temporaries.insert(ops[1]);
4495
4496                 auto &type = get<SPIRType>(result_type);
4497                 statement(variable_decl(type, to_name(id)), ";");
4498
4499                 auto &data_type = expression_type(ops[2]);
4500                 auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4501                 SPIRType::BaseType expr_type;
4502                 if (data_type.storage == StorageClassImage || !chain)
4503                 {
4504                         statement(atomic_op, "(", to_non_uniform_aware_expression(ops[2]), ", ", value_expr, ", ", to_name(id), ");");
4505                         expr_type = data_type.basetype;
4506                 }
4507                 else
4508                 {
4509                         // RWByteAddress buffer is always uint in its underlying type.
4510                         string base = chain->base;
4511                         if (has_decoration(chain->self, DecorationNonUniform))
4512                                 convert_non_uniform_expression(base, chain->self);
4513                         expr_type = SPIRType::UInt;
4514                         statement(base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ", value_expr,
4515                                   ", ", to_name(id), ");");
4516                 }
4517
4518                 auto expr = bitcast_expression(type, expr_type, to_name(id));
4519                 set<SPIRExpression>(id, expr, result_type, true);
4520         }
4521         flush_all_atomic_capable_variables();
4522 }
4523
4524 void CompilerHLSL::emit_subgroup_op(const Instruction &i)
4525 {
4526         if (hlsl_options.shader_model < 60)
4527                 SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher.");
4528
4529         const uint32_t *ops = stream(i);
4530         auto op = static_cast<Op>(i.op);
4531
4532         uint32_t result_type = ops[0];
4533         uint32_t id = ops[1];
4534
4535         auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
4536         if (scope != ScopeSubgroup)
4537                 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
4538
4539         const auto make_inclusive_Sum = [&](const string &expr) -> string {
4540                 return join(expr, " + ", to_expression(ops[4]));
4541         };
4542
4543         const auto make_inclusive_Product = [&](const string &expr) -> string {
4544                 return join(expr, " * ", to_expression(ops[4]));
4545         };
4546
4547         // If we need to do implicit bitcasts, make sure we do it with the correct type.
4548         uint32_t integer_width = get_integer_width_for_instruction(i);
4549         auto int_type = to_signed_basetype(integer_width);
4550         auto uint_type = to_unsigned_basetype(integer_width);
4551
4552 #define make_inclusive_BitAnd(expr) ""
4553 #define make_inclusive_BitOr(expr) ""
4554 #define make_inclusive_BitXor(expr) ""
4555 #define make_inclusive_Min(expr) ""
4556 #define make_inclusive_Max(expr) ""
4557
4558         switch (op)
4559         {
4560         case OpGroupNonUniformElect:
4561                 emit_op(result_type, id, "WaveIsFirstLane()", true);
4562                 break;
4563
4564         case OpGroupNonUniformBroadcast:
4565                 emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
4566                 break;
4567
4568         case OpGroupNonUniformBroadcastFirst:
4569                 emit_unary_func_op(result_type, id, ops[3], "WaveReadLaneFirst");
4570                 break;
4571
4572         case OpGroupNonUniformBallot:
4573                 emit_unary_func_op(result_type, id, ops[3], "WaveActiveBallot");
4574                 break;
4575
4576         case OpGroupNonUniformInverseBallot:
4577                 SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL.");
4578                 break;
4579
4580         case OpGroupNonUniformBallotBitExtract:
4581                 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL.");
4582                 break;
4583
4584         case OpGroupNonUniformBallotFindLSB:
4585                 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL.");
4586                 break;
4587
4588         case OpGroupNonUniformBallotFindMSB:
4589                 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL.");
4590                 break;
4591
4592         case OpGroupNonUniformBallotBitCount:
4593         {
4594                 auto operation = static_cast<GroupOperation>(ops[3]);
4595                 if (operation == GroupOperationReduce)
4596                 {
4597                         bool forward = should_forward(ops[4]);
4598                         auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(",
4599                                          to_enclosed_expression(ops[4]), ".y)");
4600                         auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(",
4601                                           to_enclosed_expression(ops[4]), ".w)");
4602                         emit_op(result_type, id, join(left, " + ", right), forward);
4603                         inherit_expression_dependencies(id, ops[4]);
4604                 }
4605                 else if (operation == GroupOperationInclusiveScan)
4606                         SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL.");
4607                 else if (operation == GroupOperationExclusiveScan)
4608                         SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL.");
4609                 else
4610                         SPIRV_CROSS_THROW("Invalid BitCount operation.");
4611                 break;
4612         }
4613
4614         case OpGroupNonUniformShuffle:
4615                 emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
4616                 break;
4617         case OpGroupNonUniformShuffleXor:
4618         {
4619                 bool forward = should_forward(ops[3]);
4620                 emit_op(ops[0], ops[1],
4621                         join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4622                              "WaveGetLaneIndex() ^ ", to_enclosed_expression(ops[4]), ")"), forward);
4623                 inherit_expression_dependencies(ops[1], ops[3]);
4624                 break;
4625         }
4626         case OpGroupNonUniformShuffleUp:
4627         {
4628                 bool forward = should_forward(ops[3]);
4629                 emit_op(ops[0], ops[1],
4630                         join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4631                              "WaveGetLaneIndex() - ", to_enclosed_expression(ops[4]), ")"), forward);
4632                 inherit_expression_dependencies(ops[1], ops[3]);
4633                 break;
4634         }
4635         case OpGroupNonUniformShuffleDown:
4636         {
4637                 bool forward = should_forward(ops[3]);
4638                 emit_op(ops[0], ops[1],
4639                         join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4640                              "WaveGetLaneIndex() + ", to_enclosed_expression(ops[4]), ")"), forward);
4641                 inherit_expression_dependencies(ops[1], ops[3]);
4642                 break;
4643         }
4644
4645         case OpGroupNonUniformAll:
4646                 emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllTrue");
4647                 break;
4648
4649         case OpGroupNonUniformAny:
4650                 emit_unary_func_op(result_type, id, ops[3], "WaveActiveAnyTrue");
4651                 break;
4652
4653         case OpGroupNonUniformAllEqual:
4654                 emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllEqual");
4655                 break;
4656
4657         // clang-format off
4658 #define HLSL_GROUP_OP(op, hlsl_op, supports_scan) \
4659 case OpGroupNonUniform##op: \
4660         { \
4661                 auto operation = static_cast<GroupOperation>(ops[3]); \
4662                 if (operation == GroupOperationReduce) \
4663                         emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \
4664                 else if (operation == GroupOperationInclusiveScan && supports_scan) \
4665         { \
4666                         bool forward = should_forward(ops[4]); \
4667                         emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \
4668                         inherit_expression_dependencies(id, ops[4]); \
4669         } \
4670                 else if (operation == GroupOperationExclusiveScan && supports_scan) \
4671                         emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \
4672                 else if (operation == GroupOperationClusteredReduce) \
4673                         SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \
4674                 else \
4675                         SPIRV_CROSS_THROW("Invalid group operation."); \
4676                 break; \
4677         }
4678
4679 #define HLSL_GROUP_OP_CAST(op, hlsl_op, type) \
4680 case OpGroupNonUniform##op: \
4681         { \
4682                 auto operation = static_cast<GroupOperation>(ops[3]); \
4683                 if (operation == GroupOperationReduce) \
4684                         emit_unary_func_op_cast(result_type, id, ops[4], "WaveActive" #hlsl_op, type, type); \
4685                 else \
4686                         SPIRV_CROSS_THROW("Invalid group operation."); \
4687                 break; \
4688         }
4689
4690         HLSL_GROUP_OP(FAdd, Sum, true)
4691         HLSL_GROUP_OP(FMul, Product, true)
4692         HLSL_GROUP_OP(FMin, Min, false)
4693         HLSL_GROUP_OP(FMax, Max, false)
4694         HLSL_GROUP_OP(IAdd, Sum, true)
4695         HLSL_GROUP_OP(IMul, Product, true)
4696         HLSL_GROUP_OP_CAST(SMin, Min, int_type)
4697         HLSL_GROUP_OP_CAST(SMax, Max, int_type)
4698         HLSL_GROUP_OP_CAST(UMin, Min, uint_type)
4699         HLSL_GROUP_OP_CAST(UMax, Max, uint_type)
4700         HLSL_GROUP_OP(BitwiseAnd, BitAnd, false)
4701         HLSL_GROUP_OP(BitwiseOr, BitOr, false)
4702         HLSL_GROUP_OP(BitwiseXor, BitXor, false)
4703         HLSL_GROUP_OP_CAST(LogicalAnd, BitAnd, uint_type)
4704         HLSL_GROUP_OP_CAST(LogicalOr, BitOr, uint_type)
4705         HLSL_GROUP_OP_CAST(LogicalXor, BitXor, uint_type)
4706
4707 #undef HLSL_GROUP_OP
4708 #undef HLSL_GROUP_OP_CAST
4709                 // clang-format on
4710
4711         case OpGroupNonUniformQuadSwap:
4712         {
4713                 uint32_t direction = evaluate_constant_u32(ops[4]);
4714                 if (direction == 0)
4715                         emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossX");
4716                 else if (direction == 1)
4717                         emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossY");
4718                 else if (direction == 2)
4719                         emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossDiagonal");
4720                 else
4721                         SPIRV_CROSS_THROW("Invalid quad swap direction.");
4722                 break;
4723         }
4724
4725         case OpGroupNonUniformQuadBroadcast:
4726         {
4727                 emit_binary_func_op(result_type, id, ops[3], ops[4], "QuadReadLaneAt");
4728                 break;
4729         }
4730
4731         default:
4732                 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
4733         }
4734
4735         register_control_dependent_expression(id);
4736 }
4737
4738 void CompilerHLSL::emit_instruction(const Instruction &instruction)
4739 {
4740         auto ops = stream(instruction);
4741         auto opcode = static_cast<Op>(instruction.op);
4742
4743 #define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
4744 #define HLSL_BOP_CAST(op, type) \
4745         emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4746 #define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
4747 #define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
4748 #define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
4749 #define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4750 #define HLSL_BFOP_CAST(op, type) \
4751         emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4752 #define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4753 #define HLSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
4754
4755         // If we need to do implicit bitcasts, make sure we do it with the correct type.
4756         uint32_t integer_width = get_integer_width_for_instruction(instruction);
4757         auto int_type = to_signed_basetype(integer_width);
4758         auto uint_type = to_unsigned_basetype(integer_width);
4759
4760         switch (opcode)
4761         {
4762         case OpAccessChain:
4763         case OpInBoundsAccessChain:
4764         {
4765                 emit_access_chain(instruction);
4766                 break;
4767         }
4768         case OpBitcast:
4769         {
4770                 auto bitcast_type = get_bitcast_type(ops[0], ops[2]);
4771                 if (bitcast_type == CompilerHLSL::TypeNormal)
4772                         CompilerGLSL::emit_instruction(instruction);
4773                 else
4774                 {
4775                         if (!requires_uint2_packing)
4776                         {
4777                                 requires_uint2_packing = true;
4778                                 force_recompile();
4779                         }
4780
4781                         if (bitcast_type == CompilerHLSL::TypePackUint2x32)
4782                                 emit_unary_func_op(ops[0], ops[1], ops[2], "spvPackUint2x32");
4783                         else
4784                                 emit_unary_func_op(ops[0], ops[1], ops[2], "spvUnpackUint2x32");
4785                 }
4786
4787                 break;
4788         }
4789
4790         case OpStore:
4791         {
4792                 emit_store(instruction);
4793                 break;
4794         }
4795
4796         case OpLoad:
4797         {
4798                 emit_load(instruction);
4799                 break;
4800         }
4801
4802         case OpMatrixTimesVector:
4803         {
4804                 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4805                 emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4806                 break;
4807         }
4808
4809         case OpVectorTimesMatrix:
4810         {
4811                 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4812                 emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4813                 break;
4814         }
4815
4816         case OpMatrixTimesMatrix:
4817         {
4818                 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4819                 emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4820                 break;
4821         }
4822
4823         case OpOuterProduct:
4824         {
4825                 uint32_t result_type = ops[0];
4826                 uint32_t id = ops[1];
4827                 uint32_t a = ops[2];
4828                 uint32_t b = ops[3];
4829
4830                 auto &type = get<SPIRType>(result_type);
4831                 string expr = type_to_glsl_constructor(type);
4832                 expr += "(";
4833                 for (uint32_t col = 0; col < type.columns; col++)
4834                 {
4835                         expr += to_enclosed_expression(a);
4836                         expr += " * ";
4837                         expr += to_extract_component_expression(b, col);
4838                         if (col + 1 < type.columns)
4839                                 expr += ", ";
4840                 }
4841                 expr += ")";
4842                 emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
4843                 inherit_expression_dependencies(id, a);
4844                 inherit_expression_dependencies(id, b);
4845                 break;
4846         }
4847
4848         case OpFMod:
4849         {
4850                 if (!requires_op_fmod)
4851                 {
4852                         requires_op_fmod = true;
4853                         force_recompile();
4854                 }
4855                 CompilerGLSL::emit_instruction(instruction);
4856                 break;
4857         }
4858
4859         case OpFRem:
4860                 emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], "fmod");
4861                 break;
4862
4863         case OpImage:
4864         {
4865                 uint32_t result_type = ops[0];
4866                 uint32_t id = ops[1];
4867                 auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
4868
4869                 if (combined)
4870                 {
4871                         auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
4872                         auto *var = maybe_get_backing_variable(combined->image);
4873                         if (var)
4874                                 e.loaded_from = var->self;
4875                 }
4876                 else
4877                 {
4878                         auto &e = emit_op(result_type, id, to_expression(ops[2]), true, true);
4879                         auto *var = maybe_get_backing_variable(ops[2]);
4880                         if (var)
4881                                 e.loaded_from = var->self;
4882                 }
4883                 break;
4884         }
4885
4886         case OpDPdx:
4887                 HLSL_UFOP(ddx);
4888                 register_control_dependent_expression(ops[1]);
4889                 break;
4890
4891         case OpDPdy:
4892                 HLSL_UFOP(ddy);
4893                 register_control_dependent_expression(ops[1]);
4894                 break;
4895
4896         case OpDPdxFine:
4897                 HLSL_UFOP(ddx_fine);
4898                 register_control_dependent_expression(ops[1]);
4899                 break;
4900
4901         case OpDPdyFine:
4902                 HLSL_UFOP(ddy_fine);
4903                 register_control_dependent_expression(ops[1]);
4904                 break;
4905
4906         case OpDPdxCoarse:
4907                 HLSL_UFOP(ddx_coarse);
4908                 register_control_dependent_expression(ops[1]);
4909                 break;
4910
4911         case OpDPdyCoarse:
4912                 HLSL_UFOP(ddy_coarse);
4913                 register_control_dependent_expression(ops[1]);
4914                 break;
4915
4916         case OpFwidth:
4917         case OpFwidthCoarse:
4918         case OpFwidthFine:
4919                 HLSL_UFOP(fwidth);
4920                 register_control_dependent_expression(ops[1]);
4921                 break;
4922
4923         case OpLogicalNot:
4924         {
4925                 auto result_type = ops[0];
4926                 auto id = ops[1];
4927                 auto &type = get<SPIRType>(result_type);
4928
4929                 if (type.vecsize > 1)
4930                         emit_unrolled_unary_op(result_type, id, ops[2], "!");
4931                 else
4932                         HLSL_UOP(!);
4933                 break;
4934         }
4935
4936         case OpIEqual:
4937         {
4938                 auto result_type = ops[0];
4939                 auto id = ops[1];
4940
4941                 if (expression_type(ops[2]).vecsize > 1)
4942                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "==", false, SPIRType::Unknown);
4943                 else
4944                         HLSL_BOP_CAST(==, int_type);
4945                 break;
4946         }
4947
4948         case OpLogicalEqual:
4949         case OpFOrdEqual:
4950         case OpFUnordEqual:
4951         {
4952                 // HLSL != operator is unordered.
4953                 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
4954                 // isnan() is apparently implemented as x != x as well.
4955                 // We cannot implement UnordEqual as !(OrdNotEqual), as HLSL cannot express OrdNotEqual.
4956                 // HACK: FUnordEqual will be implemented as FOrdEqual.
4957
4958                 auto result_type = ops[0];
4959                 auto id = ops[1];
4960
4961                 if (expression_type(ops[2]).vecsize > 1)
4962                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "==", false, SPIRType::Unknown);
4963                 else
4964                         HLSL_BOP(==);
4965                 break;
4966         }
4967
4968         case OpINotEqual:
4969         {
4970                 auto result_type = ops[0];
4971                 auto id = ops[1];
4972
4973                 if (expression_type(ops[2]).vecsize > 1)
4974                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "!=", false, SPIRType::Unknown);
4975                 else
4976                         HLSL_BOP_CAST(!=, int_type);
4977                 break;
4978         }
4979
4980         case OpLogicalNotEqual:
4981         case OpFOrdNotEqual:
4982         case OpFUnordNotEqual:
4983         {
4984                 // HLSL != operator is unordered.
4985                 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
4986                 // isnan() is apparently implemented as x != x as well.
4987
4988                 // FIXME: FOrdNotEqual cannot be implemented in a crisp and simple way here.
4989                 // We would need to do something like not(UnordEqual), but that cannot be expressed either.
4990                 // Adding a lot of NaN checks would be a breaking change from perspective of performance.
4991                 // SPIR-V will generally use isnan() checks when this even matters.
4992                 // HACK: FOrdNotEqual will be implemented as FUnordEqual.
4993
4994                 auto result_type = ops[0];
4995                 auto id = ops[1];
4996
4997                 if (expression_type(ops[2]).vecsize > 1)
4998                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "!=", false, SPIRType::Unknown);
4999                 else
5000                         HLSL_BOP(!=);
5001                 break;
5002         }
5003
5004         case OpUGreaterThan:
5005         case OpSGreaterThan:
5006         {
5007                 auto result_type = ops[0];
5008                 auto id = ops[1];
5009                 auto type = opcode == OpUGreaterThan ? uint_type : int_type;
5010
5011                 if (expression_type(ops[2]).vecsize > 1)
5012                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", false, type);
5013                 else
5014                         HLSL_BOP_CAST(>, type);
5015                 break;
5016         }
5017
5018         case OpFOrdGreaterThan:
5019         {
5020                 auto result_type = ops[0];
5021                 auto id = ops[1];
5022
5023                 if (expression_type(ops[2]).vecsize > 1)
5024                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", false, SPIRType::Unknown);
5025                 else
5026                         HLSL_BOP(>);
5027                 break;
5028         }
5029
5030         case OpFUnordGreaterThan:
5031         {
5032                 auto result_type = ops[0];
5033                 auto id = ops[1];
5034
5035                 if (expression_type(ops[2]).vecsize > 1)
5036                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", true, SPIRType::Unknown);
5037                 else
5038                         CompilerGLSL::emit_instruction(instruction);
5039                 break;
5040         }
5041
5042         case OpUGreaterThanEqual:
5043         case OpSGreaterThanEqual:
5044         {
5045                 auto result_type = ops[0];
5046                 auto id = ops[1];
5047
5048                 auto type = opcode == OpUGreaterThanEqual ? uint_type : int_type;
5049                 if (expression_type(ops[2]).vecsize > 1)
5050                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", false, type);
5051                 else
5052                         HLSL_BOP_CAST(>=, type);
5053                 break;
5054         }
5055
5056         case OpFOrdGreaterThanEqual:
5057         {
5058                 auto result_type = ops[0];
5059                 auto id = ops[1];
5060
5061                 if (expression_type(ops[2]).vecsize > 1)
5062                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", false, SPIRType::Unknown);
5063                 else
5064                         HLSL_BOP(>=);
5065                 break;
5066         }
5067
5068         case OpFUnordGreaterThanEqual:
5069         {
5070                 auto result_type = ops[0];
5071                 auto id = ops[1];
5072
5073                 if (expression_type(ops[2]).vecsize > 1)
5074                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", true, SPIRType::Unknown);
5075                 else
5076                         CompilerGLSL::emit_instruction(instruction);
5077                 break;
5078         }
5079
5080         case OpULessThan:
5081         case OpSLessThan:
5082         {
5083                 auto result_type = ops[0];
5084                 auto id = ops[1];
5085
5086                 auto type = opcode == OpULessThan ? uint_type : int_type;
5087                 if (expression_type(ops[2]).vecsize > 1)
5088                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", false, type);
5089                 else
5090                         HLSL_BOP_CAST(<, type);
5091                 break;
5092         }
5093
5094         case OpFOrdLessThan:
5095         {
5096                 auto result_type = ops[0];
5097                 auto id = ops[1];
5098
5099                 if (expression_type(ops[2]).vecsize > 1)
5100                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", false, SPIRType::Unknown);
5101                 else
5102                         HLSL_BOP(<);
5103                 break;
5104         }
5105
5106         case OpFUnordLessThan:
5107         {
5108                 auto result_type = ops[0];
5109                 auto id = ops[1];
5110
5111                 if (expression_type(ops[2]).vecsize > 1)
5112                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", true, SPIRType::Unknown);
5113                 else
5114                         CompilerGLSL::emit_instruction(instruction);
5115                 break;
5116         }
5117
5118         case OpULessThanEqual:
5119         case OpSLessThanEqual:
5120         {
5121                 auto result_type = ops[0];
5122                 auto id = ops[1];
5123
5124                 auto type = opcode == OpULessThanEqual ? uint_type : int_type;
5125                 if (expression_type(ops[2]).vecsize > 1)
5126                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", false, type);
5127                 else
5128                         HLSL_BOP_CAST(<=, type);
5129                 break;
5130         }
5131
5132         case OpFOrdLessThanEqual:
5133         {
5134                 auto result_type = ops[0];
5135                 auto id = ops[1];
5136
5137                 if (expression_type(ops[2]).vecsize > 1)
5138                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", false, SPIRType::Unknown);
5139                 else
5140                         HLSL_BOP(<=);
5141                 break;
5142         }
5143
5144         case OpFUnordLessThanEqual:
5145         {
5146                 auto result_type = ops[0];
5147                 auto id = ops[1];
5148
5149                 if (expression_type(ops[2]).vecsize > 1)
5150                         emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", true, SPIRType::Unknown);
5151                 else
5152                         CompilerGLSL::emit_instruction(instruction);
5153                 break;
5154         }
5155
5156         case OpImageQueryLod:
5157                 emit_texture_op(instruction, false);
5158                 break;
5159
5160         case OpImageQuerySizeLod:
5161         {
5162                 auto result_type = ops[0];
5163                 auto id = ops[1];
5164
5165                 require_texture_query_variant(ops[2]);
5166                 auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
5167                 statement("uint ", dummy_samples_levels, ";");
5168
5169                 auto expr = join("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", ",
5170                                  bitcast_expression(SPIRType::UInt, ops[3]), ", ", dummy_samples_levels, ")");
5171
5172                 auto &restype = get<SPIRType>(ops[0]);
5173                 expr = bitcast_expression(restype, SPIRType::UInt, expr);
5174                 emit_op(result_type, id, expr, true);
5175                 break;
5176         }
5177
5178         case OpImageQuerySize:
5179         {
5180                 auto result_type = ops[0];
5181                 auto id = ops[1];
5182
5183                 require_texture_query_variant(ops[2]);
5184                 bool uav = expression_type(ops[2]).image.sampled == 2;
5185
5186                 if (const auto *var = maybe_get_backing_variable(ops[2]))
5187                         if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
5188                                 uav = false;
5189
5190                 auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
5191                 statement("uint ", dummy_samples_levels, ";");
5192
5193                 string expr;
5194                 if (uav)
5195                         expr = join("spvImageSize(", to_non_uniform_aware_expression(ops[2]), ", ", dummy_samples_levels, ")");
5196                 else
5197                         expr = join("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", 0u, ", dummy_samples_levels, ")");
5198
5199                 auto &restype = get<SPIRType>(ops[0]);
5200                 expr = bitcast_expression(restype, SPIRType::UInt, expr);
5201                 emit_op(result_type, id, expr, true);
5202                 break;
5203         }
5204
5205         case OpImageQuerySamples:
5206         case OpImageQueryLevels:
5207         {
5208                 auto result_type = ops[0];
5209                 auto id = ops[1];
5210
5211                 require_texture_query_variant(ops[2]);
5212                 bool uav = expression_type(ops[2]).image.sampled == 2;
5213                 if (opcode == OpImageQueryLevels && uav)
5214                         SPIRV_CROSS_THROW("Cannot query levels for UAV images.");
5215
5216                 if (const auto *var = maybe_get_backing_variable(ops[2]))
5217                         if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
5218                                 uav = false;
5219
5220                 // Keep it simple and do not emit special variants to make this look nicer ...
5221                 // This stuff is barely, if ever, used.
5222                 forced_temporaries.insert(id);
5223                 auto &type = get<SPIRType>(result_type);
5224                 statement(variable_decl(type, to_name(id)), ";");
5225
5226                 if (uav)
5227                         statement("spvImageSize(", to_non_uniform_aware_expression(ops[2]), ", ", to_name(id), ");");
5228                 else
5229                         statement("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", 0u, ", to_name(id), ");");
5230
5231                 auto &restype = get<SPIRType>(ops[0]);
5232                 auto expr = bitcast_expression(restype, SPIRType::UInt, to_name(id));
5233                 set<SPIRExpression>(id, expr, result_type, true);
5234                 break;
5235         }
5236
5237         case OpImageRead:
5238         {
5239                 uint32_t result_type = ops[0];
5240                 uint32_t id = ops[1];
5241                 auto *var = maybe_get_backing_variable(ops[2]);
5242                 auto &type = expression_type(ops[2]);
5243                 bool subpass_data = type.image.dim == DimSubpassData;
5244                 bool pure = false;
5245
5246                 string imgexpr;
5247
5248                 if (subpass_data)
5249                 {
5250                         if (hlsl_options.shader_model < 40)
5251                                 SPIRV_CROSS_THROW("Subpass loads are not supported in HLSL shader model 2/3.");
5252
5253                         // Similar to GLSL, implement subpass loads using texelFetch.
5254                         if (type.image.ms)
5255                         {
5256                                 uint32_t operands = ops[4];
5257                                 if (operands != ImageOperandsSampleMask || instruction.length != 6)
5258                                         SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected operand mask was used.");
5259                                 uint32_t sample = ops[5];
5260                                 imgexpr = join(to_non_uniform_aware_expression(ops[2]), ".Load(int2(gl_FragCoord.xy), ", to_expression(sample), ")");
5261                         }
5262                         else
5263                                 imgexpr = join(to_non_uniform_aware_expression(ops[2]), ".Load(int3(int2(gl_FragCoord.xy), 0))");
5264
5265                         pure = true;
5266                 }
5267                 else
5268                 {
5269                         imgexpr = join(to_non_uniform_aware_expression(ops[2]), "[", to_expression(ops[3]), "]");
5270                         // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5271                         // except that the underlying type changes how the data is interpreted.
5272
5273                         bool force_srv =
5274                             hlsl_options.nonwritable_uav_texture_as_srv && var && has_decoration(var->self, DecorationNonWritable);
5275                         pure = force_srv;
5276
5277                         if (var && !subpass_data && !force_srv)
5278                                 imgexpr = remap_swizzle(get<SPIRType>(result_type),
5279                                                         image_format_to_components(get<SPIRType>(var->basetype).image.format), imgexpr);
5280                 }
5281
5282                 if (var && var->forwardable)
5283                 {
5284                         bool forward = forced_temporaries.find(id) == end(forced_temporaries);
5285                         auto &e = emit_op(result_type, id, imgexpr, forward);
5286
5287                         if (!pure)
5288                         {
5289                                 e.loaded_from = var->self;
5290                                 if (forward)
5291                                         var->dependees.push_back(id);
5292                         }
5293                 }
5294                 else
5295                         emit_op(result_type, id, imgexpr, false);
5296
5297                 inherit_expression_dependencies(id, ops[2]);
5298                 if (type.image.ms)
5299                         inherit_expression_dependencies(id, ops[5]);
5300                 break;
5301         }
5302
5303         case OpImageWrite:
5304         {
5305                 auto *var = maybe_get_backing_variable(ops[0]);
5306
5307                 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5308                 // except that the underlying type changes how the data is interpreted.
5309                 auto value_expr = to_expression(ops[2]);
5310                 if (var)
5311                 {
5312                         auto &type = get<SPIRType>(var->basetype);
5313                         auto narrowed_type = get<SPIRType>(type.image.type);
5314                         narrowed_type.vecsize = image_format_to_components(type.image.format);
5315                         value_expr = remap_swizzle(narrowed_type, expression_type(ops[2]).vecsize, value_expr);
5316                 }
5317
5318                 statement(to_non_uniform_aware_expression(ops[0]), "[", to_expression(ops[1]), "] = ", value_expr, ";");
5319                 if (var && variable_storage_is_aliased(*var))
5320                         flush_all_aliased_variables();
5321                 break;
5322         }
5323
5324         case OpImageTexelPointer:
5325         {
5326                 uint32_t result_type = ops[0];
5327                 uint32_t id = ops[1];
5328
5329                 auto expr = to_expression(ops[2]);
5330                 expr += join("[", to_expression(ops[3]), "]");
5331                 auto &e = set<SPIRExpression>(id, expr, result_type, true);
5332
5333                 // When using the pointer, we need to know which variable it is actually loaded from.
5334                 auto *var = maybe_get_backing_variable(ops[2]);
5335                 e.loaded_from = var ? var->self : ID(0);
5336                 inherit_expression_dependencies(id, ops[3]);
5337                 break;
5338         }
5339
5340         case OpAtomicCompareExchange:
5341         case OpAtomicExchange:
5342         case OpAtomicISub:
5343         case OpAtomicSMin:
5344         case OpAtomicUMin:
5345         case OpAtomicSMax:
5346         case OpAtomicUMax:
5347         case OpAtomicAnd:
5348         case OpAtomicOr:
5349         case OpAtomicXor:
5350         case OpAtomicIAdd:
5351         case OpAtomicIIncrement:
5352         case OpAtomicIDecrement:
5353         case OpAtomicLoad:
5354         case OpAtomicStore:
5355         {
5356                 emit_atomic(ops, instruction.length, opcode);
5357                 break;
5358         }
5359
5360         case OpControlBarrier:
5361         case OpMemoryBarrier:
5362         {
5363                 uint32_t memory;
5364                 uint32_t semantics;
5365
5366                 if (opcode == OpMemoryBarrier)
5367                 {
5368                         memory = evaluate_constant_u32(ops[0]);
5369                         semantics = evaluate_constant_u32(ops[1]);
5370                 }
5371                 else
5372                 {
5373                         memory = evaluate_constant_u32(ops[1]);
5374                         semantics = evaluate_constant_u32(ops[2]);
5375                 }
5376
5377                 if (memory == ScopeSubgroup)
5378                 {
5379                         // No Wave-barriers in HLSL.
5380                         break;
5381                 }
5382
5383                 // We only care about these flags, acquire/release and friends are not relevant to GLSL.
5384                 semantics = mask_relevant_memory_semantics(semantics);
5385
5386                 if (opcode == OpMemoryBarrier)
5387                 {
5388                         // If we are a memory barrier, and the next instruction is a control barrier, check if that memory barrier
5389                         // does what we need, so we avoid redundant barriers.
5390                         const Instruction *next = get_next_instruction_in_block(instruction);
5391                         if (next && next->op == OpControlBarrier)
5392                         {
5393                                 auto *next_ops = stream(*next);
5394                                 uint32_t next_memory = evaluate_constant_u32(next_ops[1]);
5395                                 uint32_t next_semantics = evaluate_constant_u32(next_ops[2]);
5396                                 next_semantics = mask_relevant_memory_semantics(next_semantics);
5397
5398                                 // There is no "just execution barrier" in HLSL.
5399                                 // If there are no memory semantics for next instruction, we will imply group shared memory is synced.
5400                                 if (next_semantics == 0)
5401                                         next_semantics = MemorySemanticsWorkgroupMemoryMask;
5402
5403                                 bool memory_scope_covered = false;
5404                                 if (next_memory == memory)
5405                                         memory_scope_covered = true;
5406                                 else if (next_semantics == MemorySemanticsWorkgroupMemoryMask)
5407                                 {
5408                                         // If we only care about workgroup memory, either Device or Workgroup scope is fine,
5409                                         // scope does not have to match.
5410                                         if ((next_memory == ScopeDevice || next_memory == ScopeWorkgroup) &&
5411                                             (memory == ScopeDevice || memory == ScopeWorkgroup))
5412                                         {
5413                                                 memory_scope_covered = true;
5414                                         }
5415                                 }
5416                                 else if (memory == ScopeWorkgroup && next_memory == ScopeDevice)
5417                                 {
5418                                         // The control barrier has device scope, but the memory barrier just has workgroup scope.
5419                                         memory_scope_covered = true;
5420                                 }
5421
5422                                 // If we have the same memory scope, and all memory types are covered, we're good.
5423                                 if (memory_scope_covered && (semantics & next_semantics) == semantics)
5424                                         break;
5425                         }
5426                 }
5427
5428                 // We are synchronizing some memory or syncing execution,
5429                 // so we cannot forward any loads beyond the memory barrier.
5430                 if (semantics || opcode == OpControlBarrier)
5431                 {
5432                         assert(current_emitting_block);
5433                         flush_control_dependent_expressions(current_emitting_block->self);
5434                         flush_all_active_variables();
5435                 }
5436
5437                 if (opcode == OpControlBarrier)
5438                 {
5439                         // We cannot emit just execution barrier, for no memory semantics pick the cheapest option.
5440                         if (semantics == MemorySemanticsWorkgroupMemoryMask || semantics == 0)
5441                                 statement("GroupMemoryBarrierWithGroupSync();");
5442                         else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5443                                 statement("DeviceMemoryBarrierWithGroupSync();");
5444                         else
5445                                 statement("AllMemoryBarrierWithGroupSync();");
5446                 }
5447                 else
5448                 {
5449                         if (semantics == MemorySemanticsWorkgroupMemoryMask)
5450                                 statement("GroupMemoryBarrier();");
5451                         else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5452                                 statement("DeviceMemoryBarrier();");
5453                         else
5454                                 statement("AllMemoryBarrier();");
5455                 }
5456                 break;
5457         }
5458
5459         case OpBitFieldInsert:
5460         {
5461                 if (!requires_bitfield_insert)
5462                 {
5463                         requires_bitfield_insert = true;
5464                         force_recompile();
5465                 }
5466
5467                 auto expr = join("spvBitfieldInsert(", to_expression(ops[2]), ", ", to_expression(ops[3]), ", ",
5468                                  to_expression(ops[4]), ", ", to_expression(ops[5]), ")");
5469
5470                 bool forward =
5471                     should_forward(ops[2]) && should_forward(ops[3]) && should_forward(ops[4]) && should_forward(ops[5]);
5472
5473                 auto &restype = get<SPIRType>(ops[0]);
5474                 expr = bitcast_expression(restype, SPIRType::UInt, expr);
5475                 emit_op(ops[0], ops[1], expr, forward);
5476                 break;
5477         }
5478
5479         case OpBitFieldSExtract:
5480         case OpBitFieldUExtract:
5481         {
5482                 if (!requires_bitfield_extract)
5483                 {
5484                         requires_bitfield_extract = true;
5485                         force_recompile();
5486                 }
5487
5488                 if (opcode == OpBitFieldSExtract)
5489                         HLSL_TFOP(spvBitfieldSExtract);
5490                 else
5491                         HLSL_TFOP(spvBitfieldUExtract);
5492                 break;
5493         }
5494
5495         case OpBitCount:
5496         {
5497                 auto basetype = expression_type(ops[2]).basetype;
5498                 emit_unary_func_op_cast(ops[0], ops[1], ops[2], "countbits", basetype, basetype);
5499                 break;
5500         }
5501
5502         case OpBitReverse:
5503                 HLSL_UFOP(reversebits);
5504                 break;
5505
5506         case OpArrayLength:
5507         {
5508                 auto *var = maybe_get_backing_variable(ops[2]);
5509                 if (!var)
5510                         SPIRV_CROSS_THROW("Array length must point directly to an SSBO block.");
5511
5512                 auto &type = get<SPIRType>(var->basetype);
5513                 if (!has_decoration(type.self, DecorationBlock) && !has_decoration(type.self, DecorationBufferBlock))
5514                         SPIRV_CROSS_THROW("Array length expression must point to a block type.");
5515
5516                 // This must be 32-bit uint, so we're good to go.
5517                 emit_uninitialized_temporary_expression(ops[0], ops[1]);
5518                 statement(to_non_uniform_aware_expression(ops[2]), ".GetDimensions(", to_expression(ops[1]), ");");
5519                 uint32_t offset = type_struct_member_offset(type, ops[3]);
5520                 uint32_t stride = type_struct_member_array_stride(type, ops[3]);
5521                 statement(to_expression(ops[1]), " = (", to_expression(ops[1]), " - ", offset, ") / ", stride, ";");
5522                 break;
5523         }
5524
5525         case OpIsHelperInvocationEXT:
5526                 SPIRV_CROSS_THROW("helperInvocationEXT() is not supported in HLSL.");
5527
5528         case OpBeginInvocationInterlockEXT:
5529         case OpEndInvocationInterlockEXT:
5530                 if (hlsl_options.shader_model < 51)
5531                         SPIRV_CROSS_THROW("Rasterizer order views require Shader Model 5.1.");
5532                 break; // Nothing to do in the body
5533
5534         default:
5535                 CompilerGLSL::emit_instruction(instruction);
5536                 break;
5537         }
5538 }
5539
5540 void CompilerHLSL::require_texture_query_variant(uint32_t var_id)
5541 {
5542         if (const auto *var = maybe_get_backing_variable(var_id))
5543                 var_id = var->self;
5544
5545         auto &type = expression_type(var_id);
5546         bool uav = type.image.sampled == 2;
5547         if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var_id, DecorationNonWritable))
5548                 uav = false;
5549
5550         uint32_t bit = 0;
5551         switch (type.image.dim)
5552         {
5553         case Dim1D:
5554                 bit = type.image.arrayed ? Query1DArray : Query1D;
5555                 break;
5556
5557         case Dim2D:
5558                 if (type.image.ms)
5559                         bit = type.image.arrayed ? Query2DMSArray : Query2DMS;
5560                 else
5561                         bit = type.image.arrayed ? Query2DArray : Query2D;
5562                 break;
5563
5564         case Dim3D:
5565                 bit = Query3D;
5566                 break;
5567
5568         case DimCube:
5569                 bit = type.image.arrayed ? QueryCubeArray : QueryCube;
5570                 break;
5571
5572         case DimBuffer:
5573                 bit = QueryBuffer;
5574                 break;
5575
5576         default:
5577                 SPIRV_CROSS_THROW("Unsupported query type.");
5578         }
5579
5580         switch (get<SPIRType>(type.image.type).basetype)
5581         {
5582         case SPIRType::Float:
5583                 bit += QueryTypeFloat;
5584                 break;
5585
5586         case SPIRType::Int:
5587                 bit += QueryTypeInt;
5588                 break;
5589
5590         case SPIRType::UInt:
5591                 bit += QueryTypeUInt;
5592                 break;
5593
5594         default:
5595                 SPIRV_CROSS_THROW("Unsupported query type.");
5596         }
5597
5598         auto norm_state = image_format_to_normalized_state(type.image.format);
5599         auto &variant = uav ? required_texture_size_variants
5600                                   .uav[uint32_t(norm_state)][image_format_to_components(type.image.format) - 1] :
5601                               required_texture_size_variants.srv;
5602
5603         uint64_t mask = 1ull << bit;
5604         if ((variant & mask) == 0)
5605         {
5606                 force_recompile();
5607                 variant |= mask;
5608         }
5609 }
5610
5611 void CompilerHLSL::set_root_constant_layouts(std::vector<RootConstants> layout)
5612 {
5613         root_constants_layout = move(layout);
5614 }
5615
5616 void CompilerHLSL::add_vertex_attribute_remap(const HLSLVertexAttributeRemap &vertex_attributes)
5617 {
5618         remap_vertex_attributes.push_back(vertex_attributes);
5619 }
5620
5621 VariableID CompilerHLSL::remap_num_workgroups_builtin()
5622 {
5623         update_active_builtins();
5624
5625         if (!active_input_builtins.get(BuiltInNumWorkgroups))
5626                 return 0;
5627
5628         // Create a new, fake UBO.
5629         uint32_t offset = ir.increase_bound_by(4);
5630
5631         uint32_t uint_type_id = offset;
5632         uint32_t block_type_id = offset + 1;
5633         uint32_t block_pointer_type_id = offset + 2;
5634         uint32_t variable_id = offset + 3;
5635
5636         SPIRType uint_type;
5637         uint_type.basetype = SPIRType::UInt;
5638         uint_type.width = 32;
5639         uint_type.vecsize = 3;
5640         uint_type.columns = 1;
5641         set<SPIRType>(uint_type_id, uint_type);
5642
5643         SPIRType block_type;
5644         block_type.basetype = SPIRType::Struct;
5645         block_type.member_types.push_back(uint_type_id);
5646         set<SPIRType>(block_type_id, block_type);
5647         set_decoration(block_type_id, DecorationBlock);
5648         set_member_name(block_type_id, 0, "count");
5649         set_member_decoration(block_type_id, 0, DecorationOffset, 0);
5650
5651         SPIRType block_pointer_type = block_type;
5652         block_pointer_type.pointer = true;
5653         block_pointer_type.storage = StorageClassUniform;
5654         block_pointer_type.parent_type = block_type_id;
5655         auto &ptr_type = set<SPIRType>(block_pointer_type_id, block_pointer_type);
5656
5657         // Preserve self.
5658         ptr_type.self = block_type_id;
5659
5660         set<SPIRVariable>(variable_id, block_pointer_type_id, StorageClassUniform);
5661         ir.meta[variable_id].decoration.alias = "SPIRV_Cross_NumWorkgroups";
5662
5663         num_workgroups_builtin = variable_id;
5664         return variable_id;
5665 }
5666
5667 void CompilerHLSL::set_resource_binding_flags(HLSLBindingFlags flags)
5668 {
5669         resource_binding_flags = flags;
5670 }
5671
5672 void CompilerHLSL::validate_shader_model()
5673 {
5674         // Check for nonuniform qualifier.
5675         // Instead of looping over all decorations to find this, just look at capabilities.
5676         for (auto &cap : ir.declared_capabilities)
5677         {
5678                 switch (cap)
5679                 {
5680                 case CapabilityShaderNonUniformEXT:
5681                 case CapabilityRuntimeDescriptorArrayEXT:
5682                         if (hlsl_options.shader_model < 51)
5683                                 SPIRV_CROSS_THROW(
5684                                     "Shader model 5.1 or higher is required to use bindless resources or NonUniformResourceIndex.");
5685                         break;
5686
5687                 case CapabilityVariablePointers:
5688                 case CapabilityVariablePointersStorageBuffer:
5689                         SPIRV_CROSS_THROW("VariablePointers capability is not supported in HLSL.");
5690
5691                 default:
5692                         break;
5693                 }
5694         }
5695
5696         if (ir.addressing_model != AddressingModelLogical)
5697                 SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL.");
5698
5699         if (hlsl_options.enable_16bit_types && hlsl_options.shader_model < 62)
5700                 SPIRV_CROSS_THROW("Need at least shader model 6.2 when enabling native 16-bit type support.");
5701 }
5702
5703 string CompilerHLSL::compile()
5704 {
5705         ir.fixup_reserved_names();
5706
5707         // Do not deal with ES-isms like precision, older extensions and such.
5708         options.es = false;
5709         options.version = 450;
5710         options.vulkan_semantics = true;
5711         backend.float_literal_suffix = true;
5712         backend.double_literal_suffix = false;
5713         backend.long_long_literal_suffix = true;
5714         backend.uint32_t_literal_suffix = true;
5715         backend.int16_t_literal_suffix = "";
5716         backend.uint16_t_literal_suffix = "u";
5717         backend.basic_int_type = "int";
5718         backend.basic_uint_type = "uint";
5719         backend.demote_literal = "discard";
5720         backend.boolean_mix_function = "";
5721         backend.swizzle_is_function = false;
5722         backend.shared_is_implied = true;
5723         backend.unsized_array_supported = true;
5724         backend.explicit_struct_type = false;
5725         backend.use_initializer_list = true;
5726         backend.use_constructor_splatting = false;
5727         backend.can_swizzle_scalar = true;
5728         backend.can_declare_struct_inline = false;
5729         backend.can_declare_arrays_inline = false;
5730         backend.can_return_array = false;
5731         backend.nonuniform_qualifier = "NonUniformResourceIndex";
5732         backend.support_case_fallthrough = false;
5733
5734         // SM 4.1 does not support precise for some reason.
5735         backend.support_precise_qualifier = hlsl_options.shader_model >= 50 || hlsl_options.shader_model == 40;
5736
5737         fixup_type_alias();
5738         reorder_type_alias();
5739         build_function_control_flow_graphs_and_analyze();
5740         validate_shader_model();
5741         update_active_builtins();
5742         analyze_image_and_sampler_usage();
5743         analyze_interlocked_resource_usage();
5744
5745         // Subpass input needs SV_Position.
5746         if (need_subpass_input)
5747                 active_input_builtins.set(BuiltInFragCoord);
5748
5749         uint32_t pass_count = 0;
5750         do
5751         {
5752                 if (pass_count >= 3)
5753                         SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
5754
5755                 reset();
5756
5757                 // Move constructor for this type is broken on GCC 4.9 ...
5758                 buffer.reset();
5759
5760                 emit_header();
5761                 emit_resources();
5762
5763                 emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
5764                 emit_hlsl_entry_point();
5765
5766                 pass_count++;
5767         } while (is_forcing_recompilation());
5768
5769         // Entry point in HLSL is always main() for the time being.
5770         get_entry_point().name = "main";
5771
5772         return buffer.str();
5773 }
5774
5775 void CompilerHLSL::emit_block_hints(const SPIRBlock &block)
5776 {
5777         switch (block.hint)
5778         {
5779         case SPIRBlock::HintFlatten:
5780                 statement("[flatten]");
5781                 break;
5782         case SPIRBlock::HintDontFlatten:
5783                 statement("[branch]");
5784                 break;
5785         case SPIRBlock::HintUnroll:
5786                 statement("[unroll]");
5787                 break;
5788         case SPIRBlock::HintDontUnroll:
5789                 statement("[loop]");
5790                 break;
5791         default:
5792                 break;
5793         }
5794 }
5795
5796 string CompilerHLSL::get_unique_identifier()
5797 {
5798         return join("_", unique_identifier_count++, "ident");
5799 }
5800
5801 void CompilerHLSL::add_hlsl_resource_binding(const HLSLResourceBinding &binding)
5802 {
5803         StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
5804         resource_bindings[tuple] = { binding, false };
5805 }
5806
5807 bool CompilerHLSL::is_hlsl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
5808 {
5809         StageSetBinding tuple = { model, desc_set, binding };
5810         auto itr = resource_bindings.find(tuple);
5811         return itr != end(resource_bindings) && itr->second.second;
5812 }
5813
5814 CompilerHLSL::BitcastType CompilerHLSL::get_bitcast_type(uint32_t result_type, uint32_t op0)
5815 {
5816         auto &rslt_type = get<SPIRType>(result_type);
5817         auto &expr_type = expression_type(op0);
5818
5819         if (rslt_type.basetype == SPIRType::BaseType::UInt64 && expr_type.basetype == SPIRType::BaseType::UInt &&
5820             expr_type.vecsize == 2)
5821                 return BitcastType::TypePackUint2x32;
5822         else if (rslt_type.basetype == SPIRType::BaseType::UInt && rslt_type.vecsize == 2 &&
5823                  expr_type.basetype == SPIRType::BaseType::UInt64)
5824                 return BitcastType::TypeUnpackUint64;
5825
5826         return BitcastType::TypeNormal;
5827 }
5828
5829 bool CompilerHLSL::is_hlsl_force_storage_buffer_as_uav(ID id) const
5830 {
5831         if (hlsl_options.force_storage_buffer_as_uav)
5832         {
5833                 return true;
5834         }
5835
5836         const uint32_t desc_set = get_decoration(id, spv::DecorationDescriptorSet);
5837         const uint32_t binding = get_decoration(id, spv::DecorationBinding);
5838
5839         return (force_uav_buffer_bindings.find({ desc_set, binding }) != force_uav_buffer_bindings.end());
5840 }
5841
5842 void CompilerHLSL::set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding)
5843 {
5844         SetBindingPair pair = { desc_set, binding };
5845         force_uav_buffer_bindings.insert(pair);
5846 }
5847
5848 bool CompilerHLSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
5849 {
5850         return (builtin == BuiltInSampleMask);
5851 }