HLSL: fix for flattening assignments from non-symbol R-values.
authorsteve-lunarg <steve_gh@khasekhemwy.net>
Mon, 3 Oct 2016 04:13:22 +0000 (22:13 -0600)
committersteve-lunarg <steve_gh@khasekhemwy.net>
Tue, 4 Oct 2016 23:07:45 +0000 (17:07 -0600)
If a member-wise assignment from a non-flattened struct to a flattened struct sees a complex R-value
(not a symbol), it now creates a temporary to hold that value, to avoid repeating the R-value.
This avoids, e.g, duplicating a whole function call.  Also, it avoids re-using the AST node, making a
new one for each member inside the member loop.

The latter (re-use of AST node) was also an issue in the GetDimensions intrinsic decomposition,
so this PR fixes that one too.

Test/baseResults/hlsl.flatten.return.frag.out [new file with mode: 0644]
Test/hlsl.flatten.return.frag [new file with mode: 0644]
glslang/MachineIndependent/Intermediate.cpp
glslang/MachineIndependent/localintermediate.h
gtests/Hlsl.FromFile.cpp
hlsl/hlslParseHelper.cpp

diff --git a/Test/baseResults/hlsl.flatten.return.frag.out b/Test/baseResults/hlsl.flatten.return.frag.out
new file mode 100644 (file)
index 0000000..39fbf0e
--- /dev/null
@@ -0,0 +1,187 @@
+hlsl.flatten.return.frag
+Shader version: 450
+gl_FragCoord origin is upper left
+0:? Sequence
+0:11  Function Definition: Func1( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:11    Function Parameters: 
+0:?     Sequence
+0:12      Branch: Return with expression
+0:?         Constant:
+0:?           1.000000
+0:?           1.000000
+0:?           1.000000
+0:?           1.000000
+0:?           2.000000
+0:?           3.000000
+0:?           4.000000
+0:16  Function Definition: main( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:16    Function Parameters: 
+0:?     Sequence
+0:17      Sequence
+0:17        Sequence
+0:17          move second child to first child (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17            'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17            Function Call: Func1( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17          move second child to first child (temp 4-component vector of float)
+0:?             'color' (layout(location=0 ) out 4-component vector of float)
+0:17            color: direct index for structure (temp 4-component vector of float)
+0:17              'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17              Constant:
+0:17                0 (const int)
+0:17          move second child to first child (temp float)
+0:?             'other_struct_member1' (layout(location=1 ) out float)
+0:17            other_struct_member1: direct index for structure (temp float)
+0:17              'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17              Constant:
+0:17                1 (const int)
+0:17          move second child to first child (temp float)
+0:?             'other_struct_member2' (layout(location=2 ) out float)
+0:17            other_struct_member2: direct index for structure (temp float)
+0:17              'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17              Constant:
+0:17                2 (const int)
+0:17          move second child to first child (temp float)
+0:?             'other_struct_member3' (layout(location=3 ) out float)
+0:17            other_struct_member3: direct index for structure (temp float)
+0:17              'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17              Constant:
+0:17                3 (const int)
+0:17        Branch: Return
+0:?   Linker Objects
+0:?     'color' (layout(location=0 ) out 4-component vector of float)
+0:?     'other_struct_member1' (layout(location=1 ) out float)
+0:?     'other_struct_member2' (layout(location=2 ) out float)
+0:?     'other_struct_member3' (layout(location=3 ) out float)
+
+
+Linked fragment stage:
+
+
+Shader version: 450
+gl_FragCoord origin is upper left
+0:? Sequence
+0:11  Function Definition: Func1( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:11    Function Parameters: 
+0:?     Sequence
+0:12      Branch: Return with expression
+0:?         Constant:
+0:?           1.000000
+0:?           1.000000
+0:?           1.000000
+0:?           1.000000
+0:?           2.000000
+0:?           3.000000
+0:?           4.000000
+0:16  Function Definition: main( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:16    Function Parameters: 
+0:?     Sequence
+0:17      Sequence
+0:17        Sequence
+0:17          move second child to first child (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17            'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17            Function Call: Func1( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17          move second child to first child (temp 4-component vector of float)
+0:?             'color' (layout(location=0 ) out 4-component vector of float)
+0:17            color: direct index for structure (temp 4-component vector of float)
+0:17              'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17              Constant:
+0:17                0 (const int)
+0:17          move second child to first child (temp float)
+0:?             'other_struct_member1' (layout(location=1 ) out float)
+0:17            other_struct_member1: direct index for structure (temp float)
+0:17              'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17              Constant:
+0:17                1 (const int)
+0:17          move second child to first child (temp float)
+0:?             'other_struct_member2' (layout(location=2 ) out float)
+0:17            other_struct_member2: direct index for structure (temp float)
+0:17              'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17              Constant:
+0:17                2 (const int)
+0:17          move second child to first child (temp float)
+0:?             'other_struct_member3' (layout(location=3 ) out float)
+0:17            other_struct_member3: direct index for structure (temp float)
+0:17              'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3})
+0:17              Constant:
+0:17                3 (const int)
+0:17        Branch: Return
+0:?   Linker Objects
+0:?     'color' (layout(location=0 ) out 4-component vector of float)
+0:?     'other_struct_member1' (layout(location=1 ) out float)
+0:?     'other_struct_member2' (layout(location=2 ) out float)
+0:?     'other_struct_member3' (layout(location=3 ) out float)
+
+// Module Version 10000
+// Generated by (magic number): 80001
+// Id's are bound by 45
+
+                              Capability Shader
+               1:             ExtInstImport  "GLSL.std.450"
+                              MemoryModel Logical GLSL450
+                              EntryPoint Fragment 4  "main" 24 31 36 40
+                              ExecutionMode 4 OriginUpperLeft
+                              Name 4  "main"
+                              Name 8  "PS_OUTPUT"
+                              MemberName 8(PS_OUTPUT) 0  "color"
+                              MemberName 8(PS_OUTPUT) 1  "other_struct_member1"
+                              MemberName 8(PS_OUTPUT) 2  "other_struct_member2"
+                              MemberName 8(PS_OUTPUT) 3  "other_struct_member3"
+                              Name 10  "Func1("
+                              Name 21  "flattenTemp"
+                              Name 24  "color"
+                              Name 31  "other_struct_member1"
+                              Name 36  "other_struct_member2"
+                              Name 40  "other_struct_member3"
+                              Decorate 24(color) Location 0
+                              Decorate 31(other_struct_member1) Location 1
+                              Decorate 36(other_struct_member2) Location 2
+                              Decorate 40(other_struct_member3) Location 3
+               2:             TypeVoid
+               3:             TypeFunction 2
+               6:             TypeFloat 32
+               7:             TypeVector 6(float) 4
+    8(PS_OUTPUT):             TypeStruct 7(fvec4) 6(float) 6(float) 6(float)
+               9:             TypeFunction 8(PS_OUTPUT)
+              12:    6(float) Constant 1065353216
+              13:    7(fvec4) ConstantComposite 12 12 12 12
+              14:    6(float) Constant 1073741824
+              15:    6(float) Constant 1077936128
+              16:    6(float) Constant 1082130432
+              17:8(PS_OUTPUT) ConstantComposite 13 14 15 16
+              20:             TypePointer Function 8(PS_OUTPUT)
+              23:             TypePointer Output 7(fvec4)
+       24(color):     23(ptr) Variable Output
+              25:             TypeInt 32 1
+              26:     25(int) Constant 0
+              27:             TypePointer Function 7(fvec4)
+              30:             TypePointer Output 6(float)
+31(other_struct_member1):     30(ptr) Variable Output
+              32:     25(int) Constant 1
+              33:             TypePointer Function 6(float)
+36(other_struct_member2):     30(ptr) Variable Output
+              37:     25(int) Constant 2
+40(other_struct_member3):     30(ptr) Variable Output
+              41:     25(int) Constant 3
+         4(main):           2 Function None 3
+               5:             Label
+ 21(flattenTemp):     20(ptr) Variable Function
+              22:8(PS_OUTPUT) FunctionCall 10(Func1()
+                              Store 21(flattenTemp) 22
+              28:     27(ptr) AccessChain 21(flattenTemp) 26
+              29:    7(fvec4) Load 28
+                              Store 24(color) 29
+              34:     33(ptr) AccessChain 21(flattenTemp) 32
+              35:    6(float) Load 34
+                              Store 31(other_struct_member1) 35
+              38:     33(ptr) AccessChain 21(flattenTemp) 37
+              39:    6(float) Load 38
+                              Store 36(other_struct_member2) 39
+              42:     33(ptr) AccessChain 21(flattenTemp) 41
+              43:    6(float) Load 42
+                              Store 40(other_struct_member3) 43
+                              Return
+                              FunctionEnd
+      10(Func1():8(PS_OUTPUT) Function None 9
+              11:             Label
+                              ReturnValue 17
+                              FunctionEnd
diff --git a/Test/hlsl.flatten.return.frag b/Test/hlsl.flatten.return.frag
new file mode 100644 (file)
index 0000000..c633e67
--- /dev/null
@@ -0,0 +1,18 @@
+
+struct PS_OUTPUT
+{
+    float4 color : SV_Target0;
+    float other_struct_member1;
+    float other_struct_member2;
+    float other_struct_member3;
+};
+
+PS_OUTPUT Func1()
+{
+    return PS_OUTPUT(float4(1), 2, 3, 4);
+}
+
+PS_OUTPUT main()
+{
+    return Func1();
+}
index cababc3..9755620 100644 (file)
@@ -73,6 +73,16 @@ TIntermSymbol* TIntermediate::addSymbol(int id, const TString& name, const TType
     return node;
 }
 
+TIntermSymbol* TIntermediate::addSymbol(const TIntermSymbol& intermSymbol)
+{
+    return addSymbol(intermSymbol.getId(),
+                     intermSymbol.getName(),
+                     intermSymbol.getType(),
+                     intermSymbol.getConstArray(),
+                     intermSymbol.getConstSubtree(),
+                     intermSymbol.getLoc());
+}
+
 TIntermSymbol* TIntermediate::addSymbol(const TVariable& variable)
 {
     glslang::TSourceLoc loc; // just a null location
index 14b8a00..acfafb1 100644 (file)
@@ -201,6 +201,7 @@ public:
     TIntermSymbol* addSymbol(const TVariable&);
     TIntermSymbol* addSymbol(const TVariable&, const TSourceLoc&);
     TIntermSymbol* addSymbol(const TType&, const TSourceLoc&);
+    TIntermSymbol* addSymbol(const TIntermSymbol&);
     TIntermTyped* addConversion(TOperator, const TType&, TIntermTyped*) const;
     TIntermTyped* addShapeConversion(TOperator, const TType&, TIntermTyped*);
     TIntermTyped* addBinaryMath(TOperator, TIntermTyped* left, TIntermTyped* right, TSourceLoc);
index 7467eb7..51be19d 100644 (file)
@@ -99,6 +99,7 @@ INSTANTIATE_TEST_CASE_P(
         {"hlsl.entry-out.frag", "PixelShaderFunction"},
         {"hlsl.float1.frag", "PixelShaderFunction"},
         {"hlsl.float4.frag", "PixelShaderFunction"},
+        {"hlsl.flatten.return.frag", "main"},
         {"hlsl.forLoop.frag", "PixelShaderFunction"},
         {"hlsl.gather.array.dx10.frag", "main"},
         {"hlsl.gather.basic.dx10.frag", "main"},
index 082a496..08ffd58 100755 (executable)
@@ -952,10 +952,53 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
     const TVector<TVariable*>* leftVariables = nullptr;
     const TVector<TVariable*>* rightVariables = nullptr;
 
+    // A temporary to store the right node's value, so we don't keep indirecting into it
+    // if it's not a simple symbol.
+    TVariable*     rhsTempVar   = nullptr;
+
+    // If the RHS is a simple symbol node, we'll copy it for each member.
+    TIntermSymbol* cloneSymNode = nullptr;
+
+    // Array structs are not yet handled in flattening.  (Compilation error upstream, so
+    // this should never fire).
+    assert(!(left->getType().isStruct() && left->getType().isArray()));
+
+    int memberCount = 0;
+
+    // Track how many items there are to copy.
+    if (left->getType().isStruct())
+        memberCount = left->getType().getStruct()->size();
+    if (left->getType().isArray())
+        memberCount = left->getType().getCumulativeArraySize();
+
     if (flattenLeft)
         leftVariables = &flattenMap.find(left->getAsSymbolNode()->getId())->second;
-    if (flattenRight)
+
+    if (flattenRight) {
         rightVariables = &flattenMap.find(right->getAsSymbolNode()->getId())->second;
+    } else {
+        // The RHS is not flattened.  There are several cases:
+        // 1. 1 item to copy:  Use the RHS directly.
+        // 2. >1 item, simple symbol RHS: we'll create a new TIntermSymbol node for each, but no assign to temp.
+        // 3. >1 item, complex RHS: assign it to a new temp variable, and create a TIntermSymbol for each member.
+        
+        if (memberCount <= 1) {
+            // case 1: we'll use the symbol directly below.  Nothing to do.
+        } else {
+            if (right->getAsSymbolNode() != nullptr) {
+                // case 2: we'll copy the symbol per iteration below.
+                cloneSymNode = right->getAsSymbolNode();
+            } else {
+                // case 3: assign to a temp, and indirect into that.
+                rhsTempVar = makeInternalVariable("flattenTemp", right->getType());
+                rhsTempVar->getWritableType().getQualifier().makeTemporary();
+                TIntermTyped* noFlattenRHS = intermediate.addSymbol(*rhsTempVar, loc);
+
+                // Add this to the aggregate being built.
+                assignList = intermediate.growAggregate(assignList, intermediate.addAssign(op, noFlattenRHS, right, loc), loc);
+            }
+        }
+    }
 
     const auto getMember = [&](bool flatten, TIntermTyped* node,
                                const TVector<TVariable*>& memberVariables, int member,
@@ -971,6 +1014,14 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
         return subTree;
     };
 
+    // Return the proper RHS node: a new symbol from a TVariable, copy
+    // of an TIntermSymbol node, or sometimes the right node directly.
+    const auto getRHS = [&]() {
+        return rhsTempVar   ? intermediate.addSymbol(*rhsTempVar, loc) :
+               cloneSymNode ? intermediate.addSymbol(*cloneSymNode) :
+                              right;
+    };
+
     // Handle struct assignment
     if (left->getType().isStruct()) {
         // If we get here, we are assigning to or from a whole struct that must be
@@ -978,7 +1029,7 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
         const auto& members = *left->getType().getStruct();
 
         for (int member = 0; member < (int)members.size(); ++member) {
-            TIntermTyped* subRight = getMember(flattenRight, right, *rightVariables, member,
+            TIntermTyped* subRight = getMember(flattenRight, getRHS(), *rightVariables, member,
                                                EOpIndexDirectStruct, *members[member].type);
             TIntermTyped* subLeft = getMember(flattenLeft, left, *leftVariables, member,
                                               EOpIndexDirectStruct, *members[member].type);
@@ -992,10 +1043,10 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
         // flattened, so have to do member-by-member assignment:
 
         const TType dereferencedType(left->getType(), 0);
-        const int size = left->getType().getCumulativeArraySize();
         
-        for (int element=0; element < size; ++element) {
-            TIntermTyped* subRight = getMember(flattenRight, right, *rightVariables, element,
+        for (int element=0; element < memberCount; ++element) {
+            // Add a new AST symbol node if we have a temp variable holding a complex RHS.
+            TIntermTyped* subRight = getMember(flattenRight, getRHS(), *rightVariables, element,
                                                EOpIndexDirect, dereferencedType);
             TIntermTyped* subLeft = getMember(flattenLeft, left, *leftVariables, element,
                                               EOpIndexDirect, dereferencedType);
@@ -1235,9 +1286,9 @@ void HlslParseContext::decomposeSampleMethods(const TSourceLoc& loc, TIntermType
             // Return value from size query
             TVariable* tempArg = makeInternalVariable("sizeQueryTemp", sizeQuery->getType());
             tempArg->getWritableType().getQualifier().makeTemporary();
-            TIntermSymbol* sizeQueryReturn = intermediate.addSymbol(*tempArg, loc);
-
-            TIntermTyped* sizeQueryAssign = intermediate.addAssign(EOpAssign, sizeQueryReturn, sizeQuery, loc);
+            TIntermTyped* sizeQueryAssign = intermediate.addAssign(EOpAssign,
+                                                                   intermediate.addSymbol(*tempArg, loc),
+                                                                   sizeQuery, loc);
 
             // Compound statement for assigning outputs
             TIntermAggregate* compoundStatement = intermediate.makeAggregate(sizeQueryAssign, loc);
@@ -1246,6 +1297,7 @@ void HlslParseContext::decomposeSampleMethods(const TSourceLoc& loc, TIntermType
 
             for (int compNum = 0; compNum < numDims; ++compNum) {
                 TIntermTyped* indexedOut = nullptr;
+                TIntermSymbol* sizeQueryReturn = intermediate.addSymbol(*tempArg, loc);
 
                 if (numDims > 1) {
                     TIntermTyped* component = intermediate.addConstantUnion(compNum, loc, true);