Tests for VK_EXT_shader_module_identifier
[platform/upstream/VK-GL-CTS.git] / external / vulkancts / framework / vulkan / vkRayTracingUtil.hpp
1 #ifndef _VKRAYTRACINGUTIL_HPP
2 #define _VKRAYTRACINGUTIL_HPP
3 /*-------------------------------------------------------------------------
4  * Vulkan CTS Framework
5  * --------------------
6  *
7  * Copyright (c) 2020 The Khronos Group Inc.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  *//*!
22  * \file
23  * \brief Vulkan ray tracing utility.
24  *//*--------------------------------------------------------------------*/
25
26 #include "vkDefs.hpp"
27 #include "vkRef.hpp"
28 #include "vkMemUtil.hpp"
29 #include "vkBufferWithMemory.hpp"
30
31 #include "deFloat16.h"
32
33 #include "tcuVector.hpp"
34 #include "tcuVectorType.hpp"
35
36 #include <vector>
37 #include <limits>
38 #include <stdexcept>
39
40 namespace vk
41 {
42 constexpr VkShaderStageFlags    SHADER_STAGE_ALL_RAY_TRACING    = VK_SHADER_STAGE_RAYGEN_BIT_KHR
43                                                                                                                                 | VK_SHADER_STAGE_ANY_HIT_BIT_KHR
44                                                                                                                                 | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR
45                                                                                                                                 | VK_SHADER_STAGE_MISS_BIT_KHR
46                                                                                                                                 | VK_SHADER_STAGE_INTERSECTION_BIT_KHR
47                                                                                                                                 | VK_SHADER_STAGE_CALLABLE_BIT_KHR;
48
49 const VkTransformMatrixKHR identityMatrix3x4 = { { { 1.0f, 0.0f, 0.0f, 0.0f }, { 0.0f, 1.0f, 0.0f, 0.0f }, { 0.0f, 0.0f, 1.0f, 0.0f } } };
50
51 template<typename T>
52 inline de::SharedPtr<Move<T>> makeVkSharedPtr(Move<T> move)
53 {
54         return de::SharedPtr<Move<T>>(new Move<T>(move));
55 }
56
57 template<typename T>
58 inline de::SharedPtr<de::MovePtr<T> > makeVkSharedPtr(de::MovePtr<T> movePtr)
59 {
60         return de::SharedPtr<de::MovePtr<T> >(new de::MovePtr<T>(movePtr));
61 }
62
63 inline std::string updateRayTracingGLSL (const std::string& str)
64 {
65         return str;
66 }
67
68 std::string getCommonRayGenerationShader (void);
69
70 // Get lowercase version of the format name with no VK_FORMAT_ prefix.
71 std::string getFormatSimpleName (vk::VkFormat format);
72
73 // Test whether given poin p belons to the triangle (p0, p1, p2)
74 bool pointInTriangle2D(const tcu::Vec3& p, const tcu::Vec3& p0, const tcu::Vec3& p1, const tcu::Vec3& p2);
75
76 // Checks the given vertex buffer format is valid for acceleration structures.
77 // Note: VK_KHR_get_physical_device_properties2 and VK_KHR_acceleration_structure are supposed to be supported.
78 void checkAccelerationStructureVertexBufferFormat (const vk::InstanceInterface &vki, vk::VkPhysicalDevice physicalDevice, vk::VkFormat format);
79
80 class RaytracedGeometryBase
81 {
82 public:
83                                                                 RaytracedGeometryBase                   ()                                                                              = delete;
84                                                                 RaytracedGeometryBase                   (const RaytracedGeometryBase& geometry) = delete;
85                                                                 RaytracedGeometryBase                   (VkGeometryTypeKHR geometryType, VkFormat vertexFormat, VkIndexType indexType);
86                                                                 virtual ~RaytracedGeometryBase  ();
87
88         inline VkGeometryTypeKHR        getGeometryType                                 (void) const                                                            { return m_geometryType; }
89         inline bool                                     isTrianglesType                                 (void) const                                                            { return m_geometryType == VK_GEOMETRY_TYPE_TRIANGLES_KHR; }
90         inline VkFormat                         getVertexFormat                                 (void) const                                                            { return m_vertexFormat; }
91         inline VkIndexType                      getIndexType                                    (void) const                                                            { return m_indexType; }
92         inline bool                                     usesIndices                                             (void) const                                                            { return m_indexType != VK_INDEX_TYPE_NONE_KHR; }
93         inline VkGeometryFlagsKHR       getGeometryFlags                                (void) const                                                            { return m_geometryFlags; }
94         inline void                                     setGeometryFlags                                (const VkGeometryFlagsKHR geometryFlags)        { m_geometryFlags = geometryFlags; }
95         virtual deUint32                        getVertexCount                                  (void) const                                                            = 0;
96         virtual const deUint8*          getVertexPointer                                (void) const                                                            = 0;
97         virtual VkDeviceSize            getVertexStride                                 (void) const                                                            = 0;
98         virtual VkDeviceSize            getAABBStride                                   (void) const                                                            = 0;
99         virtual size_t                          getVertexByteSize                               (void) const                                                            = 0;
100         virtual deUint32                        getIndexCount                                   (void) const                                                            = 0;
101         virtual const deUint8*          getIndexPointer                                 (void) const                                                            = 0;
102         virtual VkDeviceSize            getIndexStride                                  (void) const                                                            = 0;
103         virtual size_t                          getIndexByteSize                                (void) const                                                            = 0;
104         virtual deUint32                        getPrimitiveCount                               (void) const                                                            = 0;
105         virtual void                            addVertex                                               (const tcu::Vec3& vertex)                                       = 0;
106         virtual void                            addIndex                                                (const deUint32& index)                                         = 0;
107 private:
108         VkGeometryTypeKHR                       m_geometryType;
109         VkFormat                                        m_vertexFormat;
110         VkIndexType                                     m_indexType;
111         VkGeometryFlagsKHR                      m_geometryFlags;
112 };
113
114 template <typename T>
115 inline T convertSatRte (float f)
116 {
117         // \note Doesn't work for 64-bit types
118         DE_STATIC_ASSERT(sizeof(T) < sizeof(deUint64));
119         DE_STATIC_ASSERT((-3 % 2 != 0) && (-4 % 2 == 0));
120
121         deInt64 minVal  = std::numeric_limits<T>::min();
122         deInt64 maxVal  = std::numeric_limits<T>::max();
123         float   q               = deFloatFrac(f);
124         deInt64 intVal  = (deInt64)(f-q);
125
126         // Rounding.
127         if (q == 0.5f)
128         {
129                 if (intVal % 2 != 0)
130                         intVal++;
131         }
132         else if (q > 0.5f)
133                 intVal++;
134         // else Don't add anything
135
136         // Saturate.
137         intVal = de::max(minVal, de::min(maxVal, intVal));
138
139         return (T)intVal;
140 }
141
142 // Converts float to signed integer with variable width.
143 // Source float is assumed to be in the [-1, 1] range.
144 template <typename T>
145 inline T deFloat32ToSNorm (float src)
146 {
147         DE_STATIC_ASSERT(std::numeric_limits<T>::is_integer && std::numeric_limits<T>::is_signed);
148         const T range   = std::numeric_limits<T>::max();
149         const T intVal  = convertSatRte<T>(src * static_cast<float>(range));
150         return de::clamp<T>(intVal, -range, range);
151 }
152
153 typedef tcu::Vector<deFloat16, 2>                       Vec2_16;
154 typedef tcu::Vector<deFloat16, 3>                       Vec3_16;
155 typedef tcu::Vector<deFloat16, 4>                       Vec4_16;
156 typedef tcu::Vector<deInt16, 2>                         Vec2_16SNorm;
157 typedef tcu::Vector<deInt16, 3>                         Vec3_16SNorm;
158 typedef tcu::Vector<deInt16, 4>                         Vec4_16SNorm;
159 typedef tcu::Vector<deInt8, 2>                          Vec2_8SNorm;
160 typedef tcu::Vector<deInt8, 3>                          Vec3_8SNorm;
161 typedef tcu::Vector<deInt8, 4>                          Vec4_8SNorm;
162
163 template<typename V>    VkFormat                        vertexFormatFromType                            ();
164 template<>                              inline VkFormat         vertexFormatFromType<tcu::Vec2>         ()                                                      { return VK_FORMAT_R32G32_SFLOAT; }
165 template<>                              inline VkFormat         vertexFormatFromType<tcu::Vec3>         ()                                                      { return VK_FORMAT_R32G32B32_SFLOAT; }
166 template<>                              inline VkFormat         vertexFormatFromType<tcu::Vec4>         ()                                                      { return VK_FORMAT_R32G32B32A32_SFLOAT; }
167 template<>                              inline VkFormat         vertexFormatFromType<Vec2_16>           ()                                                      { return VK_FORMAT_R16G16_SFLOAT; }
168 template<>                              inline VkFormat         vertexFormatFromType<Vec3_16>           ()                                                      { return VK_FORMAT_R16G16B16_SFLOAT; }
169 template<>                              inline VkFormat         vertexFormatFromType<Vec4_16>           ()                                                      { return VK_FORMAT_R16G16B16A16_SFLOAT; }
170 template<>                              inline VkFormat         vertexFormatFromType<Vec2_16SNorm>      ()                                                      { return VK_FORMAT_R16G16_SNORM; }
171 template<>                              inline VkFormat         vertexFormatFromType<Vec3_16SNorm>      ()                                                      { return VK_FORMAT_R16G16B16_SNORM; }
172 template<>                              inline VkFormat         vertexFormatFromType<Vec4_16SNorm>      ()                                                      { return VK_FORMAT_R16G16B16A16_SNORM; }
173 template<>                              inline VkFormat         vertexFormatFromType<tcu::DVec2>        ()                                                      { return VK_FORMAT_R64G64_SFLOAT; }
174 template<>                              inline VkFormat         vertexFormatFromType<tcu::DVec3>        ()                                                      { return VK_FORMAT_R64G64B64_SFLOAT; }
175 template<>                              inline VkFormat         vertexFormatFromType<tcu::DVec4>        ()                                                      { return VK_FORMAT_R64G64B64A64_SFLOAT; }
176 template<>                              inline VkFormat         vertexFormatFromType<Vec2_8SNorm>       ()                                                      { return VK_FORMAT_R8G8_SNORM; }
177 template<>                              inline VkFormat         vertexFormatFromType<Vec3_8SNorm>       ()                                                      { return VK_FORMAT_R8G8B8_SNORM; }
178 template<>                              inline VkFormat         vertexFormatFromType<Vec4_8SNorm>       ()                                                      { return VK_FORMAT_R8G8B8A8_SNORM; }
179
180 struct EmptyIndex {};
181 template<typename I>    VkIndexType                     indexTypeFromType                                       ();
182 template<>                              inline VkIndexType      indexTypeFromType<deUint16>                     ()                                                      { return VK_INDEX_TYPE_UINT16; }
183 template<>                              inline VkIndexType      indexTypeFromType<deUint32>                     ()                                                      { return VK_INDEX_TYPE_UINT32; }
184 template<>                              inline VkIndexType      indexTypeFromType<EmptyIndex>           ()                                                      { return VK_INDEX_TYPE_NONE_KHR; }
185
186 template<typename V>    V                                       convertFloatTo                                          (const tcu::Vec3& vertex);
187 template<>                              inline tcu::Vec2        convertFloatTo<tcu::Vec2>                       (const tcu::Vec3& vertex)       { return tcu::Vec2(vertex.x(), vertex.y()); }
188 template<>                              inline tcu::Vec3        convertFloatTo<tcu::Vec3>                       (const tcu::Vec3& vertex)       { return vertex; }
189 template<>                              inline tcu::Vec4        convertFloatTo<tcu::Vec4>                       (const tcu::Vec3& vertex)       { return tcu::Vec4(vertex.x(), vertex.y(), vertex.z(), 0.0f); }
190 template<>                              inline Vec2_16          convertFloatTo<Vec2_16>                         (const tcu::Vec3& vertex)       { return Vec2_16(deFloat32To16(vertex.x()), deFloat32To16(vertex.y())); }
191 template<>                              inline Vec3_16          convertFloatTo<Vec3_16>                         (const tcu::Vec3& vertex)       { return Vec3_16(deFloat32To16(vertex.x()), deFloat32To16(vertex.y()), deFloat32To16(vertex.z())); }
192 template<>                              inline Vec4_16          convertFloatTo<Vec4_16>                         (const tcu::Vec3& vertex)       { return Vec4_16(deFloat32To16(vertex.x()), deFloat32To16(vertex.y()), deFloat32To16(vertex.z()), deFloat32To16(0.0f)); }
193 template<>                              inline Vec2_16SNorm     convertFloatTo<Vec2_16SNorm>            (const tcu::Vec3& vertex)       { return Vec2_16SNorm(deFloat32ToSNorm<deInt16>(vertex.x()), deFloat32ToSNorm<deInt16>(vertex.y())); }
194 template<>                              inline Vec3_16SNorm     convertFloatTo<Vec3_16SNorm>            (const tcu::Vec3& vertex)       { return Vec3_16SNorm(deFloat32ToSNorm<deInt16>(vertex.x()), deFloat32ToSNorm<deInt16>(vertex.y()), deFloat32ToSNorm<deInt16>(vertex.z())); }
195 template<>                              inline Vec4_16SNorm     convertFloatTo<Vec4_16SNorm>            (const tcu::Vec3& vertex)       { return Vec4_16SNorm(deFloat32ToSNorm<deInt16>(vertex.x()), deFloat32ToSNorm<deInt16>(vertex.y()), deFloat32ToSNorm<deInt16>(vertex.z()), deFloat32ToSNorm<deInt16>(0.0f)); }
196 template<>                              inline tcu::DVec2       convertFloatTo<tcu::DVec2>                      (const tcu::Vec3& vertex)       { return tcu::DVec2(static_cast<double>(vertex.x()), static_cast<double>(vertex.y())); }
197 template<>                              inline tcu::DVec3       convertFloatTo<tcu::DVec3>                      (const tcu::Vec3& vertex)       { return tcu::DVec3(static_cast<double>(vertex.x()), static_cast<double>(vertex.y()), static_cast<double>(vertex.z())); }
198 template<>                              inline tcu::DVec4       convertFloatTo<tcu::DVec4>                      (const tcu::Vec3& vertex)       { return tcu::DVec4(static_cast<double>(vertex.x()), static_cast<double>(vertex.y()), static_cast<double>(vertex.z()), 0.0); }
199 template<>                              inline Vec2_8SNorm      convertFloatTo<Vec2_8SNorm>                     (const tcu::Vec3& vertex)       { return Vec2_8SNorm(deFloat32ToSNorm<deInt8>(vertex.x()), deFloat32ToSNorm<deInt8>(vertex.y())); }
200 template<>                              inline Vec3_8SNorm      convertFloatTo<Vec3_8SNorm>                     (const tcu::Vec3& vertex)       { return Vec3_8SNorm(deFloat32ToSNorm<deInt8>(vertex.x()), deFloat32ToSNorm<deInt8>(vertex.y()), deFloat32ToSNorm<deInt8>(vertex.z())); }
201 template<>                              inline Vec4_8SNorm      convertFloatTo<Vec4_8SNorm>                     (const tcu::Vec3& vertex)       { return Vec4_8SNorm(deFloat32ToSNorm<deInt8>(vertex.x()), deFloat32ToSNorm<deInt8>(vertex.y()), deFloat32ToSNorm<deInt8>(vertex.z()), deFloat32ToSNorm<deInt8>(0.0f)); }
202
203 template<typename V>    V                                       convertIndexTo                                          (deUint32 index);
204 template<>                              inline EmptyIndex       convertIndexTo<EmptyIndex>                      (deUint32 index)                        { DE_UNREF(index); TCU_THROW(TestError, "Cannot add empty index"); }
205 template<>                              inline deUint16         convertIndexTo<deUint16>                        (deUint32 index)                        { return static_cast<deUint16>(index); }
206 template<>                              inline deUint32         convertIndexTo<deUint32>                        (deUint32 index)                        { return index; }
207
208 template<typename V, typename I>
209 class RaytracedGeometry : public RaytracedGeometryBase
210 {
211 public:
212                                                 RaytracedGeometry                       ()                                                                      = delete;
213                                                 RaytracedGeometry                       (const RaytracedGeometry& geometry)     = delete;
214                                                 RaytracedGeometry                       (VkGeometryTypeKHR geometryType, deUint32 paddingBlocks = 0u);
215                                                 RaytracedGeometry                       (VkGeometryTypeKHR geometryType, const std::vector<V>& vertices, const std::vector<I>& indices = std::vector<I>(), deUint32 paddingBlocks = 0u);
216
217         deUint32                        getVertexCount                          (void) const override;
218         const deUint8*          getVertexPointer                        (void) const override;
219         VkDeviceSize            getVertexStride                         (void) const override;
220         VkDeviceSize            getAABBStride                           (void) const override;
221         size_t                          getVertexByteSize                       (void) const override;
222         deUint32                        getIndexCount                           (void) const override;
223         const deUint8*          getIndexPointer                         (void) const override;
224         VkDeviceSize            getIndexStride                          (void) const override;
225         size_t                          getIndexByteSize                        (void) const override;
226         deUint32                        getPrimitiveCount                       (void) const override;
227
228         void                            addVertex                                       (const tcu::Vec3& vertex) override;
229         void                            addIndex                                        (const deUint32& index) override;
230
231 private:
232         void                            init                                            ();                                     // To be run in constructors.
233         void                            checkGeometryType                       () const;                       // Checks geometry type is valid.
234         void                            calcBlockSize                           ();                                     // Calculates and saves vertex buffer block size.
235         size_t                          getBlockSize                            () const;                       // Return stored vertex buffer block size.
236         void                            addNativeVertex                         (const V& vertex);      // Adds new vertex in native format.
237
238         // The implementation below stores vertices as byte blocks to take the requested padding into account. m_vertices is the array
239         // of bytes containing vertex data.
240         //
241         // For triangles, the padding block has a size that is a multiple of the vertex size and each vertex is stored in a byte block
242         // equivalent to:
243         //
244         //      struct Vertex
245         //      {
246         //              V               vertex;
247         //              deUint8 padding[m_paddingBlocks * sizeof(V)];
248         //      };
249         //
250         // For AABBs, the padding block has a size that is a multiple of kAABBPadBaseSize (see below) and vertices are stored in pairs
251         // before the padding block. This is equivalent to:
252         //
253         //              struct VertexPair
254         //              {
255         //                      V               vertices[2];
256         //                      deUint8 padding[m_paddingBlocks * kAABBPadBaseSize];
257         //              };
258         //
259         // The size of each pseudo-structure above is saved to one of the correspoding union members below.
260         union BlockSize
261         {
262                 size_t trianglesBlockSize;
263                 size_t aabbsBlockSize;
264         };
265
266         const deUint32                  m_paddingBlocks;
267         size_t                                  m_vertexCount;
268         std::vector<deUint8>    m_vertices;                     // Vertices are stored as byte blocks.
269         std::vector<I>                  m_indices;                      // Indices are stored natively.
270         BlockSize                               m_blockSize;            // For m_vertices.
271
272         // Data sizes.
273         static constexpr size_t kVertexSize                     = sizeof(V);
274         static constexpr size_t kIndexSize                      = sizeof(I);
275         static constexpr size_t kAABBPadBaseSize        = 8; // As required by the spec.
276 };
277
278 template<typename V, typename I>
279 RaytracedGeometry<V, I>::RaytracedGeometry (VkGeometryTypeKHR geometryType, deUint32 paddingBlocks)
280         : RaytracedGeometryBase(geometryType, vertexFormatFromType<V>(), indexTypeFromType<I>())
281         , m_paddingBlocks(paddingBlocks)
282         , m_vertexCount(0)
283 {
284         init();
285 }
286
287 template<typename V, typename I>
288 RaytracedGeometry<V,I>::RaytracedGeometry (VkGeometryTypeKHR geometryType, const std::vector<V>& vertices, const std::vector<I>& indices, deUint32 paddingBlocks)
289         : RaytracedGeometryBase(geometryType, vertexFormatFromType<V>(), indexTypeFromType<I>())
290         , m_paddingBlocks(paddingBlocks)
291         , m_vertexCount(0)
292         , m_vertices()
293         , m_indices(indices)
294 {
295         init();
296         for (const auto& vertex : vertices)
297                 addNativeVertex(vertex);
298 }
299
300 template<typename V, typename I>
301 deUint32 RaytracedGeometry<V,I>::getVertexCount (void) const
302 {
303         return (isTrianglesType() ? static_cast<deUint32>(m_vertexCount) : 0u);
304 }
305
306 template<typename V, typename I>
307 const deUint8* RaytracedGeometry<V, I>::getVertexPointer (void) const
308 {
309         DE_ASSERT(!m_vertices.empty());
310         return reinterpret_cast<const deUint8*>(m_vertices.data());
311 }
312
313 template<typename V, typename I>
314 VkDeviceSize RaytracedGeometry<V,I>::getVertexStride (void) const
315 {
316         return ((!isTrianglesType()) ? 0ull : static_cast<VkDeviceSize>(getBlockSize()));
317 }
318
319 template<typename V, typename I>
320 VkDeviceSize RaytracedGeometry<V, I>::getAABBStride (void) const
321 {
322         return (isTrianglesType() ? 0ull : static_cast<VkDeviceSize>(getBlockSize()));
323 }
324
325 template<typename V, typename I>
326 size_t RaytracedGeometry<V, I>::getVertexByteSize (void) const
327 {
328         return m_vertices.size();
329 }
330
331 template<typename V, typename I>
332 deUint32 RaytracedGeometry<V, I>::getIndexCount (void) const
333 {
334         return static_cast<deUint32>(isTrianglesType() ? m_indices.size() : 0);
335 }
336
337 template<typename V, typename I>
338 const deUint8* RaytracedGeometry<V, I>::getIndexPointer (void) const
339 {
340         const auto indexCount = getIndexCount();
341         DE_UNREF(indexCount); // For release builds.
342         DE_ASSERT(indexCount > 0u);
343
344         return reinterpret_cast<const deUint8*>(m_indices.data());
345 }
346
347 template<typename V, typename I>
348 VkDeviceSize RaytracedGeometry<V, I>::getIndexStride (void) const
349 {
350         return static_cast<VkDeviceSize>(kIndexSize);
351 }
352
353 template<typename V, typename I>
354 size_t RaytracedGeometry<V, I>::getIndexByteSize (void) const
355 {
356         const auto indexCount = getIndexCount();
357         DE_ASSERT(indexCount > 0u);
358
359         return (indexCount * kIndexSize);
360 }
361
362 template<typename V, typename I>
363 deUint32 RaytracedGeometry<V,I>::getPrimitiveCount (void) const
364 {
365         return static_cast<deUint32>(isTrianglesType() ? (usesIndices() ? m_indices.size() / 3 : m_vertexCount / 3) : (m_vertexCount / 2));
366 }
367
368 template<typename V, typename I>
369 void RaytracedGeometry<V, I>::addVertex (const tcu::Vec3& vertex)
370 {
371         addNativeVertex(convertFloatTo<V>(vertex));
372 }
373
374 template<typename V, typename I>
375 void RaytracedGeometry<V, I>::addNativeVertex (const V& vertex)
376 {
377         const auto oldSize                      = m_vertices.size();
378         const auto blockSize            = getBlockSize();
379
380         if (isTrianglesType())
381         {
382                 // Reserve new block, copy vertex at the beginning of the new block.
383                 m_vertices.resize(oldSize + blockSize, deUint8{0});
384                 deMemcpy(&m_vertices[oldSize], &vertex, kVertexSize);
385         }
386         else // AABB
387         {
388                 if (m_vertexCount % 2 == 0)
389                 {
390                         // New block needed.
391                         m_vertices.resize(oldSize + blockSize, deUint8{0});
392                         deMemcpy(&m_vertices[oldSize], &vertex, kVertexSize);
393                 }
394                 else
395                 {
396                         // Insert in the second position of last existing block.
397                         //
398                         //                                                                                              Vertex Size
399                         //                                                                                              +-------+
400                         //      +-------------+------------+----------------------------------------+
401                         //      |             |            |      ...       | vertex vertex padding |
402                         //      +-------------+------------+----------------+-----------------------+
403                         //                                                                                              +-----------------------+
404                         //                                                                                                              Block Size
405                         //      +-------------------------------------------------------------------+
406                         //                                                      Old Size
407                         //
408                         deMemcpy(&m_vertices[oldSize - blockSize + kVertexSize], &vertex, kVertexSize);
409                 }
410         }
411
412         ++m_vertexCount;
413 }
414
415 template<typename V, typename I>
416 void RaytracedGeometry<V, I>::addIndex (const deUint32& index)
417 {
418         m_indices.push_back(convertIndexTo<I>(index));
419 }
420
421 template<typename V, typename I>
422 void RaytracedGeometry<V, I>::init ()
423 {
424         checkGeometryType();
425         calcBlockSize();
426 }
427
428 template<typename V, typename I>
429 void RaytracedGeometry<V, I>::checkGeometryType () const
430 {
431         const auto geometryType = getGeometryType();
432         DE_UNREF(geometryType); // For release builds.
433         DE_ASSERT(geometryType == VK_GEOMETRY_TYPE_TRIANGLES_KHR || geometryType == VK_GEOMETRY_TYPE_AABBS_KHR);
434 }
435
436 template<typename V, typename I>
437 void RaytracedGeometry<V, I>::calcBlockSize ()
438 {
439         if (isTrianglesType())
440                 m_blockSize.trianglesBlockSize = kVertexSize * static_cast<size_t>(1u + m_paddingBlocks);
441         else
442                 m_blockSize.aabbsBlockSize = 2 * kVertexSize + m_paddingBlocks * kAABBPadBaseSize;
443 }
444
445 template<typename V, typename I>
446 size_t RaytracedGeometry<V, I>::getBlockSize () const
447 {
448         return (isTrianglesType() ? m_blockSize.trianglesBlockSize : m_blockSize.aabbsBlockSize);
449 }
450
451 de::SharedPtr<RaytracedGeometryBase> makeRaytracedGeometry (VkGeometryTypeKHR geometryType, VkFormat vertexFormat, VkIndexType indexType, bool padVertices = false);
452
453 VkDeviceAddress getBufferDeviceAddress ( const DeviceInterface& vkd,
454                                                                                  const VkDevice                 device,
455                                                                                  const VkBuffer                 buffer,
456                                                                                  VkDeviceSize                   offset );
457
458 // type used for creating a deep serialization/deserialization of top-level acceleration structures
459 class SerialInfo
460 {
461         std::vector<deUint64>           m_addresses;
462         std::vector<VkDeviceSize>       m_sizes;
463 public:
464
465         SerialInfo() = default;
466
467         // addresses: { (owner-top-level AS address) [, (first bottom_level AS address), (second bottom_level AS address), ...] }
468         // sizes:     { (owner-top-level AS serial size) [, (first bottom_level AS serial size), (second bottom_level AS serial size), ...] }
469         SerialInfo(const std::vector<deUint64>& addresses, const std::vector<VkDeviceSize>& sizes)
470                 : m_addresses(addresses), m_sizes(sizes)
471         {
472                 DE_ASSERT(!addresses.empty() && addresses.size() == sizes.size());
473         }
474
475         const std::vector<deUint64>&            addresses                       () const        { return m_addresses; }
476         const std::vector<VkDeviceSize>&        sizes                           () const        { return m_sizes; }
477 };
478
479 class SerialStorage
480 {
481 public:
482         enum
483         {
484                 DE_SERIALIZED_FIELD(DRIVER_UUID,                VK_UUID_SIZE),          // VK_UUID_SIZE bytes of data matching VkPhysicalDeviceIDProperties::driverUUID
485                 DE_SERIALIZED_FIELD(COMPAT_UUID,                VK_UUID_SIZE),          // VK_UUID_SIZE bytes of data identifying the compatibility for comparison using vkGetDeviceAccelerationStructureCompatibilityKHR
486                 DE_SERIALIZED_FIELD(SERIALIZED_SIZE,    sizeof(deUint64)),      // A 64-bit integer of the total size matching the value queried using VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR
487                 DE_SERIALIZED_FIELD(DESERIALIZED_SIZE,  sizeof(deUint64)),      // A 64-bit integer of the deserialized size to be passed in to VkAccelerationStructureCreateInfoKHR::size
488                 DE_SERIALIZED_FIELD(HANDLES_COUNT,              sizeof(deUint64)),      // A 64-bit integer of the count of the number of acceleration structure handles following. This will be zero for a bottom-level acceleration structure.
489                 SERIAL_STORAGE_SIZE_MIN
490         };
491
492         // An old fashion C-style structure that simplifies an access to the AS header
493         struct alignas(16) AccelerationStructureHeader
494         {
495                 union {
496                         struct {
497                                 deUint8 driverUUID[VK_UUID_SIZE];
498                                 deUint8 compactUUID[VK_UUID_SIZE];
499                         };
500                         deUint8         uuids[VK_UUID_SIZE * 2];
501                 };
502                 deUint64                serializedSize;
503                 deUint64                deserializedSize;
504                 deUint64                handleCount;
505                 VkDeviceAddress handleArray[1];
506         };
507
508                                                                                         SerialStorage           () = delete;
509                                                                                         SerialStorage           (const DeviceInterface&                                         vk,
510                                                                                                                                  const VkDevice                                                         device,
511                                                                                                                                  Allocator&                                                                     allocator,
512                                                                                                                                  const VkAccelerationStructureBuildTypeKHR      buildType,
513                                                                                                                                  const VkDeviceSize                                                     storageSize);
514         // An additional constructor for creating a deep copy of top-level AS's.
515                                                                                         SerialStorage           (const DeviceInterface&                                         vk,
516                                                                                                                                  const VkDevice                                                         device,
517                                                                                                                                  Allocator&                                                                     allocator,
518                                                                                                                                  const VkAccelerationStructureBuildTypeKHR      buildType,
519                                                                                                                                  const SerialInfo&                                                      SerialInfo);
520
521         // below methods will return host addres if AS was build on cpu and device addres when it was build on gpu
522         VkDeviceOrHostAddressKHR                                getAddress                      (const DeviceInterface&                                         vk,
523                                                                                                                                  const VkDevice                                                         device,
524                                                                                                                                  const VkAccelerationStructureBuildTypeKHR      buildType);
525         VkDeviceOrHostAddressConstKHR                   getAddressConst         (const DeviceInterface&                                         vk,
526                                                                                                                                  const VkDevice                                                         device,
527                                                                                                                                  const VkAccelerationStructureBuildTypeKHR      buildType);
528
529         // this methods retun host address regardless of where AS was built
530         VkDeviceOrHostAddressKHR                                getHostAddress          (VkDeviceSize                   offset = 0);
531         VkDeviceOrHostAddressConstKHR                   getHostAddressConst     (VkDeviceSize                   offset = 0);
532
533         // works the similar way as getHostAddressConst() but returns more readable/intuitive object
534         AccelerationStructureHeader*                    getASHeader                     ();
535         bool                                                                    hasDeepFormat           () const;
536         de::SharedPtr<SerialStorage>                    getBottomStorage        (deUint32                       index) const;
537
538         VkDeviceSize                                                    getStorageSize          () const;
539         const SerialInfo&                                               getSerialInfo           () const;
540         deUint64                                                                getDeserializedSize     ();
541
542 protected:
543         const VkAccelerationStructureBuildTypeKHR       m_buildType;
544         const VkDeviceSize                                                      m_storageSize;
545         const SerialInfo                                                        m_serialInfo;
546         de::MovePtr<BufferWithMemory>                           m_buffer;
547         std::vector<de::SharedPtr<SerialStorage>>       m_bottoms;
548 };
549
550 class BottomLevelAccelerationStructure
551 {
552 public:
553         static deUint32                                                                         getRequiredAllocationCount                              (void);
554
555                                                                                                                 BottomLevelAccelerationStructure                ();
556                                                                                                                 BottomLevelAccelerationStructure                (const BottomLevelAccelerationStructure&                other) = delete;
557         virtual                                                                                         ~BottomLevelAccelerationStructure               ();
558
559         virtual void                                                                            setGeometryData                                                 (const std::vector<tcu::Vec3>&                                  geometryData,
560                                                                                                                                                                                                  const bool                                                                             triangles,
561                                                                                                                                                                                                  const VkGeometryFlagsKHR                                               geometryFlags                   = 0u );
562         virtual void                                                                            setDefaultGeometryData                                  (const VkShaderStageFlagBits                                    testStage,
563                                                                                                                                                                                                  const VkGeometryFlagsKHR                                               geometryFlags                   = 0u );
564         virtual void                                                                            setGeometryCount                                                (const size_t                                                                   geometryCount);
565         virtual void                                                                            addGeometry                                                             (de::SharedPtr<RaytracedGeometryBase>&                  raytracedGeometry);
566         virtual void                                                                            addGeometry                                                             (const std::vector<tcu::Vec3>&                                  geometryData,
567                                                                                                                                                                                                  const bool                                                                             triangles,
568                                                                                                                                                                                                  const VkGeometryFlagsKHR                                               geometryFlags                   = 0u );
569
570         virtual void                                                                            setBuildType                                                    (const VkAccelerationStructureBuildTypeKHR              buildType) = DE_NULL;
571         virtual void                                                                            setCreateFlags                                                  (const VkAccelerationStructureCreateFlagsKHR    createFlags) = DE_NULL;
572         virtual void                                                                            setCreateGeneric                                                (bool                                                                                   createGeneric) = 0;
573         virtual void                                                                            setBuildFlags                                                   (const VkBuildAccelerationStructureFlagsKHR             buildFlags) = DE_NULL;
574         virtual void                                                                            setBuildWithoutGeometries                               (bool                                                                                   buildWithoutGeometries) = 0;
575         virtual void                                                                            setBuildWithoutPrimitives                               (bool                                                                                   buildWithoutPrimitives) = 0;
576         virtual void                                                                            setDeferredOperation                                    (const bool                                                                             deferredOperation,
577                                                                                                                                                                                                  const deUint32                                                                 workerThreadCount               = 0u ) = DE_NULL;
578         virtual void                                                                            setUseArrayOfPointers                                   (const bool                                                                             useArrayOfPointers) = DE_NULL;
579         virtual void                                                                            setIndirectBuildParameters                              (const VkBuffer                                                                 indirectBuffer,
580                                                                                                                                                                                                  const VkDeviceSize                                                             indirectBufferOffset,
581                                                                                                                                                                                                  const deUint32                                                                 indirectBufferStride) = DE_NULL;
582         virtual VkBuildAccelerationStructureFlagsKHR            getBuildFlags                                                   () const = DE_NULL;
583         VkDeviceSize                                                                            getStructureSize                                                () const;
584
585         // methods specific for each acceleration structure
586         virtual void                                                                            create                                                                  (const DeviceInterface&                                                 vk,
587                                                                                                                                                                                                  const VkDevice                                                                 device,
588                                                                                                                                                                                                  Allocator&                                                                             allocator,
589                                                                                                                                                                                                  VkDeviceSize                                                                   structureSize,
590                                                                                                                                                                                                  VkDeviceAddress                                                                deviceAddress                   = 0u) = DE_NULL;
591         virtual void                                                                            build                                                                   (const DeviceInterface&                                                 vk,
592                                                                                                                                                                                                  const VkDevice                                                                 device,
593                                                                                                                                                                                                  const VkCommandBuffer                                                  cmdBuffer) = DE_NULL;
594         virtual void                                                                            copyFrom                                                                (const DeviceInterface&                                                 vk,
595                                                                                                                                                                                                  const VkDevice                                                                 device,
596                                                                                                                                                                                                  const VkCommandBuffer                                                  cmdBuffer,
597                                                                                                                                                                                                  BottomLevelAccelerationStructure*                              accelerationStructure,
598                                                                                                                                                                                                  bool                                                                                   compactCopy) = DE_NULL;
599
600         virtual void                                                                            serialize                                                               (const DeviceInterface&                                                 vk,
601                                                                                                                                                                                                  const VkDevice                                                                 device,
602                                                                                                                                                                                                  const VkCommandBuffer                                                  cmdBuffer,
603                                                                                                                                                                                                  SerialStorage*                                                                 storage) = DE_NULL;
604         virtual void                                                                            deserialize                                                             (const DeviceInterface&                                                 vk,
605                                                                                                                                                                                                  const VkDevice                                                                 device,
606                                                                                                                                                                                                  const VkCommandBuffer                                                  cmdBuffer,
607                                                                                                                                                                                                  SerialStorage*                                                                 storage) = DE_NULL;
608
609         // helper methods for typical acceleration structure creation tasks
610         void                                                                                            createAndBuild                                                  (const DeviceInterface&                                                 vk,
611                                                                                                                                                                                                  const VkDevice                                                                 device,
612                                                                                                                                                                                                  const VkCommandBuffer                                                  cmdBuffer,
613                                                                                                                                                                                                  Allocator&                                                                             allocator,
614                                                                                                                                                                                                  VkDeviceAddress                                                                deviceAddress                   = 0u );
615         void                                                                                            createAndCopyFrom                                               (const DeviceInterface&                                                 vk,
616                                                                                                                                                                                                  const VkDevice                                                                 device,
617                                                                                                                                                                                                  const VkCommandBuffer                                                  cmdBuffer,
618                                                                                                                                                                                                  Allocator&                                                                             allocator,
619                                                                                                                                                                                                  BottomLevelAccelerationStructure*                              accelerationStructure,
620                                                                                                                                                                                                  VkDeviceSize                                                                   compactCopySize                 = 0u,
621                                                                                                                                                                                                  VkDeviceAddress                                                                deviceAddress                   = 0u);
622         void                                                                                            createAndDeserializeFrom                                (const DeviceInterface&                                                 vk,
623                                                                                                                                                                                                  const VkDevice                                                                 device,
624                                                                                                                                                                                                  const VkCommandBuffer                                                  cmdBuffer,
625                                                                                                                                                                                                  Allocator&                                                                             allocator,
626                                                                                                                                                                                                  SerialStorage*                                                                 storage,
627                                                                                                                                                                                                  VkDeviceAddress                                                                deviceAddress                   = 0u);
628
629         virtual const VkAccelerationStructureKHR*                       getPtr                                                                  (void) const = DE_NULL;
630 protected:
631         std::vector<de::SharedPtr<RaytracedGeometryBase>>       m_geometriesData;
632         VkDeviceSize                                                                            m_structureSize;
633         VkDeviceSize                                                                            m_updateScratchSize;
634         VkDeviceSize                                                                            m_buildScratchSize;
635 };
636
637 de::MovePtr<BottomLevelAccelerationStructure> makeBottomLevelAccelerationStructure ();
638
639 /**
640  * @brief Implements a pool of BottomLevelAccelerationStructure
641  */
642 class BottomLevelAccelerationStructurePool
643 {
644 public:
645         typedef de::SharedPtr<BottomLevelAccelerationStructure> BlasPtr;
646         struct BlasInfo {
647                 VkDeviceSize    structureSize;
648                 VkDeviceAddress deviceAddress;
649         };
650
651         BottomLevelAccelerationStructurePool();
652         virtual ~BottomLevelAccelerationStructurePool();
653
654         BlasPtr at                                      (deUint32 index) const  { return m_structs[index]; }
655         BlasPtr operator[]                      (deUint32 index) const  { return m_structs[index]; }
656         auto    structures                      () const -> const std::vector<BlasPtr>& { return m_structs; }
657         size_t  structCount                     () const { return m_structs.size(); }
658
659         size_t  batchStructCount        () const {return m_batchStructCount; }
660         void    batchStructCount        (const  size_t& value);
661
662         size_t  batchGeomCount          () const {return m_batchGeomCount; }
663         void    batchGeomCount          (const  size_t& value) { m_batchGeomCount = value; }
664
665         BlasPtr add                                     (VkDeviceSize                   structureSize = 0,
666                                                                  VkDeviceAddress                deviceAddress = 0);
667         /**
668          * @brief Creates previously added bottoms at a time.
669          * @note  All geometries must be known before call this method.
670          */
671         void    batchCreate                     (const DeviceInterface& vk,
672                                                                  const VkDevice                 device,
673                                                                  Allocator&                             allocator);
674         void    batchBuild                      (const DeviceInterface& vk,
675                                                                  const VkDevice                 device,
676                                                                  VkCommandBuffer                cmdBuffer);
677         size_t  getAllocationCount      () const;
678
679 protected:
680         size_t                                  m_batchStructCount; // default is 4
681         size_t                                  m_batchGeomCount; // default is 0, if zero then batchStructCount is used
682         std::vector<BlasInfo>   m_infos;
683         std::vector<BlasPtr>    m_structs;
684         bool                                    m_createOnce;
685
686 protected:
687         struct Impl;
688         Impl*                                   m_impl;
689 };
690
691 struct InstanceData
692 {
693                                                                 InstanceData (VkTransformMatrixKHR                                                      matrix_,
694                                                                                           deUint32                                                                              instanceCustomIndex_,
695                                                                                           deUint32                                                                              mask_,
696                                                                                           deUint32                                                                              instanceShaderBindingTableRecordOffset_,
697                                                                                           VkGeometryInstanceFlagsKHR                                    flags_)
698                                                                         : matrix(matrix_), instanceCustomIndex(instanceCustomIndex_), mask(mask_), instanceShaderBindingTableRecordOffset(instanceShaderBindingTableRecordOffset_), flags(flags_)
699                                                                 {
700                                                                 }
701         VkTransformMatrixKHR            matrix;
702         deUint32                                        instanceCustomIndex;
703         deUint32                                        mask;
704         deUint32                                        instanceShaderBindingTableRecordOffset;
705         VkGeometryInstanceFlagsKHR      flags;
706 };
707
708 class TopLevelAccelerationStructure
709 {
710 public:
711         static deUint32                                                                                                 getRequiredAllocationCount                      (void);
712
713                                                                                                                                         TopLevelAccelerationStructure           ();
714                                                                                                                                         TopLevelAccelerationStructure           (const TopLevelAccelerationStructure&                           other) = delete;
715         virtual                                                                                                                 ~TopLevelAccelerationStructure          ();
716
717         virtual void                                                                                                    setInstanceCount                                        (const size_t                                                                           instanceCount);
718         virtual void                                                                                                    addInstance                                                     (de::SharedPtr<BottomLevelAccelerationStructure>        bottomLevelStructure,
719                                                                                                                                                                                                                  const VkTransformMatrixKHR&                                            matrix                                                                  = identityMatrix3x4,
720                                                                                                                                                                                                                  deUint32                                                                                       instanceCustomIndex                                             = 0,
721                                                                                                                                                                                                                  deUint32                                                                                       mask                                                                    = 0xFF,
722                                                                                                                                                                                                                  deUint32                                                                                       instanceShaderBindingTableRecordOffset  = 0,
723                                                                                                                                                                                                                  VkGeometryInstanceFlagsKHR                                                     flags                                                                   = VkGeometryInstanceFlagBitsKHR(0u)     );
724
725         virtual void                                                                                                    setBuildType                                            (const VkAccelerationStructureBuildTypeKHR                      buildType) = DE_NULL;
726         virtual void                                                                                                    setCreateFlags                                          (const VkAccelerationStructureCreateFlagsKHR            createFlags) = DE_NULL;
727         virtual void                                                                                                    setCreateGeneric                                        (bool                                                                                           createGeneric) = 0;
728         virtual void                                                                                                    setBuildFlags                                           (const VkBuildAccelerationStructureFlagsKHR                     buildFlags) = DE_NULL;
729         virtual void                                                                                                    setBuildWithoutPrimitives                       (bool                                                                                           buildWithoutPrimitives) = 0;
730         virtual void                                                                                                    setInactiveInstances                            (bool                                                                                           inactiveInstances) = 0;
731         virtual void                                                                                                    setDeferredOperation                            (const bool                                                                                     deferredOperation,
732                                                                                                                                                                                                                  const deUint32                                                                         workerThreadCount = 0u) = DE_NULL;
733         virtual void                                                                                                    setUseArrayOfPointers                           (const bool                                                                                     useArrayOfPointers) = DE_NULL;
734         virtual void                                                                                                    setIndirectBuildParameters                      (const VkBuffer                                                                         indirectBuffer,
735                                                                                                                                                                                                                  const VkDeviceSize                                                                     indirectBufferOffset,
736                                                                                                                                                                                                                  const deUint32                                                                         indirectBufferStride) = DE_NULL;
737         virtual void                                                                                                    setUsePPGeometries                                      (const bool                                                                                     usePPGeometries) = 0;
738         virtual VkBuildAccelerationStructureFlagsKHR                                    getBuildFlags                                           () const = DE_NULL;
739         VkDeviceSize                                                                                                    getStructureSize                                        () const;
740
741         // methods specific for each acceleration structure
742         virtual void                                                                                                    create                                                          (const DeviceInterface&                                         vk,
743                                                                                                                                                                                                                  const VkDevice                                                         device,
744                                                                                                                                                                                                                  Allocator&                                                                     allocator,
745                                                                                                                                                                                                                  VkDeviceSize                                                           structureSize                   = 0u,
746                                                                                                                                                                                                                  VkDeviceAddress                                                        deviceAddress                   = 0u ) = DE_NULL;
747         virtual void                                                                                                    build                                                           (const DeviceInterface&                                         vk,
748                                                                                                                                                                                                                  const VkDevice                                                         device,
749                                                                                                                                                                                                                  const VkCommandBuffer                                          cmdBuffer) = DE_NULL;
750         virtual void                                                                                                    copyFrom                                                        (const DeviceInterface&                                         vk,
751                                                                                                                                                                                                                  const VkDevice                                                         device,
752                                                                                                                                                                                                                  const VkCommandBuffer                                          cmdBuffer,
753                                                                                                                                                                                                                  TopLevelAccelerationStructure*                         accelerationStructure,
754                                                                                                                                                                                                                  bool                                                                           compactCopy) = DE_NULL;
755
756         virtual void                                                                                                    serialize                                                       (const DeviceInterface&                                         vk,
757                                                                                                                                                                                                                  const VkDevice                                                         device,
758                                                                                                                                                                                                                  const VkCommandBuffer                                          cmdBuffer,
759                                                                                                                                                                                                                  SerialStorage*                                                         storage) = DE_NULL;
760         virtual void                                                                                                    deserialize                                                     (const DeviceInterface&                                         vk,
761                                                                                                                                                                                                                  const VkDevice                                                         device,
762                                                                                                                                                                                                                  const VkCommandBuffer                                          cmdBuffer,
763                                                                                                                                                                                                                  SerialStorage*                                                         storage) = DE_NULL;
764
765         virtual std::vector<VkDeviceSize>                                                               getSerializingSizes                                     (const DeviceInterface&                                         vk,
766                                                                                                                                                                                                                  const VkDevice                                                         device,
767                                                                                                                                                                                                                  const VkQueue                                                          queue,
768                                                                                                                                                                                                                  const deUint32                                                         queueFamilyIndex) = DE_NULL;
769
770         virtual std::vector<deUint64>                                                                   getSerializingAddresses                         (const DeviceInterface&                                         vk,
771                                                                                                                                                                                                                  const VkDevice                                                         device) const = DE_NULL;
772
773         // helper methods for typical acceleration structure creation tasks
774         void                                                                                                                    createAndBuild                                          (const DeviceInterface&                                         vk,
775                                                                                                                                                                                                                  const VkDevice                                                         device,
776                                                                                                                                                                                                                  const VkCommandBuffer                                          cmdBuffer,
777                                                                                                                                                                                                                  Allocator&                                                                     allocator,
778                                                                                                                                                                                                                  VkDeviceAddress                                                        deviceAddress                   = 0u );
779         void                                                                                                                    createAndCopyFrom                                       (const DeviceInterface&                                         vk,
780                                                                                                                                                                                                                  const VkDevice                                                         device,
781                                                                                                                                                                                                                  const VkCommandBuffer                                          cmdBuffer,
782                                                                                                                                                                                                                  Allocator&                                                                     allocator,
783                                                                                                                                                                                                                  TopLevelAccelerationStructure*                         accelerationStructure,
784                                                                                                                                                                                                                  VkDeviceSize                                                           compactCopySize                 = 0u,
785                                                                                                                                                                                                                  VkDeviceAddress                                                        deviceAddress                   = 0u);
786         void                                                                                                                    createAndDeserializeFrom                        (const DeviceInterface&                                         vk,
787                                                                                                                                                                                                                  const VkDevice                                                         device,
788                                                                                                                                                                                                                  const VkCommandBuffer                                          cmdBuffer,
789                                                                                                                                                                                                                  Allocator&                                                                     allocator,
790                                                                                                                                                                                                                  SerialStorage*                                                         storage,
791                                                                                                                                                                                                                  VkDeviceAddress                                                        deviceAddress                   = 0u);
792
793         virtual const VkAccelerationStructureKHR*                                               getPtr                                                          (void) const = DE_NULL;
794
795         virtual void                                                                                                    updateInstanceMatrix                            (const DeviceInterface&                                         vk,
796                                                                                                                                                                                                                  const VkDevice                                                         device,
797                                                                                                                                                                                                                  size_t                                                                         instanceIndex,
798                                                                                                                                                                                                                  const VkTransformMatrixKHR&                            matrix) = 0;
799
800 protected:
801         std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >   m_bottomLevelInstances;
802         std::vector<InstanceData>                                                                               m_instanceData;
803         VkDeviceSize                                                                                                    m_structureSize;
804         VkDeviceSize                                                                                                    m_updateScratchSize;
805         VkDeviceSize                                                                                                    m_buildScratchSize;
806
807         virtual void                                                                                                    createAndDeserializeBottoms                     (const DeviceInterface&                                         vk,
808                                                                                                                                                                                                                  const VkDevice                                                         device,
809                                                                                                                                                                                                                  const VkCommandBuffer                                          cmdBuffer,
810                                                                                                                                                                                                                  Allocator&                                                                     allocator,
811                                                                                                                                                                                                                  SerialStorage*                                                         storage) = DE_NULL;
812 };
813
814 de::MovePtr<TopLevelAccelerationStructure> makeTopLevelAccelerationStructure ();
815
816 template<class ASType> de::MovePtr<ASType> makeAccelerationStructure ();
817 template<> inline de::MovePtr<BottomLevelAccelerationStructure> makeAccelerationStructure () { return makeBottomLevelAccelerationStructure(); }
818 template<> inline de::MovePtr<TopLevelAccelerationStructure>    makeAccelerationStructure () { return makeTopLevelAccelerationStructure(); }
819
820 bool queryAccelerationStructureSize (const DeviceInterface&                                                     vk,
821                                                                          const VkDevice                                                                 device,
822                                                                          const VkCommandBuffer                                                  cmdBuffer,
823                                                                          const std::vector<VkAccelerationStructureKHR>& accelerationStructureHandles,
824                                                                          VkAccelerationStructureBuildTypeKHR                    buildType,
825                                                                          const VkQueryPool                                                              queryPool,
826                                                                          VkQueryType                                                                    queryType,
827                                                                          deUint32                                                                               firstQuery,
828                                                                          std::vector<VkDeviceSize>&                                             results);
829
830 class RayTracingPipeline
831 {
832 public:
833         class CompileRequiredError : public std::runtime_error
834         {
835         public:
836                 CompileRequiredError (const std::string& error)
837                         : std::runtime_error(error)
838                         {}
839         };
840
841                                                                                                                                 RayTracingPipeline                      ();
842                                                                                                                                 ~RayTracingPipeline                     ();
843
844         void                                                                                                            addShader                                       (VkShaderStageFlagBits                                                                  shaderStage,
845                                                                                                                                                                                          Move<VkShaderModule>                                                                   shaderModule,
846                                                                                                                                                                                          deUint32                                                                                               group,
847                                                                                                                                                                                          const VkSpecializationInfo*                                                    specializationInfo = nullptr,
848                                                                                                                                                                                          const VkPipelineShaderStageCreateFlags                                 pipelineShaderStageCreateFlags = static_cast<VkPipelineShaderStageCreateFlags>(0),
849                                                                                                                                                                                          const void*                                                                                    pipelineShaderStageCreateInfopNext = nullptr);
850         void                                                                                                            addShader                                       (VkShaderStageFlagBits                                                                  shaderStage,
851                                                                                                                                                                                          de::SharedPtr<Move<VkShaderModule>>                                    shaderModule,
852                                                                                                                                                                                          deUint32                                                                                               group,
853                                                                                                                                                                                          const VkSpecializationInfo*                                                    specializationInfoPtr = nullptr,
854                                                                                                                                                                                          const VkPipelineShaderStageCreateFlags                                 pipelineShaderStageCreateFlags = static_cast<VkPipelineShaderStageCreateFlags>(0),
855                                                                                                                                                                                          const void*                                                                                    pipelineShaderStageCreateInfopNext = nullptr);
856         void                                                                                                            addShader                                       (VkShaderStageFlagBits                                                                  shaderStage,
857                                                                                                                                                                                          VkShaderModule                                                                         shaderModule,
858                                                                                                                                                                                          deUint32                                                                                               group,
859                                                                                                                                                                                          const VkSpecializationInfo*                                                    specializationInfo = nullptr,
860                                                                                                                                                                                          const VkPipelineShaderStageCreateFlags                                 pipelineShaderStageCreateFlags = static_cast<VkPipelineShaderStageCreateFlags>(0),
861                                                                                                                                                                                          const void*                                                                                    pipelineShaderStageCreateInfopNext = nullptr);
862         void                                                                                                            addLibrary                                      (de::SharedPtr<de::MovePtr<RayTracingPipeline>>                 pipelineLibrary);
863         Move<VkPipeline>                                                                                        createPipeline                          (const DeviceInterface&                                                                 vk,
864                                                                                                                                                                                          const VkDevice                                                                                 device,
865                                                                                                                                                                                          const VkPipelineLayout                                                                 pipelineLayout,
866                                                                                                                                                                                          const std::vector<de::SharedPtr<Move<VkPipeline>>>&    pipelineLibraries                       = std::vector<de::SharedPtr<Move<VkPipeline>>>());
867         Move<VkPipeline>                                                                                        createPipeline                          (const DeviceInterface&                                                                 vk,
868                                                                                                                                                                                          const VkDevice                                                                                 device,
869                                                                                                                                                                                          const VkPipelineLayout                                                                 pipelineLayout,
870                                                                                                                                                                                          const std::vector<VkPipeline>&                                                 pipelineLibraries,
871                                                                                                                                                                                          const VkPipelineCache                                                                  pipelineCache);
872         std::vector<de::SharedPtr<Move<VkPipeline>>>                            createPipelineWithLibraries     (const DeviceInterface&                                                                 vk,
873                                                                                                                                                                                          const VkDevice                                                                                 device,
874                                                                                                                                                                                          const VkPipelineLayout                                                                 pipelineLayout);
875         de::MovePtr<BufferWithMemory>                                                           createShaderBindingTable        (const DeviceInterface&                                                                 vk,
876                                                                                                                                                                                          const VkDevice                                                                                 device,
877                                                                                                                                                                                          const VkPipeline                                                                               pipeline,
878                                                                                                                                                                                          Allocator&                                                                                             allocator,
879                                                                                                                                                                                          const deUint32&                                                                                shaderGroupHandleSize,
880                                                                                                                                                                                          const deUint32                                                                                 shaderGroupBaseAlignment,
881                                                                                                                                                                                          const deUint32&                                                                                firstGroup,
882                                                                                                                                                                                          const deUint32&                                                                                groupCount,
883                                                                                                                                                                                          const VkBufferCreateFlags&                                                             additionalBufferCreateFlags     = VkBufferCreateFlags(0u),
884                                                                                                                                                                                          const VkBufferUsageFlags&                                                              additionalBufferUsageFlags      = VkBufferUsageFlags(0u),
885                                                                                                                                                                                          const MemoryRequirement&                                                               additionalMemoryRequirement     = MemoryRequirement::Any,
886                                                                                                                                                                                          const VkDeviceAddress&                                                                 opaqueCaptureAddress            = 0u,
887                                                                                                                                                                                          const deUint32                                                                                 shaderBindingTableOffset        = 0u,
888                                                                                                                                                                                          const deUint32                                                                                 shaderRecordSize                        = 0u,
889                                                                                                                                                                                          const void**                                                                                   shaderGroupDataPtrPerGroup      = nullptr);
890         void                                                                                                            setCreateFlags                          (const VkPipelineCreateFlags&                                                   pipelineCreateFlags);
891         void                                                                                                            setMaxRecursionDepth            (const deUint32&                                                                                maxRecursionDepth);
892         void                                                                                                            setMaxPayloadSize                       (const deUint32&                                                                                maxPayloadSize);
893         void                                                                                                            setMaxAttributeSize                     (const deUint32&                                                                                maxAttributeSize);
894         void                                                                                                            setDeferredOperation            (const bool                                                                                             deferredOperation,
895                                                                                                                                                                                          const deUint32                                                                                 workerThreadCount = 0);
896         void                                                                                                            addDynamicState                         (const VkDynamicState&                                                                  dynamicState);
897
898
899 protected:
900         Move<VkPipeline>                                                                                        createPipelineKHR                       (const DeviceInterface&                 vk,
901                                                                                                                                                                                          const VkDevice                                 device,
902                                                                                                                                                                                          const VkPipelineLayout                 pipelineLayout,
903                                                                                                                                                                                          const std::vector<VkPipeline>& pipelineLibraries,
904                                                                                                                                                                                          const VkPipelineCache                  pipelineCache = DE_NULL);
905
906         std::vector<de::SharedPtr<Move<VkShaderModule> > >                      m_shadersModules;
907         std::vector<de::SharedPtr<de::MovePtr<RayTracingPipeline>>>     m_pipelineLibraries;
908         std::vector<VkPipelineShaderStageCreateInfo>                            m_shaderCreateInfos;
909         std::vector<VkRayTracingShaderGroupCreateInfoKHR>                       m_shadersGroupCreateInfos;
910         VkPipelineCreateFlags                                                                           m_pipelineCreateFlags;
911         deUint32                                                                                                        m_maxRecursionDepth;
912         deUint32                                                                                                        m_maxPayloadSize;
913         deUint32                                                                                                        m_maxAttributeSize;
914         bool                                                                                                            m_deferredOperation;
915         deUint32                                                                                                        m_workerThreadCount;
916         std::vector<VkDynamicState>                                                                     m_dynamicStates;
917 };
918
919 class RayTracingProperties
920 {
921 protected:
922                                                                         RayTracingProperties                                            () {}
923
924 public:
925                                                                         RayTracingProperties                                            (const InstanceInterface&       vki,
926                                                                                                                                                                  const VkPhysicalDevice         physicalDevice) { DE_UNREF(vki); DE_UNREF(physicalDevice); }
927         virtual                                                 ~RayTracingProperties                                           () {}
928
929         virtual deUint32                                getShaderGroupHandleSize                                        (void)  = DE_NULL;
930         virtual deUint32                                getMaxRecursionDepth                                            (void)  = DE_NULL;
931         virtual deUint32                                getMaxShaderGroupStride                                         (void)  = DE_NULL;
932         virtual deUint32                                getShaderGroupBaseAlignment                                     (void)  = DE_NULL;
933         virtual deUint64                                getMaxGeometryCount                                                     (void)  = DE_NULL;
934         virtual deUint64                                getMaxInstanceCount                                                     (void)  = DE_NULL;
935         virtual deUint64                                getMaxPrimitiveCount                                            (void)  = DE_NULL;
936         virtual deUint32                                getMaxDescriptorSetAccelerationStructures       (void)  = DE_NULL;
937         virtual deUint32                                getMaxRayDispatchInvocationCount                        (void)  = DE_NULL;
938         virtual deUint32                                getMaxRayHitAttributeSize                                       (void)  = DE_NULL;
939 };
940
941 de::MovePtr<RayTracingProperties> makeRayTracingProperties (const InstanceInterface&    vki,
942                                                                                                                         const VkPhysicalDevice          physicalDevice);
943
944 void cmdTraceRays       (const DeviceInterface&                                 vk,
945                                          VkCommandBuffer                                                commandBuffer,
946                                          const VkStridedDeviceAddressRegionKHR* raygenShaderBindingTableRegion,
947                                          const VkStridedDeviceAddressRegionKHR* missShaderBindingTableRegion,
948                                          const VkStridedDeviceAddressRegionKHR* hitShaderBindingTableRegion,
949                                          const VkStridedDeviceAddressRegionKHR* callableShaderBindingTableRegion,
950                                          deUint32                                                               width,
951                                          deUint32                                                               height,
952                                          deUint32                                                               depth);
953
954 void cmdTraceRaysIndirect       (const DeviceInterface&                                 vk,
955                                                          VkCommandBuffer                                                commandBuffer,
956                                                          const VkStridedDeviceAddressRegionKHR* raygenShaderBindingTableRegion,
957                                                          const VkStridedDeviceAddressRegionKHR* missShaderBindingTableRegion,
958                                                          const VkStridedDeviceAddressRegionKHR* hitShaderBindingTableRegion,
959                                                          const VkStridedDeviceAddressRegionKHR* callableShaderBindingTableRegion,
960                                                          VkDeviceAddress                                                indirectDeviceAddress);
961
962 } // vk
963
964 #endif // _VKRAYTRACINGUTIL_HPP