HLSL: Do structure conversion for return type struct-punning on non-entry-point funct...
authorJohn Kessenich <cepheus@frii.com>
Thu, 6 Oct 2016 18:59:51 +0000 (12:59 -0600)
committerJohn Kessenich <cepheus@frii.com>
Thu, 6 Oct 2016 19:06:13 +0000 (13:06 -0600)
SPIRV/GlslangToSpv.cpp
Test/baseResults/hlsl.multiReturn.frag.out [new file with mode: 0755]
Test/baseResults/remap.hlsl.templatetypes.everything.frag.out
Test/baseResults/remap.hlsl.templatetypes.none.frag.out
Test/hlsl.multiReturn.frag [new file with mode: 0755]
Test/remap.hlsl.templatetypes.everything.frag
Test/remap.hlsl.templatetypes.none.frag
glslang/Include/revision.h
gtests/Hlsl.FromFile.cpp

index 40ff9b4..5ecb6ab 100755 (executable)
@@ -172,6 +172,7 @@ protected:
     spv::Id getExtBuiltins(const char* name);
 
     spv::Function* shaderEntry;
+    spv::Function* currentFunction;
     spv::Instruction* entryPoint;
     int sequenceDepth;
 
@@ -733,7 +734,8 @@ bool HasNonLayoutQualifiers(const glslang::TType& type, const glslang::TQualifie
 //
 
 TGlslangToSpvTraverser::TGlslangToSpvTraverser(const glslang::TIntermediate* glslangIntermediate, spv::SpvBuildLogger* buildLogger)
-    : TIntermTraverser(true, false, true), shaderEntry(0), sequenceDepth(0), logger(buildLogger),
+    : TIntermTraverser(true, false, true), shaderEntry(nullptr), currentFunction(nullptr),
+      sequenceDepth(0), logger(buildLogger),
       builder((glslang::GetKhronosToolId() << 16) | GeneratorVersion, logger),
       inMain(false), mainTerminated(false), linkageOnly(false),
       glslangIntermediate(glslangIntermediate)
@@ -1351,6 +1353,7 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
             if (isShaderEntryPoint(node)) {
                 inMain = true;
                 builder.setBuildPoint(shaderEntry->getLastBlock());
+                currentFunction = shaderEntry;
             } else {
                 handleFunctionEntry(node);
             }
@@ -1858,9 +1861,18 @@ bool TGlslangToSpvTraverser::visitBranch(glslang::TVisit /* visit */, glslang::T
         builder.createLoopContinue();
         break;
     case glslang::EOpReturn:
-        if (node->getExpression())
-            builder.makeReturn(false, accessChainLoad(node->getExpression()->getType()));
-        else
+        if (node->getExpression()) {
+            const glslang::TType& glslangReturnType = node->getExpression()->getType();
+            spv::Id returnId = accessChainLoad(glslangReturnType);
+            if (builder.getTypeId(returnId) != currentFunction->getReturnType()) {
+                builder.clearAccessChain();
+                spv::Id copyId = builder.createVariable(spv::StorageClassFunction, currentFunction->getReturnType());
+                builder.setAccessChainLValue(copyId);
+                multiTypeStore(glslangReturnType, returnId);
+                returnId = builder.createLoad(copyId);
+            }
+            builder.makeReturn(false, returnId);
+        } else
             builder.makeReturn(false);
 
         builder.clearAccessChain();
@@ -2332,7 +2344,7 @@ void TGlslangToSpvTraverser::accessChainStore(const glslang::TType& type, spv::I
 // SPIR-V level.
 //
 // This especially happens when a single glslang type expands to multiple
-// SPIR-V types, like a struct that is used in an member-undecorated way as well
+// SPIR-V types, like a struct that is used in a member-undecorated way as well
 // as in a member-decorated way.
 //
 // NOTE: This function can handle any store request; if it's not special it
@@ -2599,8 +2611,8 @@ void TGlslangToSpvTraverser::handleFunctionEntry(const glslang::TIntermAggregate
 {
     // SPIR-V functions should already be in the functionMap from the prepass
     // that called makeFunctions().
-    spv::Function* function = functionMap[node->getName().c_str()];
-    spv::Block* functionBlock = function->getEntryBlock();
+    currentFunction = functionMap[node->getName().c_str()];
+    spv::Block* functionBlock = currentFunction->getEntryBlock();
     builder.setBuildPoint(functionBlock);
 }
 
diff --git a/Test/baseResults/hlsl.multiReturn.frag.out b/Test/baseResults/hlsl.multiReturn.frag.out
new file mode 100755 (executable)
index 0000000..80d7f16
--- /dev/null
@@ -0,0 +1,113 @@
+hlsl.multiReturn.frag
+Shader version: 450
+gl_FragCoord origin is upper left
+0:? Sequence
+0:12  Function Definition: foo( (temp structure{temp float f, temp 3-component vector of float v, temp 3X3 matrix of float m})
+0:12    Function Parameters: 
+0:?     Sequence
+0:13      Branch: Return with expression
+0:13        s: direct index for structure (layout(row_major std140 ) uniform structure{temp float f, temp 3-component vector of float v, temp 3X3 matrix of float m})
+0:13          'anon@0' (layout(row_major std140 ) uniform block{layout(row_major std140 ) uniform structure{temp float f, temp 3-component vector of float v, temp 3X3 matrix of float m} s})
+0:13          Constant:
+0:13            0 (const uint)
+0:17  Function Definition: main( (temp void)
+0:17    Function Parameters: 
+0:?     Sequence
+0:18      Function Call: foo( (temp structure{temp float f, temp 3-component vector of float v, temp 3X3 matrix of float m})
+0:?   Linker Objects
+0:?     'anon@0' (layout(row_major std140 ) uniform block{layout(row_major std140 ) uniform structure{temp float f, temp 3-component vector of float v, temp 3X3 matrix of float m} s})
+
+
+Linked fragment stage:
+
+
+Shader version: 450
+gl_FragCoord origin is upper left
+0:? Sequence
+0:12  Function Definition: foo( (temp structure{temp float f, temp 3-component vector of float v, temp 3X3 matrix of float m})
+0:12    Function Parameters: 
+0:?     Sequence
+0:13      Branch: Return with expression
+0:13        s: direct index for structure (layout(row_major std140 ) uniform structure{temp float f, temp 3-component vector of float v, temp 3X3 matrix of float m})
+0:13          'anon@0' (layout(row_major std140 ) uniform block{layout(row_major std140 ) uniform structure{temp float f, temp 3-component vector of float v, temp 3X3 matrix of float m} s})
+0:13          Constant:
+0:13            0 (const uint)
+0:17  Function Definition: main( (temp void)
+0:17    Function Parameters: 
+0:?     Sequence
+0:18      Function Call: foo( (temp structure{temp float f, temp 3-component vector of float v, temp 3X3 matrix of float m})
+0:?   Linker Objects
+0:?     'anon@0' (layout(row_major std140 ) uniform block{layout(row_major std140 ) uniform structure{temp float f, temp 3-component vector of float v, temp 3X3 matrix of float m} s})
+
+// Module Version 10000
+// Generated by (magic number): 80001
+// Id's are bound by 39
+
+                              Capability Shader
+               1:             ExtInstImport  "GLSL.std.450"
+                              MemoryModel Logical GLSL450
+                              EntryPoint Fragment 4  "main"
+                              ExecutionMode 4 OriginUpperLeft
+                              Name 4  "main"
+                              Name 9  "S"
+                              MemberName 9(S) 0  "f"
+                              MemberName 9(S) 1  "v"
+                              MemberName 9(S) 2  "m"
+                              Name 11  "foo("
+                              Name 13  "S"
+                              MemberName 13(S) 0  "f"
+                              MemberName 13(S) 1  "v"
+                              MemberName 13(S) 2  "m"
+                              Name 14  "bufName"
+                              MemberName 14(bufName) 0  "s"
+                              Name 16  ""
+                              MemberDecorate 13(S) 0 Offset 0
+                              MemberDecorate 13(S) 1 Offset 16
+                              MemberDecorate 13(S) 2 RowMajor
+                              MemberDecorate 13(S) 2 Offset 32
+                              MemberDecorate 13(S) 2 MatrixStride 16
+                              MemberDecorate 14(bufName) 0 Offset 0
+                              Decorate 14(bufName) Block
+                              Decorate 16 DescriptorSet 0
+               2:             TypeVoid
+               3:             TypeFunction 2
+               6:             TypeFloat 32
+               7:             TypeVector 6(float) 3
+               8:             TypeMatrix 7(fvec3) 3
+            9(S):             TypeStruct 6(float) 7(fvec3) 8
+              10:             TypeFunction 9(S)
+           13(S):             TypeStruct 6(float) 7(fvec3) 8
+     14(bufName):             TypeStruct 13(S)
+              15:             TypePointer Uniform 14(bufName)
+              16:     15(ptr) Variable Uniform
+              17:             TypeInt 32 1
+              18:     17(int) Constant 0
+              19:             TypePointer Uniform 13(S)
+              22:             TypePointer Function 9(S)
+              25:             TypePointer Function 6(float)
+              28:     17(int) Constant 1
+              29:             TypePointer Function 7(fvec3)
+              32:     17(int) Constant 2
+              33:             TypePointer Function 8
+         4(main):           2 Function None 3
+               5:             Label
+              38:        9(S) FunctionCall 11(foo()
+                              Return
+                              FunctionEnd
+        11(foo():        9(S) Function None 10
+              12:             Label
+              23:     22(ptr) Variable Function
+              20:     19(ptr) AccessChain 16 18
+              21:       13(S) Load 20
+              24:    6(float) CompositeExtract 21 0
+              26:     25(ptr) AccessChain 23 18
+                              Store 26 24
+              27:    7(fvec3) CompositeExtract 21 1
+              30:     29(ptr) AccessChain 23 28
+                              Store 30 27
+              31:           8 CompositeExtract 21 2
+              34:     33(ptr) AccessChain 23 32
+                              Store 34 31
+              35:        9(S) Load 23
+                              ReturnValue 35
+                              FunctionEnd
index 7a40f94..63eb6cb 100644 (file)
@@ -5,7 +5,7 @@ Linked fragment stage:
 
 // Module Version 10000
 // Generated by (magic number): 80001
-// Id's are bound by 16123
+// Id's are bound by 9012
 
                               Capability Shader
                               Capability Float64
@@ -20,11 +20,12 @@ Linked fragment stage:
               13:             TypeFloat 32
               29:             TypeVector 13(float) 4
             2572:   13(float) Constant 0
-             666:             TypePointer Output 29(fvec4)
-            4045:    666(ptr) Variable Output
-             667:             TypePointer Input 29(fvec4)
-            4872:    667(ptr) Variable Input
+             650:             TypePointer Output 13(float)
+            4045:    650(ptr) Variable Output
+             666:             TypePointer Input 29(fvec4)
+            4872:    666(ptr) Variable Input
             5663:           8 Function None 1282
-           16122:             Label
-                              ReturnValue 2572
+            9011:             Label
+                              Store 4045 2572
+                              Return
                               FunctionEnd
index 340198b..df0dcf5 100644 (file)
@@ -159,7 +159,7 @@ Linked fragment stage:
              142:   69(fvec3) ConstantComposite 109 111 112
              143:   69(fvec3) ConstantComposite 113 114 116
              144:         139 ConstantComposite 72 126 142 143
-             145:             TypePointer Output 7(fvec4)
+             145:             TypePointer Output 6(float)
 146(@entryPointOutput):    145(ptr) Variable Output
              147:             TypePointer Input 7(fvec4)
       148(input):    147(ptr) Variable Input
@@ -221,5 +221,6 @@ Linked fragment stage:
                               Store 130(r62) 133
                               Store 136(r65) 138
                               Store 141(r66) 144
-                              ReturnValue 106
+                              Store 146(@entryPointOutput) 106
+                              Return
                               FunctionEnd
diff --git a/Test/hlsl.multiReturn.frag b/Test/hlsl.multiReturn.frag
new file mode 100755 (executable)
index 0000000..fdab772
--- /dev/null
@@ -0,0 +1,19 @@
+struct S {\r
+    float f;\r
+    float3 v;\r
+    float3x3 m;\r
+};\r
+\r
+cbuffer bufName {\r
+    S s;\r
+};\r
+\r
+S foo()\r
+{\r
+    return s;\r
+}\r
+\r
+void main()\r
+{\r
+    foo();\r
+}\r
index bacd6f5..f48c98a 100644 (file)
@@ -1,5 +1,5 @@
 
-float4 main(float4 input) : COLOR0
+float main(float4 input) : COLOR0
 {
     vector r00 = float4(1,2,3,4);  // vector means float4
     float4 r01 = vector(2,3,4,5);  // vector means float4
index bacd6f5..f48c98a 100644 (file)
@@ -1,5 +1,5 @@
 
-float4 main(float4 input) : COLOR0
+float main(float4 input) : COLOR0
 {
     vector r00 = float4(1,2,3,4);  // vector means float4
     float4 r01 = vector(2,3,4,5);  // vector means float4
index 304acc3..4b7b8a2 100644 (file)
@@ -2,5 +2,5 @@
 // For the version, it uses the latest git tag followed by the number of commits.
 // For the date, it uses the current date (when then script is run).
 
-#define GLSLANG_REVISION "Overload400-PrecQual.1553"
-#define GLSLANG_DATE "05-Oct-2016"
+#define GLSLANG_REVISION "Overload400-PrecQual.1556"
+#define GLSLANG_DATE "06-Oct-2016"
index 90b62d1..6d2a954 100644 (file)
@@ -140,6 +140,7 @@ INSTANTIATE_TEST_CASE_P(
         {"hlsl.load.offset.dx10.frag", "main"},
         {"hlsl.load.offsetarray.dx10.frag", "main"},
         {"hlsl.multiEntry.vert", "RealEntrypoint"},
+        {"hlsl.multiReturn.frag", "main"},
         {"hlsl.matrixindex.frag", "main"},
         {"hlsl.numericsuffixes.frag", "main"},
         {"hlsl.overload.frag", "PixelShaderFunction"},