Fix missing dependency on sparse binds
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / modules / vulkan / ray_tracing / vktRayTracingComplexControlFlowTests.cpp
1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *        http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Ray Tracing Complex Control Flow tests
22  *//*--------------------------------------------------------------------*/
23
24 #include "vktRayTracingComplexControlFlowTests.hpp"
25
26 #include "vkDefs.hpp"
27
28 #include "vktTestCase.hpp"
29 #include "vkCmdUtil.hpp"
30 #include "vkObjUtil.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkBarrierUtil.hpp"
33 #include "vkBufferWithMemory.hpp"
34 #include "vkImageWithMemory.hpp"
35 #include "vkTypeUtil.hpp"
36
37 #include "vkRayTracingUtil.hpp"
38
39 #include "tcuTestLog.hpp"
40
41 #include "deRandom.hpp"
42
43 namespace vkt
44 {
45 namespace RayTracing
46 {
47 namespace
48 {
49 using namespace vk;
50 using namespace std;
51
52 static const VkFlags    ALL_RAY_TRACING_STAGES  = VK_SHADER_STAGE_RAYGEN_BIT_KHR
53                                                                                                 | VK_SHADER_STAGE_ANY_HIT_BIT_KHR
54                                                                                                 | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR
55                                                                                                 | VK_SHADER_STAGE_MISS_BIT_KHR
56                                                                                                 | VK_SHADER_STAGE_INTERSECTION_BIT_KHR
57                                                                                                 | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
58
59 #if defined(DE_DEBUG)
60 static const deUint32   PUSH_CONSTANTS_COUNT    = 6;
61 #endif
62 static const deUint32   DEFAULT_CLEAR_VALUE             = 999999;
63
64 enum TestType
65 {
66         TEST_TYPE_IF                                            = 0,
67         TEST_TYPE_LOOP,
68         TEST_TYPE_SWITCH,
69         TEST_TYPE_LOOP_DOUBLE_CALL,
70         TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE,
71         TEST_TYPE_NESTED_LOOP,
72         TEST_TYPE_NESTED_LOOP_BEFORE,
73         TEST_TYPE_NESTED_LOOP_AFTER,
74         TEST_TYPE_FUNCTION_CALL,
75         TEST_TYPE_NESTED_FUNCTION_CALL,
76 };
77
78 enum TestOp
79 {
80         TEST_OP_EXECUTE_CALLABLE                = 0,
81         TEST_OP_TRACE_RAY,
82         TEST_OP_REPORT_INTERSECTION,
83 };
84
85 enum ShaderGroups
86 {
87         FIRST_GROUP             = 0,
88         RAYGEN_GROUP    = FIRST_GROUP,
89         MISS_GROUP,
90         HIT_GROUP,
91         GROUP_COUNT
92 };
93
94 struct CaseDef
95 {
96         TestType                                testType;
97         TestOp                                  testOp;
98         VkShaderStageFlagBits   stage;
99         deUint32                                width;
100         deUint32                                height;
101 };
102
103 struct PushConstants
104 {
105         deUint32        a;
106         deUint32        b;
107         deUint32        c;
108         deUint32        d;
109         deUint32        hitOfs;
110         deUint32        miss;
111 };
112
113 deUint32 getShaderGroupSize (const InstanceInterface&   vki,
114                                                          const VkPhysicalDevice         physicalDevice)
115 {
116         de::MovePtr<RayTracingProperties>       rayTracingPropertiesKHR;
117
118         rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
119         return rayTracingPropertiesKHR->getShaderGroupHandleSize();
120 }
121
122 deUint32 getShaderGroupBaseAlignment (const InstanceInterface&  vki,
123                                                                           const VkPhysicalDevice        physicalDevice)
124 {
125         de::MovePtr<RayTracingProperties>       rayTracingPropertiesKHR;
126
127         rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
128         return rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
129 }
130
131 VkImageCreateInfo makeImageCreateInfo (deUint32 width, deUint32 height, deUint32 depth, VkFormat format)
132 {
133         const VkImageUsageFlags usage                   = VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT;
134         const VkImageCreateInfo imageCreateInfo =
135         {
136                 VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO,    // VkStructureType                      sType;
137                 DE_NULL,                                                                // const void*                          pNext;
138                 (VkImageCreateFlags)0u,                                 // VkImageCreateFlags           flags;
139                 VK_IMAGE_TYPE_3D,                                               // VkImageType                          imageType;
140                 format,                                                                 // VkFormat                                     format;
141                 makeExtent3D(width, height, depth),             // VkExtent3D                           extent;
142                 1u,                                                                             // deUint32                                     mipLevels;
143                 1u,                                                                             // deUint32                                     arrayLayers;
144                 VK_SAMPLE_COUNT_1_BIT,                                  // VkSampleCountFlagBits        samples;
145                 VK_IMAGE_TILING_OPTIMAL,                                // VkImageTiling                        tiling;
146                 usage,                                                                  // VkImageUsageFlags            usage;
147                 VK_SHARING_MODE_EXCLUSIVE,                              // VkSharingMode                        sharingMode;
148                 0u,                                                                             // deUint32                                     queueFamilyIndexCount;
149                 DE_NULL,                                                                // const deUint32*                      pQueueFamilyIndices;
150                 VK_IMAGE_LAYOUT_UNDEFINED                               // VkImageLayout                        initialLayout;
151         };
152
153         return imageCreateInfo;
154 }
155
156 Move<VkPipelineLayout> makePipelineLayout (const DeviceInterface&               vk,
157                                                                                    const VkDevice                               device,
158                                                                                    const VkDescriptorSetLayout  descriptorSetLayout,
159                                                                                    const deUint32                               pushConstantsSize)
160 {
161         const VkDescriptorSetLayout*            descriptorSetLayoutPtr  = (descriptorSetLayout == DE_NULL) ? DE_NULL : &descriptorSetLayout;
162         const deUint32                                          setLayoutCount                  = (descriptorSetLayout == DE_NULL) ? 0u : 1u;
163         const VkPushConstantRange                       pushConstantRange               =
164         {
165                 ALL_RAY_TRACING_STAGES,         //  VkShaderStageFlags  stageFlags;
166                 0u,                                                     //  deUint32                    offset;
167                 pushConstantsSize,                      //  deUint32                    size;
168         };
169         const VkPushConstantRange*                      pPushConstantRanges             = (pushConstantsSize == 0) ? DE_NULL : &pushConstantRange;
170         const deUint32                                          pushConstantRangeCount  = (pushConstantsSize == 0) ? 0 : 1u;
171         const VkPipelineLayoutCreateInfo        pipelineLayoutParams    =
172         {
173                 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,  // VkStructureType                                      sType;
174                 DE_NULL,                                                                                // const void*                                          pNext;
175                 0u,                                                                                             // VkPipelineLayoutCreateFlags          flags;
176                 setLayoutCount,                                                                 // deUint32                                                     setLayoutCount;
177                 descriptorSetLayoutPtr,                                                 // const VkDescriptorSetLayout*         pSetLayouts;
178                 pushConstantRangeCount,                                                 // deUint32                                                     pushConstantRangeCount;
179                 pPushConstantRanges,                                                    // const VkPushConstantRange*           pPushConstantRanges;
180         };
181
182         return createPipelineLayout(vk, device, &pipelineLayoutParams);
183 }
184
185 VkBuffer getVkBuffer (const de::MovePtr<BufferWithMemory>& buffer)
186 {
187         VkBuffer result = (buffer.get() == DE_NULL) ? DE_NULL : buffer->get();
188
189         return result;
190 }
191
192 VkStridedDeviceAddressRegionKHR makeStridedDeviceAddressRegion (const DeviceInterface& vkd, const VkDevice device, VkBuffer buffer, deUint32 stride, deUint32 count)
193 {
194         if (buffer == DE_NULL)
195         {
196                 return makeStridedDeviceAddressRegionKHR(0, 0, 0);
197         }
198         else
199         {
200                 return makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, buffer, 0), stride, stride * count);
201         }
202 }
203
204 // Function replacing all occurrences of substring with string passed in last parameter.
205 static inline std::string replace(const std::string& str, const std::string& from, const std::string& to)
206 {
207         std::string result(str);
208
209         size_t start_pos = 0;
210         while((start_pos = result.find(from, start_pos)) != std::string::npos)
211         {
212                 result.replace(start_pos, from.length(), to);
213                 start_pos += to.length();
214         }
215
216         return result;
217 }
218
219
220 class RayTracingComplexControlFlowInstance : public TestInstance
221 {
222 public:
223                                                                                                                                 RayTracingComplexControlFlowInstance    (Context& context, const CaseDef& data);
224                                                                                                                                 ~RayTracingComplexControlFlowInstance   (void);
225         tcu::TestStatus                                                                                         iterate                                                                 (void);
226
227 protected:
228         void                                                                                                            calcShaderGroup                                                 (deUint32&                                      shaderGroupCounter,
229                                                                                                                                                                                                                  const VkShaderStageFlags       shaders1,
230                                                                                                                                                                                                                  const VkShaderStageFlags       shaders2,
231                                                                                                                                                                                                                  const VkShaderStageFlags       shaderStageFlags,
232                                                                                                                                                                                                                  deUint32&                                      shaderGroup,
233                                                                                                                                                                                                                  deUint32&                                      shaderGroupCount) const;
234         PushConstants                                                                                           getPushConstants                                                (void) const;
235         std::vector<deUint32>                                                                           getExpectedValues                                               (void) const;
236         de::MovePtr<BufferWithMemory>                                                           runTest                                                                 (void);
237         Move<VkPipeline>                                                                                        makePipeline                                                    (de::MovePtr<RayTracingPipeline>&                                                       rayTracingPipeline,
238                                                                                                                                                                                                                  VkPipelineLayout                                                                                       pipelineLayout);
239         de::MovePtr<BufferWithMemory>                                                           createShaderBindingTable                                 (const InstanceInterface&                                                                      vki,
240                                                                                                                                                                                                                  const DeviceInterface&                                                                         vkd,
241                                                                                                                                                                                                                  const VkDevice                                                                                         device,
242                                                                                                                                                                                                                  const VkPhysicalDevice                                                                         physicalDevice,
243                                                                                                                                                                                                                  const VkPipeline                                                                                       pipeline,
244                                                                                                                                                                                                                  Allocator&                                                                                                     allocator,
245                                                                                                                                                                                                                  de::MovePtr<RayTracingPipeline>&                                                       rayTracingPipeline,
246                                                                                                                                                                                                                  const deUint32                                                                                         group,
247                                                                                                                                                                                                                  const deUint32                                                                                         groupCount = 1u);
248         de::MovePtr<TopLevelAccelerationStructure>                                      initTopAccelerationStructure                    (VkCommandBuffer                                                                                        cmdBuffer,
249                                                                                                                                                                                                                  vector<de::SharedPtr<BottomLevelAccelerationStructure> >&      bottomLevelAccelerationStructures);
250         vector<de::SharedPtr<BottomLevelAccelerationStructure>  >       initBottomAccelerationStructures                (VkCommandBuffer                                                                                        cmdBuffer);
251         de::MovePtr<BottomLevelAccelerationStructure>                           initBottomAccelerationStructure                 (VkCommandBuffer                                                                                        cmdBuffer,
252                                                                                                                                                                                                                  tcu::UVec2&                                                                                            startPos);
253
254 private:
255         CaseDef                                                                                                         m_data;
256         VkShaderStageFlags                                                                                      m_shaders;
257         VkShaderStageFlags                                                                                      m_shaders2;
258         deUint32                                                                                                        m_raygenShaderGroup;
259         deUint32                                                                                                        m_missShaderGroup;
260         deUint32                                                                                                        m_hitShaderGroup;
261         deUint32                                                                                                        m_callableShaderGroup;
262         deUint32                                                                                                        m_raygenShaderGroupCount;
263         deUint32                                                                                                        m_missShaderGroupCount;
264         deUint32                                                                                                        m_hitShaderGroupCount;
265         deUint32                                                                                                        m_callableShaderGroupCount;
266         deUint32                                                                                                        m_shaderGroupCount;
267         deUint32                                                                                                        m_depth;
268         PushConstants                                                                                           m_pushConstants;
269 };
270
271 RayTracingComplexControlFlowInstance::RayTracingComplexControlFlowInstance (Context& context, const CaseDef& data)
272         : vkt::TestInstance                             (context)
273         , m_data                                                (data)
274         , m_shaders                                             (0)
275         , m_shaders2                                    (0)
276         , m_raygenShaderGroup                   (~0u)
277         , m_missShaderGroup                             (~0u)
278         , m_hitShaderGroup                              (~0u)
279         , m_callableShaderGroup                 (~0u)
280         , m_raygenShaderGroupCount              (0)
281         , m_missShaderGroupCount                (0)
282         , m_hitShaderGroupCount                 (0)
283         , m_callableShaderGroupCount    (0)
284         , m_shaderGroupCount                    (0)
285         , m_depth                                               (16)
286         , m_pushConstants                               (getPushConstants())
287 {
288         const VkShaderStageFlags        hitStages       = VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
289         BinaryCollection&                       collection      = m_context.getBinaryCollection();
290         deUint32                                        shaderCount     = 0;
291
292         if (collection.contains("rgen")) m_shaders |= VK_SHADER_STAGE_RAYGEN_BIT_KHR;
293         if (collection.contains("ahit")) m_shaders |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
294         if (collection.contains("chit")) m_shaders |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
295         if (collection.contains("miss")) m_shaders |= VK_SHADER_STAGE_MISS_BIT_KHR;
296         if (collection.contains("sect")) m_shaders |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
297         if (collection.contains("call")) m_shaders |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
298
299         if (collection.contains("ahit2")) m_shaders2 |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
300         if (collection.contains("chit2")) m_shaders2 |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
301         if (collection.contains("miss2")) m_shaders2 |= VK_SHADER_STAGE_MISS_BIT_KHR;
302         if (collection.contains("sect2")) m_shaders2 |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
303
304         if (collection.contains("cal0")) m_shaders2 |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
305
306         for (BinaryCollection::Iterator it = collection.begin(); it != collection.end(); ++it)
307                 shaderCount++;
308
309         if (shaderCount != (deUint32)dePop32(m_shaders) + (deUint32)dePop32(m_shaders2))
310                 TCU_THROW(InternalError, "Unused shaders detected in the collection");
311
312         calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_RAYGEN_BIT_KHR,   m_raygenShaderGroup,   m_raygenShaderGroupCount);
313         calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_MISS_BIT_KHR,     m_missShaderGroup,     m_missShaderGroupCount);
314         calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, hitStages,                        m_hitShaderGroup,      m_hitShaderGroupCount);
315         calcShaderGroup(m_shaderGroupCount, m_shaders, m_shaders2, VK_SHADER_STAGE_CALLABLE_BIT_KHR, m_callableShaderGroup, m_callableShaderGroupCount);
316 }
317
318 RayTracingComplexControlFlowInstance::~RayTracingComplexControlFlowInstance (void)
319 {
320 }
321
322 void RayTracingComplexControlFlowInstance::calcShaderGroup (deUint32&                                   shaderGroupCounter,
323                                                                                                                         const VkShaderStageFlags        shaders1,
324                                                                                                                         const VkShaderStageFlags        shaders2,
325                                                                                                                         const VkShaderStageFlags        shaderStageFlags,
326                                                                                                                         deUint32&                                       shaderGroup,
327                                                                                                                         deUint32&                                       shaderGroupCount) const
328 {
329         const deUint32  shader1Count = ((shaders1 & shaderStageFlags) != 0) ? 1 : 0;
330         const deUint32  shader2Count = ((shaders2 & shaderStageFlags) != 0) ? 1 : 0;
331
332         shaderGroupCount = shader1Count + shader2Count;
333
334         if (shaderGroupCount != 0)
335         {
336                 shaderGroup                     = shaderGroupCounter;
337                 shaderGroupCounter += shaderGroupCount;
338         }
339 }
340
341 Move<VkPipeline> RayTracingComplexControlFlowInstance::makePipeline (de::MovePtr<RayTracingPipeline>&   rayTracingPipeline,
342                                                                                                                                           VkPipelineLayout                                      pipelineLayout)
343 {
344         const DeviceInterface&  vkd                     = m_context.getDeviceInterface();
345         const VkDevice                  device          = m_context.getDevice();
346         vk::BinaryCollection&   collection      = m_context.getBinaryCollection();
347
348         if (0 != (m_shaders & VK_SHADER_STAGE_RAYGEN_BIT_KHR))                  rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR            , createShaderModule(vkd, device, collection.get("rgen"), 0), m_raygenShaderGroup);
349         if (0 != (m_shaders & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))                 rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR           , createShaderModule(vkd, device, collection.get("ahit"), 0), m_hitShaderGroup);
350         if (0 != (m_shaders & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))             rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR       , createShaderModule(vkd, device, collection.get("chit"), 0), m_hitShaderGroup);
351         if (0 != (m_shaders & VK_SHADER_STAGE_MISS_BIT_KHR))                    rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR                      , createShaderModule(vkd, device, collection.get("miss"), 0), m_missShaderGroup);
352         if (0 != (m_shaders & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))    rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR      , createShaderModule(vkd, device, collection.get("sect"), 0), m_hitShaderGroup);
353         if (0 != (m_shaders & VK_SHADER_STAGE_CALLABLE_BIT_KHR))                rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR          , createShaderModule(vkd, device, collection.get("call"), 0), m_callableShaderGroup + 1);
354
355         if (0 != (m_shaders2 & VK_SHADER_STAGE_CALLABLE_BIT_KHR))               rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR          , createShaderModule(vkd, device, collection.get("cal0"), 0), m_callableShaderGroup);
356         if (0 != (m_shaders2 & VK_SHADER_STAGE_ANY_HIT_BIT_KHR))                rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR           , createShaderModule(vkd, device, collection.get("ahit2"), 0), m_hitShaderGroup + 1);
357         if (0 != (m_shaders2 & VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR))    rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR       , createShaderModule(vkd, device, collection.get("chit2"), 0), m_hitShaderGroup + 1);
358         if (0 != (m_shaders2 & VK_SHADER_STAGE_MISS_BIT_KHR))                   rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR                      , createShaderModule(vkd, device, collection.get("miss2"), 0), m_missShaderGroup + 1);
359         if (0 != (m_shaders2 & VK_SHADER_STAGE_INTERSECTION_BIT_KHR))   rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR      , createShaderModule(vkd, device, collection.get("sect2"), 0), m_hitShaderGroup + 1);
360
361         if (m_data.testOp == TEST_OP_TRACE_RAY && m_data.stage != VK_SHADER_STAGE_RAYGEN_BIT_KHR)
362                 rayTracingPipeline->setMaxRecursionDepth(2);
363
364         Move<VkPipeline> pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout);
365
366         return pipeline;
367 }
368
369 de::MovePtr<BufferWithMemory> RayTracingComplexControlFlowInstance::createShaderBindingTable (const InstanceInterface&                  vki,
370                                                                                                                                                                                           const DeviceInterface&                        vkd,
371                                                                                                                                                                                           const VkDevice                                        device,
372                                                                                                                                                                                           const VkPhysicalDevice                        physicalDevice,
373                                                                                                                                                                                           const VkPipeline                                      pipeline,
374                                                                                                                                                                                           Allocator&                                            allocator,
375                                                                                                                                                                                           de::MovePtr<RayTracingPipeline>&      rayTracingPipeline,
376                                                                                                                                                                                           const deUint32                                        group,
377                                                                                                                                                                                           const deUint32                                        groupCount)
378 {
379         de::MovePtr<BufferWithMemory>   shaderBindingTable;
380
381         if (group < m_shaderGroupCount)
382         {
383                 const deUint32  shaderGroupHandleSize           = getShaderGroupSize(vki, physicalDevice);
384                 const deUint32  shaderGroupBaseAlignment        = getShaderGroupBaseAlignment(vki, physicalDevice);
385
386                 shaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment, group, groupCount);
387         }
388
389         return shaderBindingTable;
390 }
391
392
393 de::MovePtr<TopLevelAccelerationStructure> RayTracingComplexControlFlowInstance::initTopAccelerationStructure (VkCommandBuffer                                                                                          cmdBuffer,
394                                                                                                                                                                                                                            vector<de::SharedPtr<BottomLevelAccelerationStructure> >&    bottomLevelAccelerationStructures)
395 {
396         const DeviceInterface&                                          vkd                     = m_context.getDeviceInterface();
397         const VkDevice                                                          device          = m_context.getDevice();
398         Allocator&                                                                      allocator       = m_context.getDefaultAllocator();
399         de::MovePtr<TopLevelAccelerationStructure>      result          = makeTopLevelAccelerationStructure();
400
401         result->setInstanceCount(bottomLevelAccelerationStructures.size());
402
403         for (size_t structNdx = 0; structNdx < bottomLevelAccelerationStructures.size(); ++structNdx)
404                 result->addInstance(bottomLevelAccelerationStructures[structNdx]);
405
406         result->createAndBuild(vkd, device, cmdBuffer, allocator);
407
408         return result;
409 }
410
411 de::MovePtr<BottomLevelAccelerationStructure> RayTracingComplexControlFlowInstance::initBottomAccelerationStructure (VkCommandBuffer    cmdBuffer,
412                                                                                                                                                                                                                                          tcu::UVec2&            startPos)
413 {
414         const DeviceInterface&                                                  vkd                             = m_context.getDeviceInterface();
415         const VkDevice                                                                  device                  = m_context.getDevice();
416         Allocator&                                                                              allocator               = m_context.getDefaultAllocator();
417         de::MovePtr<BottomLevelAccelerationStructure>   result                  = makeBottomLevelAccelerationStructure();
418         const float                                                                             z                               = (m_data.stage == VK_SHADER_STAGE_MISS_BIT_KHR) ? +1.0f : -1.0f;
419         std::vector<tcu::Vec3>                                                  geometryData;
420
421         DE_UNREF(startPos);
422
423         result->setGeometryCount(1);
424         geometryData.push_back(tcu::Vec3(0.0f, 0.0f, z));
425         geometryData.push_back(tcu::Vec3(1.0f, 1.0f, z));
426         result->addGeometry(geometryData, false);
427         result->createAndBuild(vkd, device, cmdBuffer, allocator);
428
429         return result;
430 }
431
432 vector<de::SharedPtr<BottomLevelAccelerationStructure> > RayTracingComplexControlFlowInstance::initBottomAccelerationStructures (VkCommandBuffer        cmdBuffer)
433 {
434         tcu::UVec2                                                                                                      startPos;
435         vector<de::SharedPtr<BottomLevelAccelerationStructure> >        result;
436         de::MovePtr<BottomLevelAccelerationStructure>                           bottomLevelAccelerationStructure        = initBottomAccelerationStructure(cmdBuffer, startPos);
437
438         result.push_back(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
439
440         return result;
441 }
442
443 PushConstants RayTracingComplexControlFlowInstance::getPushConstants (void) const
444 {
445         const                   deUint32        hitOfs  = 1;
446         const                   deUint32        miss    = 1;
447         PushConstants   result;
448
449         switch (m_data.testType)
450         {
451                 case TEST_TYPE_IF:
452                 {
453                         result = { 32 | 8 | 1, 10000, 0x0F, 0xF0, hitOfs, miss };
454
455                         break;
456                 }
457                 case TEST_TYPE_LOOP:
458                 {
459                         result = { 8, 10000, 0x0F, 100000, hitOfs, miss };
460
461                         break;
462                 }
463                 case TEST_TYPE_SWITCH:
464                 {
465                         result = { 3, 10000, 0x07, 100000, hitOfs, miss };
466
467                         break;
468                 }
469                 case TEST_TYPE_LOOP_DOUBLE_CALL:
470                 {
471                         result = { 7, 10000, 0x0F, 0xF0, hitOfs, miss };
472
473                         break;
474                 }
475                 case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
476                 {
477                         result = { 16, 5, 0x0F, 0xF0, hitOfs, miss };
478
479                         break;
480                 }
481                 case TEST_TYPE_NESTED_LOOP:
482                 {
483                         result = { 8, 5, 0x0F, 0x09, hitOfs, miss };
484
485                         break;
486                 }
487                 case TEST_TYPE_NESTED_LOOP_BEFORE:
488                 {
489                         result = { 9, 16, 0x0F, 10, hitOfs, miss };
490
491                         break;
492                 }
493                 case TEST_TYPE_NESTED_LOOP_AFTER:
494                 {
495                         result = { 9, 16, 0x0F, 10, hitOfs, miss };
496
497                         break;
498                 }
499                 case TEST_TYPE_FUNCTION_CALL:
500                 {
501                         result = { 0xFFB, 16, 10, 100000, hitOfs, miss };
502
503                         break;
504                 }
505                 case TEST_TYPE_NESTED_FUNCTION_CALL:
506                 {
507                         result = { 0xFFB, 16, 10, 100000, hitOfs, miss };
508
509                         break;
510                 }
511
512                 default:
513                         TCU_THROW(InternalError, "Unknown testType");
514         }
515
516         return result;
517 }
518
519 de::MovePtr<BufferWithMemory> RayTracingComplexControlFlowInstance::runTest (void)
520 {
521         const InstanceInterface&                                vki                                                                     = m_context.getInstanceInterface();
522         const DeviceInterface&                                  vkd                                                                     = m_context.getDeviceInterface();
523         const VkDevice                                                  device                                                          = m_context.getDevice();
524         const VkPhysicalDevice                                  physicalDevice                                          = m_context.getPhysicalDevice();
525         const deUint32                                                  queueFamilyIndex                                        = m_context.getUniversalQueueFamilyIndex();
526         const VkQueue                                                   queue                                                           = m_context.getUniversalQueue();
527         Allocator&                                                              allocator                                                       = m_context.getDefaultAllocator();
528         const VkFormat                                                  format                                                          = VK_FORMAT_R32_UINT;
529         const deUint32                                                  pushConstants[]                                         = { m_pushConstants.a, m_pushConstants.b, m_pushConstants.c, m_pushConstants.d, m_pushConstants.hitOfs, m_pushConstants.miss };
530         const deUint32                                                  pushConstantsSize                                       = sizeof(pushConstants);
531         const deUint32                                                  pixelCount                                                      = m_data.width * m_data.height * m_depth;
532         const deUint32                                                  shaderGroupHandleSize                           = getShaderGroupSize(vki, physicalDevice);
533
534         const Move<VkDescriptorSetLayout>               descriptorSetLayout                                     = DescriptorSetLayoutBuilder()
535                                                                                                                                                                                 .addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, ALL_RAY_TRACING_STAGES)
536                                                                                                                                                                                 .addSingleBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, ALL_RAY_TRACING_STAGES)
537                                                                                                                                                                                 .build(vkd, device);
538         const Move<VkDescriptorPool>                    descriptorPool                                          = DescriptorPoolBuilder()
539                                                                                                                                                                                 .addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE)
540                                                                                                                                                                                 .addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR)
541                                                                                                                                                                                 .build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
542         const Move<VkDescriptorSet>                             descriptorSet                                           = makeDescriptorSet(vkd, device, *descriptorPool, *descriptorSetLayout);
543         const Move<VkPipelineLayout>                    pipelineLayout                                          = makePipelineLayout(vkd, device, descriptorSetLayout.get(), pushConstantsSize);
544         const Move<VkCommandPool>                               cmdPool                                                         = createCommandPool(vkd, device, 0, queueFamilyIndex);
545         const Move<VkCommandBuffer>                             cmdBuffer                                                       = allocateCommandBuffer(vkd, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
546
547         de::MovePtr<RayTracingPipeline>                 rayTracingPipeline                                      = de::newMovePtr<RayTracingPipeline>();
548         const Move<VkPipeline>                                  pipeline                                                        = makePipeline(rayTracingPipeline, *pipelineLayout);
549         const de::MovePtr<BufferWithMemory>             raygenShaderBindingTable                        = createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_raygenShaderGroup, m_raygenShaderGroupCount);
550         const de::MovePtr<BufferWithMemory>             missShaderBindingTable                          = createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_missShaderGroup, m_missShaderGroupCount);
551         const de::MovePtr<BufferWithMemory>             hitShaderBindingTable                           = createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_hitShaderGroup, m_hitShaderGroupCount);
552         const de::MovePtr<BufferWithMemory>             callableShaderBindingTable                      = createShaderBindingTable(vki, vkd, device, physicalDevice, *pipeline, allocator, rayTracingPipeline, m_callableShaderGroup, m_callableShaderGroupCount);
553
554         const VkStridedDeviceAddressRegionKHR   raygenShaderBindingTableRegion          = makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(raygenShaderBindingTable),   shaderGroupHandleSize, m_raygenShaderGroupCount);
555         const VkStridedDeviceAddressRegionKHR   missShaderBindingTableRegion            = makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(missShaderBindingTable),     shaderGroupHandleSize, m_missShaderGroupCount);
556         const VkStridedDeviceAddressRegionKHR   hitShaderBindingTableRegion                     = makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(hitShaderBindingTable),      shaderGroupHandleSize, m_hitShaderGroupCount);
557         const VkStridedDeviceAddressRegionKHR   callableShaderBindingTableRegion        = makeStridedDeviceAddressRegion(vkd, device, getVkBuffer(callableShaderBindingTable), shaderGroupHandleSize, m_callableShaderGroupCount);
558
559         const VkImageCreateInfo                                 imageCreateInfo                                         = makeImageCreateInfo(m_data.width, m_data.height, m_depth, format);
560         const VkImageSubresourceRange                   imageSubresourceRange                           = makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0, 1u);
561         const de::MovePtr<ImageWithMemory>              image                                                           = de::MovePtr<ImageWithMemory>(new ImageWithMemory(vkd, device, allocator, imageCreateInfo, MemoryRequirement::Any));
562         const Move<VkImageView>                                 imageView                                                       = makeImageView(vkd, device, **image, VK_IMAGE_VIEW_TYPE_3D, format, imageSubresourceRange);
563
564         const VkBufferCreateInfo                                bufferCreateInfo                                        = makeBufferCreateInfo(pixelCount*sizeof(deUint32), VK_BUFFER_USAGE_TRANSFER_DST_BIT);
565         const VkImageSubresourceLayers                  bufferImageSubresourceLayers            = makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
566         const VkBufferImageCopy                                 bufferImageRegion                                       = makeBufferImageCopy(makeExtent3D(m_data.width, m_data.height, m_depth), bufferImageSubresourceLayers);
567         de::MovePtr<BufferWithMemory>                   buffer                                                          = de::MovePtr<BufferWithMemory>(new BufferWithMemory(vkd, device, allocator, bufferCreateInfo, MemoryRequirement::HostVisible));
568
569         const VkDescriptorImageInfo                             descriptorImageInfo                                     = makeDescriptorImageInfo(DE_NULL, *imageView, VK_IMAGE_LAYOUT_GENERAL);
570
571         const VkImageMemoryBarrier                              preImageBarrier                                         = makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT,
572                                                                                                                                                                         VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL,
573                                                                                                                                                                         **image, imageSubresourceRange);
574         const VkImageMemoryBarrier                              postImageBarrier                                        = makeImageMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT,
575                                                                                                                                                                         VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL,
576                                                                                                                                                                         **image, imageSubresourceRange);
577         const VkMemoryBarrier                                   preTraceMemoryBarrier                           = makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
578         const VkMemoryBarrier                                   postTraceMemoryBarrier                          = makeMemoryBarrier(VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
579         const VkMemoryBarrier                                   postCopyMemoryBarrier                           = makeMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
580         const VkClearValue                                              clearValue                                                      = makeClearValueColorU32(DEFAULT_CLEAR_VALUE, 0u, 0u, 255u);
581
582         vector<de::SharedPtr<BottomLevelAccelerationStructure> >        bottomLevelAccelerationStructures;
583         de::MovePtr<TopLevelAccelerationStructure>                                      topLevelAccelerationStructure;
584
585         DE_ASSERT(DE_LENGTH_OF_ARRAY(pushConstants) == PUSH_CONSTANTS_COUNT);
586
587         beginCommandBuffer(vkd, *cmdBuffer, 0u);
588         {
589                 vkd.cmdPushConstants(*cmdBuffer, *pipelineLayout, ALL_RAY_TRACING_STAGES, 0, pushConstantsSize, &m_pushConstants);
590
591                 cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, &preImageBarrier);
592                 vkd.cmdClearColorImage(*cmdBuffer, **image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, &clearValue.color, 1, &imageSubresourceRange);
593                 cmdPipelineImageMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, ALL_RAY_TRACING_STAGES, &postImageBarrier);
594
595                 bottomLevelAccelerationStructures = initBottomAccelerationStructures(*cmdBuffer);
596                 topLevelAccelerationStructure = initTopAccelerationStructure(*cmdBuffer, bottomLevelAccelerationStructures);
597
598                 cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, ALL_RAY_TRACING_STAGES, &preTraceMemoryBarrier);
599
600                 const TopLevelAccelerationStructure*                    topLevelAccelerationStructurePtr                = topLevelAccelerationStructure.get();
601                 VkWriteDescriptorSetAccelerationStructureKHR    accelerationStructureWriteDescriptorSet =
602                 {
603                         VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,      //  VkStructureType                                             sType;
604                         DE_NULL,                                                                                                                        //  const void*                                                 pNext;
605                         1u,                                                                                                                                     //  deUint32                                                    accelerationStructureCount;
606                         topLevelAccelerationStructurePtr->getPtr(),                                                     //  const VkAccelerationStructureKHR*   pAccelerationStructures;
607                 };
608
609                 DescriptorSetUpdateBuilder()
610                         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &descriptorImageInfo)
611                         .writeSingle(*descriptorSet, DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &accelerationStructureWriteDescriptorSet)
612                         .update(vkd, device);
613
614                 vkd.cmdBindDescriptorSets(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipelineLayout, 0, 1, &descriptorSet.get(), 0, DE_NULL);
615
616                 vkd.cmdBindPipeline(*cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, *pipeline);
617
618                 cmdTraceRays(vkd,
619                         *cmdBuffer,
620                         &raygenShaderBindingTableRegion,
621                         &missShaderBindingTableRegion,
622                         &hitShaderBindingTableRegion,
623                         &callableShaderBindingTableRegion,
624                         m_data.width, m_data.height, 1);
625
626                 cmdPipelineMemoryBarrier(vkd, *cmdBuffer, ALL_RAY_TRACING_STAGES, VK_PIPELINE_STAGE_TRANSFER_BIT, &postTraceMemoryBarrier);
627
628                 vkd.cmdCopyImageToBuffer(*cmdBuffer, **image, VK_IMAGE_LAYOUT_GENERAL, **buffer, 1u, &bufferImageRegion);
629
630                 cmdPipelineMemoryBarrier(vkd, *cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_HOST_BIT, &postCopyMemoryBarrier);
631         }
632         endCommandBuffer(vkd, *cmdBuffer);
633
634         submitCommandsAndWait(vkd, device, queue, cmdBuffer.get());
635
636         invalidateMappedMemoryRange(vkd, device, buffer->getAllocation().getMemory(), buffer->getAllocation().getOffset(), pixelCount * sizeof(deUint32));
637
638         return buffer;
639 }
640
641 std::vector<deUint32> RayTracingComplexControlFlowInstance::getExpectedValues (void) const
642 {
643         const deUint32                          plainSize               = m_data.width * m_data.height;
644         const deUint32                          plain8Ofs               = 8 * plainSize;
645         const struct PushConstants&     p                               = m_pushConstants;
646         const deUint32                          pushConstants[] = { 0, m_pushConstants.a, m_pushConstants.b, m_pushConstants.c, m_pushConstants.d, m_pushConstants.hitOfs, m_pushConstants.miss };
647         const deUint32                          resultSize              = plainSize * m_depth;
648         const bool                                      fixed                   = m_data.testOp == TEST_OP_REPORT_INTERSECTION;
649         std::vector<deUint32>           result                  (resultSize, DEFAULT_CLEAR_VALUE);
650         deUint32                                        v0;
651         deUint32                                        v1;
652         deUint32                                        v2;
653         deUint32                                        v3;
654
655         switch (m_data.testType)
656         {
657                 case TEST_TYPE_IF:
658                 {
659                         for (deUint32 id = 0; id < plainSize; ++id)
660                         {
661                                 v2 = v3 = p.b;
662
663                                 if ((p.a & id) != 0)
664                                 {
665                                         v0 = p.c & id;
666                                         v1 = (p.d & id) + 1;
667
668                                         result[plain8Ofs + id] = v0;
669                                         if (!fixed) v0++;
670                                 }
671                                 else
672                                 {
673                                         v0 = p.d & id;
674                                         v1 = (p.c & id) + 1;
675
676                                         if (!fixed)
677                                         {
678                                                 result[plain8Ofs + id] = v1;
679                                                 v1++;
680                                         }
681                                         else
682                                                 result[plain8Ofs + id] = v0;
683                                 }
684
685                                 result[id] = v0 + v1 + v2 + v3;
686                         }
687
688                         break;
689                 }
690                 case TEST_TYPE_LOOP:
691                 {
692                         for (deUint32 id = 0; id < plainSize; ++id)
693                         {
694                                 result[id] = 0;
695
696                                 v1 = v3 = p.b;
697
698                                 for (deUint32 n = 0; n < p.a; n++)
699                                 {
700                                         v0 = (p.c & id) + n;
701
702                                         result[((n % 8) + 8) * plainSize + id] = v0;
703                                         if (!fixed) v0++;
704
705                                         result[id] += v0 + v1 + v3;
706                                 }
707                         }
708
709                         break;
710                 }
711                 case TEST_TYPE_SWITCH:
712                 {
713                         for (deUint32 id = 0; id < plainSize; ++id)
714                         {
715                                 switch (p.a & id)
716                                 {
717                                         case 0: { v1 = v2 = v3 = p.b; v0 = p.c & id; break; }
718                                         case 1: { v0 = v2 = v3 = p.b; v1 = p.c & id; break; }
719                                         case 2: { v0 = v1 = v3 = p.b; v2 = p.c & id; break; }
720                                         case 3: { v0 = v1 = v2 = p.b; v3 = p.c & id; break; }
721                                         default: { v0 = v1 = v2 = v3 = 0; break; }
722                                 }
723
724                                 if (!fixed)
725                                         result[plain8Ofs + id] = p.c & id;
726                                 else
727                                         result[plain8Ofs + id] = v0;
728
729                                 result[id] = v0 + v1 + v2 + v3;
730
731                                 if (!fixed) result[id]++;
732                         }
733
734                         break;
735                 }
736                 case TEST_TYPE_LOOP_DOUBLE_CALL:
737                 {
738                         for (deUint32 id = 0; id < plainSize; ++id)
739                         {
740                                 result[id] = 0;
741
742                                 v3 = p.b;
743
744                                 for (deUint32 x = 0; x < p.a; x++)
745                                 {
746                                         v0 = (p.c & id) + x;
747                                         v1 = (p.d & id) + x + 1;
748
749                                         result[(((2 * x + 0) % 8) + 8) * plainSize + id] = v0;
750                                         if (!fixed) v0++;
751
752                                         if (!fixed)
753                                         {
754                                                 result[(((2 * x + 1) % 8) + 8) * plainSize + id] = v1;
755                                                 v1++;
756                                         }
757
758                                         result[id] += v0 + v1 + v3;
759                                 }
760                         }
761
762                         break;
763                 }
764                 case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
765                 {
766                         for (deUint32 id = 0; id < plainSize; ++id)
767                         {
768                                 result[id] = 0;
769
770                                 v3 = p.a + p.b;
771
772                                 for (deUint32 x = 0; x < p.a; x++)
773                                 {
774                                         if ((x & p.b) != 0)
775                                         {
776                                                 v0 = (p.c & id) + x;
777                                                 v1 = (p.d & id) + x + 1;
778
779                                                 result[(((2 * x + 0) % 8) + 8) * plainSize + id] = v0;
780                                                 if (!fixed) v0++;
781
782                                                 if (!fixed)
783                                                 {
784                                                         result[(((2 * x + 1) % 8) + 8) * plainSize + id] = v1;
785                                                         v1++;
786                                                 }
787
788                                                 result[id] += v0 + v1 + v3;
789                                         }
790                                 }
791                         }
792
793                         break;
794                 }
795                 case TEST_TYPE_NESTED_LOOP:
796                 {
797                         for (deUint32 id = 0; id < plainSize; ++id)
798                         {
799                                 result[id] = 0;
800
801                                 v1 = v3 = p.b;
802
803                                 for (deUint32 y = 0; y < p.a; y++)
804                                 for (deUint32 x = 0; x < p.a; x++)
805                                 {
806                                         const deUint32 n = x + y * p.a;
807
808                                         if ((n & p.d) != 0)
809                                         {
810                                                 v0 = (p.c & id) + n;
811
812                                                 result[((n % 8) + 8) * plainSize + id] = v0;
813                                                 if (!fixed) v0++;
814
815                                                 result[id] += v0 + v1 + v3;
816                                         }
817                                 }
818                         }
819
820                         break;
821                 }
822                 case TEST_TYPE_NESTED_LOOP_BEFORE:
823                 {
824                         for (deUint32 id = 0; id < plainSize; ++id)
825                         {
826                                 result[id] = 0;
827
828                                 for (deUint32 y = 0; y < p.d; y++)
829                                 for (deUint32 x = 0; x < p.d; x++)
830                                 {
831                                         if (((x + y * p.a) & p.b) != 0)
832                                                 result[id] += (x + y);
833                                 }
834
835                                 v1 = v3 = p.a;
836
837                                 for (deUint32 x = 0; x < p.b; x++)
838                                 {
839                                         if ((x & p.a) != 0)
840                                         {
841                                                 v0 = p.c & id;
842
843                                                 result[((x % 8) + 8) * plainSize + id] = v0;
844                                                 if (!fixed) v0++;
845
846                                                 result[id] += v0 + v1 + v3;
847                                         }
848                                 }
849                         }
850
851                         break;
852                 }
853                 case TEST_TYPE_NESTED_LOOP_AFTER:
854                 {
855                         for (deUint32 id = 0; id < plainSize; ++id)
856                         {
857                                 result[id] = 0;
858
859                                 v1 = v3 = p.a;
860
861                                 for (deUint32 x = 0; x < p.b; x++)
862                                 {
863                                         if ((x & p.a) != 0)
864                                         {
865                                                 v0 = p.c & id;
866
867                                                 result[((x % 8) + 8) * plainSize + id] = v0;
868                                                 if (!fixed) v0++;
869
870                                                 result[id] += v0 + v1 + v3;
871                                         }
872                                 }
873
874                                 for (deUint32 y = 0; y < p.d; y++)
875                                 for (deUint32 x = 0; x < p.d; x++)
876                                 {
877                                         if (((x + y * p.a) & p.b) != 0)
878                                                 result[id] += (x + y);
879                                 }
880                         }
881
882                         break;
883                 }
884                 case TEST_TYPE_FUNCTION_CALL:
885                 {
886                         deUint32 a[42];
887
888                         for (deUint32 id = 0; id < plainSize; ++id)
889                         {
890                                 deUint32 r = 0;
891                                 deUint32 i;
892
893                                 v0 = p.a & id;
894                                 v1 = v3 = p.d;
895
896                                 for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
897                                         a[i] = p.c * i;
898
899                                 result[plain8Ofs + id] = v0;
900                                 if (!fixed) v0++;
901
902                                 for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
903                                         r += a[i];
904
905                                 result[id] = (r + i) + v0 + v1 + v3;
906                         }
907
908                         break;
909                 }
910                 case TEST_TYPE_NESTED_FUNCTION_CALL:
911                 {
912                         deUint32 a[14];
913                         deUint32 b[256];
914
915                         for (deUint32 id = 0; id < plainSize; ++id)
916                         {
917                                 deUint32 r = 0;
918                                 deUint32 i;
919                                 deUint32 t = 0;
920                                 deUint32 j;
921
922                                 v0 = p.a & id;
923                                 v3 = p.d;
924
925                                 for (j = 0; j < DE_LENGTH_OF_ARRAY(b); j++)
926                                         b[j] = p.c * j;
927
928                                 v1 = p.b;
929
930                                 for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
931                                         a[i] = p.c * i;
932
933                                 result[plain8Ofs + id] = v0;
934                                 if (!fixed) v0++;
935
936                                 for (i = 0; i < DE_LENGTH_OF_ARRAY(a); i++)
937                                         r += a[i];
938
939                                 for (j = 0; j < DE_LENGTH_OF_ARRAY(b); j++)
940                                         t += b[j];
941
942                                 result[id] = (r + i) + (t + j) + v0 + v1 + v3;
943                         }
944
945                         break;
946                 }
947
948                 default:
949                         TCU_THROW(InternalError, "Unknown testType");
950         }
951
952         {
953                 const deUint32  startOfs        = 7 * plainSize;
954
955                 for (deUint32 n = 0; n < plainSize; ++n)
956                         result[startOfs + n] = n;
957         }
958
959         for (deUint32 z = 1; z < DE_LENGTH_OF_ARRAY(pushConstants); ++z)
960         {
961                 const deUint32  startOfs                = z * plainSize;
962                 const deUint32  pushConstant    = pushConstants[z];
963
964                 for (deUint32 n = 0; n < plainSize; ++n)
965                         result[startOfs + n] = pushConstant;
966         }
967
968         return result;
969 }
970
971 tcu::TestStatus RayTracingComplexControlFlowInstance::iterate (void)
972 {
973         const de::MovePtr<BufferWithMemory>     buffer          = runTest();
974         const deUint32*                                         bufferPtr       = (deUint32*)buffer->getAllocation().getHostPtr();
975         const vector<deUint32>                          expected        = getExpectedValues();
976         tcu::TestLog&                                           log                     = m_context.getTestContext().getLog();
977         deUint32                                                        failures        = 0;
978         deUint32                                                        pos                     = 0;
979
980         for (deUint32 z = 0; z < m_depth; ++z)
981         for (deUint32 y = 0; y < m_data.height; ++y)
982         for (deUint32 x = 0; x < m_data.width; ++x)
983         {
984                 if (bufferPtr[pos] != expected[pos])
985                         failures++;
986
987                 ++pos;
988         }
989
990         if (failures != 0)
991         {
992                 deUint32                        pos0    = 0;
993                 deUint32                        pos1    = 0;
994                 std::stringstream       css;
995
996                 for (deUint32 z = 0; z < m_depth; ++z)
997                 {
998                         css << "z=" << z << std::endl;
999
1000                         for (deUint32 y = 0; y < m_data.height; ++y)
1001                         {
1002                                 for (deUint32 x = 0; x < m_data.width; ++x)
1003                                         css << std::setw(6) << bufferPtr[pos0++] << ' ';
1004
1005                                 css << "    ";
1006
1007                                 for (deUint32 x = 0; x < m_data.width; ++x)
1008                                         css << std::setw(6) << expected[pos1++] << ' ';
1009
1010                                 css << std::endl;
1011                         }
1012
1013                         css << std::endl;
1014                 }
1015
1016                 log << tcu::TestLog::Message << css.str() << tcu::TestLog::EndMessage;
1017         }
1018
1019         if (failures == 0)
1020                 return tcu::TestStatus::pass("Pass");
1021         else
1022                 return tcu::TestStatus::fail("failures=" + de::toString(failures));
1023 }
1024
1025 class ComplexControlFlowTestCase : public TestCase
1026 {
1027         public:
1028                                                                                 ComplexControlFlowTestCase      (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data);
1029                                                                                 ~ComplexControlFlowTestCase     (void);
1030
1031         virtual void                                            initPrograms                            (SourceCollections& programCollection) const;
1032         virtual TestInstance*                           createInstance                          (Context& context) const;
1033         virtual void                                            checkSupport                            (Context& context) const;
1034
1035 private:
1036         static inline const std::string         getIntersectionPassthrough      (void);
1037         static inline const std::string         getMissPassthrough                      (void);
1038         static inline const std::string         getHitPassthrough                       (void);
1039
1040         CaseDef                                                         m_data;
1041 };
1042
1043 ComplexControlFlowTestCase::ComplexControlFlowTestCase (tcu::TestContext& context, const char* name, const char* desc, const CaseDef data)
1044         : vkt::TestCase (context, name, desc)
1045         , m_data                (data)
1046 {
1047 }
1048
1049 ComplexControlFlowTestCase::~ComplexControlFlowTestCase (void)
1050 {
1051 }
1052
1053 void ComplexControlFlowTestCase::checkSupport (Context& context) const
1054 {
1055         context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
1056
1057         const VkPhysicalDeviceAccelerationStructureFeaturesKHR& accelerationStructureFeaturesKHR = context.getAccelerationStructureFeatures();
1058
1059         if (accelerationStructureFeaturesKHR.accelerationStructure == DE_FALSE)
1060                 TCU_THROW(TestError, "VK_KHR_ray_tracing_pipeline requires VkPhysicalDeviceAccelerationStructureFeaturesKHR.accelerationStructure");
1061
1062         context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
1063
1064         const VkPhysicalDeviceRayTracingPipelineFeaturesKHR&    rayTracingPipelineFeaturesKHR = context.getRayTracingPipelineFeatures();
1065
1066         if (rayTracingPipelineFeaturesKHR.rayTracingPipeline == DE_FALSE)
1067                 TCU_THROW(NotSupportedError, "Requires VkPhysicalDeviceRayTracingPipelineFeaturesKHR.rayTracingPipeline");
1068
1069         const VkPhysicalDeviceRayTracingPipelinePropertiesKHR&  rayTracingPipelinePropertiesKHR = context.getRayTracingPipelineProperties();
1070
1071         if (m_data.testOp == TEST_OP_TRACE_RAY && m_data.stage != VK_SHADER_STAGE_RAYGEN_BIT_KHR)
1072         {
1073                 if (rayTracingPipelinePropertiesKHR.maxRayRecursionDepth < 2)
1074                         TCU_THROW(NotSupportedError, "rayTracingPipelinePropertiesKHR.maxRayRecursionDepth is smaller than required");
1075         }
1076 }
1077
1078
1079 const std::string ComplexControlFlowTestCase::getIntersectionPassthrough (void)
1080 {
1081         const std::string intersectionPassthrough =
1082                 "#version 460 core\n"
1083                 "#extension GL_EXT_nonuniform_qualifier : enable\n"
1084                 "#extension GL_EXT_ray_tracing : require\n"
1085                 "hitAttributeEXT vec3 hitAttribute;\n"
1086                 "\n"
1087                 "void main()\n"
1088                 "{\n"
1089                 "  reportIntersectionEXT(0.95f, 0u);\n"
1090                 "}\n";
1091
1092         return intersectionPassthrough;
1093 }
1094
1095 const std::string ComplexControlFlowTestCase::getMissPassthrough (void)
1096 {
1097         const std::string missPassthrough =
1098                 "#version 460 core\n"
1099                 "#extension GL_EXT_nonuniform_qualifier : enable\n"
1100                 "#extension GL_EXT_ray_tracing : require\n"
1101                 "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1102                 "\n"
1103                 "void main()\n"
1104                 "{\n"
1105                 "}\n";
1106
1107         return missPassthrough;
1108 }
1109
1110 const std::string ComplexControlFlowTestCase::getHitPassthrough (void)
1111 {
1112         const std::string hitPassthrough =
1113                 "#version 460 core\n"
1114                 "#extension GL_EXT_nonuniform_qualifier : enable\n"
1115                 "#extension GL_EXT_ray_tracing : require\n"
1116                 "hitAttributeEXT vec3 attribs;\n"
1117                 "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1118                 "\n"
1119                 "void main()\n"
1120                 "{\n"
1121                 "}\n";
1122
1123         return hitPassthrough;
1124 }
1125
1126 void ComplexControlFlowTestCase::initPrograms (SourceCollections& programCollection) const
1127 {
1128         const vk::ShaderBuildOptions    buildOptions                    (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
1129         const std::string                               calleeMainPart                  =
1130                 "  uint z = (inValue.x % 8) + 8;\n"
1131                 "  uint v = inValue.y;\n"
1132                 "  uint n = gl_LaunchIDEXT.x + gl_LaunchSizeEXT.x * gl_LaunchIDEXT.y;\n"
1133                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, z), uvec4(v, 0, 0, 1));\n"
1134                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 7), uvec4(n, 0, 0, 1));\n";
1135         const std::string                               idTemplate                              = "$";
1136         const std::string                               shaderCallInstruction   = (m_data.testOp == TEST_OP_EXECUTE_CALLABLE)    ? "executeCallableEXT(0, " + idTemplate + ")"
1137                                                                                                                         : (m_data.testOp == TEST_OP_TRACE_RAY)           ? "traceRayEXT(as, 0, 0xFF, p.hitOfs, 0, p.miss, vec3((gl_LaunchIDEXT.x) + vec3(0.5f)) / vec3(gl_LaunchSizeEXT), 1.0f, vec3(0.0f, 0.0f, 1.0f), 100.0f, " + idTemplate + ")"
1138                                                                                                                         : (m_data.testOp == TEST_OP_REPORT_INTERSECTION) ? "reportIntersectionEXT(1.0f, 0u)"
1139                                                                                                                         : "TEST_OP_NOT_IMPLEMENTED_FAILURE";
1140         std::string                                             declsPreMain                    =
1141                 "#version 460 core\n"
1142                 "#extension GL_EXT_nonuniform_qualifier : enable\n"
1143                 "#extension GL_EXT_ray_tracing : require\n"
1144                 "\n"
1145                 "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1146                 "layout(set = 0, binding = 1) uniform accelerationStructureEXT as;\n"
1147                 "\n"
1148                 "layout(push_constant) uniform TestParams\n"
1149                 "{\n"
1150                 "    uint a;\n"
1151                 "    uint b;\n"
1152                 "    uint c;\n"
1153                 "    uint d;\n"
1154                 "    uint hitOfs;\n"
1155                 "    uint miss;\n"
1156                 "} p;\n";
1157         std::string                                             declsInMainBeforeOp             =
1158                 "  uint result = 0;\n"
1159                 "  uint id = uint(gl_LaunchIDEXT.x + gl_LaunchSizeEXT.x * gl_LaunchIDEXT.y);\n";
1160         std::string                                             declsInMainAfterOp              =
1161                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 0), uvec4(result, 0, 0, 1));\n"
1162                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 1), uvec4(p.a, 0, 0, 1));\n"
1163                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 2), uvec4(p.b, 0, 0, 1));\n"
1164                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 3), uvec4(p.c, 0, 0, 1));\n"
1165                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 4), uvec4(p.d, 0, 0, 1));\n"
1166                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 5), uvec4(p.hitOfs, 0, 0, 1));\n"
1167                 "  imageStore(resultImage, ivec3(gl_LaunchIDEXT.x, gl_LaunchIDEXT.y, 6), uvec4(p.miss, 0, 0, 1));\n";
1168         std::string                                             opInMain                                = "";
1169         std::string                                             opPreMain                               = "";
1170
1171         DE_ASSERT(!declsPreMain.empty() && PUSH_CONSTANTS_COUNT == 6);
1172
1173         switch (m_data.testType)
1174         {
1175                 case TEST_TYPE_IF:
1176                 {
1177                         opInMain =
1178                                 "  v2 = v3 = uvec2(0, p.b);\n"
1179                                 "\n"
1180                                 "  if ((p.a & id) != 0)\n"
1181                                 "      { v0 = uvec2(0, p.c & id); v1 = uvec2(0, (p.d & id) + 1);" + replace(shaderCallInstruction, idTemplate, "0") + "; }\n"
1182                                 "  else\n"
1183                                 "      { v0 = uvec2(0, p.d & id); v1 = uvec2(0, (p.c & id) + 1);" + replace(shaderCallInstruction, idTemplate, "1") + "; }\n"
1184                                 "\n"
1185                                 "  result = v0.y + v1.y + v2.y + v3.y;\n";
1186
1187                         break;
1188                 }
1189                 case TEST_TYPE_LOOP:
1190                 {
1191                         opInMain =
1192                                 "  v1 = v3 = uvec2(0, p.b);\n"
1193                                 "\n"
1194                                 "  for (uint x = 0; x < p.a; x++)\n"
1195                                 "  {\n"
1196                                 "    v0 = uvec2(x, (p.c & id) + x);\n"
1197                                 "    " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1198                                 "    result += v0.y + v1.y + v3.y;\n"
1199                                 "  }\n";
1200
1201                         break;
1202                 }
1203                 case TEST_TYPE_SWITCH:
1204                 {
1205                         opInMain =
1206                                 "  switch (p.a & id)\n"
1207                                 "  {\n"
1208                                 "    case 0: { v1 = v2 = v3 = uvec2(0, p.b); v0 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "0") + "; break; }\n"
1209                                 "    case 1: { v0 = v2 = v3 = uvec2(0, p.b); v1 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "1") + "; break; }\n"
1210                                 "    case 2: { v0 = v1 = v3 = uvec2(0, p.b); v2 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "2") + "; break; }\n"
1211                                 "    case 3: { v0 = v1 = v2 = uvec2(0, p.b); v3 = uvec2(0, p.c & id); " + replace(shaderCallInstruction, idTemplate, "3") + "; break; }\n"
1212                                 "    default: break;\n"
1213                                 "  }\n"
1214                                 "\n"
1215                                 "  result = v0.y + v1.y + v2.y + v3.y;\n";
1216
1217                         break;
1218                 }
1219                 case TEST_TYPE_LOOP_DOUBLE_CALL:
1220                 {
1221                         opInMain =
1222                                 "  v3 = uvec2(0, p.b);\n"
1223                                 "  for (uint x = 0; x < p.a; x++)\n"
1224                                 "  {\n"
1225                                 "    v0 = uvec2(2 * x + 0, (p.c & id) + x);\n"
1226                                 "    v1 = uvec2(2 * x + 1, (p.d & id) + x + 1);\n"
1227                                 "    " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1228                                 "    " + replace(shaderCallInstruction, idTemplate, "1") + ";\n"
1229                                 "    result += v0.y + v1.y + v3.y;\n"
1230                                 "  }\n";
1231
1232                         break;
1233                 }
1234                 case TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE:
1235                 {
1236                         opInMain =
1237                                 "  v3 = uvec2(0, p.a + p.b);\n"
1238                                 "  for (uint x = 0; x < p.a; x++)\n"
1239                                 "    if ((x & p.b) != 0)\n"
1240                                 "    {\n"
1241                                 "      v0 = uvec2(2 * x + 0, (p.c & id) + x + 0);\n"
1242                                 "      v1 = uvec2(2 * x + 1, (p.d & id) + x + 1);\n"
1243                                 "      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1244                                 "      " + replace(shaderCallInstruction, idTemplate, "1") + ";\n"
1245                                 "      result += v0.y + v1.y + v3.y;\n"
1246                                 "    }\n"
1247                                 "\n";
1248
1249                         break;
1250                 }
1251                 case TEST_TYPE_NESTED_LOOP:
1252                 {
1253                         opInMain =
1254                                 "  v1 = v3 = uvec2(0, p.b);\n"
1255                                 "  for (uint y = 0; y < p.a; y++)\n"
1256                                 "  for (uint x = 0; x < p.a; x++)\n"
1257                                 "  {\n"
1258                                 "    uint n = x + y * p.a;\n"
1259                                 "    if ((n & p.d) != 0)\n"
1260                                 "    {\n"
1261                                 "      v0 = uvec2(n, (p.c & id) + (x + y * p.a));\n"
1262                                 "      "+ replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1263                                 "      result += v0.y + v1.y + v3.y;\n"
1264                                 "    }\n"
1265                                 "  }\n"
1266                                 "\n";
1267
1268                         break;
1269                 }
1270                 case TEST_TYPE_NESTED_LOOP_BEFORE:
1271                 {
1272                         opInMain =
1273                                 "  for (uint y = 0; y < p.d; y++)\n"
1274                                 "  for (uint x = 0; x < p.d; x++)\n"
1275                                 "    if (((x + y * p.a) & p.b) != 0)\n"
1276                                 "      result += (x + y);\n"
1277                                 "\n"
1278                                 "  v1 = v3 = uvec2(0, p.a);\n"
1279                                 "\n"
1280                                 "  for (uint x = 0; x < p.b; x++)\n"
1281                                 "    if ((x & p.a) != 0)\n"
1282                                 "    {\n"
1283                                 "      v0 = uvec2(x, p.c & id);\n"
1284                                 "      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1285                                 "      result += v0.y + v1.y + v3.y;\n"
1286                                 "    }\n";
1287
1288                         break;
1289                 }
1290                 case TEST_TYPE_NESTED_LOOP_AFTER:
1291                 {
1292                         opInMain =
1293                                 "  v1 = v3 = uvec2(0, p.a); \n"
1294                                 "  for (uint x = 0; x < p.b; x++)\n"
1295                                 "    if ((x & p.a) != 0)\n"
1296                                 "    {\n"
1297                                 "      v0 = uvec2(x, p.c & id);\n"
1298                                 "      " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1299                                 "      result += v0.y + v1.y + v3.y;\n"
1300                                 "    }\n"
1301                                 "\n"
1302                                 "  for (uint y = 0; y < p.d; y++)\n"
1303                                 "  for (uint x = 0; x < p.d; x++)\n"
1304                                 "    if (((x + y * p.a) & p.b) != 0)\n"
1305                                 "      result += x + y;\n";
1306
1307                         break;
1308                 }
1309                 case TEST_TYPE_FUNCTION_CALL:
1310                 {
1311                         opPreMain =
1312                                 "uint f1(void)\n"
1313                                 "{\n"
1314                                 "  uint i, r = 0;\n"
1315                                 "  uint a[42];\n"
1316                                 "\n"
1317                                 "  for (i = 0; i < a.length(); i++) a[i] = p.c * i;\n"
1318                                 "\n"
1319                                 "  " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1320                                 "\n"
1321                                 "  for (i = 0; i < a.length(); i++) r += a[i];\n"
1322                                 "\n"
1323                                 "  return r + i;\n"
1324                                 "}\n";
1325                         opInMain =
1326                                 "  v0 = uvec2(0, p.a & id); v1 = v3 = uvec2(0, p.d);\n"
1327                                 "  result = f1() + v0.y + v1.y + v3.y;\n";
1328
1329                         break;
1330                 }
1331                 case TEST_TYPE_NESTED_FUNCTION_CALL:
1332                 {
1333                         opPreMain =
1334                                 "uint f0(void)\n"
1335                                 "{\n"
1336                                 "  uint i, r = 0;\n"
1337                                 "  uint a[14];\n"
1338                                 "\n"
1339                                 "  for (i = 0; i < a.length(); i++) a[i] = p.c * i;\n"
1340                                 "\n"
1341                                 "  " + replace(shaderCallInstruction, idTemplate, "0") + ";\n"
1342                                 "\n"
1343                                 "  for (i = 0; i < a.length(); i++) r += a[i];\n"
1344                                 "\n"
1345                                 "  return r + i;\n"
1346                                 "}\n"
1347                                 "\n"
1348                                 "uint f1(void)\n"
1349                                 "{\n"
1350                                 "  uint j, t = 0;\n"
1351                                 "  uint b[256];\n"
1352                                 "\n"
1353                                 "  for (j = 0; j < b.length(); j++) b[j] = p.c * j;\n"
1354                                 "\n"
1355                                 "  v1 = uvec2(0, p.b);\n"
1356                                 "\n"
1357                                 "  t += f0();\n"
1358                                 "\n"
1359                                 "  for (j = 0; j < b.length(); j++) t += b[j];\n"
1360                                 "\n"
1361                                 "  return t + j;\n"
1362                                 "}\n";
1363                         opInMain =
1364                                 "  v0 = uvec2(0, p.a & id); v3 = uvec2(0, p.d);\n"
1365                                 "  result = f1() + v0.y + v1.y + v3.y;\n";
1366
1367                         break;
1368                 }
1369
1370                 default:
1371                         TCU_THROW(InternalError, "Unknown testType");
1372         }
1373
1374         if (m_data.testOp == TEST_OP_EXECUTE_CALLABLE)
1375         {
1376                 const std::string       calleeShader                    =
1377                         "#version 460 core\n"
1378                         "#extension GL_EXT_nonuniform_qualifier : enable\n"
1379                         "#extension GL_EXT_ray_tracing : require\n"
1380                         "\n"
1381                         "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1382                         "layout(location = 0) callableDataInEXT uvec2 inValue;\n"
1383                         "\n"
1384                         "void main()\n"
1385                         "{\n"
1386                         + calleeMainPart +
1387                         "  inValue.y++;\n"
1388                         "}\n";
1389
1390                 declsPreMain +=
1391                         "layout(location = 0) callableDataEXT uvec2 v0;\n"
1392                         "layout(location = 1) callableDataEXT uvec2 v1;\n"
1393                         "layout(location = 2) callableDataEXT uvec2 v2;\n"
1394                         "layout(location = 3) callableDataEXT uvec2 v3;\n"
1395                         "\n";
1396
1397                 switch (m_data.stage)
1398                 {
1399                         case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1400                         {
1401                                 std::stringstream css;
1402                                 css << declsPreMain
1403                                         << opPreMain
1404                                         << "\n"
1405                                         << "void main()\n"
1406                                         << "{\n"
1407                                         << declsInMainBeforeOp
1408                                         << opInMain // executeCallableEXT
1409                                         << declsInMainAfterOp
1410                                         << "}\n";
1411
1412                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1413                                 programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1414
1415                                 break;
1416                         }
1417
1418                         case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1419                         {
1420                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1421
1422                                 std::stringstream css;
1423                                 css << declsPreMain
1424                                         << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1425                                         << "hitAttributeEXT vec3 attribs;\n"
1426                                         << "\n"
1427                                         << opPreMain
1428                                         << "\n"
1429                                         << "void main()\n"
1430                                         << "{\n"
1431                                         << declsInMainBeforeOp
1432                                         << opInMain // executeCallableEXT
1433                                         << declsInMainAfterOp
1434                                         << "}\n";
1435
1436                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1437                                 programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1438
1439                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1440                                 programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1441                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1442
1443                                 break;
1444                         }
1445
1446                         case VK_SHADER_STAGE_MISS_BIT_KHR:
1447                         {
1448                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1449
1450                                 std::stringstream css;
1451                                 css << declsPreMain
1452                                         << opPreMain
1453                                         << "\n"
1454                                         << "void main()\n"
1455                                         << "{\n"
1456                                         << declsInMainBeforeOp
1457                                         << opInMain // executeCallableEXT
1458                                         << declsInMainAfterOp
1459                                         << "}\n";
1460
1461                                 programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1462                                 programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1463
1464                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1465                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1466                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1467
1468                                 break;
1469                         }
1470
1471                         case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
1472                         {
1473                                 {
1474                                         std::stringstream css;
1475                                         css << "#version 460 core\n"
1476                                                 << "#extension GL_EXT_nonuniform_qualifier : enable\n"
1477                                                 << "#extension GL_EXT_ray_tracing : require\n"
1478                                                 << "\n"
1479                                                 << "layout(location = 4) callableDataEXT float dummy;\n"
1480                                                 << "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1481                                                 << "\n"
1482                                                 << "void main()\n"
1483                                                 << "{\n"
1484                                                 << "  executeCallableEXT(1, 4);\n"
1485                                                 << "}\n";
1486
1487                                         programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1488                                 }
1489
1490                                 {
1491                                         std::stringstream css;
1492                                         css << declsPreMain
1493                                                 << "layout(location = 4) callableDataInEXT float dummyIn;\n"
1494                                                 << opPreMain
1495                                                 << "\n"
1496                                                 << "void main()\n"
1497                                                 << "{\n"
1498                                                 << declsInMainBeforeOp
1499                                                 << opInMain // executeCallableEXT
1500                                                 << declsInMainAfterOp
1501                                                 << "}\n";
1502
1503                                         programCollection.glslSources.add("call") << glu::CallableSource(css.str()) << buildOptions;
1504                                 }
1505
1506                                 programCollection.glslSources.add("cal0") << glu::CallableSource(calleeShader) << buildOptions;
1507
1508                                 break;
1509                         }
1510
1511                         default:
1512                                 TCU_THROW(InternalError, "Unknown stage");
1513                 }
1514         }
1515         else if (m_data.testOp == TEST_OP_TRACE_RAY)
1516         {
1517                 const std::string       missShader      =
1518                         "#version 460 core\n"
1519                         "#extension GL_EXT_nonuniform_qualifier : enable\n"
1520                         "#extension GL_EXT_ray_tracing : require\n"
1521                         "\n"
1522                         "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1523                         "layout(location = 0) rayPayloadInEXT uvec2 inValue;\n"
1524                         "\n"
1525                         "void main()\n"
1526                         "{\n"
1527                         + calleeMainPart +
1528                         "  inValue.y++;\n"
1529                         "}\n";
1530
1531                 declsPreMain +=
1532                         "layout(location = 0) rayPayloadEXT uvec2 v0;\n"
1533                         "layout(location = 1) rayPayloadEXT uvec2 v1;\n"
1534                         "layout(location = 2) rayPayloadEXT uvec2 v2;\n"
1535                         "layout(location = 3) rayPayloadEXT uvec2 v3;\n";
1536
1537                 switch (m_data.stage)
1538                 {
1539                         case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1540                         {
1541                                 std::stringstream css;
1542                                 css << declsPreMain
1543                                         << opPreMain
1544                                         << "\n"
1545                                         << "void main()\n"
1546                                         << "{\n"
1547                                         << declsInMainBeforeOp
1548                                         << opInMain // traceRayEXT
1549                                         << declsInMainAfterOp
1550                                         << "}\n";
1551
1552                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(css.str()) << buildOptions;
1553
1554                                 programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1555                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1556                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1557                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1558
1559                                 programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1560                                 programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1561                                 programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1562                                 programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1563
1564                                 break;
1565                         }
1566
1567                         case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1568                         {
1569                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1570
1571                                 std::stringstream css;
1572                                 css << declsPreMain
1573                                         << opPreMain
1574                                         << "\n"
1575                                         << "void main()\n"
1576                                         << "{\n"
1577                                         << declsInMainBeforeOp
1578                                         << opInMain // traceRayEXT
1579                                         << declsInMainAfterOp
1580                                         << "}\n";
1581
1582                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(css.str()) << buildOptions;
1583
1584                                 programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1585                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1586                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1587
1588                                 programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1589                                 programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1590                                 programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1591                                 programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1592
1593                                 break;
1594                         }
1595
1596                         case VK_SHADER_STAGE_MISS_BIT_KHR:
1597                         {
1598                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1599
1600                                 std::stringstream css;
1601                                 css << declsPreMain
1602                                         << opPreMain
1603                                         << "\n"
1604                                         << "void main()\n"
1605                                         << "{\n"
1606                                         << declsInMainBeforeOp
1607                                         << opInMain // traceRayEXT
1608                                         << declsInMainAfterOp
1609                                         << "}\n";
1610
1611                                 programCollection.glslSources.add("miss") << glu::MissSource(css.str()) << buildOptions;
1612
1613                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1614                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1615                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1616
1617                                 programCollection.glslSources.add("miss2") << glu::MissSource(missShader) << buildOptions;
1618                                 programCollection.glslSources.add("ahit2") << glu::AnyHitSource(getHitPassthrough()) << buildOptions;
1619                                 programCollection.glslSources.add("chit2") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1620                                 programCollection.glslSources.add("sect2") << glu::IntersectionSource(getIntersectionPassthrough()) << buildOptions;
1621
1622                                 break;
1623                         }
1624
1625                         default:
1626                                 TCU_THROW(InternalError, "Unknown stage");
1627                 }
1628         }
1629         else if (m_data.testOp == TEST_OP_REPORT_INTERSECTION)
1630         {
1631                 const std::string       anyHitShader            =
1632                         "#version 460 core\n"
1633                         "#extension GL_EXT_nonuniform_qualifier : enable\n"
1634                         "#extension GL_EXT_ray_tracing : require\n"
1635                         "\n"
1636                         "layout(set = 0, binding = 0, r32ui) uniform uimage3D resultImage;\n"
1637                         "hitAttributeEXT block { uvec2 inValue; };\n"
1638                         "\n"
1639                         "void main()\n"
1640                         "{\n"
1641                         + calleeMainPart +
1642                         "}\n";
1643
1644                 declsPreMain +=
1645                         "hitAttributeEXT block { uvec2 v0; };\n"
1646                         "uvec2 v1;\n"
1647                         "uvec2 v2;\n"
1648                         "uvec2 v3;\n";
1649
1650                 switch (m_data.stage)
1651                 {
1652                         case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
1653                         {
1654                                 programCollection.glslSources.add("rgen") << glu::RaygenSource(getCommonRayGenerationShader()) << buildOptions;
1655
1656                                 std::stringstream css;
1657                                 css << declsPreMain
1658                                         << opPreMain
1659                                         << "\n"
1660                                         << "void main()\n"
1661                                         << "{\n"
1662                                         << declsInMainBeforeOp
1663                                         << opInMain // reportIntersectionEXT
1664                                         << declsInMainAfterOp
1665                                         << "}\n";
1666
1667                                 programCollection.glslSources.add("sect") << glu::IntersectionSource(css.str()) << buildOptions;
1668                                 programCollection.glslSources.add("ahit") << glu::AnyHitSource(anyHitShader) << buildOptions;
1669
1670                                 programCollection.glslSources.add("chit") << glu::ClosestHitSource(getHitPassthrough()) << buildOptions;
1671                                 programCollection.glslSources.add("miss") << glu::MissSource(getMissPassthrough()) << buildOptions;
1672
1673                                 break;
1674                         }
1675
1676                         default:
1677                                 TCU_THROW(InternalError, "Unknown stage");
1678                 }
1679         }
1680         else
1681         {
1682                 TCU_THROW(InternalError, "Unknown operation");
1683         }
1684 }
1685
1686 TestInstance* ComplexControlFlowTestCase::createInstance (Context& context) const
1687 {
1688         return new RayTracingComplexControlFlowInstance(context, m_data);
1689 }
1690
1691 }       // anonymous
1692
1693 tcu::TestCaseGroup*     createComplexControlFlowTests (tcu::TestContext& testCtx)
1694 {
1695         const VkShaderStageFlagBits     R       = VK_SHADER_STAGE_RAYGEN_BIT_KHR;
1696         const VkShaderStageFlagBits     A       = VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
1697         const VkShaderStageFlagBits     C       = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
1698         const VkShaderStageFlagBits     M       = VK_SHADER_STAGE_MISS_BIT_KHR;
1699         const VkShaderStageFlagBits     I       = VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
1700         const VkShaderStageFlagBits     L       = VK_SHADER_STAGE_CALLABLE_BIT_KHR;
1701
1702         DE_UNREF(A);
1703
1704         static const struct
1705         {
1706                 const char*                             name;
1707                 VkShaderStageFlagBits   stage;
1708         }
1709         testStages[]
1710         {
1711                 { "rgen", VK_SHADER_STAGE_RAYGEN_BIT_KHR                },
1712                 { "chit", VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR   },
1713                 { "ahit", VK_SHADER_STAGE_ANY_HIT_BIT_KHR               },
1714                 { "sect", VK_SHADER_STAGE_INTERSECTION_BIT_KHR  },
1715                 { "miss", VK_SHADER_STAGE_MISS_BIT_KHR                  },
1716                 { "call", VK_SHADER_STAGE_CALLABLE_BIT_KHR              },
1717         };
1718         static const struct
1719         {
1720                 const char*                     name;
1721                 TestOp                          op;
1722                 VkShaderStageFlags      applicableInStages;
1723         }
1724         testOps[]
1725         {
1726                 { "execute_callable",           TEST_OP_EXECUTE_CALLABLE,               R |    C | M     | L },
1727                 { "trace_ray",                          TEST_OP_TRACE_RAY,                              R |    C | M         },
1728                 { "report_intersection",        TEST_OP_REPORT_INTERSECTION,                   I     },
1729         };
1730         static const struct
1731         {
1732                 const char*     name;
1733                 TestType        testType;
1734         }
1735         testTypes[]
1736         {
1737                 { "if",                                                 TEST_TYPE_IF                                            },
1738                 { "loop",                                               TEST_TYPE_LOOP                                          },
1739                 { "switch",                                             TEST_TYPE_SWITCH                                        },
1740                 { "loop_double_call",                   TEST_TYPE_LOOP_DOUBLE_CALL                      },
1741                 { "loop_double_call_sparse",    TEST_TYPE_LOOP_DOUBLE_CALL_SPARSE       },
1742                 { "nested_loop",                                TEST_TYPE_NESTED_LOOP                           },
1743                 { "nested_loop_loop_before",    TEST_TYPE_NESTED_LOOP_BEFORE            },
1744                 { "nested_loop_loop_after",             TEST_TYPE_NESTED_LOOP_AFTER                     },
1745                 { "function_call",                              TEST_TYPE_FUNCTION_CALL                         },
1746                 { "nested_function_call",               TEST_TYPE_NESTED_FUNCTION_CALL          },
1747         };
1748
1749         de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "complexcontrolflow", "Ray tracing complex control flow tests"));
1750
1751         for (size_t testTypeNdx = 0; testTypeNdx < DE_LENGTH_OF_ARRAY(testTypes); ++testTypeNdx)
1752         {
1753                 const TestType                                  testType                = testTypes[testTypeNdx].testType;
1754                 de::MovePtr<tcu::TestCaseGroup> testTypeGroup   (new tcu::TestCaseGroup(testCtx, testTypes[testTypeNdx].name, ""));
1755
1756                 for (size_t testOpNdx = 0; testOpNdx < DE_LENGTH_OF_ARRAY(testOps); ++testOpNdx)
1757                 {
1758                         const TestOp                                    testOp          = testOps[testOpNdx].op;
1759                         de::MovePtr<tcu::TestCaseGroup> testOpGroup     (new tcu::TestCaseGroup(testCtx, testOps[testOpNdx].name, ""));
1760
1761                         for (size_t testStagesNdx = 0; testStagesNdx < DE_LENGTH_OF_ARRAY(testStages); ++testStagesNdx)
1762                         {
1763                                 const VkShaderStageFlagBits     testStage                               = testStages[testStagesNdx].stage;
1764                                 const std::string                       testName                                = de::toString(testStages[testStagesNdx].name);
1765                                 const deUint32                          width                                   = 4u;
1766                                 const deUint32                          height                                  = 4u;
1767                                 const CaseDef                           caseDef                                 =
1768                                 {
1769                                         testType,                               //  TestType                            testType;
1770                                         testOp,                                 //  TestOp                                      testOp;
1771                                         testStage,                              //  VkShaderStageFlagBits       stage;
1772                                         width,                                  //  deUint32                            width;
1773                                         height,                                 //  deUint32                            height;
1774                                 };
1775
1776                                 if ((testOps[testOpNdx].applicableInStages & static_cast<VkShaderStageFlags>(testStage)) == 0)
1777                                         continue;
1778
1779                                 testOpGroup->addChild(new ComplexControlFlowTestCase(testCtx, testName.c_str(), "", caseDef));
1780                         }
1781
1782                         testTypeGroup->addChild(testOpGroup.release());
1783                 }
1784
1785                 group->addChild(testTypeGroup.release());
1786         }
1787
1788         return group.release();
1789 }
1790
1791 }       // RayTracing
1792 }       // vkt