porting changes for OpenGL Subgroup tests
[platform/upstream/VK-GL-CTS.git] / external / openglcts / modules / common / subgroups / glcSubgroupsPartitionedTests.cpp
1 /*------------------------------------------------------------------------
2  * OpenGL Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017-2019 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  * Copyright (c) 2018-2019 NVIDIA Corporation
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  */ /*!
22  * \file
23  * \brief Subgroups Tests
24  */ /*--------------------------------------------------------------------*/
25
26 #include "glcSubgroupsPartitionedTests.hpp"
27 #include "glcSubgroupsTestsUtils.hpp"
28
29 #include <string>
30 #include <vector>
31
32 using namespace tcu;
33 using namespace std;
34
35 namespace glc
36 {
37 namespace subgroups
38 {
39 namespace
40 {
41 enum OpType
42 {
43         OPTYPE_ADD = 0,
44         OPTYPE_MUL,
45         OPTYPE_MIN,
46         OPTYPE_MAX,
47         OPTYPE_AND,
48         OPTYPE_OR,
49         OPTYPE_XOR,
50         OPTYPE_INCLUSIVE_ADD,
51         OPTYPE_INCLUSIVE_MUL,
52         OPTYPE_INCLUSIVE_MIN,
53         OPTYPE_INCLUSIVE_MAX,
54         OPTYPE_INCLUSIVE_AND,
55         OPTYPE_INCLUSIVE_OR,
56         OPTYPE_INCLUSIVE_XOR,
57         OPTYPE_EXCLUSIVE_ADD,
58         OPTYPE_EXCLUSIVE_MUL,
59         OPTYPE_EXCLUSIVE_MIN,
60         OPTYPE_EXCLUSIVE_MAX,
61         OPTYPE_EXCLUSIVE_AND,
62         OPTYPE_EXCLUSIVE_OR,
63         OPTYPE_EXCLUSIVE_XOR,
64         OPTYPE_LAST
65 };
66
67 static bool checkVertexPipelineStages(std::vector<const void*> datas,
68                                                                           deUint32 width, deUint32)
69 {
70         return glc::subgroups::check(datas, width, 0xFFFFFF);
71 }
72
73 static bool checkComputeStage(std::vector<const void*> datas,
74                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
75                                                  deUint32)
76 {
77         return glc::subgroups::checkCompute(datas, numWorkgroups, localSize, 0xFFFFFF);
78 }
79
80 std::string getOpTypeName(int opType)
81 {
82         switch (opType)
83         {
84                 default:
85                         DE_FATAL("Unsupported op type");
86                         return "";
87                 case OPTYPE_ADD:
88                         return "subgroupAdd";
89                 case OPTYPE_MUL:
90                         return "subgroupMul";
91                 case OPTYPE_MIN:
92                         return "subgroupMin";
93                 case OPTYPE_MAX:
94                         return "subgroupMax";
95                 case OPTYPE_AND:
96                         return "subgroupAnd";
97                 case OPTYPE_OR:
98                         return "subgroupOr";
99                 case OPTYPE_XOR:
100                         return "subgroupXor";
101                 case OPTYPE_INCLUSIVE_ADD:
102                         return "subgroupInclusiveAdd";
103                 case OPTYPE_INCLUSIVE_MUL:
104                         return "subgroupInclusiveMul";
105                 case OPTYPE_INCLUSIVE_MIN:
106                         return "subgroupInclusiveMin";
107                 case OPTYPE_INCLUSIVE_MAX:
108                         return "subgroupInclusiveMax";
109                 case OPTYPE_INCLUSIVE_AND:
110                         return "subgroupInclusiveAnd";
111                 case OPTYPE_INCLUSIVE_OR:
112                         return "subgroupInclusiveOr";
113                 case OPTYPE_INCLUSIVE_XOR:
114                         return "subgroupInclusiveXor";
115                 case OPTYPE_EXCLUSIVE_ADD:
116                         return "subgroupExclusiveAdd";
117                 case OPTYPE_EXCLUSIVE_MUL:
118                         return "subgroupExclusiveMul";
119                 case OPTYPE_EXCLUSIVE_MIN:
120                         return "subgroupExclusiveMin";
121                 case OPTYPE_EXCLUSIVE_MAX:
122                         return "subgroupExclusiveMax";
123                 case OPTYPE_EXCLUSIVE_AND:
124                         return "subgroupExclusiveAnd";
125                 case OPTYPE_EXCLUSIVE_OR:
126                         return "subgroupExclusiveOr";
127                 case OPTYPE_EXCLUSIVE_XOR:
128                         return "subgroupExclusiveXor";
129         }
130 }
131
132 std::string getOpTypeNamePartitioned(int opType)
133 {
134         switch (opType)
135         {
136                 default:
137                         DE_FATAL("Unsupported op type");
138                         return "";
139                 case OPTYPE_ADD:
140                         return "subgroupPartitionedAddNV";
141                 case OPTYPE_MUL:
142                         return "subgroupPartitionedMulNV";
143                 case OPTYPE_MIN:
144                         return "subgroupPartitionedMinNV";
145                 case OPTYPE_MAX:
146                         return "subgroupPartitionedMaxNV";
147                 case OPTYPE_AND:
148                         return "subgroupPartitionedAndNV";
149                 case OPTYPE_OR:
150                         return "subgroupPartitionedOrNV";
151                 case OPTYPE_XOR:
152                         return "subgroupPartitionedXorNV";
153                 case OPTYPE_INCLUSIVE_ADD:
154                         return "subgroupPartitionedInclusiveAddNV";
155                 case OPTYPE_INCLUSIVE_MUL:
156                         return "subgroupPartitionedInclusiveMulNV";
157                 case OPTYPE_INCLUSIVE_MIN:
158                         return "subgroupPartitionedInclusiveMinNV";
159                 case OPTYPE_INCLUSIVE_MAX:
160                         return "subgroupPartitionedInclusiveMaxNV";
161                 case OPTYPE_INCLUSIVE_AND:
162                         return "subgroupPartitionedInclusiveAndNV";
163                 case OPTYPE_INCLUSIVE_OR:
164                         return "subgroupPartitionedInclusiveOrNV";
165                 case OPTYPE_INCLUSIVE_XOR:
166                         return "subgroupPartitionedInclusiveXorNV";
167                 case OPTYPE_EXCLUSIVE_ADD:
168                         return "subgroupPartitionedExclusiveAddNV";
169                 case OPTYPE_EXCLUSIVE_MUL:
170                         return "subgroupPartitionedExclusiveMulNV";
171                 case OPTYPE_EXCLUSIVE_MIN:
172                         return "subgroupPartitionedExclusiveMinNV";
173                 case OPTYPE_EXCLUSIVE_MAX:
174                         return "subgroupPartitionedExclusiveMaxNV";
175                 case OPTYPE_EXCLUSIVE_AND:
176                         return "subgroupPartitionedExclusiveAndNV";
177                 case OPTYPE_EXCLUSIVE_OR:
178                         return "subgroupPartitionedExclusiveOrNV";
179                 case OPTYPE_EXCLUSIVE_XOR:
180                         return "subgroupPartitionedExclusiveXorNV";
181         }
182 }
183
184 std::string getIdentity(int opType, Format format)
185 {
186         bool isFloat = false;
187         bool isInt = false;
188         bool isUnsigned = false;
189
190         switch (format)
191         {
192                 default:
193                         DE_FATAL("Unhandled format!");
194                         return "";
195                 case FORMAT_R32_SINT:
196                 case FORMAT_R32G32_SINT:
197                 case FORMAT_R32G32B32_SINT:
198                 case FORMAT_R32G32B32A32_SINT:
199                         isInt = true;
200                         break;
201                 case FORMAT_R32_UINT:
202                 case FORMAT_R32G32_UINT:
203                 case FORMAT_R32G32B32_UINT:
204                 case FORMAT_R32G32B32A32_UINT:
205                         isUnsigned = true;
206                         break;
207                 case FORMAT_R32_SFLOAT:
208                 case FORMAT_R32G32_SFLOAT:
209                 case FORMAT_R32G32B32_SFLOAT:
210                 case FORMAT_R32G32B32A32_SFLOAT:
211                 case FORMAT_R64_SFLOAT:
212                 case FORMAT_R64G64_SFLOAT:
213                 case FORMAT_R64G64B64_SFLOAT:
214                 case FORMAT_R64G64B64A64_SFLOAT:
215                         isFloat = true;
216                         break;
217                 case FORMAT_R32_BOOL:
218                 case FORMAT_R32G32_BOOL:
219                 case FORMAT_R32G32B32_BOOL:
220                 case FORMAT_R32G32B32A32_BOOL:
221                         break; // bool types are not anything
222         }
223
224         switch (opType)
225         {
226                 default:
227                         DE_FATAL("Unsupported op type");
228                         return "";
229                 case OPTYPE_ADD:
230                 case OPTYPE_INCLUSIVE_ADD:
231                 case OPTYPE_EXCLUSIVE_ADD:
232                         return subgroups::getFormatNameForGLSL(format) + "(0)";
233                 case OPTYPE_MUL:
234                 case OPTYPE_INCLUSIVE_MUL:
235                 case OPTYPE_EXCLUSIVE_MUL:
236                         return subgroups::getFormatNameForGLSL(format) + "(1)";
237                 case OPTYPE_MIN:
238                 case OPTYPE_INCLUSIVE_MIN:
239                 case OPTYPE_EXCLUSIVE_MIN:
240                         if (isFloat)
241                         {
242                                 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))";
243                         }
244                         else if (isInt)
245                         {
246                                 return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)";
247                         }
248                         else if (isUnsigned)
249                         {
250                                 return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)";
251                         }
252                         else
253                         {
254                                 DE_FATAL("Unhandled case");
255                                 return "";
256                         }
257                 case OPTYPE_MAX:
258                 case OPTYPE_INCLUSIVE_MAX:
259                 case OPTYPE_EXCLUSIVE_MAX:
260                         if (isFloat)
261                         {
262                                 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))";
263                         }
264                         else if (isInt)
265                         {
266                                 return subgroups::getFormatNameForGLSL(format) + "(0x80000000)";
267                         }
268                         else if (isUnsigned)
269                         {
270                                 return subgroups::getFormatNameForGLSL(format) + "(0)";
271                         }
272                         else
273                         {
274                                 DE_FATAL("Unhandled case");
275                                 return "";
276                         }
277                 case OPTYPE_AND:
278                 case OPTYPE_INCLUSIVE_AND:
279                 case OPTYPE_EXCLUSIVE_AND:
280                         return subgroups::getFormatNameForGLSL(format) + "(~0)";
281                 case OPTYPE_OR:
282                 case OPTYPE_INCLUSIVE_OR:
283                 case OPTYPE_EXCLUSIVE_OR:
284                         return subgroups::getFormatNameForGLSL(format) + "(0)";
285                 case OPTYPE_XOR:
286                 case OPTYPE_INCLUSIVE_XOR:
287                 case OPTYPE_EXCLUSIVE_XOR:
288                         return subgroups::getFormatNameForGLSL(format) + "(0)";
289         }
290 }
291
292 std::string getCompare(int opType, Format format, std::string lhs, std::string rhs)
293 {
294         std::string formatName = subgroups::getFormatNameForGLSL(format);
295         switch (format)
296         {
297                 default:
298                         return "all(equal(" + lhs + ", " + rhs + "))";
299                 case FORMAT_R32_BOOL:
300                 case FORMAT_R32_UINT:
301                 case FORMAT_R32_SINT:
302                         return "(" + lhs + " == " + rhs + ")";
303                 case FORMAT_R32_SFLOAT:
304                 case FORMAT_R64_SFLOAT:
305                         switch (opType)
306                         {
307                                 default:
308                                         return "(abs(" + lhs + " - " + rhs + ") < 0.00001)";
309                                 case OPTYPE_MIN:
310                                 case OPTYPE_INCLUSIVE_MIN:
311                                 case OPTYPE_EXCLUSIVE_MIN:
312                                 case OPTYPE_MAX:
313                                 case OPTYPE_INCLUSIVE_MAX:
314                                 case OPTYPE_EXCLUSIVE_MAX:
315                                         return "(" + lhs + " == " + rhs + ")";
316                         }
317                 case FORMAT_R32G32_SFLOAT:
318                 case FORMAT_R32G32B32_SFLOAT:
319                 case FORMAT_R32G32B32A32_SFLOAT:
320                 case FORMAT_R64G64_SFLOAT:
321                 case FORMAT_R64G64B64_SFLOAT:
322                 case FORMAT_R64G64B64A64_SFLOAT:
323                         switch (opType)
324                         {
325                                 default:
326                                         return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))";
327                                 case OPTYPE_MIN:
328                                 case OPTYPE_INCLUSIVE_MIN:
329                                 case OPTYPE_EXCLUSIVE_MIN:
330                                 case OPTYPE_MAX:
331                                 case OPTYPE_INCLUSIVE_MAX:
332                                 case OPTYPE_EXCLUSIVE_MAX:
333                                         return "all(equal(" + lhs + ", " + rhs + "))";
334                         }
335         }
336 }
337
338 struct CaseDefinition
339 {
340         int                                     opType;
341         ShaderStageFlags        shaderStage;
342         Format                          format;
343 };
344
345 string getTestString(const CaseDefinition &caseDef)
346 {
347     // NOTE: tempResult can't have anything in bits 31:24 to avoid int->float
348     // conversion overflow in framebuffer tests.
349     string fmt = subgroups::getFormatNameForGLSL(caseDef.format);
350         string bdy =
351                 "  uint tempResult = 0;\n"
352                 "  uint id = gl_SubgroupInvocationID;\n";
353
354     // Test the case where the partition has a single subset with all invocations in it.
355     // This should generate the same result as the non-partitioned function.
356     bdy +=
357         "  uvec4 allBallot = mask;\n"
358         "  " + fmt + " allResult = " + getOpTypeNamePartitioned(caseDef.opType) + "(data[gl_SubgroupInvocationID], allBallot);\n"
359         "  " + fmt + " refResult = " + getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID]);\n"
360         "  if (" + getCompare(caseDef.opType, caseDef.format, "allResult", "refResult") + ") {\n"
361         "      tempResult |= 0x1;\n"
362         "  }\n";
363
364     // The definition of a partition doesn't forbid bits corresponding to inactive
365     // invocations being in the subset with active invocations. In other words, test that
366     // bits corresponding to inactive invocations are ignored.
367     bdy +=
368             "  if (0 == (gl_SubgroupInvocationID % 2)) {\n"
369         "    " + fmt + " allResult = " + getOpTypeNamePartitioned(caseDef.opType) + "(data[gl_SubgroupInvocationID], allBallot);\n"
370         "    " + fmt + " refResult = " + getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID]);\n"
371         "    if (" + getCompare(caseDef.opType, caseDef.format, "allResult", "refResult") + ") {\n"
372         "        tempResult |= 0x2;\n"
373         "    }\n"
374         "  } else {\n"
375         "    tempResult |= 0x2;\n"
376         "  }\n";
377
378     // Test the case where the partition has each invocation in a unique subset. For
379     // exclusive ops, the result is identity. For reduce/inclusive, it's the original value.
380     string expectedSelfResult = "data[gl_SubgroupInvocationID]";
381     if (caseDef.opType >= OPTYPE_EXCLUSIVE_ADD &&
382         caseDef.opType <= OPTYPE_EXCLUSIVE_XOR) {
383         expectedSelfResult = getIdentity(caseDef.opType, caseDef.format);
384     }
385
386     bdy +=
387         "  uvec4 selfBallot = subgroupPartitionNV(gl_SubgroupInvocationID);\n"
388         "  " + fmt + " selfResult = " + getOpTypeNamePartitioned(caseDef.opType) + "(data[gl_SubgroupInvocationID], selfBallot);\n"
389         "  if (" + getCompare(caseDef.opType, caseDef.format, "selfResult", expectedSelfResult) + ") {\n"
390         "      tempResult |= 0x4;\n"
391         "  }\n";
392
393     // Test "random" partitions based on a hash of the invocation id.
394     // This "hash" function produces interesting/randomish partitions.
395     static const char *idhash = "((id%N)+(id%(N+1))-(id%2)+(id/2))%((N+1)/2)";
396
397     bdy +=
398                 "  for (uint N = 1; N < 16; ++N) {\n"
399                 "    " + fmt + " idhashFmt = " + fmt + "(" + idhash + ");\n"
400                 "    uvec4 partitionBallot = subgroupPartitionNV(idhashFmt) & mask;\n"
401                 "    " + fmt + " partitionedResult = " + getOpTypeNamePartitioned(caseDef.opType) + "(data[gl_SubgroupInvocationID], partitionBallot);\n"
402                 "      for (uint i = 0; i < N; ++i) {\n"
403                 "        " + fmt + " iFmt = " + fmt + "(i);\n"
404         "        if (" + getCompare(caseDef.opType, caseDef.format, "idhashFmt", "iFmt") + ") {\n"
405         "          " + fmt + " subsetResult = " + getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID]);\n"
406         "          tempResult |= " + getCompare(caseDef.opType, caseDef.format, "partitionedResult", "subsetResult") + " ? (0x4 << N) : 0;\n"
407         "        }\n"
408         "      }\n"
409         "  }\n"
410         // tests in flow control:
411                 "  if (1 == (gl_SubgroupInvocationID % 2)) {\n"
412         "    for (uint N = 1; N < 7; ++N) {\n"
413                 "      " + fmt + " idhashFmt = " + fmt + "(" + idhash + ");\n"
414                 "      uvec4 partitionBallot = subgroupPartitionNV(idhashFmt) & mask;\n"
415         "      " + fmt + " partitionedResult = " + getOpTypeNamePartitioned(caseDef.opType) + "(data[gl_SubgroupInvocationID], partitionBallot);\n"
416         "        for (uint i = 0; i < N; ++i) {\n"
417                 "          " + fmt + " iFmt = " + fmt + "(i);\n"
418         "          if (" + getCompare(caseDef.opType, caseDef.format, "idhashFmt", "iFmt") + ") {\n"
419         "            " + fmt + " subsetResult = " + getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID]);\n"
420         "            tempResult |= " + getCompare(caseDef.opType, caseDef.format, "partitionedResult", "subsetResult") + " ? (0x20000 << N) : 0;\n"
421         "          }\n"
422         "        }\n"
423         "    }\n"
424         "  } else {\n"
425         "    tempResult |= 0xFC0000;\n"
426         "  }\n"
427         ;
428
429     return bdy;
430 }
431
432 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
433 {
434         std::ostringstream                              bdy;
435
436         subgroups::setFragmentShaderFrameBuffer(programCollection);
437
438         if (SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
439                 subgroups::setVertexShaderFrameBuffer(programCollection);
440
441         bdy << getTestString(caseDef);
442
443         if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
444         {
445                 std::ostringstream vertexSrc;
446                 vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
447                         << "#extension GL_NV_shader_subgroup_partitioned: enable\n"
448                         << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
449                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
450                         << "layout(location = 0) in highp vec4 in_position;\n"
451                         << "layout(location = 0) out float out_color;\n"
452                         << "layout(binding = 0) uniform Buffer0\n"
453                         << "{\n"
454                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
455                         << "};\n"
456                         << "\n"
457                         << "void main (void)\n"
458                         << "{\n"
459                         << "  uvec4 mask = subgroupBallot(true);\n"
460                         << bdy.str()
461                         << "  out_color = float(tempResult);\n"
462                         << "  gl_Position = in_position;\n"
463                         << "  gl_PointSize = 1.0f;\n"
464                         << "}\n";
465                 programCollection.add("vert") << glu::VertexSource(vertexSrc.str());
466         }
467         else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
468         {
469                 std::ostringstream geometry;
470
471                 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
472                         << "#extension GL_NV_shader_subgroup_partitioned: enable\n"
473                         << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
474                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
475                         << "layout(points) in;\n"
476                         << "layout(points, max_vertices = 1) out;\n"
477                         << "layout(location = 0) out float out_color;\n"
478                         << "layout(binding = 0) uniform Buffer0\n"
479                         << "{\n"
480                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
481                         << "};\n"
482                         << "\n"
483                         << "void main (void)\n"
484                         << "{\n"
485                         << "  uvec4 mask = subgroupBallot(true);\n"
486                         << bdy.str()
487                         << "  out_color = float(tempResult);\n"
488                         << "  gl_Position = gl_in[0].gl_Position;\n"
489                         << "  EmitVertex();\n"
490                         << "  EndPrimitive();\n"
491                         << "}\n";
492
493                 programCollection.add("geometry") << glu::GeometrySource(geometry.str());
494         }
495         else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
496         {
497                 std::ostringstream controlSource;
498                 controlSource  << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
499                         << "#extension GL_NV_shader_subgroup_partitioned: enable\n"
500                         << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
501                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
502                         << "layout(vertices = 2) out;\n"
503                         << "layout(location = 0) out float out_color[];\n"
504                         << "layout(binding = 0) uniform Buffer0\n"
505                         << "{\n"
506                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
507                         << "};\n"
508                         << "\n"
509                         << "void main (void)\n"
510                         << "{\n"
511                         << "  if (gl_InvocationID == 0)\n"
512                         <<"  {\n"
513                         << "    gl_TessLevelOuter[0] = 1.0f;\n"
514                         << "    gl_TessLevelOuter[1] = 1.0f;\n"
515                         << "  }\n"
516                         << "  uvec4 mask = subgroupBallot(true);\n"
517                         << bdy.str()
518                         << "  out_color[gl_InvocationID] = float(tempResult);"
519                         << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
520                         << "}\n";
521
522
523                 programCollection.add("tesc") << glu::TessellationControlSource(controlSource.str());
524                 subgroups::setTesEvalShaderFrameBuffer(programCollection);
525         }
526         else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
527         {
528
529                 std::ostringstream evaluationSource;
530                 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
531                         << "#extension GL_NV_shader_subgroup_partitioned: enable\n"
532                         << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
533                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
534                         << "layout(isolines, equal_spacing, ccw ) in;\n"
535                         << "layout(location = 0) out float out_color;\n"
536                         << "layout(binding = 0) uniform Buffer0\n"
537                         << "{\n"
538                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
539                         << "};\n"
540                         << "\n"
541                         << "void main (void)\n"
542                         << "{\n"
543                         << "  uvec4 mask = subgroupBallot(true);\n"
544                         << bdy.str()
545                         << "  out_color = float(tempResult);\n"
546                         << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
547                         << "}\n";
548
549                 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
550                 programCollection.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str());
551         }
552         else
553         {
554                 DE_FATAL("Unsupported shader stage");
555         }
556 }
557
558 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
559 {
560         const string bdy = getTestString(caseDef);
561
562         if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
563         {
564                 std::ostringstream src;
565
566                 src << "#version 450\n"
567                         << "#extension GL_NV_shader_subgroup_partitioned: enable\n"
568                         << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
569                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
570                         << "layout (${LOCAL_SIZE_X}, ${LOCAL_SIZE_Y}, ${LOCAL_SIZE_Z}) in;\n"
571                         << "layout(binding = 0, std430) buffer Buffer0\n"
572                         << "{\n"
573                         << "  uint result[];\n"
574                         << "};\n"
575                         << "layout(binding = 1, std430) buffer Buffer1\n"
576                         << "{\n"
577                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
578                         << "};\n"
579                         << "\n"
580                         << "void main (void)\n"
581                         << "{\n"
582                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
583                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
584                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
585                         "gl_GlobalInvocationID.x;\n"
586                         << "  uvec4 mask = subgroupBallot(true);\n"
587                         << bdy
588                         << "  result[offset] = tempResult;\n"
589                         << "}\n";
590
591                 programCollection.add("comp") << glu::ComputeSource(src.str());
592         }
593         else
594         {
595                 {
596                         const std::string vertex =
597                                 "#version 450\n"
598                                 "#extension GL_NV_shader_subgroup_partitioned: enable\n"
599                             "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
600                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
601                                 "layout(binding = 0, std430) buffer Buffer0\n"
602                                 "{\n"
603                                 "  uint result[];\n"
604                                 "} b0;\n"
605                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
606                                 "{\n"
607                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
608                                 "};\n"
609                                 "\n"
610                                 "void main (void)\n"
611                                 "{\n"
612                                 "  uvec4 mask = subgroupBallot(true);\n"
613                                 + bdy+
614                                 "  b0.result[gl_VertexID] = tempResult;\n"
615                                 "  float pixelSize = 2.0f/1024.0f;\n"
616                                 "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
617                                 "  gl_Position = vec4(float(gl_VertexID) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
618                                 "  gl_PointSize = 1.0f;\n"
619                                 "}\n";
620                         programCollection.add("vert") << glu::VertexSource(vertex);
621                 }
622
623                 {
624                         const std::string tesc =
625                                 "#version 450\n"
626                                 "#extension GL_NV_shader_subgroup_partitioned: enable\n"
627                             "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
628                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
629                                 "layout(vertices=1) out;\n"
630                                 "layout(binding = 1, std430) buffer Buffer1\n"
631                                 "{\n"
632                                 "  uint result[];\n"
633                                 "} b1;\n"
634                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
635                                 "{\n"
636                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
637                                 "};\n"
638                                 "\n"
639                                 "void main (void)\n"
640                                 "{\n"
641                                 "  uvec4 mask = subgroupBallot(true);\n"
642                                 + bdy +
643                                 "  b1.result[gl_PrimitiveID] = tempResult;\n"
644                                 "  if (gl_InvocationID == 0)\n"
645                                 "  {\n"
646                                 "    gl_TessLevelOuter[0] = 1.0f;\n"
647                                 "    gl_TessLevelOuter[1] = 1.0f;\n"
648                                 "  }\n"
649                                 "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
650                                 "}\n";
651                         programCollection.add("tesc") << glu::TessellationControlSource(tesc);
652                 }
653
654                 {
655                         const std::string tese =
656                                 "#version 450\n"
657                                 "#extension GL_NV_shader_subgroup_partitioned: enable\n"
658                             "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
659                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
660                                 "layout(isolines) in;\n"
661                                 "layout(binding = 2, std430) buffer Buffer2\n"
662                                 "{\n"
663                                 "  uint result[];\n"
664                                 "} b2;\n"
665                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
666                                 "{\n"
667                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
668                                 "};\n"
669                                 "\n"
670                                 "void main (void)\n"
671                                 "{\n"
672                                 "  uvec4 mask = subgroupBallot(true);\n"
673                                 + bdy +
674                                 "  b2.result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
675                                 "  float pixelSize = 2.0f/1024.0f;\n"
676                                 "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
677                                 "}\n";
678                         programCollection.add("tese") << glu::TessellationEvaluationSource(tese);
679                 }
680
681                 {
682                         const std::string geometry =
683                                 "#version 450\n"
684                                 "#extension GL_NV_shader_subgroup_partitioned: enable\n"
685                             "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
686                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
687                                 "layout(${TOPOLOGY}) in;\n"
688                                 "layout(points, max_vertices = 1) out;\n"
689                                 "layout(binding = 3, std430) buffer Buffer3\n"
690                                 "{\n"
691                                 "  uint result[];\n"
692                                 "} b3;\n"
693                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
694                                 "{\n"
695                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
696                                 "};\n"
697                                 "\n"
698                                 "void main (void)\n"
699                                 "{\n"
700                                 "  uvec4 mask = subgroupBallot(true);\n"
701                                  + bdy +
702                                 "  b3.result[gl_PrimitiveIDIn] = tempResult;\n"
703                                 "  gl_Position = gl_in[0].gl_Position;\n"
704                                 "  EmitVertex();\n"
705                                 "  EndPrimitive();\n"
706                                 "}\n";
707                         subgroups::addGeometryShadersFromTemplate(geometry, programCollection);
708                 }
709
710                 {
711                         const std::string fragment =
712                                 "#version 450\n"
713                                 "#extension GL_NV_shader_subgroup_partitioned: enable\n"
714                             "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
715                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
716                                 "layout(location = 0) out uint result;\n"
717                                 "layout(binding = 4, std430) readonly buffer Buffer4\n"
718                                 "{\n"
719                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
720                                 "};\n"
721                                 "void main (void)\n"
722                                 "{\n"
723                                 "  uvec4 mask = subgroupBallot(true);\n"
724                                 + bdy +
725                                 "  result = tempResult;\n"
726                                 "}\n";
727                         programCollection.add("fragment") << glu::FragmentSource(fragment);
728                 }
729                 subgroups::addNoSubgroupShader(programCollection);
730         }
731 }
732
733 void supportedCheck (Context& context, CaseDefinition caseDef)
734 {
735         if (!subgroups::isSubgroupSupported(context))
736                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
737
738         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, SUBGROUP_FEATURE_PARTITIONED_BIT_NV))
739         {
740                 TCU_THROW(NotSupportedError, "Device does not support subgroup partitioned operations");
741         }
742
743         if (subgroups::isDoubleFormat(caseDef.format) &&
744                         !subgroups::isDoubleSupportedForDevice(context))
745         {
746                 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
747         }
748 }
749
750 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
751 {
752         if (!subgroups::areSubgroupOperationsSupportedForStage(
753                                 context, caseDef.shaderStage))
754         {
755                 if (subgroups::areSubgroupOperationsRequiredForStage(
756                                         caseDef.shaderStage))
757                 {
758                         return tcu::TestStatus::fail(
759                                            "Shader stage " +
760                                            subgroups::getShaderStageName(caseDef.shaderStage) +
761                                            " is required to support subgroup operations!");
762                 }
763                 else
764                 {
765                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
766                 }
767         }
768
769         subgroups::SSBOData inputData;
770         inputData.format = caseDef.format;
771         inputData.numElements = subgroups::maxSupportedSubgroupSize();
772         inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
773         inputData.binding = 0u;
774
775         if (SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
776                 return subgroups::makeVertexFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
777         else if (SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
778                 return subgroups::makeGeometryFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
779         else if (SHADER_STAGE_TESS_CONTROL_BIT == caseDef.shaderStage)
780                 return subgroups::makeTessellationEvaluationFrameBufferTest(context, FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_CONTROL_BIT);
781         else if (SHADER_STAGE_TESS_EVALUATION_BIT == caseDef.shaderStage)
782                 return subgroups::makeTessellationEvaluationFrameBufferTest(context,  FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, SHADER_STAGE_TESS_EVALUATION_BIT);
783         else
784                 TCU_THROW(InternalError, "Unhandled shader stage");
785 }
786
787 bool checkShaderStages (Context& context, const CaseDefinition& caseDef)
788 {
789         if (!subgroups::areSubgroupOperationsSupportedForStage(
790                                 context, caseDef.shaderStage))
791         {
792                 if (subgroups::areSubgroupOperationsRequiredForStage(
793                                         caseDef.shaderStage))
794                 {
795                         return false;
796                 }
797                 else
798                 {
799                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
800                 }
801         }
802         return true;
803 }
804
805 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
806 {
807         if (SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
808         {
809                 if(!checkShaderStages(context,caseDef))
810                 {
811                         return tcu::TestStatus::fail(
812                                                         "Shader stage " +
813                                                         subgroups::getShaderStageName(caseDef.shaderStage) +
814                                                         " is required to support subgroup operations!");
815                 }
816                 subgroups::SSBOData inputData;
817                 inputData.format = caseDef.format;
818                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
819                 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
820                 inputData.binding = 1u;
821
822                 return subgroups::makeComputeTest(context, FORMAT_R32_UINT, &inputData, 1, checkComputeStage);
823         }
824         else
825         {
826                 int supportedStages = context.getDeqpContext().getContextInfo().getInt(GL_SUBGROUP_SUPPORTED_STAGES_KHR);
827
828                 ShaderStageFlags stages = (ShaderStageFlags)(caseDef.shaderStage & supportedStages);
829
830                 if ( SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
831                 {
832                         if ( (stages & SHADER_STAGE_FRAGMENT_BIT) == 0)
833                                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
834                         else
835                                 stages = SHADER_STAGE_FRAGMENT_BIT;
836                 }
837
838                 if ((ShaderStageFlags)0u == stages)
839                         TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
840
841                 subgroups::SSBOData inputData;
842                 inputData.format                        = caseDef.format;
843                 inputData.numElements           = subgroups::maxSupportedSubgroupSize();
844                 inputData.initializeType        = subgroups::SSBOData::InitializeNonZero;
845                 inputData.binding                       = 4u;
846                 inputData.stages                        = stages;
847
848                 return subgroups::allStages(context, FORMAT_R32_UINT, &inputData,
849                                                                                  1, checkVertexPipelineStages, stages);
850         }
851 }
852 }
853
854 deqp::TestCaseGroup* createSubgroupsPartitionedTests(deqp::Context& testCtx)
855 {
856         de::MovePtr<deqp::TestCaseGroup> graphicGroup(new deqp::TestCaseGroup(
857                 testCtx, "graphics", "Subgroup partitioned category tests: graphics"));
858         de::MovePtr<deqp::TestCaseGroup> computeGroup(new deqp::TestCaseGroup(
859                 testCtx, "compute", "Subgroup partitioned category tests: compute"));
860         de::MovePtr<deqp::TestCaseGroup> framebufferGroup(new deqp::TestCaseGroup(
861                 testCtx, "framebuffer", "Subgroup partitioned category tests: framebuffer"));
862
863
864         const ShaderStageFlags stages[] =
865         {
866                 SHADER_STAGE_VERTEX_BIT,
867                 SHADER_STAGE_TESS_EVALUATION_BIT,
868                 SHADER_STAGE_TESS_CONTROL_BIT,
869                 SHADER_STAGE_GEOMETRY_BIT,
870         };
871
872         const Format formats[] =
873         {
874                 FORMAT_R32_SINT, FORMAT_R32G32_SINT, FORMAT_R32G32B32_SINT,
875                 FORMAT_R32G32B32A32_SINT, FORMAT_R32_UINT, FORMAT_R32G32_UINT,
876                 FORMAT_R32G32B32_UINT, FORMAT_R32G32B32A32_UINT,
877                 FORMAT_R32_SFLOAT, FORMAT_R32G32_SFLOAT,
878                 FORMAT_R32G32B32_SFLOAT, FORMAT_R32G32B32A32_SFLOAT,
879                 FORMAT_R64_SFLOAT, FORMAT_R64G64_SFLOAT,
880                 FORMAT_R64G64B64_SFLOAT, FORMAT_R64G64B64A64_SFLOAT,
881                 FORMAT_R32_BOOL, FORMAT_R32G32_BOOL,
882                 FORMAT_R32G32B32_BOOL, FORMAT_R32G32B32A32_BOOL,
883         };
884
885         for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
886         {
887                 const Format format = formats[formatIndex];
888
889                 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
890                 {
891                         bool isBool = false;
892                         bool isFloat = false;
893
894                         switch (format)
895                         {
896                                 default:
897                                         break;
898                                 case FORMAT_R32_SFLOAT:
899                                 case FORMAT_R32G32_SFLOAT:
900                                 case FORMAT_R32G32B32_SFLOAT:
901                                 case FORMAT_R32G32B32A32_SFLOAT:
902                                 case FORMAT_R64_SFLOAT:
903                                 case FORMAT_R64G64_SFLOAT:
904                                 case FORMAT_R64G64B64_SFLOAT:
905                                 case FORMAT_R64G64B64A64_SFLOAT:
906                                         isFloat = true;
907                                         break;
908                                 case FORMAT_R32_BOOL:
909                                 case FORMAT_R32G32_BOOL:
910                                 case FORMAT_R32G32B32_BOOL:
911                                 case FORMAT_R32G32B32A32_BOOL:
912                                         isBool = true;
913                                         break;
914                         }
915
916                         bool isBitwiseOp = false;
917
918                         switch (opTypeIndex)
919                         {
920                                 default:
921                                         break;
922                                 case OPTYPE_AND:
923                                 case OPTYPE_INCLUSIVE_AND:
924                                 case OPTYPE_EXCLUSIVE_AND:
925                                 case OPTYPE_OR:
926                                 case OPTYPE_INCLUSIVE_OR:
927                                 case OPTYPE_EXCLUSIVE_OR:
928                                 case OPTYPE_XOR:
929                                 case OPTYPE_INCLUSIVE_XOR:
930                                 case OPTYPE_EXCLUSIVE_XOR:
931                                         isBitwiseOp = true;
932                                         break;
933                         }
934
935                         if (isFloat && isBitwiseOp)
936                         {
937                                 // Skip float with bitwise category.
938                                 continue;
939                         }
940
941                         if (isBool && !isBitwiseOp)
942                         {
943                                 // Skip bool when its not the bitwise category.
944                                 continue;
945                         }
946                         const std::string name = de::toLower(getOpTypeName(opTypeIndex)) + "_" +
947                                 subgroups::getFormatNameForGLSL(format);
948
949                         {
950                                 const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_COMPUTE_BIT, format};
951                                 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(computeGroup.get(),
952                                                                                         name, "", supportedCheck, initPrograms, test, caseDef);
953                         }
954
955                         {
956                                 const CaseDefinition caseDef = {opTypeIndex, SHADER_STAGE_ALL_GRAPHICS, format};
957                                 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(graphicGroup.get(),
958                                                                                         name, "", supportedCheck, initPrograms, test, caseDef);
959                         }
960
961                         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
962                         {
963                                 const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
964                                 SubgroupFactory<CaseDefinition>::addFunctionCaseWithPrograms(framebufferGroup.get(),
965                                                                                         name + "_" + getShaderStageName(caseDef.shaderStage), "",
966                                                                                         supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
967                         }
968                 }
969         }
970         de::MovePtr<deqp::TestCaseGroup> group(new deqp::TestCaseGroup(
971                         testCtx, "partitioned", "NV_shader_subgroup_partitioned category tests"));
972
973         group->addChild(graphicGroup.release());
974         group->addChild(computeGroup.release());
975         group->addChild(framebufferGroup.release());
976
977         return group.release();
978 }
979
980 } // subgroups
981 } // glc
982