6f2f6ba13ac376bf6929b4aaedc7f1a3f52385cf
[platform/upstream/VK-GL-CTS.git] / external / openglcts / modules / common / subgroups / glcSubgroupsArithmeticTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsArithmeticTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 enum OpType
39 {
40         OPTYPE_ADD = 0,
41         OPTYPE_MUL,
42         OPTYPE_MIN,
43         OPTYPE_MAX,
44         OPTYPE_AND,
45         OPTYPE_OR,
46         OPTYPE_XOR,
47         OPTYPE_INCLUSIVE_ADD,
48         OPTYPE_INCLUSIVE_MUL,
49         OPTYPE_INCLUSIVE_MIN,
50         OPTYPE_INCLUSIVE_MAX,
51         OPTYPE_INCLUSIVE_AND,
52         OPTYPE_INCLUSIVE_OR,
53         OPTYPE_INCLUSIVE_XOR,
54         OPTYPE_EXCLUSIVE_ADD,
55         OPTYPE_EXCLUSIVE_MUL,
56         OPTYPE_EXCLUSIVE_MIN,
57         OPTYPE_EXCLUSIVE_MAX,
58         OPTYPE_EXCLUSIVE_AND,
59         OPTYPE_EXCLUSIVE_OR,
60         OPTYPE_EXCLUSIVE_XOR,
61         OPTYPE_LAST
62 };
63
64 static bool checkVertexPipelineStages(std::vector<const void*> datas,
65                                                                           deUint32 width, deUint32)
66 {
67         return vkt::subgroups::check(datas, width, 0x3);
68 }
69
70 static bool checkCompute(std::vector<const void*> datas,
71                                                  const deUint32 numWorkgroups[3], const deUint32 localSize[3],
72                                                  deUint32)
73 {
74         return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 0x3);
75 }
76
77 std::string getOpTypeName(int opType)
78 {
79         switch (opType)
80         {
81                 default:
82                         DE_FATAL("Unsupported op type");
83                         return "";
84                 case OPTYPE_ADD:
85                         return "subgroupAdd";
86                 case OPTYPE_MUL:
87                         return "subgroupMul";
88                 case OPTYPE_MIN:
89                         return "subgroupMin";
90                 case OPTYPE_MAX:
91                         return "subgroupMax";
92                 case OPTYPE_AND:
93                         return "subgroupAnd";
94                 case OPTYPE_OR:
95                         return "subgroupOr";
96                 case OPTYPE_XOR:
97                         return "subgroupXor";
98                 case OPTYPE_INCLUSIVE_ADD:
99                         return "subgroupInclusiveAdd";
100                 case OPTYPE_INCLUSIVE_MUL:
101                         return "subgroupInclusiveMul";
102                 case OPTYPE_INCLUSIVE_MIN:
103                         return "subgroupInclusiveMin";
104                 case OPTYPE_INCLUSIVE_MAX:
105                         return "subgroupInclusiveMax";
106                 case OPTYPE_INCLUSIVE_AND:
107                         return "subgroupInclusiveAnd";
108                 case OPTYPE_INCLUSIVE_OR:
109                         return "subgroupInclusiveOr";
110                 case OPTYPE_INCLUSIVE_XOR:
111                         return "subgroupInclusiveXor";
112                 case OPTYPE_EXCLUSIVE_ADD:
113                         return "subgroupExclusiveAdd";
114                 case OPTYPE_EXCLUSIVE_MUL:
115                         return "subgroupExclusiveMul";
116                 case OPTYPE_EXCLUSIVE_MIN:
117                         return "subgroupExclusiveMin";
118                 case OPTYPE_EXCLUSIVE_MAX:
119                         return "subgroupExclusiveMax";
120                 case OPTYPE_EXCLUSIVE_AND:
121                         return "subgroupExclusiveAnd";
122                 case OPTYPE_EXCLUSIVE_OR:
123                         return "subgroupExclusiveOr";
124                 case OPTYPE_EXCLUSIVE_XOR:
125                         return "subgroupExclusiveXor";
126         }
127 }
128
129 std::string getOpTypeOperation(int opType, vk::VkFormat format, std::string lhs, std::string rhs)
130 {
131         switch (opType)
132         {
133                 default:
134                         DE_FATAL("Unsupported op type");
135                         return "";
136                 case OPTYPE_ADD:
137                 case OPTYPE_INCLUSIVE_ADD:
138                 case OPTYPE_EXCLUSIVE_ADD:
139                         return lhs + " + " + rhs;
140                 case OPTYPE_MUL:
141                 case OPTYPE_INCLUSIVE_MUL:
142                 case OPTYPE_EXCLUSIVE_MUL:
143                         return lhs + " * " + rhs;
144                 case OPTYPE_MIN:
145                 case OPTYPE_INCLUSIVE_MIN:
146                 case OPTYPE_EXCLUSIVE_MIN:
147                         switch (format)
148                         {
149                                 default:
150                                         return "min(" + lhs + ", " + rhs + ")";
151                                 case VK_FORMAT_R32_SFLOAT:
152                                 case VK_FORMAT_R64_SFLOAT:
153                                         return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs + ")))";
154                                 case VK_FORMAT_R32G32_SFLOAT:
155                                 case VK_FORMAT_R32G32B32_SFLOAT:
156                                 case VK_FORMAT_R32G32B32A32_SFLOAT:
157                                 case VK_FORMAT_R64G64_SFLOAT:
158                                 case VK_FORMAT_R64G64B64_SFLOAT:
159                                 case VK_FORMAT_R64G64B64A64_SFLOAT:
160                                         return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
161                         }
162                 case OPTYPE_MAX:
163                 case OPTYPE_INCLUSIVE_MAX:
164                 case OPTYPE_EXCLUSIVE_MAX:
165                         switch (format)
166                         {
167                                 default:
168                                         return "max(" + lhs + ", " + rhs + ")";
169                                 case VK_FORMAT_R32_SFLOAT:
170                                 case VK_FORMAT_R64_SFLOAT:
171                                         return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs + ")))";
172                                 case VK_FORMAT_R32G32_SFLOAT:
173                                 case VK_FORMAT_R32G32B32_SFLOAT:
174                                 case VK_FORMAT_R32G32B32A32_SFLOAT:
175                                 case VK_FORMAT_R64G64_SFLOAT:
176                                 case VK_FORMAT_R64G64B64_SFLOAT:
177                                 case VK_FORMAT_R64G64B64A64_SFLOAT:
178                                         return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
179                         }
180                 case OPTYPE_AND:
181                 case OPTYPE_INCLUSIVE_AND:
182                 case OPTYPE_EXCLUSIVE_AND:
183                         switch (format)
184                         {
185                                 default:
186                                         return lhs + " & " + rhs;
187                                 case VK_FORMAT_R8_USCALED:
188                                         return lhs + " && " + rhs;
189                                 case VK_FORMAT_R8G8_USCALED:
190                                         return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)";
191                                 case VK_FORMAT_R8G8B8_USCALED:
192                                         return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z)";
193                                 case VK_FORMAT_R8G8B8A8_USCALED:
194                                         return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z, " + lhs + ".w && " + rhs + ".w)";
195                         }
196                 case OPTYPE_OR:
197                 case OPTYPE_INCLUSIVE_OR:
198                 case OPTYPE_EXCLUSIVE_OR:
199                         switch (format)
200                         {
201                                 default:
202                                         return lhs + " | " + rhs;
203                                 case VK_FORMAT_R8_USCALED:
204                                         return lhs + " || " + rhs;
205                                 case VK_FORMAT_R8G8_USCALED:
206                                         return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)";
207                                 case VK_FORMAT_R8G8B8_USCALED:
208                                         return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z)";
209                                 case VK_FORMAT_R8G8B8A8_USCALED:
210                                         return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z, " + lhs + ".w || " + rhs + ".w)";
211                         }
212                 case OPTYPE_XOR:
213                 case OPTYPE_INCLUSIVE_XOR:
214                 case OPTYPE_EXCLUSIVE_XOR:
215                         switch (format)
216                         {
217                                 default:
218                                         return lhs + " ^ " + rhs;
219                                 case VK_FORMAT_R8_USCALED:
220                                         return lhs + " ^^ " + rhs;
221                                 case VK_FORMAT_R8G8_USCALED:
222                                         return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)";
223                                 case VK_FORMAT_R8G8B8_USCALED:
224                                         return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z)";
225                                 case VK_FORMAT_R8G8B8A8_USCALED:
226                                         return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z, " + lhs + ".w ^^ " + rhs + ".w)";
227                         }
228         }
229 }
230
231 std::string getIdentity(int opType, vk::VkFormat format)
232 {
233         bool isFloat = false;
234         bool isInt = false;
235         bool isUnsigned = false;
236
237         switch (format)
238         {
239                 default:
240                         DE_FATAL("Unhandled format!");
241                         break;
242                 case VK_FORMAT_R32_SINT:
243                 case VK_FORMAT_R32G32_SINT:
244                 case VK_FORMAT_R32G32B32_SINT:
245                 case VK_FORMAT_R32G32B32A32_SINT:
246                         isInt = true;
247                         break;
248                 case VK_FORMAT_R32_UINT:
249                 case VK_FORMAT_R32G32_UINT:
250                 case VK_FORMAT_R32G32B32_UINT:
251                 case VK_FORMAT_R32G32B32A32_UINT:
252                         isUnsigned = true;
253                         break;
254                 case VK_FORMAT_R32_SFLOAT:
255                 case VK_FORMAT_R32G32_SFLOAT:
256                 case VK_FORMAT_R32G32B32_SFLOAT:
257                 case VK_FORMAT_R32G32B32A32_SFLOAT:
258                 case VK_FORMAT_R64_SFLOAT:
259                 case VK_FORMAT_R64G64_SFLOAT:
260                 case VK_FORMAT_R64G64B64_SFLOAT:
261                 case VK_FORMAT_R64G64B64A64_SFLOAT:
262                         isFloat = true;
263                         break;
264                 case VK_FORMAT_R8_USCALED:
265                 case VK_FORMAT_R8G8_USCALED:
266                 case VK_FORMAT_R8G8B8_USCALED:
267                 case VK_FORMAT_R8G8B8A8_USCALED:
268                         break; // bool types are not anything
269         }
270
271         switch (opType)
272         {
273                 default:
274                         DE_FATAL("Unsupported op type");
275                         return "";
276                 case OPTYPE_ADD:
277                 case OPTYPE_INCLUSIVE_ADD:
278                 case OPTYPE_EXCLUSIVE_ADD:
279                         return subgroups::getFormatNameForGLSL(format) + "(0)";
280                 case OPTYPE_MUL:
281                 case OPTYPE_INCLUSIVE_MUL:
282                 case OPTYPE_EXCLUSIVE_MUL:
283                         return subgroups::getFormatNameForGLSL(format) + "(1)";
284                 case OPTYPE_MIN:
285                 case OPTYPE_INCLUSIVE_MIN:
286                 case OPTYPE_EXCLUSIVE_MIN:
287                         if (isFloat)
288                         {
289                                 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))";
290                         }
291                         else if (isInt)
292                         {
293                                 return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)";
294                         }
295                         else if (isUnsigned)
296                         {
297                                 return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)";
298                         }
299                         else
300                         {
301                                 DE_FATAL("Unhandled case");
302                                 return "";
303                         }
304                 case OPTYPE_MAX:
305                 case OPTYPE_INCLUSIVE_MAX:
306                 case OPTYPE_EXCLUSIVE_MAX:
307                         if (isFloat)
308                         {
309                                 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))";
310                         }
311                         else if (isInt)
312                         {
313                                 return subgroups::getFormatNameForGLSL(format) + "(0x80000000)";
314                         }
315                         else if (isUnsigned)
316                         {
317                                 return subgroups::getFormatNameForGLSL(format) + "(0)";
318                         }
319                         else
320                         {
321                                 DE_FATAL("Unhandled case");
322                                 return "";
323                         }
324                 case OPTYPE_AND:
325                 case OPTYPE_INCLUSIVE_AND:
326                 case OPTYPE_EXCLUSIVE_AND:
327                         return subgroups::getFormatNameForGLSL(format) + "(~0)";
328                 case OPTYPE_OR:
329                 case OPTYPE_INCLUSIVE_OR:
330                 case OPTYPE_EXCLUSIVE_OR:
331                         return subgroups::getFormatNameForGLSL(format) + "(0)";
332                 case OPTYPE_XOR:
333                 case OPTYPE_INCLUSIVE_XOR:
334                 case OPTYPE_EXCLUSIVE_XOR:
335                         return subgroups::getFormatNameForGLSL(format) + "(0)";
336         }
337 }
338
339 std::string getCompare(int opType, vk::VkFormat format, std::string lhs, std::string rhs)
340 {
341         std::string formatName = subgroups::getFormatNameForGLSL(format);
342         switch (format)
343         {
344                 default:
345                         return "all(equal(" + lhs + ", " + rhs + "))";
346                 case VK_FORMAT_R8_USCALED:
347                 case VK_FORMAT_R32_UINT:
348                 case VK_FORMAT_R32_SINT:
349                         return "(" + lhs + " == " + rhs + ")";
350                 case VK_FORMAT_R32_SFLOAT:
351                 case VK_FORMAT_R64_SFLOAT:
352                         switch (opType)
353                         {
354                                 default:
355                                         return "(abs(" + lhs + " - " + rhs + ") < 0.00001)";
356                                 case OPTYPE_MIN:
357                                 case OPTYPE_INCLUSIVE_MIN:
358                                 case OPTYPE_EXCLUSIVE_MIN:
359                                 case OPTYPE_MAX:
360                                 case OPTYPE_INCLUSIVE_MAX:
361                                 case OPTYPE_EXCLUSIVE_MAX:
362                                         return "(" + lhs + " == " + rhs + ")";
363                         }
364                 case VK_FORMAT_R32G32_SFLOAT:
365                 case VK_FORMAT_R32G32B32_SFLOAT:
366                 case VK_FORMAT_R32G32B32A32_SFLOAT:
367                 case VK_FORMAT_R64G64_SFLOAT:
368                 case VK_FORMAT_R64G64B64_SFLOAT:
369                 case VK_FORMAT_R64G64B64A64_SFLOAT:
370                         switch (opType)
371                         {
372                                 default:
373                                         return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))";
374                                 case OPTYPE_MIN:
375                                 case OPTYPE_INCLUSIVE_MIN:
376                                 case OPTYPE_EXCLUSIVE_MIN:
377                                 case OPTYPE_MAX:
378                                 case OPTYPE_INCLUSIVE_MAX:
379                                 case OPTYPE_EXCLUSIVE_MAX:
380                                         return "all(equal(" + lhs + ", " + rhs + "))";
381                         }
382         }
383 }
384
385 struct CaseDefinition
386 {
387         int                                     opType;
388         VkShaderStageFlags      shaderStage;
389         VkFormat                        format;
390 };
391
392 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
393 {
394         const vk::ShaderBuildOptions    buildOptions    (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
395         std::string                                             indexVars;
396         std::ostringstream                              bdy;
397
398         subgroups::setFragmentShaderFrameBuffer(programCollection);
399
400         if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
401                 subgroups::setVertexShaderFrameBuffer(programCollection);
402
403         switch (caseDef.opType)
404         {
405                 default:
406                         indexVars = "  uint start = 0, end = gl_SubgroupSize;\n";
407                         break;
408                 case OPTYPE_INCLUSIVE_ADD:
409                 case OPTYPE_INCLUSIVE_MUL:
410                 case OPTYPE_INCLUSIVE_MIN:
411                 case OPTYPE_INCLUSIVE_MAX:
412                 case OPTYPE_INCLUSIVE_AND:
413                 case OPTYPE_INCLUSIVE_OR:
414                 case OPTYPE_INCLUSIVE_XOR:
415                         indexVars = "  uint start = 0, end = gl_SubgroupInvocationID + 1;\n";
416                         break;
417                 case OPTYPE_EXCLUSIVE_ADD:
418                 case OPTYPE_EXCLUSIVE_MUL:
419                 case OPTYPE_EXCLUSIVE_MIN:
420                 case OPTYPE_EXCLUSIVE_MAX:
421                 case OPTYPE_EXCLUSIVE_AND:
422                 case OPTYPE_EXCLUSIVE_OR:
423                 case OPTYPE_EXCLUSIVE_XOR:
424                         indexVars = "  uint start = 0, end = gl_SubgroupInvocationID;\n";
425                         break;
426         }
427
428         bdy << indexVars
429                 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = "
430                 << getIdentity(caseDef.opType, caseDef.format) << ";\n"
431                 << "  uint tempResult = 0;\n"
432                 << "  for (uint index = start; index < end; index++)\n"
433                 << "  {\n"
434                 << "    if (subgroupBallotBitExtract(mask, index))\n"
435                 << "    {\n"
436                 << "      ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
437                 << "    }\n"
438                 << "  }\n"
439                 << "  tempResult = " << getCompare(caseDef.opType, caseDef.format, "ref",
440                                                                                         getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x1 : 0;\n"
441                 << "  if (1 == (gl_SubgroupInvocationID % 2))\n"
442                 << "  {\n"
443                 << "    mask = subgroupBallot(true);\n"
444                 << "    ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n"
445                 << "    for (uint index = start; index < end; index++)\n"
446                 << "    {\n"
447                 << "      if (subgroupBallotBitExtract(mask, index))\n"
448                 << "      {\n"
449                 << "        ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
450                 << "      }\n"
451                 << "    }\n"
452                 << "    tempResult |= " << getCompare(caseDef.opType, caseDef.format, "ref",
453                                 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x2 : 0;\n"
454                 << "  }\n"
455                 << "  else\n"
456                 << "  {\n"
457                 << "    tempResult |= 0x2;\n"
458                 << "  }\n";
459
460         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
461         {
462                 std::ostringstream vertexSrc;
463                 vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
464                         << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
465                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
466                         << "layout(location = 0) in highp vec4 in_position;\n"
467                         << "layout(location = 0) out float out_color;\n"
468                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
469                         << "{\n"
470                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
471                         << "};\n"
472                         << "\n"
473                         << "void main (void)\n"
474                         << "{\n"
475                         << "  uvec4 mask = subgroupBallot(true);\n"
476                         << bdy.str()
477                         << "  out_color = float(tempResult);\n"
478                         << "  gl_Position = in_position;\n"
479                         << "  gl_PointSize = 1.0f;\n"
480                         << "}\n";
481                 programCollection.glslSources.add("vert")
482                         << glu::VertexSource(vertexSrc.str()) << buildOptions;
483         }
484         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
485         {
486                 std::ostringstream geometry;
487
488                 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
489                         << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
490                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
491                         << "layout(points) in;\n"
492                         << "layout(points, max_vertices = 1) out;\n"
493                         << "layout(location = 0) out float out_color;\n"
494                         << "layout(set = 0, binding = 0) uniform Buffer\n"
495                         << "{\n"
496                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
497                         << "};\n"
498                         << "\n"
499                         << "void main (void)\n"
500                         << "{\n"
501                         << "  uvec4 mask = subgroupBallot(true);\n"
502                         << bdy.str()
503                         << "  out_color = float(tempResult);\n"
504                         << "  gl_Position = gl_in[0].gl_Position;\n"
505                         << "  EmitVertex();\n"
506                         << "  EndPrimitive();\n"
507                         << "}\n";
508
509                 programCollection.glslSources.add("geometry")
510                                 << glu::GeometrySource(geometry.str()) << buildOptions;
511         }
512         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
513         {
514                 std::ostringstream controlSource;
515                 controlSource  << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
516                         << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
517                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
518                         << "layout(vertices = 2) out;\n"
519                         << "layout(location = 0) out float out_color[];\n"
520                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
521                         << "{\n"
522                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
523                         << "};\n"
524                         << "\n"
525                         << "void main (void)\n"
526                         << "{\n"
527                         << "  if (gl_InvocationID == 0)\n"
528                         <<"  {\n"
529                         << "    gl_TessLevelOuter[0] = 1.0f;\n"
530                         << "    gl_TessLevelOuter[1] = 1.0f;\n"
531                         << "  }\n"
532                         << "  uvec4 mask = subgroupBallot(true);\n"
533                         << bdy.str()
534                         << "  out_color[gl_InvocationID] = float(tempResult);"
535                         << "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
536                         << "}\n";
537
538
539                 programCollection.glslSources.add("tesc")
540                         << glu::TessellationControlSource(controlSource.str()) << buildOptions;
541                 subgroups::setTesEvalShaderFrameBuffer(programCollection);
542         }
543         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
544         {
545
546                 std::ostringstream evaluationSource;
547                 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
548                         << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
549                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
550                         << "layout(isolines, equal_spacing, ccw ) in;\n"
551                         << "layout(location = 0) out float out_color;\n"
552                         << "layout(set = 0, binding = 0) uniform Buffer1\n"
553                         << "{\n"
554                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
555                         << "};\n"
556                         << "\n"
557                         << "void main (void)\n"
558                         << "{\n"
559                         << "  uvec4 mask = subgroupBallot(true);\n"
560                         << bdy.str()
561                         << "  out_color = float(tempResult);\n"
562                         << "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
563                         << "}\n";
564
565                 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
566                 programCollection.glslSources.add("tese") << glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
567         }
568         else
569         {
570                 DE_FATAL("Unsupported shader stage");
571         }
572 }
573
574 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
575 {
576         std::string indexVars;
577         switch (caseDef.opType)
578         {
579                 default:
580                         indexVars = "  uint start = 0, end = gl_SubgroupSize;\n";
581                         break;
582                 case OPTYPE_INCLUSIVE_ADD:
583                 case OPTYPE_INCLUSIVE_MUL:
584                 case OPTYPE_INCLUSIVE_MIN:
585                 case OPTYPE_INCLUSIVE_MAX:
586                 case OPTYPE_INCLUSIVE_AND:
587                 case OPTYPE_INCLUSIVE_OR:
588                 case OPTYPE_INCLUSIVE_XOR:
589                         indexVars = "  uint start = 0, end = gl_SubgroupInvocationID + 1;\n";
590                         break;
591                 case OPTYPE_EXCLUSIVE_ADD:
592                 case OPTYPE_EXCLUSIVE_MUL:
593                 case OPTYPE_EXCLUSIVE_MIN:
594                 case OPTYPE_EXCLUSIVE_MAX:
595                 case OPTYPE_EXCLUSIVE_AND:
596                 case OPTYPE_EXCLUSIVE_OR:
597                 case OPTYPE_EXCLUSIVE_XOR:
598                         indexVars = "  uint start = 0, end = gl_SubgroupInvocationID;\n";
599                         break;
600         }
601
602         const string bdy =
603                 indexVars +
604                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " ref = "
605                 + getIdentity(caseDef.opType, caseDef.format) + ";\n"
606                 "  uint tempResult = 0;\n"
607                 "  for (uint index = start; index < end; index++)\n"
608                 "  {\n"
609                 "    if (subgroupBallotBitExtract(mask, index))\n"
610                 "    {\n"
611                 "      ref = " + getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") + ";\n"
612                 "    }\n"
613                 "  }\n"
614                 "  tempResult = " + getCompare(caseDef.opType, caseDef.format, "ref", getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") + " ? 0x1 : 0;\n"
615                 "  if (1 == (gl_SubgroupInvocationID % 2))\n"
616                 "  {\n"
617                 "    mask = subgroupBallot(true);\n"
618                 "    ref = " + getIdentity(caseDef.opType, caseDef.format) + ";\n"
619                 "    for (uint index = start; index < end; index++)\n"
620                 "    {\n"
621                 "      if (subgroupBallotBitExtract(mask, index))\n"
622                 "      {\n"
623                 "        ref = " + getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") + ";\n"
624                 "      }\n"
625                 "    }\n"
626                 "    tempResult |= " + getCompare(caseDef.opType, caseDef.format, "ref", getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") + " ? 0x2 : 0;\n"
627                 "  }\n"
628                 "  else\n"
629                 "  {\n"
630                 "    tempResult |= 0x2;\n"
631                 "  }\n";
632
633         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
634         {
635                 std::ostringstream src;
636
637                 src << "#version 450\n"
638                         << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
639                         << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
640                         << "layout (local_size_x_id = 0, local_size_y_id = 1, "
641                         "local_size_z_id = 2) in;\n"
642                         << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
643                         << "{\n"
644                         << "  uint result[];\n"
645                         << "};\n"
646                         << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
647                         << "{\n"
648                         << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
649                         << "};\n"
650                         << "\n"
651                         << "void main (void)\n"
652                         << "{\n"
653                         << "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
654                         << "  highp uint offset = globalSize.x * ((globalSize.y * "
655                         "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
656                         "gl_GlobalInvocationID.x;\n"
657                         << "  uvec4 mask = subgroupBallot(true);\n"
658                         << bdy
659                         << "  result[offset] = tempResult;\n"
660                         << "}\n";
661
662                 programCollection.glslSources.add("comp")
663                                 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
664         }
665         else
666         {
667                 {
668                         const std::string vertex =
669                                 "#version 450\n"
670                                 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
671                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
672                                 "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
673                                 "{\n"
674                                 "  uint result[];\n"
675                                 "};\n"
676                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
677                                 "{\n"
678                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
679                                 "};\n"
680                                 "\n"
681                                 "void main (void)\n"
682                                 "{\n"
683                                 "  uvec4 mask = subgroupBallot(true);\n"
684                                 + bdy+
685                                 "  result[gl_VertexIndex] = tempResult;\n"
686                                 "  float pixelSize = 2.0f/1024.0f;\n"
687                                 "  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
688                                 "  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
689                                 "  gl_PointSize = 1.0f;\n"
690                                 "}\n";
691                         programCollection.glslSources.add("vert")
692                                         << glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
693                 }
694
695                 {
696                         const std::string tesc =
697                                 "#version 450\n"
698                                 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
699                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
700                                 "layout(vertices=1) out;\n"
701                                 "layout(set = 0, binding = 1, std430) buffer Buffer1\n"
702                                 "{\n"
703                                 "  uint result[];\n"
704                                 "};\n"
705                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
706                                 "{\n"
707                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
708                                 "};\n"
709                                 "\n"
710                                 "void main (void)\n"
711                                 "{\n"
712                                 "  uvec4 mask = subgroupBallot(true);\n"
713                                 + bdy +
714                                 "  result[gl_PrimitiveID] = tempResult;\n"
715                                 "  if (gl_InvocationID == 0)\n"
716                                 "  {\n"
717                                 "    gl_TessLevelOuter[0] = 1.0f;\n"
718                                 "    gl_TessLevelOuter[1] = 1.0f;\n"
719                                 "  }\n"
720                                 "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
721                                 "}\n";
722                         programCollection.glslSources.add("tesc")
723                                 << glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
724                 }
725
726                 {
727                         const std::string tese =
728                                 "#version 450\n"
729                                 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
730                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
731                                 "layout(isolines) in;\n"
732                                 "layout(set = 0, binding = 2, std430) buffer Buffer1\n"
733                                 "{\n"
734                                 "  uint result[];\n"
735                                 "};\n"
736                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
737                                 "{\n"
738                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
739                                 "};\n"
740                                 "\n"
741                                 "void main (void)\n"
742                                 "{\n"
743                                 "  uvec4 mask = subgroupBallot(true);\n"
744                                 + bdy +
745                                 "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
746                                 "  float pixelSize = 2.0f/1024.0f;\n"
747                                 "  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
748                                 "}\n";
749                         programCollection.glslSources.add("tese")
750                                 << glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
751                 }
752
753                 {
754                         const std::string geometry =
755                                 "#version 450\n"
756                                 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
757                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
758                                 "layout(${TOPOLOGY}) in;\n"
759                                 "layout(points, max_vertices = 1) out;\n"
760                                 "layout(set = 0, binding = 3, std430) buffer Buffer1\n"
761                                 "{\n"
762                                 "  uint result[];\n"
763                                 "};\n"
764                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
765                                 "{\n"
766                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
767                                 "};\n"
768                                 "\n"
769                                 "void main (void)\n"
770                                 "{\n"
771                                 "  uvec4 mask = subgroupBallot(true);\n"
772                                  + bdy +
773                                 "  result[gl_PrimitiveIDIn] = tempResult;\n"
774                                 "  gl_Position = gl_in[0].gl_Position;\n"
775                                 "  EmitVertex();\n"
776                                 "  EndPrimitive();\n"
777                                 "}\n";
778                         subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
779                                                                                                           programCollection.glslSources);
780                 }
781
782                 {
783                         const std::string fragment =
784                                 "#version 450\n"
785                                 "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
786                                 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
787                                 "layout(location = 0) out uint result;\n"
788                                 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
789                                 "{\n"
790                                 "  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
791                                 "};\n"
792                                 "void main (void)\n"
793                                 "{\n"
794                                 "  uvec4 mask = subgroupBallot(true);\n"
795                                 + bdy +
796                                 "  result = tempResult;\n"
797                                 "}\n";
798                         programCollection.glslSources.add("fragment")
799                                 << glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
800                 }
801                 subgroups::addNoSubgroupShader(programCollection);
802         }
803 }
804
805 void supportedCheck (Context& context, CaseDefinition caseDef)
806 {
807         if (!subgroups::isSubgroupSupported(context))
808                 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
809
810         if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_ARITHMETIC_BIT))
811         {
812                 TCU_THROW(NotSupportedError, "Device does not support subgroup arithmetic operations");
813         }
814
815         if (subgroups::isDoubleFormat(caseDef.format) &&
816                         !subgroups::isDoubleSupportedForDevice(context))
817         {
818                 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
819         }
820 }
821
822 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
823 {
824         if (!subgroups::areSubgroupOperationsSupportedForStage(
825                                 context, caseDef.shaderStage))
826         {
827                 if (subgroups::areSubgroupOperationsRequiredForStage(
828                                         caseDef.shaderStage))
829                 {
830                         return tcu::TestStatus::fail(
831                                            "Shader stage " +
832                                            subgroups::getShaderStageName(caseDef.shaderStage) +
833                                            " is required to support subgroup operations!");
834                 }
835                 else
836                 {
837                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
838                 }
839         }
840
841         subgroups::SSBOData inputData;
842         inputData.format = caseDef.format;
843         inputData.numElements = subgroups::maxSupportedSubgroupSize();
844         inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
845
846         if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
847                 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
848         else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
849                 return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
850         else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
851                 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
852         else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
853                 return subgroups::makeTessellationEvaluationFrameBufferTest(context,  VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
854         else
855                 TCU_THROW(InternalError, "Unhandled shader stage");
856 }
857
858 bool checkShaderStages (Context& context, const CaseDefinition& caseDef)
859 {
860         if (!subgroups::areSubgroupOperationsSupportedForStage(
861                                 context, caseDef.shaderStage))
862         {
863                 if (subgroups::areSubgroupOperationsRequiredForStage(
864                                         caseDef.shaderStage))
865                 {
866                         return false;
867                 }
868                 else
869                 {
870                         TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
871                 }
872         }
873         return true;
874 }
875
876 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
877 {
878         if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
879         {
880                 if(!checkShaderStages(context,caseDef))
881                 {
882                         return tcu::TestStatus::fail(
883                                                         "Shader stage " +
884                                                         subgroups::getShaderStageName(caseDef.shaderStage) +
885                                                         " is required to support subgroup operations!");
886                 }
887                 subgroups::SSBOData inputData;
888                 inputData.format = caseDef.format;
889                 inputData.numElements = subgroups::maxSupportedSubgroupSize();
890                 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
891
892                 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkCompute);
893         }
894         else
895         {
896                 VkPhysicalDeviceSubgroupProperties subgroupProperties;
897                 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
898                 subgroupProperties.pNext = DE_NULL;
899
900                 VkPhysicalDeviceProperties2 properties;
901                 properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
902                 properties.pNext = &subgroupProperties;
903
904                 context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
905
906                 VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage  & subgroupProperties.supportedStages);
907
908                 if ( VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
909                 {
910                         if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
911                                 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
912                         else
913                                 stages = VK_SHADER_STAGE_FRAGMENT_BIT;
914                 }
915
916                 if ((VkShaderStageFlagBits)0u == stages)
917                         TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
918
919                 subgroups::SSBOData inputData;
920                 inputData.format                        = caseDef.format;
921                 inputData.numElements           = subgroups::maxSupportedSubgroupSize();
922                 inputData.initializeType        = subgroups::SSBOData::InitializeNonZero;
923                 inputData.binding                       = 4u;
924                 inputData.stages                        = stages;
925
926                 return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData,
927                                                                                  1, checkVertexPipelineStages, stages);
928         }
929 }
930 }
931
932 namespace vkt
933 {
934 namespace subgroups
935 {
936 tcu::TestCaseGroup* createSubgroupsArithmeticTests(tcu::TestContext& testCtx)
937 {
938         de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
939                 testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
940         de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
941                 testCtx, "compute", "Subgroup arithmetic category tests: compute"));
942         de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
943                 testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
944
945         const VkShaderStageFlags stages[] =
946         {
947                 VK_SHADER_STAGE_VERTEX_BIT,
948                 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
949                 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
950                 VK_SHADER_STAGE_GEOMETRY_BIT,
951         };
952
953         const VkFormat formats[] =
954         {
955                 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
956                 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
957                 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
958                 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
959                 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
960                 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
961                 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
962                 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
963                 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
964         };
965
966         for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
967         {
968                 const VkFormat format = formats[formatIndex];
969
970                 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
971                 {
972                         bool isBool = false;
973                         bool isFloat = false;
974
975                         switch (format)
976                         {
977                                 default:
978                                         break;
979                                 case VK_FORMAT_R32_SFLOAT:
980                                 case VK_FORMAT_R32G32_SFLOAT:
981                                 case VK_FORMAT_R32G32B32_SFLOAT:
982                                 case VK_FORMAT_R32G32B32A32_SFLOAT:
983                                 case VK_FORMAT_R64_SFLOAT:
984                                 case VK_FORMAT_R64G64_SFLOAT:
985                                 case VK_FORMAT_R64G64B64_SFLOAT:
986                                 case VK_FORMAT_R64G64B64A64_SFLOAT:
987                                         isFloat = true;
988                                         break;
989                                 case VK_FORMAT_R8_USCALED:
990                                 case VK_FORMAT_R8G8_USCALED:
991                                 case VK_FORMAT_R8G8B8_USCALED:
992                                 case VK_FORMAT_R8G8B8A8_USCALED:
993                                         isBool = true;
994                                         break;
995                         }
996
997                         bool isBitwiseOp = false;
998
999                         switch (opTypeIndex)
1000                         {
1001                                 default:
1002                                         break;
1003                                 case OPTYPE_AND:
1004                                 case OPTYPE_INCLUSIVE_AND:
1005                                 case OPTYPE_EXCLUSIVE_AND:
1006                                 case OPTYPE_OR:
1007                                 case OPTYPE_INCLUSIVE_OR:
1008                                 case OPTYPE_EXCLUSIVE_OR:
1009                                 case OPTYPE_XOR:
1010                                 case OPTYPE_INCLUSIVE_XOR:
1011                                 case OPTYPE_EXCLUSIVE_XOR:
1012                                         isBitwiseOp = true;
1013                                         break;
1014                         }
1015
1016                         if (isFloat && isBitwiseOp)
1017                         {
1018                                 // Skip float with bitwise category.
1019                                 continue;
1020                         }
1021
1022                         if (isBool && !isBitwiseOp)
1023                         {
1024                                 // Skip bool when its not the bitwise category.
1025                                 continue;
1026                         }
1027                         std::string op = getOpTypeName(opTypeIndex);
1028
1029                         {
1030                                 const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format};
1031                                 addFunctionCaseWithPrograms(computeGroup.get(),
1032                                                                                         de::toLower(op) + "_" +
1033                                                                                         subgroups::getFormatNameForGLSL(format),
1034                                                                                         "", supportedCheck, initPrograms, test, caseDef);
1035                         }
1036
1037                         {
1038                                 const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_ALL_GRAPHICS, format};
1039                                 addFunctionCaseWithPrograms(graphicGroup.get(),
1040                                                                                         de::toLower(op) + "_" +
1041                                                                                         subgroups::getFormatNameForGLSL(format),
1042                                                                                         "", supportedCheck, initPrograms, test, caseDef);
1043                         }
1044
1045                         for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
1046                         {
1047                                 const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
1048                                 addFunctionCaseWithPrograms(framebufferGroup.get(), de::toLower(op) + "_" + subgroups::getFormatNameForGLSL(format) +
1049                                                                                         "_" + getShaderStageName(caseDef.shaderStage), "",
1050                                                                                         supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
1051                         }
1052                 }
1053         }
1054
1055         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
1056                 testCtx, "arithmetic", "Subgroup arithmetic category tests"));
1057
1058         group->addChild(graphicGroup.release());
1059         group->addChild(computeGroup.release());
1060         group->addChild(framebufferGroup.release());
1061
1062         return group.release();
1063 }
1064
1065 } // subgroups
1066 } // vkt