[dali_2.3.21] Merge branch 'devel/master'
[platform/core/uifw/dali-toolkit.git] / dali-physics / third-party / bullet3 / src / Bullet3OpenCL / BroadphaseCollision / kernels / parallelLinearBvh.cl
1 /*
2 This software is provided 'as-is', without any express or implied warranty.
3 In no event will the authors be held liable for any damages arising from the use of this software.
4 Permission is granted to anyone to use this software for any purpose,
5 including commercial applications, and to alter it and redistribute it freely,
6 subject to the following restrictions:
7
8 1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required.
9 2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software.
10 3. This notice may not be removed or altered from any source distribution.
11 */
12 //Initial Author Jackson Lee, 2014
13
14 typedef float b3Scalar;
15 typedef float4 b3Vector3;
16 #define b3Max max
17 #define b3Min min
18 #define b3Sqrt sqrt
19
20 typedef struct
21 {
22         unsigned int m_key;
23         unsigned int m_value;
24 } SortDataCL;
25
26 typedef struct 
27 {
28         union
29         {
30                 float4  m_min;
31                 float   m_minElems[4];
32                 int                     m_minIndices[4];
33         };
34         union
35         {
36                 float4  m_max;
37                 float   m_maxElems[4];
38                 int                     m_maxIndices[4];
39         };
40 } b3AabbCL;
41
42
43 unsigned int interleaveBits(unsigned int x)
44 {
45         //........ ........ ......12 3456789A   //x
46         //....1..2 ..3..4.. 5..6..7. .8..9..A   //x after interleaving bits
47         
48         //......12 3456789A ......12 3456789A   //x ^ (x << 16)
49         //11111111 ........ ........ 11111111   //0x FF 00 00 FF
50         //......12 ........ ........ 3456789A   //x = (x ^ (x << 16)) & 0xFF0000FF;
51         
52         //......12 ........ 3456789A 3456789A   //x ^ (x <<  8)
53         //......11 ........ 1111.... ....1111   //0x 03 00 F0 0F
54         //......12 ........ 3456.... ....789A   //x = (x ^ (x <<  8)) & 0x0300F00F;
55         
56         //..12..12 ....3456 3456.... 789A789A   //x ^ (x <<  4)
57         //......11 ....11.. ..11.... 11....11   //0x 03 0C 30 C3
58         //......12 ....34.. ..56.... 78....9A   //x = (x ^ (x <<  4)) & 0x030C30C3;
59         
60         //....1212 ..3434.. 5656..78 78..9A9A   //x ^ (x <<  2)
61         //....1..1 ..1..1.. 1..1..1. .1..1..1   //0x 09 24 92 49
62         //....1..2 ..3..4.. 5..6..7. .8..9..A   //x = (x ^ (x <<  2)) & 0x09249249;
63         
64         //........ ........ ......11 11111111   //0x000003FF
65         x &= 0x000003FF;                //Clear all bits above bit 10
66         
67         x = (x ^ (x << 16)) & 0xFF0000FF;
68         x = (x ^ (x <<  8)) & 0x0300F00F;
69         x = (x ^ (x <<  4)) & 0x030C30C3;
70         x = (x ^ (x <<  2)) & 0x09249249;
71         
72         return x;
73 }
74 unsigned int getMortonCode(unsigned int x, unsigned int y, unsigned int z)
75 {
76         return interleaveBits(x) << 0 | interleaveBits(y) << 1 | interleaveBits(z) << 2;
77 }
78
79 __kernel void separateAabbs(__global b3AabbCL* unseparatedAabbs, __global int* aabbIndices, __global b3AabbCL* out_aabbs, int numAabbsToSeparate)
80 {
81         int separatedAabbIndex = get_global_id(0);
82         if(separatedAabbIndex >= numAabbsToSeparate) return;
83
84         int unseparatedAabbIndex = aabbIndices[separatedAabbIndex];
85         out_aabbs[separatedAabbIndex] = unseparatedAabbs[unseparatedAabbIndex];
86 }
87
88 //Should replace with an optimized parallel reduction
89 __kernel void findAllNodesMergedAabb(__global b3AabbCL* out_mergedAabb, int numAabbsNeedingMerge)
90 {
91         //Each time this kernel is added to the command queue, 
92         //the number of AABBs needing to be merged is halved
93         //
94         //Example with 159 AABBs:
95         //      numRemainingAabbs == 159 / 2 + 159 % 2 == 80
96         //      numMergedAabbs == 159 - 80 == 79
97         //So, indices [0, 78] are merged with [0 + 80, 78 + 80]
98         
99         int numRemainingAabbs = numAabbsNeedingMerge / 2 + numAabbsNeedingMerge % 2;
100         int numMergedAabbs = numAabbsNeedingMerge - numRemainingAabbs;
101         
102         int aabbIndex = get_global_id(0);
103         if(aabbIndex >= numMergedAabbs) return;
104         
105         int otherAabbIndex = aabbIndex + numRemainingAabbs;
106         
107         b3AabbCL aabb = out_mergedAabb[aabbIndex];
108         b3AabbCL otherAabb = out_mergedAabb[otherAabbIndex];
109                 
110         b3AabbCL mergedAabb;
111         mergedAabb.m_min = b3Min(aabb.m_min, otherAabb.m_min);
112         mergedAabb.m_max = b3Max(aabb.m_max, otherAabb.m_max);
113         out_mergedAabb[aabbIndex] = mergedAabb;
114 }
115
116 __kernel void assignMortonCodesAndAabbIndicies(__global b3AabbCL* worldSpaceAabbs, __global b3AabbCL* mergedAabbOfAllNodes, 
117                                                                                                 __global SortDataCL* out_mortonCodesAndAabbIndices, int numAabbs)
118 {
119         int leafNodeIndex = get_global_id(0);   //Leaf node index == AABB index
120         if(leafNodeIndex >= numAabbs) return;
121         
122         b3AabbCL mergedAabb = mergedAabbOfAllNodes[0];
123         b3Vector3 gridCenter = (mergedAabb.m_min + mergedAabb.m_max) * 0.5f;
124         b3Vector3 gridCellSize = (mergedAabb.m_max - mergedAabb.m_min) / (float)1024;
125         
126         b3AabbCL aabb = worldSpaceAabbs[leafNodeIndex];
127         b3Vector3 aabbCenter = (aabb.m_min + aabb.m_max) * 0.5f;
128         b3Vector3 aabbCenterRelativeToGrid = aabbCenter - gridCenter;
129         
130         //Quantize into integer coordinates
131         //floor() is needed to prevent the center cell, at (0,0,0) from being twice the size
132         b3Vector3 gridPosition = aabbCenterRelativeToGrid / gridCellSize;
133         
134         int4 discretePosition;
135         discretePosition.x = (int)( (gridPosition.x >= 0.0f) ? gridPosition.x : floor(gridPosition.x) );
136         discretePosition.y = (int)( (gridPosition.y >= 0.0f) ? gridPosition.y : floor(gridPosition.y) );
137         discretePosition.z = (int)( (gridPosition.z >= 0.0f) ? gridPosition.z : floor(gridPosition.z) );
138         
139         //Clamp coordinates into [-512, 511], then convert range from [-512, 511] to [0, 1023]
140         discretePosition = b3Max( -512, b3Min(discretePosition, 511) );
141         discretePosition += 512;
142         
143         //Interleave bits(assign a morton code, also known as a z-curve)
144         unsigned int mortonCode = getMortonCode(discretePosition.x, discretePosition.y, discretePosition.z);
145         
146         //
147         SortDataCL mortonCodeIndexPair;
148         mortonCodeIndexPair.m_key = mortonCode;
149         mortonCodeIndexPair.m_value = leafNodeIndex;
150         
151         out_mortonCodesAndAabbIndices[leafNodeIndex] = mortonCodeIndexPair;
152 }
153
154 #define B3_PLVBH_TRAVERSE_MAX_STACK_SIZE 128
155
156 //The most significant bit(0x80000000) of a int32 is used to distinguish between leaf and internal nodes.
157 //If it is set, then the index is for an internal node; otherwise, it is a leaf node. 
158 //In both cases, the bit should be cleared to access the actual node index.
159 int isLeafNode(int index) { return (index >> 31 == 0); }
160 int getIndexWithInternalNodeMarkerRemoved(int index) { return index & (~0x80000000); }
161 int getIndexWithInternalNodeMarkerSet(int isLeaf, int index) { return (isLeaf) ? index : (index | 0x80000000); }
162
163 //From sap.cl
164 #define NEW_PAIR_MARKER -1
165
166 bool TestAabbAgainstAabb2(const b3AabbCL* aabb1, const b3AabbCL* aabb2)
167 {
168         bool overlap = true;
169         overlap = (aabb1->m_min.x > aabb2->m_max.x || aabb1->m_max.x < aabb2->m_min.x) ? false : overlap;
170         overlap = (aabb1->m_min.z > aabb2->m_max.z || aabb1->m_max.z < aabb2->m_min.z) ? false : overlap;
171         overlap = (aabb1->m_min.y > aabb2->m_max.y || aabb1->m_max.y < aabb2->m_min.y) ? false : overlap;
172         return overlap;
173 }
174 //From sap.cl
175
176 __kernel void plbvhCalculateOverlappingPairs(__global b3AabbCL* rigidAabbs, 
177
178                                                                                         __global int* rootNodeIndex, 
179                                                                                         __global int2* internalNodeChildIndices, 
180                                                                                         __global b3AabbCL* internalNodeAabbs,
181                                                                                         __global int2* internalNodeLeafIndexRanges,
182                                                                                         
183                                                                                         __global SortDataCL* mortonCodesAndAabbIndices,
184                                                                                         __global int* out_numPairs, __global int4* out_overlappingPairs, 
185                                                                                         int maxPairs, int numQueryAabbs)
186 {
187         //Using get_group_id()/get_local_id() is Faster than get_global_id(0) since
188         //mortonCodesAndAabbIndices[] contains rigid body indices sorted along the z-curve (more spatially coherent)
189         int queryBvhNodeIndex = get_group_id(0) * get_local_size(0) + get_local_id(0);
190         if(queryBvhNodeIndex >= numQueryAabbs) return;
191         
192         int queryRigidIndex = mortonCodesAndAabbIndices[queryBvhNodeIndex].m_value;
193         b3AabbCL queryAabb = rigidAabbs[queryRigidIndex];
194         
195         int stack[B3_PLVBH_TRAVERSE_MAX_STACK_SIZE];
196         
197         int stackSize = 1;
198         stack[0] = *rootNodeIndex;
199         
200         while(stackSize)
201         {
202                 int internalOrLeafNodeIndex = stack[ stackSize - 1 ];
203                 --stackSize;
204                 
205                 int isLeaf = isLeafNode(internalOrLeafNodeIndex);       //Internal node if false
206                 int bvhNodeIndex = getIndexWithInternalNodeMarkerRemoved(internalOrLeafNodeIndex);
207                 
208                 //Optimization - if the BVH is structured as a binary radix tree, then
209                 //each internal node corresponds to a contiguous range of leaf nodes(internalNodeLeafIndexRanges[]).
210                 //This can be used to avoid testing each AABB-AABB pair twice, including preventing each node from colliding with itself.
211                 {
212                         int highestLeafIndex = (isLeaf) ? bvhNodeIndex : internalNodeLeafIndexRanges[bvhNodeIndex].y;
213                         if(highestLeafIndex <= queryBvhNodeIndex) continue;
214                 }
215                 
216                 //bvhRigidIndex is not used if internal node
217                 int bvhRigidIndex = (isLeaf) ? mortonCodesAndAabbIndices[bvhNodeIndex].m_value : -1;
218         
219                 b3AabbCL bvhNodeAabb = (isLeaf) ? rigidAabbs[bvhRigidIndex] : internalNodeAabbs[bvhNodeIndex];
220                 if( TestAabbAgainstAabb2(&queryAabb, &bvhNodeAabb) )
221                 {
222                         if(isLeaf)
223                         {
224                                 int4 pair;
225                                 pair.x = rigidAabbs[queryRigidIndex].m_minIndices[3];
226                                 pair.y = rigidAabbs[bvhRigidIndex].m_minIndices[3];
227                                 pair.z = NEW_PAIR_MARKER;
228                                 pair.w = NEW_PAIR_MARKER;
229                                 
230                                 int pairIndex = atomic_inc(out_numPairs);
231                                 if(pairIndex < maxPairs) out_overlappingPairs[pairIndex] = pair;
232                         }
233                         
234                         if(!isLeaf)     //Internal node
235                         {
236                                 if(stackSize + 2 > B3_PLVBH_TRAVERSE_MAX_STACK_SIZE)
237                                 {
238                                         //Error
239                                 }
240                                 else
241                                 {
242                                         stack[ stackSize++ ] = internalNodeChildIndices[bvhNodeIndex].x;
243                                         stack[ stackSize++ ] = internalNodeChildIndices[bvhNodeIndex].y;
244                                 }
245                         }
246                 }
247                 
248         }
249 }
250
251
252 //From rayCastKernels.cl
253 typedef struct
254 {
255         float4 m_from;
256         float4 m_to;
257 } b3RayInfo;
258 //From rayCastKernels.cl
259
260 b3Vector3 b3Vector3_normalize(b3Vector3 v)
261 {
262         b3Vector3 normal = (b3Vector3){v.x, v.y, v.z, 0.f};
263         return normalize(normal);       //OpenCL normalize == vector4 normalize
264 }
265 b3Scalar b3Vector3_length2(b3Vector3 v) { return v.x*v.x + v.y*v.y + v.z*v.z; }
266 b3Scalar b3Vector3_dot(b3Vector3 a, b3Vector3 b) { return a.x*b.x + a.y*b.y + a.z*b.z; }
267
268 int rayIntersectsAabb(b3Vector3 rayOrigin, b3Scalar rayLength, b3Vector3 rayNormalizedDirection, b3AabbCL aabb)
269 {
270         //AABB is considered as 3 pairs of 2 planes( {x_min, x_max}, {y_min, y_max}, {z_min, z_max} ).
271         //t_min is the point of intersection with the closer plane, t_max is the point of intersection with the farther plane.
272         //
273         //if (rayNormalizedDirection.x < 0.0f), then max.x will be the near plane 
274         //and min.x will be the far plane; otherwise, it is reversed.
275         //
276         //In order for there to be a collision, the t_min and t_max of each pair must overlap.
277         //This can be tested for by selecting the highest t_min and lowest t_max and comparing them.
278         
279         int4 isNegative = isless( rayNormalizedDirection, ((b3Vector3){0.0f, 0.0f, 0.0f, 0.0f}) );      //isless(x,y) returns (x < y)
280         
281         //When using vector types, the select() function checks the most signficant bit, 
282         //but isless() sets the least significant bit.
283         isNegative <<= 31;
284
285         //select(b, a, condition) == condition ? a : b
286         //When using select() with vector types, (condition[i]) is true if its most significant bit is 1
287         b3Vector3 t_min = ( select(aabb.m_min, aabb.m_max, isNegative) - rayOrigin ) / rayNormalizedDirection;
288         b3Vector3 t_max = ( select(aabb.m_max, aabb.m_min, isNegative) - rayOrigin ) / rayNormalizedDirection;
289         
290         b3Scalar t_min_final = 0.0f;
291         b3Scalar t_max_final = rayLength;
292         
293         //Must use fmin()/fmax(); if one of the parameters is NaN, then the parameter that is not NaN is returned. 
294         //Behavior of min()/max() with NaNs is undefined. (See OpenCL Specification 1.2 [6.12.2] and [6.12.4])
295         //Since the innermost fmin()/fmax() is always not NaN, this should never return NaN.
296         t_min_final = fmax( t_min.z, fmax(t_min.y, fmax(t_min.x, t_min_final)) );
297         t_max_final = fmin( t_max.z, fmin(t_max.y, fmin(t_max.x, t_max_final)) );
298         
299         return (t_min_final <= t_max_final);
300 }
301
302 __kernel void plbvhRayTraverse(__global b3AabbCL* rigidAabbs,
303
304                                                                 __global int* rootNodeIndex, 
305                                                                 __global int2* internalNodeChildIndices, 
306                                                                 __global b3AabbCL* internalNodeAabbs,
307                                                                 __global int2* internalNodeLeafIndexRanges,
308                                                                 __global SortDataCL* mortonCodesAndAabbIndices,
309                                                                 
310                                                                 __global b3RayInfo* rays,
311                                                                 
312                                                                 __global int* out_numRayRigidPairs, 
313                                                                 __global int2* out_rayRigidPairs,
314                                                                 int maxRayRigidPairs, int numRays)
315 {
316         int rayIndex = get_global_id(0);
317         if(rayIndex >= numRays) return;
318         
319         //
320         b3Vector3 rayFrom = rays[rayIndex].m_from;
321         b3Vector3 rayTo = rays[rayIndex].m_to;
322         b3Vector3 rayNormalizedDirection = b3Vector3_normalize(rayTo - rayFrom);
323         b3Scalar rayLength = b3Sqrt( b3Vector3_length2(rayTo - rayFrom) );
324         
325         //
326         int stack[B3_PLVBH_TRAVERSE_MAX_STACK_SIZE];
327         
328         int stackSize = 1;
329         stack[0] = *rootNodeIndex;
330         
331         while(stackSize)
332         {
333                 int internalOrLeafNodeIndex = stack[ stackSize - 1 ];
334                 --stackSize;
335                 
336                 int isLeaf = isLeafNode(internalOrLeafNodeIndex);       //Internal node if false
337                 int bvhNodeIndex = getIndexWithInternalNodeMarkerRemoved(internalOrLeafNodeIndex);
338                 
339                 //bvhRigidIndex is not used if internal node
340                 int bvhRigidIndex = (isLeaf) ? mortonCodesAndAabbIndices[bvhNodeIndex].m_value : -1;
341         
342                 b3AabbCL bvhNodeAabb = (isLeaf) ? rigidAabbs[bvhRigidIndex] : internalNodeAabbs[bvhNodeIndex];
343                 if( rayIntersectsAabb(rayFrom, rayLength, rayNormalizedDirection, bvhNodeAabb)  )
344                 {
345                         if(isLeaf)
346                         {
347                                 int2 rayRigidPair;
348                                 rayRigidPair.x = rayIndex;
349                                 rayRigidPair.y = rigidAabbs[bvhRigidIndex].m_minIndices[3];
350                                 
351                                 int pairIndex = atomic_inc(out_numRayRigidPairs);
352                                 if(pairIndex < maxRayRigidPairs) out_rayRigidPairs[pairIndex] = rayRigidPair;
353                         }
354                         
355                         if(!isLeaf)     //Internal node
356                         {
357                                 if(stackSize + 2 > B3_PLVBH_TRAVERSE_MAX_STACK_SIZE)
358                                 {
359                                         //Error
360                                 }
361                                 else
362                                 {
363                                         stack[ stackSize++ ] = internalNodeChildIndices[bvhNodeIndex].x;
364                                         stack[ stackSize++ ] = internalNodeChildIndices[bvhNodeIndex].y;
365                                 }
366                         }
367                 }
368         }
369 }
370
371 __kernel void plbvhLargeAabbAabbTest(__global b3AabbCL* smallAabbs, __global b3AabbCL* largeAabbs, 
372                                                                         __global int* out_numPairs, __global int4* out_overlappingPairs, 
373                                                                         int maxPairs, int numLargeAabbRigids, int numSmallAabbRigids)
374 {
375         int smallAabbIndex = get_global_id(0);
376         if(smallAabbIndex >= numSmallAabbRigids) return;
377         
378         b3AabbCL smallAabb = smallAabbs[smallAabbIndex];
379         for(int i = 0; i < numLargeAabbRigids; ++i)
380         {
381                 b3AabbCL largeAabb = largeAabbs[i];
382                 if( TestAabbAgainstAabb2(&smallAabb, &largeAabb) )
383                 {
384                         int4 pair;
385                         pair.x = largeAabb.m_minIndices[3];
386                         pair.y = smallAabb.m_minIndices[3];
387                         pair.z = NEW_PAIR_MARKER;
388                         pair.w = NEW_PAIR_MARKER;
389                         
390                         int pairIndex = atomic_inc(out_numPairs);
391                         if(pairIndex < maxPairs) out_overlappingPairs[pairIndex] = pair;
392                 }
393         }
394 }
395 __kernel void plbvhLargeAabbRayTest(__global b3AabbCL* largeRigidAabbs, __global b3RayInfo* rays,
396                                                                         __global int* out_numRayRigidPairs,  __global int2* out_rayRigidPairs,
397                                                                         int numLargeAabbRigids, int maxRayRigidPairs, int numRays)
398 {
399         int rayIndex = get_global_id(0);
400         if(rayIndex >= numRays) return;
401         
402         b3Vector3 rayFrom = rays[rayIndex].m_from;
403         b3Vector3 rayTo = rays[rayIndex].m_to;
404         b3Vector3 rayNormalizedDirection = b3Vector3_normalize(rayTo - rayFrom);
405         b3Scalar rayLength = b3Sqrt( b3Vector3_length2(rayTo - rayFrom) );
406         
407         for(int i = 0; i < numLargeAabbRigids; ++i)
408         {
409                 b3AabbCL rigidAabb = largeRigidAabbs[i];
410                 if( rayIntersectsAabb(rayFrom, rayLength, rayNormalizedDirection, rigidAabb) )
411                 {
412                         int2 rayRigidPair;
413                         rayRigidPair.x = rayIndex;
414                         rayRigidPair.y = rigidAabb.m_minIndices[3];
415                         
416                         int pairIndex = atomic_inc(out_numRayRigidPairs);
417                         if(pairIndex < maxRayRigidPairs) out_rayRigidPairs[pairIndex] = rayRigidPair;
418                 }
419         }
420 }
421
422
423 //Set so that it is always greater than the actual common prefixes, and never selected as a parent node.
424 //If there are no duplicates, then the highest common prefix is 32 or 64, depending on the number of bits used for the z-curve.
425 //Duplicate common prefixes increase the highest common prefix at most by the number of bits used to index the leaf node.
426 //Since 32 bit ints are used to index leaf nodes, the max prefix is 64(32 + 32 bit z-curve) or 96(32 + 64 bit z-curve).
427 #define B3_PLBVH_INVALID_COMMON_PREFIX 128
428
429 #define B3_PLBVH_ROOT_NODE_MARKER -1
430
431 #define b3Int64 long
432
433 int computeCommonPrefixLength(b3Int64 i, b3Int64 j) { return (int)clz(i ^ j); }
434 b3Int64 computeCommonPrefix(b3Int64 i, b3Int64 j) 
435 {
436         //This function only needs to return (i & j) in order for the algorithm to work,
437         //but it may help with debugging to mask out the lower bits.
438
439         b3Int64 commonPrefixLength = (b3Int64)computeCommonPrefixLength(i, j);
440
441         b3Int64 sharedBits = i & j;
442         b3Int64 bitmask = ((b3Int64)(~0)) << (64 - commonPrefixLength); //Set all bits after the common prefix to 0
443         
444         return sharedBits & bitmask;
445 }
446
447 //Same as computeCommonPrefixLength(), but allows for prefixes with different lengths
448 int getSharedPrefixLength(b3Int64 prefixA, int prefixLengthA, b3Int64 prefixB, int prefixLengthB)
449 {
450         return b3Min( computeCommonPrefixLength(prefixA, prefixB), b3Min(prefixLengthA, prefixLengthB) );
451 }
452
453 __kernel void computeAdjacentPairCommonPrefix(__global SortDataCL* mortonCodesAndAabbIndices,
454                                                                                         __global b3Int64* out_commonPrefixes,
455                                                                                         __global int* out_commonPrefixLengths,
456                                                                                         int numInternalNodes)
457 {
458         int internalNodeIndex = get_global_id(0);
459         if (internalNodeIndex >= numInternalNodes) return;
460         
461         //Here, (internalNodeIndex + 1) is never out of bounds since it is a leaf node index,
462         //and the number of internal nodes is always numLeafNodes - 1
463         int leftLeafIndex = internalNodeIndex;
464         int rightLeafIndex = internalNodeIndex + 1;
465         
466         int leftLeafMortonCode = mortonCodesAndAabbIndices[leftLeafIndex].m_key;
467         int rightLeafMortonCode = mortonCodesAndAabbIndices[rightLeafIndex].m_key;
468         
469         //Binary radix tree construction algorithm does not work if there are duplicate morton codes.
470         //Append the index of each leaf node to each morton code so that there are no duplicates.
471         //The algorithm also requires that the morton codes are sorted in ascending order; this requirement
472         //is also satisfied with this method, as (leftLeafIndex < rightLeafIndex) is always true.
473         //
474         //upsample(a, b) == ( ((b3Int64)a) << 32) | b
475         b3Int64 nonduplicateLeftMortonCode = upsample(leftLeafMortonCode, leftLeafIndex);
476         b3Int64 nonduplicateRightMortonCode = upsample(rightLeafMortonCode, rightLeafIndex);
477         
478         out_commonPrefixes[internalNodeIndex] = computeCommonPrefix(nonduplicateLeftMortonCode, nonduplicateRightMortonCode);
479         out_commonPrefixLengths[internalNodeIndex] = computeCommonPrefixLength(nonduplicateLeftMortonCode, nonduplicateRightMortonCode);
480 }
481
482
483 __kernel void buildBinaryRadixTreeLeafNodes(__global int* commonPrefixLengths, __global int* out_leafNodeParentNodes,
484                                                                                         __global int2* out_childNodes, int numLeafNodes)
485 {
486         int leafNodeIndex = get_global_id(0);
487         if (leafNodeIndex >= numLeafNodes) return;
488         
489         int numInternalNodes = numLeafNodes - 1;
490         
491         int leftSplitIndex = leafNodeIndex - 1;
492         int rightSplitIndex = leafNodeIndex;
493         
494         int leftCommonPrefix = (leftSplitIndex >= 0) ? commonPrefixLengths[leftSplitIndex] : B3_PLBVH_INVALID_COMMON_PREFIX;
495         int rightCommonPrefix = (rightSplitIndex < numInternalNodes) ? commonPrefixLengths[rightSplitIndex] : B3_PLBVH_INVALID_COMMON_PREFIX;
496         
497         //Parent node is the highest adjacent common prefix that is lower than the node's common prefix
498         //Leaf nodes are considered as having the highest common prefix
499         int isLeftHigherCommonPrefix = (leftCommonPrefix > rightCommonPrefix);
500         
501         //Handle cases for the edge nodes; the first and last node
502         //For leaf nodes, leftCommonPrefix and rightCommonPrefix should never both be B3_PLBVH_INVALID_COMMON_PREFIX
503         if(leftCommonPrefix == B3_PLBVH_INVALID_COMMON_PREFIX) isLeftHigherCommonPrefix = false;
504         if(rightCommonPrefix == B3_PLBVH_INVALID_COMMON_PREFIX) isLeftHigherCommonPrefix = true;
505         
506         int parentNodeIndex = (isLeftHigherCommonPrefix) ? leftSplitIndex : rightSplitIndex;
507         out_leafNodeParentNodes[leafNodeIndex] = parentNodeIndex;
508         
509         int isRightChild = (isLeftHigherCommonPrefix);  //If the left node is the parent, then this node is its right child and vice versa
510         
511         //out_childNodesAsInt[0] == int2.x == left child
512         //out_childNodesAsInt[1] == int2.y == right child
513         int isLeaf = 1;
514         __global int* out_childNodesAsInt = (__global int*)(&out_childNodes[parentNodeIndex]);
515         out_childNodesAsInt[isRightChild] = getIndexWithInternalNodeMarkerSet(isLeaf, leafNodeIndex);
516 }
517
518 __kernel void buildBinaryRadixTreeInternalNodes(__global b3Int64* commonPrefixes, __global int* commonPrefixLengths,
519                                                                                                 __global int2* out_childNodes,
520                                                                                                 __global int* out_internalNodeParentNodes, __global int* out_rootNodeIndex,
521                                                                                                 int numInternalNodes)
522 {
523         int internalNodeIndex = get_group_id(0) * get_local_size(0) + get_local_id(0);
524         if(internalNodeIndex >= numInternalNodes) return;
525         
526         b3Int64 nodePrefix = commonPrefixes[internalNodeIndex];
527         int nodePrefixLength = commonPrefixLengths[internalNodeIndex];
528         
529 //#define USE_LINEAR_SEARCH
530 #ifdef USE_LINEAR_SEARCH
531         int leftIndex = -1;
532         int rightIndex = -1;
533         
534         //Find nearest element to left with a lower common prefix
535         for(int i = internalNodeIndex - 1; i >= 0; --i)
536         {
537                 int nodeLeftSharedPrefixLength = getSharedPrefixLength(nodePrefix, nodePrefixLength, commonPrefixes[i], commonPrefixLengths[i]);
538                 if(nodeLeftSharedPrefixLength < nodePrefixLength)
539                 {
540                         leftIndex = i;
541                         break;
542                 }
543         }
544         
545         //Find nearest element to right with a lower common prefix
546         for(int i = internalNodeIndex + 1; i < numInternalNodes; ++i)
547         {
548                 int nodeRightSharedPrefixLength = getSharedPrefixLength(nodePrefix, nodePrefixLength, commonPrefixes[i], commonPrefixLengths[i]);
549                 if(nodeRightSharedPrefixLength < nodePrefixLength)
550                 {
551                         rightIndex = i;
552                         break;
553                 }
554         }
555         
556 #else //Use binary search
557
558         //Find nearest element to left with a lower common prefix
559         int leftIndex = -1;
560         {
561                 int lower = 0;
562                 int upper = internalNodeIndex - 1;
563                 
564                 while(lower <= upper)
565                 {
566                         int mid = (lower + upper) / 2;
567                         b3Int64 midPrefix = commonPrefixes[mid];
568                         int midPrefixLength = commonPrefixLengths[mid];
569                         
570                         int nodeMidSharedPrefixLength = getSharedPrefixLength(nodePrefix, nodePrefixLength, midPrefix, midPrefixLength);
571                         if(nodeMidSharedPrefixLength < nodePrefixLength) 
572                         {
573                                 int right = mid + 1;
574                                 if(right < internalNodeIndex)
575                                 {
576                                         b3Int64 rightPrefix = commonPrefixes[right];
577                                         int rightPrefixLength = commonPrefixLengths[right];
578                                         
579                                         int nodeRightSharedPrefixLength = getSharedPrefixLength(nodePrefix, nodePrefixLength, rightPrefix, rightPrefixLength);
580                                         if(nodeRightSharedPrefixLength < nodePrefixLength) 
581                                         {
582                                                 lower = right;
583                                                 leftIndex = right;
584                                         }
585                                         else 
586                                         {
587                                                 leftIndex = mid;
588                                                 break;
589                                         }
590                                 }
591                                 else 
592                                 {
593                                         leftIndex = mid;
594                                         break;
595                                 }
596                         }
597                         else upper = mid - 1;
598                 }
599         }
600         
601         //Find nearest element to right with a lower common prefix
602         int rightIndex = -1;
603         {
604                 int lower = internalNodeIndex + 1;
605                 int upper = numInternalNodes - 1;
606                 
607                 while(lower <= upper)
608                 {
609                         int mid = (lower + upper) / 2;
610                         b3Int64 midPrefix = commonPrefixes[mid];
611                         int midPrefixLength = commonPrefixLengths[mid];
612                         
613                         int nodeMidSharedPrefixLength = getSharedPrefixLength(nodePrefix, nodePrefixLength, midPrefix, midPrefixLength);
614                         if(nodeMidSharedPrefixLength < nodePrefixLength) 
615                         {
616                                 int left = mid - 1;
617                                 if(left > internalNodeIndex)
618                                 {
619                                         b3Int64 leftPrefix = commonPrefixes[left];
620                                         int leftPrefixLength = commonPrefixLengths[left];
621                                 
622                                         int nodeLeftSharedPrefixLength = getSharedPrefixLength(nodePrefix, nodePrefixLength, leftPrefix, leftPrefixLength);
623                                         if(nodeLeftSharedPrefixLength < nodePrefixLength) 
624                                         {
625                                                 upper = left;
626                                                 rightIndex = left;
627                                         }
628                                         else 
629                                         {
630                                                 rightIndex = mid;
631                                                 break;
632                                         }
633                                 }
634                                 else 
635                                 {
636                                         rightIndex = mid;
637                                         break;
638                                 }
639                         }
640                         else lower = mid + 1;
641                 }
642         }
643 #endif
644         
645         //Select parent
646         {
647                 int leftPrefixLength = (leftIndex != -1) ? commonPrefixLengths[leftIndex] : B3_PLBVH_INVALID_COMMON_PREFIX;
648                 int rightPrefixLength =  (rightIndex != -1) ? commonPrefixLengths[rightIndex] : B3_PLBVH_INVALID_COMMON_PREFIX;
649                 
650                 int isLeftHigherPrefixLength = (leftPrefixLength > rightPrefixLength);
651                 
652                 if(leftPrefixLength == B3_PLBVH_INVALID_COMMON_PREFIX) isLeftHigherPrefixLength = false;
653                 else if(rightPrefixLength == B3_PLBVH_INVALID_COMMON_PREFIX) isLeftHigherPrefixLength = true;
654                 
655                 int parentNodeIndex = (isLeftHigherPrefixLength) ? leftIndex : rightIndex;
656                 
657                 int isRootNode = (leftIndex == -1 && rightIndex == -1);
658                 out_internalNodeParentNodes[internalNodeIndex] = (!isRootNode) ? parentNodeIndex : B3_PLBVH_ROOT_NODE_MARKER;
659                 
660                 int isLeaf = 0;
661                 if(!isRootNode)
662                 {
663                         int isRightChild = (isLeftHigherPrefixLength);  //If the left node is the parent, then this node is its right child and vice versa
664                         
665                         //out_childNodesAsInt[0] == int2.x == left child
666                         //out_childNodesAsInt[1] == int2.y == right child
667                         __global int* out_childNodesAsInt = (__global int*)(&out_childNodes[parentNodeIndex]);
668                         out_childNodesAsInt[isRightChild] = getIndexWithInternalNodeMarkerSet(isLeaf, internalNodeIndex);
669                 }
670                 else *out_rootNodeIndex = getIndexWithInternalNodeMarkerSet(isLeaf, internalNodeIndex);
671         }
672 }
673
674 __kernel void findDistanceFromRoot(__global int* rootNodeIndex, __global int* internalNodeParentNodes,
675                                                                         __global int* out_maxDistanceFromRoot, __global int* out_distanceFromRoot, int numInternalNodes)
676 {
677         if( get_global_id(0) == 0 ) atomic_xchg(out_maxDistanceFromRoot, 0);
678
679         int internalNodeIndex = get_global_id(0);
680         if(internalNodeIndex >= numInternalNodes) return;
681         
682         //
683         int distanceFromRoot = 0;
684         {
685                 int parentIndex = internalNodeParentNodes[internalNodeIndex];
686                 while(parentIndex != B3_PLBVH_ROOT_NODE_MARKER)
687                 {
688                         parentIndex = internalNodeParentNodes[parentIndex];
689                         ++distanceFromRoot;
690                 }
691         }
692         out_distanceFromRoot[internalNodeIndex] = distanceFromRoot;
693         
694         //
695         __local int localMaxDistanceFromRoot;
696         if( get_local_id(0) == 0 ) localMaxDistanceFromRoot = 0;
697         barrier(CLK_LOCAL_MEM_FENCE);
698         
699         atomic_max(&localMaxDistanceFromRoot, distanceFromRoot);
700         barrier(CLK_LOCAL_MEM_FENCE);
701         
702         if( get_local_id(0) == 0 ) atomic_max(out_maxDistanceFromRoot, localMaxDistanceFromRoot);
703 }
704
705 __kernel void buildBinaryRadixTreeAabbsRecursive(__global int* distanceFromRoot, __global SortDataCL* mortonCodesAndAabbIndices,
706                                                                                                 __global int2* childNodes,
707                                                                                                 __global b3AabbCL* leafNodeAabbs, __global b3AabbCL* internalNodeAabbs,
708                                                                                                 int maxDistanceFromRoot, int processedDistance, int numInternalNodes)
709 {
710         int internalNodeIndex = get_global_id(0);
711         if(internalNodeIndex >= numInternalNodes) return;
712         
713         int distance = distanceFromRoot[internalNodeIndex];
714         
715         if(distance == processedDistance)
716         {
717                 int leftChildIndex = childNodes[internalNodeIndex].x;
718                 int rightChildIndex = childNodes[internalNodeIndex].y;
719                 
720                 int isLeftChildLeaf = isLeafNode(leftChildIndex);
721                 int isRightChildLeaf = isLeafNode(rightChildIndex);
722                 
723                 leftChildIndex = getIndexWithInternalNodeMarkerRemoved(leftChildIndex);
724                 rightChildIndex = getIndexWithInternalNodeMarkerRemoved(rightChildIndex);
725                 
726                 //leftRigidIndex/rightRigidIndex is not used if internal node
727                 int leftRigidIndex = (isLeftChildLeaf) ? mortonCodesAndAabbIndices[leftChildIndex].m_value : -1;
728                 int rightRigidIndex = (isRightChildLeaf) ? mortonCodesAndAabbIndices[rightChildIndex].m_value : -1;
729                 
730                 b3AabbCL leftChildAabb = (isLeftChildLeaf) ? leafNodeAabbs[leftRigidIndex] : internalNodeAabbs[leftChildIndex];
731                 b3AabbCL rightChildAabb = (isRightChildLeaf) ? leafNodeAabbs[rightRigidIndex] : internalNodeAabbs[rightChildIndex];
732                 
733                 b3AabbCL mergedAabb;
734                 mergedAabb.m_min = b3Min(leftChildAabb.m_min, rightChildAabb.m_min);
735                 mergedAabb.m_max = b3Max(leftChildAabb.m_max, rightChildAabb.m_max);
736                 internalNodeAabbs[internalNodeIndex] = mergedAabb;
737         }
738 }
739
740 __kernel void findLeafIndexRanges(__global int2* internalNodeChildNodes, __global int2* out_leafIndexRanges, int numInternalNodes)
741 {
742         int internalNodeIndex = get_global_id(0);
743         if(internalNodeIndex >= numInternalNodes) return;
744         
745         int numLeafNodes = numInternalNodes + 1;
746         
747         int2 childNodes = internalNodeChildNodes[internalNodeIndex];
748         
749         int2 leafIndexRange;    //x == min leaf index, y == max leaf index
750         
751         //Find lowest leaf index covered by this internal node
752         {
753                 int lowestIndex = childNodes.x;         //childNodes.x == Left child
754                 while( !isLeafNode(lowestIndex) ) lowestIndex = internalNodeChildNodes[ getIndexWithInternalNodeMarkerRemoved(lowestIndex) ].x;
755                 leafIndexRange.x = lowestIndex;
756         }
757         
758         //Find highest leaf index covered by this internal node
759         {
760                 int highestIndex = childNodes.y;        //childNodes.y == Right child
761                 while( !isLeafNode(highestIndex) ) highestIndex = internalNodeChildNodes[ getIndexWithInternalNodeMarkerRemoved(highestIndex) ].y;
762                 leafIndexRange.y = highestIndex;
763         }
764         
765         //
766         out_leafIndexRanges[internalNodeIndex] = leafIndexRange;
767 }