Fix clang 10.0 build of ray tracing control flow tests
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / ray_tracing / vktRayTracingComplexControlFlowTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *        http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Ray Tracing Complex Control Flow tests
22  *//*--------------------------------------------------------------------*/
23
24 #include "vktRayTracingComplexControlFlowTests.hpp"
25
26 #include "vkDefs.hpp"
27
28 #include "vktTestCase.hpp"
29 #include "vkCmdUtil.hpp"
30 #include "vkObjUtil.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkBarrierUtil.hpp"
33 #include "vkBufferWithMemory.hpp"
34 #include "vkImageWithMemory.hpp"
35 #include "vkTypeUtil.hpp"
36
37 #include "vkRayTracingUtil.hpp"
38
39 #include "tcuTestLog.hpp"
40
41 #include "deRandom.hpp"
42
43 namespace vkt
44 {
45 namespace RayTracing
46 {
47 namespace
48 {
49 using namespace vk;
50 using namespace std;
51
52 static const VkFlags    ALL_RAY_TRACING_STAGES  = VK_SHADER_STAGE_RAYGEN_BIT_KHR
53                                                                                                 | VK_SHADER_STAGE_ANY_HIT_BIT_KHR
54                                                                                                 | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR
55                                                                                                 | VK_SHADER_STAGE_MISS_BIT_KHR
56                                                                                                 | VK_SHADER_STAGE_INTERSECTION_BIT_KHR
57                                                                                                 | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
58
59 #if defined(DE_DEBUG)
60 static const deUint32   PUSH_CONSTANTS_COUNT    = 6;
61 #endif
62 static const deUint32   DEFAULT_CLEAR_VALUE             = 999999;
63
64 enum TestType
65 {
66         TEST_TYPE_IF                                            = 0,
67         TEST_TYPE_LOOP,
68         TEST_TYPE_SWITCH,
69         TEST_TYPE_LOOP_DOUBLE_CALL,
70         TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE,
71         TEST_TYPE_NESTED_LOOP,
72         TEST_TYPE_NESTED_LOOP_BEFORE,
73         TEST_TYPE_NESTED_LOOP_AFTER,
74         TEST_TYPE_FUNCTION_CALL,
75         TEST_TYPE_NESTED_FUNCTION_CALL,
76 };
77
78 enum TestOp
79 {
80         TEST_OP_EXECUTE_CALLABLE                = 0,
81         TEST_OP_TRACE_RAY,
82         TEST_OP_REPORT_INTERSECTION,
83 };
84
85 enum ShaderGroups
86 {
87         FIRST_GROUP             = 0,
88         RAYGEN_GROUP    = FIRST_GROUP,
89         MISS_GROUP,
90         HIT_GROUP,
91         GROUP_COUNT
92 };
93
94 struct CaseDef
95 {
96         TestType                                testType;
97         TestOp                                  testOp;
98         VkShaderStageFlagBits   stage;
99         deUint32                                width;
100         deUint32                                height;
101 };
102
103 struct PushConstants
104 {
105         deUint32        a;
106         deUint32        b;
107         deUint32        c;
108         deUint32        d;
109         deUint32        hitOfs;
110         deUint32        miss;
111 };
112
113 deUint32 getShaderGroupSize (const InstanceInterface&   vki,
114                                                          const VkPhysicalDevice         physicalDevice)
115 {
116         de::MovePtr<RayTracingProperties>       rayTracingPropertiesKHR;
117
118         rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
119         return rayTracingPropertiesKHR->getShaderGroupHandleSize();
120 }
121
122 deUint32 getShaderGroupBaseAlignment (const InstanceInterface&  vki,
123                                                                           const VkPhysicalDevice        physicalDevice)
124 {
125         de::MovePtr<RayTracingProperties>       rayTracingPropertiesKHR;
126
127         rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
128         return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
129 }
130
131 VkImageCreateInfo makeImageCreateInfo (deUint32 width, deUint32 height, deUint32 depth, VkFormat format)
132 {
133         const VkImageUsageFlags usage                   = VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT;
134         const VkImageCreateInfo imageCreateInfo =
135         {
136                 VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO,    // VkStructureType                      sType;
137                 DE_NULL,                                                                // const void*                          pNext;
138                 (VkImageCreateFlags)0u,                                 // VkImageCreateFlags           flags;
139                 VK_IMAGE_TYPE_3D,                                               // VkImageType                          imageType;
140                 format,                                                                 // VkFormat                                     format;
141                 makeExtent3D(width, height, depth),             // VkExtent3D                           extent;
142                 1u,                                                                             // deUint32                                     mipLevels;
143                 1u,                                                                             // deUint32                                     arrayLayers;
144                 VK_SAMPLE_COUNT_1_BIT,                                  // VkSampleCountFlagBits        samples;
145                 VK_IMAGE_TILING_OPTIMAL,                                // VkImageTiling                        tiling;
146                 usage,                                                                  // VkImageUsageFlags            usage;
147                 VK_SHARING_MODE_EXCLUSIVE,                              // VkSharingMode                        sharingMode;
148                 0u,                                                                             // deUint32                                     queueFamilyIndexCount;
149                 DE_NULL,                                                                // const deUint32*                      pQueueFamilyIndices;
150                 VK_IMAGE_LAYOUT_UNDEFINED                               // VkImageLayout                        initialLayout;
151         };
152
153         return imageCreateInfo;
154 }
155
156 Move<VkPipelineLayout> makePipelineLayout (const DeviceInterface&               vk,
157                                                                                    const VkDevice                               device,
158                                                                                    const VkDescriptorSetLayout  descriptorSetLayout,
159                                                                                    const deUint32                               pushConstantsSize)
160 {
161         const VkDescriptorSetLayout*            descriptorSetLayoutPtr  = (descriptorSetLayout == DE_NULL) ? DE_NULL : &descriptorSetLayout;
162         const deUint32                                          setLayoutCount                  = (descriptorSetLayout == DE_NULL) ? 0u : 1u;
163         const VkPushConstantRange                       pushConstantRange               =
164         {
165                 ALL_RAY_TRACING_STAGES,         //  VkShaderStageFlags  stageFlags;
166                 0u,                                                     //  deUint32                    offset;
167                 pushConstantsSize,                      //  deUint32                    size;
168         };
169         const VkPushConstantRange*                      pPushConstantRanges             = (pushConstantsSize == 0) ? DE_NULL : &pushConstantRange;
170         const deUint32                                          pushConstantRangeCount  = (pushConstantsSize == 0) ? 0 : 1u;
171         const VkPipelineLayoutCreateInfo        pipelineLayoutParams    =
172         {
173                 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,  // VkStructureType                                      sType;
174                 DE_NULL,                                                                                // const void*                                          pNext;
175                 0u,                                                                                             // VkPipelineLayoutCreateFlags          flags;
176                 setLayoutCount,                                                                 // deUint32                                                     setLayoutCount;
177                 descriptorSetLayoutPtr,                                                 // const VkDescriptorSetLayout*         pSetLayouts;
178                 pushConstantRangeCount,                                                 // deUint32                                                     pushConstantRangeCount;
179                 pPushConstantRanges,                                                    // const VkPushConstantRange*           pPushConstantRanges;
180         };
181
182         return createPipelineLayout(vk, device, &pipelineLayoutParams);
183 }
184
185 VkBuffer getVkBuffer (const de::MovePtr<BufferWithMemory>& buffer)
186 {
187         VkBuffer result = (buffer.get() == DE_NULL) ? DE_NULL : buffer->get();
188
189         return result;
190 }
191
192 VkStridedDeviceAddressRegionKHR makeStridedDeviceAddressRegion (const DeviceInterface& vkd, const VkDevice device, VkBuffer buffer, deUint32 stride, deUint32 count)
193 {
194         if (buffer == DE_NULL)
195         {
196                 return makeStridedDeviceAddressRegionKHR(0, 0, 0);
197         }
198         else
199         {
200                 return makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, buffer, 0), stride, stride * count);
201         }
202 }
203
204 // Function replacing all occurrences of substring with string passed in last parameter.
205 static inline std::string replace(const std::string& str, const std::string& from, const std::string& to)
206 {
207         std::string result(str);
208
209         size_t start_pos = 0;
210         while((start_pos = result.find(from, start_pos)) != std::string::npos)
211         {
212                 result.replace(start_pos, from.length(), to);
213                 start_pos += to.length();
214         }
215
216         return result;
217 }
218
219
220 class RayTracingComplexControlFlowInstance : public TestInstance
221 {
222 public:
223                                                                                                                                 RayTracingComplexControlFlowInstance    (Context& context, const CaseDef& data);
224                                                                                                                                 ~RayTracingComplexControlFlowInstance   (void);
225         tcu::TestStatus                                                                                         iterate                                                                 (void);
226
227 protected:
228         void                                                                                                            calcShaderGroup                                                 (deUint32&                                      shaderGroupCounter,
229                                                                                                                                                                                                                  const VkShaderStageFlags       shaders1,
230                                                                                                                                                                                                                  const VkShaderStageFlags       shaders2,
231                                                                                                                                                                                                                  const VkShaderStageFlags       shaderStageFlags,
232                                                                                                                                                                                                                  deUint32&                                      shaderGroup,
233                                                                                                                                                                                                                  deUint32&                                      shaderGroupCount) const;
234         void                                                                                                            checkSupportInInstance                                  (void) const;
235         PushConstants                                                                                           getPushConstants                                                (void) const;
236         std::vector<deUint32>                                                                           getExpectedValues                                               (void) const;
237         de::MovePtr<BufferWithMemory>                                                           runTest                                                                 (void);
238         Move<VkPipeline>                                                                                        makePipeline                                                    (de::MovePtr<RayTracingPipeline>&                                                       rayTracingPipeline,
239                                                                                                                                                                                                                  VkPipelineLayout                                                                                       pipelineLayout);
240         de::MovePtr<BufferWithMemory>                                                           createShaderBindingTable                                 (const InstanceInterface&                                                                      vki,
241                                                                                                                                                                                                                  const DeviceInterface&                                                                         vkd,
242                                                                                                                                                                                                                  const VkDevice                                                                                         device,
243                                                                                                                                                                                                                  const VkPhysicalDevice                                                                         physicalDevice,
244                                                                                                                                                                                                                  const VkPipeline                                                                                       pipeline,
245                                                                                                                                                                                                                  Allocator&                                                                                                     allocator,
246                                                                                                                                                                                                                  de::MovePtr<RayTracingPipeline>&                                                       rayTracingPipeline,
247                                                                                                                                                                                                                  const deUint32                                                                                         group,
248                                                                                                                                                                                                                  const deUint32                                                                                         groupCount = 1u);
249         de::MovePtr<TopLevelAccelerationStructure>                                      initTopAccelerationStructure                    (VkCommandBuffer                                                                                        cmdBuffer,
250                                                                                                                                                                                                                  vector<de::SharedPtr<BottomLevelAccelerationStructure> >&      bottomLevelAccelerationStructures);
251         vector<de::SharedPtr<BottomLevelAccelerationStructure>  >       initBottomAccelerationStructures                (VkCommandBuffer                                                                                        cmdBuffer);
252         de::MovePtr<BottomLevelAccelerationStructure>                           initBottomAccelerationStructure                 (VkCommandBuffer                                                                                        cmdBuffer,
253                                                                                                                                                                                                                  tcu::UVec2&                                                                                            startPos);
254
255 private:
256         CaseDef                                                                                                         m_data;
257         VkShaderStageFlags                                                                                      m_shaders;
258         VkShaderStageFlags                                                                                      m_shaders2;
259         deUint32                                                                                                        m_raygenShaderGroup;
260         deUint32                                                                                                        m_missShaderGroup;
261         deUint32                                                                                                        m_hitShaderGroup;
262         deUint32                                                                                                        m_callableShaderGroup;
263         deUint32                                                                                                        m_raygenShaderGroupCount;
264         deUint32                                                                                                        m_missShaderGroupCount;
265         deUint32                                                                                                        m_hitShaderGroupCount;
266         deUint32                                                                                                        m_callableShaderGroupCount;
267         deUint32                                                                                                        m_shaderGroupCount;
268         deUint32                                                                                                        m_depth;
269         PushConstants                                                                                           m_pushConstants;
270 };
271
272 RayTracingComplexControlFlowInstance::RayTracingComplexControlFlowInstance (Context& context, const CaseDef& data)
273         : vkt::TestInstance                             (context)
274         , m_data                                                (data)
275         , m_shaders                                             (0)
276         , m_shaders2                                    (0)
277         , m_raygenShaderGroup                   (~0u)
278         , m_missShaderGroup                             (~0u)
279         , m_hitShaderGroup                              (~0u)
280         , m_callableShaderGroup                 (~0u)
281         , m_raygenShaderGroupCount              (0)
282         , m_missShaderGroupCount                (0)
283         , m_hitShaderGroupCount                 (0)
284         , m_callableShaderGroupCount    (0)
285         , m_shaderGroupCount                    (0)
286         , m_depth                                               (16)
287         , m_pushConstants                               (getPushConstants())
288 {
289         const VkShaderStageFlags        hitStages       = VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
290         BinaryCollection&                       collection      = m_context.getBinaryCollection();
291         deUint32                                        shaderCount     = 0;
292
293         if (collection.contains("rgen")) m_shaders |= VK_SHADER_STAGE_RAYGEN_BIT_KHR;
294         if (collection.contains("ahit")) m_shaders |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
295         if (collection.contains("chit")) m_shaders |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
296         if (collection.contains("miss")) m_shaders |= VK_SHADER_STAGE_MISS_BIT_KHR;
297         if (collection.contains("sect")) m_shaders |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
298         if (collection.contains("call")) m_shaders |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
299
300         if (collection.contains("ahit2")) m_shaders2 |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
301         if (collection.contains("chit2")) m_shaders2 |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
302         if (collection.contains("miss2")) m_shaders2 |= VK_SHADER_STAGE_MISS_BIT_KHR;
303         if (collection.contains("sect2")) m_shaders2 |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
304
305         if (collection.contains("cal0")) m_shaders2 |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
306
307         for (BinaryCollection::Iterator it = collection.begin(); it != collection.end(); ++it)
308                 shaderCount++;
309
310         if (shaderCount != (deUint32)dePop32(m_shaders) + (deUint32)dePop32(m_shaders2))
311                 TCU_THROW(InternalError, "Unused shaders detected in the collection");
312
313         calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_RAYGEN_BIT_KHR,   m_raygenShaderGroup,   m_raygenShaderGroupCount);
314         calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_MISS_BIT_KHR,     m_missShaderGroup,     m_missShaderGroupCount);
315         calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, hitStages,                        m_hitShaderGroup,      m_hitShaderGroupCount);
316         calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_CALLABLE_BIT_KHR, m_callableShaderGroup, m_callableShaderGroupCount);
317 }
318
319 RayTracingComplexControlFlowInstance::~RayTracingComplexControlFlowInstance (void)
320 {
321 }
322
323 void RayTracingComplexControlFlowInstance::calcShaderGroup (deUint32&                                   shaderGroupCounter,
324                                                                                                                         const VkShaderStageFlags        shaders1,
325                                                                                                                         const VkShaderStageFlags        shaders2,
326                                                                                                                         const VkShaderStageFlags        shaderStageFlags,
327                                                                                                                         deUint32&                                       shaderGroup,
328                                                                                                                         deUint32&                                       shaderGroupCount) const
329 {
330         const deUint32  shader1Count = ((shaders1 & shaderStageFlags) != 0) ? 1 : 0;
331         const deUint32  shader2Count = ((shaders2 & shaderStageFlags) != 0) ? 1 : 0;
332
333         shaderGroupCount = shader1Count + shader2Count;
334
335         if (shaderGroupCount != 0)
336         {
337                 shaderGroup                     = shaderGroupCounter;
338                 shaderGroupCounter += shaderGroupCount;
339         }
340 }
341
342 Move<VkPipeline> RayTracingComplexControlFlowInstance::makePipeline (de::MovePtr<RayTracingPipeline>&   rayTracingPipeline,
343                                                                                                                                           VkPipelineLayout                                      pipelineLayout)
344 {
345         const DeviceInterface&  vkd                     = m_context.getDeviceInterface();
346         const VkDevice                  device          = m_context.getDevice();
347         vk::BinaryCollection&   collection      = m_context.getBinaryCollection();
348
349         if (0 != (m_shaders & VK_SHADER_STAGE_RAYGEN_BIT_KHR))                  rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR            , createShaderModule(vkd, device, collection.get("rgen"), 0), m_raygenShaderGroup);
350         if (0 != (m_shaders & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))                 rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR           , createShaderModule(vkd, device, collection.get("ahit"), 0), m_hitShaderGroup);
351         if (0 != (m_shaders & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))             rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR       , createShaderModule(vkd, device, collection.get("chit"), 0), m_hitShaderGroup);
352         if (0 != (m_shaders & VK_SHADER_STAGE_MISS_BIT_KHR))                    rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR                      , createShaderModule(vkd, device, collection.get("miss"), 0), m_missShaderGroup);
353         if (0 != (m_shaders & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))    rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR      , createShaderModule(vkd, device, collection.get("sect"), 0), m_hitShaderGroup);
354         if (0 != (m_shaders & VK_SHADER_STAGE_CALLABLE_BIT_KHR))                rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR          , createShaderModule(vkd, device, collection.get("call"), 0), m_callableShaderGroup + 1);
355
356         if (0 != (m_shaders2 & VK_SHADER_STAGE_CALLABLE_BIT_KHR))               rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR          , createShaderModule(vkd, device, collection.get("cal0"), 0), m_callableShaderGroup);
357         if (0 != (m_shaders2 & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))                rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR           , createShaderModule(vkd, device, collection.get("ahit2"), 0), m_hitShaderGroup + 1);
358         if (0 != (m_shaders2 & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))    rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR       , createShaderModule(vkd, device, collection.get("chit2"), 0), m_hitShaderGroup + 1);
359         if (0 != (m_shaders2 & VK_SHADER_STAGE_MISS_BIT_KHR))                   rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR                      , createShaderModule(vkd, device, collection.get("miss2"), 0), m_missShaderGroup + 1);
360         if (0 != (m_shaders2 & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))   rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR      , createShaderModule(vkd, device, collection.get("sect2"), 0), m_hitShaderGroup + 1);
361
362         Move<VkPipeline> pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout);
363
364         return pipeline;
365 }
366
367 de::MovePtr<BufferWithMemory> RayTracingComplexControlFlowInstance::createShaderBindingTable (const InstanceInterface&                  vki,
368                                                                                                                                                                                           const DeviceInterface&                        vkd,
369                                                                                                                                                                                           const VkDevice                                        device,
370                                                                                                                                                                                           const VkPhysicalDevice                        physicalDevice,
371                                                                                                                                                                                           const VkPipeline                                      pipeline,
372                                                                                                                                                                                           Allocator&                                            allocator,
373                                                                                                                                                                                           de::MovePtr<RayTracingPipeline>&      rayTracingPipeline,
374                                                                                                                                                                                           const deUint32                                        group,
375                                                                                                                                                                                           const deUint32                                        groupCount)
376 {
377         de::MovePtr<BufferWithMemory>   shaderBindingTable;
378
379         if (group < m_shaderGroupCount)
380         {
381                 const deUint32  shaderGroupHandleSize           = getShaderGroupSize(vki, physicalDevice);
382                 const deUint32  shaderGroupBaseAlignment        = getShaderGroupBaseAlignment(vki, physicalDevice);
383
384                 shaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, group, groupCount);
385         }
386
387         return shaderBindingTable;
388 }
389
390
391 de::MovePtr<TopLevelAccelerationStructure> RayTracingComplexControlFlowInstance::initTopAccelerationStructure (VkCommandBuffer                                                                                          cmdBuffer,
392                                                                                                                                                                                                                            vector<de::SharedPtr<BottomLevelAccelerationStructure> >&    bottomLevelAccelerationStructures)
393 {
394         const DeviceInterface&                                          vkd                     = m_context.getDeviceInterface();
395         const VkDevice                                                          device          = m_context.getDevice();
396         Allocator&                                                                      allocator       = m_context.getDefaultAllocator();
397         de::MovePtr<TopLevelAccelerationStructure>      result          = makeTopLevelAccelerationStructure();
398
399         result->setInstanceCount(bottomLevelAccelerationStructures.size());
400
401         for (size_t structNdx = 0; structNdx < bottomLevelAccelerationStructures.size(); ++structNdx)
402                 result->addInstance(bottomLevelAccelerationStructures[structNdx]);
403
404         result->createAndBuild(vkd, device, cmdBuffer, allocator);
405
406         return result;
407 }
408
409 de::MovePtr<BottomLevelAccelerationStructure> RayTracingComplexControlFlowInstance::initBottomAccelerationStructure (VkCommandBuffer    cmdBuffer,
410                                                                                                                                                                                                                                          tcu::UVec2&            startPos)
411 {
412         const DeviceInterface&                                                  vkd                             = m_context.getDeviceInterface();
413         const VkDevice                                                                  device                  = m_context.getDevice();
414         Allocator&                                                                              allocator               = m_context.getDefaultAllocator();
415         de::MovePtr<BottomLevelAccelerationStructure>   result                  = makeBottomLevelAccelerationStructure();
416         const float                                                                             z                               = (m_data.stage == VK_SHADER_STAGE_MISS_BIT_KHR) ? +1.0f : -1.0f;
417         std::vector<tcu::Vec3>                                                  geometryData;
418
419         DE_UNREF(startPos);
420
421         result->setGeometryCount(1);
422         geometryData.push_back(tcu::Vec3(0.0f, 0.0f, z));
423         geometryData.push_back(tcu::Vec3(1.0f, 1.0f, z));
424         result->addGeometry(geometryData, false);
425         result->createAndBuild(vkd, device, cmdBuffer, allocator);
426
427         return result;
428 }
429
430 vector<de::SharedPtr<BottomLevelAccelerationStructure> > RayTracingComplexControlFlowInstance::initBottomAccelerationStructures (VkCommandBuffer        cmdBuffer)
431 {
432         tcu::UVec2                                                                                                      startPos;
433         vector<de::SharedPtr<BottomLevelAccelerationStructure> >        result;
434         de::MovePtr<BottomLevelAccelerationStructure>                           bottomLevelAccelerationStructure        = initBottomAccelerationStructure(cmdBuffer, startPos);
435
436         result.push_back(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
437
438         return result;
439 }
440
441 PushConstants RayTracingComplexControlFlowInstance::getPushConstants (void) const
442 {
443         const                   deUint32        hitOfs  = 1;
444         const                   deUint32        miss    = 1;
445         PushConstants   result;
446
447         switch (m_data.testType)
448         {
449                 case TEST_TYPE_IF:
450                 {
451                         result = { 32 | 8 | 1, 10000, 0x0F, 0xF0, hitOfs, miss };
452
453                         break;
454                 }
455                 case TEST_TYPE_LOOP:
456                 {
457                         result = { 8, 10000, 0x0F, 100000, hitOfs, miss };
458
459                         break;
460                 }
461                 case TEST_TYPE_SWITCH:
462                 {
463                         result = { 3, 10000, 0x07, 100000, hitOfs, miss };
464
465                         break;
466                 }
467                 case TEST_TYPE_LOOP_DOUBLE_CALL:
468                 {
469                         result = { 7, 10000, 0x0F, 0xF0, hitOfs, miss };
470
471                         break;
472                 }
473                 case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
474                 {
475                         result = { 16, 5, 0x0F, 0xF0, hitOfs, miss };
476
477                         break;
478                 }
479                 case TEST_TYPE_NESTED_LOOP:
480                 {
481                         result = { 8, 5, 0x0F, 0x09, hitOfs, miss };
482
483                         break;
484                 }
485                 case TEST_TYPE_NESTED_LOOP_BEFORE:
486                 {
487                         result = { 9, 16, 0x0F, 10, hitOfs, miss };
488
489                         break;
490                 }
491                 case TEST_TYPE_NESTED_LOOP_AFTER:
492                 {
493                         result = { 9, 16, 0x0F, 10, hitOfs, miss };
494
495                         break;
496                 }
497                 case TEST_TYPE_FUNCTION_CALL:
498                 {
499                         result = { 0xFFB, 16, 10, 100000, hitOfs, miss };
500
501                         break;
502                 }
503                 case TEST_TYPE_NESTED_FUNCTION_CALL:
504                 {
505                         result = { 0xFFB, 16, 10, 100000, hitOfs, miss };
506
507                         break;
508                 }
509
510                 default:
511                         TCU_THROW(InternalError, "Unknown testType");
512         }
513
514         return result;
515 }
516
517 de::MovePtr<BufferWithMemory> RayTracingComplexControlFlowInstance::runTest (void)
518 {
519         const InstanceInterface&                                vki                                                                     = m_context.getInstanceInterface();
520         const DeviceInterface&                                  vkd                                                                     = m_context.getDeviceInterface();
521         const VkDevice                                                  device                                                          = m_context.getDevice();
522         const VkPhysicalDevice                                  physicalDevice                                          = m_context.getPhysicalDevice();
523         const deUint32                                                  queueFamilyIndex                                        = m_context.getUniversalQueueFamilyIndex();
524         const VkQueue                                                   queue                                                           = m_context.getUniversalQueue();
525         Allocator&                                                              allocator                                                       = m_context.getDefaultAllocator();
526         const VkFormat                                                  format                                                          = VK_FORMAT_R32_UINT;
527         const deUint32                                                  pushConstants[]                                         = { m_pushConstants.a, m_pushConstants.b, m_pushConstants.c, m_pushConstants.d, m_pushConstants.hitOfs, m_pushConstants.miss };
528         const deUint32                                                  pushConstantsSize                                       = sizeof(pushConstants);
529         const deUint32                                                  pixelCount                                                      = m_data.width * m_data.height * m_depth;
530         const deUint32                                                  shaderGroupHandleSize                           = getShaderGroupSize(vki, physicalDevice);
531
532         const Move<VkDescriptorSetLayout>               descriptorSetLayout                                     = DescriptorSetLayoutBuilder()
533                                                                                                                                                                                 .addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
534                                                                                                                                                                                 .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
535                                                                                                                                                                                 .build(vkd, device);
536         const Move<VkDescriptorPool>                    descriptorPool                                          = DescriptorPoolBuilder()
537                                                                                                                                                                                 .addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
538                                                                                                                                                                                 .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
539                                                                                                                                                                                 .build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
540         const Move<VkDescriptorSet>                             descriptorSet                                           = makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
541         const Move<VkPipelineLayout>                    pipelineLayout                                          = makePipelineLayout(vkd, device, descriptorSetLayout.get(), pushConstantsSize);
542         const Move<VkCommandPool>                               cmdPool                                                         = createCommandPool(vkd, device, 0, queueFamilyIndex);
543         const Move<VkCommandBuffer>                             cmdBuffer                                                       = allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
544
545         de::MovePtr<RayTracingPipeline>                 rayTracingPipeline                                      = de::newMovePtr<RayTracingPipeline>();
546         const Move<VkPipeline>                                  pipeline                                                        = makePipeline(rayTracingPipeline, *pipelineLayout);
547         const de::MovePtr<BufferWithMemory>             raygenShaderBindingTable                        = createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_raygenShaderGroup, m_raygenShaderGroupCount);
548         const de::MovePtr<BufferWithMemory>             missShaderBindingTable                          = createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_missShaderGroup, m_missShaderGroupCount);
549         const de::MovePtr<BufferWithMemory>             hitShaderBindingTable                           = createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_hitShaderGroup, m_hitShaderGroupCount);
550         const de::MovePtr<BufferWithMemory>             callableShaderBindingTable                      = createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_callableShaderGroup, m_callableShaderGroupCount);
551
552         const VkStridedDeviceAddressRegionKHR   raygenShaderBindingTableRegion          = makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(raygenShaderBindingTable),   shaderGroupHandleSize, m_raygenShaderGroupCount);
553         const VkStridedDeviceAddressRegionKHR   missShaderBindingTableRegion            = makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(missShaderBindingTable),     shaderGroupHandleSize, m_missShaderGroupCount);
554         const VkStridedDeviceAddressRegionKHR   hitShaderBindingTableRegion                     = makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(hitShaderBindingTable),      shaderGroupHandleSize, m_hitShaderGroupCount);
555         const VkStridedDeviceAddressRegionKHR   callableShaderBindingTableRegion        = makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(callableShaderBindingTable), shaderGroupHandleSize, m_callableShaderGroupCount);
556
557         const VkImageCreateInfo                                 imageCreateInfo                                         = makeImageCreateInfo(m_data.width, m_data.height, m_depth, format);
558         const VkImageSubresourceRange                   imageSubresourceRange                           = makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0, 1u);
559         const de::MovePtr<ImageWithMemory>              image                                                           = de::MovePtr<ImageWithMemory>(new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
560         const Move<VkImageView>                                 imageView                                                       = makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_3D, format, imageSubresourceRange);
561
562         const VkBufferCreateInfo                                bufferCreateInfo                                        = makeBufferCreateInfo(pixelCount*sizeof(deUint32), VK_BUFFER_USAGE_TRANSFER_DST_BIT);
563         const VkImageSubresourceLayers                  bufferImageSubresourceLayers            = makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
564         const VkBufferImageCopy                                 bufferImageRegion                                       = makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, m_depth), bufferImageSubresourceLayers);
565         de::MovePtr<BufferWithMemory>                   buffer                                                          = de::MovePtr<BufferWithMemory>(new BufferWithMemory(vkd, device, allocator, bufferCreateInfo, MemoryRequirement::HostVisible));
566
567         const VkDescriptorImageInfo                             descriptorImageInfo                                     = makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
568
569         const VkImageMemoryBarrier                              preImageBarrier                                         = makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT,
570                                                                                                                                                                         VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
571                                                                                                                                                                         **image, imageSubresourceRange);
572         const VkImageMemoryBarrier                              postImageBarrier                                        = makeImageMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT,
573                                                                                                                                                                         VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL,
574                                                                                                                                                                         **image, imageSubresourceRange);
575         const VkMemoryBarrier                                   preTraceMemoryBarrier                           = makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
576         const VkMemoryBarrier                                   postTraceMemoryBarrier                          = makeMemoryBarrier(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
577         const VkMemoryBarrier                                   postCopyMemoryBarrier                           = makeMemoryBarrier(VK_ACCESS_TRANSFER_READ_BIT, 0);
578         const VkClearValue                                              clearValue                                                      = makeClearValueColorU32(DEFAULT_CLEAR_VALUE, 0u, 0u, 255u);
579
580         vector<de::SharedPtr<BottomLevelAccelerationStructure> >        bottomLevelAccelerationStructures;
581         de::MovePtr<TopLevelAccelerationStructure>                                      topLevelAccelerationStructure;
582
583         DE_ASSERT(DE_LENGTH_OF_ARRAY(pushConstants) == PUSH_CONSTANTS_COUNT);
584
585         beginCommandBuffer(vkd, *cmdBuffer, 0u);
586         {
587                 vkd.cmdPushConstants(*cmdBuffer, *pipelineLayout, ALL_RAY_TRACING_STAGES, 0, pushConstantsSize, &m_pushConstants);
588
589                 cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
590                 vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1, &imageSubresourceRange);
591                 cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, ALL_RAY_TRACING_STAGES, &postImageBarrier);
592
593                 bottomLevelAccelerationStructures = initBottomAccelerationStructures(*cmdBuffer);
594                 topLevelAccelerationStructure = initTopAccelerationStructure(*cmdBuffer, bottomLevelAccelerationStructures);
595
596                 cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, ALL_RAY_TRACING_STAGES, &preTraceMemoryBarrier);
597
598                 const TopLevelAccelerationStructure*                    topLevelAccelerationStructurePtr                = topLevelAccelerationStructure.get();
599                 VkWriteDescriptorSetAccelerationStructureKHR    accelerationStructureWriteDescriptorSet =
600                 {
601                         VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,      //  VkStructureType                                             sType;
602                         DE_NULL,                                                                                                                        //  const void*                                                 pNext;
603                         1u,                                                                                                                                     //  deUint32                                                    accelerationStructureCount;
604                         topLevelAccelerationStructurePtr->getPtr(),                                                     //  const VkAccelerationStructureKHR*   pAccelerationStructures;
605                 };
606
607                 DescriptorSetUpdateBuilder()
608                         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &descriptorImageInfo)
609                         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
610                         .update(vkd, device);
611
612                 vkd.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1, &descriptorSet.get(), 0, DE_NULL);
613
614                 vkd.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
615
616                 cmdTraceRays(vkd,
617                         *cmdBuffer,
618                         &raygenShaderBindingTableRegion,
619                         &missShaderBindingTableRegion,
620                         &hitShaderBindingTableRegion,
621                         &callableShaderBindingTableRegion,
622                         m_data.width, m_data.height, 1);
623
624                 cmdPipelineMemoryBarrier(vkd, *cmdBuffer, ALL_RAY_TRACING_STAGES, VK_PIPELINE_STAGE_TRANSFER_BIT, &postTraceMemoryBarrier);
625
626                 vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **buffer, 1u, &bufferImageRegion);
627
628                 cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT, &postCopyMemoryBarrier);
629         }
630         endCommandBuffer(vkd, *cmdBuffer);
631
632         submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
633
634         invalidateMappedMemoryRange(vkd, device, buffer->getAllocation().getMemory(), buffer->getAllocation().getOffset(), pixelCount * sizeof(deUint32));
635
636         return buffer;
637 }
638
639 void RayTracingComplexControlFlowInstance::checkSupportInInstance (void) const
640 {
641 }
642
643 std::vector<deUint32> RayTracingComplexControlFlowInstance::getExpectedValues (void) const
644 {
645         const deUint32                          plainSize               = m_data.width * m_data.height;
646         const deUint32                          plain8Ofs               = 8 * plainSize;
647         const struct PushConstants&     p                               = m_pushConstants;
648         const deUint32                          pushConstants[] = { 0, m_pushConstants.a, m_pushConstants.b, m_pushConstants.c, m_pushConstants.d, m_pushConstants.hitOfs, m_pushConstants.miss };
649         const deUint32                          resultSize              = plainSize * m_depth;
650         const bool                                      fixed                   = m_data.testOp == TEST_OP_REPORT_INTERSECTION;
651         std::vector<deUint32>           result                  (resultSize, DEFAULT_CLEAR_VALUE);
652         deUint32                                        v0;
653         deUint32                                        v1;
654         deUint32                                        v2;
655         deUint32                                        v3;
656
657         switch (m_data.testType)
658         {
659                 case TEST_TYPE_IF:
660                 {
661                         for (deUint32 id = 0; id < plainSize; ++id)
662                         {
663                                 v2 = v3 = p.b;
664
665                                 if ((p.a & id) != 0)
666                                 {
667                                         v0 = p.c & id;
668                                         v1 = (p.d & id) + 1;
669
670                                         result[plain8Ofs + id] = v0;
671                                         if (!fixed) v0++;
672                                 }
673                                 else
674                                 {
675                                         v0 = p.d & id;
676                                         v1 = (p.c & id) + 1;
677
678                                         if (!fixed)
679                                         {
680                                                 result[plain8Ofs + id] = v1;
681                                                 v1++;
682                                         }
683                                         else
684                                                 result[plain8Ofs + id] = v0;
685                                 }
686
687                                 result[id] = v0 + v1 + v2 + v3;
688                         }
689
690                         break;
691                 }
692                 case TEST_TYPE_LOOP:
693                 {
694                         for (deUint32 id = 0; id < plainSize; ++id)
695                         {
696                                 result[id] = 0;
697
698                                 v1 = v3 = p.b;
699
700                                 for (deUint32 n = 0; n < p.a; n++)
701                                 {
702                                         v0 = (p.c & id) + n;
703
704                                         result[((n % 8) + 8) * plainSize + id] = v0;
705                                         if (!fixed) v0++;
706
707                                         result[id] += v0 + v1 + v3;
708                                 }
709                         }
710
711                         break;
712                 }
713                 case TEST_TYPE_SWITCH:
714                 {
715                         for (deUint32 id = 0; id < plainSize; ++id)
716                         {
717                                 switch (p.a & id)
718                                 {
719                                         case 0: { v1 = v2 = v3 = p.b; v0 = p.c & id; break; }
720                                         case 1: { v0 = v2 = v3 = p.b; v1 = p.c & id; break; }
721                                         case 2: { v0 = v1 = v3 = p.b; v2 = p.c & id; break; }
722                                         case 3: { v0 = v1 = v2 = p.b; v3 = p.c & id; break; }
723                                         default: { v0 = v1 = v2 = v3 = 0; break; }
724                                 };
725
726                                 if (!fixed)
727                                         result[plain8Ofs + id] = p.c & id;
728                                 else
729                                         result[plain8Ofs + id] = v0;
730
731                                 result[id] = v0 + v1 + v2 + v3;
732
733                                 if (!fixed) result[id]++;
734                         }
735
736                         break;
737                 }
738                 case TEST_TYPE_LOOP_DOUBLE_CALL:
739                 {
740                         for (deUint32 id = 0; id < plainSize; ++id)
741                         {
742                                 result[id] = 0;
743
744                                 v3 = p.b;
745
746                                 for (deUint32 x = 0; x < p.a; x++)
747                                 {
748                                         v0 = (p.c & id) + x;
749                                         v1 = (p.d & id) + x + 1;
750
751                                         result[(((2 * x + 0) % 8) + 8) * plainSize + id] = v0;
752                                         if (!fixed) v0++;
753
754                                         if (!fixed)
755                                         {
756                                                 result[(((2 * x + 1) % 8) + 8) * plainSize + id] = v1;
757                                                 v1++;
758                                         }
759
760                                         result[id] += v0 + v1 + v3;
761                                 }
762                         }
763
764                         break;
765                 }
766                 case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
767                 {
768                         for (deUint32 id = 0; id < plainSize; ++id)
769                         {
770                                 result[id] = 0;
771
772                                 v3 = p.a + p.b;
773
774                                 for (deUint32 x = 0; x < p.a; x++)
775                                 {
776                                         if ((x & p.b) != 0)
777                                         {
778                                                 v0 = (p.c & id) + x;
779                                                 v1 = (p.d & id) + x + 1;
780
781                                                 result[(((2 * x + 0) % 8) + 8) * plainSize + id] = v0;
782                                                 if (!fixed) v0++;
783
784                                                 if (!fixed)
785                                                 {
786                                                         result[(((2 * x + 1) % 8) + 8) * plainSize + id] = v1;
787                                                         v1++;
788                                                 }
789
790                                                 result[id] += v0 + v1 + v3;
791                                         }
792                                 }
793                         }
794
795                         break;
796                 }
797                 case TEST_TYPE_NESTED_LOOP:
798                 {
799                         for (deUint32 id = 0; id < plainSize; ++id)
800                         {
801                                 result[id] = 0;
802
803                                 v1 = v3 = p.b;
804
805                                 for (deUint32 y = 0; y < p.a; y++)
806                                 for (deUint32 x = 0; x < p.a; x++)
807                                 {
808                                         const deUint32 n = x + y * p.a;
809
810                                         if ((n & p.d) != 0)
811                                         {
812                                                 v0 = (p.c & id) + n;
813
814                                                 result[((n % 8) + 8) * plainSize + id] = v0;
815                                                 if (!fixed) v0++;
816
817                                                 result[id] += v0 + v1 + v3;
818                                         }
819                                 }
820                         }
821
822                         break;
823                 }
824                 case TEST_TYPE_NESTED_LOOP_BEFORE:
825                 {
826                         for (deUint32 id = 0; id < plainSize; ++id)
827                         {
828                                 result[id] = 0;
829
830                                 for (deUint32 y = 0; y < p.d; y++)
831                                 for (deUint32 x = 0; x < p.d; x++)
832                                 {
833                                         if (((x + y * p.a) & p.b) != 0)
834                                                 result[id] += (x + y);
835                                 }
836
837                                 v1 = v3 = p.a;
838
839                                 for (deUint32 x = 0; x < p.b; x++)
840                                 {
841                                         if ((x & p.a) != 0)
842                                         {
843                                                 v0 = p.c & id;
844
845                                                 result[((x % 8) + 8) * plainSize + id] = v0;
846                                                 if (!fixed) v0++;
847
848                                                 result[id] += v0 + v1 + v3;
849                                         }
850                                 }
851                         }
852
853                         break;
854                 }
855                 case TEST_TYPE_NESTED_LOOP_AFTER:
856                 {
857                         for (deUint32 id = 0; id < plainSize; ++id)
858                         {
859                                 result[id] = 0;
860
861                                 v1 = v3 = p.a;
862
863                                 for (deUint32 x = 0; x < p.b; x++)
864                                 {
865                                         if ((x & p.a) != 0)
866                                         {
867                                                 v0 = p.c & id;
868
869                                                 result[((x % 8) + 8) * plainSize + id] = v0;
870                                                 if (!fixed) v0++;
871
872                                                 result[id] += v0 + v1 + v3;
873                                         }
874                                 }
875
876                                 for (deUint32 y = 0; y < p.d; y++)
877                                 for (deUint32 x = 0; x < p.d; x++)
878                                 {
879                                         if (((x + y * p.a) & p.b) != 0)
880                                                 result[id] += (x + y);
881                                 }
882                         }
883
884                         break;
885                 }
886                 case TEST_TYPE_FUNCTION_CALL:
887                 {
888                         deUint32 a[42];
889
890                         for (deUint32 id = 0; id < plainSize; ++id)
891                         {
892                                 deUint32 r = 0;
893                                 deUint32 i;
894
895                                 v0 = p.a & id;
896                                 v1 = v3 = p.d;
897
898                                 for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
899                                         a[i] = p.c * i;
900
901                                 result[plain8Ofs + id] = v0;
902                                 if (!fixed) v0++;
903
904                                 for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
905                                         r += a[i];
906
907                                 result[id] = (r + i) + v0 + v1 + v3;
908                         }
909
910                         break;
911                 }
912                 case TEST_TYPE_NESTED_FUNCTION_CALL:
913                 {
914                         deUint32 a[14];
915                         deUint32 b[256];
916
917                         for (deUint32 id = 0; id < plainSize; ++id)
918                         {
919                                 deUint32 r = 0;
920                                 deUint32 i;
921                                 deUint32 t = 0;
922                                 deUint32 j;
923
924                                 v0 = p.a & id;
925                                 v3 = p.d;
926
927                                 for (j = 0; j < DE_LENGTH_OF_ARRAY(b); j++)
928                                         b[j] = p.c * j;
929
930                                 v1 = p.b;
931
932                                 for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
933                                         a[i] = p.c * i;
934
935                                 result[plain8Ofs + id] = v0;
936                                 if (!fixed) v0++;
937
938                                 for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
939                                         r += a[i];
940
941                                 for (j = 0; j < DE_LENGTH_OF_ARRAY(b); j++)
942                                         t += b[j];
943
944                                 result[id] = (r + i) + (t + j) + v0 + v1 + v3;
945                         }
946
947                         break;
948                 }
949
950                 default:
951                         TCU_THROW(InternalError, "Unknown testType");
952         }
953
954         {
955                 const deUint32  startOfs        = 7 * plainSize;
956
957                 for (deUint32 n = 0; n < plainSize; ++n)
958                         result[startOfs + n] = n;
959         }
960
961         for (deUint32 z = 1; z < DE_LENGTH_OF_ARRAY(pushConstants); ++z)
962         {
963                 const deUint32  startOfs                = z * plainSize;
964                 const deUint32  pushConstant    = pushConstants[z];
965
966                 for (deUint32 n = 0; n < plainSize; ++n)
967                         result[startOfs + n] = pushConstant;
968         }
969
970         return result;
971 }
972
973 tcu::TestStatus RayTracingComplexControlFlowInstance::iterate (void)
974 {
975         checkSupportInInstance();
976
977         const de::MovePtr<BufferWithMemory>     buffer          = runTest();
978         const deUint32*                                         bufferPtr       = (deUint32*)buffer->getAllocation().getHostPtr();
979         const vector<deUint32>                          expected        = getExpectedValues();
980         tcu::TestLog&                                           log                     = m_context.getTestContext().getLog();
981         deUint32                                                        failures        = 0;
982         deUint32                                                        pos                     = 0;
983
984         for (deUint32 z = 0; z < m_depth; ++z)
985         for (deUint32 y = 0; y < m_data.height; ++y)
986         for (deUint32 x = 0; x < m_data.width; ++x)
987         {
988                 if (bufferPtr[pos] != expected[pos])
989                         failures++;
990
991                 ++pos;
992         }
993
994         if (failures != 0)
995         {
996                 deUint32                        pos0    = 0;
997                 deUint32                        pos1    = 0;
998                 std::stringstream       css;
999
1000                 for (deUint32 z = 0; z < m_depth; ++z)
1001                 {
1002                         css << "z=" << z << std::endl;
1003
1004                         for (deUint32 y = 0; y < m_data.height; ++y)
1005                         {
1006                                 for (deUint32 x = 0; x < m_data.width; ++x)
1007                                         css << std::setw(6) << bufferPtr[pos0++] << ' ';
1008
1009                                 css << "    ";
1010
1011                                 for (deUint32 x = 0; x < m_data.width; ++x)
1012                                         css << std::setw(6) << expected[pos1++] << ' ';
1013
1014                                 css << std::endl;
1015                         }
1016
1017                         css << std::endl;
1018                 }
1019
1020                 log << tcu::TestLog::Message << css.str() << tcu::TestLog::EndMessage;
1021         }
1022
1023         if (failures == 0)
1024                 return tcu::TestStatus::pass("Pass");
1025         else
1026                 return tcu::TestStatus::fail("failures=" + de::toString(failures));
1027 }
1028
1029 class ComplexControlFlowTestCase : public TestCase
1030 {
1031         public:
1032                                                                                 ComplexControlFlowTestCase      (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
1033                                                                                 ~ComplexControlFlowTestCase     (void);
1034
1035         virtual void                                            initPrograms                            (SourceCollections& programCollection) const;
1036         virtual TestInstance*                           createInstance                          (Context& context) const;
1037         virtual void                                            checkSupport                            (Context& context) const;
1038
1039 private:
1040         static inline const std::string         getIntersectionPassthrough      (void);
1041         static inline const std::string         getMissPassthrough                      (void);
1042         static inline const std::string         getHitPassthrough                       (void);
1043
1044         CaseDef                                 m_data;
1045 };
1046
1047 ComplexControlFlowTestCase::ComplexControlFlowTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
1048         : vkt::TestCase (context, name, desc)
1049         , m_data                (data)
1050 {
1051 }
1052
1053 ComplexControlFlowTestCase::~ComplexControlFlowTestCase (void)
1054 {
1055 }
1056
1057 void ComplexControlFlowTestCase::checkSupport (Context& context) const
1058 {
1059         context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
1060         const VkPhysicalDeviceAccelerationStructureFeaturesKHR& accelerationStructureFeaturesKHR = context.getAccelerationStructureFeatures();
1061         if (accelerationStructureFeaturesKHR.accelerationStructure == DE_FALSE)
1062                 TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
1063
1064         context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
1065         const VkPhysicalDeviceRayTracingPipelineFeaturesKHR&    rayTracingPipelineFeaturesKHR = context.getRayTracingPipelineFeatures();
1066         if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == DE_FALSE)
1067                 TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
1068
1069 }
1070
1071
1072 const std::string ComplexControlFlowTestCase::getIntersectionPassthrough (void)
1073 {
1074         const std::string intersectionPassthrough =
1075                 "#version 460 core\n"
1076                 "#extension GL_EXT_nonuniform_qualifier : enable\n"
1077                 "#extension GL_EXT_ray_tracing : require\n"
1078                 "hitAttributeEXT vec3 hitAttribute;\n"
1079                 "\n"
1080                 "void main()\n"
1081                 "{\n"
1082                 "  reportIntersectionEXT(0.95f, 0u);\n"
1083                 "}\n";
1084
1085         return intersectionPassthrough;
1086 }
1087
1088 const std::string ComplexControlFlowTestCase::getMissPassthrough (void)
1089 {
1090         const std::string missPassthrough =
1091                 "#version 460 core\n"
1092                 "#extension GL_EXT_nonuniform_qualifier : enable\n"
1093                 "#extension GL_EXT_ray_tracing : require\n"
1094                 "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1095                 "\n"
1096                 "void main()\n"
1097                 "{\n"
1098                 "}\n";
1099
1100         return missPassthrough;
1101 }
1102
1103 const std::string ComplexControlFlowTestCase::getHitPassthrough (void)
1104 {
1105         const std::string hitPassthrough =
1106                 "#version 460 core\n"
1107                 "#extension GL_EXT_nonuniform_qualifier : enable\n"
1108                 "#extension GL_EXT_ray_tracing : require\n"
1109                 "hitAttributeEXT vec3 attribs;\n"
1110                 "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1111                 "\n"
1112                 "void main()\n"
1113                 "{\n"
1114                 "}\n";
1115
1116         return hitPassthrough;
1117 }
1118
1119 void ComplexControlFlowTestCase::initPrograms (SourceCollections& programCollection) const
1120 {
1121         const vk::ShaderBuildOptions    buildOptions                    (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
1122         const std::string                               calleeMainPart                  =
1123                 "  uint z = (inValue.x % 8) + 8;\n"
1124                 "  uint v = inValue.y;\n"
1125                 "  uint n = gl_LaunchIDEXT.x + gl_LaunchSizeEXT.x * gl_LaunchIDEXT.y;\n"
1126                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, z), uvec4(v, 0, 0, 1));\n"
1127                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 7), uvec4(n, 0, 0, 1));\n";
1128         const std::string                               idTemplate                              = "$";
1129         const std::string                               shaderCallInstruction   = (m_data.testOp == TEST_OP_EXECUTE_CALLABLE)    ? "executeCallableEXT(0, " + idTemplate + ")"
1130                                                                                                                         : (m_data.testOp == TEST_OP_TRACE_RAY)           ? "traceRayEXT(as, 0, 0xFF, p.hitOfs, 0, p.miss, vec3((gl_LaunchIDEXT.x) + vec3(0.5f)) / vec3(gl_LaunchSizeEXT), 1.0f, vec3(0.0f, 0.0f, 1.0f), 100.0f, " + idTemplate + ")"
1131                                                                                                                         : (m_data.testOp == TEST_OP_REPORT_INTERSECTION) ? "reportIntersectionEXT(1.0f, 0u)"
1132                                                                                                                         : "TEST_OP_NOT_IMPLEMENTED_FAILURE";
1133         std::string                                             declsPreMain                    =
1134                 "#version 460 core\n"
1135                 "#extension GL_EXT_nonuniform_qualifier : enable\n"
1136                 "#extension GL_EXT_ray_tracing : require\n"
1137                 "\n"
1138                 "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1139                 "layout(set = 0, binding = 1) uniform accelerationStructureEXT as;\n"
1140                 "\n"
1141                 "layout(push_constant) uniform TestParams\n"
1142                 "{\n"
1143                 "    uint a;\n"
1144                 "    uint b;\n"
1145                 "    uint c;\n"
1146                 "    uint d;\n"
1147                 "    uint hitOfs;\n"
1148                 "    uint miss;\n"
1149                 "} p;\n";
1150         std::string                                             declsInMainBeforeOp             =
1151                 "  uint result = 0;\n"
1152                 "  uint id = uint(gl_LaunchIDEXT.x + gl_LaunchSizeEXT.x * gl_LaunchIDEXT.y);\n";
1153         std::string                                             declsInMainAfterOp              =
1154                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 0), uvec4(result, 0, 0, 1));\n"
1155                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 1), uvec4(p.a, 0, 0, 1));\n"
1156                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 2), uvec4(p.b, 0, 0, 1));\n"
1157                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 3), uvec4(p.c, 0, 0, 1));\n"
1158                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 4), uvec4(p.d, 0, 0, 1));\n"
1159                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 5), uvec4(p.hitOfs, 0, 0, 1));\n"
1160                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 6), uvec4(p.miss, 0, 0, 1));\n";
1161         std::string                                             opInMain                                = "";
1162         std::string                                             opPreMain                               = "";
1163
1164         DE_ASSERT(!declsPreMain.empty() && PUSH_CONSTANTS_COUNT == 6);
1165
1166         switch (m_data.testType)
1167         {
1168                 case TEST_TYPE_IF:
1169                 {
1170                         opInMain =
1171                                 "  v2 = v3 = uvec2(0, p.b);\n"
1172                                 "\n"
1173                                 "  if ((p.a & id) != 0)\n"
1174                                 "      { v0 = uvec2(0, p.c & id); v1 = uvec2(0, (p.d & id) + 1);" + replace(shaderCallInstruction, idTemplate, "0") + "; }\n"
1175                                 "  else\n"
1176                                 "      { v0 = uvec2(0, p.d & id); v1 = uvec2(0, (p.c & id) + 1);" + replace(shaderCallInstruction, idTemplate, "1") + "; }\n"
1177                                 "\n"
1178                                 "  result = v0.y + v1.y + v2.y + v3.y;\n";
1179
1180                         break;
1181                 }
1182                 case TEST_TYPE_LOOP:
1183                 {
1184                         opInMain =
1185                                 "  v1 = v3 = uvec2(0, p.b);\n"
1186                                 "\n"
1187                                 "  for (uint x = 0; x < p.a; x++)\n"
1188                                 "  {\n"
1189                                 "    v0 = uvec2(x, (p.c & id) + x);\n"
1190                                 "    " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1191                                 "    result += v0.y + v1.y + v3.y;\n"
1192                                 "  }\n";
1193
1194                         break;
1195                 }
1196                 case TEST_TYPE_SWITCH:
1197                 {
1198                         opInMain =
1199                                 "  switch (p.a & id)\n"
1200                                 "  {\n"
1201                                 "    case 0: { v1 = v2 = v3 = uvec2(0, p.b); v0 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "0") + "; break; }\n"
1202                                 "    case 1: { v0 = v2 = v3 = uvec2(0, p.b); v1 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "1") + "; break; }\n"
1203                                 "    case 2: { v0 = v1 = v3 = uvec2(0, p.b); v2 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "2") + "; break; }\n"
1204                                 "    case 3: { v0 = v1 = v2 = uvec2(0, p.b); v3 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "3") + "; break; }\n"
1205                                 "    default: break;\n"
1206                                 "  }\n"
1207                                 "\n"
1208                                 "  result = v0.y + v1.y + v2.y + v3.y;\n";
1209
1210                         break;
1211                 }
1212                 case TEST_TYPE_LOOP_DOUBLE_CALL:
1213                 {
1214                         opInMain =
1215                                 "  v3 = uvec2(0, p.b);\n"
1216                                 "  for (uint x = 0; x < p.a; x++)\n"
1217                                 "  {\n"
1218                                 "    v0 = uvec2(2 * x + 0, (p.c & id) + x);\n"
1219                                 "    v1 = uvec2(2 * x + 1, (p.d & id) + x + 1);\n"
1220                                 "    " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1221                                 "    " + replace(shaderCallInstruction, idTemplate, "1") + ";\n"
1222                                 "    result += v0.y + v1.y + v3.y;\n"
1223                                 "  }\n";
1224
1225                         break;
1226                 }
1227                 case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
1228                 {
1229                         opInMain =
1230                                 "  v3 = uvec2(0, p.a + p.b);\n"
1231                                 "  for (uint x = 0; x < p.a; x++)\n"
1232                                 "    if ((x & p.b) != 0)\n"
1233                                 "    {\n"
1234                                 "      v0 = uvec2(2 * x + 0, (p.c & id) + x + 0);\n"
1235                                 "      v1 = uvec2(2 * x + 1, (p.d & id) + x + 1);\n"
1236                                 "      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1237                                 "      " + replace(shaderCallInstruction, idTemplate, "1") + ";\n"
1238                                 "      result += v0.y + v1.y + v3.y;\n"
1239                                 "    }\n"
1240                                 "\n";
1241
1242                         break;
1243                 }
1244                 case TEST_TYPE_NESTED_LOOP:
1245                 {
1246                         opInMain =
1247                                 "  v1 = v3 = uvec2(0, p.b);\n"
1248                                 "  for (uint y = 0; y < p.a; y++)\n"
1249                                 "  for (uint x = 0; x < p.a; x++)\n"
1250                                 "  {\n"
1251                                 "    uint n = x + y * p.a;\n"
1252                                 "    if ((n & p.d) != 0)\n"
1253                                 "    {\n"
1254                                 "      v0 = uvec2(n, (p.c & id) + (x + y * p.a));\n"
1255                                 "      "+ replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1256                                 "      result += v0.y + v1.y + v3.y;\n"
1257                                 "    }\n"
1258                                 "  }\n"
1259                                 "\n";
1260
1261                         break;
1262                 }
1263                 case TEST_TYPE_NESTED_LOOP_BEFORE:
1264                 {
1265                         opInMain =
1266                                 "  for (uint y = 0; y < p.d; y++)\n"
1267                                 "  for (uint x = 0; x < p.d; x++)\n"
1268                                 "    if (((x + y * p.a) & p.b) != 0)\n"
1269                                 "      result += (x + y);\n"
1270                                 "\n"
1271                                 "  v1 = v3 = uvec2(0, p.a);\n"
1272                                 "\n"
1273                                 "  for (uint x = 0; x < p.b; x++)\n"
1274                                 "    if ((x & p.a) != 0)\n"
1275                                 "    {\n"
1276                                 "      v0 = uvec2(x, p.c & id);\n"
1277                                 "      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1278                                 "      result += v0.y + v1.y + v3.y;\n"
1279                                 "    }\n";
1280
1281                         break;
1282                 }
1283                 case TEST_TYPE_NESTED_LOOP_AFTER:
1284                 {
1285                         opInMain =
1286                                 "  v1 = v3 = uvec2(0, p.a); \n"
1287                                 "  for (uint x = 0; x < p.b; x++)\n"
1288                                 "    if ((x & p.a) != 0)\n"
1289                                 "    {\n"
1290                                 "      v0 = uvec2(x, p.c & id);\n"
1291                                 "      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1292                                 "      result += v0.y + v1.y + v3.y;\n"
1293                                 "    }\n"
1294                                 "\n"
1295                                 "  for (uint y = 0; y < p.d; y++)\n"
1296                                 "  for (uint x = 0; x < p.d; x++)\n"
1297                                 "    if (((x + y * p.a) & p.b) != 0)\n"
1298                                 "      result += x + y;\n";
1299
1300                         break;
1301                 }
1302                 case TEST_TYPE_FUNCTION_CALL:
1303                 {
1304                         opPreMain =
1305                                 "uint f1(void)\n"
1306                                 "{\n"
1307                                 "  uint i, r = 0;\n"
1308                                 "  uint a[42];\n"
1309                                 "\n"
1310                                 "  for (i = 0; i < a.length(); i++) a[i] = p.c * i;\n"
1311                                 "\n"
1312                                 "  " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1313                                 "\n"
1314                                 "  for (i = 0; i < a.length(); i++) r += a[i];\n"
1315                                 "\n"
1316                                 "  return r + i;\n"
1317                                 "}\n";
1318                         opInMain =
1319                                 "  v0 = uvec2(0, p.a & id); v1 = v3 = uvec2(0, p.d);\n"
1320                                 "  result = f1() + v0.y + v1.y + v3.y;\n";
1321
1322                         break;
1323                 }
1324                 case TEST_TYPE_NESTED_FUNCTION_CALL:
1325                 {
1326                         opPreMain =
1327                                 "uint f0(void)\n"
1328                                 "{\n"
1329                                 "  uint i, r = 0;\n"
1330                                 "  uint a[14];\n"
1331                                 "\n"
1332                                 "  for (i = 0; i < a.length(); i++) a[i] = p.c * i;\n"
1333                                 "\n"
1334                                 "  " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1335                                 "\n"
1336                                 "  for (i = 0; i < a.length(); i++) r += a[i];\n"
1337                                 "\n"
1338                                 "  return r + i;\n"
1339                                 "}\n"
1340                                 "\n"
1341                                 "uint f1(void)\n"
1342                                 "{\n"
1343                                 "  uint j, t = 0;\n"
1344                                 "  uint b[256];\n"
1345                                 "\n"
1346                                 "  for (j = 0; j < b.length(); j++) b[j] = p.c * j;\n"
1347                                 "\n"
1348                                 "  v1 = uvec2(0, p.b);\n"
1349                                 "\n"
1350                                 "  t += f0();\n"
1351                                 "\n"
1352                                 "  for (j = 0; j < b.length(); j++) t += b[j];\n"
1353                                 "\n"
1354                                 "  return t + j;\n"
1355                                 "}\n";
1356                         opInMain =
1357                                 "  v0 = uvec2(0, p.a & id); v3 = uvec2(0, p.d);\n"
1358                                 "  result = f1() + v0.y + v1.y + v3.y;\n";
1359
1360                         break;
1361                 }
1362
1363                 default:
1364                         TCU_THROW(InternalError, "Unknown testType");
1365         }
1366
1367         if (m_data.testOp == TEST_OP_EXECUTE_CALLABLE)
1368         {
1369                 const std::string       calleeShader                    =
1370                         "#version 460 core\n"
1371                         "#extension GL_EXT_nonuniform_qualifier : enable\n"
1372                         "#extension GL_EXT_ray_tracing : require\n"
1373                         "\n"
1374                         "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1375                         "layout(location = 0) callableDataInEXT uvec2 inValue;\n"
1376                         "\n"
1377                         "void main()\n"
1378                         "{\n"
1379                         + calleeMainPart +
1380                         "  inValue.y++;\n"
1381                         "}\n";
1382
1383                 declsPreMain +=
1384                         "layout(location = 0) callableDataEXT uvec2 v0;\n"
1385                         "layout(location = 1) callableDataEXT uvec2 v1;\n"
1386                         "layout(location = 2) callableDataEXT uvec2 v2;\n"
1387                         "layout(location = 3) callableDataEXT uvec2 v3;\n"
1388                         "\n";
1389
1390                 switch (m_data.stage)
1391                 {
1392                         case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1393                         {
1394                                 std::stringstream css;
1395                                 css << declsPreMain
1396                                         << opPreMain
1397                                         << "\n"
1398                                         << "void main()\n"
1399                                         << "{\n"
1400                                         << declsInMainBeforeOp
1401                                         << opInMain // executeCallableEXT
1402                                         << declsInMainAfterOp
1403                                         << "}\n";
1404
1405                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1406                                 programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1407
1408                                 break;
1409                         }
1410
1411                         case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1412                         {
1413                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1414
1415                                 std::stringstream css;
1416                                 css << declsPreMain
1417                                         << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1418                                         << "hitAttributeEXT vec3 attribs;\n"
1419                                         << "\n"
1420                                         << opPreMain
1421                                         << "\n"
1422                                         << "void main()\n"
1423                                         << "{\n"
1424                                         << declsInMainBeforeOp
1425                                         << opInMain // executeCallableEXT
1426                                         << declsInMainAfterOp
1427                                         << "}\n";
1428
1429                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1430                                 programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1431
1432                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1433                                 programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1434                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1435
1436                                 break;
1437                         }
1438
1439                         case VK_SHADER_STAGE_MISS_BIT_KHR:
1440                         {
1441                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1442
1443                                 std::stringstream css;
1444                                 css << declsPreMain
1445                                         << opPreMain
1446                                         << "\n"
1447                                         << "void main()\n"
1448                                         << "{\n"
1449                                         << declsInMainBeforeOp
1450                                         << opInMain // executeCallableEXT
1451                                         << declsInMainAfterOp
1452                                         << "}\n";
1453
1454                                 programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1455                                 programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1456
1457                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1458                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1459                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1460
1461                                 break;
1462                         }
1463
1464                         case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
1465                         {
1466                                 {
1467                                         std::stringstream css;
1468                                         css << "#version 460 core\n"
1469                                                 << "#extension GL_EXT_nonuniform_qualifier : enable\n"
1470                                                 << "#extension GL_EXT_ray_tracing : require\n"
1471                                                 << "\n"
1472                                                 << "layout(location = 4) callableDataEXT float dummy;\n"
1473                                                 << "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1474                                                 << "\n"
1475                                                 << "void main()\n"
1476                                                 << "{\n"
1477                                                 << "  executeCallableEXT(1, 4);\n"
1478                                                 << "}\n";
1479
1480                                         programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1481                                 }
1482
1483                                 {
1484                                         std::stringstream css;
1485                                         css << declsPreMain
1486                                                 << "layout(location = 4) callableDataInEXT float dummyIn;\n"
1487                                                 << opPreMain
1488                                                 << "\n"
1489                                                 << "void main()\n"
1490                                                 << "{\n"
1491                                                 << declsInMainBeforeOp
1492                                                 << opInMain // executeCallableEXT
1493                                                 << declsInMainAfterOp
1494                                                 << "}\n";
1495
1496                                         programCollection.glslSources.add("call") << glu::CallableSource(css.str()) << buildOptions;
1497                                 }
1498
1499                                 programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1500
1501                                 break;
1502                         }
1503
1504                         default:
1505                                 TCU_THROW(InternalError, "Unknown stage");
1506                 }
1507         }
1508         else if (m_data.testOp == TEST_OP_TRACE_RAY)
1509         {
1510                 const std::string       missShader      =
1511                         "#version 460 core\n"
1512                         "#extension GL_EXT_nonuniform_qualifier : enable\n"
1513                         "#extension GL_EXT_ray_tracing : require\n"
1514                         "\n"
1515                         "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1516                         "layout(location = 0) rayPayloadInEXT uvec2 inValue;\n"
1517                         "\n"
1518                         "void main()\n"
1519                         "{\n"
1520                         + calleeMainPart +
1521                         "  inValue.y++;\n"
1522                         "}\n";
1523
1524                 declsPreMain +=
1525                         "layout(location = 0) rayPayloadEXT uvec2 v0;\n"
1526                         "layout(location = 1) rayPayloadEXT uvec2 v1;\n"
1527                         "layout(location = 2) rayPayloadEXT uvec2 v2;\n"
1528                         "layout(location = 3) rayPayloadEXT uvec2 v3;\n";
1529
1530                 switch (m_data.stage)
1531                 {
1532                         case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1533                         {
1534                                 std::stringstream css;
1535                                 css << declsPreMain
1536                                         << opPreMain
1537                                         << "\n"
1538                                         << "void main()\n"
1539                                         << "{\n"
1540                                         << declsInMainBeforeOp
1541                                         << opInMain // traceRayEXT
1542                                         << declsInMainAfterOp
1543                                         << "}\n";
1544
1545                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1546
1547                                 programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1548                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1549                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1550                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1551
1552                                 programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1553                                 programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1554                                 programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1555                                 programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1556
1557                                 break;
1558                         }
1559
1560                         case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1561                         {
1562                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1563
1564                                 std::stringstream css;
1565                                 css << declsPreMain
1566                                         << opPreMain
1567                                         << "\n"
1568                                         << "void main()\n"
1569                                         << "{\n"
1570                                         << declsInMainBeforeOp
1571                                         << opInMain // traceRayEXT
1572                                         << declsInMainAfterOp
1573                                         << "}\n";
1574
1575                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1576
1577                                 programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1578                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1579                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1580
1581                                 programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1582                                 programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1583                                 programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1584                                 programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1585
1586                                 break;
1587                         }
1588
1589                         case VK_SHADER_STAGE_MISS_BIT_KHR:
1590                         {
1591                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1592
1593                                 std::stringstream css;
1594                                 css << declsPreMain
1595                                         << opPreMain
1596                                         << "\n"
1597                                         << "void main()\n"
1598                                         << "{\n"
1599                                         << declsInMainBeforeOp
1600                                         << opInMain // traceRayEXT
1601                                         << declsInMainAfterOp
1602                                         << "}\n";
1603
1604                                 programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1605
1606                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1607                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1608                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1609
1610                                 programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1611                                 programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1612                                 programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1613                                 programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1614
1615                                 break;
1616                         }
1617
1618                         default:
1619                                 TCU_THROW(InternalError, "Unknown stage");
1620                 }
1621         }
1622         else if (m_data.testOp == TEST_OP_REPORT_INTERSECTION)
1623         {
1624                 const std::string       anyHitShader            =
1625                         "#version 460 core\n"
1626                         "#extension GL_EXT_nonuniform_qualifier : enable\n"
1627                         "#extension GL_EXT_ray_tracing : require\n"
1628                         "\n"
1629                         "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1630                         "hitAttributeEXT block { uvec2 inValue; };\n"
1631                         "\n"
1632                         "void main()\n"
1633                         "{\n"
1634                         + calleeMainPart +
1635                         "}\n";
1636
1637                 declsPreMain +=
1638                         "hitAttributeEXT block { uvec2 v0; };\n"
1639                         "uvec2 v1;\n"
1640                         "uvec2 v2;\n"
1641                         "uvec2 v3;\n";
1642
1643                 switch (m_data.stage)
1644                 {
1645                         case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
1646                         {
1647                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1648
1649                                 std::stringstream css;
1650                                 css << declsPreMain
1651                                         << opPreMain
1652                                         << "\n"
1653                                         << "void main()\n"
1654                                         << "{\n"
1655                                         << declsInMainBeforeOp
1656                                         << opInMain // reportIntersectionEXT
1657                                         << declsInMainAfterOp
1658                                         << "}\n";
1659
1660                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(css.str()) << buildOptions;
1661                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(anyHitShader) << buildOptions;
1662
1663                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1664                                 programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1665
1666                                 break;
1667                         }
1668
1669                         default:
1670                                 TCU_THROW(InternalError, "Unknown stage");
1671                 }
1672         }
1673         else
1674         {
1675                 TCU_THROW(InternalError, "Unknown operation");
1676         }
1677 }
1678
1679 TestInstance* ComplexControlFlowTestCase::createInstance (Context& context) const
1680 {
1681         return new RayTracingComplexControlFlowInstance(context, m_data);
1682 }
1683
1684 }       // anonymous
1685
1686 tcu::TestCaseGroup*     createComplexControlFlowTests (tcu::TestContext& testCtx)
1687 {
1688         static const struct
1689         {
1690                 const char*                             name;
1691                 VkShaderStageFlagBits   stage;
1692         }
1693         testStages[]
1694         {
1695                 { "rgen", VK_SHADER_STAGE_RAYGEN_BIT_KHR                },
1696                 { "chit", VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR   },
1697                 { "ahit", VK_SHADER_STAGE_ANY_HIT_BIT_KHR               },
1698                 { "sect", VK_SHADER_STAGE_INTERSECTION_BIT_KHR  },
1699                 { "miss", VK_SHADER_STAGE_MISS_BIT_KHR                  },
1700                 { "call", VK_SHADER_STAGE_CALLABLE_BIT_KHR              },
1701         };
1702         static const struct
1703         {
1704                 const char*     name;
1705                 TestOp          op;
1706         }
1707         testOps[]
1708         {
1709                 { "execute_callable",           TEST_OP_EXECUTE_CALLABLE        },
1710                 { "trace_ray",                          TEST_OP_TRACE_RAY                       },
1711                 { "report_intersection",        TEST_OP_REPORT_INTERSECTION     },
1712         };
1713         static const struct
1714         {
1715                 const char*     name;
1716                 TestType        testType;
1717         }
1718         testTypes[]
1719         {
1720                 { "if",                                                 TEST_TYPE_IF                                            },
1721                 { "loop",                                               TEST_TYPE_LOOP                                          },
1722                 { "switch",                                             TEST_TYPE_SWITCH                                        },
1723                 { "loop_double_call",                   TEST_TYPE_LOOP_DOUBLE_CALL                      },
1724                 { "loop_double_call_sparse",    TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE       },
1725                 { "nested_loop",                                TEST_TYPE_NESTED_LOOP                           },
1726                 { "nested_loop_loop_before",    TEST_TYPE_NESTED_LOOP_BEFORE            },
1727                 { "nested_loop_loop_after",             TEST_TYPE_NESTED_LOOP_AFTER                     },
1728                 { "function_call",                              TEST_TYPE_FUNCTION_CALL                         },
1729                 { "nested_function_call",               TEST_TYPE_NESTED_FUNCTION_CALL          },
1730         };
1731
1732         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "complexcontrolflow", "Ray tracing complex control flow tests"));
1733
1734         for (size_t testTypeNdx = 0; testTypeNdx < DE_LENGTH_OF_ARRAY(testTypes); ++testTypeNdx)
1735         {
1736                 const TestType                                  testType                = testTypes[testTypeNdx].testType;
1737                 de::MovePtr<tcu::TestCaseGroup> testTypeGroup   (new tcu::TestCaseGroup(testCtx, testTypes[testTypeNdx].name, ""));
1738
1739                 for (size_t testOpNdx = 0; testOpNdx < DE_LENGTH_OF_ARRAY(testOps); ++testOpNdx)
1740                 {
1741                         const TestOp                                    testOp          = testOps[testOpNdx].op;
1742                         de::MovePtr<tcu::TestCaseGroup> testOpGroup     (new tcu::TestCaseGroup(testCtx, testOps[testOpNdx].name, ""));
1743
1744                         for (size_t testStagesNdx = 0; testStagesNdx < DE_LENGTH_OF_ARRAY(testStages); ++testStagesNdx)
1745                         {
1746                                 const VkShaderStageFlagBits     testStage                               = testStages[testStagesNdx].stage;
1747                                 const std::string                       testName                                = de::toString(testStages[testStagesNdx].name);
1748                                 const deUint32                          width                                   = 4u;
1749                                 const deUint32                          height                                  = 4u;
1750                                 const CaseDef                           caseDef                                 =
1751                                 {
1752                                         testType,                               //  TestType                            testType;
1753                                         testOp,                                 //  TestOp                                      testOp;
1754                                         testStage,                              //  VkShaderStageFlagBits       stage;
1755                                         width,                                  //  deUint32                            width;
1756                                         height,                                 //  deUint32                            height;
1757                                 };
1758
1759                                 if (testOp == TEST_OP_REPORT_INTERSECTION && testStage != VK_SHADER_STAGE_INTERSECTION_BIT_KHR)
1760                                         continue;
1761
1762                                 if (testOp == TEST_OP_TRACE_RAY)
1763                                 {
1764                                         switch (testStage)
1765                                         {
1766                                                 case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1767                                                 case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1768                                                 case VK_SHADER_STAGE_MISS_BIT_KHR:
1769                                                         break;
1770                                                 default:
1771                                                         continue;
1772                                         }
1773                                 }
1774
1775                                 if (testOp == TEST_OP_EXECUTE_CALLABLE)
1776                                 {
1777                                         switch (testStage)
1778                                         {
1779                                                 case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1780                                                 case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1781                                                 case VK_SHADER_STAGE_MISS_BIT_KHR:
1782                                                 case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
1783                                                         break;
1784                                                 default:
1785                                                         continue;
1786                                         }
1787                                 }
1788
1789                                 testOpGroup->addChild(new ComplexControlFlowTestCase(testCtx, testName.c_str(), "", caseDef));
1790                         }
1791
1792                         testTypeGroup->addChild(testOpGroup.release());
1793                 }
1794
1795                 group->addChild(testTypeGroup.release());
1796         }
1797
1798         return group.release();
1799 }
1800
1801 }       // RayTracing
1802 }       // vkt